forked from KEMT/zpwiki
		
	update
This commit is contained in:
		
							parent
							
								
									179efb972e
								
							
						
					
					
						commit
						d919ca93c0
					
				| @ -1,3 +1,9 @@ | ||||
| ## Update 05.06.2020 | ||||
| - pridaný čas začiatku a čas ukončenia trénovania, aby bolo možné určit, ako dlho trénovanie trvalo | ||||
| - upravený skript na úpravu textu do vhodnej podoby (skombinoval som môj vlastný skript s jedným voľne dostupným na internete, aby bola úprava textu presnejšia) | ||||
| - pridaný tag na identifikáciu čísel v texte ("N"), čo by teoreticky mohlo zvýšiť presnosť modelu | ||||
| - vyriešený výpočet precision, recall a f-score (problém som vyriešil tak, že som najprv zo skutočných hodnôt urobil tensor, ktorý som následne konvertoval na numpy pole) | ||||
| 
 | ||||
| ## Update 05.05.2020 | ||||
| - upravený skript "punc.py" tak, že model načítava dáta zo súboru/ov | ||||
| - vytvorený skript "text.py", ktorý upraví dáta do vhodnej podoby (5 krokov) | ||||
|  | ||||
| @ -1,5 +1,4 @@ | ||||
| import os | ||||
| import re | ||||
| 
 | ||||
| if os.path.exists('tags.txt'): | ||||
| 	os.remove('tags.txt') | ||||
| @ -11,15 +10,15 @@ with open('text.txt', 'r') as input_file: | ||||
| 				if (word == '.PER'): | ||||
| 					word = word.replace(word, 'P') | ||||
| 					output_file.write(word + ' ') | ||||
| 
 | ||||
| 				elif (word == ',COM'): | ||||
| 					word = word.replace(word, 'C') | ||||
| 					output_file.write(word + ' ') | ||||
| 
 | ||||
| 				elif(word == '?QUE'): | ||||
| 					word = word.replace(word, 'Q') | ||||
| 					output_file.write(word + ' ') | ||||
| 
 | ||||
| 				elif(word == '<NUM>'): | ||||
| 					word = word.replace(word, 'N') | ||||
| 					output_file.write(word + ' ') | ||||
| 				else: | ||||
| 					word = word.replace(word, 'S') | ||||
| 					output_file.write(word + ' ') | ||||
							
								
								
									
										73
									
								
								pages/students/2016/darius_lindvai/dp2021/prepare_text.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								pages/students/2016/darius_lindvai/dp2021/prepare_text.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,73 @@ | ||||
| from __future__ import division, print_function | ||||
| from nltk.tokenize import word_tokenize | ||||
| 
 | ||||
| import nltk | ||||
| import os | ||||
| from io import open | ||||
| import re | ||||
| import sys | ||||
| 
 | ||||
| nltk.download('punkt') | ||||
| 
 | ||||
| NUM = '<NUM>' | ||||
| 
 | ||||
| PUNCTS = {".": ".PER", ",": ".COM", "?": "?QUE", "!": ".PER", ":": ",COM", ";": ".PER", "-": ",COM"} | ||||
| 
 | ||||
| forbidden_symbols = re.compile(r"[\[\]\(\)\/\\\>\<\=\+\_\*]") | ||||
| numbers = re.compile(r"\d") | ||||
| multiple_punct = re.compile(r'([\.\?\!\,\:\;\-])(?:[\.\?\!\,\:\;\-]){1,}') | ||||
| 
 | ||||
| is_number = lambda x: len(numbers.sub("", x)) / len(x) < 0.6 | ||||
| 
 | ||||
| def untokenize(line): | ||||
|     return line.replace(" '", "'").replace(" n't", "n't").replace("can not", "cannot") | ||||
| 
 | ||||
| def skip(line): | ||||
| 
 | ||||
|     if line.strip() == '': | ||||
|         return True | ||||
| 
 | ||||
|     last_symbol = line[-1] | ||||
|     if not last_symbol in PUNCTS: | ||||
|         return True | ||||
| 
 | ||||
|     if forbidden_symbols.search(line) is not None: | ||||
|         return True | ||||
| 
 | ||||
|     return False | ||||
| 
 | ||||
| def process_line(line): | ||||
| 
 | ||||
|     tokens = word_tokenize(line) | ||||
|     output_tokens = [] | ||||
| 
 | ||||
|     for token in tokens: | ||||
| 
 | ||||
|         if token in PUNCTS: | ||||
|             output_tokens.append(PUNCTS[token]) | ||||
|         elif is_number(token): | ||||
|             output_tokens.append(NUM) | ||||
|         else: | ||||
|             output_tokens.append(token.lower()) | ||||
| 
 | ||||
|     return untokenize(" ".join(output_tokens) + " ") | ||||
| 
 | ||||
| skipped = 0 | ||||
| 
 | ||||
| with open(sys.argv[2], 'w', encoding='utf-8') as out_txt: | ||||
|     with open(sys.argv[1], 'r', encoding='utf-8') as text: | ||||
| 
 | ||||
|         for line in text: | ||||
| 
 | ||||
|             line = line.replace("\"", "").strip() | ||||
|             line = multiple_punct.sub(r"\g<1>", line) | ||||
| 
 | ||||
|             if skip(line): | ||||
|                 skipped += 1 | ||||
|                 continue | ||||
| 
 | ||||
|             line = process_line(line) | ||||
| 
 | ||||
|             out_txt.write(line) | ||||
| 
 | ||||
| print("Skipped %d lines" % skipped) | ||||
| @ -1,14 +1,13 @@ | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.autograd as autograd | ||||
| import torch.nn as nn | ||||
| import torch.optim as optim | ||||
| from sklearn import metrics | ||||
| from datetime import datetime | ||||
| 
 | ||||
| torch.manual_seed(1) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def argmax(vec): | ||||
|     # return the argmax as a python int | ||||
|     _, idx = torch.max(vec, 1) | ||||
| @ -27,10 +26,6 @@ def log_sum_exp(vec): | ||||
|     return max_score + \ | ||||
|         torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class BiLSTM_CRF(nn.Module): | ||||
| 
 | ||||
|     def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim): | ||||
| @ -65,7 +60,7 @@ class BiLSTM_CRF(nn.Module): | ||||
|                 torch.randn(2, 1, self.hidden_dim // 2)) | ||||
| 
 | ||||
|     def _forward_alg(self, feats): | ||||
|         # Forward algorithm to compute the partition function | ||||
|         # Do the forward algorithm to compute the partition function | ||||
|         init_alphas = torch.full((1, self.tagset_size), -10000.) | ||||
|         # START_TAG has all of the score. | ||||
|         init_alphas[0][self.tag_to_ix[START_TAG]] = 0. | ||||
| @ -77,13 +72,18 @@ class BiLSTM_CRF(nn.Module): | ||||
|         for feat in feats: | ||||
|             alphas_t = []  # The forward tensors at this timestep | ||||
|             for next_tag in range(self.tagset_size): | ||||
|                 # broadcast the emission score: it is the same regardless of the previous tag | ||||
|                 emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size) | ||||
|                 # the ith entry of trans_score is the score of transitioning to next_tag from i | ||||
|                 # broadcast the emission score: it is the same regardless of | ||||
|                 # the previous tag | ||||
|                 emit_score = feat[next_tag].view( | ||||
|                     1, -1).expand(1, self.tagset_size) | ||||
|                 # the ith entry of trans_score is the score of transitioning to | ||||
|                 # next_tag from i | ||||
|                 trans_score = self.transitions[next_tag].view(1, -1) | ||||
|                 # The ith entry of next_tag_var is the value for the edge (i -> next_tag) before we do log-sum-exp | ||||
|                 # The ith entry of next_tag_var is the value for the | ||||
|                 # edge (i -> next_tag) before we do log-sum-exp | ||||
|                 next_tag_var = forward_var + trans_score + emit_score | ||||
|                 # The forward variable for this tag is log-sum-exp of all the scores. | ||||
|                 # The forward variable for this tag is log-sum-exp of all the | ||||
|                 # scores. | ||||
|                 alphas_t.append(log_sum_exp(next_tag_var).view(1)) | ||||
|             forward_var = torch.cat(alphas_t).view(1, -1) | ||||
|         terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] | ||||
| @ -158,7 +158,7 @@ class BiLSTM_CRF(nn.Module): | ||||
|         gold_score = self._score_sentence(feats, tags) | ||||
|         return forward_score - gold_score | ||||
| 
 | ||||
|     def forward(self, sentence): | ||||
|     def forward(self, sentence):  # dont confuse this with _forward_alg above. | ||||
|         # Get the emission scores from the BiLSTM | ||||
|         lstm_feats = self._get_lstm_features(sentence) | ||||
| 
 | ||||
| @ -166,25 +166,12 @@ class BiLSTM_CRF(nn.Module): | ||||
|         score, tag_seq = self._viterbi_decode(lstm_feats) | ||||
|         return score, tag_seq | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| START_TAG = "<START>" | ||||
| STOP_TAG = "<STOP>" | ||||
| EMBEDDING_DIM = 5 | ||||
| HIDDEN_DIM = 4 | ||||
| 
 | ||||
| ''' | ||||
| training_data = [( | ||||
|     "hovorí sa ,COM že ľudstvo postihuje nová epidémia ,COM šíriaca sa závratnou rýchlosťou .PER preto je dôležité vedieť čo to je ,COM ako jej predísť alebo ako ju odstrániť .PER".split(), | ||||
|     "S S C S S S S S C S S S S P S S S S S S S C S S S S S S S P".split() | ||||
| ), ( | ||||
|     "nárast obezity je spôsobený najmä spôsobom života .PER tuky zlepšujú chuť do jedla a dávajú lepší pocit sýtosti ,COM uvedomte si však ,COM že všetky tuky sa Vám ukladajú ,COM pokiaľ ich nespálite .PER".split(), | ||||
|     "S S S S S S S P S S S S S S S S S S C S S S C S S S S S S C S S S P".split() | ||||
| )] | ||||
| ''' | ||||
| 
 | ||||
| # Make up some training data | ||||
| with open('/home/dlindvai/work/text.txt', 'r') as text2: | ||||
| 	with open('/home/dlindvai/work/tags.txt', 'r') as tags2: | ||||
| 		text1 = text2.read().splitlines() | ||||
| @ -204,20 +191,27 @@ for sentence, tags in training_data: | ||||
| 		if word not in word_to_ix: | ||||
| 			word_to_ix[word] = len(word_to_ix) | ||||
| 
 | ||||
| tag_to_ix = {"S": 0, "C": 1, "P": 2, "Q": 3, START_TAG: 4, STOP_TAG: 5} | ||||
| tag_to_ix = {"S": 0, "P": 1, "C": 2, "Q": 3, "N": 4, START_TAG: 5, STOP_TAG: 6} | ||||
| 
 | ||||
| model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM) | ||||
| optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4) | ||||
| 
 | ||||
| 
 | ||||
| # Check predictions before training | ||||
| with torch.no_grad(): | ||||
| 	precheck_sent = prepare_sequence(training_data[0][0], word_to_ix) | ||||
| 	precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long) | ||||
|     print("Predicted output before training: ", model(precheck_sent)) | ||||
| 	#print(model(precheck_sent)) | ||||
| 
 | ||||
| for epoch in range(30):  # normally you would NOT do 300 epochs, but this is small dataset | ||||
| 
 | ||||
| # Print start time | ||||
| start = datetime.now() | ||||
| start_time = start.strftime("%H:%M:%S") | ||||
| print("Start time = ", start_time) | ||||
| 
 | ||||
| for epoch in range(50): | ||||
| 	for sentence, tags in training_data: | ||||
|         # Step 1. Remember that Pytorch accumulates gradients. | ||||
|         # We need to clear them out before each instance | ||||
| 		# Step 1. Remember that Pytorch accumulates gradients. We need to clear them out before each instance | ||||
| 		model.zero_grad() | ||||
| 
 | ||||
| 		# Step 2. Get our inputs ready for the network, that is, turn them into Tensors of word indices. | ||||
| @ -231,7 +225,22 @@ for epoch in range(30):  # normally you would NOT do 300 epochs, but this is sma | ||||
| 		loss.backward() | ||||
| 		optimizer.step() | ||||
| 
 | ||||
| 
 | ||||
| # Check predictions after training | ||||
| with torch.no_grad(): | ||||
| 	precheck_sent = prepare_sequence(training_data[0][0], word_to_ix) | ||||
|     print("Predicted output after training: ", model(precheck_sent)) | ||||
| 	#print(model(precheck_sent)) | ||||
| 
 | ||||
| 
 | ||||
| # Error calculator | ||||
| var = model(precheck_sent) | ||||
| y_true = np.array(targets) | ||||
| y_pred = np.array(var[1]) | ||||
| 
 | ||||
| print(metrics.confusion_matrix(y_true, y_pred)) | ||||
| print(metrics.classification_report(y_true, y_pred, digits=3)) | ||||
| 
 | ||||
| # Print finish time | ||||
| finish = datetime.now() | ||||
| finish_time = finish.strftime("%H:%M:%S") | ||||
| print("Finish time = ", finish_time) | ||||
|  | ||||
| @ -1,14 +0,0 @@ | ||||
| import re | ||||
| import os | ||||
| 
 | ||||
| if os.path.exists('text.txt'): | ||||
| 	os.remove('text.txt') | ||||
| 
 | ||||
| with open('/home/dlindvai/work/train.txt', 'r') as input_file: | ||||
| 	with open('/home/dlindvai/work/text.txt', 'a') as output_file: | ||||
| 		for line in input_file: | ||||
| 			line = line.replace('\n', '') | ||||
| 			line = re.sub(r"([\w/'+$\s-]+|[^\w/'+$\s-]+)\s*", r"\1 ", line) | ||||
| 			line = line.lower() | ||||
| 			line = line.replace('.','.PER').replace(',',',COM').replace('?','?QUE') | ||||
| 			output_file.write(line) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user