diff --git a/slovak_punction2.py b/slovak_punction2.py index 4643ea1..47f1e9a 100644 --- a/slovak_punction2.py +++ b/slovak_punction2.py @@ -1,54 +1,201 @@ -# -*- coding: utf-8 -*- +# coding: utf-8 +def convert(text, indices, vals, puns): + # Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu + modified_text = text + + + for val, i in zip(vals, indices): + # Pridanie zodpovedajucej interpunkcie v upravenom texte + modified_text.insert(val, puns[i - 1]) + + + return modified_text + +# kniznice from transformers import RobertaTokenizer, RobertaForMaskedLM - +from transformers import DataCollatorForLanguageModeling +#maskovacei modely tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert') model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert') import torch + import nltk -from nltk.tokenize import word_tokenize ,sent_tokenize +from nltk.tokenize import word_tokenize, sent_tokenize + +# importovanie modulu pre manipuláciu s textom import re -nltk.download('punkt') +# Stiahnutie obsahu tokenizerov +# nltk.download('punkt') -input="ako sa voláš" +# Importovanie kniznic a modulov +from transformers import DataCollatorForLanguageModeling, AdamW +from torch.utils.data import DataLoader +from nltk.tokenize import sent_tokenize -def restore_pun(text): - words=nltk.word_tokenize(text ) - for i in range (1,len(words)): - current=words[i] - if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]: - words[i] +=" " - current_pun="no" - else : - current_pun=words[i] - words[i]=" " - current_pun=words[i] - x=" ".join(words) +def fine_tuning(texts, model, tokenizer): + # Kontrola textu či je spravna + if len(texts) == 0: + return model - encoded_input = tokenizer(x, return_tensors='pt') - output = model(**encoded_input) - mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] + # Spracovanie textu + def preprocess_for_punctuation(texts): + processed_texts = [] + for text in texts: + # Maskovanie interpunkcie pomocou tokenov + text = re.sub(r'[.,?!:-]', '[MASK]', text) + processed_texts.append(text) + return processed_texts - # Extract the logits for the masked token - mask_token_logits = output.logits[0, mask_token_index, :] + # Aplikuje spracovanie na vstupne texty + texts = preprocess_for_punctuation(texts) - # Find the token with the highest probability - predicted_token_id = torch.argmax(mask_token_logits).item() - predicted_token = tokenizer.decode([predicted_token_id]) + # Tokenizuje a encoduje spravoané texty + encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512) - if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : - words[i]=current+ predicted_token - elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : - words[i]= predicted_token - else : - words[i]=current - out=" ".join(words) - return out + # Definicia vlastneho datasetu + class MLM_Dataset(torch.utils.data.Dataset): + def __init__(self, encodings): + self.encodings = encodings -import nltk -nltk.download('punkt') + def __len__(self): + return len(self.encodings['input_ids']) -print("input : " , input) -print ("output :" ,restore_pun(input)) + def __getitem__(self, idx): + return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} + + # Vytvorenie valstneho datasetu + dataset = MLM_Dataset(encodings) + + # Vytvorenie dat pre MLM + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) + + # Vytvorenie DataLoader pre trenovanie modelu + dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator) + + # Optimalizaotr pre trenovanie + optimizer = AdamW(model.parameters(), lr=5e-5) + + # Nastavenie epoch na trenovanie + epochs = 1 + + print("Zaciatok trenovania modelu...") + # Trenovanie + for epoch in range(epochs): + model.train() + for batch in dataloader: + # Vynuluje pred spätným prechodom + optimizer.zero_grad() + # Presunutie vstupov + inputs = {k: v.to(model.device) for k, v in batch.items()} + outputs = model(**inputs) + loss = outputs.loss + loss.backward() + optimizer.step() + + print("Ucenie dokoncene.") + + # Vratenie sa k modelu + return model + + + +# Obnovenie interpunkcie +def restore_pun(text, model): + # Tokenizacia vstupného textu + words = nltk.word_tokenize(text) + + # Opakovanie slov + for i in range(1, len(words)): + current = words[i] + + # Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie + if current not in [".", ",", "?", "!" ,":","-"]: + words[i] += " " + current_pun = "no" + else: + current_pun = current + words[i] = " " + + # Spojenie slov do retazca + x = " ".join(words) + + # Encodovanie vstupu pomocou tokenizera + encoded_input = tokenizer(x, return_tensors='pt') + + # vystup cez encode + output = model(**encoded_input) + + # najdenie indexu maskovaneho tokenu vo vstupe + mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] + + mask_token_logits = output.logits[0, mask_token_index, :] + + # Najdenie tokeu s najvecsou pravdepodobnostou + predicted_token_id = torch.argmax(mask_token_logits).item() + predicted_token = tokenizer.decode([predicted_token_id]) + + # Aktualizuje slovo na zaklade tokenu + if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: + words[i] = current + predicted_token + elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: + words[i] = predicted_token + else: + words[i] = current + + # Spojenie slov do reťazca s vysledkom + out = " ".join(words) + return out + +# Vybranie co chceme s programom robit +while True: + option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ') + + #1 Trenovanie + if option == '1': + file_path = input('Zadajte subor s datami') + + # importovanie json + import json + + # Cita a analyzuje kazdy riadok ako samsotatny json objekt + json_objects = [] + with open(file_path, 'r') as file: + for line in file: + try: + json_object = json.loads(line) + json_objects.append(json_object) + except json.JSONDecodeError: + continue + + # Definovanie interpunkcie na trenovanie + puns = ['.', ',', '?', '!', ':', '-'] + + # Spracovanie a ucenie + texts = [] + for i in range(len(json_objects)): + indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0] + val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0] + + # Uprava textu + json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns) + + # Pridanie upraveneho textu d ozoznamu + texts.append(" ".join(json_objects[i]['text'])) + + # doladovanie modelu + model = fine_tuning(texts[:], model, tokenizer) + + #2: Oprava interpunkciet + elif option == '2': + # Vlozenie textu bez interpunkcie alebo so zlou interpunkciou + test = input('Enter your text: ') + + # Vypisanie textu + print("Output:", restore_pun(test, model)) + + #3: Ukoncenei programu + else: + break \ No newline at end of file