在 PyTorch 中进行数据增强后得到坏图像

Getting Bad Images After Data Augmentation in PyTorch

我正在研究核分割问题,试图在染色组织图像中识别核的位置。给定的训练数据集有一张染色组织的图片和一个带有细胞核位置的掩码。由于数据集很小,我想尝试在 PyTorch 中进行数据扩充,但在这样做之后,由于某种原因,当我输出我的蒙版图像时,它看起来很好,但相应的组织图像不正确。

我所有的训练图像都在 X_train 中,形状为 (128, 128, 3),对应的掩码在 Y_train 中,形状为 (128, 128, 1) 类似地,交叉验证图像和掩码在X_valY_val 分别。

Y_trainY_valdtype = np.boolX_trainX_valdtype = np.uint8

在数据增强之前,我这样检查我的图像:

fig, axis = plt.subplots(2, 2)
axis[0][0].imshow(X_train[0].astype(np.uint8))
axis[0][1].imshow(np.squeeze(Y_train[0]).astype(np.uint8))
axis[1][0].imshow(X_val[0].astype(np.uint8))
axis[1][1].imshow(np.squeeze(Y_val[0]).astype(np.uint8))

输出结果如下: Before Data Augmentation

为了数据扩充,我自定义了一个class如下:

这里我将 torchvision.transforms.functional 导入为 TFtorchvision.transforms as transformsimages_npmasks_np 是 numpy 数组的输入。

class Nuc_Seg(Dataset):
def __init__(self, images_np, masks_np):
    self.images_np = images_np
    self.masks_np = masks_np

def transform(self, image_np, mask_np):
    ToPILImage = transforms.ToPILImage()
    image = ToPILImage(image_np)
    mask = ToPILImage(mask_np.astype(np.int32))

    angle = random.uniform(-10, 10)
    width, height = image.size
    max_dx = 0.2 * width
    max_dy = 0.2 * height
    translations = (np.round(random.uniform(-max_dx, max_dx)), np.round(random.uniform(-max_dy, max_dy)))
    scale = random.uniform(0.8, 1.2)
    shear = random.uniform(-0.5, 0.5)
    image = TF.affine(image, angle = angle, translate = translations, scale = scale, shear = shear)
    mask = TF.affine(mask, angle = angle, translate = translations, scale = scale, shear = shear)

    image = TF.to_tensor(image)
    mask = TF.to_tensor(mask)
    return image, mask

def __len__(self):
    return len(self.images_np)

def __getitem__(self, idx):
    image_np = self.images_np[idx]
    mask_np = self.masks_np[idx]
    image, mask = self.transform(image_np, mask_np)

    return image, mask    

接下来是:

我用过from torch.utils.data import DataLoader

train_dataset = Nuc_Seg(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
val_dataset = Nuc_Seg(X_val, Y_val)
val_loader = DataLoader(val_dataset, batch_size = 16, shuffle = True)

完成这一步后,我尝试使用以下方法检查我的第一个训练图像和蒙版:

%matplotlib inline

for ex_img, ex_mask in train_loader:

    img = ex_img[0]
    img = img.reshape(128, 128, 3)
    mask = ex_mask[0]
    mask = mask.reshape(128, 128)

    img = img.numpy()
    mask = mask.numpy()

    fig, (axis_1, axis_2) = plt.subplots(1, 2)
    axis_1.imshow(img.astype(np.uint8))
    axis_2.imshow(mask.astype(np.uint8))

    break

我得到这个作为我的输出: After Data Augmentation 1

当我将行 axis_1.imshow(img.astype(np.uint8)) 更改为 axis_1.imshow(img) 时,

我得到这张图片: After Data Augmentation 2

mask的图像是正确的,但是不知为何,细胞核的图像是错误的。使用.astype(np.uint8),组织图像是完全黑色的。

没有.astype(np.uint8),原子核的位置是正确的,但是配色方案全乱了(我希望图像像数据增强之前看到的那样,灰色或粉红色),由于某种原因,在网格中显示了 9 个相同图像的副本。你能帮我得到组织图像的正确输出吗?

您正在将图像转换为 PyTorch 张量,在 PyTorch 中,图像的大小为 [C, H, W]。当您将它们可视化时,您正在将张量转换回 NumPy 数组,其中图像的大小为 [H, W, C]。因此,您正在尝试重新排列维度,但您使用的是 torch.reshape,它不会交换维度,而只会以不同的方式对数据进行分区。

一个例子使这一点更清楚:

# Incrementing numbers with size 2 x 3 x 3
image = torch.arange(2 * 3 * 3).reshape(2, 3, 3)
# => tensor([[[ 0,  1,  2],
#             [ 3,  4,  5],
#             [ 6,  7,  8]],
#
#            [[ 9, 10, 11],
#             [12, 13, 14],
#             [15, 16, 17]]])

# Reshape keeps the same order of elements but for a different size
# The numbers are still incrementing from left to right
image.reshape(3, 3, 2)
# => tensor([[[ 0,  1],
#             [ 2,  3],
#             [ 4,  5]],
#
#            [[ 6,  7],
#             [ 8,  9],
#             [10, 11]],
#
#            [[12, 13],
#             [14, 15],
#             [16, 17]]])

要重新排序维度,您可以使用 permute:

# Dimensions are swapped
# Now the numbers increment from top to bottom
image.permute(1, 2, 0)
# => tensor([[[ 0,  9],
#             [ 1, 10],
#             [ 2, 11]],
#
#            [[ 3, 12],
#             [ 4, 13],
#             [ 5, 14]],
#
#            [[ 6, 15],
#             [ 7, 16],
#             [ 8, 17]]])

With the .astype(np.uint8), the tissue image is completely black.

PyTorch 图像表示为值介于 [0, 1] 之间的浮点数,但 NumPy 使用介于 [0, 255] 之间的整数值。将浮点值转换为 np.uint8 将只产生 0 和 1,其中不等于 1 的所有内容都将设置为 0,因此整个图像为黑色。

您需要将这些值乘以 255 才能将它们置于 [0, 255] 范围内。

img = img.permute(1, 2, 0) * 255
img = img.numpy().astype(np.uint8)

当您使用 transforms.ToPILImage (or with TF.to_pil_image 将张量转换为 PIL 图像时,如果您更喜欢函数式版本,此转换也会自动完成)并且 PIL 图像可以直接转换为 NumPy 数组。这样您就不必担心尺寸、值范围或类型,上面的代码可以替换为:

img = np.array(TF.to_pil_image(img))