在 NumPy 中有效计算给定向量元素的所有成对乘积

Efficiently computing all pairwise products of a given vector's elements in NumPy

我正在寻找一种 "optimal" 方法来计算给定向量元素的所有成对乘积。如果向量的大小为 N,则输出将是大小为 N * (N + 1) // 2 的向量,并包含所有 (i, j)i <= jx[i] * x[j] 值。计算这个的简单方法如下:

import numpy as np

def get_pairwise_products_naive(vec: np.ndarray):
    k, size = 0, vec.size
    output = np.empty(size * (size + 1) // 2)
    for i in range(size):
        for j in range(i, size):
            output[k] = vec[i] * vec[j]
            k += 1
    return output

需要的东西:

我一直在玩 outertriu_indiceseinsum 等套路以及一些 indexing/view 技巧,但一直找不到合适的符合上述要求的解决方案。

我可能会计算 M = vTv 然后展平这个矩阵的下三角部分或高三角部分。

def pairwise_products(v: np.ndarray):
    assert len(v.shape) == 1
    n = v.shape[0]
    m = v.reshape(n, 1) @ v.reshape(1, n)
    return m[np.tril_indices_from(m)].ravel()

我还想提一下 numba,这将使您的 'naive' 方法很可能比这个方法更快。

import numba

@numba.njit
def pairwise_products_numba(vec: np.ndarray):
    k, size = 0, vec.size
    output = np.empty(size * (size + 1) // 2)
    for i in range(size):
        for j in range(i, size):
            output[k] = vec[i] * vec[j]
            k += 1
    return output

仅测试上述 pairwise_products(np.arange(5000)) 需要约 0.3 秒,而 numba 版本需要约 0.05 秒(忽略用于即时编译函数的第一个 运行)。

方法 #1

对于使用 NumPy 的矢量化,您可以在使用外乘法获得所有成对乘法后使用掩码,就像这样 -

def pairwise_multiply_masking(a):
    return (a[:,None]*a)[~np.tri(len(a),k=-1,dtype=bool)]

方法 #2

对于非常大的输入一维数组,我们可能希望求助于使用单循环的迭代 slicing 方法 -

def pairwise_multiply_iterative_slicing(a):
    n = len(a)
    N = (n*(n+1))//2
    out = np.empty(N, dtype=a.dtype)
    c = np.r_[0,np.arange(n,0,-1)].cumsum()
    for ii,(i,j) in enumerate(zip(c[:-1],c[1:])):
        out[i:j] = a[ii:]*a[ii]
    return out

基准测试

我们将在设置中包含

使用 benchit 包(几个基准测试工具打包在一起;免责声明:我是它的作者)对提议的解决方案进行基准测试。

import benchit
funcs = [pairwise_multiply_masking, pairwise_multiply_iterative_slicing, pairwise_products_numba, pairwise_products]
in_ = [np.random.rand(n) for n in [10,50,100,200,500,1000,5000]]
t = benchit.timings(funcs, in_)
t.plot(logx=True, save='timings.png')
t.speedups(-1).plot(logx=True, logy=False, save='speedups.png')

结果(超过 pairwise_products 的时间和加速)-

从绘图趋势可以看出,对于非常大的数组,基于切片的数组将开始获胜,否则向量化的数组会做得很好。

建议

  • 我们还可以研究 numexpr 以更有效地执行大型数组的外部乘法。

您也可以将此算法并行化。如果有可能只分配一次足够大的数组(此数组上的较小视图几乎不需要任何成本)并在之后覆盖它,则可以实现更大的加速。

例子

@numba.njit(parallel=True)
def pairwise_products_numba_2_with_allocation(vec):
    k, size = 0, vec.size
    k_vec=np.empty(vec.size,dtype=np.int64)
    output = np.empty(size * (size + 1) // 2)

    #precalculate the indices
    for i in range(size):
        k_vec[i] = k
        k+=(size-i)

    for i in numba.prange(size):
        k=k_vec[i]
        for j in range(size-i):
            output[k+j] = vec[i] * vec[j+i]

    return output

@numba.njit(parallel=True)
def pairwise_products_numba_2_without_allocation(vec,output):
    k, size = 0, vec.size
    k_vec=np.empty(vec.size,dtype=np.int64)

    #precalculate the indices
    for i in range(size):
        k_vec[i] = k
        k+=(size-i)

    for i in numba.prange(size):
        k=k_vec[i]
        for j in range(size-i):
            output[k+j] = vec[i] * vec[j+i]

    return output

计时

A=np.arange(5000)
k, size = 0, A.size
output = np.empty(size * (size + 1) // 2)

%timeit res_1=pairwise_products_numba_2_without_allocation(A,output)
#7.84 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res_2=pairwise_products_numba_2_with_allocation(A)
#16.9 ms ± 325 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res_3=pairwise_products_numba(A) #@orlp
#43.3 ms ± 134 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)