chore: improve cross-validation
This commit is contained in:
parent
439238e070
commit
c3fd1fc415
@ -377,7 +377,7 @@ if __name__ == "__main__":
|
|||||||
NR_SUBJECTS = 5
|
NR_SUBJECTS = 5
|
||||||
NR_SESSIONS = 4
|
NR_SESSIONS = 4
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
EPOCHS = 5
|
EPOCHS = 30
|
||||||
|
|
||||||
TEST_SESSION_NR = 4
|
TEST_SESSION_NR = 4
|
||||||
VERBOSE = 1
|
VERBOSE = 1
|
||||||
@ -430,8 +430,9 @@ if __name__ == "__main__":
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
#'''
|
'''
|
||||||
# ----- Cross validation ------
|
# ----- Cross validation ------
|
||||||
|
# Trained on three sessions, tested on one
|
||||||
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
log_to_csv=LOG,
|
log_to_csv=LOG,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
@ -455,5 +456,33 @@ if __name__ == "__main__":
|
|||||||
print('Crossvalidated FFN:', average_FFN)
|
print('Crossvalidated FFN:', average_FFN)
|
||||||
print('Cross-validated CNN_1D:', average_CNN)
|
print('Cross-validated CNN_1D:', average_CNN)
|
||||||
print('\n')
|
print('\n')
|
||||||
#'''
|
'''
|
||||||
|
|
||||||
|
'''
|
||||||
|
# ----- Inverse cross-validation ------
|
||||||
|
# Trained on one session, tested on three
|
||||||
|
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=LOG,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
epochs=EPOCHS)
|
||||||
|
average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=LOG,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
epochs=EPOCHS)
|
||||||
|
average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=LOG,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
epochs=EPOCHS)
|
||||||
|
average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
|
log_to_csv=LOG,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
epochs=EPOCHS)
|
||||||
|
|
||||||
|
print('\n')
|
||||||
|
print('Crossvalidated GRU:', average_GRU)
|
||||||
|
print('Crossvalidated LSTM:', average_LSTM)
|
||||||
|
print('Crossvalidated FFN:', average_FFN)
|
||||||
|
print('Cross-validated CNN_1D:', average_CNN)
|
||||||
|
print('\n')
|
||||||
|
'''
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user