transformer
This commit is contained in:
parent
76bf877800
commit
2d1376cd6d
296
results/trans_hand_vgg/devel.predictions.csv
Normal file
296
results/trans_hand_vgg/devel.predictions.csv
Normal file
@ -0,0 +1,296 @@
|
||||
filename,prediction
|
||||
devel_001.wav,False
|
||||
devel_002.wav,False
|
||||
devel_003.wav,False
|
||||
devel_004.wav,False
|
||||
devel_005.wav,False
|
||||
devel_006.wav,False
|
||||
devel_007.wav,False
|
||||
devel_008.wav,False
|
||||
devel_009.wav,False
|
||||
devel_010.wav,False
|
||||
devel_011.wav,False
|
||||
devel_012.wav,False
|
||||
devel_013.wav,False
|
||||
devel_014.wav,False
|
||||
devel_015.wav,False
|
||||
devel_016.wav,False
|
||||
devel_017.wav,False
|
||||
devel_018.wav,False
|
||||
devel_019.wav,False
|
||||
devel_020.wav,False
|
||||
devel_021.wav,False
|
||||
devel_022.wav,False
|
||||
devel_023.wav,False
|
||||
devel_024.wav,False
|
||||
devel_025.wav,False
|
||||
devel_026.wav,False
|
||||
devel_027.wav,False
|
||||
devel_028.wav,False
|
||||
devel_029.wav,False
|
||||
devel_030.wav,False
|
||||
devel_031.wav,False
|
||||
devel_032.wav,False
|
||||
devel_033.wav,False
|
||||
devel_034.wav,False
|
||||
devel_035.wav,False
|
||||
devel_036.wav,False
|
||||
devel_037.wav,False
|
||||
devel_038.wav,False
|
||||
devel_039.wav,False
|
||||
devel_040.wav,False
|
||||
devel_041.wav,False
|
||||
devel_042.wav,False
|
||||
devel_043.wav,False
|
||||
devel_044.wav,False
|
||||
devel_045.wav,False
|
||||
devel_046.wav,False
|
||||
devel_047.wav,False
|
||||
devel_048.wav,False
|
||||
devel_049.wav,False
|
||||
devel_050.wav,False
|
||||
devel_051.wav,False
|
||||
devel_052.wav,False
|
||||
devel_053.wav,False
|
||||
devel_054.wav,False
|
||||
devel_055.wav,False
|
||||
devel_056.wav,False
|
||||
devel_057.wav,False
|
||||
devel_058.wav,False
|
||||
devel_059.wav,False
|
||||
devel_060.wav,False
|
||||
devel_061.wav,False
|
||||
devel_062.wav,False
|
||||
devel_063.wav,False
|
||||
devel_064.wav,False
|
||||
devel_065.wav,False
|
||||
devel_066.wav,False
|
||||
devel_067.wav,False
|
||||
devel_068.wav,False
|
||||
devel_069.wav,False
|
||||
devel_070.wav,False
|
||||
devel_071.wav,False
|
||||
devel_072.wav,False
|
||||
devel_073.wav,False
|
||||
devel_074.wav,False
|
||||
devel_075.wav,False
|
||||
devel_076.wav,False
|
||||
devel_077.wav,False
|
||||
devel_078.wav,False
|
||||
devel_079.wav,False
|
||||
devel_080.wav,False
|
||||
devel_081.wav,False
|
||||
devel_082.wav,False
|
||||
devel_083.wav,False
|
||||
devel_084.wav,False
|
||||
devel_085.wav,False
|
||||
devel_086.wav,False
|
||||
devel_087.wav,False
|
||||
devel_088.wav,False
|
||||
devel_089.wav,False
|
||||
devel_090.wav,False
|
||||
devel_091.wav,False
|
||||
devel_092.wav,False
|
||||
devel_093.wav,False
|
||||
devel_094.wav,False
|
||||
devel_095.wav,False
|
||||
devel_096.wav,False
|
||||
devel_097.wav,False
|
||||
devel_098.wav,False
|
||||
devel_099.wav,False
|
||||
devel_100.wav,False
|
||||
devel_101.wav,False
|
||||
devel_102.wav,False
|
||||
devel_103.wav,False
|
||||
devel_104.wav,False
|
||||
devel_105.wav,False
|
||||
devel_106.wav,False
|
||||
devel_107.wav,False
|
||||
devel_108.wav,False
|
||||
devel_109.wav,False
|
||||
devel_110.wav,False
|
||||
devel_111.wav,False
|
||||
devel_112.wav,False
|
||||
devel_113.wav,False
|
||||
devel_114.wav,False
|
||||
devel_115.wav,False
|
||||
devel_116.wav,False
|
||||
devel_117.wav,False
|
||||
devel_118.wav,False
|
||||
devel_119.wav,False
|
||||
devel_120.wav,False
|
||||
devel_121.wav,False
|
||||
devel_122.wav,False
|
||||
devel_123.wav,False
|
||||
devel_124.wav,False
|
||||
devel_125.wav,False
|
||||
devel_126.wav,False
|
||||
devel_127.wav,False
|
||||
devel_128.wav,False
|
||||
devel_129.wav,False
|
||||
devel_130.wav,False
|
||||
devel_131.wav,False
|
||||
devel_132.wav,False
|
||||
devel_133.wav,False
|
||||
devel_134.wav,False
|
||||
devel_135.wav,False
|
||||
devel_136.wav,False
|
||||
devel_137.wav,False
|
||||
devel_138.wav,False
|
||||
devel_139.wav,False
|
||||
devel_140.wav,False
|
||||
devel_141.wav,False
|
||||
devel_142.wav,False
|
||||
devel_143.wav,False
|
||||
devel_144.wav,False
|
||||
devel_145.wav,False
|
||||
devel_146.wav,False
|
||||
devel_147.wav,False
|
||||
devel_148.wav,False
|
||||
devel_149.wav,False
|
||||
devel_150.wav,False
|
||||
devel_151.wav,False
|
||||
devel_152.wav,False
|
||||
devel_153.wav,False
|
||||
devel_154.wav,False
|
||||
devel_155.wav,False
|
||||
devel_156.wav,False
|
||||
devel_157.wav,False
|
||||
devel_158.wav,False
|
||||
devel_159.wav,False
|
||||
devel_160.wav,False
|
||||
devel_161.wav,False
|
||||
devel_162.wav,False
|
||||
devel_163.wav,False
|
||||
devel_164.wav,False
|
||||
devel_165.wav,False
|
||||
devel_166.wav,False
|
||||
devel_167.wav,False
|
||||
devel_168.wav,False
|
||||
devel_169.wav,False
|
||||
devel_170.wav,False
|
||||
devel_171.wav,False
|
||||
devel_172.wav,False
|
||||
devel_173.wav,False
|
||||
devel_174.wav,False
|
||||
devel_175.wav,False
|
||||
devel_176.wav,False
|
||||
devel_177.wav,False
|
||||
devel_178.wav,False
|
||||
devel_179.wav,False
|
||||
devel_180.wav,False
|
||||
devel_181.wav,False
|
||||
devel_182.wav,False
|
||||
devel_183.wav,False
|
||||
devel_184.wav,False
|
||||
devel_185.wav,False
|
||||
devel_186.wav,False
|
||||
devel_187.wav,False
|
||||
devel_188.wav,False
|
||||
devel_189.wav,False
|
||||
devel_190.wav,False
|
||||
devel_191.wav,False
|
||||
devel_192.wav,False
|
||||
devel_193.wav,False
|
||||
devel_194.wav,False
|
||||
devel_195.wav,False
|
||||
devel_196.wav,False
|
||||
devel_197.wav,False
|
||||
devel_198.wav,False
|
||||
devel_199.wav,False
|
||||
devel_200.wav,False
|
||||
devel_201.wav,False
|
||||
devel_202.wav,False
|
||||
devel_203.wav,False
|
||||
devel_204.wav,False
|
||||
devel_205.wav,False
|
||||
devel_206.wav,False
|
||||
devel_207.wav,False
|
||||
devel_208.wav,False
|
||||
devel_209.wav,False
|
||||
devel_210.wav,False
|
||||
devel_211.wav,False
|
||||
devel_212.wav,False
|
||||
devel_213.wav,False
|
||||
devel_214.wav,False
|
||||
devel_215.wav,False
|
||||
devel_216.wav,False
|
||||
devel_217.wav,False
|
||||
devel_218.wav,False
|
||||
devel_219.wav,False
|
||||
devel_220.wav,False
|
||||
devel_221.wav,False
|
||||
devel_222.wav,False
|
||||
devel_223.wav,False
|
||||
devel_224.wav,False
|
||||
devel_225.wav,False
|
||||
devel_226.wav,False
|
||||
devel_227.wav,False
|
||||
devel_228.wav,False
|
||||
devel_229.wav,False
|
||||
devel_230.wav,False
|
||||
devel_231.wav,False
|
||||
devel_232.wav,False
|
||||
devel_233.wav,False
|
||||
devel_234.wav,False
|
||||
devel_235.wav,False
|
||||
devel_236.wav,False
|
||||
devel_237.wav,False
|
||||
devel_238.wav,False
|
||||
devel_239.wav,False
|
||||
devel_240.wav,False
|
||||
devel_241.wav,False
|
||||
devel_242.wav,False
|
||||
devel_243.wav,False
|
||||
devel_244.wav,False
|
||||
devel_245.wav,False
|
||||
devel_246.wav,False
|
||||
devel_247.wav,False
|
||||
devel_248.wav,False
|
||||
devel_249.wav,False
|
||||
devel_250.wav,False
|
||||
devel_251.wav,False
|
||||
devel_252.wav,False
|
||||
devel_253.wav,False
|
||||
devel_254.wav,False
|
||||
devel_255.wav,False
|
||||
devel_256.wav,False
|
||||
devel_257.wav,False
|
||||
devel_258.wav,False
|
||||
devel_259.wav,False
|
||||
devel_260.wav,False
|
||||
devel_261.wav,False
|
||||
devel_262.wav,False
|
||||
devel_263.wav,False
|
||||
devel_264.wav,False
|
||||
devel_265.wav,False
|
||||
devel_266.wav,False
|
||||
devel_267.wav,False
|
||||
devel_268.wav,False
|
||||
devel_269.wav,False
|
||||
devel_270.wav,False
|
||||
devel_271.wav,False
|
||||
devel_272.wav,False
|
||||
devel_273.wav,False
|
||||
devel_274.wav,False
|
||||
devel_275.wav,False
|
||||
devel_276.wav,False
|
||||
devel_277.wav,False
|
||||
devel_278.wav,False
|
||||
devel_279.wav,False
|
||||
devel_280.wav,False
|
||||
devel_281.wav,False
|
||||
devel_282.wav,False
|
||||
devel_283.wav,False
|
||||
devel_284.wav,False
|
||||
devel_285.wav,False
|
||||
devel_286.wav,False
|
||||
devel_287.wav,False
|
||||
devel_288.wav,False
|
||||
devel_289.wav,False
|
||||
devel_290.wav,False
|
||||
devel_291.wav,False
|
||||
devel_292.wav,False
|
||||
devel_293.wav,False
|
||||
devel_294.wav,False
|
||||
devel_295.wav,False
|
|
284
results/trans_hand_vgg/test.predictions.csv
Normal file
284
results/trans_hand_vgg/test.predictions.csv
Normal file
@ -0,0 +1,284 @@
|
||||
filename,prediction
|
||||
test_001.wav,False
|
||||
test_002.wav,False
|
||||
test_003.wav,False
|
||||
test_004.wav,False
|
||||
test_005.wav,False
|
||||
test_006.wav,False
|
||||
test_007.wav,False
|
||||
test_008.wav,False
|
||||
test_009.wav,False
|
||||
test_010.wav,False
|
||||
test_011.wav,False
|
||||
test_012.wav,False
|
||||
test_013.wav,False
|
||||
test_014.wav,False
|
||||
test_015.wav,False
|
||||
test_016.wav,False
|
||||
test_017.wav,False
|
||||
test_018.wav,False
|
||||
test_019.wav,False
|
||||
test_020.wav,False
|
||||
test_021.wav,False
|
||||
test_022.wav,False
|
||||
test_023.wav,False
|
||||
test_024.wav,False
|
||||
test_025.wav,False
|
||||
test_026.wav,False
|
||||
test_027.wav,False
|
||||
test_028.wav,False
|
||||
test_029.wav,False
|
||||
test_030.wav,False
|
||||
test_031.wav,False
|
||||
test_032.wav,False
|
||||
test_033.wav,False
|
||||
test_034.wav,False
|
||||
test_035.wav,False
|
||||
test_036.wav,False
|
||||
test_037.wav,False
|
||||
test_038.wav,False
|
||||
test_039.wav,False
|
||||
test_040.wav,False
|
||||
test_041.wav,False
|
||||
test_042.wav,False
|
||||
test_043.wav,False
|
||||
test_044.wav,False
|
||||
test_045.wav,False
|
||||
test_046.wav,False
|
||||
test_047.wav,False
|
||||
test_048.wav,False
|
||||
test_049.wav,False
|
||||
test_050.wav,False
|
||||
test_051.wav,False
|
||||
test_052.wav,False
|
||||
test_053.wav,False
|
||||
test_054.wav,False
|
||||
test_055.wav,False
|
||||
test_056.wav,False
|
||||
test_057.wav,False
|
||||
test_058.wav,False
|
||||
test_059.wav,False
|
||||
test_060.wav,False
|
||||
test_061.wav,False
|
||||
test_062.wav,False
|
||||
test_063.wav,False
|
||||
test_064.wav,False
|
||||
test_065.wav,False
|
||||
test_066.wav,False
|
||||
test_067.wav,False
|
||||
test_068.wav,False
|
||||
test_069.wav,False
|
||||
test_070.wav,False
|
||||
test_071.wav,False
|
||||
test_072.wav,False
|
||||
test_073.wav,False
|
||||
test_074.wav,False
|
||||
test_075.wav,False
|
||||
test_076.wav,False
|
||||
test_077.wav,False
|
||||
test_078.wav,False
|
||||
test_079.wav,False
|
||||
test_080.wav,False
|
||||
test_081.wav,False
|
||||
test_082.wav,False
|
||||
test_083.wav,False
|
||||
test_084.wav,False
|
||||
test_085.wav,False
|
||||
test_086.wav,False
|
||||
test_087.wav,False
|
||||
test_088.wav,False
|
||||
test_089.wav,False
|
||||
test_090.wav,False
|
||||
test_091.wav,False
|
||||
test_092.wav,False
|
||||
test_093.wav,False
|
||||
test_094.wav,False
|
||||
test_095.wav,False
|
||||
test_096.wav,False
|
||||
test_097.wav,False
|
||||
test_098.wav,False
|
||||
test_099.wav,False
|
||||
test_100.wav,False
|
||||
test_101.wav,False
|
||||
test_102.wav,False
|
||||
test_103.wav,False
|
||||
test_104.wav,False
|
||||
test_105.wav,False
|
||||
test_106.wav,False
|
||||
test_107.wav,False
|
||||
test_108.wav,False
|
||||
test_109.wav,False
|
||||
test_110.wav,False
|
||||
test_111.wav,False
|
||||
test_112.wav,False
|
||||
test_113.wav,False
|
||||
test_114.wav,False
|
||||
test_115.wav,False
|
||||
test_116.wav,False
|
||||
test_117.wav,False
|
||||
test_118.wav,False
|
||||
test_119.wav,False
|
||||
test_120.wav,False
|
||||
test_121.wav,False
|
||||
test_122.wav,False
|
||||
test_123.wav,False
|
||||
test_124.wav,False
|
||||
test_125.wav,False
|
||||
test_126.wav,False
|
||||
test_127.wav,False
|
||||
test_128.wav,False
|
||||
test_129.wav,False
|
||||
test_130.wav,False
|
||||
test_131.wav,False
|
||||
test_132.wav,False
|
||||
test_133.wav,False
|
||||
test_134.wav,False
|
||||
test_135.wav,False
|
||||
test_136.wav,False
|
||||
test_137.wav,False
|
||||
test_138.wav,False
|
||||
test_139.wav,False
|
||||
test_140.wav,False
|
||||
test_141.wav,False
|
||||
test_142.wav,False
|
||||
test_143.wav,False
|
||||
test_144.wav,False
|
||||
test_145.wav,False
|
||||
test_146.wav,False
|
||||
test_147.wav,False
|
||||
test_148.wav,False
|
||||
test_149.wav,False
|
||||
test_150.wav,False
|
||||
test_151.wav,False
|
||||
test_152.wav,False
|
||||
test_153.wav,False
|
||||
test_154.wav,False
|
||||
test_155.wav,False
|
||||
test_156.wav,False
|
||||
test_157.wav,False
|
||||
test_158.wav,False
|
||||
test_159.wav,False
|
||||
test_160.wav,False
|
||||
test_161.wav,False
|
||||
test_162.wav,False
|
||||
test_163.wav,False
|
||||
test_164.wav,False
|
||||
test_165.wav,False
|
||||
test_166.wav,False
|
||||
test_167.wav,False
|
||||
test_168.wav,False
|
||||
test_169.wav,False
|
||||
test_170.wav,False
|
||||
test_171.wav,False
|
||||
test_172.wav,False
|
||||
test_173.wav,False
|
||||
test_174.wav,False
|
||||
test_175.wav,False
|
||||
test_176.wav,False
|
||||
test_177.wav,False
|
||||
test_178.wav,False
|
||||
test_179.wav,False
|
||||
test_180.wav,False
|
||||
test_181.wav,False
|
||||
test_182.wav,False
|
||||
test_183.wav,False
|
||||
test_184.wav,False
|
||||
test_185.wav,False
|
||||
test_186.wav,False
|
||||
test_187.wav,False
|
||||
test_188.wav,False
|
||||
test_189.wav,False
|
||||
test_190.wav,False
|
||||
test_191.wav,False
|
||||
test_192.wav,False
|
||||
test_193.wav,False
|
||||
test_194.wav,False
|
||||
test_195.wav,False
|
||||
test_196.wav,False
|
||||
test_197.wav,False
|
||||
test_198.wav,False
|
||||
test_199.wav,False
|
||||
test_200.wav,False
|
||||
test_201.wav,False
|
||||
test_202.wav,False
|
||||
test_203.wav,False
|
||||
test_204.wav,False
|
||||
test_205.wav,False
|
||||
test_206.wav,False
|
||||
test_207.wav,False
|
||||
test_208.wav,False
|
||||
test_209.wav,False
|
||||
test_210.wav,False
|
||||
test_211.wav,False
|
||||
test_212.wav,False
|
||||
test_213.wav,False
|
||||
test_214.wav,False
|
||||
test_215.wav,False
|
||||
test_216.wav,False
|
||||
test_217.wav,False
|
||||
test_218.wav,False
|
||||
test_219.wav,False
|
||||
test_220.wav,False
|
||||
test_221.wav,False
|
||||
test_222.wav,False
|
||||
test_223.wav,False
|
||||
test_224.wav,False
|
||||
test_225.wav,False
|
||||
test_226.wav,False
|
||||
test_227.wav,False
|
||||
test_228.wav,False
|
||||
test_229.wav,False
|
||||
test_230.wav,False
|
||||
test_231.wav,False
|
||||
test_232.wav,False
|
||||
test_233.wav,False
|
||||
test_234.wav,False
|
||||
test_235.wav,False
|
||||
test_236.wav,False
|
||||
test_237.wav,False
|
||||
test_238.wav,False
|
||||
test_239.wav,False
|
||||
test_240.wav,False
|
||||
test_241.wav,False
|
||||
test_242.wav,False
|
||||
test_243.wav,False
|
||||
test_244.wav,False
|
||||
test_245.wav,False
|
||||
test_246.wav,False
|
||||
test_247.wav,False
|
||||
test_248.wav,False
|
||||
test_249.wav,False
|
||||
test_250.wav,False
|
||||
test_251.wav,False
|
||||
test_252.wav,False
|
||||
test_253.wav,False
|
||||
test_254.wav,False
|
||||
test_255.wav,False
|
||||
test_256.wav,False
|
||||
test_257.wav,False
|
||||
test_258.wav,False
|
||||
test_259.wav,False
|
||||
test_260.wav,False
|
||||
test_261.wav,False
|
||||
test_262.wav,False
|
||||
test_263.wav,False
|
||||
test_264.wav,False
|
||||
test_265.wav,False
|
||||
test_266.wav,False
|
||||
test_267.wav,False
|
||||
test_268.wav,False
|
||||
test_269.wav,False
|
||||
test_270.wav,False
|
||||
test_271.wav,False
|
||||
test_272.wav,False
|
||||
test_273.wav,False
|
||||
test_274.wav,False
|
||||
test_275.wav,False
|
||||
test_276.wav,False
|
||||
test_277.wav,False
|
||||
test_278.wav,False
|
||||
test_279.wav,False
|
||||
test_280.wav,False
|
||||
test_281.wav,False
|
||||
test_282.wav,False
|
||||
test_283.wav,False
|
|
278
src/transformer_hand_vgg.py
Normal file
278
src/transformer_hand_vgg.py
Normal file
@ -0,0 +1,278 @@
|
||||
import numpy as np
|
||||
from keras import backend as K
|
||||
from sklearn.metrics import classification_report, confusion_matrix, recall_score, make_scorer, plot_confusion_matrix
|
||||
import tensorflow as tf
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
|
||||
def non_nan_average(x):
|
||||
# Computes the average of all elements that are not NaN in a rank 1 tensor
|
||||
nan_mask = tf.math.is_nan(x)
|
||||
x = tf.boolean_mask(x, tf.logical_not(nan_mask))
|
||||
return K.mean(x)
|
||||
|
||||
|
||||
def uar_accuracy(y_true, y_pred):
|
||||
# Calculate the label from one-hot encoding
|
||||
pred_class_label = K.argmax(y_pred, axis=-1)
|
||||
true_class_label = K.argmax(y_true, axis=-1)
|
||||
|
||||
cf_mat = tf.math.confusion_matrix(true_class_label, pred_class_label )
|
||||
|
||||
diag = tf.linalg.tensor_diag_part(cf_mat)
|
||||
|
||||
# Calculate the total number of data examples for each class
|
||||
total_per_class = tf.reduce_sum(cf_mat, axis=1)
|
||||
|
||||
acc_per_class = diag / tf.maximum(1, total_per_class)
|
||||
uar = non_nan_average(acc_per_class)
|
||||
|
||||
return uar
|
||||
|
||||
# load features and labels
|
||||
devel_X_vgg = np.load(
|
||||
"./features/vgg_features/x_devel_data_vgg.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
test_X_vgg = np.load(
|
||||
"./features/vgg_features/x_test_data_vgg.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
train_X_vgg = np.load(
|
||||
"./features/vgg_features/x_train_data_vgg.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
devel_X_hand = np.load(
|
||||
"./features/hand_features/x_devel_data.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
test_X_hand = np.load(
|
||||
"./features/hand_features/x_test_data.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
train_X_hand = np.load(
|
||||
"./features/hand_features/x_train_data.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
devel_y = np.load(
|
||||
"./features/vgg_features/y_devel_label_vgg.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
test_y = np.load(
|
||||
"./features/vgg_features/y_test_label_vgg.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
train_y = np.load(
|
||||
"./features/vgg_features/y_train_label_vgg.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
devel_names = np.load(
|
||||
"./features/hand_features/devel_names.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
test_names = np.load(
|
||||
"./features/hand_features/test_names.npy", allow_pickle=True
|
||||
)
|
||||
|
||||
train_X_vgg = np.squeeze(train_X_vgg)
|
||||
devel_X_vgg = np.squeeze(devel_X_vgg)
|
||||
test_X_vgg = np.squeeze(test_X_vgg)
|
||||
|
||||
devel_X=np.concatenate(
|
||||
(
|
||||
devel_X_hand,
|
||||
devel_X_vgg
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
test_X=np.concatenate(
|
||||
(
|
||||
test_X_hand,
|
||||
test_X_vgg
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
train_X=np.concatenate(
|
||||
(
|
||||
train_X_hand,
|
||||
train_X_vgg
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
X = np.append(train_X, devel_X, axis=0)
|
||||
y = np.append(train_y, devel_y, axis=0)
|
||||
|
||||
x = X.reshape((X.shape[0], X.shape[1], 1))
|
||||
x_train = train_X.reshape((train_X.shape[0], train_X.shape[1], 1))
|
||||
x_test = test_X.reshape((test_X.shape[0], test_X.shape[1], 1))
|
||||
devel_X = devel_X.reshape((devel_X.shape[0], devel_X.shape[1], 1))
|
||||
|
||||
n_classes = len(np.unique(y))
|
||||
|
||||
train_y[train_y == "positive"] = 1
|
||||
train_y[train_y == "negative"] = 0
|
||||
|
||||
y[y == "positive"] = 1
|
||||
y[y == "negative"] = 0
|
||||
|
||||
devel_y[devel_y == "positive"] = 1
|
||||
devel_y[devel_y == "negative"] = 0
|
||||
|
||||
test_y[test_y == "positive"] = 1
|
||||
test_y[test_y == "negative"] = 0
|
||||
|
||||
"""
|
||||
## Build the model
|
||||
Our model processes a tensor of shape `(batch size, sequence length, features)`,
|
||||
where `sequence length` is the number of time steps and `features` is each input
|
||||
timeseries.
|
||||
You can replace your classification RNN layers with this one: the
|
||||
inputs are fully compatible!
|
||||
"""
|
||||
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
|
||||
"""
|
||||
We include residual connections, layer normalization, and dropout.
|
||||
The resulting layer can be stacked multiple times.
|
||||
The projection layers are implemented through `keras.layers.Conv1D`.
|
||||
"""
|
||||
|
||||
|
||||
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
|
||||
# Attention and Normalization
|
||||
x = layers.MultiHeadAttention(
|
||||
key_dim=head_size, num_heads=num_heads, dropout=dropout
|
||||
)(inputs, inputs)
|
||||
x = layers.Dropout(dropout)(x)
|
||||
x = layers.LayerNormalization(epsilon=1e-6)(x)
|
||||
res = x + inputs
|
||||
|
||||
# Feed Forward Part
|
||||
x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
|
||||
x = layers.Dropout(dropout)(x)
|
||||
x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
|
||||
x = layers.LayerNormalization(epsilon=1e-6)(x)
|
||||
return x + res
|
||||
|
||||
|
||||
"""
|
||||
The main part of our model is now complete. We can stack multiple of those
|
||||
`transformer_encoder` blocks and we can also proceed to add the final
|
||||
Multi-Layer Perceptron classification head. Apart from a stack of `Dense`
|
||||
layers, we need to reduce the output tensor of the `TransformerEncoder` part of
|
||||
our model down to a vector of features for each data point in the current
|
||||
batch. A common way to achieve this is to use a pooling layer. For
|
||||
this example, a `GlobalAveragePooling1D` layer is sufficient.
|
||||
"""
|
||||
|
||||
|
||||
def build_model(
|
||||
input_shape,
|
||||
head_size,
|
||||
num_heads,
|
||||
ff_dim,
|
||||
num_transformer_blocks,
|
||||
mlp_units,
|
||||
dropout=0,
|
||||
mlp_dropout=0,
|
||||
):
|
||||
inputs = keras.Input(shape=input_shape)
|
||||
x = inputs
|
||||
for _ in range(num_transformer_blocks):
|
||||
x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)
|
||||
|
||||
x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
|
||||
for dim in mlp_units:
|
||||
x = layers.Dense(dim, activation="relu")(x)
|
||||
x = layers.Dropout(mlp_dropout)(x)
|
||||
outputs = layers.Dense(1, activation="sigmoid")(x)
|
||||
return keras.Model(inputs, outputs)
|
||||
|
||||
|
||||
"""
|
||||
## Train and evaluate
|
||||
"""
|
||||
|
||||
input_shape = x_train.shape[1:]
|
||||
|
||||
model = build_model(
|
||||
input_shape,
|
||||
head_size=256,
|
||||
num_heads=4,
|
||||
ff_dim=4,
|
||||
num_transformer_blocks=4,
|
||||
mlp_units=[128],
|
||||
mlp_dropout=0.4,
|
||||
dropout=0.25,
|
||||
)
|
||||
|
||||
model.compile(
|
||||
loss="binary_crossentropy",
|
||||
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
|
||||
metrics=[uar_accuracy],
|
||||
)
|
||||
|
||||
model.summary()
|
||||
|
||||
callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
|
||||
|
||||
model.fit(
|
||||
np.asarray(x_train).astype(np.float32),
|
||||
np.asarray(train_y).astype(np.float32),
|
||||
validation_split=0.2,
|
||||
epochs=20,
|
||||
batch_size=64,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
devel_y_pred = model.predict(np.asarray(devel_X).astype(np.float32), verbose=1)
|
||||
devel_y_pred = devel_y_pred.argmax(axis=-1)
|
||||
|
||||
devel_y_pred = devel_y_pred.astype('bool')
|
||||
devel_y = devel_y.astype('bool')
|
||||
|
||||
# devel metrics
|
||||
print('DEVEL')
|
||||
uar = recall_score(devel_y, devel_y_pred, average='macro')
|
||||
cm = confusion_matrix(devel_y, devel_y_pred)
|
||||
print(f'UAR: {uar}\n{classification_report(devel_y, devel_y_pred)}\n\nConfusion Matrix:\n\n{cm}')
|
||||
|
||||
|
||||
model.fit(
|
||||
np.asarray(x).astype(np.float32),
|
||||
np.asarray(y).astype(np.float32),
|
||||
validation_split=0.2,
|
||||
epochs=20,
|
||||
batch_size=64,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
test_y_pred = model.predict(np.asarray(x_test).astype(np.float32), verbose=1)
|
||||
test_y_pred = test_y_pred.argmax(axis=-1)
|
||||
|
||||
test_y_pred = test_y_pred.astype('bool')
|
||||
test_y = test_y.astype('bool')
|
||||
|
||||
# devel metrics
|
||||
print('DEVEL')
|
||||
uar = recall_score(devel_y, devel_y_pred, average='macro')
|
||||
cm = confusion_matrix(devel_y, devel_y_pred)
|
||||
print(f'UAR: {uar}\n{classification_report(devel_y, devel_y_pred)}\n\nConfusion Matrix:\n\n{cm}')
|
||||
|
||||
df_predictions = pd.DataFrame({'filename': devel_names.tolist(), 'prediction': devel_y_pred.tolist()})
|
||||
df_predictions.to_csv(os.path.join('./results/trans_hand_vgg/', 'devel.predictions.csv'), index=False)
|
||||
|
||||
# test metrics
|
||||
print('TEST')
|
||||
uar = recall_score(test_y, test_y_pred, average='macro')
|
||||
cm = confusion_matrix(test_y, test_y_pred)
|
||||
print(f'UAR: {uar}\n{classification_report(test_y, test_y_pred)}\n\nConfusion Matrix:\n\n{cm}')
|
||||
|
||||
df_predictions = pd.DataFrame({'filename': test_names.tolist(), 'prediction': test_y_pred.tolist()})
|
||||
df_predictions.to_csv(os.path.join('./results/trans_hand_vgg/', 'test.predictions.csv'), index=False)
|
Loading…
Reference in New Issue
Block a user