将具有多维输出的 tf.data 数据集提供给 Keras 模型

Feeding tf.data Dataset with multidimensional output to Keras model

我想将 tf.data 数据集提供给 Keras 模型,但出现以下错误:

AttributeError: 'DatasetV1Adapter' object has no attribute 'ndim'

此数据集将用于解决分割问题,因此输入和输出均为图像(3D 张量)

数据集是使用以下代码创建的:

dataset = tf.data.Dataset.list_files(TRAIN_PATH + "*.png",shuffle=False)

def process_path(file_path):
  img = tf.io.read_file(file_path)
  img = tf.image.decode_png(img, channels=3)
  train_image_path=tf.strings.regex_replace(file_path,"image","mask")
  mask = tf.io.read_file(train_image_path)
  mask = tf.image.decode_png(mask, channels=1)
  mask = tf.squeeze(mask)
  mask = tf.one_hot(tf.cast(mask, tf.int32), Num_Classes, axis = -1)
  return img,mask

dataset = dataset.map(process_path)
dataset = dataset.batch(32,drop_remainder=True)

从数据集中取一个项目显示我得到一个包含输入张量和输出张量的元组,其维度是正确的:

输入:(批量大小、图像高度、图像宽度、3 个通道)

输出:(批量大小、图像高度、图像宽度、4 个通道)

拟合模型时出现错误:

model.fit(dataset, epochs = 50)

我已经解决了迁移到 Keras 2.4.3 和 Tensorflow 2.2 的问题

一切正常,但显然之前版本的 Keras 没有正确处理 tf.data。

Here 我发现这方面的教程非常有用。