38 lines
855 B
Python
38 lines
855 B
Python
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
import torch
|
|
|
|
app = FastAPI()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
|
|
|
|
|
|
class TextRequest(BaseModel):
|
|
text: str
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/summarize")
|
|
def summarize(req: TextRequest):
|
|
inputs = tokenizer(
|
|
req.text,
|
|
max_length=1024,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
summary_ids = model.generate(
|
|
**inputs,
|
|
max_length=150,
|
|
min_length=40,
|
|
num_beams=4,
|
|
length_penalty=2.0,
|
|
)
|
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
|
return {"summary": summary}
|