使用 scipy.sosfilt 过滤多维数据

Filtering multidimensional data with scipy.sosfilt

我正在尝试使用 'scipy.sosfilt' 过滤尺寸为 19 通道 x 10,000 个样本的二维数组,但收到有关我的初始过滤器形状的错误。我阅读了发布的解决方案 here and ,但我正在寻找比 2 通道解决方案更通用的解决方案以及不涉及循环每个通道的解决方案。

这是我尝试过的:

import numpy as np
import scipy
from scipy import signal

# data
# 19 channels x 10,000 samples
data = np.random.rand(19, 10000)

# sampling rate (Hz)
sampling_rate = 500 

# filters
nyq = sampling_rate / 2

# create bandpass filter
band_low = 0.01
band_lowcut = band_low / nyq
band_high = 50.0
band_highcut = band_high / nyq
band_sos = scipy.signal.butter(N = 3, Wn = [band_lowcut, band_highcut], btype = 'bandpass', fs = sampling_rate, output = 'sos')
band_z = scipy.signal.sosfilt_zi(band_sos)

# apply bandpass filter across the columns i.e. each of 19 channels bandpassed individually
bandpassed_data, band_z = scipy.signal.sosfilt(sos = band_sos, x = data, zi = band_z, axis = 1)

我收到以下错误:

ValueError: Invalid zi shape. With axis=1, an input with shape (19, 10000), and an sos array with 3 sections, zi must have shape (3, 19, 2), got (3, 2).

谢谢!

可以这样重塑为 (3, 19, 2),假设您希望所有 19 个通道都相同:

band_z = np.repeat(np.expand_dims(band_z, axis=1), 19, axis=1)

我在最新的 scipy 版本中尝试过,这很有效。

干杯