使用 numpy/numba 进行移动线性回归的高效计算

Efficient computation of moving linear regression with numpy/numba

我正在尝试创建一个移动线性回归指标,我想利用 numba。但是,由于缺乏经验,我在后半部分苦苦挣扎

这是我目前使用的numpy。它正在工作,但是,如果不应用 numba,一旦你向它扔大数组,它就会非常慢。

import numpy as np


def ols_1d(y, window):
    y_roll = np.lib.stride_tricks.sliding_window_view(y, window_shape=window)

    m = list()
    c = list()
    for row in np.arange(y_roll.shape[0]):
        A = np.vstack([np.arange(1, window + 1), np.ones(window)]).T
        tmp_m, tmp_c = np.linalg.lstsq(A, y_roll[row], rcond=None)[0]
        m.append(tmp_m)
        c.append(tmp_c)

    m, c = np.array([m, c])

    return np.hstack((np.full((window - 1), np.nan), m * window + c))


def ols_2d(y, window):
    out = list()
    for col in range(y.shape[1]):
        out.append(ols_1d(y=y[:, col], window=window))

    return np.array(out).T


if __name__ == "__main__":
    a = np.random.randn(
        10000, 10
    )  # function is slow once you really increse number of columns (let's say 1 mln)

    print(ols_2d(a, 10))

该指标实际上是针对给定的 window 长度应用 np.linalg.lstsq 函数 (https://numpy.org/doc/stable/reference/generated/numpy.linalg.lstsq.html) 计算线性回归。它基本上输出回归线的最后一个点并移动到下一个范围并再次计算线性回归。在一天结束时,ols_1d 输出每条回归线的最后一个点并将其放入数组中。

现在,我需要帮助才能在其上应用 numba。我不熟悉 numba,但就我自己的反复试验而言,使用 nb.lib.stride_tricks.sliding_window_view().

可能存在问题

编辑:跟进问题

根据@aerobiomat 的建议,我只需要稍微修改 np.cumsum 即可考虑矩阵而不是向量。

def window_sum(x, w):
    c = np.cumsum(x, axis=0)  # inserted axis argument
    s = c[w - 1:]
    s[1:] -= c[:-w]
    return s

def window_lin_reg(x, y, w):
    sx = window_sum(x, w)
    sy = window_sum(y, w)
    sx2 = window_sum(x**2, w)
    sxy = window_sum(x * y, w)
    slope = (w * sxy - sx * sy) / (w * sx2 - sx**2)
    intercept = (sy - slope * sx) / w
    return slope, intercept

现在,让我们创建一个 20x3 矩阵,其中列代表各个时间序列。

timeseries_count = 3
x = np.arange(start=1,stop=21).reshape(-1, 1)
y = np.random.randn(20, timeseries_count)

slope, intercept = window_lin_reg(x, y, 10)

这工作正常。但是,一旦我介绍了一些 np.nan,我就 运行 陷入了一些问题。

y[0, 0] = np.nan
y[5, 0] = np.nan
y[10, 2] = np.nan

为了计算滚动回归,我需要删除每列中的所有 np.nan 并逐列计算。这确实不利于矢量化,不是吗? 真实数据集中的每一列可能包含不同数量的 np.nan。如何巧妙地应对这个问题?我需要速度,因为数据集可能非常大(10000 x 10000 左右)。

无需为每个 window 计算线性回归,这涉及重复许多中间计算,您可以为所有 windows 计算 the formula 所需的值并执行向量化计算在所有回归中:

def window_sum(x, w):
    # Faster than np.lib.stride_tricks.sliding_window_view(x, w).sum(axis=0)
    c = np.cumsum(x)
    s = c[w - 1:]
    s[1:] -= c[:-w]
    return s

def window_lin_reg(x, y, w):
    sx = window_sum(x, w)
    sy = window_sum(y, w)
    sx2 = window_sum(x**2, w)
    sxy = window_sum(x * y, w)
    slope = (w * sxy - sx * sy) / (w * sx2 - sx**2)
    intercept = (sy - slope * sx) / w
    return slope, intercept

例如:

>>> w = 5      # Window
>>> x = np.arange(15)
>>> y = 0.1 * x**2 - x + 10
>>> slope, intercept = window_lin_reg(x, y, w)
>>> print(slope)
>>> print(intercept)
[-0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8  1.   1.2  1.4]
[ 9.8  9.3  8.6  7.7  6.6  5.3  3.8  2.1  0.2 -1.9 -4.2]

np.linalg.lstsq()和循环比较:

def numpy_lin_reg(x, y, w):
    m = len(x) - w + 1
    slope = np.empty(m)
    intercept = np.empty(m)
    for i in range(m):
        A = np.vstack(((x[i:i + w]), np.ones(w))).T
        m, c = np.linalg.lstsq(A, y[i:i + w], rcond=None)[0]
        slope[i] = m
        intercept[i] = c
    return slope, intercept

同样的例子,同样的结果:

>>> slope2, intercept2 = numpy_lin_reg(x, y, w)
>>> with np.printoptions(precision=2, suppress=True):
...     print(np.array(slope2))
...     print(np.array(intercept2))
[-0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8  1.   1.2  1.4]
[ 9.8  9.3  8.6  7.7  6.6  5.3  3.8  2.1  0.2 -1.9 -4.2]

一些时间比较大的例子:

>>> w = 20
>>> x = np.arange(10000)
>>> y = 0.1 * x**2 - x + 10

>>> %timeit numpy_lin_reg(x, y, w)
324 ms ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit window_lin_reg(x, y, w)
189 µs ± 3.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

这是三个数量级的性能提升。这些函数已经矢量化,因此 Numba 几乎无能为力。当函数用 @nb.njit:

修饰时,“仅”快两倍
>>> %timeit window_lin_reg(x, y, w)
96.4 µs ± 350 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

我将此添加为@Andi 的后续问题的答案,其中输入数据可能包含 nans。

这些是更新的功能:

def window_sum(x, w):
    c = np.nancumsum(x)     # Nans no longer affect the summation
    s = c[w - 1:]
    s[1:] -= c[:-w]
    return s

def window_lin_reg(x, y, w):
    # Invalidate both x and y values when there's a nan in one of them
    valid = np.isfinite(x) & np.isfinite(y)
    x[~valid] = np.nan
    y[~valid] = np.nan

    # Sums for each window
    n = window_sum(valid, w)    # Count only valid points in the window
    sx = window_sum(x, w)
    sy = window_sum(y, w)
    sx2 = window_sum(x ** 2, w)
    sxy = window_sum(x * y, w)

    # Avoid ugly warnings
    with np.errstate(divide='ignore', invalid='ignore'):
        slope = (n * sxy - sx * sy) / (n * sx2 - sx ** 2)
        intercept = (sy - slope * sx) / n

    # Replace infinities by nans. Not necessary, but cleaner.
    invalid_results = n < 2
    slope[invalid_results] = np.nan
    intercept[invalid_results] = np.nan

    return slope, intercept

用 nans 测试:

>>> w = 5      # Window
>>> x = np.arange(15.)
>>> y = 0.1 * x**2 - x + 10
>>> x[3] = np.nan
>>> y[7:12] = np.nan
>>> slope, intercept = window_lin_reg(x, y, w)
>>> with np.printoptions(precision=2, suppress=True):
...     print(slope)
...     print(intercept)
[-0.59 -0.4  -0.21 -0.   -0.    0.1    nan   nan   nan  1.5   1.6 ]
[ 9.8   9.35  8.69  7.57  7.57  7.     nan   nan   nan -5.6  -6.83]

使用np.linalg.lstsq和循环的版本,只比较结果:

def numpy_lin_reg(x, y, w):
    valid = np.isfinite(x) & np.isfinite(y)
    x[~valid] = np.nan
    y[~valid] = np.nan
    m = len(x) - w + 1
    slope = np.empty(m)
    intercept = np.empty(m)
    for i in range(m):
        window = slice(i, i+w)
        valid_ = valid[window]
        x_ = x[window][valid_]
        y_ = y[window][valid_]
        n = valid_.sum()
        if n < 2:
            slope[i] = np.nan
            intercept[i] = np.nan
        else:
            A = np.vstack((x_, np.ones(n))).T
            m, c = np.linalg.lstsq(A, y_, rcond=None)[0]
            slope[i] = m
            intercept[i] = c
    return slope, intercept

使用相同数据测试:

>>> slope2, intercept2 = numpy_lin_reg(x, y, w)
>>> with np.printoptions(precision=2, suppress=True):
...     print(np.array(slope2))
...     print(np.array(intercept2))
[-0.59 -0.4  -0.21  0.    0.    0.1    nan   nan   nan  1.5   1.6 ]
[ 9.8   9.35  8.69  7.57  7.57  7.     nan   nan   nan -5.6  -6.83]

为了使用 Numba 的新实现,需要进行一些更改:

@nb.njit
def window_sum(x, w):
    c = np.nancumsum(x)         # Numba needs SciPy here
    s = c[w - 1:]
    s[1:] = s[1:] - c[:-w]      # Numba doens't like -=
    return s

@nb.njit
def window_lin_reg(x, y, w):
    # Invalidate both x and y values when there's a nan in one of them
    valid = np.isfinite(x) & np.isfinite(y)
    x[~valid] = np.nan
    y[~valid] = np.nan

    # Sums for each window
    n = window_sum(valid, w)
    sx = window_sum(x, w)
    sy = window_sum(y, w)
    sx2 = window_sum(x ** 2, w)
    sxy = window_sum(x * y, w)

    # No warnings here from Numba
    slope = (n * sxy - sx * sy) / (n * sx2 - sx ** 2)
    intercept = (sy - slope * sx) / n

    # Replace infinities by nans. Not necessary, but cleaner.
    invalid_results = n < 2
    slope[invalid_results] = np.nan
    intercept[invalid_results] = np.nan

    return slope, intercept