如何将一个轴与多个numpy数组匹配

how to match one axis with multiple numpy arrays

背景: 我有一个 rgb 图像,具有三个缩小(W,H,C),其中 C = 3。我想在此图像中屏蔽一些颜色,如 (0,0,255) , (0,255,255) 。问题变成了将图像的最后一个轴与我定义的颜色列表相匹配。 color_list = [[255,0,0], [255,255,0], [255,0,255]] # just an example

一种颜色很容易做到,

mask = np.all(image == [255,0,0], axis = 2)

但如果我有多种颜色,我必须 运行 一个 for 循环。

masks = [np.all(image == color, axis = 2) for color in color_list]
mask = np.any(masks, axis=0)

问题: 有什么优雅的方法可以得到多种颜色的面具吗?

我有一种使用广播的方法,这种方法效率更高,因为它会在 C 中循环。基本上使数组具有可比性。开始时可能看起来很难,但是一旦您知道它是如何工作的,这就是您将使用的所有内容 [条件适用]...

import numpy as np
x = np.array([[[255,   0,   0],[  0, 255,   0], [  0, 255,   0], [  0, 255,   0]], [[255,   0,   0],[  0, 255,   0], [  0, 255,   0], [  0, 255,   0]]])
print(x.shape)
# (2, 4, 3)

color_list = np.array([[255,0,0], [255,255,0], [255,0,255]])
print(color_list.shape)
# (3, 3)

# make array compatible
x = x[:, :, np.newaxis, :]

### Analogy for interpreting broadcasting
# Here repeating is for analogy and does not mean it will allocate new copy of memory
# element wise comparision, possibler due to broadcast
# shape of x is (2, 4, 1, 3)
# By broadcasting conceptually x will be repeated along axis=2 this will make (2, 4, 3, 3)
# color_list will be repeated over (2, 4) making it (2, 4, 3, 3) and they will have same shape also the final shape after == will be (2, 4, 3, 3)
f1 = np.all(x[:, :, np.newaxis, :] == color_list, axis=3)
#array([[[ True, False, False],
#        [False, False, False],
#        [False, False, False],
#        [False, False, False]],
#
#       [[ True, False, False],
#        [False, False, False],
#        [False, False, False],
#        [False, False, False]]])

mask = np.any(f1, axis=2)

我们有形状为 (W, H, C) == (2, 4, 3) 的目标数组,我们需要找到大小为 color_list == [[255,0,0], [255,255,0], [255,0,255]]

的 3 数组

理想情况下我们想做 cross comparison,我的意思是如果一侧有 M 和另一侧 N 条目,那么经过一些操作我们想要 M * N 结果。这看起来像是每 N 次重复 M 个条目并进行比较。虽然乍一看这似乎不可能,但 numpy 提供了 广播 。这将概念上像你的for循环一样重复条目(实际上它具有很高的内存效率,它不会创建实际副本)

所以我们需要广播,让这两个数组兼容,但它们不兼容,如broadcasting rules中提到的形状是从右到左比较的,它们需要相同或者其中之一必须是1.

color_list形状为(3, 3),x形状为(2, 4, 3)。我们将在 x 中添加新轴以使其与广播兼容,即 x[:, :, np.newaxis, :],其形状为 (2, 4, 1, 3)。现在两者兼容,我们可以比较。

沿最后一个轴进行比较,即颜色通道轴 = 3,然后在最后一个轴上进行比较,其轴 = 2 将给出 (W, H) 布尔值,其中如果颜色通道三元组在 color_list.

这种技术是完全相同的,可以用来计算距离矩阵当给定两个点数组像这里Fast way to calculate min distance between two numpy arrays of 3D points