fix: fix the inverse wavelet func

This commit is contained in:
Skudalen 2021-06-24 15:45:33 +02:00
parent 04498a79e6
commit dfc1b4173a

View File

@ -99,11 +99,14 @@ def load_user_emg_data():
return csv_handler.data_container_dict return csv_handler.data_container_dict
# Takes in a df and outputs np arrays for x and y values # Takes in a df and outputs np arrays for x and y values
def prep_df(df:DataFrame): def get_xory_from_df(x_or_y, df:DataFrame):
min, duration = Handler.get_min_max_timestamp(df) swither = {
y = df.iloc[:,1].to_numpy() 'x': df.iloc[:,0].to_numpy(),
return y, duration 'y': df.iloc[:,1].to_numpy()
}
return swither.get(x_or_y, 0)
# Normalizes a ndarray of a signal to the scale of int16(32767) # Normalizes a ndarray of a signal to the scale of int16(32767)
def normalize_wave(y_values): def normalize_wave(y_values):
@ -112,7 +115,7 @@ def normalize_wave(y_values):
# Takes the FFT of a DataFrame object # Takes the FFT of a DataFrame object
def fft_of_df(df:DataFrame): def fft_of_df(df:DataFrame):
y_values, duration = prep_df(df) y_values = get_xory_from_df('y', df)
N = y_values.size N = y_values.size
norm = normalize_wave(y_values) norm = normalize_wave(y_values)
N_trans = fftfreq(N, 1 / SAMPLE_RATE) N_trans = fftfreq(N, 1 / SAMPLE_RATE)
@ -121,7 +124,7 @@ def fft_of_df(df:DataFrame):
# Removes noise with db4 wavelet function # Removes noise with db4 wavelet function
def wavelet_db4_denoising(df:DataFrame): def wavelet_db4_denoising(df:DataFrame):
y_values, duration = prep_df(df) y_values = get_xory_from_df('y', df)
#y_values = normalize_wave(y_values) #y_values = normalize_wave(y_values)
wavelet = pywt.Wavelet('db4') wavelet = pywt.Wavelet('db4')
cA, cD = pywt.dwt(y_values, wavelet) cA, cD = pywt.dwt(y_values, wavelet)
@ -140,10 +143,17 @@ def soft_threshold_filter(cA, cD):
cD_filt = cD cD_filt = cD
return cA_filt, cD_filt return cA_filt, cD_filt
# Inverse dwt for brining denoise signal back to the time domain # Inverse dwt for brining denoise signal back to the time domainfi
def inverse_wavelet(cA_filt, cD_filt): def inverse_wavelet(df, cA_filt, cD_filt):
wavelet = pywt.Wavelet('db4') wavelet = pywt.Wavelet('db4')
y_new_values = pywt.idwt(cA_filt, cD_filt, wavelet) y_new_values = pywt.idwt(cA_filt, cD_filt, wavelet)
new_len = len(y_new_values)
old_len = len(get_xory_from_df('y', df))
if new_len > old_len:
while new_len > old_len:
y_new_values = y_new_values[:-1]
new_len = len(y_new_values)
old_len = len(get_xory_from_df('y', df))
return y_new_values return y_new_values
# Plots DataFrame objects # Plots DataFrame objects
@ -161,11 +171,16 @@ def plot_arrays(N, y):
handler = Handler.CSV_handler() handler = Handler.CSV_handler()
file = "/Exp20201205_2myo_hardTypePP/HaluskaMarek_20201207_1810/myoLeftEmg.csv" file = "/Exp20201205_2myo_hardTypePP/HaluskaMarek_20201207_1810/myoLeftEmg.csv"
df = handler.get_time_emg_table(file, 1) df = handler.get_time_emg_table(file, 1)
N = np.array(range(int(df.iloc[:,1].size + 1))) N = get_xory_from_df('x', df)
plot_df(df) #plot_df(df)
print(len(N))
print(len(get_xory_from_df('y', df)))
x, cA, cD = wavelet_db4_denoising(df) x, cA, cD = wavelet_db4_denoising(df)
plot_arrays(x, cA) #plot_arrays(x, cA)
print(len(cA))
cA_filt, cD_filt = soft_threshold_filter(cA, cD) cA_filt, cD_filt = soft_threshold_filter(cA, cD)
plot_arrays(x, cA_filt) #plot_arrays(x, cA_filt)
y_new_values = inverse_wavelet(cA, cD) print(len(cA_filt))
y_new_values = inverse_wavelet(df, cA, cD)
print(len(y_new_values))
plot_arrays(N, y_new_values) plot_arrays(N, y_new_values)