Upload files to ''
This commit is contained in:
parent
206ff3ce64
commit
c5e55256db
48
api.py
Normal file
48
api.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import MT5Tokenizer,AutoTokenizer, AutoModel ,T5ForConditionalGeneration
|
||||||
|
import warnings
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import torch.nn.functional as F
|
||||||
|
#from ece import compute_ECE
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
DEVICE ='cpu'
|
||||||
|
model_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_model"
|
||||||
|
tokenizer_dir = "C:/Users/david/Desktop/T5_JUPYTER/qa_tokenizer"
|
||||||
|
MODEL = T5ForConditionalGeneration.from_pretrained(model_dir, from_tf=False, return_dict=True).to(DEVICE)
|
||||||
|
print("Model succesfully loaded!")
|
||||||
|
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
|
||||||
|
print("Tokenizer succesfully loaded!")
|
||||||
|
Q_LEN = 512
|
||||||
|
TOKENIZER.add_tokens('<sep>')
|
||||||
|
|
||||||
|
|
||||||
|
print('model loaded')
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# BASE MODEL
|
||||||
|
class InputData(BaseModel):
|
||||||
|
context: str
|
||||||
|
question: str
|
||||||
|
|
||||||
|
@app.post("/predict")
|
||||||
|
async def predict(input_data: InputData):
|
||||||
|
inputs = TOKENIZER(input_data.question, input_data.context, max_length=512, padding="max_length", truncation=True, add_special_tokens=True)
|
||||||
|
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
||||||
|
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
||||||
|
outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True,output_scores=True,max_length=512)
|
||||||
|
predicted_ids = outputs.sequences.numpy()
|
||||||
|
predicted_text = TOKENIZER.decode(predicted_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
return {'prediction':predicted_text}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="127.0.0.1", port=8090)
|
33
aplication.py
Normal file
33
aplication.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
def predict(context,question):
|
||||||
|
url = 'http://localhost:8090/predict'
|
||||||
|
data = {'context': context,'question': question}
|
||||||
|
json_data = json.dumps(data)
|
||||||
|
headers = {'Content-type': 'application/json'}
|
||||||
|
response = requests.post(url, data=json_data, headers=headers)
|
||||||
|
result = response.json()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def main():
|
||||||
|
st.title("T5 model inference")
|
||||||
|
|
||||||
|
# Vytvoríme polia pre zadanie hodnôt
|
||||||
|
context = st.text_input("context:")
|
||||||
|
question = st.text_input("question:")
|
||||||
|
prediction = predict(context,question)
|
||||||
|
# Vytvoríme tlačidlo pre vykonanie akcie
|
||||||
|
if st.button("Execute"):
|
||||||
|
|
||||||
|
st.json({
|
||||||
|
'context': context,
|
||||||
|
'question': question,
|
||||||
|
'prediciton':prediction
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
91
condadiplo.yaml
Normal file
91
condadiplo.yaml
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
name: DIPLOcondaEnviroment
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- bzip2=1.0.8=he774522_0
|
||||||
|
- ca-certificates=2023.01.10=haa95532_0
|
||||||
|
- libffi=3.4.2=hd77b12b_6
|
||||||
|
- openssl=1.1.1t=h2bbff1b_0
|
||||||
|
- pip=23.0.1=py311haa95532_0
|
||||||
|
- python=3.11.2=h966fe2a_0
|
||||||
|
- setuptools=66.0.0=py311haa95532_0
|
||||||
|
- sqlite=3.41.2=h2bbff1b_0
|
||||||
|
- tk=8.6.12=h2bbff1b_0
|
||||||
|
- vc=14.2=h21ff451_1
|
||||||
|
- vs2015_runtime=14.27.29016=h5e58377_2
|
||||||
|
- wheel=0.38.4=py311haa95532_0
|
||||||
|
- xz=5.2.10=h8cc25b3_1
|
||||||
|
- zlib=1.2.13=h8cc25b3_0
|
||||||
|
- pip:
|
||||||
|
- altair==4.2.2
|
||||||
|
- anyio==3.6.2
|
||||||
|
- attrs==23.1.0
|
||||||
|
- blinker==1.6.2
|
||||||
|
- cachetools==5.3.0
|
||||||
|
- certifi==2022.12.7
|
||||||
|
- charset-normalizer==3.1.0
|
||||||
|
- click==8.1.3
|
||||||
|
- colorama==0.4.6
|
||||||
|
- comtypes==1.1.14
|
||||||
|
- decorator==5.1.1
|
||||||
|
- entrypoints==0.4
|
||||||
|
- fastapi==0.92.0
|
||||||
|
- filelock==3.11.0
|
||||||
|
- gitdb==4.0.10
|
||||||
|
- gitpython==3.1.31
|
||||||
|
- h11==0.14.0
|
||||||
|
- huggingface-hub==0.12.1
|
||||||
|
- idna==3.4
|
||||||
|
- importlib-metadata==6.4.1
|
||||||
|
- jinja2==3.1.2
|
||||||
|
- jsonschema==4.17.3
|
||||||
|
- markdown-it-py==2.2.0
|
||||||
|
- markupsafe==2.1.2
|
||||||
|
- mdurl==0.1.2
|
||||||
|
- mouseinfo==0.1.3
|
||||||
|
- mpmath==1.3.0
|
||||||
|
- networkx==3.1
|
||||||
|
- numpy==1.24.2
|
||||||
|
- packaging==23.1
|
||||||
|
- pandas==1.5.3
|
||||||
|
- pillow==9.5.0
|
||||||
|
- protobuf==3.20.3
|
||||||
|
- pyarrow==11.0.0
|
||||||
|
- pydantic==1.10.7
|
||||||
|
- pydeck==0.8.1b0
|
||||||
|
- pygments==2.15.0
|
||||||
|
- pympler==1.0.1
|
||||||
|
- pyperclip==1.8.2
|
||||||
|
- pyrsistent==0.19.3
|
||||||
|
- python-dateutil==2.8.2
|
||||||
|
- pytz==2023.3
|
||||||
|
- pytz-deprecation-shim==0.1.0.post0
|
||||||
|
- pywin32==305
|
||||||
|
- pywinauto==0.6.8
|
||||||
|
- pyyaml==6.0
|
||||||
|
- regex==2023.3.23
|
||||||
|
- requests==2.28.2
|
||||||
|
- rich==13.3.4
|
||||||
|
- six==1.16.0
|
||||||
|
- smmap==5.0.0
|
||||||
|
- sniffio==1.3.0
|
||||||
|
- starlette==0.25.0
|
||||||
|
- streamlit==1.21.0
|
||||||
|
- sympy==1.11.1
|
||||||
|
- tokenizers==0.13.2
|
||||||
|
- toml==0.10.2
|
||||||
|
- toolz==0.12.0
|
||||||
|
- torch==2.0.0
|
||||||
|
- torchaudio==2.0.1
|
||||||
|
- torchvision==0.15.1
|
||||||
|
- tornado==6.2
|
||||||
|
- tqdm==4.65.0
|
||||||
|
- transformers==4.26.1
|
||||||
|
- typing-extensions==4.5.0
|
||||||
|
- tzdata==2023.3
|
||||||
|
- tzlocal==4.3
|
||||||
|
- urllib3==1.26.15
|
||||||
|
- uvicorn==0.20.0
|
||||||
|
- validators==0.20.0
|
||||||
|
- watchdog==3.0.0
|
||||||
|
- zipp==3.15.0
|
72
requirements.txt
Normal file
72
requirements.txt
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
altair==4.2.2
|
||||||
|
anyio==3.6.2
|
||||||
|
attrs==23.1.0
|
||||||
|
blinker==1.6.2
|
||||||
|
cachetools==5.3.0
|
||||||
|
certifi==2022.12.7
|
||||||
|
charset-normalizer==3.1.0
|
||||||
|
click==8.1.3
|
||||||
|
colorama==0.4.6
|
||||||
|
comtypes==1.1.14
|
||||||
|
decorator==5.1.1
|
||||||
|
entrypoints==0.4
|
||||||
|
fastapi==0.92.0
|
||||||
|
filelock==3.11.0
|
||||||
|
gitdb==4.0.10
|
||||||
|
GitPython==3.1.31
|
||||||
|
h11==0.14.0
|
||||||
|
huggingface-hub==0.12.1
|
||||||
|
idna==3.4
|
||||||
|
importlib-metadata==6.4.1
|
||||||
|
Jinja2==3.1.2
|
||||||
|
jsonschema==4.17.3
|
||||||
|
markdown-it-py==2.2.0
|
||||||
|
MarkupSafe==2.1.2
|
||||||
|
mdurl==0.1.2
|
||||||
|
MouseInfo==0.1.3
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.1
|
||||||
|
numpy==1.24.2
|
||||||
|
packaging==23.1
|
||||||
|
pandas==1.5.3
|
||||||
|
Pillow==9.5.0
|
||||||
|
protobuf==3.20.3
|
||||||
|
pyarrow==11.0.0
|
||||||
|
pydantic==1.10.7
|
||||||
|
pydeck==0.8.1b0
|
||||||
|
Pygments==2.15.0
|
||||||
|
Pympler==1.0.1
|
||||||
|
pyperclip==1.8.2
|
||||||
|
pyrsistent==0.19.3
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
pytz==2023.3
|
||||||
|
pytz-deprecation-shim==0.1.0.post0
|
||||||
|
pywin32==305
|
||||||
|
pywinauto==0.6.8
|
||||||
|
PyYAML==6.0
|
||||||
|
regex==2023.3.23
|
||||||
|
requests==2.28.2
|
||||||
|
rich==13.3.4
|
||||||
|
six==1.16.0
|
||||||
|
smmap==5.0.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
starlette==0.25.0
|
||||||
|
streamlit==1.21.0
|
||||||
|
sympy==1.11.1
|
||||||
|
tokenizers==0.13.2
|
||||||
|
toml==0.10.2
|
||||||
|
toolz==0.12.0
|
||||||
|
torch==2.0.0
|
||||||
|
torchaudio==2.0.1
|
||||||
|
torchvision==0.15.1
|
||||||
|
tornado==6.2
|
||||||
|
tqdm==4.65.0
|
||||||
|
transformers==4.26.1
|
||||||
|
typing_extensions==4.5.0
|
||||||
|
tzdata==2023.3
|
||||||
|
tzlocal==4.3
|
||||||
|
urllib3==1.26.15
|
||||||
|
uvicorn==0.20.0
|
||||||
|
validators==0.20.0
|
||||||
|
watchdog==3.0.0
|
||||||
|
zipp==3.15.0
|
175
train.py
Normal file
175
train.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Adam
|
||||||
|
import nltk
|
||||||
|
import spacy
|
||||||
|
import string
|
||||||
|
import evaluate # Bleu
|
||||||
|
from torch.utils.data import Dataset, DataLoader, RandomSampler
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import transformers
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
#from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, T5TokenizerFast
|
||||||
|
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
print("Imports succesfully done")
|
||||||
|
|
||||||
|
DEVICE ='cuda:0'
|
||||||
|
#TOKENIZER = AutoTokenizer.from_pretrained("ApoTro/slovak-t5-small")
|
||||||
|
TOKENIZER=AutoTokenizer.from_pretrained('google/mt5-small')
|
||||||
|
#TOKENIZER=AutoTokenizer.from_pretrained('google/mt5-base')
|
||||||
|
TOKENIZER.add_tokens('<sep>')
|
||||||
|
#MODEL = T5ForConditionalGeneration.from_pretrained("ApoTro/slovak-t5-small").to(DEVICE)
|
||||||
|
MODEL = T5ForConditionalGeneration.from_pretrained("google/mt5-small").to(DEVICE)
|
||||||
|
#MODEL = T5ForConditionalGeneration.from_pretrained("google/mt5-base").to(DEVICE)
|
||||||
|
|
||||||
|
#pridam token
|
||||||
|
MODEL.resize_token_embeddings(len(TOKENIZER))
|
||||||
|
#lr = learning rate = 10-5
|
||||||
|
OPTIMIZER = Adam(MODEL.parameters(), lr=0.00001)
|
||||||
|
Q_LEN = 256 # Question Length
|
||||||
|
T_LEN = 32 # Target Length
|
||||||
|
BATCH_SIZE = 4 #dávka dát
|
||||||
|
print("Model succesfully loaded")
|
||||||
|
|
||||||
|
path_train = '/home/omasta/T5_JUPYTER/skquad-221017/train-v1.json'
|
||||||
|
|
||||||
|
with open(path_train) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(data):
|
||||||
|
articles = []
|
||||||
|
for article in data["data"]:
|
||||||
|
for paragraph in article["paragraphs"]:
|
||||||
|
for qa in paragraph["qas"]:
|
||||||
|
question = qa["question"]
|
||||||
|
|
||||||
|
answer = qa["answers"][0]["text"]
|
||||||
|
#input_ = 'generuj_odpoved : ' + paragraph["context"] + ' <sep>' + question + ' <sep>'
|
||||||
|
|
||||||
|
inputs = {"input": paragraph["context"]+'<sep>'+question, "answer": answer}
|
||||||
|
#inputs = {'context': input_ ,'answer':answer}
|
||||||
|
|
||||||
|
articles.append(inputs)
|
||||||
|
|
||||||
|
return articles
|
||||||
|
|
||||||
|
prepared_data = prepare_data(data)
|
||||||
|
print(prepared_data[0])
|
||||||
|
|
||||||
|
#Dataframe
|
||||||
|
data = pd.DataFrame(prepared_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QA_Dataset(Dataset):
|
||||||
|
def __init__(self, tokenizer, dataframe, q_len, t_len):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.q_len = q_len
|
||||||
|
self.t_len = t_len
|
||||||
|
self.data = dataframe
|
||||||
|
self.input = self.data['input']
|
||||||
|
#self.context = self.data["context"]
|
||||||
|
self.answer = self.data['answer']
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.questions)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
input = self.input[idx]
|
||||||
|
answer = self.answer[idx]
|
||||||
|
|
||||||
|
input_tokenized = self.tokenizer(input, max_length=self.q_len, padding="max_length",
|
||||||
|
truncation=True, pad_to_max_length=True, add_special_tokens=True)
|
||||||
|
answer_tokenized = self.tokenizer(answer, max_length=self.t_len, padding="max_length",
|
||||||
|
truncation=True, pad_to_max_length=True, add_special_tokens=True)
|
||||||
|
|
||||||
|
labels = torch.tensor(answer_tokenized["input_ids"], dtype=torch.long)
|
||||||
|
labels[labels == 0] = -100
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor(input_tokenized["input_ids"], dtype=torch.long),
|
||||||
|
"attention_mask": torch.tensor(input_tokenized["attention_mask"], dtype=torch.long),
|
||||||
|
"labels": labels,
|
||||||
|
"decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##DATA LOADERS
|
||||||
|
|
||||||
|
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
|
||||||
|
train_sampler = RandomSampler(train_data.index)
|
||||||
|
val_sampler = RandomSampler(val_data.index)
|
||||||
|
qa_dataset = QA_Dataset(TOKENIZER, data, Q_LEN, T_LEN)
|
||||||
|
|
||||||
|
train_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
|
||||||
|
val_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
|
||||||
|
print("Loaders working fine")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### TRAINING (46MINS ACCORDING THE V1_DATA)
|
||||||
|
|
||||||
|
|
||||||
|
train_loss = 0
|
||||||
|
val_loss = 0
|
||||||
|
train_batch_count = 0
|
||||||
|
val_batch_count = 0
|
||||||
|
|
||||||
|
for epoch in range(4):
|
||||||
|
MODEL.train()
|
||||||
|
for batch in tqdm(train_loader, desc="Training batches"):
|
||||||
|
input_ids = batch["input_ids"].to(DEVICE)
|
||||||
|
attention_mask = batch["attention_mask"].to(DEVICE)
|
||||||
|
labels = batch["labels"].to(DEVICE)
|
||||||
|
decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)
|
||||||
|
|
||||||
|
outputs = MODEL(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
decoder_attention_mask=decoder_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
OPTIMIZER.zero_grad()
|
||||||
|
outputs.loss.backward()
|
||||||
|
OPTIMIZER.step()
|
||||||
|
train_loss += outputs.loss.item()
|
||||||
|
train_batch_count += 1
|
||||||
|
|
||||||
|
#Evaluation
|
||||||
|
MODEL.eval()
|
||||||
|
for batch in tqdm(val_loader, desc="Validation batches"):
|
||||||
|
input_ids = batch["input_ids"].to(DEVICE)
|
||||||
|
attention_mask = batch["attention_mask"].to(DEVICE)
|
||||||
|
labels = batch["labels"].to(DEVICE)
|
||||||
|
decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)
|
||||||
|
|
||||||
|
outputs = MODEL(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
decoder_attention_mask=decoder_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
OPTIMIZER.zero_grad()
|
||||||
|
outputs.loss.backward()
|
||||||
|
OPTIMIZER.step()
|
||||||
|
val_loss += outputs.loss.item()
|
||||||
|
val_batch_count += 1
|
||||||
|
print(f"{epoch+1}/{2} -> Train loss: {train_loss / train_batch_count}\tValidation loss: {val_loss/val_batch_count}")
|
||||||
|
|
||||||
|
|
||||||
|
print("Training done succesfully")
|
||||||
|
|
||||||
|
## SAVE FINE_TUNED MODEL
|
||||||
|
MODEL.save_pretrained("qa_model_mT5_small")
|
||||||
|
TOKENIZER.save_pretrained('qa_tokenizer_mT5_small')
|
139
usecase.py
Normal file
139
usecase.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
## IMPORT NESSESARY EQUIPMENTS
|
||||||
|
from transformers import T5ForConditionalGeneration, T5Tokenizer,AutoTokenizer
|
||||||
|
import torch
|
||||||
|
import evaluate # Bleu
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import statistics
|
||||||
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||||
|
## TURN WARNINGS OFF
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
##13/03/23 added
|
||||||
|
from rouge import Rouge
|
||||||
|
|
||||||
|
# Názov modelu
|
||||||
|
DEVICE ='cuda:0'
|
||||||
|
|
||||||
|
|
||||||
|
#T5 MODEL
|
||||||
|
#model_name = 'T5_SK_model'
|
||||||
|
#model_dir = "/home/omasta/T5_JUPYTER/qa_model"
|
||||||
|
#tokenizer_dir = "/home/omasta/T5_JUPYTER/qa_tokenizer"
|
||||||
|
|
||||||
|
#mT5 SMALL MODEL
|
||||||
|
model_name = 'mT5_SMALL'
|
||||||
|
model_dir = '/home/omasta/T5_JUPYTER/qa_model_mT5_small'
|
||||||
|
tokenizer_dir = '/home/omasta/T5_JUPYTER/qa_tokenizer_mT5_small'
|
||||||
|
|
||||||
|
#Načítanie modelu z adresára
|
||||||
|
MODEL = T5ForConditionalGeneration.from_pretrained(model_dir, from_tf=False, return_dict=True).to(DEVICE)
|
||||||
|
print("Model succesfully loaded!")
|
||||||
|
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
|
||||||
|
print("Tokenizer succesfully loaded!")
|
||||||
|
Q_LEN = 512
|
||||||
|
TOKENIZER.add_tokens('<sep>')
|
||||||
|
MODEL.resize_token_embeddings(len(TOKENIZER))
|
||||||
|
|
||||||
|
def predict_answer(data, ref_answer=None,random=None):
|
||||||
|
predictions=[]
|
||||||
|
for i in data:
|
||||||
|
inputs = TOKENIZER(i['input'], max_length=Q_LEN, padding="max_length", truncation=True, add_special_tokens=True)
|
||||||
|
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
||||||
|
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
||||||
|
outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
predicted_answer = TOKENIZER.decode(outputs.flatten(), skip_special_tokens=True)
|
||||||
|
ref_answer = i['answer'].lower()
|
||||||
|
#print(ref_answer)
|
||||||
|
if ref_answer:
|
||||||
|
# Load the Bleu metric
|
||||||
|
bleu = evaluate.load("google_bleu")
|
||||||
|
#print('debug')
|
||||||
|
#precision = list(precision_score(ref_answer, predicted_answer))
|
||||||
|
#recall = list(recall_score(ref_answer, predicted_answer))
|
||||||
|
#f1 = list(f1_score(ref_answer, predicted_answer))
|
||||||
|
score = bleu.compute(predictions=[predicted_answer],
|
||||||
|
references=[ref_answer])
|
||||||
|
predictions.append({'prediction':predicted_answer,'ref_answer':ref_answer,'score':score['google_bleu']})
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def prepare_data(data):
|
||||||
|
articles = []
|
||||||
|
for article in data["data"]:
|
||||||
|
for paragraph in article["paragraphs"]:
|
||||||
|
for qa in paragraph["qas"]:
|
||||||
|
question = qa["question"]
|
||||||
|
answer = qa["answers"][0]["text"]
|
||||||
|
inputs = {"input": paragraph["context"]+ "<sep>" + question, "answer": answer}
|
||||||
|
articles.append(inputs)
|
||||||
|
|
||||||
|
return articles
|
||||||
|
|
||||||
|
dev_data_path = '/home/omasta/T5_JUPYTER/skquad-221017/dev-v1.json'
|
||||||
|
with open(dev_data_path,'r') as f:
|
||||||
|
data=json.load(f)
|
||||||
|
#print('data imported')
|
||||||
|
|
||||||
|
dev_data = prepare_data(data)
|
||||||
|
|
||||||
|
#print('data prepared')
|
||||||
|
print(f'Number of dev samples {len(dev_data)}')
|
||||||
|
print(dev_data[0])
|
||||||
|
bleu_score = []
|
||||||
|
precisions=[]
|
||||||
|
f1_scores=[]
|
||||||
|
recall_scores=[]
|
||||||
|
rouge_1 = []
|
||||||
|
rouge_2 = []
|
||||||
|
#X = 150
|
||||||
|
evaluate = predict_answer(dev_data)
|
||||||
|
rouge = Rouge()
|
||||||
|
for item in evaluate:
|
||||||
|
bleu_score.append(item['score'])
|
||||||
|
try:
|
||||||
|
#scores = rouge.get_scores(item['prediction'], item['ref_answer'], avg=True)
|
||||||
|
precision=precision_score(list(item['ref_answer']), list(item['prediction']),average='macro')
|
||||||
|
recall=recall_score(list(item['ref_answer']), list(item['prediction']),average='macro')
|
||||||
|
f1=f1_score(list(item['ref_answer']), list(item['prediction']),average='macro')
|
||||||
|
except ValueError:
|
||||||
|
precision=0
|
||||||
|
recall=0
|
||||||
|
f1=0
|
||||||
|
precisions.append(precision)
|
||||||
|
f1_scores.append(f1)
|
||||||
|
recall_scores.append(recall)
|
||||||
|
|
||||||
|
|
||||||
|
def rouge_eval(dict_x):
|
||||||
|
rouge = Rouge()
|
||||||
|
rouge_scores=[]
|
||||||
|
for item in dict_x:
|
||||||
|
if item['prediction'] and item['ref_answer']:
|
||||||
|
rouge_score = rouge.get_scores(item['prediction'], item['ref_answer'])
|
||||||
|
rouge_scores.append(rouge_score)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
return rouge_scores
|
||||||
|
|
||||||
|
|
||||||
|
print(f'VYHODNOTENIE VYSLEDKOV : ------------------------')
|
||||||
|
#print(evaluate)
|
||||||
|
#bleu_score_total = statistics.mean(bleu_score)
|
||||||
|
#recall_score_total= statistics.mean(recall_scores)
|
||||||
|
#f1_score_total = statistics.mean(f1_scores)
|
||||||
|
#precision_total = statistics.mean(precisions)
|
||||||
|
#print(f'Bleu_score of model {model_name} : ',bleu_score_total)
|
||||||
|
#print(f'Recall of model {model_name}: ',recall_score_total)
|
||||||
|
#print(f'F1 of model {model_name} : ', f1_score_total)
|
||||||
|
#print(f'Precision of model {model_name}: :',precision_total)
|
||||||
|
#print(rouge_eval(evaluate))
|
||||||
|
print(f'{model_name} results')
|
||||||
|
rouge_scores = rouge_eval(evaluate)
|
||||||
|
rouge_values = [score[0]['rouge-1']['f'] for score in rouge_scores]
|
||||||
|
mean_rouge_score = statistics.mean(rouge_values)
|
||||||
|
print(f'Rouge:{mean_rouge_score}')
|
||||||
|
|
||||||
|
rouge2_values = [score[0]['rouge-2']['f'] for score in rouge_scores]
|
||||||
|
mean_rouge_score =statistics.mean(rouge2_values)
|
||||||
|
print(f'Rouge-2:{mean_rouge_score}')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user