feat: add cross-validation that trains on one session only

This commit is contained in:
Skudalen 2021-07-19 09:39:56 +02:00
parent 3d512addb8
commit 5e80e465ad
2 changed files with 51 additions and 14 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@ -285,6 +285,47 @@ def session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions,
return average_result, session_training_results return average_result, session_training_results
# Retrieves data sets for each session as train set and evalutes on the others.
# the average of networks trained om them
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
session_training_results = []
for i in range(nr_sessions):
X_test_session, X_train_session, y_test_session, y_train_session = prepare_datasets_sessions(X, y, session_lengths, i)
# Model:
if model_name == 'LSTM':
model = LSTM(input_shape=(1, 208))
elif model_name == 'GRU':
model = GRU(input_shape=(1, 208))
elif model_name == 'CNN_1D':
X_train_session = np.reshape(X_train_session, (X_train_session.shape[0], 208, 1))
X_test_session = np.reshape(X_test_session, (X_test_session.shape[0], 208, 1))
model = CNN_1D(input_shape=(208, 1))
elif model_name == 'FFN':
model = FFN(input_shape=(1, 208))
else:
raise Exception('Model not found')
train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs)
test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0)
session_training_results.append(test_acc)
#if log_to_csv:
#prediction_csv_logger(X_test_session, y_test_session, model_name, model, i)
del model
K.clear_session()
#print('Session', i, 'as test data gives accuracy:', test_acc)
average_result = statistics.mean((session_training_results))
return average_result, session_training_results
# Takes in test data and logs input data and the prediction from a model # Takes in test data and logs input data and the prediction from a model
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs # Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions))) # Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
@ -458,31 +499,27 @@ if __name__ == "__main__":
print('\n') print('\n')
''' '''
''' #'''
# ----- Inverse cross-validation ------ # ----- Inverse cross-validation ------
# Trained on one session, tested on three # Trained on one session, tested on three
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS, average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
epochs=EPOCHS) epochs=EPOCHS)
average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS, average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
epochs=EPOCHS) epochs=EPOCHS)
average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS, average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
epochs=EPOCHS) epochs=EPOCHS)
average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS, average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
log_to_csv=LOG,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
epochs=EPOCHS) epochs=EPOCHS)
print('\n') print('\n')
print('Crossvalidated GRU:', average_GRU) print('Cross-validated one-session-train GRU:', average_GRU)
print('Crossvalidated LSTM:', average_LSTM) print('Cross-validated one-session-train LSTM:', average_LSTM)
print('Crossvalidated FFN:', average_FFN) print('Cross-validated one-session-train FFN:', average_FFN)
print('Cross-validated CNN_1D:', average_CNN) print('Cross-validated one-session-train CNN_1D:', average_CNN)
print('\n') print('\n')
''' #'''