Update 'slovak_punction2.py'
This commit is contained in:
parent
2ae4f9ed63
commit
eec002d873
@ -12,52 +12,6 @@ import re
|
||||
|
||||
nltk.download('punkt')
|
||||
|
||||
text="text pre trenovanie neuronovej siete"
|
||||
|
||||
# Example: tokenizing a list of text strings
|
||||
texts = sent_tokenize(text , "slovene")
|
||||
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
|
||||
|
||||
import torch
|
||||
|
||||
class MLM_Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, encodings):
|
||||
self.encodings = encodings
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encodings['input_ids'])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||
|
||||
dataset = MLM_Dataset(encodings)
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
|
||||
)
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
|
||||
|
||||
from transformers import AdamW
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||
|
||||
epochs = 1
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
model.save_pretrained('path/to/save/model')
|
||||
|
||||
input="ako sa voláš"
|
||||
|
||||
def restore_pun(text):
|
||||
|
Loading…
Reference in New Issue
Block a user