feat: add plot func in order to compare validation
on three sessions after testing on one
This commit is contained in:
parent
14d9b65060
commit
cb7da4c657
@ -12,6 +12,7 @@ from keras.callbacks import Callback, CSVLogger, ModelCheckpoint
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
#from matplotlib.legend import _get_legend_handles_
|
||||
import statistics
|
||||
import csv
|
||||
|
||||
@ -150,81 +151,97 @@ def plot_4_x_inverse_cross_val(X, y, session_lengths, nr_sessions, batch_size=64
|
||||
# Ouput: None -> plot
|
||||
def plot_4_x_average_val(X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
|
||||
|
||||
history_dict = {'GRU': [],
|
||||
'LSTM': [],
|
||||
'FFN': [],
|
||||
'CNN_1D': []}
|
||||
history_dict = {'GRU_train': [],
|
||||
'LSTM_train': [],
|
||||
'FFN_train': [],
|
||||
'CNN_1D_train': []}
|
||||
history_dict_val = {'GRU_val': [],
|
||||
'LSTM_val': [],
|
||||
'FFN_val': [],
|
||||
'CNN_1D_val': []}
|
||||
|
||||
for i in range(nr_sessions):
|
||||
|
||||
# Prepare data
|
||||
X_val_session, X_train_session, y_val_session, y_train_session = prepare_datasets_sessions(X, y, session_lengths, i)
|
||||
|
||||
# GRU
|
||||
model_GRU = GRU(input_shape=(1, 208))
|
||||
GRU_h = train(model_GRU, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['GRU'].append(GRU_h)
|
||||
history_dict['GRU_train'].append(GRU_h.history['accuracy'])
|
||||
history_dict_val['GRU_val'].append(GRU_h.history['val_accuracy'])
|
||||
del model_GRU
|
||||
K.clear_session()
|
||||
|
||||
# LSTM
|
||||
model_LSTM = LSTM(input_shape=(1, 208))
|
||||
LSTM_h = train(model_LSTM, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['LSTM'].append(LSTM_h)
|
||||
history_dict['LSTM_train'].append(LSTM_h.history['accuracy'])
|
||||
history_dict_val['LSTM_val'].append(LSTM_h.history['val_accuracy'])
|
||||
del model_LSTM
|
||||
K.clear_session()
|
||||
|
||||
# FFN
|
||||
model_FFN = FFN(input_shape=(1, 208))
|
||||
FFN_h = train(model_FFN, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['FFN'].append(FFN_h)
|
||||
history_dict['FFN_train'].append(FFN_h.history['accuracy'])
|
||||
history_dict_val['FFN_val'].append(FFN_h.history['val_accuracy'])
|
||||
del model_FFN
|
||||
K.clear_session()
|
||||
|
||||
# CNN_1D
|
||||
model_CNN_1D = CNN_1D(input_shape=(208, 1))
|
||||
X_train_session = np.reshape(X_train_session, (X_train_session.shape[0], 208, 1))
|
||||
X_val_session = np.reshape(X_val_session, (X_val_session.shape[0], 208, 1))
|
||||
CNN_1D_h = train(model_CNN_1D, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['CNN_1D'].append(CNN_1D_h)
|
||||
history_dict['CNN_1D_train'].append(CNN_1D_h.history['accuracy'])
|
||||
history_dict_val['CNN_1D_val'].append(CNN_1D_h.history['val_accuracy'])
|
||||
del model_CNN_1D
|
||||
K.clear_session()
|
||||
|
||||
|
||||
for key, value in history_dict.items():
|
||||
print(key)
|
||||
print(value[0].history['val_accuracy'])
|
||||
|
||||
# Averaging out session training for each network
|
||||
for key in history_dict:
|
||||
val_key = key + '_val'
|
||||
history_dict[key] = (np.average(x, y, z, c) for x, y, z, c in zip(history_dict[key][0]['val_accuracy'],
|
||||
history_dict[key][1]['val_accuracy'],
|
||||
history_dict[key][2]['val_accuracy'],
|
||||
history_dict[key][3]['val_accuracy']))
|
||||
train_key = key + '_train'
|
||||
history_dict[key] = (np.average(x, y, z, c) for x, y, z, c in zip(history_dict[key][0]['accuracy'],
|
||||
history_dict[key][1]['accuracy'],
|
||||
history_dict[key][2]['accuracy'],
|
||||
history_dict[key][3]['accuracy']))
|
||||
history_dict[key] = list(np.average([x, y, z, c]) for x, y, z, c in list(zip(*history_dict[key])))
|
||||
for key in history_dict_val:
|
||||
history_dict_val[key] = list(np.average([x, y, z, c]) for x, y, z, c in list(zip(*history_dict_val[key])))
|
||||
|
||||
fig, axs = plt.subplots(2, sharey=True)
|
||||
plt.ylim(0, 1)
|
||||
'''
|
||||
history_dict = {'GRU_train': [0.5, 0.8],
|
||||
'LSTM_train': [0.5, 0.9],
|
||||
'FFN_train': [0.75, 0.8],
|
||||
'CNN_1D_train': [0.8, 0.95]}
|
||||
history_dict_val = {'GRU_val': [0.5, 0.8],
|
||||
'LSTM_val': [0.5, 0.9],
|
||||
'FFN_val': [0.75, 0.8],
|
||||
'CNN_1D_val': [0.8, 0.95]}
|
||||
'''
|
||||
|
||||
# Plot:
|
||||
axs[0].plot(history_dict['GRU-val'])
|
||||
axs[0].plot(history_dict['LSTM_val'], 'tab:orange')
|
||||
axs[0].plot(history_dict['FFN_val'], 'tab:green')
|
||||
axs[0].plot(history_dict['CNN_1D_val'], 'tab:red')
|
||||
axs[0].set_title('Avarage validation accuracy with cross-session-training')
|
||||
fig, axs = plt.subplots(2, sharey=True)
|
||||
plt.ylim(0, 1)
|
||||
plt.subplots_adjust(hspace=1.0, top=0.85, bottom=0.15, right=0.75)
|
||||
fig.suptitle('Avarage accuracy with cross-session-training', fontsize=16)
|
||||
|
||||
axs[1].plot(history_dict['GRU_train'])
|
||||
axs[1].plot(history_dict['LSTM_train'], 'tab:orange')
|
||||
axs[1].plot(history_dict['FFM_train'], 'tab:green')
|
||||
axs[1].plot(history_dict['CNN_1D_train'], 'tab:red')
|
||||
axs[1].set_title('Avarage training accuracy with cross-session-training')
|
||||
axs[0].plot(history_dict['GRU_train'], label='GRU')
|
||||
axs[0].plot(history_dict['LSTM_train'], 'tab:orange', label='LSTM')
|
||||
axs[0].plot(history_dict['FFN_train'], 'tab:green', label='FFN')
|
||||
axs[0].plot(history_dict['CNN_1D_train'], 'tab:red', label='CNN_1D')
|
||||
axs[0].set_title('Training accuracy')
|
||||
|
||||
|
||||
axs[1].plot(history_dict_val['GRU_val'], label='GRU')
|
||||
axs[1].plot(history_dict_val['LSTM_val'], 'tab:orange', label='LSTM')
|
||||
axs[1].plot(history_dict_val['FFN_val'], 'tab:green', label='FFN')
|
||||
axs[1].plot(history_dict_val['CNN_1D_val'], 'tab:red', label='CNN_1D')
|
||||
axs[1].set_title('Validation accuracy')
|
||||
|
||||
for ax in axs.flat:
|
||||
ax.set(xlabel='Epochs', ylabel='Accuracy')
|
||||
|
||||
plt.legend(bbox_to_anchor=(1.05, 1.5), title='Networks', loc='center left')
|
||||
plt.show()
|
||||
|
||||
|
||||
@ -697,6 +714,6 @@ if __name__ == "__main__":
|
||||
# ----- PLOTTING ------
|
||||
|
||||
#plot_4xinverse_cross_val(X, y, session_lengths, NR_SESSIONS, epochs=30)
|
||||
plot_4_x_average_val(X, y, session_lengths, NR_SESSIONS, epochs=2)
|
||||
plot_4_x_average_val(X, y, session_lengths, NR_SESSIONS, epochs=30)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user