np.where 对于二维数组,操作整行

np.where for 2d array, manipulate whole rows

我想用 numpy 广播功能重建以下逻辑,例如 np.where:从二维数组检查每行是否第一个元素满足条件。如果条件为真,则 return 前三个元素作为一行,否则最后三个元素。

我想规避的 for 循环形式的短 MWE:

import numpy as np
array = np.array([
    [1, 2, 3, 4],
    [1, 2, 4, 2],
    [2, 3, 4, 6]
])

new_array = np.zeros((array.shape[0], array.shape[1]-1))
for i, row in enumerate(array):
    if row[0] == 1: new_array[i] = row[:3]
    else: new_array[i] = row[-3:]

IIUC 你想要这样的东西:

condition = array[:,0]==1
new_array[condition,:] = array[condition,:3]
new_array[~condition,:] = array[~condition,-3:]

如果你想使用np.where:

import numpy as np
array = np.array([
    [1, 2, 3, 4],
    [1, 2, 4, 2],
    [2, 3, 4, 6]
])

cond = array[:, 0] == 1
np.where(cond[:, None], array[:,:3], array[:,-3:])

输出:

array([[1, 2, 3],
       [1, 2, 4],
       [3, 4, 6]])

编辑

稍微简洁一点的版本:

np.where(array[:, [0]] == 1, array[:,:3], array[:,-3:])