BakalarskaPraca/BART_eng/BART_CLEAN_15.py

136 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
from transformers import (
BartTokenizer,
BartForConditionalGeneration,
Trainer,
TrainingArguments,
)
# ----------------------------------------------------------
# 1. Zistenie aktuálneho adresára skriptu
# ----------------------------------------------------------
# Umožňuje uložiť výsledný model do rovnakého priečinka ako tento skript
script_dir = os.path.dirname(os.path.abspath(__file__))
# ----------------------------------------------------------
# 2. Pomocná funkcia na načítanie všetkých .txt súborov z priečinka
# ----------------------------------------------------------
def load_texts_from_folder(folder_path):
texts = []
for filename in os.listdir(folder_path):
if filename.endswith(".txt"):
with open(
os.path.join(folder_path, filename), "r", encoding="utf-8"
) as file:
texts.append(file.read().strip())
return texts
# ----------------------------------------------------------
# 3. Cesty k dátam zdravé a choré texty
# ----------------------------------------------------------
healthy_folder = r"project\text\zdr_eng_cleaned"
dementia_folder = r"project\text\chory_eng_cleaned"
# Načítanie textov z oboch priečinkov
healthy_texts = load_texts_from_folder(healthy_folder)
dementia_texts = load_texts_from_folder(dementia_folder)
# Kontrola: počet textov musí byť rovnaký
assert len(healthy_texts) == len(
dementia_texts
), "Počet textov v oboch priečinkoch sa musí zhodovať!"
# Vytvorenie trénovacích dvojíc (vstup: zdravý text, výstup: text s príznakmi demencie)
train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)]
# ----------------------------------------------------------
# 4. Načítanie predtrénovaného BART modelu a tokenizátora
# ----------------------------------------------------------
local_model_path = r"project\BART_eng\facebook_bart_large"
tokenizer = BartTokenizer.from_pretrained(local_model_path)
model = BartForConditionalGeneration.from_pretrained(local_model_path)
# ----------------------------------------------------------
# 5. Tokenizácia dát pre vstup do modelu
# ----------------------------------------------------------
# Tokenizácia so skrátením na max. 256 tokenov
def tokenize_function(examples):
inputs = tokenizer(
examples["input"], padding="max_length", truncation=True, max_length=256
)
outputs = tokenizer(
examples["output"], padding="max_length", truncation=True, max_length=256
)
inputs["labels"] = outputs["input_ids"]
return inputs
# Aplikácia tokenizácie na všetky vstupy
tokenized_data = list(map(tokenize_function, train_data))
# ----------------------------------------------------------
# 6. Vytvorenie vlastného datasetu kompatibilného s PyTorch
# ----------------------------------------------------------
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
def __len__(self):
return len(self.encodings["input_ids"])
# Dataset na trénovanie
train_dataset = CustomDataset(
{
"input_ids": [item["input_ids"] for item in tokenized_data],
"labels": [item["labels"] for item in tokenized_data],
"attention_mask": [item["attention_mask"] for item in tokenized_data],
}
)
# ----------------------------------------------------------
# 7. Konfigurácia trénovania modelu
# ----------------------------------------------------------
# Žiadne checkpointy, logy, ani validácia len čisté učenie
training_args = TrainingArguments(
output_dir=".", # Povinné, ale nepoužíva sa
evaluation_strategy="no", # Bez evaluačnej množiny
save_strategy="no", # Bez ukladania medzivýsledkov (checkpointov)
per_device_train_batch_size=4, # Veľkosť dávky na zariadenie (CPU/GPU)
num_train_epochs=15, # Počet epôch
learning_rate=3e-5, # Rýchlosť učenia
weight_decay=0.01, # Penalizácia veľkých váh (regularizácia)
)
# ----------------------------------------------------------
# 8. Inicializácia trénovania pomocou Trainer API
# ----------------------------------------------------------
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
# ----------------------------------------------------------
# 9. Spustenie trénovania modelu
# ----------------------------------------------------------
print("Trénovanie modelu prebieha...")
trainer.train()
print("Trénovanie dokončené.")
# ----------------------------------------------------------
# 10. Uloženie finálneho modelu a tokenizátora na disk
# ----------------------------------------------------------
model_save_path = os.path.join(script_dir, "trained_bart_model_CLEAN_15")
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"Model a tokenizátor boli uložené do: {model_save_path}")