BakalarskaPraca/BART_eng/METRIKY.py

155 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import nltk
from rouge_score import rouge_scorer
from bert_score import score
import warnings
import logging
from transformers.utils import logging as hf_logging
# ----------------------------------------------------------
# 1. Nastavenie prostredia potlačenie varovaní a logov
# ----------------------------------------------------------
# Automatické stiahnutie 'punkt' ak ešte nie je k dispozícii
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
# Skrytie varovaní BERTScore pri načítavaní modelov
warnings.filterwarnings(
"ignore", message="Some weights of RobertaModel were not initialized"
)
# Potlačenie logov z knižnice transformers
hf_logging.set_verbosity_error()
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
logging.getLogger("transformers.configuration_utils").setLevel(logging.ERROR)
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
# ----------------------------------------------------------
# 2. Načítanie referenčných textov pacientov
# ----------------------------------------------------------
# Pomocná funkcia na načítanie obsahu všetkých .txt súborov v priečinku
def load_texts_from_folder(folder_path):
texts = []
for filename in os.listdir(folder_path):
if filename.endswith(".txt"):
with open(
os.path.join(folder_path, filename), "r", encoding="utf-8"
) as file:
texts.append(file.read().strip())
return texts
# Cesta k textom pacientov s Alzheimerovou chorobou
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, ".."))
dementia_folder = os.path.join(project_root, "text", "chory_eng_cleaned")
reference_texts_dementia = load_texts_from_folder(dementia_folder)
# Overenie, že priečinok obsahuje súbory
if not reference_texts_dementia:
raise ValueError("V zložke nie sú žiadne .txt súbory! Skontroluj cestu.")
print(f"Načítaných {len(reference_texts_dementia)} chorých textov na porovnanie.")
# ----------------------------------------------------------
# 3. Vzorky na porovnanie (originálny zdravý a vygenerovaný text)
# ----------------------------------------------------------
original_text_healthy = """there's a boy reaching into the cookie jar while standing on a stool
and he already took one cookie and is handing it to the girl next to him
and girl seems to be telling him to be quiet so their mom doesn't notice
and the sink is overflowing with water and nobody is paying attention
meanwhile, their mother is turned away, busy with something else"""
generated_text = """okay exc there's a boy taking cookies out of the cookie jar
and he's standing on a stool
and he's handing one to his sister
and she's laughing at him
and I don't know what else you want exc
I don't know exc
what else do you want exc
uh... that's all I see exc"""
# ----------------------------------------------------------
# 4. Výpočet ROUGE metriky (porovnávanie na úrovni tokenov/fráz)
# ----------------------------------------------------------
def evaluate_text(reference_texts, generated_text):
scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
rouge1_scores = []
rougeL_scores = []
for ref_text in reference_texts:
scores = scorer.score(ref_text, generated_text)
rouge1_scores.append(scores["rouge1"].fmeasure)
rougeL_scores.append(scores["rougeL"].fmeasure)
return {
"ROUGE-1-best": max(rouge1_scores),
"ROUGE-1-mean": sum(rouge1_scores) / len(rouge1_scores),
"ROUGE-L-best": max(rougeL_scores),
"ROUGE-L-mean": sum(rougeL_scores) / len(rougeL_scores),
}
# ----------------------------------------------------------
# 5. Výpočet BERTScore metriky (porovnávanie na úrovni významu)
# ----------------------------------------------------------
def compute_bert_score(reference_texts, generated_text):
try:
P, R, F1 = score(
cands=[generated_text]
* len(
reference_texts
), # rovnaký kandidát sa porovnáva so všetkými referenciami
refs=[
[ref] for ref in reference_texts
], # každá referencia musí byť v samostatnom zozname
lang="en",
rescale_with_baseline=False,
)
return {"BERTScore-F1": F1.mean().item()}
except Exception as e:
print(f"Chyba pri výpočte BERTScore: {e}")
return {"BERTScore-F1": 0}
# ----------------------------------------------------------
# 6. Vyhodnotenie kvality vygenerovaného textu
# ----------------------------------------------------------
results_to_dementia = evaluate_text(reference_texts_dementia, generated_text)
bert_results_dementia = compute_bert_score(reference_texts_dementia, generated_text)
results_to_original = evaluate_text([original_text_healthy], generated_text)
bert_results_original = compute_bert_score([original_text_healthy], generated_text)
# ----------------------------------------------------------
# 7. Výpis výsledkov pre hodnotenie v bakalárskej práci
# ----------------------------------------------------------
print(
"\nPorovnanie s textami pacientov s Alzheimerovou chorobou (ČÍM VYŠŠIE, TÝM LEPŠIE):"
)
print(
f"ROUGE-1 (najvyššie): {results_to_dementia['ROUGE-1-best']:.4f} | (priemer): {results_to_dementia['ROUGE-1-mean']:.4f}"
)
print(
f"ROUGE-L (najvyššie): {results_to_dementia['ROUGE-L-best']:.4f} | (priemer): {results_to_dementia['ROUGE-L-mean']:.4f}"
)
print(f"BERTScore-F1 (priemer): {bert_results_dementia['BERTScore-F1']:.4f}\n")
print("Porovnanie s pôvodným zdravým textom (ČÍM NIŽŠIE, TÝM LEPŠIE):")
print("(Porovnanie s jedným textom hodnoty 'naj' a 'priemer' sú rovnaké)")
print(f"ROUGE-1: {results_to_original['ROUGE-1-best']:.4f}")
print(f"ROUGE-L: {results_to_original['ROUGE-L-best']:.4f}")
print(f"BERTScore-F1: {bert_results_original['BERTScore-F1']:.4f}")