Update 'slovak_punction2.py'

This commit is contained in:
Adrián Remiaš 2024-01-22 09:02:22 +00:00
parent eec002d873
commit 40117c648d

View File

@ -1,54 +1,201 @@
# -*- coding: utf-8 -*- # coding: utf-8
def convert(text, indices, vals, puns):
# Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu
modified_text = text
for val, i in zip(vals, indices):
# Pridanie zodpovedajucej interpunkcie v upravenom texte
modified_text.insert(val, puns[i - 1])
return modified_text
# kniznice
from transformers import RobertaTokenizer, RobertaForMaskedLM from transformers import RobertaTokenizer, RobertaForMaskedLM
from transformers import DataCollatorForLanguageModeling
#maskovacei modely
tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert') tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert') model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert')
import torch import torch
import nltk import nltk
from nltk.tokenize import word_tokenize ,sent_tokenize from nltk.tokenize import word_tokenize, sent_tokenize
# importovanie modulu pre manipuláciu s textom
import re import re
nltk.download('punkt') # Stiahnutie obsahu tokenizerov
# nltk.download('punkt')
input="ako sa voláš" # Importovanie kniznic a modulov
from transformers import DataCollatorForLanguageModeling, AdamW
from torch.utils.data import DataLoader
from nltk.tokenize import sent_tokenize
def restore_pun(text): def fine_tuning(texts, model, tokenizer):
words=nltk.word_tokenize(text ) # Kontrola textu či je spravna
for i in range (1,len(words)): if len(texts) == 0:
current=words[i] return model
if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]:
words[i] +=" <mask>"
current_pun="no"
else :
current_pun=words[i]
words[i]=" <mask>"
current_pun=words[i]
x=" ".join(words)
# Spracovanie textu
def preprocess_for_punctuation(texts):
processed_texts = []
for text in texts:
# Maskovanie interpunkcie pomocou tokenov
text = re.sub(r'[.,?!:-]', '[MASK]', text)
processed_texts.append(text)
return processed_texts
# Aplikuje spracovanie na vstupne texty
texts = preprocess_for_punctuation(texts)
# Tokenizuje a encoduje spravoané texty
encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512)
# Definicia vlastneho datasetu
class MLM_Dataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
# Vytvorenie valstneho datasetu
dataset = MLM_Dataset(encodings)
# Vytvorenie dat pre MLM
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
# Vytvorenie DataLoader pre trenovanie modelu
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
# Optimalizaotr pre trenovanie
optimizer = AdamW(model.parameters(), lr=5e-5)
# Nastavenie epoch na trenovanie
epochs = 1
print("Zaciatok trenovania modelu...")
# Trenovanie
for epoch in range(epochs):
model.train()
for batch in dataloader:
# Vynuluje pred spätným prechodom
optimizer.zero_grad()
# Presunutie vstupov
inputs = {k: v.to(model.device) for k, v in batch.items()}
outputs = model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
print("Ucenie dokoncene.")
# Vratenie sa k modelu
return model
# Obnovenie interpunkcie
def restore_pun(text, model):
# Tokenizacia vstupného textu
words = nltk.word_tokenize(text)
# Opakovanie slov
for i in range(1, len(words)):
current = words[i]
# Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie
if current not in [".", ",", "?", "!" ,":","-"]:
words[i] += " <mask>"
current_pun = "no"
else:
current_pun = current
words[i] = " <mask>"
# Spojenie slov do retazca
x = " ".join(words)
# Encodovanie vstupu pomocou tokenizera
encoded_input = tokenizer(x, return_tensors='pt') encoded_input = tokenizer(x, return_tensors='pt')
# vystup cez encode
output = model(**encoded_input) output = model(**encoded_input)
# najdenie indexu maskovaneho tokenu vo vstupe
mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0]
# Extract the logits for the masked token
mask_token_logits = output.logits[0, mask_token_index, :] mask_token_logits = output.logits[0, mask_token_index, :]
# Find the token with the highest probability # Najdenie tokeu s najvecsou pravdepodobnostou
predicted_token_id = torch.argmax(mask_token_logits).item() predicted_token_id = torch.argmax(mask_token_logits).item()
predicted_token = tokenizer.decode([predicted_token_id]) predicted_token = tokenizer.decode([predicted_token_id])
if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : # Aktualizuje slovo na zaklade tokenu
words[i]=current+ predicted_token if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
elif current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : words[i] = current + predicted_token
words[i]= predicted_token elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]:
else : words[i] = predicted_token
words[i]=current else:
out=" ".join(words) words[i] = current
# Spojenie slov do reťazca s vysledkom
out = " ".join(words)
return out return out
import nltk # Vybranie co chceme s programom robit
nltk.download('punkt') while True:
option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ')
print("input : " , input) #1 Trenovanie
print ("output :" ,restore_pun(input)) if option == '1':
file_path = input('Zadajte subor s datami')
# importovanie json
import json
# Cita a analyzuje kazdy riadok ako samsotatny json objekt
json_objects = []
with open(file_path, 'r') as file:
for line in file:
try:
json_object = json.loads(line)
json_objects.append(json_object)
except json.JSONDecodeError:
continue
# Definovanie interpunkcie na trenovanie
puns = ['.', ',', '?', '!', ':', '-']
# Spracovanie a ucenie
texts = []
for i in range(len(json_objects)):
indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0]
val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0]
# Uprava textu
json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns)
# Pridanie upraveneho textu d ozoznamu
texts.append(" ".join(json_objects[i]['text']))
# doladovanie modelu
model = fine_tuning(texts[:], model, tokenizer)
#2: Oprava interpunkciet
elif option == '2':
# Vlozenie textu bez interpunkcie alebo so zlou interpunkciou
test = input('Enter your text: ')
# Vypisanie textu
print("Output:", restore_pun(test, model))
#3: Ukoncenei programu
else:
break