DP2024/slovak_punction2.py

55 lines
1.5 KiB
Python
Raw Normal View History

2023-12-06 12:46:41 +00:00
# -*- coding: utf-8 -*-
from transformers import RobertaTokenizer, RobertaForMaskedLM
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
import torch
2023-11-28 13:31:11 +00:00
import nltk
2023-12-06 12:46:41 +00:00
from nltk.tokenize import word_tokenize ,sent_tokenize
import re
2023-11-28 13:31:11 +00:00
nltk.download('punkt')
2023-12-06 12:46:41 +00:00
input="ako sa voláš"
def restore_pun(text):
words=nltk.word_tokenize(text )
for i in range (1,len(words)):
current=words[i]
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
words[i] +=" <mask>"
current_pun="no"
else :
current_pun=words[i]
words[i]=" <mask>"
current_pun=words[i]
x=" ".join(words)
encoded_input = tokenizer(x, return_tensors='pt')
output = model(**encoded_input)
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
# Extract the logits for the masked token
mask_token_logits = output.logits[0, mask_token_index, :]
# Find the token with the highest probability
predicted_token_id = torch.argmax(mask_token_logits).item()
predicted_token = tokenizer.decode([predicted_token_id])
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
words[i]=current+ predicted_token
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
words[i]= predicted_token
else :
words[i]=current
out=" ".join(words)
return out
import nltk
nltk.download('punkt')
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
print("input : " , input)
print ("output :" ,restore_pun(input))