This commit is contained in:
Dárius Lindvai 2020-06-05 14:49:42 +02:00
parent 179efb972e
commit d919ca93c0
5 changed files with 142 additions and 69 deletions

View File

@ -1,3 +1,9 @@
## Update 05.06.2020
- pridaný čas začiatku a čas ukončenia trénovania, aby bolo možné určit, ako dlho trénovanie trvalo
- upravený skript na úpravu textu do vhodnej podoby (skombinoval som môj vlastný skript s jedným voľne dostupným na internete, aby bola úprava textu presnejšia)
- pridaný tag na identifikáciu čísel v texte ("N"), čo by teoreticky mohlo zvýšiť presnosť modelu
- vyriešený výpočet precision, recall a f-score (problém som vyriešil tak, že som najprv zo skutočných hodnôt urobil tensor, ktorý som následne konvertoval na numpy pole)
## Update 05.05.2020
- upravený skript "punc.py" tak, že model načítava dáta zo súboru/ov
- vytvorený skript "text.py", ktorý upraví dáta do vhodnej podoby (5 krokov)

View File

@ -1,5 +1,4 @@
import os
import re
if os.path.exists('tags.txt'):
os.remove('tags.txt')
@ -11,15 +10,15 @@ with open('text.txt', 'r') as input_file:
if (word == '.PER'):
word = word.replace(word, 'P')
output_file.write(word + ' ')
elif (word == ',COM'):
word = word.replace(word, 'C')
output_file.write(word + ' ')
elif(word == '?QUE'):
word = word.replace(word, 'Q')
output_file.write(word + ' ')
elif(word == '<NUM>'):
word = word.replace(word, 'N')
output_file.write(word + ' ')
else:
word = word.replace(word, 'S')
output_file.write(word + ' ')

View File

@ -0,0 +1,73 @@
from __future__ import division, print_function
from nltk.tokenize import word_tokenize
import nltk
import os
from io import open
import re
import sys
nltk.download('punkt')
NUM = '<NUM>'
PUNCTS = {".": ".PER", ",": ".COM", "?": "?QUE", "!": ".PER", ":": ",COM", ";": ".PER", "-": ",COM"}
forbidden_symbols = re.compile(r"[\[\]\(\)\/\\\>\<\=\+\_\*]")
numbers = re.compile(r"\d")
multiple_punct = re.compile(r'([\.\?\!\,\:\;\-])(?:[\.\?\!\,\:\;\-]){1,}')
is_number = lambda x: len(numbers.sub("", x)) / len(x) < 0.6
def untokenize(line):
return line.replace(" '", "'").replace(" n't", "n't").replace("can not", "cannot")
def skip(line):
if line.strip() == '':
return True
last_symbol = line[-1]
if not last_symbol in PUNCTS:
return True
if forbidden_symbols.search(line) is not None:
return True
return False
def process_line(line):
tokens = word_tokenize(line)
output_tokens = []
for token in tokens:
if token in PUNCTS:
output_tokens.append(PUNCTS[token])
elif is_number(token):
output_tokens.append(NUM)
else:
output_tokens.append(token.lower())
return untokenize(" ".join(output_tokens) + " ")
skipped = 0
with open(sys.argv[2], 'w', encoding='utf-8') as out_txt:
with open(sys.argv[1], 'r', encoding='utf-8') as text:
for line in text:
line = line.replace("\"", "").strip()
line = multiple_punct.sub(r"\g<1>", line)
if skip(line):
skipped += 1
continue
line = process_line(line)
out_txt.write(line)
print("Skipped %d lines" % skipped)

View File

@ -1,14 +1,13 @@
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
from datetime import datetime
torch.manual_seed(1)
def argmax(vec):
# return the argmax as a python int
_, idx = torch.max(vec, 1)
@ -27,10 +26,6 @@ def log_sum_exp(vec):
return max_score + \
torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
@ -65,7 +60,7 @@ class BiLSTM_CRF(nn.Module):
torch.randn(2, 1, self.hidden_dim // 2))
def _forward_alg(self, feats):
# Forward algorithm to compute the partition function
# Do the forward algorithm to compute the partition function
init_alphas = torch.full((1, self.tagset_size), -10000.)
# START_TAG has all of the score.
init_alphas[0][self.tag_to_ix[START_TAG]] = 0.
@ -77,13 +72,18 @@ class BiLSTM_CRF(nn.Module):
for feat in feats:
alphas_t = [] # The forward tensors at this timestep
for next_tag in range(self.tagset_size):
# broadcast the emission score: it is the same regardless of the previous tag
emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)
# the ith entry of trans_score is the score of transitioning to next_tag from i
# broadcast the emission score: it is the same regardless of
# the previous tag
emit_score = feat[next_tag].view(
1, -1).expand(1, self.tagset_size)
# the ith entry of trans_score is the score of transitioning to
# next_tag from i
trans_score = self.transitions[next_tag].view(1, -1)
# The ith entry of next_tag_var is the value for the edge (i -> next_tag) before we do log-sum-exp
# The ith entry of next_tag_var is the value for the
# edge (i -> next_tag) before we do log-sum-exp
next_tag_var = forward_var + trans_score + emit_score
# The forward variable for this tag is log-sum-exp of all the scores.
# The forward variable for this tag is log-sum-exp of all the
# scores.
alphas_t.append(log_sum_exp(next_tag_var).view(1))
forward_var = torch.cat(alphas_t).view(1, -1)
terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
@ -158,7 +158,7 @@ class BiLSTM_CRF(nn.Module):
gold_score = self._score_sentence(feats, tags)
return forward_score - gold_score
def forward(self, sentence):
def forward(self, sentence): # dont confuse this with _forward_alg above.
# Get the emission scores from the BiLSTM
lstm_feats = self._get_lstm_features(sentence)
@ -166,25 +166,12 @@ class BiLSTM_CRF(nn.Module):
score, tag_seq = self._viterbi_decode(lstm_feats)
return score, tag_seq
START_TAG = "<START>"
STOP_TAG = "<STOP>"
EMBEDDING_DIM = 5
HIDDEN_DIM = 4
'''
training_data = [(
"hovorí sa ,COM že ľudstvo postihuje nová epidémia ,COM šíriaca sa závratnou rýchlosťou .PER preto je dôležité vedieť čo to je ,COM ako jej predísť alebo ako ju odstrániť .PER".split(),
"S S C S S S S S C S S S S P S S S S S S S C S S S S S S S P".split()
), (
"nárast obezity je spôsobený najmä spôsobom života .PER tuky zlepšujú chuť do jedla a dávajú lepší pocit sýtosti ,COM uvedomte si však ,COM že všetky tuky sa Vám ukladajú ,COM pokiaľ ich nespálite .PER".split(),
"S S S S S S S P S S S S S S S S S S C S S S C S S S S S S C S S S P".split()
)]
'''
# Make up some training data
with open('/home/dlindvai/work/text.txt', 'r') as text2:
with open('/home/dlindvai/work/tags.txt', 'r') as tags2:
text1 = text2.read().splitlines()
@ -204,20 +191,27 @@ for sentence, tags in training_data:
if word not in word_to_ix:
word_to_ix[word] = len(word_to_ix)
tag_to_ix = {"S": 0, "C": 1, "P": 2, "Q": 3, START_TAG: 4, STOP_TAG: 5}
tag_to_ix = {"S": 0, "P": 1, "C": 2, "Q": 3, "N": 4, START_TAG: 5, STOP_TAG: 6}
model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)
# Check predictions before training
with torch.no_grad():
precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)
precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long)
print("Predicted output before training: ", model(precheck_sent))
#print(model(precheck_sent))
for epoch in range(30): # normally you would NOT do 300 epochs, but this is small dataset
# Print start time
start = datetime.now()
start_time = start.strftime("%H:%M:%S")
print("Start time = ", start_time)
for epoch in range(50):
for sentence, tags in training_data:
# Step 1. Remember that Pytorch accumulates gradients.
# We need to clear them out before each instance
# Step 1. Remember that Pytorch accumulates gradients. We need to clear them out before each instance
model.zero_grad()
# Step 2. Get our inputs ready for the network, that is, turn them into Tensors of word indices.
@ -231,7 +225,22 @@ for epoch in range(30): # normally you would NOT do 300 epochs, but this is sma
loss.backward()
optimizer.step()
# Check predictions after training
with torch.no_grad():
precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)
print("Predicted output after training: ", model(precheck_sent))
#print(model(precheck_sent))
# Error calculator
var = model(precheck_sent)
y_true = np.array(targets)
y_pred = np.array(var[1])
print(metrics.confusion_matrix(y_true, y_pred))
print(metrics.classification_report(y_true, y_pred, digits=3))
# Print finish time
finish = datetime.now()
finish_time = finish.strftime("%H:%M:%S")
print("Finish time = ", finish_time)

View File

@ -1,14 +0,0 @@
import re
import os
if os.path.exists('text.txt'):
os.remove('text.txt')
with open('/home/dlindvai/work/train.txt', 'r') as input_file:
with open('/home/dlindvai/work/text.txt', 'a') as output_file:
for line in input_file:
line = line.replace('\n', '')
line = re.sub(r"([\w/'+$\s-]+|[^\w/'+$\s-]+)\s*", r"\1 ", line)
line = line.lower()
line = line.replace('.','.PER').replace(',',',COM').replace('?','?QUE')
output_file.write(line)