125 lines
5.0 KiB
Python
125 lines
5.0 KiB
Python
import os
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
# ----------------------------------------------------------
|
|
# Načítanie predtrénovaného modelu a tokenizátora z disku
|
|
# ----------------------------------------------------------
|
|
|
|
# Cesta k priečinku s trénovanou slovenskou T5 verziou
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
model_path = os.path.join(script_dir, "trained_slovak_t5_cpu")
|
|
|
|
# Kontrola existencie modelu na disku
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"Priečinok s modelom nebol nájdený: {model_path}")
|
|
|
|
# Načítanie modelu a tokenizátora z daného priečinka
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
|
|
|
# Výber zariadenia (GPU ak je dostupné, inak CPU)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model.to(device)
|
|
|
|
print("Model bol úspešne načítaný a je pripravený na použitie.")
|
|
|
|
# ----------------------------------------------------------
|
|
# Vstupný text zdravého človeka na transformáciu
|
|
# ----------------------------------------------------------
|
|
|
|
healthy_text = """no mamička ktorá má všetko sama na starosť. varí drží plačúce dieťa na rukách. telefonuje čiže asi aj pracuje alebo vybavuje alebo ska a nie rozmýšľa tam. nakupuje topánky cez telefón. na polici nám sedí mačka ktorá vysypala. neviem či sú to klinčeky alebo niečo. padá nám to do na počítač. máme tam rozložené jedlo. mamička samozrejme nestíha. vypráža vajíčka robí hemendex. áno lebo letí tam šunka aj vajíčka tam sú. na polici sú tri knihy. poobede čiže nestíhajú obed. už je po dvanástej hodine. a dieťa je hladné strašne plače. není naň čas. mamička je proste v jednom kole stále."""
|
|
|
|
# Tokenizácia vstupného textu pre vstup do modelu
|
|
inputs = tokenizer(healthy_text, return_tensors="pt").to(device)
|
|
input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False))
|
|
|
|
# Výpočet dynamických limitov pre generovanie výstupu
|
|
min_length_ratio = 0.6
|
|
max_new_tokens_ratio = 0.5
|
|
|
|
min_length = int(input_token_length * min_length_ratio)
|
|
max_new_tokens = min(75, int(input_token_length * max_new_tokens_ratio))
|
|
|
|
print(f"Počet tokenov vo vstupe: {input_token_length}")
|
|
print(
|
|
f"Minimálna dĺžka výstupu: {min_length}, maximálny počet nových tokenov: {max_new_tokens}"
|
|
)
|
|
|
|
# ----------------------------------------------------------
|
|
# Konfigurácie generovania textu (rôzne nastavenia)
|
|
# ----------------------------------------------------------
|
|
|
|
generation_configs = [
|
|
{
|
|
"max_new_tokens": max_new_tokens,
|
|
"do_sample": False,
|
|
"num_beams": 5,
|
|
"length_penalty": 1.5,
|
|
"no_repeat_ngram_size": 4,
|
|
"repetition_penalty": 3.5,
|
|
"early_stopping": True,
|
|
},
|
|
{
|
|
"max_new_tokens": max_new_tokens,
|
|
"do_sample": False,
|
|
"num_beams": 7,
|
|
"length_penalty": 1.5,
|
|
"no_repeat_ngram_size": 4,
|
|
"repetition_penalty": 3.5,
|
|
"early_stopping": True,
|
|
},
|
|
{
|
|
"max_new_tokens": max_new_tokens,
|
|
"do_sample": False,
|
|
"num_beams": 8,
|
|
"length_penalty": 1.5,
|
|
"no_repeat_ngram_size": 4,
|
|
"repetition_penalty": 3.5,
|
|
"early_stopping": True,
|
|
},
|
|
]
|
|
|
|
# ----------------------------------------------------------
|
|
# Vlastné generovanie textu podľa konfigurácií
|
|
# ----------------------------------------------------------
|
|
|
|
generated_texts = []
|
|
for i, config in enumerate(generation_configs):
|
|
print(f"\nGenerácia textu č. {i+1}:")
|
|
|
|
# Parametre pre funkciu generate()
|
|
gen_kwargs = {
|
|
"min_length": min(min_length, config["max_new_tokens"] - 10),
|
|
"max_new_tokens": config["max_new_tokens"],
|
|
"repetition_penalty": config["repetition_penalty"],
|
|
"num_beams": config.get("num_beams", 1),
|
|
"length_penalty": config.get("length_penalty", 1.0),
|
|
"early_stopping": config.get("early_stopping", True),
|
|
}
|
|
|
|
# Nepovinný parameter: zákaz opakovania n-gramov
|
|
if "no_repeat_ngram_size" in config:
|
|
gen_kwargs["no_repeat_ngram_size"] = config["no_repeat_ngram_size"]
|
|
|
|
# Sampling (ak by bol použitý, tu je vypnutý)
|
|
if config.get("do_sample", False):
|
|
gen_kwargs["do_sample"] = True
|
|
if "top_p" in config:
|
|
gen_kwargs["top_p"] = config["top_p"]
|
|
if "temperature" in config:
|
|
gen_kwargs["temperature"] = config["temperature"]
|
|
|
|
# Generovanie bez výpočtu gradientov
|
|
with torch.no_grad():
|
|
outputs = model.generate(inputs["input_ids"], **gen_kwargs)
|
|
|
|
# Dekódovanie výstupných tokenov na čitateľný text
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
generated_texts.append(generated_text)
|
|
|
|
# Výpis výsledku
|
|
print("Vygenerovaný text:")
|
|
print(generated_text)
|
|
print("-" * 80)
|