48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
import torch
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
from transformers import MT5Tokenizer,AutoTokenizer, AutoModel ,T5ForConditionalGeneration
|
|
import warnings
|
|
import json
|
|
import random
|
|
import torch.nn.functional as F
|
|
#from ece import compute_ECE
|
|
from torch.utils.data import DataLoader
|
|
from functools import reduce
|
|
|
|
warnings.filterwarnings("ignore")
|
|
DEVICE ='cpu'
|
|
model_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_model"
|
|
tokenizer_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_tokenizer"
|
|
MODEL = T5ForConditionalGeneration.from_pretrained(model_dir, from_tf=False, return_dict=True).to(DEVICE)
|
|
print("Model succesfully loaded!")
|
|
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
|
|
print("Tokenizer succesfully loaded!")
|
|
Q_LEN = 512
|
|
TOKENIZER.add_tokens('<sep>')
|
|
|
|
|
|
print('model loaded')
|
|
|
|
app = FastAPI()
|
|
|
|
# BASE MODEL
|
|
class InputData(BaseModel):
|
|
context: str
|
|
question: str
|
|
|
|
@app.post("/predict")
|
|
async def predict(input_data: InputData):
|
|
inputs = TOKENIZER(input_data.question, input_data.context, max_length=512, padding="max_length", truncation=True, add_special_tokens=True)
|
|
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
|
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
|
outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True,output_scores=True,max_length=512)
|
|
predicted_ids = outputs.sequences.numpy()
|
|
predicted_text = TOKENIZER.decode(predicted_ids[0], skip_special_tokens=True)
|
|
|
|
return {'prediction':predicted_text}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="127.0.0.1", port=8090) |