feat: add cross-validation that trains on one session only
This commit is contained in:
parent
3d512addb8
commit
5e80e465ad
@ -285,6 +285,47 @@ def session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions,
|
||||
|
||||
return average_result, session_training_results
|
||||
|
||||
# Retrieves data sets for each session as train set and evalutes on the others.
|
||||
# the average of networks trained om them
|
||||
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
|
||||
session_training_results = []
|
||||
for i in range(nr_sessions):
|
||||
|
||||
X_test_session, X_train_session, y_test_session, y_train_session = prepare_datasets_sessions(X, y, session_lengths, i)
|
||||
|
||||
# Model:
|
||||
if model_name == 'LSTM':
|
||||
model = LSTM(input_shape=(1, 208))
|
||||
|
||||
elif model_name == 'GRU':
|
||||
model = GRU(input_shape=(1, 208))
|
||||
|
||||
elif model_name == 'CNN_1D':
|
||||
X_train_session = np.reshape(X_train_session, (X_train_session.shape[0], 208, 1))
|
||||
X_test_session = np.reshape(X_test_session, (X_test_session.shape[0], 208, 1))
|
||||
model = CNN_1D(input_shape=(208, 1))
|
||||
|
||||
elif model_name == 'FFN':
|
||||
model = FFN(input_shape=(1, 208))
|
||||
|
||||
else:
|
||||
raise Exception('Model not found')
|
||||
|
||||
train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs)
|
||||
test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0)
|
||||
session_training_results.append(test_acc)
|
||||
#if log_to_csv:
|
||||
#prediction_csv_logger(X_test_session, y_test_session, model_name, model, i)
|
||||
del model
|
||||
K.clear_session()
|
||||
#print('Session', i, 'as test data gives accuracy:', test_acc)
|
||||
|
||||
average_result = statistics.mean((session_training_results))
|
||||
|
||||
return average_result, session_training_results
|
||||
|
||||
# Takes in test data and logs input data and the prediction from a model
|
||||
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||
@ -458,31 +499,27 @@ if __name__ == "__main__":
|
||||
print('\n')
|
||||
'''
|
||||
|
||||
'''
|
||||
#'''
|
||||
# ----- Inverse cross-validation ------
|
||||
# Trained on one session, tested on three
|
||||
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=LOG,
|
||||
average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
|
||||
print('\n')
|
||||
print('Crossvalidated GRU:', average_GRU)
|
||||
print('Crossvalidated LSTM:', average_LSTM)
|
||||
print('Crossvalidated FFN:', average_FFN)
|
||||
print('Cross-validated CNN_1D:', average_CNN)
|
||||
print('Cross-validated one-session-train GRU:', average_GRU)
|
||||
print('Cross-validated one-session-train LSTM:', average_LSTM)
|
||||
print('Cross-validated one-session-train FFN:', average_FFN)
|
||||
print('Cross-validated one-session-train CNN_1D:', average_CNN)
|
||||
print('\n')
|
||||
'''
|
||||
#'''
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user