在哪些情况下 numpy 的 out 参数更快?

In which cases numpy's out parameter is faster?

我不确定 out 参数是否值得为我的项目带来麻烦,所以我正在进行一系列测试。

但到目前为止,在我测试过的每种情况下,如果与更简单的实现不同,out 参数似乎会使性能稍微变慢,我不明白为什么。

举个例子: test1test11 本来是等价的,但后者使用 n out 参数来避免每次都分配一个新数组。

test2test22

相同

import numpy as np
from timeit import timeit

N = 1000

weights = np.arange(N * N).astype('f4').reshape((N, N))

inputs = np.arange(N).astype('f4')

cache_out1 = np.empty((N, N), dtype='f4')
cache_out2 = np.empty(N, dtype='f4')

def test1():
    return (weights * inputs).sum(axis=1)

def test11():
    np.multiply(weights, inputs, out=cache_out1)
    np.sum(cache_out1, axis=1, out=cache_out2)
    return cache_out2

def test2():
    return (weights * inputs[:, np.newaxis]).sum(axis=0)

def test22():
    np.multiply(weights, inputs[:, np.newaxis], out=cache_out1)
    np.sum(cache_out1, axis=0, out=cache_out2)
    return cache_out2


print('test1:', timeit(test1,  number=1000))
print('test11:', timeit(test11, number=1000))
print('test2:', timeit(test2,  number=1000))
print('test22:', timeit(test22, number=1000))

输出:

test1: 1.1015455439919606
test11: 1.0834621820104076
test2: 1.1083468289871234
test22: 1.1045935050060507

当你的数组越大,分配时间越长时,它会产生更大的影响。使用 out 参数可以让您通过 re-using 该内存来分摊分配时间。举个例子:

import numpy as np
from timeit import timeit

N = 4096

weights = np.arange(N * N).astype('f4').reshape((N, N))

inputs = np.arange(N).astype('f4')

cache_out1 = np.empty((N, N), dtype='f4')
cache_out2 = np.empty(N, dtype='f4')

def test1():
    return (weights * inputs).sum(axis=1)

def test11():
    np.multiply(weights, inputs, out=cache_out1)
    np.sum(cache_out1, axis=1, out=cache_out2)
    return cache_out2

def test2():
    return (weights * inputs[:, np.newaxis]).sum(axis=0)

def test22():
    np.multiply(weights, inputs[:, np.newaxis], out=cache_out1)
    np.sum(cache_out1, axis=0, out=cache_out2)
    return cache_out2


n = 100
print('test1:', timeit(test1,  number=n))
print('test11:', timeit(test11, number=n))
print('test2:', timeit(test2,  number=n))
print('test22:', timeit(test22, number=n))
test1: 2.5047981239913497
test11: 1.7144565229973523
test2: 2.4683585959865013
test22: 1.6845238849928137

如果你真的想挤出最后的纳秒,请检查 Numba 或 Cython。

特别是如果你可以并行操作,你可以获得非常重要的加速。