BakalarskaPraca/T5_slovak_4_3_1/modelClean.py

125 lines
5.0 KiB
Python

import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ----------------------------------------------------------
# Načítanie predtrénovaného modelu a tokenizátora z disku
# ----------------------------------------------------------
# Cesta k priečinku s trénovanou slovenskou T5 verziou
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "trained_slovak_t5_cpu")
# Kontrola existencie modelu na disku
if not os.path.exists(model_path):
raise FileNotFoundError(f"Priečinok s modelom nebol nájdený: {model_path}")
# Načítanie modelu a tokenizátora z daného priečinka
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
# Výber zariadenia (GPU ak je dostupné, inak CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("Model bol úspešne načítaný a je pripravený na použitie.")
# ----------------------------------------------------------
# Vstupný text zdravého človeka na transformáciu
# ----------------------------------------------------------
healthy_text = """no mamička ktorá má všetko sama na starosť. varí drží plačúce dieťa na rukách. telefonuje čiže asi aj pracuje alebo vybavuje alebo ska a nie rozmýšľa tam. nakupuje topánky cez telefón. na polici nám sedí mačka ktorá vysypala. neviem či sú to klinčeky alebo niečo. padá nám to do na počítač. máme tam rozložené jedlo. mamička samozrejme nestíha. vypráža vajíčka robí hemendex. áno lebo letí tam šunka aj vajíčka tam sú. na polici sú tri knihy. poobede čiže nestíhajú obed. už je po dvanástej hodine. a dieťa je hladné strašne plače. není naň čas. mamička je proste v jednom kole stále."""
# Tokenizácia vstupného textu pre vstup do modelu
inputs = tokenizer(healthy_text, return_tensors="pt").to(device)
input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False))
# Výpočet dynamických limitov pre generovanie výstupu
min_length_ratio = 0.6
max_new_tokens_ratio = 0.5
min_length = int(input_token_length * min_length_ratio)
max_new_tokens = min(75, int(input_token_length * max_new_tokens_ratio))
print(f"Počet tokenov vo vstupe: {input_token_length}")
print(
f"Minimálna dĺžka výstupu: {min_length}, maximálny počet nových tokenov: {max_new_tokens}"
)
# ----------------------------------------------------------
# Konfigurácie generovania textu (rôzne nastavenia)
# ----------------------------------------------------------
generation_configs = [
{
"max_new_tokens": max_new_tokens,
"do_sample": False,
"num_beams": 5,
"length_penalty": 1.5,
"no_repeat_ngram_size": 4,
"repetition_penalty": 3.5,
"early_stopping": True,
},
{
"max_new_tokens": max_new_tokens,
"do_sample": False,
"num_beams": 7,
"length_penalty": 1.5,
"no_repeat_ngram_size": 4,
"repetition_penalty": 3.5,
"early_stopping": True,
},
{
"max_new_tokens": max_new_tokens,
"do_sample": False,
"num_beams": 8,
"length_penalty": 1.5,
"no_repeat_ngram_size": 4,
"repetition_penalty": 3.5,
"early_stopping": True,
},
]
# ----------------------------------------------------------
# Vlastné generovanie textu podľa konfigurácií
# ----------------------------------------------------------
generated_texts = []
for i, config in enumerate(generation_configs):
print(f"\nGenerácia textu č. {i+1}:")
# Parametre pre funkciu generate()
gen_kwargs = {
"min_length": min(min_length, config["max_new_tokens"] - 10),
"max_new_tokens": config["max_new_tokens"],
"repetition_penalty": config["repetition_penalty"],
"num_beams": config.get("num_beams", 1),
"length_penalty": config.get("length_penalty", 1.0),
"early_stopping": config.get("early_stopping", True),
}
# Nepovinný parameter: zákaz opakovania n-gramov
if "no_repeat_ngram_size" in config:
gen_kwargs["no_repeat_ngram_size"] = config["no_repeat_ngram_size"]
# Sampling (ak by bol použitý, tu je vypnutý)
if config.get("do_sample", False):
gen_kwargs["do_sample"] = True
if "top_p" in config:
gen_kwargs["top_p"] = config["top_p"]
if "temperature" in config:
gen_kwargs["temperature"] = config["temperature"]
# Generovanie bez výpočtu gradientov
with torch.no_grad():
outputs = model.generate(inputs["input_ids"], **gen_kwargs)
# Dekódovanie výstupných tokenov na čitateľný text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_texts.append(generated_text)
# Výpis výsledku
print("Vygenerovaný text:")
print(generated_text)
print("-" * 80)