当只需要第一个元素时,为什么要创建一个新轴?

Why is a new axis created, when only the first element is needed?

首先,对于模糊的标题感到抱歉

由于我有兴趣了解有关 TensorFlow 和图像分割的更多信息,所以我一直在学习他们的教程 (https://www.tensorflow.org/tutorials/images/segmentation)。但是,我注意到一些我无法完全理解的东西,在谷歌搜索之后也没有。

本节内容:

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

首先为 pred_mask 向量创建一个新轴的原因是什么,然后才只选择第一个元素?为什么和我想象的不一样,如下图:

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    return pred_mask

axis=-1 调用 tf.argmax 使张量松开最后一个通道。这是通过 tf.newaxis.

作为单例通道添加回来的

然后你 return 批处理的第一个元素。简而言之:

(batch_size, height, width, channels)  # original tensor shape
(batch_size, height, width)            # after argmax
(batch_size, height, width, 1)         # after unsqueeze
(height, width, 1)                     # this is what you are returning

只是为了保持图像是 3D 张量。例如,如果您有形状为 (1, 256, 256, 10) 的预测(一批 256x256 图像 10 类),在 tf.argmax() 之后您将收到一个形状为 (1, 256, 256) 的张量(一批一张没有通道的 256x256 图像)。但通常情况下,使用 HWC 格式 (Height, Width, Channel) 而不是 (Height, Width) 的图像会更容易。例如,如果你使用matplotlib或者OpenCV,你通常需要HWC图像。

在图像分割中,U-Net 架构的输入是 (samples, H, W, channels),你得到的输出是 (samples, H, W, n-classes) --(sample is the batch of images)-- 让我们为这个样本拍摄一张图像,它将是 (1, H, W,n_classes) 这意味着图像的每个像素都可以属于 n-classes具有不同的概率,所以我们想知道每个像素的最高 class 概率,因此我们将此像素设置为特定的 class。使用 argmax(axis=-1) 将为您提供 (1, H, W) 的真值和假值掩码,n-classes 被丢弃,我们有 selected 像素但丢弃了它的 class,使用这个掩码你将无法从 (1, H, W,n-classes) 中 select,我们只想添加返回更多维度到我们拥有的 (1, H, W) 掩码的末尾。 在 pred_mask[..., tf.newaxis]tf.expand_dims(a,axis=1) 我们可以将掩码构造为 (1,H,W,TheHeighestprobablityclass)

np.random.seed(2)
pixel=np.random.random((1,2,2,3)) #
pixel[0,0,1,:] #pixel(1,1) #with three classes probability  it should belong to class 0

输出

array([0.43532239, 0.4203678 , 0.33033482]) #0 is heighest

检查像素(2x2 图像)每个像素有 3 classes

pixel

输出

array([[[[0.4359949 , 0.02592623, 0.54966248],
         [0.43532239, 0.4203678 , 0.33033482]],

        [[0.20464863, 0.61927097, 0.29965467],
         [0.26682728, 0.62113383, 0.52914209]]]])

形状 (1, 2, 2, 3)

面具

mask=pixel.argmax(-1)
mask #you can't select from above with this mask cuz of it's shape add more dim to match the image

输出

     array([[[2, 0],
                [1, 1]]], dtype=int64)
      

mask.shape (1, 2, 2)

mask=mask[..., np.newaxis] #pixel (0,1) mask say 0 which is What we expect
mask.shape
mask.shape #now we can use this mask to select max class  from each pixel of  our 2x2 image
    

输出

   array([[[[2],
             [0]],
    
            [[1],
             [1]]]], dtype=int64)

形状 (1, 2, 2, 1) 现在您可以将蒙版与图像一起使用,这就是他添加更多暗淡的原因。