65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
|
from numpy import array
|
||
|
from pickle import dump
|
||
|
from keras.utils import to_categorical
|
||
|
from keras.models import Sequential
|
||
|
from keras.layers import Dense
|
||
|
from keras.layers import LSTM
|
||
|
from keras.callbacks import CSVLogger
|
||
|
|
||
|
|
||
|
# load doc into memory
|
||
|
def load_doc(filename):
|
||
|
# open the file as read only
|
||
|
file = open(filename, 'r')
|
||
|
# read all text
|
||
|
text = file.read()
|
||
|
# close the file
|
||
|
file.close()
|
||
|
return text
|
||
|
|
||
|
|
||
|
# load
|
||
|
in_filename = 'char_sequences.txt'
|
||
|
raw_text = load_doc(in_filename)
|
||
|
lines = raw_text.split('\n')
|
||
|
|
||
|
# integer encode sequences of characters
|
||
|
chars = sorted(list(set(raw_text)))
|
||
|
mapping = dict((c, i) for i, c in enumerate(chars))
|
||
|
sequences = list()
|
||
|
for line in lines:
|
||
|
# integer encode line
|
||
|
encoded_seq = [mapping[char] for char in line]
|
||
|
# store
|
||
|
sequences.append(encoded_seq)
|
||
|
|
||
|
# vocabulary size
|
||
|
vocab_size = len(mapping)
|
||
|
print('Vocabulary Size: %d' % vocab_size)
|
||
|
|
||
|
# separate into input and output
|
||
|
sequences = array(sequences)
|
||
|
X, y = sequences[:, :-1], sequences[:, -1]
|
||
|
sequences = [to_categorical(x, num_classes=vocab_size) for x in X]
|
||
|
X = array(sequences)
|
||
|
y = to_categorical(y, num_classes=vocab_size)
|
||
|
|
||
|
# define model
|
||
|
model = Sequential()
|
||
|
model.add(LSTM(250, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
|
||
|
model.add(LSTM(250, return_sequences=True))
|
||
|
model.add((LSTM(250)))
|
||
|
model.add(Dense(vocab_size, activation='softmax'))
|
||
|
# compile model
|
||
|
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
|
||
|
# print(model.summary())
|
||
|
# fit model
|
||
|
csv_logger = CSVLogger('log.csv', append=True, separator=';')
|
||
|
|
||
|
model.fit(X, y, epochs=30, verbose=2, callbacks=[csv_logger])
|
||
|
|
||
|
# save the model to file
|
||
|
model.save('model.h5')
|
||
|
# save the mapping
|
||
|
dump(mapping, open('mapping.pkl', 'wb'))
|