如何在应用地图后保持 tf.dataset 的形状?

How to keep the shape of tf.dataset after applying map?

我正在创建一个 tf.dataset 对象,其中包含 2 个图像作为输入和一个蒙版作为目标。它们都是 3D 的。应用自定义地图后,对象的形状从 <RepeatDataset shapes: (((), ()), ()), types: ((tf.string, tf.string), tf.string)> 变为 <PrefetchDataset shapes: (<unknown>, <unknown>, <unknown>), types: (tf.float32, tf.float32, tf.int32)>,当我拟合数据时,我的模型抛出错误,因为它只检测到一个输入而不是 2 个。 这是我正在做的事情:

x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')

path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)

ds = dataset. \
    map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
                        tf.py_function(load_mask, [zz], [tf.int32])),
        num_parallel_calls=tf.data.AUTOTUNE)

ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
                                           [tf.float32, tf.float32, tf.int32])),
            num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)

我无法单独映射图像和蒙版,因为它们需要相同的种子来进行随机裁剪和翻转。是否可以在贴图之后更改形状,以便我可以将其提供给我的 2 输入模型?

编辑:

我的random_crop_flip函数如下:

def random_crop_flip(images, mask, width=128, height=128, depth=128):
    img_bl, img_fu = images
    img_bl = img_bl.numpy()
    img_fu = img_fu.numpy()
    mask = mask.numpy()
    x_rand = random.randint(0, img_bl.shape[2] - width)
    y_rand = random.randint(0, img_bl.shape[1] - height)
    z_rand = random.randint(0, img_bl.shape[3] - depth)
    img_bl_f = img_bl[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
    img_fu_f = img_fu[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
    mask_f = mask[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
    flip_x = random.choice([True, False])
    flip_y = random.choice([True, False])
    flip_z = random.choice([True, False])

    if flip_x:
        img_bl_f = np.flip(img_bl_f, axis=2)
        img_fu_f = np.flip(img_fu_f, axis=2)
        mask_f = np.flip(mask_f, axis=2)

    if flip_y:
        img_bl_f = np.flip(img_bl_f, axis=1)
        img_fu_f = np.flip(img_fu_f, axis=1)
        mask_f = np.flip(mask_f, axis=1)

    if flip_z:
        img_bl_f = np.flip(img_bl_f, axis=3)
        img_fu_f = np.flip(img_fu_f, axis=3)
        mask_f = np.flip(mask_f, axis=3)

    images = zip(img_bl_f, img_fu_f)
    return images, mask_f

zip 无法解决我的问题。是否可以修改 return 以获得我想要的输出?

我设法通过“展平”(消除括号)random_crop_flip 的 return 并在它们之上应用另一个贴图来解决这个问题,我在其中指定了形状和return编辑了我想要的结构 (x ,y), z:

def _set_shapes(img_bl, img_fu, mask):
    img_bl.set_shape([128, 128, 128, 1])
    img_fu.set_shape([128, 128, 128, 1])
    mask.set_shape([128, 128, 128, 1])

    return (img_bl, img_fu), mask 

那么我的代码如下所示:

x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')

path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)

ds = dataset. \
    map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
                        tf.py_function(load_mask, [zz], [tf.int32])),
        num_parallel_calls=tf.data.AUTOTUNE)

ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
                                           [tf.float32, tf.float32, tf.int32])),
            num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(_set_shapes)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)