feat: make log possibility for inverse cross-val too
This commit is contained in:
parent
5e80e465ad
commit
1d1ec433d8
@ -289,7 +289,7 @@ def session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions,
|
|||||||
# the average of networks trained om them
|
# the average of networks trained om them
|
||||||
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||||
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||||
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
|
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, log_to_csv=True, batch_size=64, epochs=30):
|
||||||
session_training_results = []
|
session_training_results = []
|
||||||
for i in range(nr_sessions):
|
for i in range(nr_sessions):
|
||||||
|
|
||||||
@ -316,8 +316,9 @@ def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_s
|
|||||||
train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs)
|
train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs)
|
||||||
test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0)
|
test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0)
|
||||||
session_training_results.append(test_acc)
|
session_training_results.append(test_acc)
|
||||||
#if log_to_csv:
|
if log_to_csv:
|
||||||
#prediction_csv_logger(X_test_session, y_test_session, model_name, model, i)
|
custom_path = '/{}_train_session{}_log.csv'
|
||||||
|
prediction_csv_logger(X_test_session, y_test_session, model_name, model, i, custom_path)
|
||||||
del model
|
del model
|
||||||
K.clear_session()
|
K.clear_session()
|
||||||
#print('Session', i, 'as test data gives accuracy:', test_acc)
|
#print('Session', i, 'as test data gives accuracy:', test_acc)
|
||||||
@ -329,9 +330,12 @@ def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_s
|
|||||||
# Takes in test data and logs input data and the prediction from a model
|
# Takes in test data and logs input data and the prediction from a model
|
||||||
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||||
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||||
def prediction_csv_logger(X, y, model_name, model, session_nr):
|
def prediction_csv_logger(X, y, model_name, model, session_nr, custom_path=None):
|
||||||
|
|
||||||
csv_path = str(Path.cwd()) + '/logs/{}/{}_session{}_log.csv'.format(model_name, model_name, session_nr+1)
|
csv_path = str(Path.cwd()) + '/logs/{}/{}_session{}_log.csv'.format(model_name, model_name, session_nr+1)
|
||||||
|
if custom_path:
|
||||||
|
path = str(Path.cwd()) + '/logs/{}' + custom_path
|
||||||
|
csv_path = path.format(model_name, model_name, session_nr+1)
|
||||||
|
|
||||||
layerOutput = model.predict(X, verbose=0)
|
layerOutput = model.predict(X, verbose=0)
|
||||||
|
|
||||||
@ -502,16 +506,20 @@ if __name__ == "__main__":
|
|||||||
#'''
|
#'''
|
||||||
# ----- Inverse cross-validation ------
|
# ----- Inverse cross-validation ------
|
||||||
# Trained on one session, tested on three
|
# Trained on one session, tested on three
|
||||||
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=True,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=True,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=True,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=True,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user