import os import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ---------------------------------------------------------- # Načítanie predtrénovaného modelu a tokenizátora z disku # ---------------------------------------------------------- # Cesta k priečinku s trénovanou slovenskou T5 verziou script_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(script_dir, "trained_slovak_t5_cpu") # Kontrola existencie modelu na disku if not os.path.exists(model_path): raise FileNotFoundError(f"Priečinok s modelom nebol nájdený: {model_path}") # Načítanie modelu a tokenizátora z daného priečinka tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSeq2SeqLM.from_pretrained(model_path) # Výber zariadenia (GPU ak je dostupné, inak CPU) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print("Model bol úspešne načítaný a je pripravený na použitie.") # ---------------------------------------------------------- # Vstupný text zdravého človeka na transformáciu # ---------------------------------------------------------- healthy_text = """no mamička ktorá má všetko sama na starosť. varí drží plačúce dieťa na rukách. telefonuje čiže asi aj pracuje alebo vybavuje alebo ska a nie rozmýšľa tam. nakupuje topánky cez telefón. na polici nám sedí mačka ktorá vysypala. neviem či sú to klinčeky alebo niečo. padá nám to do na počítač. máme tam rozložené jedlo. mamička samozrejme nestíha. vypráža vajíčka robí hemendex. áno lebo letí tam šunka aj vajíčka tam sú. na polici sú tri knihy. poobede čiže nestíhajú obed. už je po dvanástej hodine. a dieťa je hladné strašne plače. není naň čas. mamička je proste v jednom kole stále.""" # Tokenizácia vstupného textu pre vstup do modelu inputs = tokenizer(healthy_text, return_tensors="pt").to(device) input_token_length = len(tokenizer.encode(healthy_text, add_special_tokens=False)) # Výpočet dynamických limitov pre generovanie výstupu min_length_ratio = 0.6 max_new_tokens_ratio = 0.5 min_length = int(input_token_length * min_length_ratio) max_new_tokens = min(75, int(input_token_length * max_new_tokens_ratio)) print(f"Počet tokenov vo vstupe: {input_token_length}") print( f"Minimálna dĺžka výstupu: {min_length}, maximálny počet nových tokenov: {max_new_tokens}" ) # ---------------------------------------------------------- # Konfigurácie generovania textu (rôzne nastavenia) # ---------------------------------------------------------- generation_configs = [ { "max_new_tokens": max_new_tokens, "do_sample": False, "num_beams": 5, "length_penalty": 1.5, "no_repeat_ngram_size": 4, "repetition_penalty": 3.5, "early_stopping": True, }, { "max_new_tokens": max_new_tokens, "do_sample": False, "num_beams": 7, "length_penalty": 1.5, "no_repeat_ngram_size": 4, "repetition_penalty": 3.5, "early_stopping": True, }, { "max_new_tokens": max_new_tokens, "do_sample": False, "num_beams": 8, "length_penalty": 1.5, "no_repeat_ngram_size": 4, "repetition_penalty": 3.5, "early_stopping": True, }, ] # ---------------------------------------------------------- # Vlastné generovanie textu podľa konfigurácií # ---------------------------------------------------------- generated_texts = [] for i, config in enumerate(generation_configs): print(f"\nGenerácia textu č. {i+1}:") # Parametre pre funkciu generate() gen_kwargs = { "min_length": min(min_length, config["max_new_tokens"] - 10), "max_new_tokens": config["max_new_tokens"], "repetition_penalty": config["repetition_penalty"], "num_beams": config.get("num_beams", 1), "length_penalty": config.get("length_penalty", 1.0), "early_stopping": config.get("early_stopping", True), } # Nepovinný parameter: zákaz opakovania n-gramov if "no_repeat_ngram_size" in config: gen_kwargs["no_repeat_ngram_size"] = config["no_repeat_ngram_size"] # Sampling (ak by bol použitý, tu je vypnutý) if config.get("do_sample", False): gen_kwargs["do_sample"] = True if "top_p" in config: gen_kwargs["top_p"] = config["top_p"] if "temperature" in config: gen_kwargs["temperature"] = config["temperature"] # Generovanie bez výpočtu gradientov with torch.no_grad(): outputs = model.generate(inputs["input_ids"], **gen_kwargs) # Dekódovanie výstupných tokenov na čitateľný text generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_texts.append(generated_text) # Výpis výsledku print("Vygenerovaný text:") print(generated_text) print("-" * 80)