feat: add cross-validation that trains on one session only
This commit is contained in:
parent
3d512addb8
commit
5e80e465ad
@ -285,6 +285,47 @@ def session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions,
|
|||||||
|
|
||||||
return average_result, session_training_results
|
return average_result, session_training_results
|
||||||
|
|
||||||
|
# Retrieves data sets for each session as train set and evalutes on the others.
|
||||||
|
# the average of networks trained om them
|
||||||
|
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||||
|
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||||
|
def inverse_session_cross_validation(model_name:str, X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
|
||||||
|
session_training_results = []
|
||||||
|
for i in range(nr_sessions):
|
||||||
|
|
||||||
|
X_test_session, X_train_session, y_test_session, y_train_session = prepare_datasets_sessions(X, y, session_lengths, i)
|
||||||
|
|
||||||
|
# Model:
|
||||||
|
if model_name == 'LSTM':
|
||||||
|
model = LSTM(input_shape=(1, 208))
|
||||||
|
|
||||||
|
elif model_name == 'GRU':
|
||||||
|
model = GRU(input_shape=(1, 208))
|
||||||
|
|
||||||
|
elif model_name == 'CNN_1D':
|
||||||
|
X_train_session = np.reshape(X_train_session, (X_train_session.shape[0], 208, 1))
|
||||||
|
X_test_session = np.reshape(X_test_session, (X_test_session.shape[0], 208, 1))
|
||||||
|
model = CNN_1D(input_shape=(208, 1))
|
||||||
|
|
||||||
|
elif model_name == 'FFN':
|
||||||
|
model = FFN(input_shape=(1, 208))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception('Model not found')
|
||||||
|
|
||||||
|
train(model, X_train_session, y_train_session, verbose=1, batch_size=batch_size, epochs=epochs)
|
||||||
|
test_loss, test_acc = model.evaluate(X_test_session, y_test_session, verbose=0)
|
||||||
|
session_training_results.append(test_acc)
|
||||||
|
#if log_to_csv:
|
||||||
|
#prediction_csv_logger(X_test_session, y_test_session, model_name, model, i)
|
||||||
|
del model
|
||||||
|
K.clear_session()
|
||||||
|
#print('Session', i, 'as test data gives accuracy:', test_acc)
|
||||||
|
|
||||||
|
average_result = statistics.mean((session_training_results))
|
||||||
|
|
||||||
|
return average_result, session_training_results
|
||||||
|
|
||||||
# Takes in test data and logs input data and the prediction from a model
|
# Takes in test data and logs input data and the prediction from a model
|
||||||
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
# Input: raw data, session_lengths list, total nr of sessions, batch_size, and nr of epochs
|
||||||
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
# Ouput: tuple(cross validation average, list(result for each dataset(len=nr_sessions)))
|
||||||
@ -458,31 +499,27 @@ if __name__ == "__main__":
|
|||||||
print('\n')
|
print('\n')
|
||||||
'''
|
'''
|
||||||
|
|
||||||
'''
|
#'''
|
||||||
# ----- Inverse cross-validation ------
|
# ----- Inverse cross-validation ------
|
||||||
# Trained on one session, tested on three
|
# Trained on one session, tested on three
|
||||||
average_GRU = session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
log_to_csv=LOG,
|
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
average_LSTM = session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
log_to_csv=LOG,
|
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
average_FFN = session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
log_to_csv=LOG,
|
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
average_CNN = session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||||
log_to_csv=LOG,
|
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
epochs=EPOCHS)
|
epochs=EPOCHS)
|
||||||
|
|
||||||
print('\n')
|
print('\n')
|
||||||
print('Crossvalidated GRU:', average_GRU)
|
print('Cross-validated one-session-train GRU:', average_GRU)
|
||||||
print('Crossvalidated LSTM:', average_LSTM)
|
print('Cross-validated one-session-train LSTM:', average_LSTM)
|
||||||
print('Crossvalidated FFN:', average_FFN)
|
print('Cross-validated one-session-train FFN:', average_FFN)
|
||||||
print('Cross-validated CNN_1D:', average_CNN)
|
print('Cross-validated one-session-train CNN_1D:', average_CNN)
|
||||||
print('\n')
|
print('\n')
|
||||||
'''
|
#'''
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user