153 lines
5.9 KiB
Python
153 lines
5.9 KiB
Python
import os
|
||
import torch
|
||
from transformers import (
|
||
AutoTokenizer,
|
||
AutoModelForSeq2SeqLM,
|
||
Trainer,
|
||
TrainingArguments,
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 1. Načítanie modelu a tokenizátora zo súborového systému
|
||
# ----------------------------------------------------------
|
||
|
||
# Zistenie aktuálnej cesty k projektu
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
# Načítanie slovenského modelu T5 a tokenizátora
|
||
tokenizer = AutoTokenizer.from_pretrained("TUKE-KEMT/slovak-t5-base", legacy=False)
|
||
tokenizer.pad_token = tokenizer.eos_token # Nastavenie koncového tokenu ako výplňového
|
||
model = AutoModelForSeq2SeqLM.from_pretrained("TUKE-KEMT/slovak-t5-base")
|
||
|
||
# Výber výpočtového zariadenia (v tomto prípade CPU)
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
model.to(device)
|
||
|
||
# ----------------------------------------------------------
|
||
# 2. Načítanie vstupných dát (zdravé a cieľové texty)
|
||
# ----------------------------------------------------------
|
||
|
||
# Cesty k priečinkom s textovými súbormi
|
||
healthy_folder = r"project\text\zdr_cleaned\4"
|
||
dementia_folder = r"project\text\chory_cleaned\4"
|
||
|
||
|
||
# Pomocná funkcia na načítanie všetkých textov z daného priečinka
|
||
def load_texts_from_folder(folder_path):
|
||
texts = []
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith(".txt"):
|
||
with open(
|
||
os.path.join(folder_path, filename), "r", encoding="utf-8"
|
||
) as file:
|
||
texts.append(file.read().strip())
|
||
return texts
|
||
|
||
|
||
# Načítanie vstupných textov
|
||
healthy_texts = load_texts_from_folder(healthy_folder)
|
||
dementia_texts = load_texts_from_folder(dementia_folder)
|
||
|
||
# Kontrola: počet textov v oboch priečinkoch musí byť rovnaký
|
||
assert len(healthy_texts) == len(dementia_texts), "Počet textov musí byť rovnaký."
|
||
|
||
# Vytvorenie zoznamu trénovacích dvojíc (vstup → cieľ)
|
||
train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)]
|
||
|
||
# ----------------------------------------------------------
|
||
# 3. Tokenizácia vstupných dát (na formát požadovaný modelom)
|
||
# ----------------------------------------------------------
|
||
|
||
|
||
def tokenize_function(example):
|
||
inputs = tokenizer(
|
||
example["input"], padding="max_length", truncation=True, max_length=128
|
||
)
|
||
outputs = tokenizer(
|
||
example["output"], padding="max_length", truncation=True, max_length=128
|
||
)
|
||
return {
|
||
"input_ids": inputs["input_ids"],
|
||
"attention_mask": inputs["attention_mask"],
|
||
"labels": outputs["input_ids"],
|
||
}
|
||
|
||
|
||
# Vykonanie tokenizácie pre všetky trénovacie dvojice
|
||
tokenized_train_data = {key: [] for key in ["input_ids", "attention_mask", "labels"]}
|
||
for example in train_data:
|
||
tokenized = tokenize_function(example)
|
||
for key in tokenized_train_data:
|
||
tokenized_train_data[key].append(tokenized[key])
|
||
|
||
# ----------------------------------------------------------
|
||
# 4. Vytvorenie datasetu pre PyTorch
|
||
# ----------------------------------------------------------
|
||
|
||
|
||
class CustomDataset(torch.utils.data.Dataset):
|
||
def __init__(self, data):
|
||
self.data = data
|
||
|
||
def __getitem__(self, idx):
|
||
return {
|
||
"input_ids": torch.tensor(self.data["input_ids"][idx], dtype=torch.long),
|
||
"labels": torch.tensor(self.data["labels"][idx], dtype=torch.long),
|
||
"attention_mask": torch.tensor(
|
||
self.data["attention_mask"][idx], dtype=torch.long
|
||
),
|
||
}
|
||
|
||
def __len__(self):
|
||
return len(self.data["input_ids"])
|
||
|
||
|
||
# Inicializácia trénovacieho datasetu
|
||
train_dataset = CustomDataset(tokenized_train_data)
|
||
|
||
# ----------------------------------------------------------
|
||
# 5. Nastavenie cieľovej cesty pre uloženie finálneho modelu
|
||
# ----------------------------------------------------------
|
||
|
||
model_save_path = os.path.join(script_dir, "trained_slovak_t5_cpu")
|
||
|
||
# ----------------------------------------------------------
|
||
# 6. Definícia parametrov trénovania modelu
|
||
# ----------------------------------------------------------
|
||
|
||
training_args = TrainingArguments(
|
||
save_strategy="no", # Automatické ukladanie checkpointov je vypnuté (model sa uloží len na konci ručne)
|
||
logging_steps=50, # Frekvencia logovania – každých 50 krokov sa vypíšu metriky a stav
|
||
per_device_train_batch_size=2, # Veľkosť dávky pre jeden krok tréningu (na jedno zariadenie)
|
||
gradient_accumulation_steps=4, # Počet krokov, počas ktorých sa akumulujú gradienty pred aktualizáciou váh
|
||
# Efektívna veľkosť dávky = 2 × 4 = 8 príkladov
|
||
num_train_epochs=15, # Počet úplných prechodov cez celý trénovací dataset (15 epôch)
|
||
learning_rate=3e-5, # Rýchlosť učenia – ovplyvňuje mieru zmien váh počas učenia
|
||
weight_decay=0.01, # Regularizácia váh – zabraňuje pretrénovaniu tým, že penalizuje veľké váhy
|
||
optim="adamw_torch", # Optimalizátor použitý počas trénovania (AdamW implementovaný v PyTorch)
|
||
dataloader_num_workers=0, # Počet subprocessov použitých pri načítavaní dát (0 = hlavný proces)
|
||
logging_first_step=True, # Zabezpečí logovanie už po prvom kroku trénovania
|
||
disable_tqdm=False, # Zapína vizuálny progres bar (tqdm) počas trénovania v konzole
|
||
)
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# 7. Inicializácia a spustenie trénovania modelu
|
||
# ----------------------------------------------------------
|
||
|
||
trainer = Trainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_dataset,
|
||
)
|
||
|
||
# Spustenie procesu trénovania
|
||
trainer.train()
|
||
|
||
# ----------------------------------------------------------
|
||
# 8. Uloženie finálneho modelu na disk
|
||
# ----------------------------------------------------------
|
||
|
||
model.save_pretrained(model_save_path)
|
||
tokenizer.save_pretrained(model_save_path)
|