Update 'slovak_punction2.py'
This commit is contained in:
		
							parent
							
								
									eec002d873
								
							
						
					
					
						commit
						40117c648d
					
				| @ -1,54 +1,201 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| # coding: utf-8  | ||||
| 
 | ||||
| def convert(text, indices, vals, puns): | ||||
|     # Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu | ||||
|     modified_text = text | ||||
| 
 | ||||
|      | ||||
|     for val, i in zip(vals, indices): | ||||
|         # Pridanie zodpovedajucej interpunkcie v upravenom texte | ||||
|         modified_text.insert(val, puns[i - 1]) | ||||
| 
 | ||||
|      | ||||
|     return modified_text | ||||
| 
 | ||||
| # kniznice | ||||
| from transformers import RobertaTokenizer, RobertaForMaskedLM | ||||
| 
 | ||||
| from transformers import DataCollatorForLanguageModeling | ||||
| #maskovacei modely  | ||||
| tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert') | ||||
| model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert') | ||||
| 
 | ||||
| import torch | ||||
| 
 | ||||
| import nltk | ||||
| from nltk.tokenize import word_tokenize ,sent_tokenize | ||||
| from nltk.tokenize import word_tokenize, sent_tokenize | ||||
| 
 | ||||
| # importovanie modulu pre manipuláciu s textom | ||||
| import re | ||||
| 
 | ||||
| nltk.download('punkt') | ||||
| # Stiahnutie obsahu tokenizerov | ||||
| # nltk.download('punkt') | ||||
| 
 | ||||
| input="ako sa voláš" | ||||
| # Importovanie kniznic a modulov | ||||
| from transformers import DataCollatorForLanguageModeling, AdamW | ||||
| from torch.utils.data import DataLoader | ||||
| from nltk.tokenize import sent_tokenize | ||||
| 
 | ||||
| def restore_pun(text): | ||||
|   words=nltk.word_tokenize(text ) | ||||
|   for i in range (1,len(words)): | ||||
|     current=words[i] | ||||
|     if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]: | ||||
|       words[i] +=" <mask>" | ||||
|       current_pun="no" | ||||
|     else : | ||||
|       current_pun=words[i] | ||||
|       words[i]=" <mask>" | ||||
|       current_pun=words[i] | ||||
|     x=" ".join(words) | ||||
| def fine_tuning(texts, model, tokenizer): | ||||
|     # Kontrola textu či je spravna | ||||
|     if len(texts) == 0: | ||||
|         return model | ||||
| 
 | ||||
|     encoded_input = tokenizer(x, return_tensors='pt') | ||||
|     output = model(**encoded_input) | ||||
|     mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] | ||||
|     # Spracovanie textu | ||||
|     def preprocess_for_punctuation(texts): | ||||
|         processed_texts = [] | ||||
|         for text in texts: | ||||
|             # Maskovanie interpunkcie pomocou tokenov | ||||
|             text = re.sub(r'[.,?!:-]', '[MASK]', text) | ||||
|             processed_texts.append(text) | ||||
|         return processed_texts | ||||
| 
 | ||||
|   # Extract the logits for the masked token | ||||
|     mask_token_logits = output.logits[0, mask_token_index, :] | ||||
|     # Aplikuje spracovanie na vstupne texty | ||||
|     texts = preprocess_for_punctuation(texts) | ||||
| 
 | ||||
|   # Find the token with the highest probability | ||||
|     predicted_token_id = torch.argmax(mask_token_logits).item() | ||||
|     predicted_token = tokenizer.decode([predicted_token_id]) | ||||
|     # Tokenizuje a encoduje spravoané texty | ||||
|     encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512) | ||||
| 
 | ||||
|     if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : | ||||
|       words[i]=current+ predicted_token | ||||
|     elif  current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"]  : | ||||
|       words[i]= predicted_token | ||||
|     else : | ||||
|       words[i]=current | ||||
|   out=" ".join(words) | ||||
|   return out | ||||
|     # Definicia vlastneho datasetu | ||||
|     class MLM_Dataset(torch.utils.data.Dataset): | ||||
|         def __init__(self, encodings): | ||||
|             self.encodings = encodings | ||||
| 
 | ||||
| import nltk | ||||
| nltk.download('punkt') | ||||
|         def __len__(self): | ||||
|             return len(self.encodings['input_ids']) | ||||
| 
 | ||||
| print("input : "  , input) | ||||
| print ("output :" ,restore_pun(input)) | ||||
|         def __getitem__(self, idx): | ||||
|             return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | ||||
| 
 | ||||
|     # Vytvorenie valstneho datasetu | ||||
|     dataset = MLM_Dataset(encodings) | ||||
| 
 | ||||
|     # Vytvorenie dat pre MLM | ||||
|     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) | ||||
| 
 | ||||
|     # Vytvorenie DataLoader pre trenovanie modelu  | ||||
|     dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator) | ||||
| 
 | ||||
|     # Optimalizaotr pre trenovanie | ||||
|     optimizer = AdamW(model.parameters(), lr=5e-5) | ||||
| 
 | ||||
|     # Nastavenie epoch na trenovanie | ||||
|     epochs = 1 | ||||
| 
 | ||||
|     print("Zaciatok trenovania modelu...") | ||||
|     # Trenovanie | ||||
|     for epoch in range(epochs): | ||||
|         model.train() | ||||
|         for batch in dataloader: | ||||
|             # Vynuluje pred spätným prechodom | ||||
|             optimizer.zero_grad() | ||||
|             # Presunutie vstupov | ||||
|             inputs = {k: v.to(model.device) for k, v in batch.items()} | ||||
|             outputs = model(**inputs) | ||||
|             loss = outputs.loss | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
| 
 | ||||
|     print("Ucenie dokoncene.") | ||||
| 
 | ||||
|     # Vratenie sa k modelu | ||||
|     return model | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| # Obnovenie interpunkcie | ||||
| def restore_pun(text, model): | ||||
|     # Tokenizacia vstupného textu | ||||
|     words = nltk.word_tokenize(text) | ||||
| 
 | ||||
|     # Opakovanie slov | ||||
|     for i in range(1, len(words)): | ||||
|         current = words[i] | ||||
| 
 | ||||
|         # Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie | ||||
|         if current not in [".", ",", "?", "!" ,":","-"]: | ||||
|             words[i] += " <mask>" | ||||
|             current_pun = "no" | ||||
|         else: | ||||
|             current_pun = current | ||||
|             words[i] = " <mask>" | ||||
| 
 | ||||
|         # Spojenie slov do retazca | ||||
|         x = " ".join(words) | ||||
| 
 | ||||
|         # Encodovanie vstupu pomocou tokenizera | ||||
|         encoded_input = tokenizer(x, return_tensors='pt') | ||||
| 
 | ||||
|         # vystup cez encode | ||||
|         output = model(**encoded_input) | ||||
| 
 | ||||
|         # najdenie indexu maskovaneho tokenu vo vstupe | ||||
|         mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] | ||||
| 
 | ||||
|         mask_token_logits = output.logits[0, mask_token_index, :] | ||||
| 
 | ||||
|         # Najdenie tokeu s najvecsou pravdepodobnostou | ||||
|         predicted_token_id = torch.argmax(mask_token_logits).item() | ||||
|         predicted_token = tokenizer.decode([predicted_token_id]) | ||||
| 
 | ||||
|         # Aktualizuje slovo na zaklade tokenu | ||||
|         if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: | ||||
|             words[i] = current + predicted_token | ||||
|         elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: | ||||
|             words[i] = predicted_token | ||||
|         else: | ||||
|             words[i] = current | ||||
| 
 | ||||
|     # Spojenie slov do reťazca s vysledkom | ||||
|     out = " ".join(words) | ||||
|     return out | ||||
| 
 | ||||
| # Vybranie co chceme s programom robit | ||||
| while True: | ||||
|     option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ') | ||||
| 
 | ||||
|     #1 Trenovanie | ||||
|     if option == '1': | ||||
|         file_path = input('Zadajte subor s datami') | ||||
| 
 | ||||
|         # importovanie json | ||||
|         import json | ||||
| 
 | ||||
|         # Cita a analyzuje kazdy riadok ako samsotatny json objekt | ||||
|         json_objects = [] | ||||
|         with open(file_path, 'r') as file: | ||||
|             for line in file: | ||||
|                 try: | ||||
|                     json_object = json.loads(line) | ||||
|                     json_objects.append(json_object) | ||||
|                 except json.JSONDecodeError: | ||||
|                     continue | ||||
| 
 | ||||
|         # Definovanie interpunkcie na trenovanie | ||||
|         puns = ['.', ',', '?', '!', ':', '-'] | ||||
| 
 | ||||
|         # Spracovanie a ucenie | ||||
|         texts = [] | ||||
|         for i in range(len(json_objects)): | ||||
|             indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0] | ||||
|             val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0] | ||||
| 
 | ||||
|             # Uprava textu | ||||
|             json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns) | ||||
| 
 | ||||
|             # Pridanie upraveneho textu d ozoznamu | ||||
|             texts.append(" ".join(json_objects[i]['text'])) | ||||
| 
 | ||||
|         # doladovanie modelu | ||||
|         model = fine_tuning(texts[:], model, tokenizer) | ||||
| 
 | ||||
|     #2: Oprava interpunkciet | ||||
|     elif option == '2': | ||||
|         # Vlozenie textu bez interpunkcie alebo so zlou interpunkciou | ||||
|         test = input('Enter your text: ') | ||||
| 
 | ||||
|         # Vypisanie textu  | ||||
|         print("Output:", restore_pun(test, model)) | ||||
| 
 | ||||
|     #3: Ukoncenei programu | ||||
|     else: | ||||
|         break | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user