feat: get the LSTM running
This commit is contained in:
parent
61e5c74cde
commit
cd6a1f56fd
@ -19,11 +19,11 @@ def load_data(data_path):
|
||||
# convert lists to numpy arraysls
|
||||
X = np.array(data['mfcc'])
|
||||
X = X.reshape(X.shape[0], 1, X.shape[1])
|
||||
print(X.shape)
|
||||
#print(X.shape)
|
||||
|
||||
y = np.array(data["labels"])
|
||||
y = y.reshape(y.shape[0], 1)
|
||||
print(y.shape)
|
||||
#print(y.shape)
|
||||
|
||||
|
||||
print("Data succesfully loaded!")
|
||||
@ -109,7 +109,7 @@ if __name__ == "__main__":
|
||||
print(X_train.shape[1], X_train.shape[2])
|
||||
# create network
|
||||
|
||||
input_shape = (X_train.shape[1], X_train.shape[2]) # 1, 208
|
||||
input_shape = (X_train.shape[1], X_train.shape[2]) # (~2800), 1, 208
|
||||
model = build_model(input_shape)
|
||||
|
||||
# compile model
|
||||
@ -121,7 +121,7 @@ if __name__ == "__main__":
|
||||
model.summary()
|
||||
|
||||
# train model
|
||||
history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=128, epochs=30)
|
||||
history = model.fit(X_train, y_train, validation_data=(X_validation, y_validation), batch_size=64, epochs=30)
|
||||
|
||||
# plot accuracy/error for training and validation
|
||||
plot_history(history)
|
||||
|
Loading…
Reference in New Issue
Block a user