如何在应用地图后保持 tf.dataset 的形状?
How to keep the shape of tf.dataset after applying map?
我正在创建一个 tf.dataset 对象,其中包含 2 个图像作为输入和一个蒙版作为目标。它们都是 3D 的。应用自定义地图后,对象的形状从 <RepeatDataset shapes: (((), ()), ()), types: ((tf.string, tf.string), tf.string)>
变为 <PrefetchDataset shapes: (<unknown>, <unknown>, <unknown>), types: (tf.float32, tf.float32, tf.int32)>
,当我拟合数据时,我的模型抛出错误,因为它只检测到一个输入而不是 2 个。
这是我正在做的事情:
x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')
path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)
ds = dataset. \
map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
tf.py_function(load_mask, [zz], [tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
[tf.float32, tf.float32, tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)
我无法单独映射图像和蒙版,因为它们需要相同的种子来进行随机裁剪和翻转。是否可以在贴图之后更改形状,以便我可以将其提供给我的 2 输入模型?
编辑:
我的random_crop_flip函数如下:
def random_crop_flip(images, mask, width=128, height=128, depth=128):
img_bl, img_fu = images
img_bl = img_bl.numpy()
img_fu = img_fu.numpy()
mask = mask.numpy()
x_rand = random.randint(0, img_bl.shape[2] - width)
y_rand = random.randint(0, img_bl.shape[1] - height)
z_rand = random.randint(0, img_bl.shape[3] - depth)
img_bl_f = img_bl[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
img_fu_f = img_fu[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
mask_f = mask[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
flip_x = random.choice([True, False])
flip_y = random.choice([True, False])
flip_z = random.choice([True, False])
if flip_x:
img_bl_f = np.flip(img_bl_f, axis=2)
img_fu_f = np.flip(img_fu_f, axis=2)
mask_f = np.flip(mask_f, axis=2)
if flip_y:
img_bl_f = np.flip(img_bl_f, axis=1)
img_fu_f = np.flip(img_fu_f, axis=1)
mask_f = np.flip(mask_f, axis=1)
if flip_z:
img_bl_f = np.flip(img_bl_f, axis=3)
img_fu_f = np.flip(img_fu_f, axis=3)
mask_f = np.flip(mask_f, axis=3)
images = zip(img_bl_f, img_fu_f)
return images, mask_f
zip 无法解决我的问题。是否可以修改 return 以获得我想要的输出?
我设法通过“展平”(消除括号)random_crop_flip 的 return 并在它们之上应用另一个贴图来解决这个问题,我在其中指定了形状和return编辑了我想要的结构 (x ,y), z:
def _set_shapes(img_bl, img_fu, mask):
img_bl.set_shape([128, 128, 128, 1])
img_fu.set_shape([128, 128, 128, 1])
mask.set_shape([128, 128, 128, 1])
return (img_bl, img_fu), mask
那么我的代码如下所示:
x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')
path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)
ds = dataset. \
map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
tf.py_function(load_mask, [zz], [tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
[tf.float32, tf.float32, tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(_set_shapes)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)
我正在创建一个 tf.dataset 对象,其中包含 2 个图像作为输入和一个蒙版作为目标。它们都是 3D 的。应用自定义地图后,对象的形状从 <RepeatDataset shapes: (((), ()), ()), types: ((tf.string, tf.string), tf.string)>
变为 <PrefetchDataset shapes: (<unknown>, <unknown>, <unknown>), types: (tf.float32, tf.float32, tf.int32)>
,当我拟合数据时,我的模型抛出错误,因为它只检测到一个输入而不是 2 个。
这是我正在做的事情:
x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')
path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)
ds = dataset. \
map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
tf.py_function(load_mask, [zz], [tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
[tf.float32, tf.float32, tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)
我无法单独映射图像和蒙版,因为它们需要相同的种子来进行随机裁剪和翻转。是否可以在贴图之后更改形状,以便我可以将其提供给我的 2 输入模型?
编辑:
我的random_crop_flip函数如下:
def random_crop_flip(images, mask, width=128, height=128, depth=128):
img_bl, img_fu = images
img_bl = img_bl.numpy()
img_fu = img_fu.numpy()
mask = mask.numpy()
x_rand = random.randint(0, img_bl.shape[2] - width)
y_rand = random.randint(0, img_bl.shape[1] - height)
z_rand = random.randint(0, img_bl.shape[3] - depth)
img_bl_f = img_bl[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
img_fu_f = img_fu[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
mask_f = mask[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
flip_x = random.choice([True, False])
flip_y = random.choice([True, False])
flip_z = random.choice([True, False])
if flip_x:
img_bl_f = np.flip(img_bl_f, axis=2)
img_fu_f = np.flip(img_fu_f, axis=2)
mask_f = np.flip(mask_f, axis=2)
if flip_y:
img_bl_f = np.flip(img_bl_f, axis=1)
img_fu_f = np.flip(img_fu_f, axis=1)
mask_f = np.flip(mask_f, axis=1)
if flip_z:
img_bl_f = np.flip(img_bl_f, axis=3)
img_fu_f = np.flip(img_fu_f, axis=3)
mask_f = np.flip(mask_f, axis=3)
images = zip(img_bl_f, img_fu_f)
return images, mask_f
zip 无法解决我的问题。是否可以修改 return 以获得我想要的输出?
我设法通过“展平”(消除括号)random_crop_flip 的 return 并在它们之上应用另一个贴图来解决这个问题,我在其中指定了形状和return编辑了我想要的结构 (x ,y), z:
def _set_shapes(img_bl, img_fu, mask):
img_bl.set_shape([128, 128, 128, 1])
img_fu.set_shape([128, 128, 128, 1])
mask.set_shape([128, 128, 128, 1])
return (img_bl, img_fu), mask
那么我的代码如下所示:
x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')
path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)
ds = dataset. \
map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
tf.py_function(load_mask, [zz], [tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
[tf.float32, tf.float32, tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(_set_shapes)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)