feat: make log possibility for inverse cross-val too

This commit is contained in:
Skudalen 2021-07-19 10:34:33 +02:00
parent 5e80e465ad
commit 1d1ec433d8
2 changed files with 13 additions and 5 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@ -289,7 +289,7 @@ def session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions,
# the average of networks trained om them
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, log_to_csv=True, batch_size=64, epochs=30):
session_training_results = []
for i in range(nr_sessions):
@ -316,8 +316,9 @@ def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_s
train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs)
test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0)
session_training_results.append(test_acc)
#if log_to_csv:
#prediction_csv_logger(X_test_session, y_test_session, model_name, model, i)
if log_to_csv:
custom_path = '/{}_train_session{}_log.csv'
prediction_csv_logger(X_test_session, y_test_session, model_name, model, i, custom_path)
del model
K.clear_session()
#print('Session', i, 'as test data gives accuracy:', test_acc)
@ -329,9 +330,12 @@ def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_s
# Takes in test data and logs input data and the prediction from a model
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
def prediction_csv_logger(X, y, model_name, model, session_nr):
def prediction_csv_logger(X, y, model_name, model, session_nr, custom_path=None):
csv_path = str(Path.cwd()) + '/logs/{}/{}_session{}_log.csv'.format(model_name, model_name, session_nr+1)
if custom_path:
path = str(Path.cwd()) + '/logs/{}' + custom_path
csv_path = path.format(model_name, model_name, session_nr+1)
layerOutput = model.predict(X, verbose=0)
@ -502,16 +506,20 @@ if __name__ == "__main__":
#'''
# ----- Inverse cross-validation ------
# Trained on one session, tested on three
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=True,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=True,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=True,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=True,
batch_size=BATCH_SIZE,
epochs=EPOCHS)