55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
|
|
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
|
|
|
|
import torch
|
|
import nltk
|
|
from nltk.tokenize import word_tokenize ,sent_tokenize
|
|
import re
|
|
|
|
nltk.download('punkt')
|
|
|
|
input="ako sa voláš"
|
|
|
|
def restore_pun(text):
|
|
words=nltk.word_tokenize(text )
|
|
for i in range (1,len(words)):
|
|
current=words[i]
|
|
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
|
|
words[i] +=" <mask>"
|
|
current_pun="no"
|
|
else :
|
|
current_pun=words[i]
|
|
words[i]=" <mask>"
|
|
current_pun=words[i]
|
|
x=" ".join(words)
|
|
|
|
encoded_input = tokenizer(x, return_tensors='pt')
|
|
output = model(**encoded_input)
|
|
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
|
|
|
|
# Extract the logits for the masked token
|
|
mask_token_logits = output.logits[0, mask_token_index, :]
|
|
|
|
# Find the token with the highest probability
|
|
predicted_token_id = torch.argmax(mask_token_logits).item()
|
|
predicted_token = tokenizer.decode([predicted_token_id])
|
|
|
|
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
|
words[i]=current+ predicted_token
|
|
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
|
words[i]= predicted_token
|
|
else :
|
|
words[i]=current
|
|
out=" ".join(words)
|
|
return out
|
|
|
|
import nltk
|
|
nltk.download('punkt')
|
|
|
|
print("input : " , input)
|
|
print ("output :" ,restore_pun(input))
|