fix: fix the inverse wavelet func
This commit is contained in:
		
							parent
							
								
									04498a79e6
								
							
						
					
					
						commit
						dfc1b4173a
					
				@ -99,11 +99,14 @@ def load_user_emg_data():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return csv_handler.data_container_dict
 | 
					    return csv_handler.data_container_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Takes in a df and outputs np arrays for x and y values
 | 
					# Takes in a df and outputs np arrays for x and y values
 | 
				
			||||||
def prep_df(df:DataFrame):
 | 
					def get_xory_from_df(x_or_y, df:DataFrame):
 | 
				
			||||||
    min, duration = Handler.get_min_max_timestamp(df)
 | 
					    swither = {
 | 
				
			||||||
    y = df.iloc[:,1].to_numpy()
 | 
					        'x': df.iloc[:,0].to_numpy(),
 | 
				
			||||||
    return y, duration
 | 
					        'y': df.iloc[:,1].to_numpy()
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return swither.get(x_or_y, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Normalizes a ndarray of a signal to the scale of int16(32767)
 | 
					# Normalizes a ndarray of a signal to the scale of int16(32767)
 | 
				
			||||||
def normalize_wave(y_values):
 | 
					def normalize_wave(y_values):
 | 
				
			||||||
@ -112,7 +115,7 @@ def normalize_wave(y_values):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Takes the FFT of a DataFrame object
 | 
					# Takes the FFT of a DataFrame object
 | 
				
			||||||
def fft_of_df(df:DataFrame):
 | 
					def fft_of_df(df:DataFrame):
 | 
				
			||||||
    y_values, duration = prep_df(df)
 | 
					    y_values = get_xory_from_df('y', df)
 | 
				
			||||||
    N = y_values.size
 | 
					    N = y_values.size
 | 
				
			||||||
    norm = normalize_wave(y_values)
 | 
					    norm = normalize_wave(y_values)
 | 
				
			||||||
    N_trans = fftfreq(N, 1 / SAMPLE_RATE)
 | 
					    N_trans = fftfreq(N, 1 / SAMPLE_RATE)
 | 
				
			||||||
@ -121,7 +124,7 @@ def fft_of_df(df:DataFrame):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Removes noise with db4 wavelet function 
 | 
					# Removes noise with db4 wavelet function 
 | 
				
			||||||
def wavelet_db4_denoising(df:DataFrame):
 | 
					def wavelet_db4_denoising(df:DataFrame):
 | 
				
			||||||
    y_values, duration = prep_df(df)
 | 
					    y_values = get_xory_from_df('y', df)
 | 
				
			||||||
    #y_values = normalize_wave(y_values)
 | 
					    #y_values = normalize_wave(y_values)
 | 
				
			||||||
    wavelet = pywt.Wavelet('db4')
 | 
					    wavelet = pywt.Wavelet('db4')
 | 
				
			||||||
    cA, cD = pywt.dwt(y_values, wavelet)
 | 
					    cA, cD = pywt.dwt(y_values, wavelet)
 | 
				
			||||||
@ -140,10 +143,17 @@ def soft_threshold_filter(cA, cD):
 | 
				
			|||||||
    cD_filt = cD 
 | 
					    cD_filt = cD 
 | 
				
			||||||
    return cA_filt, cD_filt
 | 
					    return cA_filt, cD_filt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Inverse dwt for brining denoise signal back to the time domain
 | 
					# Inverse dwt for brining denoise signal back to the time domainfi
 | 
				
			||||||
def inverse_wavelet(cA_filt, cD_filt):
 | 
					def inverse_wavelet(df, cA_filt, cD_filt):
 | 
				
			||||||
    wavelet = pywt.Wavelet('db4')
 | 
					    wavelet = pywt.Wavelet('db4')
 | 
				
			||||||
    y_new_values = pywt.idwt(cA_filt, cD_filt, wavelet)
 | 
					    y_new_values = pywt.idwt(cA_filt, cD_filt, wavelet)
 | 
				
			||||||
 | 
					    new_len = len(y_new_values)
 | 
				
			||||||
 | 
					    old_len = len(get_xory_from_df('y', df))
 | 
				
			||||||
 | 
					    if new_len > old_len:
 | 
				
			||||||
 | 
					        while new_len > old_len:
 | 
				
			||||||
 | 
					            y_new_values = y_new_values[:-1]
 | 
				
			||||||
 | 
					            new_len = len(y_new_values)
 | 
				
			||||||
 | 
					            old_len = len(get_xory_from_df('y', df))
 | 
				
			||||||
    return y_new_values
 | 
					    return y_new_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Plots DataFrame objects
 | 
					# Plots DataFrame objects
 | 
				
			||||||
@ -161,11 +171,16 @@ def plot_arrays(N, y):
 | 
				
			|||||||
handler = Handler.CSV_handler()
 | 
					handler = Handler.CSV_handler()
 | 
				
			||||||
file = "/Exp20201205_2myo_hardTypePP/HaluskaMarek_20201207_1810/myoLeftEmg.csv"
 | 
					file = "/Exp20201205_2myo_hardTypePP/HaluskaMarek_20201207_1810/myoLeftEmg.csv"
 | 
				
			||||||
df = handler.get_time_emg_table(file, 1)
 | 
					df = handler.get_time_emg_table(file, 1)
 | 
				
			||||||
N = np.array(range(int(df.iloc[:,1].size + 1)))
 | 
					N = get_xory_from_df('x', df)
 | 
				
			||||||
plot_df(df)
 | 
					#plot_df(df)
 | 
				
			||||||
 | 
					print(len(N))
 | 
				
			||||||
 | 
					print(len(get_xory_from_df('y', df)))
 | 
				
			||||||
x, cA, cD = wavelet_db4_denoising(df)
 | 
					x, cA, cD = wavelet_db4_denoising(df)
 | 
				
			||||||
plot_arrays(x, cA)
 | 
					#plot_arrays(x, cA)
 | 
				
			||||||
 | 
					print(len(cA))
 | 
				
			||||||
cA_filt, cD_filt = soft_threshold_filter(cA, cD)
 | 
					cA_filt, cD_filt = soft_threshold_filter(cA, cD)
 | 
				
			||||||
plot_arrays(x, cA_filt)
 | 
					#plot_arrays(x, cA_filt)
 | 
				
			||||||
y_new_values = inverse_wavelet(cA, cD)
 | 
					print(len(cA_filt))
 | 
				
			||||||
 | 
					y_new_values = inverse_wavelet(df, cA, cD)
 | 
				
			||||||
 | 
					print(len(y_new_values))
 | 
				
			||||||
plot_arrays(N, y_new_values)
 | 
					plot_arrays(N, y_new_values)
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user