155 lines
6.1 KiB
Python
155 lines
6.1 KiB
Python
import os
|
||
import nltk
|
||
from rouge_score import rouge_scorer
|
||
from bert_score import score
|
||
import warnings
|
||
import logging
|
||
from transformers.utils import logging as hf_logging
|
||
|
||
# ----------------------------------------------------------
|
||
# 1. Nastavenie prostredia – potlačenie varovaní a logov
|
||
# ----------------------------------------------------------
|
||
|
||
# Automatické stiahnutie 'punkt' ak ešte nie je k dispozícii
|
||
try:
|
||
nltk.data.find("tokenizers/punkt")
|
||
except LookupError:
|
||
nltk.download("punkt")
|
||
|
||
# Skrytie varovaní BERTScore pri načítavaní modelov
|
||
warnings.filterwarnings(
|
||
"ignore", message="Some weights of RobertaModel were not initialized"
|
||
)
|
||
|
||
# Potlačenie logov z knižnice transformers
|
||
hf_logging.set_verbosity_error()
|
||
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
||
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)
|
||
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
|
||
|
||
# ----------------------------------------------------------
|
||
# 2. Načítanie referenčných textov pacientov
|
||
# ----------------------------------------------------------
|
||
|
||
|
||
# Pomocná funkcia na načítanie obsahu všetkých .txt súborov v priečinku
|
||
def load_texts_from_folder(folder_path):
|
||
texts = []
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith(".txt"):
|
||
with open(
|
||
os.path.join(folder_path, filename), "r", encoding="utf-8"
|
||
) as file:
|
||
texts.append(file.read().strip())
|
||
return texts
|
||
|
||
|
||
# Cesta k textom pacientov s Alzheimerovou chorobou
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
project_root = os.path.abspath(os.path.join(script_dir, ".."))
|
||
dementia_folder = os.path.join(project_root, "text", "chory_eng_cleaned")
|
||
reference_texts_dementia = load_texts_from_folder(dementia_folder)
|
||
|
||
# Overenie, že priečinok obsahuje súbory
|
||
if not reference_texts_dementia:
|
||
raise ValueError("V zložke nie sú žiadne .txt súbory! Skontroluj cestu.")
|
||
|
||
print(f"Načítaných {len(reference_texts_dementia)} chorých textov na porovnanie.")
|
||
|
||
# ----------------------------------------------------------
|
||
# 3. Vzorky na porovnanie (originálny zdravý a vygenerovaný text)
|
||
# ----------------------------------------------------------
|
||
|
||
original_text_healthy = """there's a boy reaching into the cookie jar while standing on a stool
|
||
and he already took one cookie and is handing it to the girl next to him
|
||
and girl seems to be telling him to be quiet so their mom doesn't notice
|
||
and the sink is overflowing with water and nobody is paying attention
|
||
meanwhile, their mother is turned away, busy with something else"""
|
||
|
||
generated_text = """okay exc there's a boy taking cookies out of the cookie jar
|
||
and he's standing on a stool
|
||
and he's handing one to his sister
|
||
and she's laughing at him
|
||
and I don't know what else you want exc
|
||
I don't know exc
|
||
what else do you want exc
|
||
uh... that's all I see exc"""
|
||
|
||
# ----------------------------------------------------------
|
||
# 4. Výpočet ROUGE metriky (porovnávanie na úrovni tokenov/fráz)
|
||
# ----------------------------------------------------------
|
||
|
||
|
||
def evaluate_text(reference_texts, generated_text):
|
||
scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
|
||
rouge1_scores = []
|
||
rougeL_scores = []
|
||
|
||
for ref_text in reference_texts:
|
||
scores = scorer.score(ref_text, generated_text)
|
||
rouge1_scores.append(scores["rouge1"].fmeasure)
|
||
rougeL_scores.append(scores["rougeL"].fmeasure)
|
||
|
||
return {
|
||
"ROUGE-1-best": max(rouge1_scores),
|
||
"ROUGE-1-mean": sum(rouge1_scores) / len(rouge1_scores),
|
||
"ROUGE-L-best": max(rougeL_scores),
|
||
"ROUGE-L-mean": sum(rougeL_scores) / len(rougeL_scores),
|
||
}
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# 5. Výpočet BERTScore metriky (porovnávanie na úrovni významu)
|
||
# ----------------------------------------------------------
|
||
|
||
|
||
def compute_bert_score(reference_texts, generated_text):
|
||
try:
|
||
P, R, F1 = score(
|
||
cands=[generated_text]
|
||
* len(
|
||
reference_texts
|
||
), # rovnaký kandidát sa porovnáva so všetkými referenciami
|
||
refs=[
|
||
[ref] for ref in reference_texts
|
||
], # každá referencia musí byť v samostatnom zozname
|
||
lang="en",
|
||
rescale_with_baseline=False,
|
||
)
|
||
return {"BERTScore-F1": F1.mean().item()}
|
||
except Exception as e:
|
||
print(f"Chyba pri výpočte BERTScore: {e}")
|
||
return {"BERTScore-F1": 0}
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# 6. Vyhodnotenie kvality vygenerovaného textu
|
||
# ----------------------------------------------------------
|
||
|
||
results_to_dementia = evaluate_text(reference_texts_dementia, generated_text)
|
||
bert_results_dementia = compute_bert_score(reference_texts_dementia, generated_text)
|
||
|
||
results_to_original = evaluate_text([original_text_healthy], generated_text)
|
||
bert_results_original = compute_bert_score([original_text_healthy], generated_text)
|
||
|
||
# ----------------------------------------------------------
|
||
# 7. Výpis výsledkov pre hodnotenie v bakalárskej práci
|
||
# ----------------------------------------------------------
|
||
|
||
print(
|
||
"\nPorovnanie s textami pacientov s Alzheimerovou chorobou (ČÍM VYŠŠIE, TÝM LEPŠIE):"
|
||
)
|
||
print(
|
||
f"ROUGE-1 (najvyššie): {results_to_dementia['ROUGE-1-best']:.4f} | (priemer): {results_to_dementia['ROUGE-1-mean']:.4f}"
|
||
)
|
||
print(
|
||
f"ROUGE-L (najvyššie): {results_to_dementia['ROUGE-L-best']:.4f} | (priemer): {results_to_dementia['ROUGE-L-mean']:.4f}"
|
||
)
|
||
print(f"BERTScore-F1 (priemer): {bert_results_dementia['BERTScore-F1']:.4f}\n")
|
||
|
||
print("Porovnanie s pôvodným zdravým textom (ČÍM NIŽŠIE, TÝM LEPŠIE):")
|
||
print("(Porovnanie s jedným textom – hodnoty 'naj' a 'priemer' sú rovnaké)")
|
||
print(f"ROUGE-1: {results_to_original['ROUGE-1-best']:.4f}")
|
||
print(f"ROUGE-L: {results_to_original['ROUGE-L-best']:.4f}")
|
||
print(f"BERTScore-F1: {bert_results_original['BERTScore-F1']:.4f}")
|