如何使嵌套的 numpy 'where' 表现得像嵌套的 if 语句?
How to make nested numpy 'where' behave like nested if statements?
要计算以下代码中的b
,elif
语句运行仅当前面的if
语句returnsFalse
时
我的第一个问题是关于 c
。我们能否调整第二个 where
比较,使其 运行 仅针对 a_abs
的那些元素,而第一个 where
比较是 False
?
我们可以做类似于第三段代码的事情吗?也就是说,运行 第二个条件仅适用于 a
的那些元素,其中第一个条件是 False
。
目标是提高大型数组的执行速度。
import numpy as np
a = np.array([[ 0.1, -0.9, -0.7], \
[-3.0, 0.0, 1.1], \
[ 0.5, 0.19, 0.95]])
# 1st block
def f1(x):
x_abs = np.abs(x)
if x_abs < 0.2: return 0.0
elif x_abs > 0.8: return np.sign(x)
else: return x
f1_arr = np.frompyfunc(f1, 1, 1)
b = f1_arr(a)
# 2nd block
a_abs = np.abs(a)
c = np.where(a_abs < 0.2, 0.0, \
np.where(a_abs > 0.8, np.sign(a), \
a))
# 3rd block
a[a_abs < 0.2] = 0.0
a[a_abs > 0.8] = np.sign[a]
你的第二个区块不应该是 a_abs
而不是 a
吗?
# 2nd block
a_abs = np.abs(a)
c = np.where(a_abs < 0.2, 0.0, \
np.where(a_abs > 0.8, np.sign(a_abs), \
a_abs))
In [145]: b
Out[145]:
array([[0.0, -1.0, -0.7],
[-1.0, 0.0, 1.0],
[0.5, 0.0, 1.0]], dtype=object)
与 2 相同的结果 where
:
In [146]: a_abs = np.abs(a)
In [147]: temp = np.where(a_abs<0.2, 0, a)
In [148]: temp = np.where(a_abs>0.8, np.sign(a),a)
In [149]: temp
Out[149]:
array([[ 0.1 , -1. , -0.7 ],
[-1. , 0. , 1. ],
[ 0.5 , 0.19, 1. ]])
要记住的关键是where
不是迭代器。 where
的参数在传递给函数之前进行评估。这是基本的 Python 语法。
所以 a_abs<0.2
和 a_abs>0.8
都被评估了。 sign
也在整个数组上计算。
我们可以构建一个结合 2 个测试的条件:
In [154]: ~(a_abs<0.2)
Out[154]:
array([[False, True, True],
[ True, False, True],
[ True, False, True]])
In [155]: ~(a_abs<0.2)&(a_abs>0.8)
Out[155]:
array([[False, True, False],
[ True, False, True],
[False, False, True]])
尽管这不会添加任何内容,因为如果 a_abs>0.8
它也是 >0.2
:
In [156]: temp = np.where(a_abs<0.2, 0, a)
In [157]: np.where(_155, np.sign(temp), temp)
Out[157]:
array([[ 0. , -1. , -0.7],
[-1. , 0. , 1. ],
[ 0.5, 0. , 1. ]])
这仍在对所有数组求 np.sign
。 ufunc
确实有自己的 where
,它确实在选定的子集上评估函数。因此 [148] 可以写成:
In [159]: np.sign(temp, where=a_abs>0.8, out=temp)
Out[159]:
array([[ 0. , -1. , -0.7],
[-1. , 0. , 1. ],
[ 0.5, 0. , 1. ]])
我不认为它节省时间(虽然我没有做过测试);它在 divide
或 log
中更有用,其中某些数值会引发错误或警告。
一般来说,编译后的 numpy 方法不会“短路”。它们在整个阵列上效果最好,我们接受较小的时间(和内存)损失,以便通过使用编译方法获得更多收益。如果您想以类似 c
的迭代方式微调性能,请使用 numba
等工具。
要计算以下代码中的b
,elif
语句运行仅当前面的if
语句returnsFalse
时
我的第一个问题是关于 c
。我们能否调整第二个 where
比较,使其 运行 仅针对 a_abs
的那些元素,而第一个 where
比较是 False
?
我们可以做类似于第三段代码的事情吗?也就是说,运行 第二个条件仅适用于 a
的那些元素,其中第一个条件是 False
。
目标是提高大型数组的执行速度。
import numpy as np
a = np.array([[ 0.1, -0.9, -0.7], \
[-3.0, 0.0, 1.1], \
[ 0.5, 0.19, 0.95]])
# 1st block
def f1(x):
x_abs = np.abs(x)
if x_abs < 0.2: return 0.0
elif x_abs > 0.8: return np.sign(x)
else: return x
f1_arr = np.frompyfunc(f1, 1, 1)
b = f1_arr(a)
# 2nd block
a_abs = np.abs(a)
c = np.where(a_abs < 0.2, 0.0, \
np.where(a_abs > 0.8, np.sign(a), \
a))
# 3rd block
a[a_abs < 0.2] = 0.0
a[a_abs > 0.8] = np.sign[a]
你的第二个区块不应该是 a_abs
而不是 a
吗?
# 2nd block
a_abs = np.abs(a)
c = np.where(a_abs < 0.2, 0.0, \
np.where(a_abs > 0.8, np.sign(a_abs), \
a_abs))
In [145]: b
Out[145]:
array([[0.0, -1.0, -0.7],
[-1.0, 0.0, 1.0],
[0.5, 0.0, 1.0]], dtype=object)
与 2 相同的结果 where
:
In [146]: a_abs = np.abs(a)
In [147]: temp = np.where(a_abs<0.2, 0, a)
In [148]: temp = np.where(a_abs>0.8, np.sign(a),a)
In [149]: temp
Out[149]:
array([[ 0.1 , -1. , -0.7 ],
[-1. , 0. , 1. ],
[ 0.5 , 0.19, 1. ]])
要记住的关键是where
不是迭代器。 where
的参数在传递给函数之前进行评估。这是基本的 Python 语法。
所以 a_abs<0.2
和 a_abs>0.8
都被评估了。 sign
也在整个数组上计算。
我们可以构建一个结合 2 个测试的条件:
In [154]: ~(a_abs<0.2)
Out[154]:
array([[False, True, True],
[ True, False, True],
[ True, False, True]])
In [155]: ~(a_abs<0.2)&(a_abs>0.8)
Out[155]:
array([[False, True, False],
[ True, False, True],
[False, False, True]])
尽管这不会添加任何内容,因为如果 a_abs>0.8
它也是 >0.2
:
In [156]: temp = np.where(a_abs<0.2, 0, a)
In [157]: np.where(_155, np.sign(temp), temp)
Out[157]:
array([[ 0. , -1. , -0.7],
[-1. , 0. , 1. ],
[ 0.5, 0. , 1. ]])
这仍在对所有数组求 np.sign
。 ufunc
确实有自己的 where
,它确实在选定的子集上评估函数。因此 [148] 可以写成:
In [159]: np.sign(temp, where=a_abs>0.8, out=temp)
Out[159]:
array([[ 0. , -1. , -0.7],
[-1. , 0. , 1. ],
[ 0.5, 0. , 1. ]])
我不认为它节省时间(虽然我没有做过测试);它在 divide
或 log
中更有用,其中某些数值会引发错误或警告。
一般来说,编译后的 numpy 方法不会“短路”。它们在整个阵列上效果最好,我们接受较小的时间(和内存)损失,以便通过使用编译方法获得更多收益。如果您想以类似 c
的迭代方式微调性能,请使用 numba
等工具。