Update 'slovak_punction2.py'
This commit is contained in:
		
							parent
							
								
									eec002d873
								
							
						
					
					
						commit
						40117c648d
					
				| @ -1,54 +1,201 @@ | |||||||
| # -*- coding: utf-8 -*- | # coding: utf-8  | ||||||
| 
 | 
 | ||||||
|  | def convert(text, indices, vals, puns): | ||||||
|  |     # Zabezpecenei aby sa text nezmenil vytovrenim noveho zoznamu | ||||||
|  |     modified_text = text | ||||||
|  | 
 | ||||||
|  |      | ||||||
|  |     for val, i in zip(vals, indices): | ||||||
|  |         # Pridanie zodpovedajucej interpunkcie v upravenom texte | ||||||
|  |         modified_text.insert(val, puns[i - 1]) | ||||||
|  | 
 | ||||||
|  |      | ||||||
|  |     return modified_text | ||||||
|  | 
 | ||||||
|  | # kniznice | ||||||
| from transformers import RobertaTokenizer, RobertaForMaskedLM | from transformers import RobertaTokenizer, RobertaForMaskedLM | ||||||
| 
 | from transformers import DataCollatorForLanguageModeling | ||||||
|  | #maskovacei modely  | ||||||
| tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert') | tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert') | ||||||
| model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert') | model = RobertaForMaskedLM.from_pretrained('gerulata/slovakbert') | ||||||
| 
 | 
 | ||||||
| import torch | import torch | ||||||
|  | 
 | ||||||
| import nltk | import nltk | ||||||
| from nltk.tokenize import word_tokenize, sent_tokenize | from nltk.tokenize import word_tokenize, sent_tokenize | ||||||
|  | 
 | ||||||
|  | # importovanie modulu pre manipuláciu s textom | ||||||
| import re | import re | ||||||
| 
 | 
 | ||||||
| nltk.download('punkt') | # Stiahnutie obsahu tokenizerov | ||||||
|  | # nltk.download('punkt') | ||||||
| 
 | 
 | ||||||
| input="ako sa voláš" | # Importovanie kniznic a modulov | ||||||
|  | from transformers import DataCollatorForLanguageModeling, AdamW | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | from nltk.tokenize import sent_tokenize | ||||||
| 
 | 
 | ||||||
| def restore_pun(text): | def fine_tuning(texts, model, tokenizer): | ||||||
|  |     # Kontrola textu či je spravna | ||||||
|  |     if len(texts) == 0: | ||||||
|  |         return model | ||||||
|  | 
 | ||||||
|  |     # Spracovanie textu | ||||||
|  |     def preprocess_for_punctuation(texts): | ||||||
|  |         processed_texts = [] | ||||||
|  |         for text in texts: | ||||||
|  |             # Maskovanie interpunkcie pomocou tokenov | ||||||
|  |             text = re.sub(r'[.,?!:-]', '[MASK]', text) | ||||||
|  |             processed_texts.append(text) | ||||||
|  |         return processed_texts | ||||||
|  | 
 | ||||||
|  |     # Aplikuje spracovanie na vstupne texty | ||||||
|  |     texts = preprocess_for_punctuation(texts) | ||||||
|  | 
 | ||||||
|  |     # Tokenizuje a encoduje spravoané texty | ||||||
|  |     encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=512) | ||||||
|  | 
 | ||||||
|  |     # Definicia vlastneho datasetu | ||||||
|  |     class MLM_Dataset(torch.utils.data.Dataset): | ||||||
|  |         def __init__(self, encodings): | ||||||
|  |             self.encodings = encodings | ||||||
|  | 
 | ||||||
|  |         def __len__(self): | ||||||
|  |             return len(self.encodings['input_ids']) | ||||||
|  | 
 | ||||||
|  |         def __getitem__(self, idx): | ||||||
|  |             return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | ||||||
|  | 
 | ||||||
|  |     # Vytvorenie valstneho datasetu | ||||||
|  |     dataset = MLM_Dataset(encodings) | ||||||
|  | 
 | ||||||
|  |     # Vytvorenie dat pre MLM | ||||||
|  |     data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) | ||||||
|  | 
 | ||||||
|  |     # Vytvorenie DataLoader pre trenovanie modelu  | ||||||
|  |     dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=data_collator) | ||||||
|  | 
 | ||||||
|  |     # Optimalizaotr pre trenovanie | ||||||
|  |     optimizer = AdamW(model.parameters(), lr=5e-5) | ||||||
|  | 
 | ||||||
|  |     # Nastavenie epoch na trenovanie | ||||||
|  |     epochs = 1 | ||||||
|  | 
 | ||||||
|  |     print("Zaciatok trenovania modelu...") | ||||||
|  |     # Trenovanie | ||||||
|  |     for epoch in range(epochs): | ||||||
|  |         model.train() | ||||||
|  |         for batch in dataloader: | ||||||
|  |             # Vynuluje pred spätným prechodom | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             # Presunutie vstupov | ||||||
|  |             inputs = {k: v.to(model.device) for k, v in batch.items()} | ||||||
|  |             outputs = model(**inputs) | ||||||
|  |             loss = outputs.loss | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  | 
 | ||||||
|  |     print("Ucenie dokoncene.") | ||||||
|  | 
 | ||||||
|  |     # Vratenie sa k modelu | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Obnovenie interpunkcie | ||||||
|  | def restore_pun(text, model): | ||||||
|  |     # Tokenizacia vstupného textu | ||||||
|     words = nltk.word_tokenize(text) |     words = nltk.word_tokenize(text) | ||||||
|  | 
 | ||||||
|  |     # Opakovanie slov | ||||||
|     for i in range(1, len(words)): |     for i in range(1, len(words)): | ||||||
|         current = words[i] |         current = words[i] | ||||||
|     if words[i] not in ['.', '!', ',', ':', '?', '-', ";"]: | 
 | ||||||
|  |         # Rozpoznáva ci dane slovo ma mat interpunkciu alebo nie | ||||||
|  |         if current not in [".", ",", "?", "!" ,":","-"]: | ||||||
|             words[i] += " <mask>" |             words[i] += " <mask>" | ||||||
|             current_pun = "no" |             current_pun = "no" | ||||||
|         else: |         else: | ||||||
|       current_pun=words[i] |             current_pun = current | ||||||
|             words[i] = " <mask>" |             words[i] = " <mask>" | ||||||
|       current_pun=words[i] | 
 | ||||||
|  |         # Spojenie slov do retazca | ||||||
|         x = " ".join(words) |         x = " ".join(words) | ||||||
| 
 | 
 | ||||||
|  |         # Encodovanie vstupu pomocou tokenizera | ||||||
|         encoded_input = tokenizer(x, return_tensors='pt') |         encoded_input = tokenizer(x, return_tensors='pt') | ||||||
|  | 
 | ||||||
|  |         # vystup cez encode | ||||||
|         output = model(**encoded_input) |         output = model(**encoded_input) | ||||||
|  | 
 | ||||||
|  |         # najdenie indexu maskovaneho tokenu vo vstupe | ||||||
|         mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] |         mask_token_index = torch.where(encoded_input["input_ids"][0] == tokenizer.mask_token_id)[0] | ||||||
| 
 | 
 | ||||||
|   # Extract the logits for the masked token |  | ||||||
|         mask_token_logits = output.logits[0, mask_token_index, :] |         mask_token_logits = output.logits[0, mask_token_index, :] | ||||||
| 
 | 
 | ||||||
|   # Find the token with the highest probability |         # Najdenie tokeu s najvecsou pravdepodobnostou | ||||||
|         predicted_token_id = torch.argmax(mask_token_logits).item() |         predicted_token_id = torch.argmax(mask_token_logits).item() | ||||||
|         predicted_token = tokenizer.decode([predicted_token_id]) |         predicted_token = tokenizer.decode([predicted_token_id]) | ||||||
| 
 | 
 | ||||||
|     if current_pun=="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"] : |         # Aktualizuje slovo na zaklade tokenu | ||||||
|  |         if current_pun == "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: | ||||||
|             words[i] = current + predicted_token |             words[i] = current + predicted_token | ||||||
|     elif  current_pun!="no" and predicted_token in ['.', '!', ',', ':', '?', '-', ";"]  : |         elif current_pun != "no" and predicted_token in ['.', ',', '?' , '!',':' ,'-' ]: | ||||||
|             words[i] = predicted_token |             words[i] = predicted_token | ||||||
|         else: |         else: | ||||||
|             words[i] = current |             words[i] = current | ||||||
|  | 
 | ||||||
|  |     # Spojenie slov do reťazca s vysledkom | ||||||
|     out = " ".join(words) |     out = " ".join(words) | ||||||
|     return out |     return out | ||||||
| 
 | 
 | ||||||
| import nltk | # Vybranie co chceme s programom robit | ||||||
| nltk.download('punkt') | while True: | ||||||
|  |     option = input('1=> Ucenie programu 2=> Oprava interpunkcie v texte 3=> koniec programu ') | ||||||
| 
 | 
 | ||||||
| print("input : "  , input) |     #1 Trenovanie | ||||||
| print ("output :" ,restore_pun(input)) |     if option == '1': | ||||||
|  |         file_path = input('Zadajte subor s datami') | ||||||
|  | 
 | ||||||
|  |         # importovanie json | ||||||
|  |         import json | ||||||
|  | 
 | ||||||
|  |         # Cita a analyzuje kazdy riadok ako samsotatny json objekt | ||||||
|  |         json_objects = [] | ||||||
|  |         with open(file_path, 'r') as file: | ||||||
|  |             for line in file: | ||||||
|  |                 try: | ||||||
|  |                     json_object = json.loads(line) | ||||||
|  |                     json_objects.append(json_object) | ||||||
|  |                 except json.JSONDecodeError: | ||||||
|  |                     continue | ||||||
|  | 
 | ||||||
|  |         # Definovanie interpunkcie na trenovanie | ||||||
|  |         puns = ['.', ',', '?', '!', ':', '-'] | ||||||
|  | 
 | ||||||
|  |         # Spracovanie a ucenie | ||||||
|  |         texts = [] | ||||||
|  |         for i in range(len(json_objects)): | ||||||
|  |             indices = [value for index, value in enumerate(json_objects[i]['labels']) if value > 0] | ||||||
|  |             val = [index for index, value in enumerate(json_objects[0]['labels']) if value > 0] | ||||||
|  | 
 | ||||||
|  |             # Uprava textu | ||||||
|  |             json_objects[i]['text'] = convert(json_objects[i]['text'], indices, val, puns) | ||||||
|  | 
 | ||||||
|  |             # Pridanie upraveneho textu d ozoznamu | ||||||
|  |             texts.append(" ".join(json_objects[i]['text'])) | ||||||
|  | 
 | ||||||
|  |         # doladovanie modelu | ||||||
|  |         model = fine_tuning(texts[:], model, tokenizer) | ||||||
|  | 
 | ||||||
|  |     #2: Oprava interpunkciet | ||||||
|  |     elif option == '2': | ||||||
|  |         # Vlozenie textu bez interpunkcie alebo so zlou interpunkciou | ||||||
|  |         test = input('Enter your text: ') | ||||||
|  | 
 | ||||||
|  |         # Vypisanie textu  | ||||||
|  |         print("Output:", restore_pun(test, model)) | ||||||
|  | 
 | ||||||
|  |     #3: Ukoncenei programu | ||||||
|  |     else: | ||||||
|  |         break | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user