commit 7ab100541a66fd7d8b0c6f4a540e2ed7c1e177b2 Author: Vladyslav Korzun Date: Thu May 22 13:10:08 2025 +0000 Загрузить файлы в «BART_eng» diff --git a/BART_eng/BART_CLEAN_15.py b/BART_eng/BART_CLEAN_15.py new file mode 100644 index 0000000..d698009 --- /dev/null +++ b/BART_eng/BART_CLEAN_15.py @@ -0,0 +1,135 @@ +import os +import torch +from transformers import ( + BartTokenizer, + BartForConditionalGeneration, + Trainer, + TrainingArguments, +) + +# ---------------------------------------------------------- +# 1. Zistenie aktuálneho adresára skriptu +# ---------------------------------------------------------- +# Umožňuje uložiť výsledný model do rovnakého priečinka ako tento skript +script_dir = os.path.dirname(os.path.abspath(__file__)) + + +# ---------------------------------------------------------- +# 2. Pomocná funkcia na načítanie všetkých .txt súborov z priečinka +# ---------------------------------------------------------- +def load_texts_from_folder(folder_path): + texts = [] + for filename in os.listdir(folder_path): + if filename.endswith(".txt"): + with open( + os.path.join(folder_path, filename), "r", encoding="utf-8" + ) as file: + texts.append(file.read().strip()) + return texts + + +# ---------------------------------------------------------- +# 3. Cesty k dátam – zdravé a choré texty +# ---------------------------------------------------------- +healthy_folder = r"project\text\zdr_eng_cleaned" +dementia_folder = r"project\text\chory_eng_cleaned" + +# Načítanie textov z oboch priečinkov +healthy_texts = load_texts_from_folder(healthy_folder) +dementia_texts = load_texts_from_folder(dementia_folder) + +# Kontrola: počet textov musí byť rovnaký +assert len(healthy_texts) == len( + dementia_texts +), "Počet textov v oboch priečinkoch sa musí zhodovať!" + +# Vytvorenie trénovacích dvojíc (vstup: zdravý text, výstup: text s príznakmi demencie) +train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)] + +# ---------------------------------------------------------- +# 4. Načítanie predtrénovaného BART modelu a tokenizátora +# ---------------------------------------------------------- +local_model_path = r"project\BART_eng\facebook_bart_large" +tokenizer = BartTokenizer.from_pretrained(local_model_path) +model = BartForConditionalGeneration.from_pretrained(local_model_path) + + +# ---------------------------------------------------------- +# 5. Tokenizácia dát pre vstup do modelu +# ---------------------------------------------------------- +# Tokenizácia so skrátením na max. 256 tokenov +def tokenize_function(examples): + inputs = tokenizer( + examples["input"], padding="max_length", truncation=True, max_length=256 + ) + outputs = tokenizer( + examples["output"], padding="max_length", truncation=True, max_length=256 + ) + inputs["labels"] = outputs["input_ids"] + return inputs + + +# Aplikácia tokenizácie na všetky vstupy +tokenized_data = list(map(tokenize_function, train_data)) + + +# ---------------------------------------------------------- +# 6. Vytvorenie vlastného datasetu kompatibilného s PyTorch +# ---------------------------------------------------------- +class CustomDataset(torch.utils.data.Dataset): + def __init__(self, encodings): + self.encodings = encodings + + def __getitem__(self, idx): + return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} + + def __len__(self): + return len(self.encodings["input_ids"]) + + +# Dataset na trénovanie +train_dataset = CustomDataset( + { + "input_ids": [item["input_ids"] for item in tokenized_data], + "labels": [item["labels"] for item in tokenized_data], + "attention_mask": [item["attention_mask"] for item in tokenized_data], + } +) + +# ---------------------------------------------------------- +# 7. Konfigurácia trénovania modelu +# ---------------------------------------------------------- +# Žiadne checkpointy, logy, ani validácia – len čisté učenie +training_args = TrainingArguments( + output_dir=".", # Povinné, ale nepoužíva sa + evaluation_strategy="no", # Bez evaluačnej množiny + save_strategy="no", # Bez ukladania medzivýsledkov (checkpointov) + per_device_train_batch_size=4, # Veľkosť dávky na zariadenie (CPU/GPU) + num_train_epochs=15, # Počet epôch + learning_rate=3e-5, # Rýchlosť učenia + weight_decay=0.01, # Penalizácia veľkých váh (regularizácia) +) + +# ---------------------------------------------------------- +# 8. Inicializácia trénovania pomocou Trainer API +# ---------------------------------------------------------- +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, +) + +# ---------------------------------------------------------- +# 9. Spustenie trénovania modelu +# ---------------------------------------------------------- +print("Trénovanie modelu prebieha...") +trainer.train() +print("Trénovanie dokončené.") + +# ---------------------------------------------------------- +# 10. Uloženie finálneho modelu a tokenizátora na disk +# ---------------------------------------------------------- +model_save_path = os.path.join(script_dir, "trained_bart_model_CLEAN_15") +model.save_pretrained(model_save_path) +tokenizer.save_pretrained(model_save_path) +print(f"Model a tokenizátor boli uložené do: {model_save_path}") diff --git a/BART_eng/CLEAN_15.py b/BART_eng/CLEAN_15.py new file mode 100644 index 0000000..b980f53 --- /dev/null +++ b/BART_eng/CLEAN_15.py @@ -0,0 +1,112 @@ +import os +import torch +from transformers import BartTokenizer, BartForConditionalGeneration + +# ---------------------------------------------------------- +# Načítanie trénovaného BART modelu a tokenizátora +# ---------------------------------------------------------- + +# Cesta k priečinku s trénovaným modelom +script_dir = os.path.dirname(os.path.abspath(__file__)) +model_path = os.path.abspath( + os.path.join(script_dir, "..", "BART_eng", "trained_bart_model_CLEAN_15") +) + + +# Načítanie modelu a tokenizátora z disku +tokenizer = BartTokenizer.from_pretrained(model_path) +model = BartForConditionalGeneration.from_pretrained(model_path) + +# Výber zariadenia (GPU ak je dostupné, inak CPU) +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + +print("Model bol úspešne načítaný a je pripravený na použitie.") + +# ---------------------------------------------------------- +# Vstupný text zdravého človeka na transformáciu +# ---------------------------------------------------------- + +healthy_text = """there's a boy reaching into the cookie jar while standing on a stool +and he already took one cookie and is handing it to the girl next to him +and girl seems to be telling him to be quiet so their mom doesn't notice +and the sink is overflowing with water and nobody is paying attention +meanwhile, their mother is turned away, busy with something else +""" + +# Tokenizácia vstupného textu +inputs = tokenizer(healthy_text, return_tensors="pt").to(device) +input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False)) + +# Výpočet dynamických limitov +min_length_ratio = 0.6 +max_length_ratio = 1.1 + +min_length = int(input_token_length * min_length_ratio) +max_length = int(input_token_length * max_length_ratio) + +print(f"Počet tokenov vo vstupe: {input_token_length}") +print(f"Minimálna dĺžka výstupu: {min_length}, maximálna: {max_length}") + +# ---------------------------------------------------------- +# Konfigurácie generovania textu +# ---------------------------------------------------------- + +generation_configs = [ + { + "max_length": max_length, + "min_length": min_length, + "num_beams": 5, + "no_repeat_ngram_size": 4, + "repetition_penalty": 2.2, + "length_penalty": 0.6, + "early_stopping": True, + }, + { + "max_length": max_length + 20, + "min_length": min_length + 10, + "num_beams": 7, + "no_repeat_ngram_size": 5, + "repetition_penalty": 2.2, + "length_penalty": 0.8, + "early_stopping": True, + }, + { + "max_length": max_length + 40, + "min_length": min_length + 20, + "num_beams": 9, + "no_repeat_ngram_size": 6, + "repetition_penalty": 2.2, + "length_penalty": 1.0, + "early_stopping": True, + }, +] + +# ---------------------------------------------------------- +# Generovanie výstupov na základe rôznych konfigurácií +# ---------------------------------------------------------- + +generated_texts = [] +for i, config in enumerate(generation_configs): + print(f"\nGenerácia textu č. {i+1}:") + + # Generovanie bez výpočtu gradientov + with torch.no_grad(): + outputs = model.generate( + inputs["input_ids"], + max_length=config["max_length"], + min_length=config["min_length"], + num_beams=config["num_beams"], + no_repeat_ngram_size=config["no_repeat_ngram_size"], + repetition_penalty=config["repetition_penalty"], + length_penalty=config["length_penalty"], + early_stopping=config["early_stopping"], + ) + + # Dekódovanie výstupu na text + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + generated_texts.append(generated_text) + + print("Vygenerovaný text:") + print(generated_text) + print("-" * 80) diff --git a/BART_eng/METRIKY.py b/BART_eng/METRIKY.py new file mode 100644 index 0000000..a28745e --- /dev/null +++ b/BART_eng/METRIKY.py @@ -0,0 +1,154 @@ +import os +import nltk +from rouge_score import rouge_scorer +from bert_score import score +import warnings +import logging +from transformers.utils import logging as hf_logging + +# ---------------------------------------------------------- +# 1. Nastavenie prostredia – potlačenie varovaní a logov +# ---------------------------------------------------------- + +# Automatické stiahnutie 'punkt' ak ešte nie je k dispozícii +try: + nltk.data.find("tokenizers/punkt") +except LookupError: + nltk.download("punkt") + +# Skrytie varovaní BERTScore pri načítavaní modelov +warnings.filterwarnings( + "ignore", message="Some weights of RobertaModel were not initialized" +) + +# Potlačenie logov z knižnice transformers +hf_logging.set_verbosity_error() +logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) +logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR) +logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) + +# ---------------------------------------------------------- +# 2. Načítanie referenčných textov pacientov +# ---------------------------------------------------------- + + +# Pomocná funkcia na načítanie obsahu všetkých .txt súborov v priečinku +def load_texts_from_folder(folder_path): + texts = [] + for filename in os.listdir(folder_path): + if filename.endswith(".txt"): + with open( + os.path.join(folder_path, filename), "r", encoding="utf-8" + ) as file: + texts.append(file.read().strip()) + return texts + + +# Cesta k textom pacientov s Alzheimerovou chorobou +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(script_dir, "..")) +dementia_folder = os.path.join(project_root, "text", "chory_eng_cleaned") +reference_texts_dementia = load_texts_from_folder(dementia_folder) + +# Overenie, že priečinok obsahuje súbory +if not reference_texts_dementia: + raise ValueError("V zložke nie sú žiadne .txt súbory! Skontroluj cestu.") + +print(f"Načítaných {len(reference_texts_dementia)} chorých textov na porovnanie.") + +# ---------------------------------------------------------- +# 3. Vzorky na porovnanie (originálny zdravý a vygenerovaný text) +# ---------------------------------------------------------- + +original_text_healthy = """there's a boy reaching into the cookie jar while standing on a stool +and he already took one cookie and is handing it to the girl next to him +and girl seems to be telling him to be quiet so their mom doesn't notice +and the sink is overflowing with water and nobody is paying attention +meanwhile, their mother is turned away, busy with something else""" + +generated_text = """okay exc there's a boy taking cookies out of the cookie jar +and he's standing on a stool +and he's handing one to his sister +and she's laughing at him +and I don't know what else you want exc +I don't know exc +what else do you want exc +uh... that's all I see exc""" + +# ---------------------------------------------------------- +# 4. Výpočet ROUGE metriky (porovnávanie na úrovni tokenov/fráz) +# ---------------------------------------------------------- + + +def evaluate_text(reference_texts, generated_text): + scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True) + rouge1_scores = [] + rougeL_scores = [] + + for ref_text in reference_texts: + scores = scorer.score(ref_text, generated_text) + rouge1_scores.append(scores["rouge1"].fmeasure) + rougeL_scores.append(scores["rougeL"].fmeasure) + + return { + "ROUGE-1-best": max(rouge1_scores), + "ROUGE-1-mean": sum(rouge1_scores) / len(rouge1_scores), + "ROUGE-L-best": max(rougeL_scores), + "ROUGE-L-mean": sum(rougeL_scores) / len(rougeL_scores), + } + + +# ---------------------------------------------------------- +# 5. Výpočet BERTScore metriky (porovnávanie na úrovni významu) +# ---------------------------------------------------------- + + +def compute_bert_score(reference_texts, generated_text): + try: + P, R, F1 = score( + cands=[generated_text] + * len( + reference_texts + ), # rovnaký kandidát sa porovnáva so všetkými referenciami + refs=[ + [ref] for ref in reference_texts + ], # každá referencia musí byť v samostatnom zozname + lang="en", + rescale_with_baseline=False, + ) + return {"BERTScore-F1": F1.mean().item()} + except Exception as e: + print(f"Chyba pri výpočte BERTScore: {e}") + return {"BERTScore-F1": 0} + + +# ---------------------------------------------------------- +# 6. Vyhodnotenie kvality vygenerovaného textu +# ---------------------------------------------------------- + +results_to_dementia = evaluate_text(reference_texts_dementia, generated_text) +bert_results_dementia = compute_bert_score(reference_texts_dementia, generated_text) + +results_to_original = evaluate_text([original_text_healthy], generated_text) +bert_results_original = compute_bert_score([original_text_healthy], generated_text) + +# ---------------------------------------------------------- +# 7. Výpis výsledkov pre hodnotenie v bakalárskej práci +# ---------------------------------------------------------- + +print( + "\nPorovnanie s textami pacientov s Alzheimerovou chorobou (ČÍM VYŠŠIE, TÝM LEPŠIE):" +) +print( + f"ROUGE-1 (najvyššie): {results_to_dementia['ROUGE-1-best']:.4f} | (priemer): {results_to_dementia['ROUGE-1-mean']:.4f}" +) +print( + f"ROUGE-L (najvyššie): {results_to_dementia['ROUGE-L-best']:.4f} | (priemer): {results_to_dementia['ROUGE-L-mean']:.4f}" +) +print(f"BERTScore-F1 (priemer): {bert_results_dementia['BERTScore-F1']:.4f}\n") + +print("Porovnanie s pôvodným zdravým textom (ČÍM NIŽŠIE, TÝM LEPŠIE):") +print("(Porovnanie s jedným textom – hodnoty 'naj' a 'priemer' sú rovnaké)") +print(f"ROUGE-1: {results_to_original['ROUGE-1-best']:.4f}") +print(f"ROUGE-L: {results_to_original['ROUGE-L-best']:.4f}") +print(f"BERTScore-F1: {bert_results_original['BERTScore-F1']:.4f}")