DIPLOMOVA_PRACA/api.py
2023-11-10 22:53:03 +00:00

48 lines
1.7 KiB
Python

import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MT5Tokenizer,AutoTokenizer, AutoModel ,T5ForConditionalGeneration
import warnings
import json
import random
import torch.nn.functional as F
#from ece import compute_ECE
from torch.utils.data import DataLoader
from functools import reduce
warnings.filterwarnings("ignore")
DEVICE ='cpu'
model_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_model"
tokenizer_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_tokenizer"
MODEL = T5ForConditionalGeneration.from_pretrained(model_dir, from_tf=False, return_dict=True).to(DEVICE)
print("Model succesfully loaded!")
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
print("Tokenizer succesfully loaded!")
Q_LEN = 512
TOKENIZER.add_tokens('<sep>')
print('model loaded')
app = FastAPI()
# BASE MODEL
class InputData(BaseModel):
context: str
question: str
@app.post("/predict")
async def predict(input_data: InputData):
inputs = TOKENIZER(input_data.question, input_data.context, max_length=512, padding="max_length", truncation=True, add_special_tokens=True)
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)
outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True,output_scores=True,max_length=512)
predicted_ids = outputs.sequences.numpy()
predicted_text = TOKENIZER.decode(predicted_ids[0], skip_special_tokens=True)
return {'prediction':predicted_text}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8090)