Compare commits
4 Commits
de517dbe62
...
master
Author | SHA1 | Date | |
---|---|---|---|
40117c648d | |||
eec002d873 | |||
2ae4f9ed63 | |||
54e835669c |
201
slovak_punction2.py
Normal file
201
slovak_punction2.py
Normal file
@ -0,0 +1,201 @@
|
||||
# coding: utf-8
|
||||
|
||||
def convert(text, indices, vals, puns):
|
||||
# Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu
|
||||
modified_text = text
|
||||
|
||||
|
||||
for val, i in zip(vals, indices):
|
||||
# Pridanie zodpovedajucej interpunkcie v upravenom texte
|
||||
modified_text.insert(val, puns[i - 1])
|
||||
|
||||
|
||||
return modified_text
|
||||
|
||||
# kniznice
|
||||
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
#maskovacei modely
|
||||
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
|
||||
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
|
||||
|
||||
import torch
|
||||
|
||||
import nltk
|
||||
from nltk.tokenize import word_tokenize, sent_tokenize
|
||||
|
||||
# importovanie modulu pre manipuláciu s textom
|
||||
import re
|
||||
|
||||
# Stiahnutie obsahu tokenizerov
|
||||
# nltk.download('punkt')
|
||||
|
||||
# Importovanie kniznic a modulov
|
||||
from transformers import DataCollatorForLanguageModeling, AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
def fine_tuning(texts, model, tokenizer):
|
||||
# Kontrola textu či je spravna
|
||||
if len(texts) == 0:
|
||||
return model
|
||||
|
||||
# Spracovanie textu
|
||||
def preprocess_for_punctuation(texts):
|
||||
processed_texts = []
|
||||
for text in texts:
|
||||
# Maskovanie interpunkcie pomocou tokenov
|
||||
text = re.sub(r'[.,?!:-]', '[MASK]', text)
|
||||
processed_texts.append(text)
|
||||
return processed_texts
|
||||
|
||||
# Aplikuje spracovanie na vstupne texty
|
||||
texts = preprocess_for_punctuation(texts)
|
||||
|
||||
# Tokenizuje a encoduje spravoané texty
|
||||
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
|
||||
|
||||
# Definicia vlastneho datasetu
|
||||
class MLM_Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, encodings):
|
||||
self.encodings = encodings
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encodings['input_ids'])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||
|
||||
# Vytvorenie valstneho datasetu
|
||||
dataset = MLM_Dataset(encodings)
|
||||
|
||||
# Vytvorenie dat pre MLM
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||
|
||||
# Vytvorenie DataLoader pre trenovanie modelu
|
||||
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
|
||||
|
||||
# Optimalizaotr pre trenovanie
|
||||
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||
|
||||
# Nastavenie epoch na trenovanie
|
||||
epochs = 1
|
||||
|
||||
print("Zaciatok trenovania modelu...")
|
||||
# Trenovanie
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
for batch in dataloader:
|
||||
# Vynuluje pred spätným prechodom
|
||||
optimizer.zero_grad()
|
||||
# Presunutie vstupov
|
||||
inputs = {k: v.to(model.device) for k, v in batch.items()}
|
||||
outputs = model(**inputs)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print("Ucenie dokoncene.")
|
||||
|
||||
# Vratenie sa k modelu
|
||||
return model
|
||||
|
||||
|
||||
|
||||
# Obnovenie interpunkcie
|
||||
def restore_pun(text, model):
|
||||
# Tokenizacia vstupného textu
|
||||
words = nltk.word_tokenize(text)
|
||||
|
||||
# Opakovanie slov
|
||||
for i in range(1, len(words)):
|
||||
current = words[i]
|
||||
|
||||
# Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie
|
||||
if current not in [".", ",", "?", "!" ,":","-"]:
|
||||
words[i] += " <mask>"
|
||||
current_pun = "no"
|
||||
else:
|
||||
current_pun = current
|
||||
words[i] = " <mask>"
|
||||
|
||||
# Spojenie slov do retazca
|
||||
x = " ".join(words)
|
||||
|
||||
# Encodovanie vstupu pomocou tokenizera
|
||||
encoded_input = tokenizer(x, return_tensors='pt')
|
||||
|
||||
# vystup cez encode
|
||||
output = model(**encoded_input)
|
||||
|
||||
# najdenie indexu maskovaneho tokenu vo vstupe
|
||||
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
|
||||
|
||||
mask_token_logits = output.logits[0, mask_token_index, :]
|
||||
|
||||
# Najdenie tokeu s najvecsou pravdepodobnostou
|
||||
predicted_token_id = torch.argmax(mask_token_logits).item()
|
||||
predicted_token = tokenizer.decode([predicted_token_id])
|
||||
|
||||
# Aktualizuje slovo na zaklade tokenu
|
||||
if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
|
||||
words[i] = current + predicted_token
|
||||
elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
|
||||
words[i] = predicted_token
|
||||
else:
|
||||
words[i] = current
|
||||
|
||||
# Spojenie slov do reťazca s vysledkom
|
||||
out = " ".join(words)
|
||||
return out
|
||||
|
||||
# Vybranie co chceme s programom robit
|
||||
while True:
|
||||
option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ')
|
||||
|
||||
#1 Trenovanie
|
||||
if option == '1':
|
||||
file_path = input('Zadajte subor s datami')
|
||||
|
||||
# importovanie json
|
||||
import json
|
||||
|
||||
# Cita a analyzuje kazdy riadok ako samsotatny json objekt
|
||||
json_objects = []
|
||||
with open(file_path, 'r') as file:
|
||||
for line in file:
|
||||
try:
|
||||
json_object = json.loads(line)
|
||||
json_objects.append(json_object)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Definovanie interpunkcie na trenovanie
|
||||
puns = ['.', ',', '?', '!', ':', '-']
|
||||
|
||||
# Spracovanie a ucenie
|
||||
texts = []
|
||||
for i in range(len(json_objects)):
|
||||
indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0]
|
||||
val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0]
|
||||
|
||||
# Uprava textu
|
||||
json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns)
|
||||
|
||||
# Pridanie upraveneho textu d ozoznamu
|
||||
texts.append(" ".join(json_objects[i]['text']))
|
||||
|
||||
# doladovanie modelu
|
||||
model = fine_tuning(texts[:], model, tokenizer)
|
||||
|
||||
#2: Oprava interpunkciet
|
||||
elif option == '2':
|
||||
# Vlozenie textu bez interpunkcie alebo so zlou interpunkciou
|
||||
test = input('Enter your text: ')
|
||||
|
||||
# Vypisanie textu
|
||||
print("Output:", restore_pun(test, model))
|
||||
|
||||
#3: Ukoncenei programu
|
||||
else:
|
||||
break
|
Loading…
Reference in New Issue
Block a user