Update 'slovak_punction2.py'

This commit is contained in:
Adrián Remiaš 2023-12-06 12:46:41 +00:00
parent 54e835669c
commit 2ae4f9ed63

View File

@ -1,42 +1,100 @@
# import modulov # -*- coding: utf-8 -*-
from transformers import pipeline
from nltk.tokenize import sent_tokenize from transformers import RobertaTokenizer, RobertaForMaskedLM
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
import torch
import nltk import nltk
from nltk.tokenize import word_tokenize ,sent_tokenize
import re
nltk.download('punkt') nltk.download('punkt')
# funkcia na obnovu interpunkcie text="text pre trenovanie neuronovej siete"
def restore_punctuation(text):
sents = sent_tokenize(text)
new_text = ""
labels = ['.', '!', ',', ':', '?', '-', ";"]
for sent in sents: # Example: tokenizing a list of text strings
sent = ''.join(ch for ch in sent if ch not in labels) texts = sent_tokenize(text , "slovene")
text_word = sent.split() encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
words = text_word[:]
unmasker = pipeline('fill-mask', model='gerulata/slovakbert') import torch
for i in range(1, len(text_word) + 1): class MLM_Dataset(torch.utils.data.Dataset):
text_word.insert(i, '<mask>') def __init__(self, encodings):
sent = " ".join(text_word) self.encodings = encodings
text_with_punc = unmasker(sent)
if text_with_punc[0]['token_str'] in labels: def __len__(self):
words[i - 1] = words[i - 1] + text_with_punc[0]['token_str'] return len(self.encodings['input_ids'])
text_word = words[:] def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
new_text += " ".join(words) dataset = MLM_Dataset(encodings)
return new_text from transformers import DataCollatorForLanguageModeling
# Zadanie textu pre opravu interpunkcie data_collator = DataCollatorForLanguageModeling(
input_text = input("Zadajte text na opravu interpunkcie: ") tokenizer=tokenizer, mlm=True, mlm_probability=0.15
output_text = restore_punctuation(input_text) )
# Výpis pôvodného a opraveného textu from torch.utils.data import DataLoader
print("Pôvodný text:", input_text)
print("Opravený text:", output_text)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)
epochs = 1
for epoch in range(epochs):
model.train()
for batch in dataloader:
optimizer.zero_grad()
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
loss = outputs.loss
loss.backward()
optimizer.step()
model.save_pretrained('path/to/save/model')
input="ako sa voláš"
def restore_pun(text):
words=nltk.word_tokenize(text )
for i in range (1,len(words)):
current=words[i]
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
words[i] +=" <mask>"
current_pun="no"
else :
current_pun=words[i]
words[i]=" <mask>"
current_pun=words[i]
x=" ".join(words)
encoded_input = tokenizer(x, return_tensors='pt')
output = model(**encoded_input)
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
# Extract the logits for the masked token
mask_token_logits = output.logits[0, mask_token_index, :]
# Find the token with the highest probability
predicted_token_id = torch.argmax(mask_token_logits).item()
predicted_token = tokenizer.decode([predicted_token_id])
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
words[i]=current+ predicted_token
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
words[i]= predicted_token
else :
words[i]=current
out=" ".join(words)
return out
import nltk
nltk.download('punkt')
print("input : " , input)
print ("output :" ,restore_pun(input))