From 2d1376cd6d0dd3b350eb47c785d4cbaee18c807f Mon Sep 17 00:00:00 2001 From: em474re Date: Tue, 21 Sep 2021 13:17:22 +0200 Subject: [PATCH] transformer --- results/trans_hand_vgg/devel.predictions.csv | 296 +++++++++++++++++++ results/trans_hand_vgg/test.predictions.csv | 284 ++++++++++++++++++ src/transformer_hand_vgg.py | 278 +++++++++++++++++ 3 files changed, 858 insertions(+) create mode 100644 results/trans_hand_vgg/devel.predictions.csv create mode 100644 results/trans_hand_vgg/test.predictions.csv create mode 100644 src/transformer_hand_vgg.py diff --git a/results/trans_hand_vgg/devel.predictions.csv b/results/trans_hand_vgg/devel.predictions.csv new file mode 100644 index 0000000..c17e4a8 --- /dev/null +++ b/results/trans_hand_vgg/devel.predictions.csv @@ -0,0 +1,296 @@ +filename,prediction +devel_001.wav,False +devel_002.wav,False +devel_003.wav,False +devel_004.wav,False +devel_005.wav,False +devel_006.wav,False +devel_007.wav,False +devel_008.wav,False +devel_009.wav,False +devel_010.wav,False +devel_011.wav,False +devel_012.wav,False +devel_013.wav,False +devel_014.wav,False +devel_015.wav,False +devel_016.wav,False +devel_017.wav,False +devel_018.wav,False +devel_019.wav,False +devel_020.wav,False +devel_021.wav,False +devel_022.wav,False +devel_023.wav,False +devel_024.wav,False +devel_025.wav,False +devel_026.wav,False +devel_027.wav,False +devel_028.wav,False +devel_029.wav,False +devel_030.wav,False +devel_031.wav,False +devel_032.wav,False +devel_033.wav,False +devel_034.wav,False +devel_035.wav,False +devel_036.wav,False +devel_037.wav,False +devel_038.wav,False +devel_039.wav,False +devel_040.wav,False +devel_041.wav,False +devel_042.wav,False +devel_043.wav,False +devel_044.wav,False +devel_045.wav,False +devel_046.wav,False +devel_047.wav,False +devel_048.wav,False +devel_049.wav,False +devel_050.wav,False +devel_051.wav,False +devel_052.wav,False +devel_053.wav,False +devel_054.wav,False +devel_055.wav,False +devel_056.wav,False +devel_057.wav,False +devel_058.wav,False +devel_059.wav,False +devel_060.wav,False +devel_061.wav,False +devel_062.wav,False +devel_063.wav,False +devel_064.wav,False +devel_065.wav,False +devel_066.wav,False +devel_067.wav,False +devel_068.wav,False +devel_069.wav,False +devel_070.wav,False +devel_071.wav,False +devel_072.wav,False +devel_073.wav,False +devel_074.wav,False +devel_075.wav,False +devel_076.wav,False +devel_077.wav,False +devel_078.wav,False +devel_079.wav,False +devel_080.wav,False +devel_081.wav,False +devel_082.wav,False +devel_083.wav,False +devel_084.wav,False +devel_085.wav,False +devel_086.wav,False +devel_087.wav,False +devel_088.wav,False +devel_089.wav,False +devel_090.wav,False +devel_091.wav,False +devel_092.wav,False +devel_093.wav,False +devel_094.wav,False +devel_095.wav,False +devel_096.wav,False +devel_097.wav,False +devel_098.wav,False +devel_099.wav,False +devel_100.wav,False +devel_101.wav,False +devel_102.wav,False +devel_103.wav,False +devel_104.wav,False +devel_105.wav,False +devel_106.wav,False +devel_107.wav,False +devel_108.wav,False +devel_109.wav,False +devel_110.wav,False +devel_111.wav,False +devel_112.wav,False +devel_113.wav,False +devel_114.wav,False +devel_115.wav,False +devel_116.wav,False +devel_117.wav,False +devel_118.wav,False +devel_119.wav,False +devel_120.wav,False +devel_121.wav,False +devel_122.wav,False +devel_123.wav,False +devel_124.wav,False +devel_125.wav,False +devel_126.wav,False +devel_127.wav,False +devel_128.wav,False +devel_129.wav,False +devel_130.wav,False +devel_131.wav,False +devel_132.wav,False +devel_133.wav,False +devel_134.wav,False +devel_135.wav,False +devel_136.wav,False +devel_137.wav,False +devel_138.wav,False +devel_139.wav,False +devel_140.wav,False +devel_141.wav,False +devel_142.wav,False +devel_143.wav,False +devel_144.wav,False +devel_145.wav,False +devel_146.wav,False +devel_147.wav,False +devel_148.wav,False +devel_149.wav,False +devel_150.wav,False +devel_151.wav,False +devel_152.wav,False +devel_153.wav,False +devel_154.wav,False +devel_155.wav,False +devel_156.wav,False +devel_157.wav,False +devel_158.wav,False +devel_159.wav,False +devel_160.wav,False +devel_161.wav,False +devel_162.wav,False +devel_163.wav,False +devel_164.wav,False +devel_165.wav,False +devel_166.wav,False +devel_167.wav,False +devel_168.wav,False +devel_169.wav,False +devel_170.wav,False +devel_171.wav,False +devel_172.wav,False +devel_173.wav,False +devel_174.wav,False +devel_175.wav,False +devel_176.wav,False +devel_177.wav,False +devel_178.wav,False +devel_179.wav,False +devel_180.wav,False +devel_181.wav,False +devel_182.wav,False +devel_183.wav,False +devel_184.wav,False +devel_185.wav,False +devel_186.wav,False +devel_187.wav,False +devel_188.wav,False +devel_189.wav,False +devel_190.wav,False +devel_191.wav,False +devel_192.wav,False +devel_193.wav,False +devel_194.wav,False +devel_195.wav,False +devel_196.wav,False +devel_197.wav,False +devel_198.wav,False +devel_199.wav,False +devel_200.wav,False +devel_201.wav,False +devel_202.wav,False +devel_203.wav,False +devel_204.wav,False +devel_205.wav,False +devel_206.wav,False +devel_207.wav,False +devel_208.wav,False +devel_209.wav,False +devel_210.wav,False +devel_211.wav,False +devel_212.wav,False +devel_213.wav,False +devel_214.wav,False +devel_215.wav,False +devel_216.wav,False +devel_217.wav,False +devel_218.wav,False +devel_219.wav,False +devel_220.wav,False +devel_221.wav,False +devel_222.wav,False +devel_223.wav,False +devel_224.wav,False +devel_225.wav,False +devel_226.wav,False +devel_227.wav,False +devel_228.wav,False +devel_229.wav,False +devel_230.wav,False +devel_231.wav,False +devel_232.wav,False +devel_233.wav,False +devel_234.wav,False +devel_235.wav,False +devel_236.wav,False +devel_237.wav,False +devel_238.wav,False +devel_239.wav,False +devel_240.wav,False +devel_241.wav,False +devel_242.wav,False +devel_243.wav,False +devel_244.wav,False +devel_245.wav,False +devel_246.wav,False +devel_247.wav,False +devel_248.wav,False +devel_249.wav,False +devel_250.wav,False +devel_251.wav,False +devel_252.wav,False +devel_253.wav,False +devel_254.wav,False +devel_255.wav,False +devel_256.wav,False +devel_257.wav,False +devel_258.wav,False +devel_259.wav,False +devel_260.wav,False +devel_261.wav,False +devel_262.wav,False +devel_263.wav,False +devel_264.wav,False +devel_265.wav,False +devel_266.wav,False +devel_267.wav,False +devel_268.wav,False +devel_269.wav,False +devel_270.wav,False +devel_271.wav,False +devel_272.wav,False +devel_273.wav,False +devel_274.wav,False +devel_275.wav,False +devel_276.wav,False +devel_277.wav,False +devel_278.wav,False +devel_279.wav,False +devel_280.wav,False +devel_281.wav,False +devel_282.wav,False +devel_283.wav,False +devel_284.wav,False +devel_285.wav,False +devel_286.wav,False +devel_287.wav,False +devel_288.wav,False +devel_289.wav,False +devel_290.wav,False +devel_291.wav,False +devel_292.wav,False +devel_293.wav,False +devel_294.wav,False +devel_295.wav,False diff --git a/results/trans_hand_vgg/test.predictions.csv b/results/trans_hand_vgg/test.predictions.csv new file mode 100644 index 0000000..103ca05 --- /dev/null +++ b/results/trans_hand_vgg/test.predictions.csv @@ -0,0 +1,284 @@ +filename,prediction +test_001.wav,False +test_002.wav,False +test_003.wav,False +test_004.wav,False +test_005.wav,False +test_006.wav,False +test_007.wav,False +test_008.wav,False +test_009.wav,False +test_010.wav,False +test_011.wav,False +test_012.wav,False +test_013.wav,False +test_014.wav,False +test_015.wav,False +test_016.wav,False +test_017.wav,False +test_018.wav,False +test_019.wav,False +test_020.wav,False +test_021.wav,False +test_022.wav,False +test_023.wav,False +test_024.wav,False +test_025.wav,False +test_026.wav,False +test_027.wav,False +test_028.wav,False +test_029.wav,False +test_030.wav,False +test_031.wav,False +test_032.wav,False +test_033.wav,False +test_034.wav,False +test_035.wav,False +test_036.wav,False +test_037.wav,False +test_038.wav,False +test_039.wav,False +test_040.wav,False +test_041.wav,False +test_042.wav,False +test_043.wav,False +test_044.wav,False +test_045.wav,False +test_046.wav,False +test_047.wav,False +test_048.wav,False +test_049.wav,False +test_050.wav,False +test_051.wav,False +test_052.wav,False +test_053.wav,False +test_054.wav,False +test_055.wav,False +test_056.wav,False +test_057.wav,False +test_058.wav,False +test_059.wav,False +test_060.wav,False +test_061.wav,False +test_062.wav,False +test_063.wav,False +test_064.wav,False +test_065.wav,False +test_066.wav,False +test_067.wav,False +test_068.wav,False +test_069.wav,False +test_070.wav,False +test_071.wav,False +test_072.wav,False +test_073.wav,False +test_074.wav,False +test_075.wav,False +test_076.wav,False +test_077.wav,False +test_078.wav,False +test_079.wav,False +test_080.wav,False +test_081.wav,False +test_082.wav,False +test_083.wav,False +test_084.wav,False +test_085.wav,False +test_086.wav,False +test_087.wav,False +test_088.wav,False +test_089.wav,False +test_090.wav,False +test_091.wav,False +test_092.wav,False +test_093.wav,False +test_094.wav,False +test_095.wav,False +test_096.wav,False +test_097.wav,False +test_098.wav,False +test_099.wav,False +test_100.wav,False +test_101.wav,False +test_102.wav,False +test_103.wav,False +test_104.wav,False +test_105.wav,False +test_106.wav,False +test_107.wav,False +test_108.wav,False +test_109.wav,False +test_110.wav,False +test_111.wav,False +test_112.wav,False +test_113.wav,False +test_114.wav,False +test_115.wav,False +test_116.wav,False +test_117.wav,False +test_118.wav,False +test_119.wav,False +test_120.wav,False +test_121.wav,False +test_122.wav,False +test_123.wav,False +test_124.wav,False +test_125.wav,False +test_126.wav,False +test_127.wav,False +test_128.wav,False +test_129.wav,False +test_130.wav,False +test_131.wav,False +test_132.wav,False +test_133.wav,False +test_134.wav,False +test_135.wav,False +test_136.wav,False +test_137.wav,False +test_138.wav,False +test_139.wav,False +test_140.wav,False +test_141.wav,False +test_142.wav,False +test_143.wav,False +test_144.wav,False +test_145.wav,False +test_146.wav,False +test_147.wav,False +test_148.wav,False +test_149.wav,False +test_150.wav,False +test_151.wav,False +test_152.wav,False +test_153.wav,False +test_154.wav,False +test_155.wav,False +test_156.wav,False +test_157.wav,False +test_158.wav,False +test_159.wav,False +test_160.wav,False +test_161.wav,False +test_162.wav,False +test_163.wav,False +test_164.wav,False +test_165.wav,False +test_166.wav,False +test_167.wav,False +test_168.wav,False +test_169.wav,False +test_170.wav,False +test_171.wav,False +test_172.wav,False +test_173.wav,False +test_174.wav,False +test_175.wav,False +test_176.wav,False +test_177.wav,False +test_178.wav,False +test_179.wav,False +test_180.wav,False +test_181.wav,False +test_182.wav,False +test_183.wav,False +test_184.wav,False +test_185.wav,False +test_186.wav,False +test_187.wav,False +test_188.wav,False +test_189.wav,False +test_190.wav,False +test_191.wav,False +test_192.wav,False +test_193.wav,False +test_194.wav,False +test_195.wav,False +test_196.wav,False +test_197.wav,False +test_198.wav,False +test_199.wav,False +test_200.wav,False +test_201.wav,False +test_202.wav,False +test_203.wav,False +test_204.wav,False +test_205.wav,False +test_206.wav,False +test_207.wav,False +test_208.wav,False +test_209.wav,False +test_210.wav,False +test_211.wav,False +test_212.wav,False +test_213.wav,False +test_214.wav,False +test_215.wav,False +test_216.wav,False +test_217.wav,False +test_218.wav,False +test_219.wav,False +test_220.wav,False +test_221.wav,False +test_222.wav,False +test_223.wav,False +test_224.wav,False +test_225.wav,False +test_226.wav,False +test_227.wav,False +test_228.wav,False +test_229.wav,False +test_230.wav,False +test_231.wav,False +test_232.wav,False +test_233.wav,False +test_234.wav,False +test_235.wav,False +test_236.wav,False +test_237.wav,False +test_238.wav,False +test_239.wav,False +test_240.wav,False +test_241.wav,False +test_242.wav,False +test_243.wav,False +test_244.wav,False +test_245.wav,False +test_246.wav,False +test_247.wav,False +test_248.wav,False +test_249.wav,False +test_250.wav,False +test_251.wav,False +test_252.wav,False +test_253.wav,False +test_254.wav,False +test_255.wav,False +test_256.wav,False +test_257.wav,False +test_258.wav,False +test_259.wav,False +test_260.wav,False +test_261.wav,False +test_262.wav,False +test_263.wav,False +test_264.wav,False +test_265.wav,False +test_266.wav,False +test_267.wav,False +test_268.wav,False +test_269.wav,False +test_270.wav,False +test_271.wav,False +test_272.wav,False +test_273.wav,False +test_274.wav,False +test_275.wav,False +test_276.wav,False +test_277.wav,False +test_278.wav,False +test_279.wav,False +test_280.wav,False +test_281.wav,False +test_282.wav,False +test_283.wav,False diff --git a/src/transformer_hand_vgg.py b/src/transformer_hand_vgg.py new file mode 100644 index 0000000..9789b44 --- /dev/null +++ b/src/transformer_hand_vgg.py @@ -0,0 +1,278 @@ +import numpy as np +from keras import backend as K +from sklearn.metrics import classification_report, confusion_matrix, recall_score, make_scorer, plot_confusion_matrix +import tensorflow as tf +import pandas as pd +import matplotlib.pyplot as plt +import os + +def non_nan_average(x): + # Computes the average of all elements that are not NaN in a rank 1 tensor + nan_mask = tf.math.is_nan(x) + x = tf.boolean_mask(x, tf.logical_not(nan_mask)) + return K.mean(x) + + +def uar_accuracy(y_true, y_pred): + # Calculate the label from one-hot encoding + pred_class_label = K.argmax(y_pred, axis=-1) + true_class_label = K.argmax(y_true, axis=-1) + + cf_mat = tf.math.confusion_matrix(true_class_label, pred_class_label ) + + diag = tf.linalg.tensor_diag_part(cf_mat) + + # Calculate the total number of data examples for each class + total_per_class = tf.reduce_sum(cf_mat, axis=1) + + acc_per_class = diag / tf.maximum(1, total_per_class) + uar = non_nan_average(acc_per_class) + + return uar + +# load features and labels +devel_X_vgg = np.load( + "./features/vgg_features/x_devel_data_vgg.npy", allow_pickle=True +) + +test_X_vgg = np.load( + "./features/vgg_features/x_test_data_vgg.npy", allow_pickle=True +) + +train_X_vgg = np.load( + "./features/vgg_features/x_train_data_vgg.npy", allow_pickle=True +) + +devel_X_hand = np.load( + "./features/hand_features/x_devel_data.npy", allow_pickle=True +) + +test_X_hand = np.load( + "./features/hand_features/x_test_data.npy", allow_pickle=True +) + +train_X_hand = np.load( + "./features/hand_features/x_train_data.npy", allow_pickle=True +) + +devel_y = np.load( + "./features/vgg_features/y_devel_label_vgg.npy", allow_pickle=True +) + +test_y = np.load( + "./features/vgg_features/y_test_label_vgg.npy", allow_pickle=True +) + +train_y = np.load( + "./features/vgg_features/y_train_label_vgg.npy", allow_pickle=True +) + +devel_names = np.load( + "./features/hand_features/devel_names.npy", allow_pickle=True +) + +test_names = np.load( + "./features/hand_features/test_names.npy", allow_pickle=True +) + +train_X_vgg = np.squeeze(train_X_vgg) +devel_X_vgg = np.squeeze(devel_X_vgg) +test_X_vgg = np.squeeze(test_X_vgg) + +devel_X=np.concatenate( + ( + devel_X_hand, + devel_X_vgg + ), + axis=1, +) + +test_X=np.concatenate( + ( + test_X_hand, + test_X_vgg + ), + axis=1, +) + +train_X=np.concatenate( + ( + train_X_hand, + train_X_vgg + ), + axis=1, +) + +X = np.append(train_X, devel_X, axis=0) +y = np.append(train_y, devel_y, axis=0) + +x = X.reshape((X.shape[0], X.shape[1], 1)) +x_train = train_X.reshape((train_X.shape[0], train_X.shape[1], 1)) +x_test = test_X.reshape((test_X.shape[0], test_X.shape[1], 1)) +devel_X = devel_X.reshape((devel_X.shape[0], devel_X.shape[1], 1)) + +n_classes = len(np.unique(y)) + +train_y[train_y == "positive"] = 1 +train_y[train_y == "negative"] = 0 + +y[y == "positive"] = 1 +y[y == "negative"] = 0 + +devel_y[devel_y == "positive"] = 1 +devel_y[devel_y == "negative"] = 0 + +test_y[test_y == "positive"] = 1 +test_y[test_y == "negative"] = 0 + +""" +## Build the model +Our model processes a tensor of shape `(batch size, sequence length, features)`, +where `sequence length` is the number of time steps and `features` is each input +timeseries. +You can replace your classification RNN layers with this one: the +inputs are fully compatible! +""" + +from tensorflow import keras +from tensorflow.keras import layers + +""" +We include residual connections, layer normalization, and dropout. +The resulting layer can be stacked multiple times. +The projection layers are implemented through `keras.layers.Conv1D`. +""" + + +def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0): + # Attention and Normalization + x = layers.MultiHeadAttention( + key_dim=head_size, num_heads=num_heads, dropout=dropout + )(inputs, inputs) + x = layers.Dropout(dropout)(x) + x = layers.LayerNormalization(epsilon=1e-6)(x) + res = x + inputs + + # Feed Forward Part + x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res) + x = layers.Dropout(dropout)(x) + x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x) + x = layers.LayerNormalization(epsilon=1e-6)(x) + return x + res + + +""" +The main part of our model is now complete. We can stack multiple of those +`transformer_encoder` blocks and we can also proceed to add the final +Multi-Layer Perceptron classification head. Apart from a stack of `Dense` +layers, we need to reduce the output tensor of the `TransformerEncoder` part of +our model down to a vector of features for each data point in the current +batch. A common way to achieve this is to use a pooling layer. For +this example, a `GlobalAveragePooling1D` layer is sufficient. +""" + + +def build_model( + input_shape, + head_size, + num_heads, + ff_dim, + num_transformer_blocks, + mlp_units, + dropout=0, + mlp_dropout=0, +): + inputs = keras.Input(shape=input_shape) + x = inputs + for _ in range(num_transformer_blocks): + x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout) + + x = layers.GlobalAveragePooling1D(data_format="channels_first")(x) + for dim in mlp_units: + x = layers.Dense(dim, activation="relu")(x) + x = layers.Dropout(mlp_dropout)(x) + outputs = layers.Dense(1, activation="sigmoid")(x) + return keras.Model(inputs, outputs) + + +""" +## Train and evaluate +""" + +input_shape = x_train.shape[1:] + +model = build_model( + input_shape, + head_size=256, + num_heads=4, + ff_dim=4, + num_transformer_blocks=4, + mlp_units=[128], + mlp_dropout=0.4, + dropout=0.25, +) + +model.compile( + loss="binary_crossentropy", + optimizer=keras.optimizers.Adam(learning_rate=1e-4), + metrics=[uar_accuracy], +) + +model.summary() + +callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)] + +model.fit( + np.asarray(x_train).astype(np.float32), + np.asarray(train_y).astype(np.float32), + validation_split=0.2, + epochs=20, + batch_size=64, + callbacks=callbacks, +) + +devel_y_pred = model.predict(np.asarray(devel_X).astype(np.float32), verbose=1) +devel_y_pred = devel_y_pred.argmax(axis=-1) + +devel_y_pred = devel_y_pred.astype('bool') +devel_y = devel_y.astype('bool') + +# devel metrics +print('DEVEL') +uar = recall_score(devel_y, devel_y_pred, average='macro') +cm = confusion_matrix(devel_y, devel_y_pred) +print(f'UAR: {uar}\n{classification_report(devel_y, devel_y_pred)}\n\nConfusion Matrix:\n\n{cm}') + + +model.fit( + np.asarray(x).astype(np.float32), + np.asarray(y).astype(np.float32), + validation_split=0.2, + epochs=20, + batch_size=64, + callbacks=callbacks, +) + +test_y_pred = model.predict(np.asarray(x_test).astype(np.float32), verbose=1) +test_y_pred = test_y_pred.argmax(axis=-1) + +test_y_pred = test_y_pred.astype('bool') +test_y = test_y.astype('bool') + +# devel metrics +print('DEVEL') +uar = recall_score(devel_y, devel_y_pred, average='macro') +cm = confusion_matrix(devel_y, devel_y_pred) +print(f'UAR: {uar}\n{classification_report(devel_y, devel_y_pred)}\n\nConfusion Matrix:\n\n{cm}') + +df_predictions = pd.DataFrame({'filename': devel_names.tolist(), 'prediction': devel_y_pred.tolist()}) +df_predictions.to_csv(os.path.join('./results/trans_hand_vgg/', 'devel.predictions.csv'), index=False) + +# test metrics +print('TEST') +uar = recall_score(test_y, test_y_pred, average='macro') +cm = confusion_matrix(test_y, test_y_pred) +print(f'UAR: {uar}\n{classification_report(test_y, test_y_pred)}\n\nConfusion Matrix:\n\n{cm}') + +df_predictions = pd.DataFrame({'filename': test_names.tolist(), 'prediction': test_y_pred.tolist()}) +df_predictions.to_csv(os.path.join('./results/trans_hand_vgg/', 'test.predictions.csv'), index=False) \ No newline at end of file