与 Fortran 或 C 相比,numpy.einsum 是否高效?
Is numpy.einsum efficient compared to fortran or C?
我写了一个非常耗时的numpy程序。剖析后发现大部分时间都花在了numpy.einsum
上。
虽然 numpy 是 LAPACK 或 BLAS 的包装器,但我不知道 numpy.einsum
的性能是否与 LAPACK 或 BLAS 中的对应物相当。
那么,如果我切换到 Fortran 或 C,我的性能会得到很大提升吗?
Numpy 仅对 BLAS 指定的原始操作使用 BLAS 包装。这包括 dot
、innerproduct
、vdot
、matmul
(1.10 中的新功能)以及依赖于它的函数(tensordot
等)。 einsum
,另一方面,只调用 BLAS 进行允许回退到它的操作(从 Numpy 1.14.0 开始)。
如果您的问题可以分解为多个 BLAS 操作,那么我建议您首先在 Numpy 本身中尝试。它可能需要中间的一些临时数组(即使您要编写使用 BLAS 的 C/FORTRAN 也是如此)。您可以使用函数的 out=
参数消除某些数组创建开销。
但大多数时候,您使用的是 einsum
,因为它无法在 BLAS 中表达。看一个简单的例子:
a = np.arange(60.).reshape(3,4,5)
b = np.arange(24.).reshape(4,3,2)
c = np.einsum('ijk,jil->kl', a, b)
要用原始操作来表达上面的内容,你需要交换b
中的前两个轴,对前两个维度进行元素乘法,然后将它们相加,对于每个索引k
和 l
.
c2 = np.ndarray((5, 2))
b2 = np.swapaxes(b, 0, 1)
def manualeinsum(c2, a, b):
ny, nx = c2.shape
for k in range(ny):
for l in range(nx):
c2[k, l] = np.sum(a[..., k]*b2[...,l])
manualeinsum(c2, a, b2)
你不能 BLAS 那个。 更新:上面的问题可以表示为矩阵乘法,可以使用 BLAS 加速。请参阅@ali_m 的评论。对于足够大的阵列,BLAS 方法更快。
同时,请注意 einsum
本身是用 C 编写的,为给定的索引创建了一个特定于维度的迭代器,并且还针对 SSE 进行了优化。
我写了一个非常耗时的numpy程序。剖析后发现大部分时间都花在了numpy.einsum
上。
虽然 numpy 是 LAPACK 或 BLAS 的包装器,但我不知道 numpy.einsum
的性能是否与 LAPACK 或 BLAS 中的对应物相当。
那么,如果我切换到 Fortran 或 C,我的性能会得到很大提升吗?
Numpy 仅对 BLAS 指定的原始操作使用 BLAS 包装。这包括 dot
、innerproduct
、vdot
、matmul
(1.10 中的新功能)以及依赖于它的函数(tensordot
等)。 einsum
,另一方面,只调用 BLAS 进行允许回退到它的操作(从 Numpy 1.14.0 开始)。
如果您的问题可以分解为多个 BLAS 操作,那么我建议您首先在 Numpy 本身中尝试。它可能需要中间的一些临时数组(即使您要编写使用 BLAS 的 C/FORTRAN 也是如此)。您可以使用函数的 out=
参数消除某些数组创建开销。
但大多数时候,您使用的是 einsum
,因为它无法在 BLAS 中表达。看一个简单的例子:
a = np.arange(60.).reshape(3,4,5)
b = np.arange(24.).reshape(4,3,2)
c = np.einsum('ijk,jil->kl', a, b)
要用原始操作来表达上面的内容,你需要交换b
中的前两个轴,对前两个维度进行元素乘法,然后将它们相加,对于每个索引k
和 l
.
c2 = np.ndarray((5, 2))
b2 = np.swapaxes(b, 0, 1)
def manualeinsum(c2, a, b):
ny, nx = c2.shape
for k in range(ny):
for l in range(nx):
c2[k, l] = np.sum(a[..., k]*b2[...,l])
manualeinsum(c2, a, b2)
你不能 BLAS 那个。 更新:上面的问题可以表示为矩阵乘法,可以使用 BLAS 加速。请参阅@ali_m 的评论。对于足够大的阵列,BLAS 方法更快。
同时,请注意 einsum
本身是用 C 编写的,为给定的索引创建了一个特定于维度的迭代器,并且还针对 SSE 进行了优化。