164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
|
import torch
|
||
|
import json
|
||
|
from tqdm import tqdm
|
||
|
import torch.nn as nn
|
||
|
from torch.optim import Adam
|
||
|
import nltk
|
||
|
import string
|
||
|
from torch.utils.data import Dataset, DataLoader, RandomSampler
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
import transformers
|
||
|
#from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, T5TokenizerFast
|
||
|
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
||
|
import warnings
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
warnings.filterwarnings("ignore")
|
||
|
|
||
|
print("Imports succesfully done")
|
||
|
|
||
|
DEVICE ='cuda:0'
|
||
|
TOKENIZER=AutoTokenizer.from_pretrained('google/umt5-small')
|
||
|
TOKENIZER.add_tokens('<sep>')
|
||
|
MODEL = T5ForConditionalGeneration.from_pretrained("google/mt5-small").to(DEVICE)
|
||
|
|
||
|
#pridam token
|
||
|
MODEL.resize_token_embeddings(len(TOKENIZER))
|
||
|
#lr = learning rate = 10-5
|
||
|
OPTIMIZER = Adam(MODEL.parameters(), lr=0.00001)
|
||
|
Q_LEN = 256 # Question Length
|
||
|
T_LEN = 32 # Target Length
|
||
|
BATCH_SIZE = 4 #dávka dát
|
||
|
print("Model succesfully loaded")
|
||
|
from datasets import load_dataset
|
||
|
|
||
|
dataset_english = load_dataset("squad_v2")
|
||
|
dataset_slovak = load_dataset("TUKE-DeutscheTelekom/skquad")
|
||
|
dataset_polish = load_dataset("clarin-pl/poquad")
|
||
|
|
||
|
def prepare_data_english(data):
|
||
|
articles = []
|
||
|
for item in tqdm(data["train"],desc="Preparing training datas"):
|
||
|
context = item["context"]
|
||
|
question = item["question"]
|
||
|
try:
|
||
|
start_position = item['answers']['answer_start'][0]
|
||
|
except IndexError:
|
||
|
continue
|
||
|
text_length = len(item['answers']['text'][0])
|
||
|
target_text = context[start_position : start_position + text_length]
|
||
|
inputs = {"input": context+'<sep>'+question, "answer": target_text}
|
||
|
articles.append(inputs)
|
||
|
return articles
|
||
|
data_english = prepare_data_english(dataset_english)
|
||
|
data_polish = prepare_data_english(dataset_polish)
|
||
|
data_slovak = prepare_data_english(dataset_slovak)
|
||
|
|
||
|
train_data = data_slovak + data_english + data_polish
|
||
|
print("Training Samples : ",len(train_data))
|
||
|
|
||
|
|
||
|
#Dataframe
|
||
|
data = pd.DataFrame(train_data)
|
||
|
|
||
|
class QA_Dataset(Dataset):
|
||
|
def __init__(self, tokenizer, dataframe, q_len, t_len):
|
||
|
self.tokenizer = tokenizer
|
||
|
self.q_len = q_len
|
||
|
self.t_len = t_len
|
||
|
self.data = dataframe
|
||
|
self.input = self.data['input']
|
||
|
#self.context = self.data["context"]
|
||
|
self.answer = self.data['answer']
|
||
|
def __len__(self):
|
||
|
return len(self.questions)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
input = self.input[idx]
|
||
|
answer = self.answer[idx]
|
||
|
|
||
|
input_tokenized = self.tokenizer(input, max_length=self.q_len, padding="max_length",
|
||
|
truncation=True, pad_to_max_length=True, add_special_tokens=True)
|
||
|
answer_tokenized = self.tokenizer(answer, max_length=self.t_len, padding="max_length",
|
||
|
truncation=True, pad_to_max_length=True, add_special_tokens=True)
|
||
|
|
||
|
labels = torch.tensor(answer_tokenized["input_ids"], dtype=torch.long)
|
||
|
labels[labels == 0] = -100
|
||
|
|
||
|
return {
|
||
|
"input_ids": torch.tensor(input_tokenized["input_ids"], dtype=torch.long),
|
||
|
"attention_mask": torch.tensor(input_tokenized["attention_mask"], dtype=torch.long),
|
||
|
"labels": labels,
|
||
|
"decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long)
|
||
|
}
|
||
|
|
||
|
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
|
||
|
train_sampler = RandomSampler(train_data.index)
|
||
|
val_sampler = RandomSampler(val_data.index)
|
||
|
qa_dataset = QA_Dataset(TOKENIZER, data, Q_LEN, T_LEN)
|
||
|
|
||
|
train_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
|
||
|
val_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
|
||
|
print("Loaders working fine")
|
||
|
|
||
|
### TRAINING (46MINS ACCORDING THE V1_DATA)
|
||
|
train_loss = 0
|
||
|
val_loss = 0
|
||
|
train_batch_count = 0
|
||
|
val_batch_count = 0
|
||
|
|
||
|
|
||
|
#TODO
|
||
|
# Make a great epochs number
|
||
|
# Evaluate results and find out how to calculate a real rouge metric
|
||
|
for epoch in range(2):
|
||
|
MODEL.train()
|
||
|
for batch in tqdm(train_loader, desc="Training batches"):
|
||
|
input_ids = batch["input_ids"].to(DEVICE)
|
||
|
attention_mask = batch["attention_mask"].to(DEVICE)
|
||
|
labels = batch["labels"].to(DEVICE)
|
||
|
decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)
|
||
|
|
||
|
outputs = MODEL(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
labels=labels,
|
||
|
decoder_attention_mask=decoder_attention_mask
|
||
|
)
|
||
|
|
||
|
OPTIMIZER.zero_grad()
|
||
|
outputs.loss.backward()
|
||
|
OPTIMIZER.step()
|
||
|
train_loss += outputs.loss.item()
|
||
|
train_batch_count += 1
|
||
|
|
||
|
#Evaluation
|
||
|
MODEL.eval()
|
||
|
for batch in tqdm(val_loader, desc="Validation batches"):
|
||
|
input_ids = batch["input_ids"].to(DEVICE)
|
||
|
attention_mask = batch["attention_mask"].to(DEVICE)
|
||
|
labels = batch["labels"].to(DEVICE)
|
||
|
decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)
|
||
|
|
||
|
outputs = MODEL(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
labels=labels,
|
||
|
decoder_attention_mask=decoder_attention_mask
|
||
|
)
|
||
|
|
||
|
OPTIMIZER.zero_grad()
|
||
|
outputs.loss.backward()
|
||
|
OPTIMIZER.step()
|
||
|
val_loss += outputs.loss.item()
|
||
|
val_batch_count += 1
|
||
|
print(f"{epoch+1}/{2} -> Train loss: {train_loss / train_batch_count}\tValidation loss: {val_loss/val_batch_count}")
|
||
|
|
||
|
|
||
|
print("Training done succesfully")
|
||
|
|
||
|
## SAVE FINE_TUNED MODEL
|
||
|
MODEL.save_pretrained("qa_model_umT5_small_3LANG")
|
||
|
TOKENIZER.save_pretrained('qa_tokenizer_umT5_small_3LANG')
|
||
|
|