改进 min/max 下采样

Improve min/max downsampling

我有一些大型数组(约 1 亿个点)需要进行交互式绘图。我目前正在使用 Matplotlib。按原样绘制数组变得非常慢并且是一种浪费,因为无论如何您都无法想象那么多点。

所以我做了一个 min/max 抽取函数,我绑定到轴的 'xlim_changed' 回调。我采用 min/max 方法,因为数据包含快速峰值,我不想通过单步执行数据而错过这些峰值。有更多包装器裁剪到 x 限制,并在某些条件下跳过处理,但相关部分如下:

def min_max_downsample(x,y,num_bins):
    """ Break the data into num_bins and returns min/max for each bin"""
    pts_per_bin = x.size // num_bins    

    #Create temp to hold the reshaped & slightly cropped y
    y_temp = y[:num_bins*pts_per_bin].reshape((num_bins, pts_per_bin))
    y_out      = np.empty((num_bins,2))
    #Take the min/max by rows.
    y_out[:,0] = y_temp.max(axis=1)
    y_out[:,1] = y_temp.min(axis=1)
    y_out = y_out.ravel()

    #This duplicates the x-value for each min/max y-pair
    x_out = np.empty((num_bins,2))
    x_out[:] = x[:num_bins*pts_per_bin:pts_per_bin,np.newaxis]
    x_out = x_out.ravel()
    return x_out, y_out

这工作得很好并且足够快(在 1e8 点和 2k bin 上大约 80 毫秒)。几乎没有延迟,因为它会定期重新计算和更新行的 x 和 y 数据。

但是,我唯一的抱怨是 x 数据。此代码复制每个 bin 左边缘的 x 值,而不是 return y min/max 对的真实 x 位置。我通常将 bin 的数量设置为轴像素宽度的两倍。所以你真的看不出区别,因为垃圾箱太小了……但我知道它在那里……这让我很烦。

因此请尝试第 2 次,它会 return 每个 min/max 对的实际 x 值。然而,它慢了大约 5 倍。

def min_max_downsample_v2(x,y,num_bins):
    pts_per_bin = x.size // num_bins
    #Create temp to hold the reshaped & slightly cropped y
    y_temp = y[:num_bins*pts_per_bin].reshape((num_bins, pts_per_bin))
    #use argmax/min to get column locations
    cc_max = y_temp.argmax(axis=1)
    cc_min = y_temp.argmin(axis=1)    
    rr = np.arange(0,num_bins)
    #compute the flat index to where these are
    flat_max = cc_max + rr*pts_per_bin
    flat_min = cc_min + rr*pts_per_bin
    #Create a boolean mask of these locations
    mm_mask  = np.full((x.size,), False)
    mm_mask[flat_max] = True
    mm_mask[flat_min] = True  
    x_out = x[mm_mask]    
    y_out = y[mm_mask]  
    return x_out, y_out

这在我的机器上大约需要 400 多毫秒,这变得非常明显。所以我的问题基本上是有没有办法更快地提供相同的结果?瓶颈主要在 numpy.argminnumpy.argmax 函数中,它们比 numpy.minnumpy.max.

慢很多

答案可能是只使用版本 #1,因为它在视觉上并不重要。或者也许尝试加快它的速度,比如 cython(我从未使用过)。

仅供参考,在 Windows 上使用 Python 3.6.4 ...示例用法如下所示:

x_big = np.linspace(0,10,100000000)
y_big = np.cos(x_big )
x_small, y_small = min_max_downsample(x_big ,y_big ,2000) #Fast but not exactly correct.
x_small, y_small = min_max_downsample_v2(x_big ,y_big ,2000) #correct but not exactly fast.

我设法通过直接使用 arg(min|max) 的输出来索引数据数组来提高性能。这是以额外调用 np.sort 为代价的,但是要排序的轴只有两个元素(最小/最大索引)并且整个数组相当小(bin 数):

def min_max_downsample_v3(x, y, num_bins):
    pts_per_bin = x.size // num_bins

    x_view = x[:pts_per_bin*num_bins].reshape(num_bins, pts_per_bin)
    y_view = y[:pts_per_bin*num_bins].reshape(num_bins, pts_per_bin)
    i_min = np.argmin(y_view, axis=1)
    i_max = np.argmax(y_view, axis=1)

    r_index = np.repeat(np.arange(num_bins), 2)
    c_index = np.sort(np.stack((i_min, i_max), axis=1)).ravel()

    return x_view[r_index, c_index], y_view[r_index, c_index]

我检查了你的例子的时间,我得到:

  • min_max_downsample_v1: 110 毫秒 ± 5 毫秒
  • min_max_downsample_v2:240 毫秒±8.01 毫秒
  • min_max_downsample_v3: 164 毫秒 ± 1.23 毫秒

我还检查了在调用 arg(min|max) 之后直接返回,结果同样是 164 毫秒,也就是说,在那之后就没有真正的开销了。

所以这并没有解决加速所讨论的特定功能的问题,但它确实展示了一些在某种程度上有效地绘制具有大量点的线的方法。 这假设 x 点是有序的并且是均匀(或接近均匀)采样的。

设置

from pylab import *

这是我喜欢的一个函数,它通过在每个间隔中随机选择一个点来减少点数。 它不能保证显示数据中的每个峰值,但它没有直接抽取数据那么多问题,而且速度很快。

def calc_rand(y, factor):
    split = y[:len(y)//factor*factor].reshape(-1, factor)
    idx = randint(0, split.shape[-1], split.shape[0])
    return split[arange(split.shape[0]), idx]

这是查看信号包络的最小值和最大值

def calc_env(y, factor):
    """
    y : 1D signal
    factor : amount to reduce y by (actually returns twice this for min and max)
    Calculate envelope (interleaved min and max points) for y
    """
    split = y[:len(y)//factor*factor].reshape(-1, factor)
    upper = split.max(axis=-1)
    lower = split.min(axis=-1)
    return c_[upper, lower].flatten()

以下函数可以采用其中任何一个,并使用它们来减少绘制的数据。 实际取点数默认为5000,应该远远超过显示器的分辨率。 数据在减少后被缓存。 内存可能是个问题,尤其是有大量数据时,但不应超过原始信号所需的数量。

def plot_bigly(x, y, *, ax=None, M=5000, red=calc_env, **kwargs):
    """
    x : the x data
    y : the y data
    ax : axis to plot on
    M : The maximum number of line points to display at any given time
    kwargs : passed to line
    """
    assert x.shape == y.shape, "x and y data must have same shape!"
    if ax is None:
        ax = gca()

    cached = {}

    # Setup line to be drawn beforehand, note this doesn't increment line properties so
    #  style needs to be passed in explicitly
    line = plt.Line2D([],[], **kwargs)
    def update(xmin, xmax):
        """
        Update line data

        precomputes and caches entire line at each level, so initial
        display may be slow but panning and zooming should speed up after that
        """
        # Find nearest power of two as a factor to downsample by
        imin = max(np.searchsorted(x, xmin)-1, 0)
        imax = min(np.searchsorted(x, xmax) + 1, y.shape[0])
        L = imax - imin + 1
        factor = max(2**int(round(np.log(L/M) / np.log(2))), 1)

        # only calculate reduction if it hasn't been cached, do reduction using nearest cached version if possible
        if factor not in cached:
            cached[factor] = red(y, factor=factor)

        ## Make sure lengths match correctly here, by ensuring at least
        #   "factor" points for each x point, then matching y length
        #  this assumes x has uniform sample spacing - but could be modified
        newx = x[imin:imin + ((imax-imin)//factor)* factor:factor]
        start = imin//factor
        newy = cached[factor][start:start + newx.shape[-1]]
        assert newx.shape == newy.shape, "decimation error {}/{}!".format(newx.shape, newy.shape)

        ## Update line data
        line.set_xdata(newx)
        line.set_ydata(newy)

    update(x[0], x[-1])
    ax.add_line(line)
    ## Manually update limits of axis, as adding line doesn't do this
    #   if drawing multiple lines this can quickly slow things down, and some
    #   sort of check should be included to prevent unnecessary changes in limits
    #   when a line is first drawn.
    ax.set_xlim(min(ax.get_xlim()[0], x[0]), max(ax.get_xlim()[1], x[1]))
    ax.set_ylim(min(ax.get_ylim()[0], np.min(y)), max(ax.get_ylim()[1], np.max(y)))

    def callback(*ignore):
        lims = ax.get_xlim()
        update(*lims)

    ax.callbacks.connect('xlim_changed', callback)

    return [line]

这是一些测试代码

L=int(100e6)
x=linspace(0,1,L)
y=0.1*randn(L)+sin(2*pi*18*x)
plot_bigly(x,y, red=calc_env)

在我的机器上,这个显示速度非常快。缩放有一点滞后,尤其是当缩放量很大时。平移没有问题。使用随机选择而不是最小值和最大值要快得多,并且只有在非常高的缩放级别上才会出现问题。

编辑:将 parallel=True 添加到 numba ...甚至更快

我最终混合了单遍 argmin+max 例程和改进的索引,从@a_guest 的答案和 link 到 this related simultaneous min max question

这个版本 returns 每个 min/max y 对的正确 x 值并且由于 numba 实际上比 "fast but not quite correct" 版本快一点。

from numba import jit, prange
@jit(parallel=True)
def min_max_downsample_v4(x, y, num_bins):
    pts_per_bin = x.size // num_bins
    x_view = x[:pts_per_bin*num_bins].reshape(num_bins, pts_per_bin)
    y_view = y[:pts_per_bin*num_bins].reshape(num_bins, pts_per_bin)    
    i_min = np.zeros(num_bins,dtype='int64')
    i_max = np.zeros(num_bins,dtype='int64')

    for r in prange(num_bins):
        min_val = y_view[r,0]
        max_val = y_view[r,0]
        for c in range(pts_per_bin):
            if y_view[r,c] < min_val:
                min_val = y_view[r,c]
                i_min[r] = c
            elif y_view[r,c] > max_val:
                max_val = y_view[r,c]
                i_max[r] = c                
    r_index = np.repeat(np.arange(num_bins), 2)
    c_index = np.sort(np.stack((i_min, i_max), axis=1)).ravel()        
    return x_view[r_index, c_index], y_view[r_index, c_index]

比较使用 timeit 的速度显示 numba 代码大约快 2.6 倍,并且提供比 v1.1 更好的结果。它比连续执行 numpy 的 argmin 和 argmax 快 10 倍多一点。

%timeit min_max_downsample_v1(x_big ,y_big ,2000)
96 ms ± 2.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit min_max_downsample_v2(x_big ,y_big ,2000)
507 ms ± 4.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit min_max_downsample_v3(x_big ,y_big ,2000)
365 ms ± 1.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit min_max_downsample_v4(x_big ,y_big ,2000)
36.2 ms ± 487 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

您是否尝试过使用 pyqtgraph 进行交互式绘图?它比 matplotlib 响应更快。

我用于下采样的一个技巧是使用 array_split 并计算分割的最小值和最大值。拆分是根据绘图区域每个像素的样本数完成的。