131 lines
4.3 KiB
Python
131 lines
4.3 KiB
Python
from __future__ import print_function
|
|
|
|
import tensorflow.compat.v1 as tf
|
|
tf.disable_v2_behavior()
|
|
|
|
import pandas as pd
|
|
import os
|
|
import json
|
|
import sys
|
|
import numpy as np
|
|
|
|
import librosa
|
|
|
|
import urllib
|
|
sys.path.append('vggish')
|
|
import vggish_input
|
|
import vggish_params
|
|
import vggish_slim
|
|
|
|
SR = 22050 # sample rate
|
|
SR_VGG = 16000 # VGG pretrained model sample rate
|
|
FRAME_LEN = int(SR / 10) # 100 ms
|
|
HOP = int(FRAME_LEN / 2) # 50%overlap, 5ms
|
|
|
|
|
|
def download(url, dst_dir):
|
|
"""Download file.
|
|
If the file not exist then download it.
|
|
Args:url: Web location of the file.
|
|
Returns: path to downloaded file.
|
|
"""
|
|
filename = url.split('/')[-1]
|
|
filepath = os.path.join(dst_dir, filename)
|
|
if not os.path.exists(filepath):
|
|
def _progress(count, block_size, total_size):
|
|
sys.stdout.write('\r>> Downloading %s %.1f%%' %
|
|
(filename,
|
|
float(count * block_size) / float(total_size) * 100.0))
|
|
sys.stdout.flush()
|
|
|
|
filepath, _ = urllib.request.urlretrieve(url, filepath, _progress)
|
|
statinfo = os.stat(filepath)
|
|
print('Successfully downloaded:', filename, statinfo.st_size, 'bytes.')
|
|
return filepath
|
|
|
|
def sta_fun_2(npdata): # 1D np array
|
|
"""Extract various statistical features from the numpy array provided as input.
|
|
|
|
:param np_data: the numpy array to extract the features from
|
|
:type np_data: numpy.ndarray
|
|
:return: The extracted features as a vector
|
|
:rtype: numpy.ndarray
|
|
"""
|
|
|
|
# perform a sanity check
|
|
if npdata is None:
|
|
raise ValueError("Input array cannot be None")
|
|
|
|
# perform the feature extraction
|
|
Mean = np.mean(npdata, axis=0)
|
|
Std = np.std(npdata, axis=0)
|
|
|
|
# finally return the features in a concatenated array (as a vector)
|
|
return np.concatenate((Mean, Std), axis=0).reshape(1, -1)
|
|
|
|
print("\nTesting your install of VGGish\n")
|
|
# Paths to downloaded VGGish files.
|
|
checkpoint_path = "vggish/vggish_model.ckpt"
|
|
|
|
if not os.path.exists(checkpoint_path): #automatically download the checkpoint if not exist.
|
|
url = 'https://storage.googleapis.com/audioset/vggish_model.ckpt'
|
|
download(url, './vggish/')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# data path (raw_files\devel OR test OR train folder)
|
|
path = sys.argv[1]
|
|
|
|
##feature extraction
|
|
with tf.Graph().as_default(), tf.Session() as sess:
|
|
# load pre-trained model
|
|
vggish_slim.define_vggish_slim()
|
|
vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)
|
|
features_tensor = sess.graph.get_tensor_by_name(vggish_params.INPUT_TENSOR_NAME)
|
|
embedding_tensor = sess.graph.get_tensor_by_name(
|
|
vggish_params.OUTPUT_TENSOR_NAME
|
|
)
|
|
|
|
x_data = []
|
|
y_label = []
|
|
y_uid = []
|
|
|
|
# extract features
|
|
files = os.listdir(path)
|
|
for file in files:
|
|
try:
|
|
sample_path = os.path.join(path,file)
|
|
file_b = sample_path
|
|
y, sr = librosa.load(
|
|
file_b, sr=SR, mono=True, offset=0.0, duration=None
|
|
)
|
|
except IOError:
|
|
print("file doesn't exit")
|
|
continue
|
|
|
|
yt, index = librosa.effects.trim(
|
|
y, frame_length=FRAME_LEN, hop_length=HOP
|
|
)
|
|
duration = librosa.get_duration(y=yt, sr=sr)
|
|
if duration < 2:
|
|
continue
|
|
input_batch = vggish_input.waveform_to_examples(
|
|
yt, SR_VGG
|
|
) # ?x96x64 --> ?x128
|
|
[features] = sess.run(
|
|
[embedding_tensor], feed_dict={features_tensor: input_batch}
|
|
)
|
|
features = sta_fun_2(features)
|
|
|
|
x_data.append(features.tolist())
|
|
y_uid.append(file)
|
|
|
|
#save features in numpy.array
|
|
x_data = np.array(x_data)
|
|
labels_path = 'labels\\' + os.path.basename(os.path.normpath(path)) + '.csv'
|
|
df = pd.read_csv(labels_path, sep =',')
|
|
y_label = df.label
|
|
|
|
np.save(os.path.join('vgg_features',"x_" + os.path.basename(os.path.normpath(path)) + "_data_vgg.npy"), x_data)
|
|
np.save(os.path.join('vgg_features',"y_" + os.path.basename(os.path.normpath(path)) + "_label_vgg.npy"), y_label)
|