add VGGSH tool
This commit is contained in:
parent
8467945baf
commit
d366309629
BIN
vggish/__pycache__/mel_features.cpython-37.pyc
Normal file
BIN
vggish/__pycache__/mel_features.cpython-37.pyc
Normal file
Binary file not shown.
BIN
vggish/__pycache__/vggish_input.cpython-37.pyc
Normal file
BIN
vggish/__pycache__/vggish_input.cpython-37.pyc
Normal file
Binary file not shown.
BIN
vggish/__pycache__/vggish_params.cpython-37.pyc
Normal file
BIN
vggish/__pycache__/vggish_params.cpython-37.pyc
Normal file
Binary file not shown.
BIN
vggish/__pycache__/vggish_slim.cpython-37.pyc
Normal file
BIN
vggish/__pycache__/vggish_slim.cpython-37.pyc
Normal file
Binary file not shown.
446
vggish/mel_features.py
Normal file
446
vggish/mel_features.py
Normal file
@ -0,0 +1,446 @@
|
||||
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
#
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
#
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
#
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
||||
"""Defines routines to compute mel spectrogram features from audio waveform."""
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def frame(data, window_length, hop_length):
|
||||
|
||||
"""Convert array into a sequence of successive possibly overlapping frames.
|
||||
|
||||
|
||||
|
||||
An n-dimensional array of shape (num_samples, ...) is converted into an
|
||||
|
||||
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
|
||||
|
||||
starts hop_length points after the preceding one.
|
||||
|
||||
|
||||
|
||||
This is accomplished using stride_tricks, so the original data is not
|
||||
|
||||
copied. However, there is no zero-padding, so any incomplete frames at the
|
||||
|
||||
end are not included.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
data: np.array of dimension N >= 1.
|
||||
|
||||
window_length: Number of samples in each frame.
|
||||
|
||||
hop_length: Advance (in samples) between each window.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
(N+1)-D np.array with as many rows as there are complete frames that can be
|
||||
|
||||
extracted.
|
||||
|
||||
"""
|
||||
|
||||
num_samples = data.shape[0]
|
||||
|
||||
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
|
||||
|
||||
shape = (num_frames, window_length) + data.shape[1:]
|
||||
|
||||
strides = (data.strides[0] * hop_length,) + data.strides
|
||||
|
||||
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def periodic_hann(window_length):
|
||||
|
||||
"""Calculate a "periodic" Hann window.
|
||||
|
||||
|
||||
|
||||
The classic Hann window is defined as a raised cosine that starts and
|
||||
|
||||
ends on zero, and where every value appears twice, except the middle
|
||||
|
||||
point for an odd-length window. Matlab calls this a "symmetric" window
|
||||
|
||||
and np.hanning() returns it. However, for Fourier analysis, this
|
||||
|
||||
actually represents just over one cycle of a period N-1 cosine, and
|
||||
|
||||
thus is not compactly expressed on a length-N Fourier basis. Instead,
|
||||
|
||||
it's better to use a raised cosine that ends just before the final
|
||||
|
||||
zero value - i.e. a complete cycle of a period-N cosine. Matlab
|
||||
|
||||
calls this a "periodic" window. This routine calculates it.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
window_length: The number of points in the returned window.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
A 1D np.array containing the periodic hann window.
|
||||
|
||||
"""
|
||||
|
||||
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
|
||||
|
||||
np.arange(window_length)))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def stft_magnitude(signal, fft_length,
|
||||
|
||||
hop_length=None,
|
||||
|
||||
window_length=None):
|
||||
|
||||
"""Calculate the short-time Fourier transform magnitude.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
signal: 1D np.array of the input time-domain signal.
|
||||
|
||||
fft_length: Size of the FFT to apply.
|
||||
|
||||
hop_length: Advance (in samples) between each frame passed to FFT.
|
||||
|
||||
window_length: Length of each block of samples to pass to FFT.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
2D np.array where each row contains the magnitudes of the fft_length/2+1
|
||||
|
||||
unique values of the FFT for the corresponding frame of input samples.
|
||||
|
||||
"""
|
||||
|
||||
frames = frame(signal, window_length, hop_length)
|
||||
|
||||
# Apply frame window to each frame. We use a periodic Hann (cosine of period
|
||||
|
||||
# window_length) instead of the symmetric Hann of np.hanning (period
|
||||
|
||||
# window_length-1).
|
||||
|
||||
window = periodic_hann(window_length)
|
||||
|
||||
windowed_frames = frames * window
|
||||
|
||||
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Mel spectrum constants and functions.
|
||||
|
||||
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
||||
|
||||
_MEL_HIGH_FREQUENCY_Q = 1127.0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def hertz_to_mel(frequencies_hertz):
|
||||
|
||||
"""Convert frequencies to mel scale using HTK formula.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
frequencies_hertz: Scalar or np.array of frequencies in hertz.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Object of same size as frequencies_hertz containing corresponding values
|
||||
|
||||
on the mel scale.
|
||||
|
||||
"""
|
||||
|
||||
return _MEL_HIGH_FREQUENCY_Q * np.log(
|
||||
|
||||
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def spectrogram_to_mel_matrix(num_mel_bins=20,
|
||||
|
||||
num_spectrogram_bins=129,
|
||||
|
||||
audio_sample_rate=8000,
|
||||
|
||||
lower_edge_hertz=125.0,
|
||||
|
||||
upper_edge_hertz=3800.0):
|
||||
|
||||
"""Return a matrix that can post-multiply spectrogram rows to make mel.
|
||||
|
||||
|
||||
|
||||
Returns a np.array matrix A that can be used to post-multiply a matrix S of
|
||||
|
||||
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
|
||||
|
||||
"mel spectrogram" M of frames x num_mel_bins. M = S A.
|
||||
|
||||
|
||||
|
||||
The classic HTK algorithm exploits the complementarity of adjacent mel bands
|
||||
|
||||
to multiply each FFT bin by only one mel weight, then add it, with positive
|
||||
|
||||
and negative signs, to the two adjacent mel bands to which that bin
|
||||
|
||||
contributes. Here, by expressing this operation as a matrix multiply, we go
|
||||
|
||||
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
|
||||
|
||||
num_fft^2 multiplies and adds. However, because these are all presumably
|
||||
|
||||
accomplished in a single call to np.dot(), it's not clear which approach is
|
||||
|
||||
faster in Python. The matrix multiplication has the attraction of being more
|
||||
|
||||
general and flexible, and much easier to read.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
num_mel_bins: How many bands in the resulting mel spectrum. This is
|
||||
|
||||
the number of columns in the output matrix.
|
||||
|
||||
num_spectrogram_bins: How many bins there are in the source spectrogram
|
||||
|
||||
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
|
||||
|
||||
only contains the nonredundant FFT bins.
|
||||
|
||||
audio_sample_rate: Samples per second of the audio at the input to the
|
||||
|
||||
spectrogram. We need this to figure out the actual frequencies for
|
||||
|
||||
each spectrogram bin, which dictates how they are mapped into mel.
|
||||
|
||||
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
|
||||
|
||||
spectrum. This corresponds to the lower edge of the lowest triangular
|
||||
|
||||
band.
|
||||
|
||||
upper_edge_hertz: The desired top edge of the highest frequency band.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
An np.array with shape (num_spectrogram_bins, num_mel_bins).
|
||||
|
||||
|
||||
|
||||
Raises:
|
||||
|
||||
ValueError: if frequency edges are incorrectly ordered or out of range.
|
||||
|
||||
"""
|
||||
|
||||
nyquist_hertz = audio_sample_rate / 2.
|
||||
|
||||
if lower_edge_hertz < 0.0:
|
||||
|
||||
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
|
||||
|
||||
if lower_edge_hertz >= upper_edge_hertz:
|
||||
|
||||
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
|
||||
|
||||
(lower_edge_hertz, upper_edge_hertz))
|
||||
|
||||
if upper_edge_hertz > nyquist_hertz:
|
||||
|
||||
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
|
||||
|
||||
(upper_edge_hertz, nyquist_hertz))
|
||||
|
||||
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
|
||||
|
||||
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
|
||||
|
||||
# The i'th mel band (starting from i=1) has center frequency
|
||||
|
||||
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
|
||||
|
||||
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
|
||||
|
||||
# the band_edges_mel arrays.
|
||||
|
||||
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
|
||||
|
||||
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
|
||||
|
||||
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
|
||||
|
||||
# of spectrogram values.
|
||||
|
||||
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
|
||||
|
||||
for i in range(num_mel_bins):
|
||||
|
||||
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
|
||||
|
||||
# Calculate lower and upper slopes for every spectrogram bin.
|
||||
|
||||
# Line segments are linear in the *mel* domain, not hertz.
|
||||
|
||||
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
|
||||
|
||||
(center_mel - lower_edge_mel))
|
||||
|
||||
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
|
||||
|
||||
(upper_edge_mel - center_mel))
|
||||
|
||||
# .. then intersect them with each other and zero.
|
||||
|
||||
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
|
||||
|
||||
upper_slope))
|
||||
|
||||
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
|
||||
|
||||
# coefficient.
|
||||
|
||||
mel_weights_matrix[0, :] = 0.0
|
||||
|
||||
return mel_weights_matrix
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def log_mel_spectrogram(data,
|
||||
|
||||
audio_sample_rate=8000,
|
||||
|
||||
log_offset=0.0,
|
||||
|
||||
window_length_secs=0.025,
|
||||
|
||||
hop_length_secs=0.010,
|
||||
|
||||
**kwargs):
|
||||
|
||||
"""Convert waveform to a log magnitude mel-frequency spectrogram.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
data: 1D np.array of waveform data.
|
||||
|
||||
audio_sample_rate: The sampling rate of data.
|
||||
|
||||
log_offset: Add this to values when taking log to avoid -Infs.
|
||||
|
||||
window_length_secs: Duration of each window to analyze.
|
||||
|
||||
hop_length_secs: Advance between successive analysis windows.
|
||||
|
||||
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
|
||||
|
||||
magnitudes for successive frames.
|
||||
|
||||
"""
|
||||
|
||||
window_length_samples = int(round(audio_sample_rate * window_length_secs))
|
||||
|
||||
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
|
||||
|
||||
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
|
||||
|
||||
spectrogram = stft_magnitude(
|
||||
|
||||
data,
|
||||
|
||||
fft_length=fft_length,
|
||||
|
||||
hop_length=hop_length_samples,
|
||||
|
||||
window_length=window_length_samples)
|
||||
|
||||
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
|
||||
|
||||
num_spectrogram_bins=spectrogram.shape[1],
|
||||
|
||||
audio_sample_rate=audio_sample_rate, **kwargs))
|
||||
|
||||
return np.log(mel_spectrogram + log_offset)
|
||||
|
3
vggish/readme.md
Normal file
3
vggish/readme.md
Normal file
@ -0,0 +1,3 @@
|
||||
These codes are downloaded from https://modelzoo.co/model/audioset.
|
||||
|
||||
The check point is also needed. If it does not exist, it will be downloaded automatically when running our codes.
|
193
vggish/vggish_input.py
Normal file
193
vggish/vggish_input.py
Normal file
@ -0,0 +1,193 @@
|
||||
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
#
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
#
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
#
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
||||
"""Compute input examples for VGGish from audio waveform."""
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
#import resampy
|
||||
|
||||
|
||||
|
||||
import mel_features
|
||||
|
||||
import vggish_params
|
||||
|
||||
|
||||
|
||||
try:
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
|
||||
def wav_read(wav_file):
|
||||
|
||||
wav_data, sr = sf.read(wav_file, dtype='int16')
|
||||
|
||||
return wav_data, sr
|
||||
|
||||
|
||||
|
||||
except ImportError:
|
||||
|
||||
|
||||
|
||||
def wav_read(wav_file):
|
||||
|
||||
raise NotImplementedError('WAV file reading requires soundfile package.')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def waveform_to_examples(data, sample_rate):
|
||||
|
||||
"""Converts audio waveform into an array of examples for VGGish.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
data: np.array of either one dimension (mono) or two dimensions
|
||||
|
||||
(multi-channel, with the outer dimension representing channels).
|
||||
|
||||
Each sample is generally expected to lie in the range [-1.0, +1.0],
|
||||
|
||||
although this is not required.
|
||||
|
||||
sample_rate: Sample rate of data.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
|
||||
|
||||
a sequence of examples, each of which contains a patch of log mel
|
||||
|
||||
spectrogram, covering num_frames frames of audio and num_bands mel frequency
|
||||
|
||||
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
|
||||
|
||||
"""
|
||||
|
||||
# Convert to mono.
|
||||
|
||||
if len(data.shape) > 1:
|
||||
|
||||
data = np.mean(data, axis=1)
|
||||
|
||||
# Resample to the rate assumed by VGGish.
|
||||
|
||||
#if sample_rate != vggish_params.SAMPLE_RATE:
|
||||
|
||||
#data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
|
||||
|
||||
|
||||
|
||||
# Compute log mel spectrogram features.
|
||||
|
||||
log_mel = mel_features.log_mel_spectrogram(
|
||||
|
||||
data,
|
||||
|
||||
audio_sample_rate=vggish_params.SAMPLE_RATE,
|
||||
|
||||
log_offset=vggish_params.LOG_OFFSET,
|
||||
|
||||
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
|
||||
|
||||
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
|
||||
|
||||
num_mel_bins=vggish_params.NUM_MEL_BINS,
|
||||
|
||||
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
|
||||
|
||||
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
|
||||
|
||||
|
||||
|
||||
# Frame features into examples.
|
||||
|
||||
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
|
||||
|
||||
example_window_length = int(round(
|
||||
|
||||
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
|
||||
|
||||
example_hop_length = int(round(
|
||||
|
||||
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
|
||||
|
||||
log_mel_examples = mel_features.frame(
|
||||
|
||||
log_mel,
|
||||
|
||||
window_length=example_window_length,
|
||||
|
||||
hop_length=example_hop_length)
|
||||
|
||||
return log_mel_examples
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def wavfile_to_examples(wav_file):
|
||||
|
||||
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
wav_file: String path to a file, or a file-like object. The file
|
||||
|
||||
is assumed to contain WAV audio data with signed 16-bit PCM samples.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
See waveform_to_examples.
|
||||
|
||||
"""
|
||||
|
||||
wav_data, sr = wav_read(wav_file)
|
||||
|
||||
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
|
||||
|
||||
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
|
||||
|
||||
return waveform_to_examples(samples, sr)
|
BIN
vggish/vggish_model.ckpt
Normal file
BIN
vggish/vggish_model.ckpt
Normal file
Binary file not shown.
106
vggish/vggish_params.py
Normal file
106
vggish/vggish_params.py
Normal file
@ -0,0 +1,106 @@
|
||||
|
||||
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
#
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
#
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
#
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
||||
"""Global parameters for the VGGish model.
|
||||
|
||||
|
||||
|
||||
See vggish_slim.py for more information.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
# Architectural constants.
|
||||
|
||||
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
|
||||
|
||||
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
|
||||
|
||||
EMBEDDING_SIZE = 128 # Size of embedding layer.
|
||||
|
||||
|
||||
|
||||
# Hyperparameters used in feature and example generation.
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
STFT_WINDOW_LENGTH_SECONDS = 0.025
|
||||
|
||||
STFT_HOP_LENGTH_SECONDS = 0.010
|
||||
|
||||
NUM_MEL_BINS = NUM_BANDS
|
||||
|
||||
MEL_MIN_HZ = 125
|
||||
|
||||
MEL_MAX_HZ = 7500
|
||||
|
||||
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
|
||||
|
||||
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
|
||||
|
||||
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
|
||||
|
||||
|
||||
|
||||
# Parameters used for embedding postprocessing.
|
||||
|
||||
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
|
||||
|
||||
PCA_MEANS_NAME = 'pca_means'
|
||||
|
||||
QUANTIZE_MIN_VAL = -2.0
|
||||
|
||||
QUANTIZE_MAX_VAL = +2.0
|
||||
|
||||
|
||||
|
||||
# Hyperparameters used in training.
|
||||
|
||||
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
|
||||
|
||||
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
|
||||
|
||||
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
|
||||
|
||||
|
||||
|
||||
# Names of ops, tensors, and features.
|
||||
|
||||
INPUT_OP_NAME = 'vggish/input_features'
|
||||
|
||||
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
|
||||
|
||||
OUTPUT_OP_NAME = 'vggish/embedding'
|
||||
|
||||
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
|
||||
|
||||
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
|
181
vggish/vggish_postprocess.py
Normal file
181
vggish/vggish_postprocess.py
Normal file
@ -0,0 +1,181 @@
|
||||
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
#
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
#
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
#
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
||||
"""Post-process embeddings from VGGish."""
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
import vggish_params
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Postprocessor(object):
|
||||
|
||||
"""Post-processes VGGish embeddings.
|
||||
|
||||
|
||||
|
||||
The initial release of AudioSet included 128-D VGGish embeddings for each
|
||||
|
||||
segment of AudioSet. These released embeddings were produced by applying
|
||||
|
||||
a PCA transformation (technically, a whitening transform is included as well)
|
||||
|
||||
and 8-bit quantization to the raw embedding output from VGGish, in order to
|
||||
|
||||
stay compatible with the YouTube-8M project which provides visual embeddings
|
||||
|
||||
in the same format for a large set of YouTube videos. This class implements
|
||||
|
||||
the same PCA (with whitening) and quantization transformations.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def __init__(self, pca_params_npz_path):
|
||||
|
||||
"""Constructs a postprocessor.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
pca_params_npz_path: Path to a NumPy-format .npz file that
|
||||
|
||||
contains the PCA parameters used in postprocessing.
|
||||
|
||||
"""
|
||||
|
||||
params = np.load(pca_params_npz_path)
|
||||
|
||||
self._pca_matrix = params[vggish_params.PCA_EIGEN_VECTORS_NAME]
|
||||
|
||||
# Load means into a column vector for easier broadcasting later.
|
||||
|
||||
self._pca_means = params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1)
|
||||
|
||||
assert self._pca_matrix.shape == (
|
||||
|
||||
vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE), (
|
||||
|
||||
'Bad PCA matrix shape: %r' % (self._pca_matrix.shape,))
|
||||
|
||||
assert self._pca_means.shape == (vggish_params.EMBEDDING_SIZE, 1), (
|
||||
|
||||
'Bad PCA means shape: %r' % (self._pca_means.shape,))
|
||||
|
||||
|
||||
|
||||
def postprocess(self, embeddings_batch):
|
||||
|
||||
"""Applies postprocessing to a batch of embeddings.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
embeddings_batch: An nparray of shape [batch_size, embedding_size]
|
||||
|
||||
containing output from the embedding layer of VGGish.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
An nparray of the same shape as the input but of type uint8,
|
||||
|
||||
containing the PCA-transformed and quantized version of the input.
|
||||
|
||||
"""
|
||||
|
||||
assert len(embeddings_batch.shape) == 2, (
|
||||
|
||||
'Expected 2-d batch, got %r' % (embeddings_batch.shape,))
|
||||
|
||||
assert embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE, (
|
||||
|
||||
'Bad batch shape: %r' % (embeddings_batch.shape,))
|
||||
|
||||
|
||||
|
||||
# Apply PCA.
|
||||
|
||||
# - Embeddings come in as [batch_size, embedding_size].
|
||||
|
||||
# - Transpose to [embedding_size, batch_size].
|
||||
|
||||
# - Subtract pca_means column vector from each column.
|
||||
|
||||
# - Premultiply by PCA matrix of shape [output_dims, input_dims]
|
||||
|
||||
# where both are are equal to embedding_size in our case.
|
||||
|
||||
# - Transpose result back to [batch_size, embedding_size].
|
||||
|
||||
pca_applied = np.dot(self._pca_matrix,
|
||||
|
||||
(embeddings_batch.T - self._pca_means)).T
|
||||
|
||||
|
||||
|
||||
# Quantize by:
|
||||
|
||||
# - clipping to [min, max] range
|
||||
|
||||
clipped_embeddings = np.clip(
|
||||
|
||||
pca_applied, vggish_params.QUANTIZE_MIN_VAL,
|
||||
|
||||
vggish_params.QUANTIZE_MAX_VAL)
|
||||
|
||||
# - convert to 8-bit in range [0.0, 255.0]
|
||||
|
||||
quantized_embeddings = (
|
||||
|
||||
(clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) *
|
||||
|
||||
(255.0 /
|
||||
|
||||
(vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)))
|
||||
|
||||
# - cast 8-bit float to uint8
|
||||
|
||||
quantized_embeddings = quantized_embeddings.astype(np.uint8)
|
||||
|
||||
|
||||
|
||||
return quantized_embeddings
|
239
vggish/vggish_slim.py
Normal file
239
vggish/vggish_slim.py
Normal file
@ -0,0 +1,239 @@
|
||||
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
#
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
#
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
#
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
||||
"""Defines the 'VGGish' model used to generate AudioSet embedding features.
|
||||
|
||||
|
||||
|
||||
The public AudioSet release (https://research.google.com/audioset/download.html)
|
||||
|
||||
includes 128-D features extracted from the embedding layer of a VGG-like model
|
||||
|
||||
that was trained on a large Google-internal YouTube dataset. Here we provide
|
||||
|
||||
a TF-Slim definition of the same model, without any dependences on libraries
|
||||
|
||||
internal to Google. We call it 'VGGish'.
|
||||
|
||||
|
||||
|
||||
Note that we only define the model up to the embedding layer, which is the
|
||||
|
||||
penultimate layer before the final classifier layer. We also provide various
|
||||
|
||||
hyperparameter values (in vggish_params.py) that were used to train this model
|
||||
|
||||
internally.
|
||||
|
||||
|
||||
|
||||
For comparison, here is TF-Slim's VGG definition:
|
||||
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.disable_v2_behavior()
|
||||
|
||||
import tf_slim as slim
|
||||
|
||||
|
||||
|
||||
import vggish_params as params
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def define_vggish_slim(training=False):
|
||||
|
||||
"""Defines the VGGish TensorFlow model.
|
||||
|
||||
|
||||
|
||||
All ops are created in the current default graph, under the scope 'vggish/'.
|
||||
|
||||
|
||||
|
||||
The input is a placeholder named 'vggish/input_features' of type float32 and
|
||||
|
||||
shape [batch_size, num_frames, num_bands] where batch_size is variable and
|
||||
|
||||
num_frames and num_bands are constants, and [num_frames, num_bands] represents
|
||||
|
||||
a log-mel-scale spectrogram patch covering num_bands frequency bands and
|
||||
|
||||
num_frames time frames (where each frame step is usually 10ms). This is
|
||||
|
||||
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
|
||||
|
||||
The output is an op named 'vggish/embedding' which produces the activations of
|
||||
|
||||
a 128-D embedding layer, which is usually the penultimate layer when used as
|
||||
|
||||
part of a full model with a final classifier layer.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
training: If true, all parameters are marked trainable.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
The op 'vggish/embeddings'.
|
||||
|
||||
"""
|
||||
|
||||
# Defaults:
|
||||
|
||||
# - All weights are initialized to N(0, INIT_STDDEV).
|
||||
|
||||
# - All biases are initialized to 0.
|
||||
|
||||
# - All activations are ReLU.
|
||||
|
||||
# - All convolutions are 3x3 with stride 1 and SAME padding.
|
||||
|
||||
# - All max-pools are 2x2 with stride 2 and SAME padding.
|
||||
|
||||
with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_initializer=tf.truncated_normal_initializer(stddev=params.INIT_STDDEV), biases_initializer=tf.zeros_initializer(),activation_fn=tf.nn.relu,trainable=training), slim.arg_scope([slim.conv2d],kernel_size=[3, 3], stride=1, padding='SAME'), slim.arg_scope([slim.max_pool2d],kernel_size=[2, 2], stride=2, padding='SAME'), tf.variable_scope('vggish'):
|
||||
|
||||
# Input: a batch of 2-D log-mel-spectrogram patches.
|
||||
|
||||
features = tf.placeholder(
|
||||
|
||||
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
|
||||
|
||||
name='input_features')
|
||||
|
||||
# Reshape to 4-D so that we can convolve a batch with conv2d().
|
||||
|
||||
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
|
||||
|
||||
|
||||
|
||||
# The VGG stack of alternating convolutions and max-pools.
|
||||
|
||||
net = slim.conv2d(net, 64, scope='conv1')
|
||||
|
||||
net = slim.max_pool2d(net, scope='pool1')
|
||||
|
||||
net = slim.conv2d(net, 128, scope='conv2')
|
||||
|
||||
net = slim.max_pool2d(net, scope='pool2')
|
||||
|
||||
net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')
|
||||
|
||||
net = slim.max_pool2d(net, scope='pool3')
|
||||
|
||||
net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')
|
||||
|
||||
net = slim.max_pool2d(net, scope='pool4')
|
||||
|
||||
|
||||
|
||||
# Flatten before entering fully-connected layers
|
||||
|
||||
net = slim.flatten(net)
|
||||
|
||||
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
|
||||
|
||||
# The embedding layer.
|
||||
|
||||
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
|
||||
|
||||
return tf.identity(net, name='embedding')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def load_vggish_slim_checkpoint(session, checkpoint_path):
|
||||
|
||||
"""Loads a pre-trained VGGish-compatible checkpoint.
|
||||
|
||||
|
||||
|
||||
This function can be used as an initialization function (referred to as
|
||||
|
||||
init_fn in TensorFlow documentation) which is called in a Session after
|
||||
|
||||
initializating all variables. When used as an init_fn, this will load
|
||||
|
||||
a pre-trained checkpoint that is compatible with the VGGish model
|
||||
|
||||
definition. Only variables defined by VGGish will be loaded.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
session: an active TensorFlow session.
|
||||
|
||||
checkpoint_path: path to a file containing a checkpoint that is
|
||||
|
||||
compatible with the VGGish model definition.
|
||||
|
||||
"""
|
||||
|
||||
# Get the list of names of all VGGish variables that exist in
|
||||
|
||||
# the checkpoint (i.e., all inference-mode VGGish variables).
|
||||
|
||||
with tf.Graph().as_default():
|
||||
|
||||
define_vggish_slim(training=False)
|
||||
|
||||
vggish_var_names = [v.name for v in tf.global_variables()]
|
||||
|
||||
|
||||
|
||||
# Get the list of all currently existing variables that match
|
||||
|
||||
# the list of variable names we just computed.
|
||||
|
||||
vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
|
||||
|
||||
|
||||
|
||||
# Use a Saver to restore just the variables selected above.
|
||||
|
||||
saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained',
|
||||
|
||||
write_version=1)
|
||||
|
||||
saver.restore(session, checkpoint_path)
|
387
vggish/vggish_train_demo.py
Normal file
387
vggish/vggish_train_demo.py
Normal file
@ -0,0 +1,387 @@
|
||||
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
#
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
#
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
#
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
||||
r"""A simple demonstration of running VGGish in training mode.
|
||||
|
||||
|
||||
|
||||
This is intended as a toy example that demonstrates how to use the VGGish model
|
||||
|
||||
definition within a larger model that adds more layers on top, and then train
|
||||
|
||||
the larger model. If you let VGGish train as well, then this allows you to
|
||||
|
||||
fine-tune the VGGish model parameters for your application. If you don't let
|
||||
|
||||
VGGish train, then you use VGGish as a feature extractor for the layers above
|
||||
|
||||
it.
|
||||
|
||||
|
||||
|
||||
For this toy task, we are training a classifier to distinguish between three
|
||||
|
||||
classes: sine waves, constant signals, and white noise. We generate synthetic
|
||||
|
||||
waveforms from each of these classes, convert into shuffled batches of log mel
|
||||
|
||||
spectrogram examples with associated labels, and feed the batches into a model
|
||||
|
||||
that includes VGGish at the bottom and a couple of additional layers on top. We
|
||||
|
||||
also plumb in labels that are associated with the examples, which feed a label
|
||||
|
||||
loss used for training.
|
||||
|
||||
|
||||
|
||||
Usage:
|
||||
|
||||
# Run training for 100 steps using a model checkpoint in the default
|
||||
|
||||
# location (vggish_model.ckpt in the current directory). Allow VGGish
|
||||
|
||||
# to get fine-tuned.
|
||||
|
||||
$ python vggish_train_demo.py --num_batches 100
|
||||
|
||||
|
||||
|
||||
# Same as before but run for fewer steps and don't change VGGish parameters
|
||||
|
||||
# and use a checkpoint in a different location
|
||||
|
||||
$ python vggish_train_demo.py --num_batches 50 \
|
||||
|
||||
--train_vggish=False \
|
||||
|
||||
--checkpoint /path/to/model/checkpoint
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
|
||||
from random import shuffle
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.disable_v2_behavior()
|
||||
|
||||
import tf_slim as slim
|
||||
|
||||
|
||||
|
||||
import vggish_input
|
||||
|
||||
import vggish_params
|
||||
|
||||
import vggish_slim
|
||||
|
||||
|
||||
|
||||
flags = tf.app.flags
|
||||
|
||||
|
||||
|
||||
flags.DEFINE_integer(
|
||||
|
||||
'num_batches', 30,
|
||||
|
||||
'Number of batches of examples to feed into the model. Each batch is of '
|
||||
|
||||
'variable size and contains shuffled examples of each class of audio.')
|
||||
|
||||
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
|
||||
'train_vggish', True,
|
||||
|
||||
'If True, allow VGGish parameters to change during training, thus '
|
||||
|
||||
'fine-tuning VGGish. If False, VGGish parameters are fixed, thus using '
|
||||
|
||||
'VGGish as a fixed feature extractor.')
|
||||
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
|
||||
'checkpoint', 'vggish_model.ckpt',
|
||||
|
||||
'Path to the VGGish checkpoint file.')
|
||||
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
|
||||
_NUM_CLASSES = 3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_examples_batch():
|
||||
|
||||
"""Returns a shuffled batch of examples of all audio classes.
|
||||
|
||||
|
||||
|
||||
Note that this is just a toy function because this is a simple demo intended
|
||||
|
||||
to illustrate how the training code might work.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
a tuple (features, labels) where features is a NumPy array of shape
|
||||
|
||||
[batch_size, num_frames, num_bands] where the batch_size is variable and
|
||||
|
||||
each row is a log mel spectrogram patch of shape [num_frames, num_bands]
|
||||
|
||||
suitable for feeding VGGish, while labels is a NumPy array of shape
|
||||
|
||||
[batch_size, num_classes] where each row is a multi-hot label vector that
|
||||
|
||||
provides the labels for corresponding rows in features.
|
||||
|
||||
"""
|
||||
|
||||
# Make a waveform for each class.
|
||||
|
||||
num_seconds = 5
|
||||
|
||||
sr = 44100 # Sampling rate.
|
||||
|
||||
t = np.linspace(0, num_seconds, int(num_seconds * sr)) # Time axis.
|
||||
|
||||
# Random sine wave.
|
||||
|
||||
freq = np.random.uniform(100, 1000)
|
||||
|
||||
sine = np.sin(2 * np.pi * freq * t)
|
||||
|
||||
# Random constant signal.
|
||||
|
||||
magnitude = np.random.uniform(-1, 1)
|
||||
|
||||
const = magnitude * t
|
||||
|
||||
# White noise.
|
||||
|
||||
noise = np.random.normal(-1, 1, size=t.shape)
|
||||
|
||||
|
||||
|
||||
# Make examples of each signal and corresponding labels.
|
||||
|
||||
# Sine is class index 0, Const class index 1, Noise class index 2.
|
||||
|
||||
sine_examples = vggish_input.waveform_to_examples(sine, sr)
|
||||
|
||||
sine_labels = np.array([[1, 0, 0]] * sine_examples.shape[0])
|
||||
|
||||
const_examples = vggish_input.waveform_to_examples(const, sr)
|
||||
|
||||
const_labels = np.array([[0, 1, 0]] * const_examples.shape[0])
|
||||
|
||||
noise_examples = vggish_input.waveform_to_examples(noise, sr)
|
||||
|
||||
noise_labels = np.array([[0, 0, 1]] * noise_examples.shape[0])
|
||||
|
||||
|
||||
|
||||
# Shuffle (example, label) pairs across all classes.
|
||||
|
||||
all_examples = np.concatenate((sine_examples, const_examples, noise_examples))
|
||||
|
||||
all_labels = np.concatenate((sine_labels, const_labels, noise_labels))
|
||||
|
||||
labeled_examples = list(zip(all_examples, all_labels))
|
||||
|
||||
shuffle(labeled_examples)
|
||||
|
||||
|
||||
|
||||
# Separate and return the features and labels.
|
||||
|
||||
features = [example for (example, _) in labeled_examples]
|
||||
|
||||
labels = [label for (_, label) in labeled_examples]
|
||||
|
||||
return (features, labels)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def main(_):
|
||||
|
||||
with tf.Graph().as_default(), tf.Session() as sess:
|
||||
|
||||
# Define VGGish.
|
||||
|
||||
embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish)
|
||||
|
||||
|
||||
|
||||
# Define a shallow classification model and associated training ops on top
|
||||
|
||||
# of VGGish.
|
||||
|
||||
with tf.variable_scope('mymodel'):
|
||||
|
||||
# Add a fully connected layer with 100 units.
|
||||
|
||||
num_units = 100
|
||||
|
||||
fc = slim.fully_connected(embeddings, num_units)
|
||||
|
||||
|
||||
|
||||
# Add a classifier layer at the end, consisting of parallel logistic
|
||||
|
||||
# classifiers, one per class. This allows for multi-class tasks.
|
||||
|
||||
logits = slim.fully_connected(
|
||||
|
||||
fc, _NUM_CLASSES, activation_fn=None, scope='logits')
|
||||
|
||||
tf.sigmoid(logits, name='prediction')
|
||||
|
||||
|
||||
|
||||
# Add training ops.
|
||||
|
||||
with tf.variable_scope('train'):
|
||||
|
||||
global_step = tf.Variable(
|
||||
|
||||
0, name='global_step', trainable=False,
|
||||
|
||||
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
|
||||
|
||||
tf.GraphKeys.GLOBAL_STEP])
|
||||
|
||||
|
||||
|
||||
# Labels are assumed to be fed as a batch multi-hot vectors, with
|
||||
|
||||
# a 1 in the position of each positive class label, and 0 elsewhere.
|
||||
|
||||
labels = tf.placeholder(
|
||||
|
||||
tf.float32, shape=(None, _NUM_CLASSES), name='labels')
|
||||
|
||||
|
||||
|
||||
# Cross-entropy label loss.
|
||||
#tf.nn.softmax_cross_entropy_with_logits()
|
||||
xent = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
|
||||
logits=logits, labels=labels, name='xent')
|
||||
|
||||
loss = tf.reduce_mean(xent, name='loss_op')
|
||||
|
||||
tf.summary.scalar('loss', loss)
|
||||
|
||||
|
||||
|
||||
# We use the same optimizer and hyperparameters as used to train VGGish.
|
||||
|
||||
optimizer = tf.train.AdamOptimizer(
|
||||
|
||||
learning_rate=vggish_params.LEARNING_RATE,
|
||||
|
||||
epsilon=vggish_params.ADAM_EPSILON)
|
||||
|
||||
optimizer.minimize(loss, global_step=global_step, name='train_op')
|
||||
|
||||
|
||||
|
||||
# Initialize all variables in the model, and then load the pre-trained
|
||||
|
||||
# VGGish checkpoint.
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
|
||||
|
||||
|
||||
|
||||
# Locate all the tensors and ops we need for the training loop.
|
||||
|
||||
features_tensor = sess.graph.get_tensor_by_name(
|
||||
|
||||
vggish_params.INPUT_TENSOR_NAME)
|
||||
|
||||
labels_tensor = sess.graph.get_tensor_by_name('mymodel/train/labels:0')
|
||||
|
||||
global_step_tensor = sess.graph.get_tensor_by_name(
|
||||
|
||||
'mymodel/train/global_step:0')
|
||||
|
||||
loss_tensor = sess.graph.get_tensor_by_name('mymodel/train/loss_op:0')
|
||||
|
||||
train_op = sess.graph.get_operation_by_name('mymodel/train/train_op')
|
||||
|
||||
|
||||
|
||||
# The training loop.
|
||||
|
||||
for _ in range(FLAGS.num_batches):
|
||||
|
||||
(features, labels) = _get_examples_batch()
|
||||
|
||||
[num_steps, loss, _] = sess.run(
|
||||
|
||||
[global_step_tensor, loss_tensor, train_op],
|
||||
|
||||
feed_dict={features_tensor: features, labels_tensor: labels})
|
||||
|
||||
print('Step %d: loss %g' % (num_steps, loss))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
tf.app.run()
|
Loading…
Reference in New Issue
Block a user