From 5e80e465adbec64486a7a1b7c77330aa8c0d5d69 Mon Sep 17 00:00:00 2001 From: Skudalen Date: Mon, 19 Jul 2021 09:39:56 +0200 Subject: [PATCH] feat: add cross-validation that trains on one session only --- .DS_Store | Bin 8196 -> 6148 bytes Neural_Network_Analysis.py | 65 +++++++++++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/.DS_Store b/.DS_Store index 72c945af9ad5f4ccd8105b118be89b7cef749865..3576c0f8eef457d53eb19a238be92b358e923238 100644 GIT binary patch delta 112 zcmZp1XfcprU|?W$DortDU=RQ@Ie-{Mvv5r;6q~50$jG}fU^g=(?`9r>`;4343-K{d yHWRU(SblLaI|qj#Gf*WE2yg=lSCFQSh2NPc^UHXGj9_4b7zeU}VRJms9A*HDM-!R= delta 426 zcmZoMXmOBWU|?W$DortDU;r^WfEYvza8E20o2aMAD6}zPH}hr%jz7$c**Q2SHn1=X zZRTOQ&sZ%0Yhp^adJ*l zetr(nd*9mCa5l&#r&)NK(Rfy3{4Cu2;pamH~ zi}HYeE(ZA&2$O*#@j$CnfniX>Pznr(REA=>^Cy2}bJ$$WGn<)9f*a@)S5Sy=7UcNO YJegm_bFx1V2L~f0G8i_;^UPre0K(B#yZ`_I diff --git a/Neural_Network_Analysis.py b/Neural_Network_Analysis.py index 2f9831f..e5a55e5 100644 --- a/Neural_Network_Analysis.py +++ b/Neural_Network_Analysis.py @@ -285,6 +285,47 @@ def session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, return average_result, session_training_results +# Retrieves data sets for each session as train set and evalutes on the others. +# the average of networks trained om them +# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs +# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions))) +def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, batch_size=64, epochs=30): + session_training_results = [] + for i in range(nr_sessions): + + X_test_session, X_train_session, y_test_session, y_train_session = prepare_datasets_sessions(X, y, session_lengths, i) + + # Model: + if model_name == 'LSTM': + model = LSTM(input_shape=(1, 208)) + + elif model_name == 'GRU': + model = GRU(input_shape=(1, 208)) + + elif model_name == 'CNN_1D': + X_train_session = np.reshape(X_train_session, (X_train_session.shape[0], 208, 1)) + X_test_session = np.reshape(X_test_session, (X_test_session.shape[0], 208, 1)) + model = CNN_1D(input_shape=(208, 1)) + + elif model_name == 'FFN': + model = FFN(input_shape=(1, 208)) + + else: + raise Exception('Model not found') + + train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs) + test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0) + session_training_results.append(test_acc) + #if log_to_csv: + #prediction_csv_logger(X_test_session, y_test_session, model_name, model, i) + del model + K.clear_session() + #print('Session', i, 'as test data gives accuracy:', test_acc) + + average_result = statistics.mean((session_training_results)) + + return average_result, session_training_results + # Takes in test data and logs input data and the prediction from a model # Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs # Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions))) @@ -458,31 +499,27 @@ if __name__ == "__main__": print('\n') ''' - ''' + #''' # ----- Inverse cross-validation ------ # Trained on one session, tested on three - average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS, - log_to_csv=LOG, + average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS, batch_size=BATCH_SIZE, epochs=EPOCHS) - average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS, - log_to_csv=LOG, + average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS, batch_size=BATCH_SIZE, epochs=EPOCHS) - average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS, - log_to_csv=LOG, + average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS, batch_size=BATCH_SIZE, epochs=EPOCHS) - average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS, - log_to_csv=LOG, + average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS, batch_size=BATCH_SIZE, epochs=EPOCHS) print('\n') - print('Crossvalidated GRU:', average_GRU) - print('Crossvalidated LSTM:', average_LSTM) - print('Crossvalidated FFN:', average_FFN) - print('Cross-validated CNN_1D:', average_CNN) + print('Cross-validated one-session-train GRU:', average_GRU) + print('Cross-validated one-session-train LSTM:', average_LSTM) + print('Cross-validated one-session-train FFN:', average_FFN) + print('Cross-validated one-session-train CNN_1D:', average_CNN) print('\n') - ''' + #'''