feat: make plot func to show network training progress
This commit is contained in:
parent
1d1ec433d8
commit
14d9b65060
@ -28,16 +28,8 @@ class Data_container:
|
||||
self.subject_name = subject_name
|
||||
self.dict_list = [{'left': [None]*8, 'right': [None]*8} for i in range(nr_sessions)]
|
||||
|
||||
|
||||
#self.data_dict_round1 = {'left': [None]*8, 'right': [None]*8}
|
||||
#self.data_dict_round2 = {'left': [None]*8, 'right': [None]*8}
|
||||
#self.data_dict_round3 = {'left': [None]*8, 'right': [None]*8}
|
||||
#self.data_dict_round4 = {'left': [None]*8, 'right': [None]*8}
|
||||
#self.dict_list = [self.data_dict_round1,
|
||||
# self.data_dict_round2,
|
||||
# self.data_dict_round3,
|
||||
# self.data_dict_round4
|
||||
# ]
|
||||
def __str__(self) -> str:
|
||||
return 'Name: {}, \tID: {}'.format(self.subject_name, self.subject_nr)
|
||||
|
||||
class CSV_handler:
|
||||
|
||||
@ -79,30 +71,6 @@ class CSV_handler:
|
||||
# Places the data correctly:
|
||||
data_container.dict_list[session-1][which_arm][emg_nr] = df
|
||||
|
||||
'''
|
||||
if session == 1:
|
||||
if which_arm == 'left':
|
||||
data_container.data_dict_round1['left'][emg_nr] = df # Zero indexed emg_nr in the dict
|
||||
else:
|
||||
data_container.data_dict_round1['right'][emg_nr] = df
|
||||
elif session == 2:
|
||||
if which_arm == 'left':
|
||||
data_container.data_dict_round2['left'][emg_nr] = df
|
||||
else:
|
||||
data_container.data_dict_round2['right'][emg_nr] = df
|
||||
elif session == 3:
|
||||
if which_arm == 'left':
|
||||
data_container.data_dict_round3['left'][emg_nr] = df
|
||||
else:
|
||||
data_container.data_dict_round3['right'][emg_nr] = df
|
||||
elif session == 4:
|
||||
if which_arm == 'left':
|
||||
data_container.data_dict_round4['left'][emg_nr] = df
|
||||
else:
|
||||
data_container.data_dict_round4['right'][emg_nr] = df
|
||||
else:
|
||||
raise IndexError('Not a valid index')
|
||||
'''
|
||||
|
||||
# Links the data container for a subject to the csv_handler object
|
||||
# Input: the subject's data_container
|
||||
@ -120,9 +88,55 @@ class CSV_handler:
|
||||
df = container.dict_list[session - 1].get(which_arm)[emg_nr - 1]
|
||||
return df
|
||||
|
||||
# Loads the data from the csv files into the storing system of the CSV_handler object
|
||||
# Input: None(CSV_handler)
|
||||
|
||||
# Loads data the to the CSV_handler(general load func). Choose data_type: hard, hardPP, soft og softPP as str.
|
||||
# Input: String(datatype you want), direction name of that type
|
||||
# Output: None -> load and stores data
|
||||
def load_data(self, type:str, type_dir_name:str):
|
||||
|
||||
data_path = self.working_dir + '/data/' + type_dir_name
|
||||
subject_id = 100
|
||||
subject_name = 'bruh'
|
||||
nr_sessions = 101
|
||||
container = None
|
||||
session_count = 0
|
||||
|
||||
for i, (path, subject_dir, session_dir) in enumerate(os.walk(data_path)):
|
||||
|
||||
if path is not data_path:
|
||||
|
||||
if subject_dir:
|
||||
session_count = 0
|
||||
subject_id = int(path[-1])
|
||||
subject_name = subject_dir[0].split('_')[0]
|
||||
nr_sessions = len(subject_dir)
|
||||
container = Data_container(subject_id, subject_name, nr_sessions)
|
||||
continue
|
||||
else:
|
||||
session_count += 1
|
||||
|
||||
for f in session_dir:
|
||||
spes_path = os.path.join(path, f)
|
||||
if f == 'myoLeftEmg.csv':
|
||||
for emg_nr in range(8):
|
||||
self.store_df_in_container(spes_path, emg_nr, 'left', container, session_count)
|
||||
elif f == 'myoRightEmg.csv':
|
||||
for emg_nr in range(8):
|
||||
self.store_df_in_container(spes_path, emg_nr, 'right', container, session_count)
|
||||
self.link_container_to_handler(container)
|
||||
self.data_type = type
|
||||
return self.data_container_dict
|
||||
|
||||
# Retrieved data. Send in loaded csv_handler and data detailes you want.
|
||||
# Input: Experiment detailes
|
||||
# Output: DataFrame, samplerate:int
|
||||
def get_data(self, subject_nr, which_arm, session, emg_nr):
|
||||
data_frame = self.get_df_from_data_dict(subject_nr, which_arm, session, emg_nr)
|
||||
samplerate = get_samplerate(data_frame)
|
||||
return data_frame, samplerate
|
||||
|
||||
|
||||
# OBSOLETE
|
||||
def load_hard_PP_emg_data(self):
|
||||
|
||||
# CSV data from subject 1
|
||||
@ -487,55 +501,7 @@ class CSV_handler:
|
||||
self.link_container_to_handler(data_container)
|
||||
self.data_type = 'soft'
|
||||
return self.data_container_dict
|
||||
|
||||
|
||||
def load_general(self, type, type_dir_name:str):
|
||||
|
||||
data_path = self.working_dir + '/data/' + type_dir_name
|
||||
|
||||
for i, (path, subject_dir, session_dir) in enumerate(os.walk(data_path)):
|
||||
|
||||
if path is not data_path:
|
||||
#print(i)
|
||||
#print(path)
|
||||
#print(subject_dir)
|
||||
#print(session_dir)
|
||||
subject_id = 100
|
||||
subject_name = 'bruh'
|
||||
nr_sessions = 101
|
||||
container = None
|
||||
session_count = 0
|
||||
|
||||
if subject_dir:
|
||||
session_count = 0
|
||||
subject_id = int(path[-1])
|
||||
subject_name = subject_dir[0].split('_')[0]
|
||||
nr_sessions = len(subject_dir)
|
||||
container = Data_container(subject_id, subject_name, nr_sessions)
|
||||
continue
|
||||
else:
|
||||
session_count += 1
|
||||
|
||||
for f in session_dir:
|
||||
spes_path = os.path.join(path, f)
|
||||
if f == 'myoLeftEmg.csv':
|
||||
print(path)
|
||||
print(spes_path)
|
||||
for emg_nr in range(8):
|
||||
self.store_df_in_container(spes_path, emg_nr, 'left', container, session_count)
|
||||
elif f == 'myoRightEmg.csv':
|
||||
for emg_nr in range(8):
|
||||
self.store_df_in_container(spes_path, emg_nr, 'right', container, session_count)
|
||||
|
||||
self.data_type = type
|
||||
return self.data_container_dict
|
||||
|
||||
|
||||
|
||||
# Loads data the to the CSV_handler(general load func). Choose data_type: hard, hardPP, soft og softPP as str.
|
||||
# Input: String(datatype you want)
|
||||
# Output: None -> load and stores data
|
||||
def load_data(self, data_type):
|
||||
def load_data_OLD(self, data_type):
|
||||
if data_type == 'hard':
|
||||
self.load_hard_original_emg_data()
|
||||
elif data_type == 'hardPP':
|
||||
@ -547,13 +513,6 @@ class CSV_handler:
|
||||
else:
|
||||
raise Exception('Wrong input')
|
||||
|
||||
# Retrieved data. Send in loaded csv_handler and data detailes you want.
|
||||
# Input: Experiment detailes
|
||||
# Output: DataFrame, samplerate:int
|
||||
def get_data(self, subject_nr, which_arm, session, emg_nr):
|
||||
data_frame = self.get_df_from_data_dict(subject_nr, which_arm, session, emg_nr)
|
||||
samplerate = get_samplerate(data_frame)
|
||||
return data_frame, samplerate
|
||||
|
||||
# NOT IMPLEMENTED
|
||||
def get_keyboard_data(self, filename:str, pres_or_release:str='pressed'):
|
||||
@ -576,19 +535,9 @@ class NN_handler:
|
||||
def __init__(self, csv_handler:CSV_handler) -> None:
|
||||
self.csv_handler = csv_handler
|
||||
# Should med 4 sessions * split nr of samples per person. Each sample is structured like this: [sample_df, samplerate]
|
||||
self.reg_samples_per_subject = {1: [],
|
||||
2: [],
|
||||
3: [],
|
||||
4: [],
|
||||
5: []
|
||||
}
|
||||
self.reg_samples_per_subject = {k+1:[] for k in range(csv_handler.nr_subjects)}
|
||||
# Should med 4 sessions * (~150, 208) of mfcc samples per person. One [DataFrame, session_length_list] per subject
|
||||
self.mfcc_samples_per_subject = {1: [],
|
||||
2: [],
|
||||
3: [],
|
||||
4: [],
|
||||
5: []
|
||||
}
|
||||
self.mfcc_samples_per_subject = {k+1:[] for k in range(csv_handler.nr_subjects)}
|
||||
|
||||
# GET method for reg_samples_dict
|
||||
def get_reg_samples_dict(self) -> dict:
|
||||
|
@ -66,6 +66,169 @@ def plot_train_history(history, val_data=False):
|
||||
|
||||
plt.show()
|
||||
|
||||
# Plots the training history of four networks inverse cross-validated
|
||||
# Input: data, nr of sessions in total, batch_size and epochs
|
||||
# Ouput: None -> plot
|
||||
def plot_4_x_inverse_cross_val(X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
|
||||
|
||||
history_dict = {'GRU': [],
|
||||
'LSTM': [],
|
||||
'FFN': [],
|
||||
'CNN_1D': []}
|
||||
|
||||
for i in range(nr_sessions):
|
||||
|
||||
X_test_session, X_train_session, y_test_session, y_train_session = prepare_datasets_sessions(X, y, session_lengths, i)
|
||||
|
||||
model_GRU = GRU(input_shape=(1, 208))
|
||||
GRU_h = train(model_GRU, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs)
|
||||
history_dict['GRU'].append(GRU_h)
|
||||
del model_GRU
|
||||
K.clear_session()
|
||||
|
||||
model_LSTM = LSTM(input_shape=(1, 208))
|
||||
LSTM_h = train(model_LSTM, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs)
|
||||
history_dict['LSTM'].append(LSTM_h)
|
||||
del model_LSTM
|
||||
K.clear_session()
|
||||
|
||||
model_FFN = FFN(input_shape=(1, 208))
|
||||
FFN_h = train(model_FFN, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs)
|
||||
history_dict['FFN'].append(FFN_h)
|
||||
del model_FFN
|
||||
K.clear_session()
|
||||
|
||||
model_CNN_1D = CNN_1D(input_shape=(208, 1))
|
||||
X_train_session = np.reshape(X_train_session, (X_train_session.shape[0], 208, 1))
|
||||
X_test_session = np.reshape(X_test_session, (X_test_session.shape[0], 208, 1))
|
||||
CNN_1D_h = train(model_CNN_1D, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs)
|
||||
history_dict['CNN_1D'].append(CNN_1D_h)
|
||||
del model_CNN_1D
|
||||
K.clear_session()
|
||||
|
||||
|
||||
fig, axs = plt.subplots(2, 2, sharey=True)
|
||||
plt.ylim(0, 1)
|
||||
|
||||
# GRU plot:
|
||||
axs[0, 0].plot(history_dict['GRU'][0].history["accuracy"])
|
||||
axs[0, 0].plot(history_dict['GRU'][1].history["accuracy"], 'tab:orange')
|
||||
axs[0, 0].plot(history_dict['GRU'][2].history["accuracy"], 'tab:green')
|
||||
axs[0, 0].plot(history_dict['GRU'][3].history["accuracy"], 'tab:red')
|
||||
axs[0, 0].set_title('GRU')
|
||||
# LSTM plot:
|
||||
axs[0, 1].plot(history_dict['LSTM'][0].history["accuracy"])
|
||||
axs[0, 1].plot(history_dict['LSTM'][1].history["accuracy"], 'tab:orange')
|
||||
axs[0, 1].plot(history_dict['LSTM'][2].history["accuracy"], 'tab:green')
|
||||
axs[0, 1].plot(history_dict['LSTM'][3].history["accuracy"], 'tab:red')
|
||||
axs[0, 1].set_title('LSTM')
|
||||
# FFN plot:
|
||||
axs[1, 0].plot(history_dict['FFN'][0].history["accuracy"])
|
||||
axs[1, 0].plot(history_dict['FFN'][1].history["accuracy"], 'tab:orange')
|
||||
axs[1, 0].plot(history_dict['FFN'][2].history["accuracy"], 'tab:green')
|
||||
axs[1, 0].plot(history_dict['FFN'][3].history["accuracy"], 'tab:red')
|
||||
axs[1, 0].set_title('FFN')
|
||||
# CNN_1D plot:
|
||||
axs[1, 1].plot(history_dict['CNN_1D'][0].history["accuracy"])
|
||||
axs[1, 1].plot(history_dict['CNN_1D'][1].history["accuracy"], 'tab:orange')
|
||||
axs[1, 1].plot(history_dict['CNN_1D'][2].history["accuracy"], 'tab:green')
|
||||
axs[1, 1].plot(history_dict['CNN_1D'][3].history["accuracy"], 'tab:red')
|
||||
axs[1, 1].set_title('CNN_1D')
|
||||
|
||||
for ax in axs.flat:
|
||||
ax.set(xlabel='Epochs', ylabel='Accuracy')
|
||||
|
||||
# Hide x labels and tick labels for top plots and y ticks for right plots.
|
||||
for ax in axs.flat:
|
||||
ax.label_outer()
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
# Plots the average training history of four networks inverse cross-validated
|
||||
# Input: data, nr of sessions in total, batch_size and epochs
|
||||
# Ouput: None -> plot
|
||||
def plot_4_x_average_val(X, y, session_lengths, nr_sessions, batch_size=64, epochs=30):
|
||||
|
||||
history_dict = {'GRU': [],
|
||||
'LSTM': [],
|
||||
'FFN': [],
|
||||
'CNN_1D': []}
|
||||
|
||||
for i in range(nr_sessions):
|
||||
|
||||
X_val_session, X_train_session, y_val_session, y_train_session = prepare_datasets_sessions(X, y, session_lengths, i)
|
||||
|
||||
model_GRU = GRU(input_shape=(1, 208))
|
||||
GRU_h = train(model_GRU, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['GRU'].append(GRU_h)
|
||||
del model_GRU
|
||||
K.clear_session()
|
||||
|
||||
model_LSTM = LSTM(input_shape=(1, 208))
|
||||
LSTM_h = train(model_LSTM, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['LSTM'].append(LSTM_h)
|
||||
del model_LSTM
|
||||
K.clear_session()
|
||||
|
||||
model_FFN = FFN(input_shape=(1, 208))
|
||||
FFN_h = train(model_FFN, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['FFN'].append(FFN_h)
|
||||
del model_FFN
|
||||
K.clear_session()
|
||||
|
||||
model_CNN_1D = CNN_1D(input_shape=(208, 1))
|
||||
X_train_session = np.reshape(X_train_session, (X_train_session.shape[0], 208, 1))
|
||||
X_val_session = np.reshape(X_val_session, (X_val_session.shape[0], 208, 1))
|
||||
CNN_1D_h = train(model_CNN_1D, X_train_session, y_train_session, 1, batch_size=batch_size, epochs=epochs,
|
||||
X_validation=X_val_session, y_validation=y_val_session)
|
||||
history_dict['CNN_1D'].append(CNN_1D_h)
|
||||
del model_CNN_1D
|
||||
K.clear_session()
|
||||
|
||||
|
||||
for key, value in history_dict.items():
|
||||
print(key)
|
||||
print(value[0].history['val_accuracy'])
|
||||
|
||||
for key in history_dict:
|
||||
val_key = key + '_val'
|
||||
history_dict[key] = (np.average(x, y, z, c) for x, y, z, c in zip(history_dict[key][0]['val_accuracy'],
|
||||
history_dict[key][1]['val_accuracy'],
|
||||
history_dict[key][2]['val_accuracy'],
|
||||
history_dict[key][3]['val_accuracy']))
|
||||
train_key = key + '_train'
|
||||
history_dict[key] = (np.average(x, y, z, c) for x, y, z, c in zip(history_dict[key][0]['accuracy'],
|
||||
history_dict[key][1]['accuracy'],
|
||||
history_dict[key][2]['accuracy'],
|
||||
history_dict[key][3]['accuracy']))
|
||||
|
||||
fig, axs = plt.subplots(2, sharey=True)
|
||||
plt.ylim(0, 1)
|
||||
|
||||
# Plot:
|
||||
axs[0].plot(history_dict['GRU-val'])
|
||||
axs[0].plot(history_dict['LSTM_val'], 'tab:orange')
|
||||
axs[0].plot(history_dict['FFN_val'], 'tab:green')
|
||||
axs[0].plot(history_dict['CNN_1D_val'], 'tab:red')
|
||||
axs[0].set_title('Avarage validation accuracy with cross-session-training')
|
||||
|
||||
axs[1].plot(history_dict['GRU_train'])
|
||||
axs[1].plot(history_dict['LSTM_train'], 'tab:orange')
|
||||
axs[1].plot(history_dict['FFM_train'], 'tab:green')
|
||||
axs[1].plot(history_dict['CNN_1D_train'], 'tab:red')
|
||||
axs[1].set_title('Avarage training accuracy with cross-session-training')
|
||||
|
||||
for ax in axs.flat:
|
||||
ax.set(xlabel='Epochs', ylabel='Accuracy')
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
# Takes in data and labels, and splits it into train, validation and test sets by percentage
|
||||
# Input: Data, labels, whether to shuffle, % validatiion, % test
|
||||
# Ouput: X_train, X_validation, X_test, y_train, y_validation, y_test
|
||||
@ -173,7 +336,7 @@ def train( model, X_train, y_train, verbose, batch_size=64, epochs=30,
|
||||
#csv_path = str(Path.cwd()) + '/logs/{}/{}_train_log.csv'.format(MODEL_NAME, MODEL_NAME)
|
||||
#csv_logger = CSVLogger(csv_path, append=False)
|
||||
|
||||
if X_validation != None:
|
||||
if X_validation.any():
|
||||
history = model.fit(X_train,
|
||||
y_train,
|
||||
validation_data=(X_validation, y_validation),
|
||||
@ -422,12 +585,12 @@ if __name__ == "__main__":
|
||||
NR_SUBJECTS = 5
|
||||
NR_SESSIONS = 4
|
||||
BATCH_SIZE = 64
|
||||
EPOCHS = 30
|
||||
EPOCHS = 10
|
||||
|
||||
TEST_SESSION_NR = 4
|
||||
VERBOSE = 1
|
||||
MODEL_NAME = 'CNN_1D'
|
||||
LOG = True
|
||||
LOG = False
|
||||
|
||||
# ----- Get prepared data: train, validation, and test ------
|
||||
# X_train.shape = (2806-X_test, 1, 208)
|
||||
@ -496,30 +659,30 @@ if __name__ == "__main__":
|
||||
epochs=EPOCHS)
|
||||
|
||||
print('\n')
|
||||
print('Crossvalidated GRU:', average_GRU)
|
||||
print('Crossvalidated LSTM:', average_LSTM)
|
||||
print('Crossvalidated FFN:', average_FFN)
|
||||
print('Cross-validated GRU:', average_GRU)
|
||||
print('Cross-validated LSTM:', average_LSTM)
|
||||
print('Cross-validated FFN:', average_FFN)
|
||||
print('Cross-validated CNN_1D:', average_CNN)
|
||||
print('\n')
|
||||
'''
|
||||
|
||||
#'''
|
||||
'''
|
||||
# ----- Inverse cross-validation ------
|
||||
# Trained on one session, tested on three
|
||||
average_GRU = inverse_session_cross_validation('GRU', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_LSTM = inverse_session_cross_validation('LSTM', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_FFN = inverse_session_cross_validation('FFN', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
average_CNN = inverse_session_cross_validation('CNN_1D', X, y, session_lengths, nr_sessions=NR_SESSIONS,
|
||||
log_to_csv=True,
|
||||
log_to_csv=LOG,
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS)
|
||||
|
||||
@ -529,5 +692,11 @@ if __name__ == "__main__":
|
||||
print('Cross-validated one-session-train FFN:', average_FFN)
|
||||
print('Cross-validated one-session-train CNN_1D:', average_CNN)
|
||||
print('\n')
|
||||
#'''
|
||||
'''
|
||||
|
||||
# ----- PLOTTING ------
|
||||
|
||||
#plot_4xinverse_cross_val(X, y, session_lengths, NR_SESSIONS, epochs=30)
|
||||
plot_4_x_average_val(X, y, session_lengths, NR_SESSIONS, epochs=2)
|
||||
|
||||
|
||||
|
@ -247,11 +247,14 @@ if __name__ == "__main__":
|
||||
|
||||
soft_dir_name = 'Exp20201205_2myo_softType'
|
||||
hard_dir_name = 'Exp20201205_2myo_hardType'
|
||||
JSON_TEST_NAME = 'TEST_mfcc.json'
|
||||
|
||||
csv_handler = CSV_handler(NR_SUBJECTS, NR_SESSIONS)
|
||||
dict = csv_handler.load_general('soft', soft_dir_name)
|
||||
dict = csv_handler.load_data('soft', soft_dir_name)
|
||||
|
||||
pretty(dict)
|
||||
nn_handler = NN_handler(csv_handler)
|
||||
nn_handler.store_mfcc_samples()
|
||||
nn_handler.save_json_mfcc(JSON_TEST_NAME)
|
||||
|
||||
|
||||
|
||||
|
592111
TEST_mfcc.json
Normal file
592111
TEST_mfcc.json
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Loading…
Reference in New Issue
Block a user