191 lines
6.5 KiB
Python
191 lines
6.5 KiB
Python
import torch
|
|
import json
|
|
from tqdm import tqdm
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
import nltk
|
|
import spacy
|
|
import string
|
|
import evaluate # Bleu
|
|
from torch.utils.data import Dataset, DataLoader, RandomSampler
|
|
import pandas as pd
|
|
import numpy as np
|
|
import transformers
|
|
from sklearn.model_selection import train_test_split
|
|
import matplotlib.pyplot as plt
|
|
#from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, T5TokenizerFast
|
|
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
|
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
print("Imports succesfully done")
|
|
|
|
DEVICE ='cuda:0'
|
|
#TOKENIZER = AutoTokenizer.from_pretrained("ApoTro/slovak-t5-small")
|
|
TOKENIZER=AutoTokenizer.from_pretrained('google/mt5-small')
|
|
#TOKENIZER=AutoTokenizer.from_pretrained('google/mt5-base')
|
|
TOKENIZER.add_tokens('<sep>')
|
|
#MODEL = T5ForConditionalGeneration.from_pretrained("ApoTro/slovak-t5-small").to(DEVICE)
|
|
MODEL = T5ForConditionalGeneration.from_pretrained("google/mt5-small").to(DEVICE)
|
|
#MODEL = T5ForConditionalGeneration.from_pretrained("google/mt5-base").to(DEVICE)
|
|
|
|
#pridam token
|
|
MODEL.resize_token_embeddings(len(TOKENIZER))
|
|
#lr = learning rate = 10-5
|
|
OPTIMIZER = Adam(MODEL.parameters(), lr=0.00001)
|
|
Q_LEN = 256 # Question Length
|
|
T_LEN = 32 # Target Length
|
|
BATCH_SIZE = 4 #dávka dát
|
|
print("Model succesfully loaded")
|
|
from datasets import load_dataset
|
|
|
|
dataset = load_dataset("squad_v2")
|
|
print(dataset["train"][0])
|
|
#path_train = '/home/omasta/T5_JUPYTER/skquad-221017/train-v1.json'
|
|
path_train = "poquad-train.json"
|
|
with open(path_train) as f:
|
|
data = json.load(f)
|
|
|
|
|
|
def nahradit_znaky(retezec):
|
|
novy_retezec = retezec.replace('[', ' ').replace(']', ' ')
|
|
return novy_retezec
|
|
|
|
def prepare_data(data):
|
|
articles = []
|
|
for article in data["data"]:
|
|
for paragraph in article["paragraphs"]:
|
|
for qa in paragraph["qas"]:
|
|
question = qa["question"]
|
|
|
|
answer = qa["answers"][0]["text"]
|
|
#input_ = 'generuj_odpoved : ' + paragraph["context"] + ' <sep>' + question + ' <sep>'
|
|
|
|
inputs = {"input": paragraph["context"]+'<sep>'+question, "answer": answer}
|
|
#inputs = {'context': input_ ,'answer':answer}
|
|
|
|
articles.append(inputs)
|
|
|
|
return articles
|
|
def prep_data(data):
|
|
arcs = list()
|
|
for i in range(len(data)):
|
|
questions=data[i]["question"]
|
|
try:
|
|
answer = nahradit_znaky(', '.join(data[i]["answers"]["text"]))
|
|
except KeyError:
|
|
continue
|
|
context = data[i]["context"]
|
|
inputs = {"input":context+"<sep>"+questions,"answer":answer}
|
|
arcs.append(inputs)
|
|
return arcs
|
|
|
|
#print(dataset["train"][0]["answers"]["text"])
|
|
|
|
prepared_data=prep_data(dataset["train"])
|
|
#prepared_data = prepare_data(data)
|
|
print(prepared_data[0])
|
|
|
|
#Dataframe
|
|
data = pd.DataFrame(prepared_data)
|
|
|
|
class QA_Dataset(Dataset):
|
|
def __init__(self, tokenizer, dataframe, q_len, t_len):
|
|
self.tokenizer = tokenizer
|
|
self.q_len = q_len
|
|
self.t_len = t_len
|
|
self.data = dataframe
|
|
self.input = self.data['input']
|
|
#self.context = self.data["context"]
|
|
self.answer = self.data['answer']
|
|
def __len__(self):
|
|
return len(self.questions)
|
|
|
|
def __getitem__(self, idx):
|
|
input = self.input[idx]
|
|
answer = self.answer[idx]
|
|
|
|
input_tokenized = self.tokenizer(input, max_length=self.q_len, padding="max_length",
|
|
truncation=True, pad_to_max_length=True, add_special_tokens=True)
|
|
answer_tokenized = self.tokenizer(answer, max_length=self.t_len, padding="max_length",
|
|
truncation=True, pad_to_max_length=True, add_special_tokens=True)
|
|
|
|
labels = torch.tensor(answer_tokenized["input_ids"], dtype=torch.long)
|
|
labels[labels == 0] = -100
|
|
|
|
return {
|
|
"input_ids": torch.tensor(input_tokenized["input_ids"], dtype=torch.long),
|
|
"attention_mask": torch.tensor(input_tokenized["attention_mask"], dtype=torch.long),
|
|
"labels": labels,
|
|
"decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long)
|
|
}
|
|
|
|
|
|
|
|
##DATA LOADERS
|
|
|
|
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
|
|
train_sampler = RandomSampler(train_data.index)
|
|
val_sampler = RandomSampler(val_data.index)
|
|
qa_dataset = QA_Dataset(TOKENIZER, data, Q_LEN, T_LEN)
|
|
|
|
train_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
|
|
val_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
|
|
print("Loaders working fine")
|
|
|
|
### TRAINING (46MINS ACCORDING THE V1_DATA)
|
|
train_loss = 0
|
|
val_loss = 0
|
|
train_batch_count = 0
|
|
val_batch_count = 0
|
|
|
|
for epoch in range(2):
|
|
MODEL.train()
|
|
for batch in tqdm(train_loader, desc="Training batches"):
|
|
input_ids = batch["input_ids"].to(DEVICE)
|
|
attention_mask = batch["attention_mask"].to(DEVICE)
|
|
labels = batch["labels"].to(DEVICE)
|
|
decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)
|
|
|
|
outputs = MODEL(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
labels=labels,
|
|
decoder_attention_mask=decoder_attention_mask
|
|
)
|
|
|
|
OPTIMIZER.zero_grad()
|
|
outputs.loss.backward()
|
|
OPTIMIZER.step()
|
|
train_loss += outputs.loss.item()
|
|
train_batch_count += 1
|
|
|
|
#Evaluation
|
|
MODEL.eval()
|
|
for batch in tqdm(val_loader, desc="Validation batches"):
|
|
input_ids = batch["input_ids"].to(DEVICE)
|
|
attention_mask = batch["attention_mask"].to(DEVICE)
|
|
labels = batch["labels"].to(DEVICE)
|
|
decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)
|
|
|
|
outputs = MODEL(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
labels=labels,
|
|
decoder_attention_mask=decoder_attention_mask
|
|
)
|
|
|
|
OPTIMIZER.zero_grad()
|
|
outputs.loss.backward()
|
|
OPTIMIZER.step()
|
|
val_loss += outputs.loss.item()
|
|
val_batch_count += 1
|
|
print(f"{epoch+1}/{2} -> Train loss: {train_loss / train_batch_count}\tValidation loss: {val_loss/val_batch_count}")
|
|
|
|
|
|
print("Training done succesfully")
|
|
|
|
## SAVE FINE_TUNED MODEL
|
|
MODEL.save_pretrained("qa_model_mT5_english")
|
|
TOKENIZER.save_pretrained('qa_tokenizer_mT5_english')
|