如何使嵌套的 numpy 'where' 表现得像嵌套的 if 语句?

How to make nested numpy 'where' behave like nested if statements?

要计算以下代码中的belif语句运行仅当前面的if语句returnsFalse

我的第一个问题是关于 c。我们能否调整第二个 where 比较,使其 运行 仅针对 a_abs 的那些元素,而第一个 where 比较是 False

我们可以做类似于第三段代码的事情吗?也就是说,运行 第二个条件仅适用于 a 的那些元素,其中第一个条件是 False

目标是提高大型数组的执行速度。

import numpy as np

a = np.array([[ 0.1, -0.9, -0.7], \
              [-3.0,  0.0,  1.1], \
              [ 0.5,  0.19, 0.95]])
# 1st block
def f1(x):
    x_abs = np.abs(x)
    if   x_abs < 0.2: return 0.0
    elif x_abs > 0.8: return np.sign(x)
    else:                return x
f1_arr = np.frompyfunc(f1, 1, 1)
b = f1_arr(a)

# 2nd block
a_abs = np.abs(a)
c = np.where(a_abs < 0.2, 0.0,        \
    np.where(a_abs > 0.8, np.sign(a), \
                          a))
# 3rd block
a[a_abs < 0.2] = 0.0
a[a_abs > 0.8] = np.sign[a]

你的第二个区块不应该是 a_abs 而不是 a 吗?

# 2nd block
a_abs = np.abs(a)
c = np.where(a_abs < 0.2, 0.0,        \
    np.where(a_abs > 0.8, np.sign(a_abs), \
                          a_abs))
In [145]: b
Out[145]: 
array([[0.0, -1.0, -0.7],
       [-1.0, 0.0, 1.0],
       [0.5, 0.0, 1.0]], dtype=object)

与 2 相同的结果 where:

In [146]: a_abs = np.abs(a)
In [147]: temp = np.where(a_abs<0.2, 0, a)
In [148]: temp = np.where(a_abs>0.8, np.sign(a),a)
In [149]: temp
Out[149]: 
array([[ 0.1 , -1.  , -0.7 ],
       [-1.  ,  0.  ,  1.  ],
       [ 0.5 ,  0.19,  1.  ]])

要记住的关键是where 不是迭代器where 的参数在传递给函数之前进行评估。这是基本的 Python 语法。

所以 a_abs<0.2a_abs>0.8 都被评估了。 sign 也在整个数组上计算。

我们可以构建一个结合 2 个测试的条件:

In [154]: ~(a_abs<0.2)
Out[154]: 
array([[False,  True,  True],
       [ True, False,  True],
       [ True, False,  True]])
In [155]: ~(a_abs<0.2)&(a_abs>0.8)
Out[155]: 
array([[False,  True, False],
       [ True, False,  True],
       [False, False,  True]])

尽管这不会添加任何内容,因为如果 a_abs>0.8 它也是 >0.2:

In [156]: temp = np.where(a_abs<0.2, 0, a)
In [157]: np.where(_155, np.sign(temp), temp)
Out[157]: 
array([[ 0. , -1. , -0.7],
       [-1. ,  0. ,  1. ],
       [ 0.5,  0. ,  1. ]])

这仍在对所有数组求 np.signufunc 确实有自己的 where,它确实在选定的子集上评估函数。因此 [148] 可以写成:

In [159]: np.sign(temp, where=a_abs>0.8, out=temp)
Out[159]: 
array([[ 0. , -1. , -0.7],
       [-1. ,  0. ,  1. ],
       [ 0.5,  0. ,  1. ]])

我不认为它节省时间(虽然我没有做过测试);它在 dividelog 中更有用,其中某些数值会引发错误或警告。

一般来说,编译后的 numpy 方法不会“短路”。它们在整个阵列上效果最好,我们接受较小的时间(和内存)损失,以便通过使用编译方法获得更多收益。如果您想以类似 c 的迭代方式微调性能,请使用 numba 等工具。