DIPLOMOVA_PRACA/api.py

48 lines
1.7 KiB
Python
Raw Normal View History

2023-11-10 22:53:03 +00:00
import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MT5Tokenizer,AutoTokenizer, AutoModel ,T5ForConditionalGeneration
import warnings
import json
import random
import torch.nn.functional as F
#from ece import compute_ECE
from torch.utils.data import DataLoader
from functools import reduce
warnings.filterwarnings("ignore")
DEVICE ='cpu'
model_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_model"
tokenizer_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_tokenizer"
MODEL = T5ForConditionalGeneration.from_pretrained(model_dir, from_tf=False, return_dict=True).to(DEVICE)
print("Model succesfully loaded!")
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
print("Tokenizer succesfully loaded!")
Q_LEN = 512
TOKENIZER.add_tokens('<sep>')
print('model loaded')
app = FastAPI()
# BASE MODEL
class InputData(BaseModel):
context: str
question: str
@app.post("/predict")
async def predict(input_data: InputData):
inputs = TOKENIZER(input_data.question, input_data.context, max_length=512, padding="max_length", truncation=True, add_special_tokens=True)
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)
outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True,output_scores=True,max_length=512)
predicted_ids = outputs.sequences.numpy()
predicted_text = TOKENIZER.decode(predicted_ids[0], skip_special_tokens=True)
return {'prediction':predicted_text}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8090)