2023-12-06 12:46:41 +00:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
|
|
|
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
|
|
|
|
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
|
|
|
|
|
|
|
|
import torch
|
2023-11-28 13:31:11 +00:00
|
|
|
import nltk
|
2023-12-06 12:46:41 +00:00
|
|
|
from nltk.tokenize import word_tokenize ,sent_tokenize
|
|
|
|
import re
|
2023-11-28 13:31:11 +00:00
|
|
|
|
|
|
|
nltk.download('punkt')
|
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
text="text pre trenovanie neuronovej siete"
|
|
|
|
|
|
|
|
# Example: tokenizing a list of text strings
|
|
|
|
texts = sent_tokenize(text , "slovene")
|
|
|
|
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
class MLM_Dataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self, encodings):
|
|
|
|
self.encodings = encodings
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.encodings['input_ids'])
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
|
|
|
|
|
|
|
dataset = MLM_Dataset(encodings)
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
from transformers import DataCollatorForLanguageModeling
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
data_collator = DataCollatorForLanguageModeling(
|
|
|
|
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
|
|
|
|
)
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
from torch.utils.data import DataLoader
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
from transformers import AdamW
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
optimizer = AdamW(model.parameters(), lr=5e-5)
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
epochs = 1
|
|
|
|
for epoch in range(epochs):
|
|
|
|
model.train()
|
|
|
|
for batch in dataloader:
|
|
|
|
optimizer.zero_grad()
|
|
|
|
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
|
|
|
|
loss = outputs.loss
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
model.save_pretrained('path/to/save/model')
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
input="ako sa voláš"
|
|
|
|
|
|
|
|
def restore_pun(text):
|
|
|
|
words=nltk.word_tokenize(text )
|
|
|
|
for i in range (1,len(words)):
|
|
|
|
current=words[i]
|
|
|
|
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
|
|
|
|
words[i] +=" <mask>"
|
|
|
|
current_pun="no"
|
|
|
|
else :
|
|
|
|
current_pun=words[i]
|
|
|
|
words[i]=" <mask>"
|
|
|
|
current_pun=words[i]
|
|
|
|
x=" ".join(words)
|
|
|
|
|
|
|
|
encoded_input = tokenizer(x, return_tensors='pt')
|
|
|
|
output = model(**encoded_input)
|
|
|
|
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
|
|
|
|
|
|
|
|
# Extract the logits for the masked token
|
|
|
|
mask_token_logits = output.logits[0, mask_token_index, :]
|
|
|
|
|
|
|
|
# Find the token with the highest probability
|
|
|
|
predicted_token_id = torch.argmax(mask_token_logits).item()
|
|
|
|
predicted_token = tokenizer.decode([predicted_token_id])
|
|
|
|
|
|
|
|
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
|
|
|
words[i]=current+ predicted_token
|
|
|
|
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
|
|
|
words[i]= predicted_token
|
|
|
|
else :
|
|
|
|
words[i]=current
|
|
|
|
out=" ".join(words)
|
|
|
|
return out
|
|
|
|
|
|
|
|
import nltk
|
|
|
|
nltk.download('punkt')
|
2023-11-28 13:31:11 +00:00
|
|
|
|
2023-12-06 12:46:41 +00:00
|
|
|
print("input : " , input)
|
|
|
|
print ("output :" ,restore_pun(input))
|