广播 + 重塑 VS 列表理解速度 python

Broadcasting + reshaping VS list comprehension speed in python

给定以下 numpy 个数组:

import numpy as np
import time

v1 = np.linspace(20, 250, 100000000)
a = np.array([12.592,16.320])
m = np.array([3, 5])

列表理解怎么可能:

start = time.time()
v2 = np.max(
    [10 ** _a * np.power(v1.astype(float), -_m) for _m, _a in zip(m, a)],
    axis=0
)
end = time.time()
print(end - start)  # prints 5.822041034698486

numpy 快两倍多 broadcasting?

start = time.time()
v2 = np.max(
    np.power(10, a) * np.power(v1.astype(float)[:, None], -m).reshape(
        v1.shape[0],-1),
    axis=1
)
end = time.time()
print(end - start)  # prints 12.292157173156738

计算 v2 的“最快方法”是什么?

两者都是低效的。实际上,您可以使用自然对数和指数的 数学 属性 更有效地重写操作。这是一个可以进一步优化的等效(未优化)表达式:

v2 = np.max(
    [np.exp(_a * np.log(10) - _m * np.log(v1.astype(float))) for _m, _a in zip(m, a)],
    axis=0
)

因为np.log(v1.astype(float)))是常数,所以可以pre-compute。实际上,v1 项已经是 float 类型(请注意,float 意味着 Numpy 的 np.float64 而不是 CPython float 对象类型)。 np.log(v1) 将正确完成工作(除非 v1 设置为 np.float32 的数组)。此外,np.exp 只能根据最终结果计算(即在计算 np.max 之后),因为 a < b 等同于 e**a < e**b。最后,您可以使用 in-place 操作来避免创建多个昂贵的 临时数组 。这可以使用许多函数的 out 参数来完成,例如 np.subtractnp.multiply

这是结果代码:

log_v1 = np.log(v1)
tmp = np.empty((len(a), v1.size), dtype=np.float64)
v2 = np.exp(np.max(
    [np.subtract(_a * np.log(10), np.multiply(log_v1, _m, out=_tmp), out=_tmp) for _m, _a, _tmp in zip(m, a, tmp)],
    axis=0, 
    out=tmp[0]
), out=tmp[0])

为了获得更快的性能,您可以简单地使用 Numba 以避免在内存中写入巨大的数组(这很慢)。 Numba 还可以使用 多线程 来大大加快速度。这是一个例子:

import numba as nb

# Note that fastmath=True can cause issues with values likes NaN Inf.
# Please disable it if your input array contains such spacial values.
@nb.njit('float64[::1](float64[::1], float64[::1], float64[::1])', fastmath=True, parallel=True)
def compute(v1, m, a):
    assert a.size > 0 and a.size == m.size
    out = np.empty(v1.size, dtype=np.float64)
    log10 = np.log(10)
    for i in nb.prange(v1.size):
        log_v1 = np.log(v1[i])
        maxi = a[0] * log10 - m[0] * log_v1
        for j in range(1, len(a)):
            value = a[j] * log10 - m[j] * log_v1
            if value > maxi:
                maxi = value
        out[i] = np.exp(maxi)
    return out

v2 = compute(v1, m.astype(float), a)

基准

这是我的 6 核机器上的结果:

Initial code with big arrays:    10.013 s    (inefficient)
Initial code with lists:          8.147 s    (inefficient)
Optimized Numpy code:             2.009 s    (memory-bound)
Optimized Numba code:             0.300 s    (compute-bound)

如您所见,Numba 比其他方法快得多:大约 30 倍