从具有多维掩码的多维 Numpy 数组中选择

Selecting From multidimensional Numpy array with multidimensional mask

我正在尝试构建一个示例来理解图像分割, 你得到一个形状为 (1,2,2,3) 的图像,它是一个 2x2 图像,其中每个像素都有 3 个数字,表示该像素属于特定 class 的概率。 我想要的是输出 (1,2,2,1) 这是每个像素的 class 基于概率值(select 最高)下面的代码来显示问题。

np.random.seed(2)
pixel=np.random.random((1,2,2,3)) #the image
pixel[0,0,1,:] #pixel(1,1) #with three classes probability  it should belong to class 0

输出

array([0.43532239, 0.4203678 , 0.33033482]) #0 is the heighest

我制作的面具

mask=pixel.argmax(-1)
mask=mask[...,np.newaxis] #shape (1,2,2,1)

现在我有了图像和遮罩,但我不知道如何select使用它

请告诉我如何解决这个问题,以及在哪里可以学习这种切片和 NumPy 中的 select。

You can think about it

输入:(1,2,2,3) 每个像素有三个 class 的图像和 (1,2,2,1) class 到select

输出:(1,2,2,1) 每个像素只有一个 class 的图像

我想你想使用:

np.take_along_axis(pixel, mask, axis = -1)

您也可以通过以下方式在不使用遮罩的情况下获得相同的结果:

pixel.max(axis = -1, keepdims = True)