From cd6a1f56fdb8bc71d00a063190ca4a002a25835f Mon Sep 17 00:00:00 2001 From: Skudalen Date: Wed, 7 Jul 2021 15:28:15 +0200 Subject: [PATCH] feat: get the LSTM running --- Neural_Network_Analysis.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Neural_Network_Analysis.py b/Neural_Network_Analysis.py index 2ac3186..9fb53b8 100644 --- a/Neural_Network_Analysis.py +++ b/Neural_Network_Analysis.py @@ -19,11 +19,11 @@ def load_data(data_path): # convert lists to numpy arraysls X = np.array(data['mfcc']) X = X.reshape(X.shape[0], 1, X.shape[1]) - print(X.shape) + #print(X.shape) y = np.array(data["labels"]) y = y.reshape(y.shape[0], 1) - print(y.shape) + #print(y.shape) print("Data succesfully loaded!") @@ -109,7 +109,7 @@ if __name__ == "__main__": print(X_train.shape[1], X_train.shape[2]) # create network - input_shape = (X_train.shape[1], X_train.shape[2]) # 1, 208 + input_shape = (X_train.shape[1], X_train.shape[2]) # (~2800), 1, 208 model = build_model(input_shape) # compile model @@ -121,7 +121,7 @@ if __name__ == "__main__": model.summary() # train model - history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=128, epochs=30) + history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=64, epochs=30) # plot accuracy/error for training and validation plot_history(history)