Update 'slovak_punction2.py'
This commit is contained in:
parent
54e835669c
commit
2ae4f9ed63
@ -1,42 +1,100 @@
|
|||||||
# import modulov
|
# -*- coding: utf-8 -*-
|
||||||
from transformers import pipeline
|
|
||||||
from nltk.tokenize import sent_tokenize
|
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
||||||
|
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
|
||||||
|
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
|
||||||
|
|
||||||
|
import torch
|
||||||
import nltk
|
import nltk
|
||||||
|
from nltk.tokenize import word_tokenize ,sent_tokenize
|
||||||
|
import re
|
||||||
|
|
||||||
nltk.download('punkt')
|
nltk.download('punkt')
|
||||||
|
|
||||||
# funkcia na obnovu interpunkcie
|
text="text pre trenovanie neuronovej siete"
|
||||||
def restore_punctuation(text):
|
|
||||||
sents = sent_tokenize(text)
|
|
||||||
new_text = ""
|
|
||||||
labels = ['.', '!', ',', ':', '?', '-', ";"]
|
|
||||||
|
|
||||||
for sent in sents:
|
# Example: tokenizing a list of text strings
|
||||||
sent = ''.join(ch for ch in sent if ch not in labels)
|
texts = sent_tokenize(text , "slovene")
|
||||||
text_word = sent.split()
|
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
|
||||||
words = text_word[:]
|
|
||||||
|
|
||||||
unmasker = pipeline('fill-mask', model='gerulata/slovakbert')
|
import torch
|
||||||
|
|
||||||
for i in range(1, len(text_word) + 1):
|
class MLM_Dataset(torch.utils.data.Dataset):
|
||||||
text_word.insert(i, '<mask>')
|
def __init__(self, encodings):
|
||||||
sent = " ".join(text_word)
|
self.encodings = encodings
|
||||||
text_with_punc = unmasker(sent)
|
|
||||||
|
|
||||||
if text_with_punc[0]['token_str'] in labels:
|
def __len__(self):
|
||||||
words[i - 1] = words[i - 1] + text_with_punc[0]['token_str']
|
return len(self.encodings['input_ids'])
|
||||||
|
|
||||||
text_word = words[:]
|
def __getitem__(self, idx):
|
||||||
|
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||||
|
|
||||||
new_text += " ".join(words)
|
dataset = MLM_Dataset(encodings)
|
||||||
|
|
||||||
return new_text
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
# Zadanie textu pre opravu interpunkcie
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
input_text = input("Zadajte text na opravu interpunkcie: ")
|
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
|
||||||
output_text = restore_punctuation(input_text)
|
)
|
||||||
|
|
||||||
# Výpis pôvodného a opraveného textu
|
from torch.utils.data import DataLoader
|
||||||
print("Pôvodný text:", input_text)
|
|
||||||
print("Opravený text:", output_text)
|
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
|
||||||
|
|
||||||
|
from transformers import AdamW
|
||||||
|
|
||||||
|
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||||
|
|
||||||
|
epochs = 1
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
for batch in dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
|
||||||
|
loss = outputs.loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
model.save_pretrained('path/to/save/model')
|
||||||
|
|
||||||
|
input="ako sa voláš"
|
||||||
|
|
||||||
|
def restore_pun(text):
|
||||||
|
words=nltk.word_tokenize(text )
|
||||||
|
for i in range (1,len(words)):
|
||||||
|
current=words[i]
|
||||||
|
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
|
||||||
|
words[i] +=" <mask>"
|
||||||
|
current_pun="no"
|
||||||
|
else :
|
||||||
|
current_pun=words[i]
|
||||||
|
words[i]=" <mask>"
|
||||||
|
current_pun=words[i]
|
||||||
|
x=" ".join(words)
|
||||||
|
|
||||||
|
encoded_input = tokenizer(x, return_tensors='pt')
|
||||||
|
output = model(**encoded_input)
|
||||||
|
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
|
||||||
|
|
||||||
|
# Extract the logits for the masked token
|
||||||
|
mask_token_logits = output.logits[0, mask_token_index, :]
|
||||||
|
|
||||||
|
# Find the token with the highest probability
|
||||||
|
predicted_token_id = torch.argmax(mask_token_logits).item()
|
||||||
|
predicted_token = tokenizer.decode([predicted_token_id])
|
||||||
|
|
||||||
|
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
||||||
|
words[i]=current+ predicted_token
|
||||||
|
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
||||||
|
words[i]= predicted_token
|
||||||
|
else :
|
||||||
|
words[i]=current
|
||||||
|
out=" ".join(words)
|
||||||
|
return out
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
nltk.download('punkt')
|
||||||
|
|
||||||
|
print("input : " , input)
|
||||||
|
print ("output :" ,restore_pun(input))
|
||||||
|
Loading…
Reference in New Issue
Block a user