chore: improve cross-validation
This commit is contained in:
parent
439238e070
commit
c3fd1fc415
@ -377,7 +377,7 @@ if __name__ == "__main__":
|
||||
NR_SUBJECTS = 5
|
||||
NR_SESSIONS = 4
|
||||
BATCH_SIZE = 64
|
||||
EPOCHS = 5
|
||||
EPOCHS = 30
|
||||
|
||||
TEST_SESSION_NR = 4
|
||||
VERBOSE = 1
|
||||
@ -430,8 +430,9 @@ if __name__ == "__main__":
|
||||
'''
|
||||
|
||||
|
||||
#'''
|
||||
'''
|
||||
# ----- Cross validation ------
|
||||
# Trained on three sessions, tested on one
|
||||
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
@ -455,5 +456,33 @@ if __name__ == "__main__":
|
||||
print('Crossvalidated FFN:', average_FFN)
|
||||
print('Cross-validated CNN_1D:', average_CNN)
|
||||
print('\n')
|
||||
#'''
|
||||
'''
|
||||
|
||||
'''
|
||||
# ----- Inverse cross-validation ------
|
||||
# Trained on one session, tested on three
|
||||
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
|
||||
print('\n')
|
||||
print('Crossvalidated GRU:', average_GRU)
|
||||
print('Crossvalidated LSTM:', average_LSTM)
|
||||
print('Crossvalidated FFN:', average_FFN)
|
||||
print('Cross-validated CNN_1D:', average_CNN)
|
||||
print('\n')
|
||||
'''
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user