Update 'slovak_punction2.py'

This commit is contained in:
Adrián Remiaš 2024-01-03 08:56:32 +00:00
parent 2ae4f9ed63
commit eec002d873

View File

@ -12,52 +12,6 @@ import re
nltk.download('punkt') nltk.download('punkt')
text="text pre trenovanie neuronovej siete"
# Example: tokenizing a list of text strings
texts = sent_tokenize(text , "slovene")
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
import torch
class MLM_Dataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
dataset = MLM_Dataset(encodings)
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)
epochs = 1
for epoch in range(epochs):
model.train()
for batch in dataloader:
optimizer.zero_grad()
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
loss = outputs.loss
loss.backward()
optimizer.step()
model.save_pretrained('path/to/save/model')
input="ako sa voláš" input="ako sa voláš"
def restore_pun(text): def restore_pun(text):