import os import torch from transformers import BartTokenizer, BartForConditionalGeneration # ---------------------------------------------------------- # Načítanie trénovaného BART modelu a tokenizátora # ---------------------------------------------------------- # Cesta k priečinku s trénovaným modelom script_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.abspath( os.path.join(script_dir, "..", "BART_eng", "trained_bart_model_CLEAN_15") ) # Načítanie modelu a tokenizátora z disku tokenizer = BartTokenizer.from_pretrained(model_path) model = BartForConditionalGeneration.from_pretrained(model_path) # Výber zariadenia (GPU ak je dostupné, inak CPU) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print("Model bol úspešne načítaný a je pripravený na použitie.") # ---------------------------------------------------------- # Vstupný text zdravého človeka na transformáciu # ---------------------------------------------------------- healthy_text = """there's a boy reaching into the cookie jar while standing on a stool and he already took one cookie and is handing it to the girl next to him and girl seems to be telling him to be quiet so their mom doesn't notice and the sink is overflowing with water and nobody is paying attention meanwhile, their mother is turned away, busy with something else """ # Tokenizácia vstupného textu inputs = tokenizer(healthy_text, return_tensors="pt").to(device) input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False)) # Výpočet dynamických limitov min_length_ratio = 0.6 max_length_ratio = 1.1 min_length = int(input_token_length * min_length_ratio) max_length = int(input_token_length * max_length_ratio) print(f"Počet tokenov vo vstupe: {input_token_length}") print(f"Minimálna dĺžka výstupu: {min_length}, maximálna: {max_length}") # ---------------------------------------------------------- # Konfigurácie generovania textu # ---------------------------------------------------------- generation_configs = [ { "max_length": max_length, "min_length": min_length, "num_beams": 5, "no_repeat_ngram_size": 4, "repetition_penalty": 2.2, "length_penalty": 0.6, "early_stopping": True, }, { "max_length": max_length + 20, "min_length": min_length + 10, "num_beams": 7, "no_repeat_ngram_size": 5, "repetition_penalty": 2.2, "length_penalty": 0.8, "early_stopping": True, }, { "max_length": max_length + 40, "min_length": min_length + 20, "num_beams": 9, "no_repeat_ngram_size": 6, "repetition_penalty": 2.2, "length_penalty": 1.0, "early_stopping": True, }, ] # ---------------------------------------------------------- # Generovanie výstupov na základe rôznych konfigurácií # ---------------------------------------------------------- generated_texts = [] for i, config in enumerate(generation_configs): print(f"\nGenerácia textu č. {i+1}:") # Generovanie bez výpočtu gradientov with torch.no_grad(): outputs = model.generate( inputs["input_ids"], max_length=config["max_length"], min_length=config["min_length"], num_beams=config["num_beams"], no_repeat_ngram_size=config["no_repeat_ngram_size"], repetition_penalty=config["repetition_penalty"], length_penalty=config["length_penalty"], early_stopping=config["early_stopping"], ) # Dekódovanie výstupu na text generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_texts.append(generated_text) print("Vygenerovaný text:") print(generated_text) print("-" * 80)