Загрузить файлы в «BART_eng»

This commit is contained in:
Vladyslav Korzun 2025-05-22 13:10:08 +00:00
commit 7ab100541a
3 changed files with 401 additions and 0 deletions

135
BART_eng/BART_CLEAN_15.py Normal file
View File

@ -0,0 +1,135 @@
import os
import torch
from transformers import (
BartTokenizer,
BartForConditionalGeneration,
Trainer,
TrainingArguments,
)
# ----------------------------------------------------------
# 1. Zistenie aktuálneho adresára skriptu
# ----------------------------------------------------------
# Umožňuje uložiť výsledný model do rovnakého priečinka ako tento skript
script_dir = os.path.dirname(os.path.abspath(__file__))
# ----------------------------------------------------------
# 2. Pomocná funkcia na načítanie všetkých .txt súborov z priečinka
# ----------------------------------------------------------
def load_texts_from_folder(folder_path):
texts = []
for filename in os.listdir(folder_path):
if filename.endswith(".txt"):
with open(
os.path.join(folder_path, filename), "r", encoding="utf-8"
) as file:
texts.append(file.read().strip())
return texts
# ----------------------------------------------------------
# 3. Cesty k dátam zdravé a choré texty
# ----------------------------------------------------------
healthy_folder = r"project\text\zdr_eng_cleaned"
dementia_folder = r"project\text\chory_eng_cleaned"
# Načítanie textov z oboch priečinkov
healthy_texts = load_texts_from_folder(healthy_folder)
dementia_texts = load_texts_from_folder(dementia_folder)
# Kontrola: počet textov musí byť rovnaký
assert len(healthy_texts) == len(
dementia_texts
), "Počet textov v oboch priečinkoch sa musí zhodovať!"
# Vytvorenie trénovacích dvojíc (vstup: zdravý text, výstup: text s príznakmi demencie)
train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)]
# ----------------------------------------------------------
# 4. Načítanie predtrénovaného BART modelu a tokenizátora
# ----------------------------------------------------------
local_model_path = r"project\BART_eng\facebook_bart_large"
tokenizer = BartTokenizer.from_pretrained(local_model_path)
model = BartForConditionalGeneration.from_pretrained(local_model_path)
# ----------------------------------------------------------
# 5. Tokenizácia dát pre vstup do modelu
# ----------------------------------------------------------
# Tokenizácia so skrátením na max. 256 tokenov
def tokenize_function(examples):
inputs = tokenizer(
examples["input"], padding="max_length", truncation=True, max_length=256
)
outputs = tokenizer(
examples["output"], padding="max_length", truncation=True, max_length=256
)
inputs["labels"] = outputs["input_ids"]
return inputs
# Aplikácia tokenizácie na všetky vstupy
tokenized_data = list(map(tokenize_function, train_data))
# ----------------------------------------------------------
# 6. Vytvorenie vlastného datasetu kompatibilného s PyTorch
# ----------------------------------------------------------
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
def __len__(self):
return len(self.encodings["input_ids"])
# Dataset na trénovanie
train_dataset = CustomDataset(
{
"input_ids": [item["input_ids"] for item in tokenized_data],
"labels": [item["labels"] for item in tokenized_data],
"attention_mask": [item["attention_mask"] for item in tokenized_data],
}
)
# ----------------------------------------------------------
# 7. Konfigurácia trénovania modelu
# ----------------------------------------------------------
# Žiadne checkpointy, logy, ani validácia len čisté učenie
training_args = TrainingArguments(
output_dir=".", # Povinné, ale nepoužíva sa
evaluation_strategy="no", # Bez evaluačnej množiny
save_strategy="no", # Bez ukladania medzivýsledkov (checkpointov)
per_device_train_batch_size=4, # Veľkosť dávky na zariadenie (CPU/GPU)
num_train_epochs=15, # Počet epôch
learning_rate=3e-5, # Rýchlosť učenia
weight_decay=0.01, # Penalizácia veľkých váh (regularizácia)
)
# ----------------------------------------------------------
# 8. Inicializácia trénovania pomocou Trainer API
# ----------------------------------------------------------
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
# ----------------------------------------------------------
# 9. Spustenie trénovania modelu
# ----------------------------------------------------------
print("Trénovanie modelu prebieha...")
trainer.train()
print("Trénovanie dokončené.")
# ----------------------------------------------------------
# 10. Uloženie finálneho modelu a tokenizátora na disk
# ----------------------------------------------------------
model_save_path = os.path.join(script_dir, "trained_bart_model_CLEAN_15")
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"Model a tokenizátor boli uložené do: {model_save_path}")

112
BART_eng/CLEAN_15.py Normal file
View File

@ -0,0 +1,112 @@
import os
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
# ----------------------------------------------------------
# Načítanie trénovaného BART modelu a tokenizátora
# ----------------------------------------------------------
# Cesta k priečinku s trénovaným modelom
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.abspath(
os.path.join(script_dir, "..", "BART_eng", "trained_bart_model_CLEAN_15")
)
# Načítanie modelu a tokenizátora z disku
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)
# Výber zariadenia (GPU ak je dostupné, inak CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("Model bol úspešne načítaný a je pripravený na použitie.")
# ----------------------------------------------------------
# Vstupný text zdravého človeka na transformáciu
# ----------------------------------------------------------
healthy_text = """there's a boy reaching into the cookie jar while standing on a stool
and he already took one cookie and is handing it to the girl next to him
and girl seems to be telling him to be quiet so their mom doesn't notice
and the sink is overflowing with water and nobody is paying attention
meanwhile, their mother is turned away, busy with something else
"""
# Tokenizácia vstupného textu
inputs = tokenizer(healthy_text, return_tensors="pt").to(device)
input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False))
# Výpočet dynamických limitov
min_length_ratio = 0.6
max_length_ratio = 1.1
min_length = int(input_token_length * min_length_ratio)
max_length = int(input_token_length * max_length_ratio)
print(f"Počet tokenov vo vstupe: {input_token_length}")
print(f"Minimálna dĺžka výstupu: {min_length}, maximálna: {max_length}")
# ----------------------------------------------------------
# Konfigurácie generovania textu
# ----------------------------------------------------------
generation_configs = [
{
"max_length": max_length,
"min_length": min_length,
"num_beams": 5,
"no_repeat_ngram_size": 4,
"repetition_penalty": 2.2,
"length_penalty": 0.6,
"early_stopping": True,
},
{
"max_length": max_length + 20,
"min_length": min_length + 10,
"num_beams": 7,
"no_repeat_ngram_size": 5,
"repetition_penalty": 2.2,
"length_penalty": 0.8,
"early_stopping": True,
},
{
"max_length": max_length + 40,
"min_length": min_length + 20,
"num_beams": 9,
"no_repeat_ngram_size": 6,
"repetition_penalty": 2.2,
"length_penalty": 1.0,
"early_stopping": True,
},
]
# ----------------------------------------------------------
# Generovanie výstupov na základe rôznych konfigurácií
# ----------------------------------------------------------
generated_texts = []
for i, config in enumerate(generation_configs):
print(f"\nGenerácia textu č. {i+1}:")
# Generovanie bez výpočtu gradientov
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=config["max_length"],
min_length=config["min_length"],
num_beams=config["num_beams"],
no_repeat_ngram_size=config["no_repeat_ngram_size"],
repetition_penalty=config["repetition_penalty"],
length_penalty=config["length_penalty"],
early_stopping=config["early_stopping"],
)
# Dekódovanie výstupu na text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_texts.append(generated_text)
print("Vygenerovaný text:")
print(generated_text)
print("-" * 80)

154
BART_eng/METRIKY.py Normal file
View File

@ -0,0 +1,154 @@
import os
import nltk
from rouge_score import rouge_scorer
from bert_score import score
import warnings
import logging
from transformers.utils import logging as hf_logging
# ----------------------------------------------------------
# 1. Nastavenie prostredia potlačenie varovaní a logov
# ----------------------------------------------------------
# Automatické stiahnutie 'punkt' ak ešte nie je k dispozícii
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
# Skrytie varovaní BERTScore pri načítavaní modelov
warnings.filterwarnings(
"ignore", message="Some weights of RobertaModel were not initialized"
)
# Potlačenie logov z knižnice transformers
hf_logging.set_verbosity_error()
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
# ----------------------------------------------------------
# 2. Načítanie referenčných textov pacientov
# ----------------------------------------------------------
# Pomocná funkcia na načítanie obsahu všetkých .txt súborov v priečinku
def load_texts_from_folder(folder_path):
texts = []
for filename in os.listdir(folder_path):
if filename.endswith(".txt"):
with open(
os.path.join(folder_path, filename), "r", encoding="utf-8"
) as file:
texts.append(file.read().strip())
return texts
# Cesta k textom pacientov s Alzheimerovou chorobou
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, ".."))
dementia_folder = os.path.join(project_root, "text", "chory_eng_cleaned")
reference_texts_dementia = load_texts_from_folder(dementia_folder)
# Overenie, že priečinok obsahuje súbory
if not reference_texts_dementia:
raise ValueError("V zložke nie sú žiadne .txt súbory! Skontroluj cestu.")
print(f"Načítaných {len(reference_texts_dementia)} chorých textov na porovnanie.")
# ----------------------------------------------------------
# 3. Vzorky na porovnanie (originálny zdravý a vygenerovaný text)
# ----------------------------------------------------------
original_text_healthy = """there's a boy reaching into the cookie jar while standing on a stool
and he already took one cookie and is handing it to the girl next to him
and girl seems to be telling him to be quiet so their mom doesn't notice
and the sink is overflowing with water and nobody is paying attention
meanwhile, their mother is turned away, busy with something else"""
generated_text = """okay exc there's a boy taking cookies out of the cookie jar
and he's standing on a stool
and he's handing one to his sister
and she's laughing at him
and I don't know what else you want exc
I don't know exc
what else do you want exc
uh... that's all I see exc"""
# ----------------------------------------------------------
# 4. Výpočet ROUGE metriky (porovnávanie na úrovni tokenov/fráz)
# ----------------------------------------------------------
def evaluate_text(reference_texts, generated_text):
scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
rouge1_scores = []
rougeL_scores = []
for ref_text in reference_texts:
scores = scorer.score(ref_text, generated_text)
rouge1_scores.append(scores["rouge1"].fmeasure)
rougeL_scores.append(scores["rougeL"].fmeasure)
return {
"ROUGE-1-best": max(rouge1_scores),
"ROUGE-1-mean": sum(rouge1_scores) / len(rouge1_scores),
"ROUGE-L-best": max(rougeL_scores),
"ROUGE-L-mean": sum(rougeL_scores) / len(rougeL_scores),
}
# ----------------------------------------------------------
# 5. Výpočet BERTScore metriky (porovnávanie na úrovni významu)
# ----------------------------------------------------------
def compute_bert_score(reference_texts, generated_text):
try:
P, R, F1 = score(
cands=[generated_text]
* len(
reference_texts
), # rovnaký kandidát sa porovnáva so všetkými referenciami
refs=[
[ref] for ref in reference_texts
], # každá referencia musí byť v samostatnom zozname
lang="en",
rescale_with_baseline=False,
)
return {"BERTScore-F1": F1.mean().item()}
except Exception as e:
print(f"Chyba pri výpočte BERTScore: {e}")
return {"BERTScore-F1": 0}
# ----------------------------------------------------------
# 6. Vyhodnotenie kvality vygenerovaného textu
# ----------------------------------------------------------
results_to_dementia = evaluate_text(reference_texts_dementia, generated_text)
bert_results_dementia = compute_bert_score(reference_texts_dementia, generated_text)
results_to_original = evaluate_text([original_text_healthy], generated_text)
bert_results_original = compute_bert_score([original_text_healthy], generated_text)
# ----------------------------------------------------------
# 7. Výpis výsledkov pre hodnotenie v bakalárskej práci
# ----------------------------------------------------------
print(
"\nPorovnanie s textami pacientov s Alzheimerovou chorobou (ČÍM VYŠŠIE, TÝM LEPŠIE):"
)
print(
f"ROUGE-1 (najvyššie): {results_to_dementia['ROUGE-1-best']:.4f} | (priemer): {results_to_dementia['ROUGE-1-mean']:.4f}"
)
print(
f"ROUGE-L (najvyššie): {results_to_dementia['ROUGE-L-best']:.4f} | (priemer): {results_to_dementia['ROUGE-L-mean']:.4f}"
)
print(f"BERTScore-F1 (priemer): {bert_results_dementia['BERTScore-F1']:.4f}\n")
print("Porovnanie s pôvodným zdravým textom (ČÍM NIŽŠIE, TÝM LEPŠIE):")
print("(Porovnanie s jedným textom hodnoty 'naj' a 'priemer' sú rovnaké)")
print(f"ROUGE-1: {results_to_original['ROUGE-1-best']:.4f}")
print(f"ROUGE-L: {results_to_original['ROUGE-L-best']:.4f}")
print(f"BERTScore-F1: {bert_results_original['BERTScore-F1']:.4f}")