根据 numpy.argmin 个结果切片

Slice based on numpy.argmin results

让我们有一个形状等于 (36, 2, 400, 400) 的 numpy 数组(浮点数)。假设 400 x 400 表示图像。然后对于每个像素,我想找到两个值(第二维),它们是在第二维上取范数时的值,相对于第一维是最低的。所以我最终得到一个形状数组 (2, 400, 400).

使用 np.argmin(np.linalg.norm(array, axis=1), axis=0) 我能够为每个 2 x 400 x 400 像素获取索引,这几乎是我想要的。但是现在我想使用这个数字在第一维对原始数组进行切片,所以我剩下一个形状为 (2, 400, 400) 的数组。

我能做的是遍历所有索引并逐像素构建结果,但我相信有更聪明的方法。谁能用更聪明的方法帮助我?

要求的最小可重现示例,其中 distances 是数组:

shape = (400, 400)
centers = np.random.randint(400, size=(36, 2))

distances = np.array([np.indices(shape) - np.array(center)[:, None, None] for center in centers])
nearest_center_index = np.argmin(np.linalg.norm(distances, axis=1), axis=0)

print(distances.shape)
print(nearest_center_index.shape)
plt.imshow(nearest_center_index)

输出:

(36, 2, 400, 400)
(400, 400)

在评论的帮助下,我能够给出一个有点难看的答案,这有助于我进一步理解问题。让我详细说明。可以做的是将图像和 argmin 结果展平,然后使用带有 argmin 和图像索引的高级索引来生成结果。

flatten_indices = nearest_center_index.reshape(400**2)
image_indices = range(400**2)
results = distances.reshape(36, 2, 400**2)[flatten_indices, :, image_indices].reshape(400, 400, 2).swapaxes(0, 2)

但是,我认为经常会出现这样的情况,即您的索引形状为维度的子集,并且其值包含另一个维度的索引。我希望有一个通用方法来对此进行切片。

因此让我们有一个 n 维的数组,形状为 (x1, x2, ..., xn) 并且假设我们有一个表示维度索引的数组,例如 xi,其形状为是原始数组形状的子集,不包含 xi。然后我希望有一个方法来切片这个数组。

我要找的函数是numpy.take_along_axis()

对于具体示例,唯一需要做的就是确保 nearest_center_indexargmin 的输出)与要切片的数组具有相同的维数。在示例中,这可以通过将 keepdims=True 传递给 normargmin 来实现,然后可以将其直接用作 numpy 函数的第二个参数。第三个参数应该是 xi 轴(在示例轴 0 中)。

在不通过 keepdims=True 的情况下,按照确切的示例,所述 objective 可以通过以下方式实现:

result = np.take_along_axis(distances, nearest_center_index[None,None,:,:], 0)[0]