2023-11-10 22:53:03 +00:00
|
|
|
## IMPORT NESSESARY EQUIPMENTS
|
|
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer,AutoTokenizer
|
|
|
|
import torch
|
|
|
|
import evaluate # Bleu
|
|
|
|
import json
|
|
|
|
import random
|
|
|
|
import statistics
|
|
|
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
|
|
|
## TURN WARNINGS OFF
|
|
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
##13/03/23 added
|
|
|
|
from rouge import Rouge
|
2024-02-17 19:00:49 +00:00
|
|
|
from tqdm import tqdm
|
|
|
|
from datasets import load_dataset
|
|
|
|
import re
|
|
|
|
##CUSTOM ROUGE METRIC - NEW TODO:
|
|
|
|
|
2023-11-10 22:53:03 +00:00
|
|
|
|
|
|
|
# Názov modelu
|
|
|
|
DEVICE ='cuda:0'
|
|
|
|
|
|
|
|
|
|
|
|
#T5 MODEL
|
|
|
|
#model_name = 'T5_SK_model'
|
|
|
|
#model_dir = "/home/omasta/T5_JUPYTER/qa_model"
|
|
|
|
#tokenizer_dir = "/home/omasta/T5_JUPYTER/qa_tokenizer"
|
|
|
|
|
|
|
|
#mT5 SMALL MODEL
|
2024-02-17 19:00:49 +00:00
|
|
|
model_name = 'qa_model'
|
|
|
|
model_dir = '/home/omasta/T5_JUPYTER/qa_model_mT5_polish'
|
|
|
|
tokenizer_dir = '/home/omasta/T5_JUPYTER/qa_tokenizer_mT5_polish'
|
2023-11-10 22:53:03 +00:00
|
|
|
|
|
|
|
#Načítanie modelu z adresára
|
|
|
|
MODEL = T5ForConditionalGeneration.from_pretrained(model_dir, from_tf=False, return_dict=True).to(DEVICE)
|
|
|
|
print("Model succesfully loaded!")
|
|
|
|
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)
|
|
|
|
print("Tokenizer succesfully loaded!")
|
|
|
|
Q_LEN = 512
|
|
|
|
TOKENIZER.add_tokens('<sep>')
|
|
|
|
MODEL.resize_token_embeddings(len(TOKENIZER))
|
|
|
|
|
2024-02-17 19:00:49 +00:00
|
|
|
def nahradit_znaky(retezec):
|
|
|
|
novy_retezec = retezec.replace('[', ' ').replace(']', ' ')
|
|
|
|
return novy_retezec
|
|
|
|
|
|
|
|
|
2023-11-10 22:53:03 +00:00
|
|
|
def predict_answer(data, ref_answer=None,random=None):
|
|
|
|
predictions=[]
|
2024-02-17 19:00:49 +00:00
|
|
|
for i in tqdm(data,desc="predicting"):
|
2023-11-10 22:53:03 +00:00
|
|
|
inputs = TOKENIZER(i['input'], max_length=Q_LEN, padding="max_length", truncation=True, add_special_tokens=True)
|
|
|
|
input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
|
|
|
attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)
|
|
|
|
outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
predicted_answer = TOKENIZER.decode(outputs.flatten(), skip_special_tokens=True)
|
|
|
|
ref_answer = i['answer'].lower()
|
|
|
|
#print(ref_answer)
|
|
|
|
if ref_answer:
|
|
|
|
# Load the Bleu metric
|
2024-02-17 19:00:49 +00:00
|
|
|
#bleu = evaluate.load("google_bleu")
|
2023-11-10 22:53:03 +00:00
|
|
|
#print('debug')
|
|
|
|
#precision = list(precision_score(ref_answer, predicted_answer))
|
|
|
|
#recall = list(recall_score(ref_answer, predicted_answer))
|
|
|
|
#f1 = list(f1_score(ref_answer, predicted_answer))
|
2024-02-17 19:00:49 +00:00
|
|
|
#score = bleu.compute(predictions=[predicted_answer],
|
|
|
|
# references=[ref_answer])
|
|
|
|
predictions.append({'prediction':predicted_answer,'ref_answer':ref_answer})
|
2023-11-10 22:53:03 +00:00
|
|
|
return predictions
|
|
|
|
|
|
|
|
def prepare_data(data):
|
|
|
|
articles = []
|
|
|
|
for article in data["data"]:
|
|
|
|
for paragraph in article["paragraphs"]:
|
|
|
|
for qa in paragraph["qas"]:
|
|
|
|
question = qa["question"]
|
|
|
|
answer = qa["answers"][0]["text"]
|
|
|
|
inputs = {"input": paragraph["context"]+ "<sep>" + question, "answer": answer}
|
|
|
|
articles.append(inputs)
|
|
|
|
return articles
|
|
|
|
|
2024-02-17 19:00:49 +00:00
|
|
|
def prepare_polish_data(data):
|
|
|
|
arcs = list()
|
|
|
|
for i in range(len(data)):
|
|
|
|
questions=data[i]["question"]
|
|
|
|
try:
|
|
|
|
answer = nahradit_znaky(', '.join(data[i]["answers"]["text"]))
|
|
|
|
except KeyError:
|
|
|
|
continue
|
|
|
|
context = data[i]["context"]
|
|
|
|
inputs = {"input":context+"<sep>"+questions,"answer":answer}
|
|
|
|
arcs.append(inputs)
|
|
|
|
return arcs
|
|
|
|
|
2023-11-10 22:53:03 +00:00
|
|
|
|
2024-02-17 19:00:49 +00:00
|
|
|
#dataset = load_dataset("clarin-pl/poquad")
|
|
|
|
dataset = load_dataset("squad_v2")
|
|
|
|
dev_data = prepare_polish_data(dataset["validation"])
|
2023-11-10 22:53:03 +00:00
|
|
|
|
|
|
|
#print('data prepared')
|
|
|
|
print(f'Number of dev samples {len(dev_data)}')
|
2024-02-17 19:00:49 +00:00
|
|
|
#print(dev_data[0])
|
2023-11-10 22:53:03 +00:00
|
|
|
bleu_score = []
|
|
|
|
precisions=[]
|
|
|
|
f1_scores=[]
|
|
|
|
recall_scores=[]
|
|
|
|
rouge_1 = []
|
|
|
|
rouge_2 = []
|
|
|
|
#X = 150
|
|
|
|
evaluate = predict_answer(dev_data)
|
|
|
|
rouge = Rouge()
|
2024-02-17 19:00:49 +00:00
|
|
|
for item in tqdm(evaluate,desc="evaluating"):
|
2023-11-10 22:53:03 +00:00
|
|
|
try:
|
2024-02-17 19:00:49 +00:00
|
|
|
scores = rouge.get_scores(item['prediction'], item['ref_answer'])
|
2023-11-10 22:53:03 +00:00
|
|
|
precision=precision_score(list(item['ref_answer']), list(item['prediction']),average='macro')
|
|
|
|
recall=recall_score(list(item['ref_answer']), list(item['prediction']),average='macro')
|
|
|
|
f1=f1_score(list(item['ref_answer']), list(item['prediction']),average='macro')
|
|
|
|
except ValueError:
|
|
|
|
precision=0
|
|
|
|
recall=0
|
|
|
|
f1=0
|
|
|
|
precisions.append(precision)
|
|
|
|
f1_scores.append(f1)
|
|
|
|
recall_scores.append(recall)
|
|
|
|
|
|
|
|
|
|
|
|
def rouge_eval(dict_x):
|
|
|
|
rouge = Rouge()
|
|
|
|
rouge_scores=[]
|
|
|
|
for item in dict_x:
|
|
|
|
if item['prediction'] and item['ref_answer']:
|
|
|
|
rouge_score = rouge.get_scores(item['prediction'], item['ref_answer'])
|
|
|
|
rouge_scores.append(rouge_score)
|
|
|
|
else:
|
|
|
|
continue
|
|
|
|
return rouge_scores
|
|
|
|
|
|
|
|
|
|
|
|
print(f'VYHODNOTENIE VYSLEDKOV : ------------------------')
|
|
|
|
#print(evaluate)
|
|
|
|
#bleu_score_total = statistics.mean(bleu_score)
|
2024-02-17 19:00:49 +00:00
|
|
|
recall_score_total= statistics.mean(recall_scores)
|
|
|
|
f1_score_total = statistics.mean(f1_scores)
|
|
|
|
precision_total = statistics.mean(precisions)
|
2023-11-10 22:53:03 +00:00
|
|
|
#print(f'Bleu_score of model {model_name} : ',bleu_score_total)
|
2024-02-17 19:00:49 +00:00
|
|
|
print(f'Recall of model {model_name}: ',recall_score_total)
|
|
|
|
print(f'F1 of model {model_name} : ', f1_score_total)
|
|
|
|
print(f'Precision of model {model_name}: :',precision_total)
|
|
|
|
print(model_dir)
|
|
|
|
print(rouge_eval(evaluate))
|
2023-11-10 22:53:03 +00:00
|
|
|
print(f'{model_name} results')
|
|
|
|
rouge_scores = rouge_eval(evaluate)
|
|
|
|
rouge_values = [score[0]['rouge-1']['f'] for score in rouge_scores]
|
|
|
|
mean_rouge_score = statistics.mean(rouge_values)
|
2024-02-17 19:00:49 +00:00
|
|
|
print(f'Rouge mean score:{mean_rouge_score}')
|
2023-11-10 22:53:03 +00:00
|
|
|
|
|
|
|
rouge2_values = [score[0]['rouge-2']['f'] for score in rouge_scores]
|
|
|
|
mean_rouge_score =statistics.mean(rouge2_values)
|
2024-02-17 19:00:49 +00:00
|
|
|
print(f'Rouge-2 mean score:{mean_rouge_score}')
|
2023-11-10 22:53:03 +00:00
|
|
|
|