import os import torch from transformers import ( BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments, ) # ---------------------------------------------------------- # 1. Zistenie aktuálneho adresára skriptu # ---------------------------------------------------------- # Umožňuje uložiť výsledný model do rovnakého priečinka ako tento skript script_dir = os.path.dirname(os.path.abspath(__file__)) # ---------------------------------------------------------- # 2. Pomocná funkcia na načítanie všetkých .txt súborov z priečinka # ---------------------------------------------------------- def load_texts_from_folder(folder_path): texts = [] for filename in os.listdir(folder_path): if filename.endswith(".txt"): with open( os.path.join(folder_path, filename), "r", encoding="utf-8" ) as file: texts.append(file.read().strip()) return texts # ---------------------------------------------------------- # 3. Cesty k dátam – zdravé a choré texty # ---------------------------------------------------------- healthy_folder = r"project\text\zdr_eng_cleaned" dementia_folder = r"project\text\chory_eng_cleaned" # Načítanie textov z oboch priečinkov healthy_texts = load_texts_from_folder(healthy_folder) dementia_texts = load_texts_from_folder(dementia_folder) # Kontrola: počet textov musí byť rovnaký assert len(healthy_texts) == len( dementia_texts ), "Počet textov v oboch priečinkoch sa musí zhodovať!" # Vytvorenie trénovacích dvojíc (vstup: zdravý text, výstup: text s príznakmi demencie) train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)] # ---------------------------------------------------------- # 4. Načítanie predtrénovaného BART modelu a tokenizátora # ---------------------------------------------------------- local_model_path = r"project\BART_eng\facebook_bart_large" tokenizer = BartTokenizer.from_pretrained(local_model_path) model = BartForConditionalGeneration.from_pretrained(local_model_path) # ---------------------------------------------------------- # 5. Tokenizácia dát pre vstup do modelu # ---------------------------------------------------------- # Tokenizácia so skrátením na max. 256 tokenov def tokenize_function(examples): inputs = tokenizer( examples["input"], padding="max_length", truncation=True, max_length=256 ) outputs = tokenizer( examples["output"], padding="max_length", truncation=True, max_length=256 ) inputs["labels"] = outputs["input_ids"] return inputs # Aplikácia tokenizácie na všetky vstupy tokenized_data = list(map(tokenize_function, train_data)) # ---------------------------------------------------------- # 6. Vytvorenie vlastného datasetu kompatibilného s PyTorch # ---------------------------------------------------------- class CustomDataset(torch.utils.data.Dataset): def __init__(self, encodings): self.encodings = encodings def __getitem__(self, idx): return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} def __len__(self): return len(self.encodings["input_ids"]) # Dataset na trénovanie train_dataset = CustomDataset( { "input_ids": [item["input_ids"] for item in tokenized_data], "labels": [item["labels"] for item in tokenized_data], "attention_mask": [item["attention_mask"] for item in tokenized_data], } ) # ---------------------------------------------------------- # 7. Konfigurácia trénovania modelu # ---------------------------------------------------------- # Žiadne checkpointy, logy, ani validácia – len čisté učenie training_args = TrainingArguments( output_dir=".", # Povinné, ale nepoužíva sa evaluation_strategy="no", # Bez evaluačnej množiny save_strategy="no", # Bez ukladania medzivýsledkov (checkpointov) per_device_train_batch_size=4, # Veľkosť dávky na zariadenie (CPU/GPU) num_train_epochs=15, # Počet epôch learning_rate=3e-5, # Rýchlosť učenia weight_decay=0.01, # Penalizácia veľkých váh (regularizácia) ) # ---------------------------------------------------------- # 8. Inicializácia trénovania pomocou Trainer API # ---------------------------------------------------------- trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, ) # ---------------------------------------------------------- # 9. Spustenie trénovania modelu # ---------------------------------------------------------- print("Trénovanie modelu prebieha...") trainer.train() print("Trénovanie dokončené.") # ---------------------------------------------------------- # 10. Uloženie finálneho modelu a tokenizátora na disk # ---------------------------------------------------------- model_save_path = os.path.join(script_dir, "trained_bart_model_CLEAN_15") model.save_pretrained(model_save_path) tokenizer.save_pretrained(model_save_path) print(f"Model a tokenizátor boli uložené do: {model_save_path}")