Загрузить файлы в «BART_eng»
This commit is contained in:
commit
7ab100541a
135
BART_eng/BART_CLEAN_15.py
Normal file
135
BART_eng/BART_CLEAN_15.py
Normal file
@ -0,0 +1,135 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers import (
|
||||
BartTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 1. Zistenie aktuálneho adresára skriptu
|
||||
# ----------------------------------------------------------
|
||||
# Umožňuje uložiť výsledný model do rovnakého priečinka ako tento skript
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 2. Pomocná funkcia na načítanie všetkých .txt súborov z priečinka
|
||||
# ----------------------------------------------------------
|
||||
def load_texts_from_folder(folder_path):
|
||||
texts = []
|
||||
for filename in os.listdir(folder_path):
|
||||
if filename.endswith(".txt"):
|
||||
with open(
|
||||
os.path.join(folder_path, filename), "r", encoding="utf-8"
|
||||
) as file:
|
||||
texts.append(file.read().strip())
|
||||
return texts
|
||||
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 3. Cesty k dátam – zdravé a choré texty
|
||||
# ----------------------------------------------------------
|
||||
healthy_folder = r"project\text\zdr_eng_cleaned"
|
||||
dementia_folder = r"project\text\chory_eng_cleaned"
|
||||
|
||||
# Načítanie textov z oboch priečinkov
|
||||
healthy_texts = load_texts_from_folder(healthy_folder)
|
||||
dementia_texts = load_texts_from_folder(dementia_folder)
|
||||
|
||||
# Kontrola: počet textov musí byť rovnaký
|
||||
assert len(healthy_texts) == len(
|
||||
dementia_texts
|
||||
), "Počet textov v oboch priečinkoch sa musí zhodovať!"
|
||||
|
||||
# Vytvorenie trénovacích dvojíc (vstup: zdravý text, výstup: text s príznakmi demencie)
|
||||
train_data = [{"input": h, "output": d} for h, d in zip(healthy_texts, dementia_texts)]
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 4. Načítanie predtrénovaného BART modelu a tokenizátora
|
||||
# ----------------------------------------------------------
|
||||
local_model_path = r"project\BART_eng\facebook_bart_large"
|
||||
tokenizer = BartTokenizer.from_pretrained(local_model_path)
|
||||
model = BartForConditionalGeneration.from_pretrained(local_model_path)
|
||||
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 5. Tokenizácia dát pre vstup do modelu
|
||||
# ----------------------------------------------------------
|
||||
# Tokenizácia so skrátením na max. 256 tokenov
|
||||
def tokenize_function(examples):
|
||||
inputs = tokenizer(
|
||||
examples["input"], padding="max_length", truncation=True, max_length=256
|
||||
)
|
||||
outputs = tokenizer(
|
||||
examples["output"], padding="max_length", truncation=True, max_length=256
|
||||
)
|
||||
inputs["labels"] = outputs["input_ids"]
|
||||
return inputs
|
||||
|
||||
|
||||
# Aplikácia tokenizácie na všetky vstupy
|
||||
tokenized_data = list(map(tokenize_function, train_data))
|
||||
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 6. Vytvorenie vlastného datasetu kompatibilného s PyTorch
|
||||
# ----------------------------------------------------------
|
||||
class CustomDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, encodings):
|
||||
self.encodings = encodings
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encodings["input_ids"])
|
||||
|
||||
|
||||
# Dataset na trénovanie
|
||||
train_dataset = CustomDataset(
|
||||
{
|
||||
"input_ids": [item["input_ids"] for item in tokenized_data],
|
||||
"labels": [item["labels"] for item in tokenized_data],
|
||||
"attention_mask": [item["attention_mask"] for item in tokenized_data],
|
||||
}
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 7. Konfigurácia trénovania modelu
|
||||
# ----------------------------------------------------------
|
||||
# Žiadne checkpointy, logy, ani validácia – len čisté učenie
|
||||
training_args = TrainingArguments(
|
||||
output_dir=".", # Povinné, ale nepoužíva sa
|
||||
evaluation_strategy="no", # Bez evaluačnej množiny
|
||||
save_strategy="no", # Bez ukladania medzivýsledkov (checkpointov)
|
||||
per_device_train_batch_size=4, # Veľkosť dávky na zariadenie (CPU/GPU)
|
||||
num_train_epochs=15, # Počet epôch
|
||||
learning_rate=3e-5, # Rýchlosť učenia
|
||||
weight_decay=0.01, # Penalizácia veľkých váh (regularizácia)
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 8. Inicializácia trénovania pomocou Trainer API
|
||||
# ----------------------------------------------------------
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 9. Spustenie trénovania modelu
|
||||
# ----------------------------------------------------------
|
||||
print("Trénovanie modelu prebieha...")
|
||||
trainer.train()
|
||||
print("Trénovanie dokončené.")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 10. Uloženie finálneho modelu a tokenizátora na disk
|
||||
# ----------------------------------------------------------
|
||||
model_save_path = os.path.join(script_dir, "trained_bart_model_CLEAN_15")
|
||||
model.save_pretrained(model_save_path)
|
||||
tokenizer.save_pretrained(model_save_path)
|
||||
print(f"Model a tokenizátor boli uložené do: {model_save_path}")
|
112
BART_eng/CLEAN_15.py
Normal file
112
BART_eng/CLEAN_15.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers import BartTokenizer, BartForConditionalGeneration
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Načítanie trénovaného BART modelu a tokenizátora
|
||||
# ----------------------------------------------------------
|
||||
|
||||
# Cesta k priečinku s trénovaným modelom
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.abspath(
|
||||
os.path.join(script_dir, "..", "BART_eng", "trained_bart_model_CLEAN_15")
|
||||
)
|
||||
|
||||
|
||||
# Načítanie modelu a tokenizátora z disku
|
||||
tokenizer = BartTokenizer.from_pretrained(model_path)
|
||||
model = BartForConditionalGeneration.from_pretrained(model_path)
|
||||
|
||||
# Výber zariadenia (GPU ak je dostupné, inak CPU)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model.to(device)
|
||||
|
||||
print("Model bol úspešne načítaný a je pripravený na použitie.")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Vstupný text zdravého človeka na transformáciu
|
||||
# ----------------------------------------------------------
|
||||
|
||||
healthy_text = """there's a boy reaching into the cookie jar while standing on a stool
|
||||
and he already took one cookie and is handing it to the girl next to him
|
||||
and girl seems to be telling him to be quiet so their mom doesn't notice
|
||||
and the sink is overflowing with water and nobody is paying attention
|
||||
meanwhile, their mother is turned away, busy with something else
|
||||
"""
|
||||
|
||||
# Tokenizácia vstupného textu
|
||||
inputs = tokenizer(healthy_text, return_tensors="pt").to(device)
|
||||
input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False))
|
||||
|
||||
# Výpočet dynamických limitov
|
||||
min_length_ratio = 0.6
|
||||
max_length_ratio = 1.1
|
||||
|
||||
min_length = int(input_token_length * min_length_ratio)
|
||||
max_length = int(input_token_length * max_length_ratio)
|
||||
|
||||
print(f"Počet tokenov vo vstupe: {input_token_length}")
|
||||
print(f"Minimálna dĺžka výstupu: {min_length}, maximálna: {max_length}")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Konfigurácie generovania textu
|
||||
# ----------------------------------------------------------
|
||||
|
||||
generation_configs = [
|
||||
{
|
||||
"max_length": max_length,
|
||||
"min_length": min_length,
|
||||
"num_beams": 5,
|
||||
"no_repeat_ngram_size": 4,
|
||||
"repetition_penalty": 2.2,
|
||||
"length_penalty": 0.6,
|
||||
"early_stopping": True,
|
||||
},
|
||||
{
|
||||
"max_length": max_length + 20,
|
||||
"min_length": min_length + 10,
|
||||
"num_beams": 7,
|
||||
"no_repeat_ngram_size": 5,
|
||||
"repetition_penalty": 2.2,
|
||||
"length_penalty": 0.8,
|
||||
"early_stopping": True,
|
||||
},
|
||||
{
|
||||
"max_length": max_length + 40,
|
||||
"min_length": min_length + 20,
|
||||
"num_beams": 9,
|
||||
"no_repeat_ngram_size": 6,
|
||||
"repetition_penalty": 2.2,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": True,
|
||||
},
|
||||
]
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Generovanie výstupov na základe rôznych konfigurácií
|
||||
# ----------------------------------------------------------
|
||||
|
||||
generated_texts = []
|
||||
for i, config in enumerate(generation_configs):
|
||||
print(f"\nGenerácia textu č. {i+1}:")
|
||||
|
||||
# Generovanie bez výpočtu gradientov
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
inputs["input_ids"],
|
||||
max_length=config["max_length"],
|
||||
min_length=config["min_length"],
|
||||
num_beams=config["num_beams"],
|
||||
no_repeat_ngram_size=config["no_repeat_ngram_size"],
|
||||
repetition_penalty=config["repetition_penalty"],
|
||||
length_penalty=config["length_penalty"],
|
||||
early_stopping=config["early_stopping"],
|
||||
)
|
||||
|
||||
# Dekódovanie výstupu na text
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
generated_texts.append(generated_text)
|
||||
|
||||
print("Vygenerovaný text:")
|
||||
print(generated_text)
|
||||
print("-" * 80)
|
154
BART_eng/METRIKY.py
Normal file
154
BART_eng/METRIKY.py
Normal file
@ -0,0 +1,154 @@
|
||||
import os
|
||||
import nltk
|
||||
from rouge_score import rouge_scorer
|
||||
from bert_score import score
|
||||
import warnings
|
||||
import logging
|
||||
from transformers.utils import logging as hf_logging
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 1. Nastavenie prostredia – potlačenie varovaní a logov
|
||||
# ----------------------------------------------------------
|
||||
|
||||
# Automatické stiahnutie 'punkt' ak ešte nie je k dispozícii
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
nltk.download("punkt")
|
||||
|
||||
# Skrytie varovaní BERTScore pri načítavaní modelov
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="Some weights of RobertaModel were not initialized"
|
||||
)
|
||||
|
||||
# Potlačenie logov z knižnice transformers
|
||||
hf_logging.set_verbosity_error()
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
||||
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)
|
||||
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 2. Načítanie referenčných textov pacientov
|
||||
# ----------------------------------------------------------
|
||||
|
||||
|
||||
# Pomocná funkcia na načítanie obsahu všetkých .txt súborov v priečinku
|
||||
def load_texts_from_folder(folder_path):
|
||||
texts = []
|
||||
for filename in os.listdir(folder_path):
|
||||
if filename.endswith(".txt"):
|
||||
with open(
|
||||
os.path.join(folder_path, filename), "r", encoding="utf-8"
|
||||
) as file:
|
||||
texts.append(file.read().strip())
|
||||
return texts
|
||||
|
||||
|
||||
# Cesta k textom pacientov s Alzheimerovou chorobou
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
dementia_folder = os.path.join(project_root, "text", "chory_eng_cleaned")
|
||||
reference_texts_dementia = load_texts_from_folder(dementia_folder)
|
||||
|
||||
# Overenie, že priečinok obsahuje súbory
|
||||
if not reference_texts_dementia:
|
||||
raise ValueError("V zložke nie sú žiadne .txt súbory! Skontroluj cestu.")
|
||||
|
||||
print(f"Načítaných {len(reference_texts_dementia)} chorých textov na porovnanie.")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 3. Vzorky na porovnanie (originálny zdravý a vygenerovaný text)
|
||||
# ----------------------------------------------------------
|
||||
|
||||
original_text_healthy = """there's a boy reaching into the cookie jar while standing on a stool
|
||||
and he already took one cookie and is handing it to the girl next to him
|
||||
and girl seems to be telling him to be quiet so their mom doesn't notice
|
||||
and the sink is overflowing with water and nobody is paying attention
|
||||
meanwhile, their mother is turned away, busy with something else"""
|
||||
|
||||
generated_text = """okay exc there's a boy taking cookies out of the cookie jar
|
||||
and he's standing on a stool
|
||||
and he's handing one to his sister
|
||||
and she's laughing at him
|
||||
and I don't know what else you want exc
|
||||
I don't know exc
|
||||
what else do you want exc
|
||||
uh... that's all I see exc"""
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 4. Výpočet ROUGE metriky (porovnávanie na úrovni tokenov/fráz)
|
||||
# ----------------------------------------------------------
|
||||
|
||||
|
||||
def evaluate_text(reference_texts, generated_text):
|
||||
scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
|
||||
rouge1_scores = []
|
||||
rougeL_scores = []
|
||||
|
||||
for ref_text in reference_texts:
|
||||
scores = scorer.score(ref_text, generated_text)
|
||||
rouge1_scores.append(scores["rouge1"].fmeasure)
|
||||
rougeL_scores.append(scores["rougeL"].fmeasure)
|
||||
|
||||
return {
|
||||
"ROUGE-1-best": max(rouge1_scores),
|
||||
"ROUGE-1-mean": sum(rouge1_scores) / len(rouge1_scores),
|
||||
"ROUGE-L-best": max(rougeL_scores),
|
||||
"ROUGE-L-mean": sum(rougeL_scores) / len(rougeL_scores),
|
||||
}
|
||||
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 5. Výpočet BERTScore metriky (porovnávanie na úrovni významu)
|
||||
# ----------------------------------------------------------
|
||||
|
||||
|
||||
def compute_bert_score(reference_texts, generated_text):
|
||||
try:
|
||||
P, R, F1 = score(
|
||||
cands=[generated_text]
|
||||
* len(
|
||||
reference_texts
|
||||
), # rovnaký kandidát sa porovnáva so všetkými referenciami
|
||||
refs=[
|
||||
[ref] for ref in reference_texts
|
||||
], # každá referencia musí byť v samostatnom zozname
|
||||
lang="en",
|
||||
rescale_with_baseline=False,
|
||||
)
|
||||
return {"BERTScore-F1": F1.mean().item()}
|
||||
except Exception as e:
|
||||
print(f"Chyba pri výpočte BERTScore: {e}")
|
||||
return {"BERTScore-F1": 0}
|
||||
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 6. Vyhodnotenie kvality vygenerovaného textu
|
||||
# ----------------------------------------------------------
|
||||
|
||||
results_to_dementia = evaluate_text(reference_texts_dementia, generated_text)
|
||||
bert_results_dementia = compute_bert_score(reference_texts_dementia, generated_text)
|
||||
|
||||
results_to_original = evaluate_text([original_text_healthy], generated_text)
|
||||
bert_results_original = compute_bert_score([original_text_healthy], generated_text)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 7. Výpis výsledkov pre hodnotenie v bakalárskej práci
|
||||
# ----------------------------------------------------------
|
||||
|
||||
print(
|
||||
"\nPorovnanie s textami pacientov s Alzheimerovou chorobou (ČÍM VYŠŠIE, TÝM LEPŠIE):"
|
||||
)
|
||||
print(
|
||||
f"ROUGE-1 (najvyššie): {results_to_dementia['ROUGE-1-best']:.4f} | (priemer): {results_to_dementia['ROUGE-1-mean']:.4f}"
|
||||
)
|
||||
print(
|
||||
f"ROUGE-L (najvyššie): {results_to_dementia['ROUGE-L-best']:.4f} | (priemer): {results_to_dementia['ROUGE-L-mean']:.4f}"
|
||||
)
|
||||
print(f"BERTScore-F1 (priemer): {bert_results_dementia['BERTScore-F1']:.4f}\n")
|
||||
|
||||
print("Porovnanie s pôvodným zdravým textom (ČÍM NIŽŠIE, TÝM LEPŠIE):")
|
||||
print("(Porovnanie s jedným textom – hodnoty 'naj' a 'priemer' sú rovnaké)")
|
||||
print(f"ROUGE-1: {results_to_original['ROUGE-1-best']:.4f}")
|
||||
print(f"ROUGE-L: {results_to_original['ROUGE-L-best']:.4f}")
|
||||
print(f"BERTScore-F1: {bert_results_original['BERTScore-F1']:.4f}")
|
Loading…
Reference in New Issue
Block a user