BakalarskaPraca/T5_slovak_4_3_1/T5_4.py

153 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
Trainer,
TrainingArguments,
)
# ----------------------------------------------------------
# 1. Načítanie modelu a tokenizátora zo súborového systému
# ----------------------------------------------------------
# Zistenie aktuálnej cesty k projektu
script_dir = os.path.dirname(os.path.abspath(__file__))
# Načítanie slovenského modelu T5 a tokenizátora
tokenizer = AutoTokenizer.from_pretrained("TUKE-KEMT/slovak-t5-base", legacy=False)
tokenizer.pad_token = tokenizer.eos_token # Nastavenie koncového tokenu ako výplňového
model = AutoModelForSeq2SeqLM.from_pretrained("TUKE-KEMT/slovak-t5-base")
# Výber výpočtového zariadenia (v tomto prípade CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# ----------------------------------------------------------
# 2. Načítanie vstupných dát (zdravé a cieľové texty)
# ----------------------------------------------------------
# Cesty k priečinkom s textovými súbormi
healthy_folder = r"project\text\zdr_cleaned\4"
dementia_folder = r"project\text\chory_cleaned\4"
# Pomocná funkcia na načítanie všetkých textov z daného priečinka
def load_texts_from_folder(folder_path):
texts = []
for filename in os.listdir(folder_path):
if filename.endswith(".txt"):
with open(
os.path.join(folder_path, filename), "r", encoding="utf-8"
) as file:
texts.append(file.read().strip())
return texts
# Načítanie vstupných textov
healthy_texts = load_texts_from_folder(healthy_folder)
dementia_texts = load_texts_from_folder(dementia_folder)
# Kontrola: počet textov v oboch priečinkoch musí byť rovnaký
assert len(healthy_texts) == len(dementia_texts), "Počet textov musí byť rovnaký."
# Vytvorenie zoznamu trénovacích dvojíc (vstup → cieľ)
train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)]
# ----------------------------------------------------------
# 3. Tokenizácia vstupných dát (na formát požadovaný modelom)
# ----------------------------------------------------------
def tokenize_function(example):
inputs = tokenizer(
example["input"], padding="max_length", truncation=True, max_length=128
)
outputs = tokenizer(
example["output"], padding="max_length", truncation=True, max_length=128
)
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": outputs["input_ids"],
}
# Vykonanie tokenizácie pre všetky trénovacie dvojice
tokenized_train_data = {key: [] for key in ["input_ids", "attention_mask", "labels"]}
for example in train_data:
tokenized = tokenize_function(example)
for key in tokenized_train_data:
tokenized_train_data[key].append(tokenized[key])
# ----------------------------------------------------------
# 4. Vytvorenie datasetu pre PyTorch
# ----------------------------------------------------------
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, idx):
return {
"input_ids": torch.tensor(self.data["input_ids"][idx], dtype=torch.long),
"labels": torch.tensor(self.data["labels"][idx], dtype=torch.long),
"attention_mask": torch.tensor(
self.data["attention_mask"][idx], dtype=torch.long
),
}
def __len__(self):
return len(self.data["input_ids"])
# Inicializácia trénovacieho datasetu
train_dataset = CustomDataset(tokenized_train_data)
# ----------------------------------------------------------
# 5. Nastavenie cieľovej cesty pre uloženie finálneho modelu
# ----------------------------------------------------------
model_save_path = os.path.join(script_dir, "trained_slovak_t5_cpu")
# ----------------------------------------------------------
# 6. Definícia parametrov trénovania modelu
# ----------------------------------------------------------
training_args = TrainingArguments(
save_strategy="no", # Automatické ukladanie checkpointov je vypnuté (model sa uloží len na konci ručne)
logging_steps=50, # Frekvencia logovania každých 50 krokov sa vypíšu metriky a stav
per_device_train_batch_size=2, # Veľkosť dávky pre jeden krok tréningu (na jedno zariadenie)
gradient_accumulation_steps=4, # Počet krokov, počas ktorých sa akumulujú gradienty pred aktualizáciou váh
# Efektívna veľkosť dávky = 2 × 4 = 8 príkladov
num_train_epochs=15, # Počet úplných prechodov cez celý trénovací dataset (15 epôch)
learning_rate=3e-5, # Rýchlosť učenia ovplyvňuje mieru zmien váh počas učenia
weight_decay=0.01, # Regularizácia váh zabraňuje pretrénovaniu tým, že penalizuje veľké váhy
optim="adamw_torch", # Optimalizátor použitý počas trénovania (AdamW implementovaný v PyTorch)
dataloader_num_workers=0, # Počet subprocessov použitých pri načítavaní dát (0 = hlavný proces)
logging_first_step=True, # Zabezpečí logovanie už po prvom kroku trénovania
disable_tqdm=False, # Zapína vizuálny progres bar (tqdm) počas trénovania v konzole
)
# ----------------------------------------------------------
# 7. Inicializácia a spustenie trénovania modelu
# ----------------------------------------------------------
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
# Spustenie procesu trénovania
trainer.train()
# ----------------------------------------------------------
# 8. Uloženie finálneho modelu na disk
# ----------------------------------------------------------
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)