57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
|
import torch
|
||
|
import uvicorn
|
||
|
from fastapi import FastAPI
|
||
|
from pydantic import BaseModel
|
||
|
from transformers import MT5Tokenizer,AutoTokenizer, AutoModel ,T5ForConditionalGeneration
|
||
|
import warnings
|
||
|
import json
|
||
|
import random
|
||
|
import torch.nn.functional as F
|
||
|
import os
|
||
|
from dotenv import load_dotenv
|
||
|
#from ece import compute_ECE
|
||
|
from torch.utils.data import DataLoader
|
||
|
from functools import reduce
|
||
|
|
||
|
warnings.filterwarnings("ignore")
|
||
|
DEVICE ='cpu'
|
||
|
|
||
|
load_dotenv()
|
||
|
host = os.getenv("HOST")
|
||
|
port = os.getenv("PORT")
|
||
|
|
||
|
model_dir = os.getenv("QA_MODEL")
|
||
|
#model_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_model"
|
||
|
tokenizer_dir = os.getenv("QA_TOKENIZER")
|
||
|
#tokenizer_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_tokenizer"
|
||
|
MODEL = T5ForConditionalGeneration.from_pretrained(model_dir, from_tf=False, return_dict=True).to(DEVICE)
|
||
|
print("Model succesfully loaded!")
|
||
|
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
|
||
|
print("Tokenizer succesfully loaded!")
|
||
|
Q_LEN = 512
|
||
|
TOKENIZER.add_tokens('<sep>')
|
||
|
|
||
|
|
||
|
print('model loaded')
|
||
|
|
||
|
app = FastAPI()
|
||
|
|
||
|
# BASE MODEL
|
||
|
class InputData(BaseModel):
|
||
|
context: str
|
||
|
question: str
|
||
|
|
||
|
@app.post("/predict")
|
||
|
async def predict(input_data: InputData):
|
||
|
inputs = TOKENIZER(input_data.question, input_data.context, max_length=512, padding="max_length", truncation=True, add_special_tokens=True)
|
||
|
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
||
|
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
||
|
outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True,output_scores=True,max_length=512)
|
||
|
predicted_ids = outputs.sequences.numpy()
|
||
|
predicted_text = TOKENIZER.decode(predicted_ids[0], skip_special_tokens=True)
|
||
|
|
||
|
return {'prediction':predicted_text}
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
uvicorn.run(app, host=host, port=port)
|