BakalarskaPraca/BART_eng/CLEAN_15.py

113 lines
3.9 KiB
Python

import os
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
# ----------------------------------------------------------
# Načítanie trénovaného BART modelu a tokenizátora
# ----------------------------------------------------------
# Cesta k priečinku s trénovaným modelom
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.abspath(
os.path.join(script_dir, "..", "BART_eng", "trained_bart_model_CLEAN_15")
)
# Načítanie modelu a tokenizátora z disku
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)
# Výber zariadenia (GPU ak je dostupné, inak CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("Model bol úspešne načítaný a je pripravený na použitie.")
# ----------------------------------------------------------
# Vstupný text zdravého človeka na transformáciu
# ----------------------------------------------------------
healthy_text = """there's a boy reaching into the cookie jar while standing on a stool
and he already took one cookie and is handing it to the girl next to him
and girl seems to be telling him to be quiet so their mom doesn't notice
and the sink is overflowing with water and nobody is paying attention
meanwhile, their mother is turned away, busy with something else
"""
# Tokenizácia vstupného textu
inputs = tokenizer(healthy_text, return_tensors="pt").to(device)
input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False))
# Výpočet dynamických limitov
min_length_ratio = 0.6
max_length_ratio = 1.1
min_length = int(input_token_length * min_length_ratio)
max_length = int(input_token_length * max_length_ratio)
print(f"Počet tokenov vo vstupe: {input_token_length}")
print(f"Minimálna dĺžka výstupu: {min_length}, maximálna: {max_length}")
# ----------------------------------------------------------
# Konfigurácie generovania textu
# ----------------------------------------------------------
generation_configs = [
{
"max_length": max_length,
"min_length": min_length,
"num_beams": 5,
"no_repeat_ngram_size": 4,
"repetition_penalty": 2.2,
"length_penalty": 0.6,
"early_stopping": True,
},
{
"max_length": max_length + 20,
"min_length": min_length + 10,
"num_beams": 7,
"no_repeat_ngram_size": 5,
"repetition_penalty": 2.2,
"length_penalty": 0.8,
"early_stopping": True,
},
{
"max_length": max_length + 40,
"min_length": min_length + 20,
"num_beams": 9,
"no_repeat_ngram_size": 6,
"repetition_penalty": 2.2,
"length_penalty": 1.0,
"early_stopping": True,
},
]
# ----------------------------------------------------------
# Generovanie výstupov na základe rôznych konfigurácií
# ----------------------------------------------------------
generated_texts = []
for i, config in enumerate(generation_configs):
print(f"\nGenerácia textu č. {i+1}:")
# Generovanie bez výpočtu gradientov
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=config["max_length"],
min_length=config["min_length"],
num_beams=config["num_beams"],
no_repeat_ngram_size=config["no_repeat_ngram_size"],
repetition_penalty=config["repetition_penalty"],
length_penalty=config["length_penalty"],
early_stopping=config["early_stopping"],
)
# Dekódovanie výstupu na text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_texts.append(generated_text)
print("Vygenerovaný text:")
print(generated_text)
print("-" * 80)