136 lines
5.2 KiB
Python
136 lines
5.2 KiB
Python
import os
|
||
import torch
|
||
from transformers import (
|
||
BartTokenizer,
|
||
BartForConditionalGeneration,
|
||
Trainer,
|
||
TrainingArguments,
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 1. Zistenie aktuálneho adresára skriptu
|
||
# ----------------------------------------------------------
|
||
# Umožňuje uložiť výsledný model do rovnakého priečinka ako tento skript
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# 2. Pomocná funkcia na načítanie všetkých .txt súborov z priečinka
|
||
# ----------------------------------------------------------
|
||
def load_texts_from_folder(folder_path):
|
||
texts = []
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith(".txt"):
|
||
with open(
|
||
os.path.join(folder_path, filename), "r", encoding="utf-8"
|
||
) as file:
|
||
texts.append(file.read().strip())
|
||
return texts
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# 3. Cesty k dátam – zdravé a choré texty
|
||
# ----------------------------------------------------------
|
||
healthy_folder = r"project\text\zdr_eng_cleaned"
|
||
dementia_folder = r"project\text\chory_eng_cleaned"
|
||
|
||
# Načítanie textov z oboch priečinkov
|
||
healthy_texts = load_texts_from_folder(healthy_folder)
|
||
dementia_texts = load_texts_from_folder(dementia_folder)
|
||
|
||
# Kontrola: počet textov musí byť rovnaký
|
||
assert len(healthy_texts) == len(
|
||
dementia_texts
|
||
), "Počet textov v oboch priečinkoch sa musí zhodovať!"
|
||
|
||
# Vytvorenie trénovacích dvojíc (vstup: zdravý text, výstup: text s príznakmi demencie)
|
||
train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)]
|
||
|
||
# ----------------------------------------------------------
|
||
# 4. Načítanie predtrénovaného BART modelu a tokenizátora
|
||
# ----------------------------------------------------------
|
||
local_model_path = r"project\BART_eng\facebook_bart_large"
|
||
tokenizer = BartTokenizer.from_pretrained(local_model_path)
|
||
model = BartForConditionalGeneration.from_pretrained(local_model_path)
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# 5. Tokenizácia dát pre vstup do modelu
|
||
# ----------------------------------------------------------
|
||
# Tokenizácia so skrátením na max. 256 tokenov
|
||
def tokenize_function(examples):
|
||
inputs = tokenizer(
|
||
examples["input"], padding="max_length", truncation=True, max_length=256
|
||
)
|
||
outputs = tokenizer(
|
||
examples["output"], padding="max_length", truncation=True, max_length=256
|
||
)
|
||
inputs["labels"] = outputs["input_ids"]
|
||
return inputs
|
||
|
||
|
||
# Aplikácia tokenizácie na všetky vstupy
|
||
tokenized_data = list(map(tokenize_function, train_data))
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# 6. Vytvorenie vlastného datasetu kompatibilného s PyTorch
|
||
# ----------------------------------------------------------
|
||
class CustomDataset(torch.utils.data.Dataset):
|
||
def __init__(self, encodings):
|
||
self.encodings = encodings
|
||
|
||
def __getitem__(self, idx):
|
||
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||
|
||
def __len__(self):
|
||
return len(self.encodings["input_ids"])
|
||
|
||
|
||
# Dataset na trénovanie
|
||
train_dataset = CustomDataset(
|
||
{
|
||
"input_ids": [item["input_ids"] for item in tokenized_data],
|
||
"labels": [item["labels"] for item in tokenized_data],
|
||
"attention_mask": [item["attention_mask"] for item in tokenized_data],
|
||
}
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 7. Konfigurácia trénovania modelu
|
||
# ----------------------------------------------------------
|
||
# Žiadne checkpointy, logy, ani validácia – len čisté učenie
|
||
training_args = TrainingArguments(
|
||
output_dir=".", # Povinné, ale nepoužíva sa
|
||
evaluation_strategy="no", # Bez evaluačnej množiny
|
||
save_strategy="no", # Bez ukladania medzivýsledkov (checkpointov)
|
||
per_device_train_batch_size=4, # Veľkosť dávky na zariadenie (CPU/GPU)
|
||
num_train_epochs=15, # Počet epôch
|
||
learning_rate=3e-5, # Rýchlosť učenia
|
||
weight_decay=0.01, # Penalizácia veľkých váh (regularizácia)
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 8. Inicializácia trénovania pomocou Trainer API
|
||
# ----------------------------------------------------------
|
||
trainer = Trainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_dataset,
|
||
)
|
||
|
||
# ----------------------------------------------------------
|
||
# 9. Spustenie trénovania modelu
|
||
# ----------------------------------------------------------
|
||
print("Trénovanie modelu prebieha...")
|
||
trainer.train()
|
||
print("Trénovanie dokončené.")
|
||
|
||
# ----------------------------------------------------------
|
||
# 10. Uloženie finálneho modelu a tokenizátora na disk
|
||
# ----------------------------------------------------------
|
||
model_save_path = os.path.join(script_dir, "trained_bart_model_CLEAN_15")
|
||
model.save_pretrained(model_save_path)
|
||
tokenizer.save_pretrained(model_save_path)
|
||
print(f"Model a tokenizátor boli uložené do: {model_save_path}")
|