chore: improve cross-validation

This commit is contained in:
Skudalen 2021-07-16 18:11:21 +02:00
parent 439238e070
commit c3fd1fc415
2 changed files with 32 additions and 3 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@ -377,7 +377,7 @@ if __name__ == "__main__":
NR_SUBJECTS = 5
NR_SESSIONS = 4
BATCH_SIZE = 64
EPOCHS = 5
EPOCHS = 30
TEST_SESSION_NR = 4
VERBOSE = 1
@ -430,8 +430,9 @@ if __name__ == "__main__":
'''
#'''
'''
# ----- Cross validation ------
# Trained on three sessions, tested on one
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE,
@ -455,5 +456,33 @@ if __name__ == "__main__":
print('Crossvalidated FFN:', average_FFN)
print('Cross-validated CNN_1D:', average_CNN)
print('\n')
#'''
'''
'''
# ----- Inverse cross-validation ------
# Trained on one session, tested on three
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE,
epochs=EPOCHS)
print('\n')
print('Crossvalidated GRU:', average_GRU)
print('Crossvalidated LSTM:', average_LSTM)
print('Crossvalidated FFN:', average_FFN)
print('Cross-validated CNN_1D:', average_CNN)
print('\n')
'''