Update 'slovak_punction2.py'
This commit is contained in:
parent
eec002d873
commit
40117c648d
@ -1,54 +1,201 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# coding: utf-8
|
||||||
|
|
||||||
|
def convert(text, indices, vals, puns):
|
||||||
|
# Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu
|
||||||
|
modified_text = text
|
||||||
|
|
||||||
|
|
||||||
|
for val, i in zip(vals, indices):
|
||||||
|
# Pridanie zodpovedajucej interpunkcie v upravenom texte
|
||||||
|
modified_text.insert(val, puns[i - 1])
|
||||||
|
|
||||||
|
|
||||||
|
return modified_text
|
||||||
|
|
||||||
|
# kniznice
|
||||||
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
||||||
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
#maskovacei modely
|
||||||
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
|
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
|
||||||
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
|
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import nltk
|
import nltk
|
||||||
from nltk.tokenize import word_tokenize ,sent_tokenize
|
from nltk.tokenize import word_tokenize, sent_tokenize
|
||||||
|
|
||||||
|
# importovanie modulu pre manipuláciu s textom
|
||||||
import re
|
import re
|
||||||
|
|
||||||
nltk.download('punkt')
|
# Stiahnutie obsahu tokenizerov
|
||||||
|
# nltk.download('punkt')
|
||||||
|
|
||||||
input="ako sa voláš"
|
# Importovanie kniznic a modulov
|
||||||
|
from transformers import DataCollatorForLanguageModeling, AdamW
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from nltk.tokenize import sent_tokenize
|
||||||
|
|
||||||
def restore_pun(text):
|
def fine_tuning(texts, model, tokenizer):
|
||||||
words=nltk.word_tokenize(text )
|
# Kontrola textu či je spravna
|
||||||
for i in range (1,len(words)):
|
if len(texts) == 0:
|
||||||
current=words[i]
|
return model
|
||||||
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
|
|
||||||
words[i] +=" <mask>"
|
|
||||||
current_pun="no"
|
|
||||||
else :
|
|
||||||
current_pun=words[i]
|
|
||||||
words[i]=" <mask>"
|
|
||||||
current_pun=words[i]
|
|
||||||
x=" ".join(words)
|
|
||||||
|
|
||||||
encoded_input = tokenizer(x, return_tensors='pt')
|
# Spracovanie textu
|
||||||
output = model(**encoded_input)
|
def preprocess_for_punctuation(texts):
|
||||||
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
|
processed_texts = []
|
||||||
|
for text in texts:
|
||||||
|
# Maskovanie interpunkcie pomocou tokenov
|
||||||
|
text = re.sub(r'[.,?!:-]', '[MASK]', text)
|
||||||
|
processed_texts.append(text)
|
||||||
|
return processed_texts
|
||||||
|
|
||||||
# Extract the logits for the masked token
|
# Aplikuje spracovanie na vstupne texty
|
||||||
mask_token_logits = output.logits[0, mask_token_index, :]
|
texts = preprocess_for_punctuation(texts)
|
||||||
|
|
||||||
# Find the token with the highest probability
|
# Tokenizuje a encoduje spravoané texty
|
||||||
predicted_token_id = torch.argmax(mask_token_logits).item()
|
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
|
||||||
predicted_token = tokenizer.decode([predicted_token_id])
|
|
||||||
|
|
||||||
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
# Definicia vlastneho datasetu
|
||||||
words[i]=current+ predicted_token
|
class MLM_Dataset(torch.utils.data.Dataset):
|
||||||
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] :
|
def __init__(self, encodings):
|
||||||
words[i]= predicted_token
|
self.encodings = encodings
|
||||||
else :
|
|
||||||
words[i]=current
|
|
||||||
out=" ".join(words)
|
|
||||||
return out
|
|
||||||
|
|
||||||
import nltk
|
def __len__(self):
|
||||||
nltk.download('punkt')
|
return len(self.encodings['input_ids'])
|
||||||
|
|
||||||
print("input : " , input)
|
def __getitem__(self, idx):
|
||||||
print ("output :" ,restore_pun(input))
|
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||||
|
|
||||||
|
# Vytvorenie valstneho datasetu
|
||||||
|
dataset = MLM_Dataset(encodings)
|
||||||
|
|
||||||
|
# Vytvorenie dat pre MLM
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||||
|
|
||||||
|
# Vytvorenie DataLoader pre trenovanie modelu
|
||||||
|
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
|
||||||
|
|
||||||
|
# Optimalizaotr pre trenovanie
|
||||||
|
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||||
|
|
||||||
|
# Nastavenie epoch na trenovanie
|
||||||
|
epochs = 1
|
||||||
|
|
||||||
|
print("Zaciatok trenovania modelu...")
|
||||||
|
# Trenovanie
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
for batch in dataloader:
|
||||||
|
# Vynuluje pred spätným prechodom
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# Presunutie vstupov
|
||||||
|
inputs = {k: v.to(model.device) for k, v in batch.items()}
|
||||||
|
outputs = model(**inputs)
|
||||||
|
loss = outputs.loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print("Ucenie dokoncene.")
|
||||||
|
|
||||||
|
# Vratenie sa k modelu
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Obnovenie interpunkcie
|
||||||
|
def restore_pun(text, model):
|
||||||
|
# Tokenizacia vstupného textu
|
||||||
|
words = nltk.word_tokenize(text)
|
||||||
|
|
||||||
|
# Opakovanie slov
|
||||||
|
for i in range(1, len(words)):
|
||||||
|
current = words[i]
|
||||||
|
|
||||||
|
# Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie
|
||||||
|
if current not in [".", ",", "?", "!" ,":","-"]:
|
||||||
|
words[i] += " <mask>"
|
||||||
|
current_pun = "no"
|
||||||
|
else:
|
||||||
|
current_pun = current
|
||||||
|
words[i] = " <mask>"
|
||||||
|
|
||||||
|
# Spojenie slov do retazca
|
||||||
|
x = " ".join(words)
|
||||||
|
|
||||||
|
# Encodovanie vstupu pomocou tokenizera
|
||||||
|
encoded_input = tokenizer(x, return_tensors='pt')
|
||||||
|
|
||||||
|
# vystup cez encode
|
||||||
|
output = model(**encoded_input)
|
||||||
|
|
||||||
|
# najdenie indexu maskovaneho tokenu vo vstupe
|
||||||
|
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
|
||||||
|
|
||||||
|
mask_token_logits = output.logits[0, mask_token_index, :]
|
||||||
|
|
||||||
|
# Najdenie tokeu s najvecsou pravdepodobnostou
|
||||||
|
predicted_token_id = torch.argmax(mask_token_logits).item()
|
||||||
|
predicted_token = tokenizer.decode([predicted_token_id])
|
||||||
|
|
||||||
|
# Aktualizuje slovo na zaklade tokenu
|
||||||
|
if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
|
||||||
|
words[i] = current + predicted_token
|
||||||
|
elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
|
||||||
|
words[i] = predicted_token
|
||||||
|
else:
|
||||||
|
words[i] = current
|
||||||
|
|
||||||
|
# Spojenie slov do reťazca s vysledkom
|
||||||
|
out = " ".join(words)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Vybranie co chceme s programom robit
|
||||||
|
while True:
|
||||||
|
option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ')
|
||||||
|
|
||||||
|
#1 Trenovanie
|
||||||
|
if option == '1':
|
||||||
|
file_path = input('Zadajte subor s datami')
|
||||||
|
|
||||||
|
# importovanie json
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Cita a analyzuje kazdy riadok ako samsotatny json objekt
|
||||||
|
json_objects = []
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
for line in file:
|
||||||
|
try:
|
||||||
|
json_object = json.loads(line)
|
||||||
|
json_objects.append(json_object)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Definovanie interpunkcie na trenovanie
|
||||||
|
puns = ['.', ',', '?', '!', ':', '-']
|
||||||
|
|
||||||
|
# Spracovanie a ucenie
|
||||||
|
texts = []
|
||||||
|
for i in range(len(json_objects)):
|
||||||
|
indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0]
|
||||||
|
val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0]
|
||||||
|
|
||||||
|
# Uprava textu
|
||||||
|
json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns)
|
||||||
|
|
||||||
|
# Pridanie upraveneho textu d ozoznamu
|
||||||
|
texts.append(" ".join(json_objects[i]['text']))
|
||||||
|
|
||||||
|
# doladovanie modelu
|
||||||
|
model = fine_tuning(texts[:], model, tokenizer)
|
||||||
|
|
||||||
|
#2: Oprava interpunkciet
|
||||||
|
elif option == '2':
|
||||||
|
# Vlozenie textu bez interpunkcie alebo so zlou interpunkciou
|
||||||
|
test = input('Enter your text: ')
|
||||||
|
|
||||||
|
# Vypisanie textu
|
||||||
|
print("Output:", restore_pun(test, model))
|
||||||
|
|
||||||
|
#3: Ukoncenei programu
|
||||||
|
else:
|
||||||
|
break
|
Loading…
Reference in New Issue
Block a user