# coding: utf-8 def convert(text, indices, vals, puns): # Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu modified_text = text for val, i in zip(vals, indices): # Pridanie zodpovedajucej interpunkcie v upravenom texte modified_text.insert(val, puns[i - 1]) return modified_text # kniznice from transformers import RobertaTokenizer, RobertaForMaskedLM from transformers import DataCollatorForLanguageModeling #maskovacei modely tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert') model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert') import torch import nltk from nltk.tokenize import word_tokenize, sent_tokenize # importovanie modulu pre manipuláciu s textom import re # Stiahnutie obsahu tokenizerov # nltk.download('punkt') # Importovanie kniznic a modulov from transformers import DataCollatorForLanguageModeling, AdamW from torch.utils.data import DataLoader from nltk.tokenize import sent_tokenize def fine_tuning(texts, model, tokenizer): # Kontrola textu či je spravna if len(texts) == 0: return model # Spracovanie textu def preprocess_for_punctuation(texts): processed_texts = [] for text in texts: # Maskovanie interpunkcie pomocou tokenov text = re.sub(r'[.,?!:-]', '[MASK]', text) processed_texts.append(text) return processed_texts # Aplikuje spracovanie na vstupne texty texts = preprocess_for_punctuation(texts) # Tokenizuje a encoduje spravoané texty encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512) # Definicia vlastneho datasetu class MLM_Dataset(torch.utils.data.Dataset): def __init__(self, encodings): self.encodings = encodings def __len__(self): return len(self.encodings['input_ids']) def __getitem__(self, idx): return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} # Vytvorenie valstneho datasetu dataset = MLM_Dataset(encodings) # Vytvorenie dat pre MLM data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) # Vytvorenie DataLoader pre trenovanie modelu dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator) # Optimalizaotr pre trenovanie optimizer = AdamW(model.parameters(), lr=5e-5) # Nastavenie epoch na trenovanie epochs = 1 print("Zaciatok trenovania modelu...") # Trenovanie for epoch in range(epochs): model.train() for batch in dataloader: # Vynuluje pred spätným prechodom optimizer.zero_grad() # Presunutie vstupov inputs = {k: v.to(model.device) for k, v in batch.items()} outputs = model(**inputs) loss = outputs.loss loss.backward() optimizer.step() print("Ucenie dokoncene.") # Vratenie sa k modelu return model # Obnovenie interpunkcie def restore_pun(text, model): # Tokenizacia vstupného textu words = nltk.word_tokenize(text) # Opakovanie slov for i in range(1, len(words)): current = words[i] # Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie if current not in [".", ",", "?", "!" ,":","-"]: words[i] += " " current_pun = "no" else: current_pun = current words[i] = " " # Spojenie slov do retazca x = " ".join(words) # Encodovanie vstupu pomocou tokenizera encoded_input = tokenizer(x, return_tensors='pt') # vystup cez encode output = model(**encoded_input) # najdenie indexu maskovaneho tokenu vo vstupe mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] mask_token_logits = output.logits[0, mask_token_index, :] # Najdenie tokeu s najvecsou pravdepodobnostou predicted_token_id = torch.argmax(mask_token_logits).item() predicted_token = tokenizer.decode([predicted_token_id]) # Aktualizuje slovo na zaklade tokenu if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: words[i] = current + predicted_token elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: words[i] = predicted_token else: words[i] = current # Spojenie slov do reťazca s vysledkom out = " ".join(words) return out # Vybranie co chceme s programom robit while True: option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ') #1 Trenovanie if option == '1': file_path = input('Zadajte subor s datami') # importovanie json import json # Cita a analyzuje kazdy riadok ako samsotatny json objekt json_objects = [] with open(file_path, 'r') as file: for line in file: try: json_object = json.loads(line) json_objects.append(json_object) except json.JSONDecodeError: continue # Definovanie interpunkcie na trenovanie puns = ['.', ',', '?', '!', ':', '-'] # Spracovanie a ucenie texts = [] for i in range(len(json_objects)): indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0] val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0] # Uprava textu json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns) # Pridanie upraveneho textu d ozoznamu texts.append(" ".join(json_objects[i]['text'])) # doladovanie modelu model = fine_tuning(texts[:], model, tokenizer) #2: Oprava interpunkciet elif option == '2': # Vlozenie textu bez interpunkcie alebo so zlou interpunkciou test = input('Enter your text: ') # Vypisanie textu print("Output:", restore_pun(test, model)) #3: Ukoncenei programu else: break