在 python 中将一些二维数组标签设置为零

Setting some 2d array labels to zero in python

我的目标是在不使用 for 循环的情况下将二维数组中的一些标签设置为零。没有 for 循环,有没有更快的 numpy 方法来做到这一点?理想的情况是 temp_arr[labeled_im not in labels] = 0,但它并没有真正按照我希望的方式工作。

labeled_array = np.array([[1,2,3],
                          [4,5,6],
                          [7,8,9]])

labels = [2,4,5,6,8]
temp_arr = np.zeros((labeled_array.shape)).astype(int)
for label in labels:
    temp_arr[labeled_array == label] = label

>> temp_arr
[[0 2 0]
 [4 5 6]
 [0 8 0]]

当有很多迭代要经过时,for 循环会变得很慢,因此使用 numpy 缩短执行时间很重要。

您可以将labels定义为一个集合并使用temp_arr = np.where(np.isin(labeled_array, labels), labeled_array, 0)。虽然,这么小的阵列,差别似乎并不大。

import numpy as np
import time

labeled_array = np.array([[1,2,3],
                          [4,5,6],
                          [7,8,9]])

labels = [2,4,5,6,8]

start = time.time()
temp_arr_0 = np.zeros((labeled_array.shape)).astype(int)
for label in labels:
    temp_arr_0[labeled_array == label] = label
end = time.time()

print(f"Loop takes {end - start}")

start = time.time()
temp_arr_1 = np.where(np.isin(labeled_array, labels), labeled_array, 0)
end = time.time()

print(f"np.where takes {end - start}")

labels  = {2,4,5,6,8}

start = time.time()
temp_arr_2 = np.where(np.isin(labeled_array, labels), labeled_array, 0)
end = time.time()

print(f"np.where with set takes {end - start}")

产出

Loop takes 5.3882598876953125e-05
np.where takes 0.00010514259338378906
np.where with set takes 3.314018249511719e-05

如果标签在 labels 中是唯一的(并且内存不是问题),这是一种方法。

作为第一步,我们将标签转换为 ndarray

labels = np.array(labels)

然后,我们从 labeled_arraylabels

生成两个可广播数组
labeled_row = labeled_array.ravel()[np.newaxis, :]
labels_col = labels[:, np.newaxis]

上面的代码块分别产生一个形状为(1,9)的行数组

array([[1, 2, 3, 4, 5, 6, 7, 8, 9]])

和形状为 (5,1) 的列数组

array([[2],
       [4],
       [5],
       [6],
       [8]])

现在这两个形状是可广播的(见this page),所以我们可以进行逐元素比较,例如

mask = labeled_row == labels_col

其中 returns 一个 (5,9) 形布尔掩码

array([[False,  True, False, False, False, False, False, False, False],
       [False, False, False,  True, False, False, False, False, False],
       [False, False, False, False,  True, False, False, False, False],
       [False, False, False, False, False,  True, False, False, False],
       [False, False, False, False, False, False, False,  True, False]])

在满足上述假设的情况下,每行的 True 个值等于相应标签在 labeled_array 中出现的次数。尽管如此,您也可以拥有 all-False 行,例如当 labels 中的标签从未出现在您的 labeled_array 中时。

要找出哪些标签实际出现在您的 labeled_array 中,您可以在布尔掩码

上使用 np.nonzero
indices = np.nonzero(mask)

其中 returns 包含 non-zero(即 True)元素的行和列索引的元组

(array([0, 1, 2, 3, 4], dtype=int64), array([1, 3, 4, 5, 7], dtype=int64))

通过构造,上面元组的第一个元素告诉您哪些标签实际出现在您的 labeled_array 中,例如

appeared_labels = labels[indices[0]]

(请注意,如果特定标签在 labeled_array 中出现多次,则 appeared_labels 中可以有连续的元素。

我们现在可以构建并填充输出数组:

out = np.zeros(labeled_array.size, dtype=int)
out[indices[1]] = labels[indices[0]]

然后恢复原状

out = out.reshape(*labeled_array.shape)
array([[0, 2, 0],
       [4, 5, 6],
       [0, 8, 0]])