chore: improve cross-validation
This commit is contained in:
		
							parent
							
								
									439238e070
								
							
						
					
					
						commit
						c3fd1fc415
					
				| @ -377,7 +377,7 @@ if __name__ == "__main__": | ||||
|     NR_SUBJECTS = 5 | ||||
|     NR_SESSIONS = 4 | ||||
|     BATCH_SIZE = 64 | ||||
|     EPOCHS = 5 | ||||
|     EPOCHS = 30 | ||||
| 
 | ||||
|     TEST_SESSION_NR = 4 | ||||
|     VERBOSE = 1 | ||||
| @ -430,8 +430,9 @@ if __name__ == "__main__": | ||||
|     ''' | ||||
| 
 | ||||
| 
 | ||||
|     #''' | ||||
|     ''' | ||||
|     # ----- Cross validation ------ | ||||
|     # Trained on three sessions, tested on one | ||||
|     average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,  | ||||
|                                                                         log_to_csv=LOG, | ||||
|                                                                         batch_size=BATCH_SIZE,  | ||||
| @ -455,5 +456,33 @@ if __name__ == "__main__": | ||||
|     print('Crossvalidated FFN:', average_FFN) | ||||
|     print('Cross-validated CNN_1D:', average_CNN) | ||||
|     print('\n') | ||||
|     #''' | ||||
|     ''' | ||||
| 
 | ||||
|     ''' | ||||
|     # ----- Inverse cross-validation ------ | ||||
|     # Trained on one session, tested on three | ||||
|     average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,  | ||||
|                                                                         log_to_csv=LOG, | ||||
|                                                                         batch_size=BATCH_SIZE,  | ||||
|                                                                         epochs=EPOCHS) | ||||
|     average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS, | ||||
|                                                                         log_to_csv=LOG, | ||||
|                                                                         batch_size=BATCH_SIZE,  | ||||
|                                                                         epochs=EPOCHS) | ||||
|     average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS, | ||||
|                                                                         log_to_csv=LOG, | ||||
|                                                                         batch_size=BATCH_SIZE,  | ||||
|                                                                         epochs=EPOCHS) | ||||
|     average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS, | ||||
|                                                                         log_to_csv=LOG, | ||||
|                                                                         batch_size=BATCH_SIZE,  | ||||
|                                                                         epochs=EPOCHS) | ||||
| 
 | ||||
|     print('\n') | ||||
|     print('Crossvalidated GRU:', average_GRU) | ||||
|     print('Crossvalidated LSTM:', average_LSTM) | ||||
|     print('Crossvalidated FFN:', average_FFN) | ||||
|     print('Cross-validated CNN_1D:', average_CNN) | ||||
|     print('\n') | ||||
|     ''' | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user