113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
import os
|
|
import torch
|
|
from transformers import BartTokenizer, BartForConditionalGeneration
|
|
|
|
# ----------------------------------------------------------
|
|
# Načítanie trénovaného BART modelu a tokenizátora
|
|
# ----------------------------------------------------------
|
|
|
|
# Cesta k priečinku s trénovaným modelom
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
model_path = os.path.abspath(
|
|
os.path.join(script_dir, "..", "BART_eng", "trained_bart_model_CLEAN_15")
|
|
)
|
|
|
|
|
|
# Načítanie modelu a tokenizátora z disku
|
|
tokenizer = BartTokenizer.from_pretrained(model_path)
|
|
model = BartForConditionalGeneration.from_pretrained(model_path)
|
|
|
|
# Výber zariadenia (GPU ak je dostupné, inak CPU)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model.to(device)
|
|
|
|
print("Model bol úspešne načítaný a je pripravený na použitie.")
|
|
|
|
# ----------------------------------------------------------
|
|
# Vstupný text zdravého človeka na transformáciu
|
|
# ----------------------------------------------------------
|
|
|
|
healthy_text = """there's a boy reaching into the cookie jar while standing on a stool
|
|
and he already took one cookie and is handing it to the girl next to him
|
|
and girl seems to be telling him to be quiet so their mom doesn't notice
|
|
and the sink is overflowing with water and nobody is paying attention
|
|
meanwhile, their mother is turned away, busy with something else
|
|
"""
|
|
|
|
# Tokenizácia vstupného textu
|
|
inputs = tokenizer(healthy_text, return_tensors="pt").to(device)
|
|
input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False))
|
|
|
|
# Výpočet dynamických limitov
|
|
min_length_ratio = 0.6
|
|
max_length_ratio = 1.1
|
|
|
|
min_length = int(input_token_length * min_length_ratio)
|
|
max_length = int(input_token_length * max_length_ratio)
|
|
|
|
print(f"Počet tokenov vo vstupe: {input_token_length}")
|
|
print(f"Minimálna dĺžka výstupu: {min_length}, maximálna: {max_length}")
|
|
|
|
# ----------------------------------------------------------
|
|
# Konfigurácie generovania textu
|
|
# ----------------------------------------------------------
|
|
|
|
generation_configs = [
|
|
{
|
|
"max_length": max_length,
|
|
"min_length": min_length,
|
|
"num_beams": 5,
|
|
"no_repeat_ngram_size": 4,
|
|
"repetition_penalty": 2.2,
|
|
"length_penalty": 0.6,
|
|
"early_stopping": True,
|
|
},
|
|
{
|
|
"max_length": max_length + 20,
|
|
"min_length": min_length + 10,
|
|
"num_beams": 7,
|
|
"no_repeat_ngram_size": 5,
|
|
"repetition_penalty": 2.2,
|
|
"length_penalty": 0.8,
|
|
"early_stopping": True,
|
|
},
|
|
{
|
|
"max_length": max_length + 40,
|
|
"min_length": min_length + 20,
|
|
"num_beams": 9,
|
|
"no_repeat_ngram_size": 6,
|
|
"repetition_penalty": 2.2,
|
|
"length_penalty": 1.0,
|
|
"early_stopping": True,
|
|
},
|
|
]
|
|
|
|
# ----------------------------------------------------------
|
|
# Generovanie výstupov na základe rôznych konfigurácií
|
|
# ----------------------------------------------------------
|
|
|
|
generated_texts = []
|
|
for i, config in enumerate(generation_configs):
|
|
print(f"\nGenerácia textu č. {i+1}:")
|
|
|
|
# Generovanie bez výpočtu gradientov
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
inputs["input_ids"],
|
|
max_length=config["max_length"],
|
|
min_length=config["min_length"],
|
|
num_beams=config["num_beams"],
|
|
no_repeat_ngram_size=config["no_repeat_ngram_size"],
|
|
repetition_penalty=config["repetition_penalty"],
|
|
length_penalty=config["length_penalty"],
|
|
early_stopping=config["early_stopping"],
|
|
)
|
|
|
|
# Dekódovanie výstupu na text
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
generated_texts.append(generated_text)
|
|
|
|
print("Vygenerovaný text:")
|
|
print(generated_text)
|
|
print("-" * 80)
|