feat: get the LSTM running

This commit is contained in:
Skudalen 2021-07-07 15:28:15 +02:00
parent 61e5c74cde
commit cd6a1f56fd

View File

@ -19,11 +19,11 @@ def load_data(data_path):
# convert lists to numpy arraysls # convert lists to numpy arraysls
X = np.array(data['mfcc']) X = np.array(data['mfcc'])
X = X.reshape(X.shape[0], 1, X.shape[1]) X = X.reshape(X.shape[0], 1, X.shape[1])
print(X.shape) #print(X.shape)
y = np.array(data["labels"]) y = np.array(data["labels"])
y = y.reshape(y.shape[0], 1) y = y.reshape(y.shape[0], 1)
print(y.shape) #print(y.shape)
print("Data succesfully loaded!") print("Data succesfully loaded!")
@ -109,7 +109,7 @@ if __name__ == "__main__":
print(X_train.shape[1], X_train.shape[2]) print(X_train.shape[1], X_train.shape[2])
# create network # create network
input_shape = (X_train.shape[1], X_train.shape[2]) # 1, 208 input_shape = (X_train.shape[1], X_train.shape[2]) # (~2800), 1, 208
model = build_model(input_shape) model = build_model(input_shape)
# compile model # compile model
@ -121,7 +121,7 @@ if __name__ == "__main__":
model.summary() model.summary()
# train model # train model
history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=128, epochs=30) history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=64, epochs=30)
# plot accuracy/error for training and validation # plot accuracy/error for training and validation
plot_history(history) plot_history(history)