zkt26/sk1/summarizer/main.py
2026-05-13 21:49:47 +02:00

38 lines
855 B
Python

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI()
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
class TextRequest(BaseModel):
text: str
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/summarize")
def summarize(req: TextRequest):
inputs = tokenizer(
req.text,
max_length=1024,
truncation=True,
return_tensors="pt",
)
summary_ids = model.generate(
**inputs,
max_length=150,
min_length=40,
num_beams=4,
length_penalty=2.0,
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}