Update 'slovak_punction2.py'
This commit is contained in:
parent
2ae4f9ed63
commit
eec002d873
@ -12,52 +12,6 @@ import re
|
|||||||
|
|
||||||
nltk.download('punkt')
|
nltk.download('punkt')
|
||||||
|
|
||||||
text="text pre trenovanie neuronovej siete"
|
|
||||||
|
|
||||||
# Example: tokenizing a list of text strings
|
|
||||||
texts = sent_tokenize(text , "slovene")
|
|
||||||
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class MLM_Dataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, encodings):
|
|
||||||
self.encodings = encodings
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.encodings['input_ids'])
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
|
||||||
|
|
||||||
dataset = MLM_Dataset(encodings)
|
|
||||||
|
|
||||||
from transformers import DataCollatorForLanguageModeling
|
|
||||||
|
|
||||||
data_collator = DataCollatorForLanguageModeling(
|
|
||||||
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
|
|
||||||
)
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
|
|
||||||
|
|
||||||
from transformers import AdamW
|
|
||||||
|
|
||||||
optimizer = AdamW(model.parameters(), lr=5e-5)
|
|
||||||
|
|
||||||
epochs = 1
|
|
||||||
for epoch in range(epochs):
|
|
||||||
model.train()
|
|
||||||
for batch in dataloader:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
|
|
||||||
loss = outputs.loss
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
model.save_pretrained('path/to/save/model')
|
|
||||||
|
|
||||||
input="ako sa voláš"
|
input="ako sa voláš"
|
||||||
|
|
||||||
def restore_pun(text):
|
def restore_pun(text):
|
||||||
|
Loading…
Reference in New Issue
Block a user