DP2024/slovak_punction2.py

101 lines
2.8 KiB
Python
Raw Normal View History

2023-12-06 12:46:41 +00:00
# -*- coding: utf-8 -*-
from transformers import RobertaTokenizer, RobertaForMaskedLM
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
import torch
2023-11-28 13:31:11 +00:00
import nltk
2023-12-06 12:46:41 +00:00
from nltk.tokenize import word_tokenize ,sent_tokenize
import re
2023-11-28 13:31:11 +00:00
nltk.download('punkt')
2023-12-06 12:46:41 +00:00
text="text pre trenovanie neuronovej siete"
# Example: tokenizing a list of text strings
texts = sent_tokenize(text , "slovene")
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
import torch
class MLM_Dataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
dataset = MLM_Dataset(encodings)
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
from transformers import DataCollatorForLanguageModeling
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
from torch.utils.data import DataLoader
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
from transformers import AdamW
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
optimizer = AdamW(model.parameters(), lr=5e-5)
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
epochs = 1
for epoch in range(epochs):
model.train()
for batch in dataloader:
optimizer.zero_grad()
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
loss = outputs.loss
loss.backward()
optimizer.step()
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
model.save_pretrained('path/to/save/model')
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
input="ako sa voláš"
def restore_pun(text):
words=nltk.word_tokenize(text )
for i in range (1,len(words)):
current=words[i]
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
words[i] +=" <mask>"
current_pun="no"
else :
current_pun=words[i]
words[i]=" <mask>"
current_pun=words[i]
x=" ".join(words)
encoded_input = tokenizer(x, return_tensors='pt')
output = model(**encoded_input)
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
# Extract the logits for the masked token
mask_token_logits = output.logits[0, mask_token_index, :]
# Find the token with the highest probability
predicted_token_id = torch.argmax(mask_token_logits).item()
predicted_token = tokenizer.decode([predicted_token_id])
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
words[i]=current+ predicted_token
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
words[i]= predicted_token
else :
words[i]=current
out=" ".join(words)
return out
import nltk
nltk.download('punkt')
2023-11-28 13:31:11 +00:00
2023-12-06 12:46:41 +00:00
print("input : " , input)
print ("output :" ,restore_pun(input))