From 2ae4f9ed63345eede8b9b1d37fb4b3f7e23de7bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Remia=C5=A1?= Date: Wed, 6 Dec 2023 12:46:41 +0000 Subject: [PATCH] Update 'slovak_punction2.py' --- slovak_punction2.py | 114 +++++++++++++++++++++++++++++++++----------- 1 file changed, 86 insertions(+), 28 deletions(-) diff --git a/slovak_punction2.py b/slovak_punction2.py index 018f8e0..c5c6d69 100644 --- a/slovak_punction2.py +++ b/slovak_punction2.py @@ -1,42 +1,100 @@ -# import modulov -from transformers import pipeline -from nltk.tokenize import sent_tokenize +# -*- coding: utf-8 -*- + +from transformers import RobertaTokenizer, RobertaForMaskedLM + +tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert') +model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert') + +import torch import nltk +from nltk.tokenize import word_tokenize ,sent_tokenize +import re nltk.download('punkt') -# funkcia na obnovu interpunkcie -def restore_punctuation(text): - sents = sent_tokenize(text) - new_text = "" - labels = ['.', '!', ',', ':', '?', '-', ";"] +text="text pre trenovanie neuronovej siete" - for sent in sents: - sent = ''.join(ch for ch in sent if ch not in labels) - text_word = sent.split() - words = text_word[:] +# Example: tokenizing a list of text strings +texts = sent_tokenize(text , "slovene") +encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512) - unmasker = pipeline('fill-mask', model='gerulata/slovakbert') +import torch - for i in range(1, len(text_word) + 1): - text_word.insert(i, '') - sent = " ".join(text_word) - text_with_punc = unmasker(sent) +class MLM_Dataset(torch.utils.data.Dataset): + def __init__(self, encodings): + self.encodings = encodings - if text_with_punc[0]['token_str'] in labels: - words[i - 1] = words[i - 1] + text_with_punc[0]['token_str'] + def __len__(self): + return len(self.encodings['input_ids']) - text_word = words[:] + def __getitem__(self, idx): + return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} - new_text += " ".join(words) +dataset = MLM_Dataset(encodings) - return new_text +from transformers import DataCollatorForLanguageModeling -# Zadanie textu pre opravu interpunkcie -input_text = input("Zadajte text na opravu interpunkcie: ") -output_text = restore_punctuation(input_text) +data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm=True, mlm_probability=0.15 +) -# Výpis pôvodného a opraveného textu -print("Pôvodný text:", input_text) -print("Opravený text:", output_text) +from torch.utils.data import DataLoader +dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator) + +from transformers import AdamW + +optimizer = AdamW(model.parameters(), lr=5e-5) + +epochs = 1 +for epoch in range(epochs): + model.train() + for batch in dataloader: + optimizer.zero_grad() + outputs = model(**{k: v.to(model.device) for k, v in batch.items()}) + loss = outputs.loss + loss.backward() + optimizer.step() + +model.save_pretrained('path/to/save/model') + +input="ako sa voláš" + +def restore_pun(text): + words=nltk.word_tokenize(text ) + for i in range (1,len(words)): + current=words[i] + if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]: + words[i] +=" " + current_pun="no" + else : + current_pun=words[i] + words[i]=" " + current_pun=words[i] + x=" ".join(words) + + encoded_input = tokenizer(x, return_tensors='pt') + output = model(**encoded_input) + mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] + + # Extract the logits for the masked token + mask_token_logits = output.logits[0, mask_token_index, :] + + # Find the token with the highest probability + predicted_token_id = torch.argmax(mask_token_logits).item() + predicted_token = tokenizer.decode([predicted_token_id]) + + if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : + words[i]=current+ predicted_token + elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : + words[i]= predicted_token + else : + words[i]=current + out=" ".join(words) + return out + +import nltk +nltk.download('punkt') + +print("input : " , input) +print ("output :" ,restore_pun(input))