feat: make log possibility for inverse cross-val too
This commit is contained in:
parent
5e80e465ad
commit
1d1ec433d8
@ -289,7 +289,7 @@ def session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions,
|
||||
# the average of networks trained om them
|
||||
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
|
||||
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, log_to_csv=True, batch_size=64, epochs=30):
|
||||
session_training_results = []
|
||||
for i in range(nr_sessions):
|
||||
|
||||
@ -316,8 +316,9 @@ def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_s
|
||||
train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs)
|
||||
test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0)
|
||||
session_training_results.append(test_acc)
|
||||
#if log_to_csv:
|
||||
#prediction_csv_logger(X_test_session, y_test_session, model_name, model, i)
|
||||
if log_to_csv:
|
||||
custom_path = '/{}_train_session{}_log.csv'
|
||||
prediction_csv_logger(X_test_session, y_test_session, model_name, model, i, custom_path)
|
||||
del model
|
||||
K.clear_session()
|
||||
#print('Session', i, 'as test data gives accuracy:', test_acc)
|
||||
@ -329,9 +330,12 @@ def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_s
|
||||
# Takes in test data and logs input data and the prediction from a model
|
||||
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||
def prediction_csv_logger(X, y, model_name, model, session_nr):
|
||||
def prediction_csv_logger(X, y, model_name, model, session_nr, custom_path=None):
|
||||
|
||||
csv_path = str(Path.cwd()) + '/logs/{}/{}_session{}_log.csv'.format(model_name, model_name, session_nr+1)
|
||||
if custom_path:
|
||||
path = str(Path.cwd()) + '/logs/{}' + custom_path
|
||||
csv_path = path.format(model_name, model_name, session_nr+1)
|
||||
|
||||
layerOutput = model.predict(X, verbose=0)
|
||||
|
||||
@ -503,15 +507,19 @@ if __name__ == "__main__":
|
||||
# ----- Inverse cross-validation ------
|
||||
# Trained on one session, tested on three
|
||||
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user