Python 中的连续数组警告 (numba)

Contiguous array warning in Python (numba)

我有以下代码片段,其中我使用了 numba 来加速我的代码:

import numpy as np
from numba import jit

Sigma = np.array([
                  [1, 1, 0.5, 0.25],
                  [1, 2.5, 1, 0.5],
                  [0.5, 1, 0.5, 0.25],
                  [0.25, 0.5, 0.25, 0.25]
])
Z = np.array([0.111, 0.00658])

@jit(nopython=True)
def mean(Sigma, Z):
  return np.dot(np.dot(Sigma[0:2, 2:4], linalg.inv(Sigma[2:4, 2:4])), Z)

print(mean(Sigma, Z))

但是,numba 正在抱怨

NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 2d, A), array(float64, 2d, F))
  return np.dot(np.dot(Sigma[0:2, 2:4], linalg.inv(Sigma[2:4, 2:4])), Z)

如果我没记错的话(在阅读 this 之后),由于 Sigma 的子矩阵切片(即“Sigma[0:2, 2:4]").这个对吗?如果是这样,有什么办法可以解决此警告?我相信解决此警告将有助于加快我的代码速度,这是我的主要目标。谢谢

您收到此错误是因为 dotinv 针对连续数组进行了优化。然而,对于小输入尺寸,这不是一个大问题。不过,您至少可以使用装饰器 @jit(...) 中的签名 'float64[:](float64[:,::1], float64[::1])' 指定输入数组是连续的。这也导致函数被急于编译。

此函数中最大的性能问题是创建了很少的临时数组和对 linalg.inv 的调用,这并不是为非常小的矩阵而设计的。可以通过基于行列式计算一个简单的表达式来获得逆矩阵。

这是结果代码:

import numba as nb

@nb.njit('float64[:](float64[:,::1], float64[::1])')
def fast_mean(Sigma, Z):
    # Compute the inverse matrix
    mat_a = Sigma[2, 2]
    mat_b = Sigma[2, 3]
    mat_c = Sigma[3, 2]
    mat_d = Sigma[3, 3]
    invDet = 1.0 / (mat_a*mat_d - mat_b*mat_c)
    inv_a = invDet * mat_d
    inv_b = -invDet * mat_b
    inv_c = -invDet * mat_c
    inv_d = invDet * mat_a

    # Compute the matrix multiplication
    mat_a = Sigma[0, 2]
    mat_b = Sigma[0, 3]
    mat_c = Sigma[1, 2]
    mat_d = Sigma[1, 3]
    tmp_a = mat_a*inv_a + mat_b*inv_c
    tmp_b = mat_a*inv_b + mat_b*inv_d
    tmp_c = mat_c*inv_a + mat_d*inv_c
    tmp_d = mat_c*inv_b + mat_d*inv_d

    # Final dot product
    z0, z1 = Z
    result = np.empty(2, dtype=np.float64)
    result[0] = tmp_a*z0 + tmp_b*z1
    result[1] = tmp_c*z0 + tmp_d*z1
    return result

这比我的机器快 3 倍。请注意,>60% 的时间花费在调用 Numba 函数和创建输出临时数组的开销上。因此,在调用函数中使用 Numba 可能是明智的,这样可以消除这种开销。

您可以将 result 数组作为参数传递,以避免创建@max9111 指出的非常昂贵的数组。仅当您可以在调用函数中预分配输出缓冲区时(如果可能,一次),这才有用。这几乎快 6 倍