numpy where 根据条件替换为 numpy 数组

numpy where replace with numpy array depending on condition

我有:

import numpy as np

a = np.array([[1,1,1],[-1,-1,-1]])

print (a)

输出:

[[ 1  1  1]
 [-1 -1 -1]]

我能做到:

b = np.where(np.mean(a,axis=1) > 0,1,0)

print (b)

正确的结果是:

[1 0]

但是当我这样做时:

b = np.where(np.mean(a,axis=1) > 0,np.array([1,1]),np.array([0,0]))

print (b)

结果相同:

[1 0]

我想要的是:

[[1 1]
 [0 0]]

详细来说,我想用数组而不是单个整数替换基于沿轴 1 的平均值的 ndarray 元素。所以二维数组的输出应该是一个二维数组。

In [254]: np.mean(a,axis=1) > 0,np.array([1,1]),np.array([0,0])
Out[254]: (array([ True, False]), array([1, 1]), array([0, 0]))

3个参数是(2,)形数组。他们互相广播到 return 一个 (2,) 数组。

关键是,broadcasting。它根据条件的相应元素从 2 个数组中挑选元素。这不是两者之间的批发选择。

如果条件是 (2,1) 形状的,将针对 (2,) 进行广播以产生 (2,2) 结果

In [255]: (np.mean(a,axis=1) > 0)[:,None],np.array([1,1]),np.array([0,0])
Out[255]: 
(array([[ True],
        [False]]),
 array([1, 1]),
 array([0, 0]))

In [256]: np.where((np.mean(a,axis=1) > 0)[:,None],np.array([1,1]),np.array([0,0]))
Out[256]: 
array([[1, 1],
       [0, 0]])