chore: improve cross-validation
This commit is contained in:
		
							parent
							
								
									439238e070
								
							
						
					
					
						commit
						c3fd1fc415
					
				| @ -377,7 +377,7 @@ if __name__ == "__main__": | |||||||
|     NR_SUBJECTS = 5 |     NR_SUBJECTS = 5 | ||||||
|     NR_SESSIONS = 4 |     NR_SESSIONS = 4 | ||||||
|     BATCH_SIZE = 64 |     BATCH_SIZE = 64 | ||||||
|     EPOCHS = 5 |     EPOCHS = 30 | ||||||
| 
 | 
 | ||||||
|     TEST_SESSION_NR = 4 |     TEST_SESSION_NR = 4 | ||||||
|     VERBOSE = 1 |     VERBOSE = 1 | ||||||
| @ -430,8 +430,9 @@ if __name__ == "__main__": | |||||||
|     ''' |     ''' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     #''' |     ''' | ||||||
|     # ----- Cross validation ------ |     # ----- Cross validation ------ | ||||||
|  |     # Trained on three sessions, tested on one | ||||||
|     average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,  |     average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,  | ||||||
|                                                                         log_to_csv=LOG, |                                                                         log_to_csv=LOG, | ||||||
|                                                                         batch_size=BATCH_SIZE,  |                                                                         batch_size=BATCH_SIZE,  | ||||||
| @ -455,5 +456,33 @@ if __name__ == "__main__": | |||||||
|     print('Crossvalidated FFN:', average_FFN) |     print('Crossvalidated FFN:', average_FFN) | ||||||
|     print('Cross-validated CNN_1D:', average_CNN) |     print('Cross-validated CNN_1D:', average_CNN) | ||||||
|     print('\n') |     print('\n') | ||||||
|     #''' |     ''' | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  |     # ----- Inverse cross-validation ------ | ||||||
|  |     # Trained on one session, tested on three | ||||||
|  |     average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,  | ||||||
|  |                                                                         log_to_csv=LOG, | ||||||
|  |                                                                         batch_size=BATCH_SIZE,  | ||||||
|  |                                                                         epochs=EPOCHS) | ||||||
|  |     average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS, | ||||||
|  |                                                                         log_to_csv=LOG, | ||||||
|  |                                                                         batch_size=BATCH_SIZE,  | ||||||
|  |                                                                         epochs=EPOCHS) | ||||||
|  |     average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS, | ||||||
|  |                                                                         log_to_csv=LOG, | ||||||
|  |                                                                         batch_size=BATCH_SIZE,  | ||||||
|  |                                                                         epochs=EPOCHS) | ||||||
|  |     average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS, | ||||||
|  |                                                                         log_to_csv=LOG, | ||||||
|  |                                                                         batch_size=BATCH_SIZE,  | ||||||
|  |                                                                         epochs=EPOCHS) | ||||||
|  | 
 | ||||||
|  |     print('\n') | ||||||
|  |     print('Crossvalidated GRU:', average_GRU) | ||||||
|  |     print('Crossvalidated LSTM:', average_LSTM) | ||||||
|  |     print('Crossvalidated FFN:', average_FFN) | ||||||
|  |     print('Cross-validated CNN_1D:', average_CNN) | ||||||
|  |     print('\n') | ||||||
|  |     ''' | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user