Compare commits

...

4 Commits

201
slovak_punction2.py Normal file
View File

@ -0,0 +1,201 @@
# coding: utf-8
def convert(text, indices, vals, puns):
# Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu
modified_text = text
for val, i in zip(vals, indices):
# Pridanie zodpovedajucej interpunkcie v upravenom texte
modified_text.insert(val, puns[i - 1])
return modified_text
# kniznice
from transformers import RobertaTokenizer, RobertaForMaskedLM
from transformers import DataCollatorForLanguageModeling
#maskovacei modely
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
import torch
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
# importovanie modulu pre manipuláciu s textom
import re
# Stiahnutie obsahu tokenizerov
# nltk.download('punkt')
# Importovanie kniznic a modulov
from transformers import DataCollatorForLanguageModeling, AdamW
from torch.utils.data import DataLoader
from nltk.tokenize import sent_tokenize
def fine_tuning(texts, model, tokenizer):
# Kontrola textu či je spravna
if len(texts) == 0:
return model
# Spracovanie textu
def preprocess_for_punctuation(texts):
processed_texts = []
for text in texts:
# Maskovanie interpunkcie pomocou tokenov
text = re.sub(r'[.,?!:-]', '[MASK]', text)
processed_texts.append(text)
return processed_texts
# Aplikuje spracovanie na vstupne texty
texts = preprocess_for_punctuation(texts)
# Tokenizuje a encoduje spravoané texty
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
# Definicia vlastneho datasetu
class MLM_Dataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
# Vytvorenie valstneho datasetu
dataset = MLM_Dataset(encodings)
# Vytvorenie dat pre MLM
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
# Vytvorenie DataLoader pre trenovanie modelu
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
# Optimalizaotr pre trenovanie
optimizer = AdamW(model.parameters(), lr=5e-5)
# Nastavenie epoch na trenovanie
epochs = 1
print("Zaciatok trenovania modelu...")
# Trenovanie
for epoch in range(epochs):
model.train()
for batch in dataloader:
# Vynuluje pred spätným prechodom
optimizer.zero_grad()
# Presunutie vstupov
inputs = {k: v.to(model.device) for k, v in batch.items()}
outputs = model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
print("Ucenie dokoncene.")
# Vratenie sa k modelu
return model
# Obnovenie interpunkcie
def restore_pun(text, model):
# Tokenizacia vstupného textu
words = nltk.word_tokenize(text)
# Opakovanie slov
for i in range(1, len(words)):
current = words[i]
# Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie
if current not in [".", ",", "?", "!" ,":","-"]:
words[i] += " <mask>"
current_pun = "no"
else:
current_pun = current
words[i] = " <mask>"
# Spojenie slov do retazca
x = " ".join(words)
# Encodovanie vstupu pomocou tokenizera
encoded_input = tokenizer(x, return_tensors='pt')
# vystup cez encode
output = model(**encoded_input)
# najdenie indexu maskovaneho tokenu vo vstupe
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
mask_token_logits = output.logits[0, mask_token_index, :]
# Najdenie tokeu s najvecsou pravdepodobnostou
predicted_token_id = torch.argmax(mask_token_logits).item()
predicted_token = tokenizer.decode([predicted_token_id])
# Aktualizuje slovo na zaklade tokenu
if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
words[i] = current + predicted_token
elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
words[i] = predicted_token
else:
words[i] = current
# Spojenie slov do reťazca s vysledkom
out = " ".join(words)
return out
# Vybranie co chceme s programom robit
while True:
option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ')
#1 Trenovanie
if option == '1':
file_path = input('Zadajte subor s datami')
# importovanie json
import json
# Cita a analyzuje kazdy riadok ako samsotatny json objekt
json_objects = []
with open(file_path, 'r') as file:
for line in file:
try:
json_object = json.loads(line)
json_objects.append(json_object)
except json.JSONDecodeError:
continue
# Definovanie interpunkcie na trenovanie
puns = ['.', ',', '?', '!', ':', '-']
# Spracovanie a ucenie
texts = []
for i in range(len(json_objects)):
indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0]
val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0]
# Uprava textu
json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns)
# Pridanie upraveneho textu d ozoznamu
texts.append(" ".join(json_objects[i]['text']))
# doladovanie modelu
model = fine_tuning(texts[:], model, tokenizer)
#2: Oprava interpunkciet
elif option == '2':
# Vlozenie textu bez interpunkcie alebo so zlou interpunkciou
test = input('Enter your text: ')
# Vypisanie textu
print("Output:", restore_pun(test, model))
#3: Ukoncenei programu
else:
break