Update 'slovak_punction2.py'
This commit is contained in:
parent
54e835669c
commit
2ae4f9ed63
@ -1,42 +1,100 @@
|
||||
# import modulov
|
||||
from transformers import pipeline
|
||||
from nltk.tokenize import sent_tokenize
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
|
||||
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
|
||||
|
||||
import torch
|
||||
import nltk
|
||||
from nltk.tokenize import word_tokenize ,sent_tokenize
|
||||
import re
|
||||
|
||||
nltk.download('punkt')
|
||||
|
||||
# funkcia na obnovu interpunkcie
|
||||
def restore_punctuation(text):
|
||||
sents = sent_tokenize(text)
|
||||
new_text = ""
|
||||
labels = ['.', '!', ',', ':', '?', '-', ";"]
|
||||
text="text pre trenovanie neuronovej siete"
|
||||
|
||||
for sent in sents:
|
||||
sent = ''.join(ch for ch in sent if ch not in labels)
|
||||
text_word = sent.split()
|
||||
words = text_word[:]
|
||||
# Example: tokenizing a list of text strings
|
||||
texts = sent_tokenize(text , "slovene")
|
||||
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
|
||||
|
||||
unmasker = pipeline('fill-mask', model='gerulata/slovakbert')
|
||||
import torch
|
||||
|
||||
for i in range(1, len(text_word) + 1):
|
||||
text_word.insert(i, '<mask>')
|
||||
sent = " ".join(text_word)
|
||||
text_with_punc = unmasker(sent)
|
||||
class MLM_Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, encodings):
|
||||
self.encodings = encodings
|
||||
|
||||
if text_with_punc[0]['token_str'] in labels:
|
||||
words[i - 1] = words[i - 1] + text_with_punc[0]['token_str']
|
||||
def __len__(self):
|
||||
return len(self.encodings['input_ids'])
|
||||
|
||||
text_word = words[:]
|
||||
def __getitem__(self, idx):
|
||||
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||
|
||||
new_text += " ".join(words)
|
||||
dataset = MLM_Dataset(encodings)
|
||||
|
||||
return new_text
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
# Zadanie textu pre opravu interpunkcie
|
||||
input_text = input("Zadajte text na opravu interpunkcie: ")
|
||||
output_text = restore_punctuation(input_text)
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
|
||||
)
|
||||
|
||||
# Výpis pôvodného a opraveného textu
|
||||
print("Pôvodný text:", input_text)
|
||||
print("Opravený text:", output_text)
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
|
||||
|
||||
from transformers import AdamW
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||
|
||||
epochs = 1
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
model.save_pretrained('path/to/save/model')
|
||||
|
||||
input="ako sa voláš"
|
||||
|
||||
def restore_pun(text):
|
||||
words=nltk.word_tokenize(text )
|
||||
for i in range (1,len(words)):
|
||||
current=words[i]
|
||||
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
|
||||
words[i] +=" <mask>"
|
||||
current_pun="no"
|
||||
else :
|
||||
current_pun=words[i]
|
||||
words[i]=" <mask>"
|
||||
current_pun=words[i]
|
||||
x=" ".join(words)
|
||||
|
||||
encoded_input = tokenizer(x, return_tensors='pt')
|
||||
output = model(**encoded_input)
|
||||
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
|
||||
|
||||
# Extract the logits for the masked token
|
||||
mask_token_logits = output.logits[0, mask_token_index, :]
|
||||
|
||||
# Find the token with the highest probability
|
||||
predicted_token_id = torch.argmax(mask_token_logits).item()
|
||||
predicted_token = tokenizer.decode([predicted_token_id])
|
||||
|
||||
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
||||
words[i]=current+ predicted_token
|
||||
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
||||
words[i]= predicted_token
|
||||
else :
|
||||
words[i]=current
|
||||
out=" ".join(words)
|
||||
return out
|
||||
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
|
||||
print("input : " , input)
|
||||
print ("output :" ,restore_pun(input))
|
||||
|
Loading…
Reference in New Issue
Block a user