Numpy 布尔语句 - 帮助在语句中使用 a.any() 和 a.all()

Numpy boolean statement - help on using a.any() and a.all() in statement

假设我有一个变量 a,它是一个 numpy 数组。当a小于某个值时我想应用某个函数,当它大于这个值时我会应用不同的函数。

我尝试使用布尔 if 语句来执行此操作,但 return 出现以下错误:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我从 this 的回答中知道我需要使用 numpy a.any() 和 a.all() 但我不清楚 how/where 我会在循环。我在下面提供了一个非常简单的示例:

import numpy as np

a = np.linspace(1, 10, num=9)

def sortfoo(a):
    if a < 5:
        b = a*3
    else:
        b = a/2
    return b

result = sortfoo(a)
print(result)

所以我想我是在问一个例子,说明我需要在哪里以及如何使用上面的 any() 和 all() 。

非常基本的问题,但出于某种原因,我的大脑不太清楚。非常感谢任何帮助。

在 numpy 中使用简单的语句,你可以做到这一点:

import numpy as np
a = np.linspace(1, 10, num=9)
s = a < 5 # Test a < 5
a[s] = a[s] * 3
a[s == False] = a[s == False] / 2

根据描述,这看起来像是 np.where()

的用例
a = np.linspace(1, 10, num=9)

b = np.where(a<5,a*3,a/2)

b
array([ 3.    ,  6.375 ,  9.75  , 13.125 ,  2.75  ,  3.3125,  3.875 ,
    4.4375,  5.    ])

既然你也提到要应用不同的功能,那么你可以使用相同的语法

def f1(n):
    return n*3

def f2(n):
    return n/2

np.where(a<5,f1(a),f2(a))

array([ 3.    ,  6.375 ,  9.75  , 13.125 ,  2.75  ,  3.3125,  3.875 ,
        4.4375,  5.    ])