feat: get the LSTM running

This commit is contained in:
Skudalen 2021-07-07 15:28:15 +02:00
parent 61e5c74cde
commit cd6a1f56fd

View File

@ -19,11 +19,11 @@ def load_data(data_path):
# convert lists to numpy arraysls
X = np.array(data['mfcc'])
X = X.reshape(X.shape[0], 1, X.shape[1])
print(X.shape)
#print(X.shape)
y = np.array(data["labels"])
y = y.reshape(y.shape[0], 1)
print(y.shape)
#print(y.shape)
print("Data succesfully loaded!")
@ -109,7 +109,7 @@ if __name__ == "__main__":
print(X_train.shape[1], X_train.shape[2])
# create network
input_shape = (X_train.shape[1], X_train.shape[2]) # 1, 208
input_shape = (X_train.shape[1], X_train.shape[2]) # (~2800), 1, 208
model = build_model(input_shape)
# compile model
@ -121,7 +121,7 @@ if __name__ == "__main__":
model.summary()
# train model
history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=128, epochs=30)
history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=64, epochs=30)
# plot accuracy/error for training and validation
plot_history(history)