具有 TensorFlow TFRecord 数据集错误的 Keras 模型——等级未定义

Keras model with TensorFlow TFRecord Dataset error -- rank is undefined

我使用的是相当标准的 TFRecord 数据集。记录是 Example protobufs。 “图像”特征是由 tf.io.serialize_tensor.

序列化的 28 x 28 张量
feature_description = {
    "image": tf.io.FixedLenFeature((), tf.string),
    "label": tf.io.FixedLenFeature((), tf.int64)}

image_shape = (28, 28)

def preprocess(example):
    example = tf.io.parse_single_example(example, feature_description)
    image, label = example["image"], example["label"]
    image = tf.io.parse_tensor(image, out_type=tf.float64)
    return image, label

batch_size = 32
dataset = tf.data.TFRecordDataset("data/train.tfrecord")\
                 .map(preprocess).batch(batch_size).prefetch(1)

但是,我有以下简单的 Keras 模型:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=image_shape))
model.add(tf.keras.layers.Dense(10, activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])

每当我尝试用数据集拟合或预测这个模型时

model.fit(dataset)
model.predict(dataset)

我收到以下错误:

ValueError: Input 0 of layer sequential is incompatible with the layer: its rank is undefined, but the layer requires a defined rank.

奇怪的是,如果我改为通过 tf.data.Dataset.from_tensor_slices(images) 创建一个等效数据集,尽管它产生完全相同的项目,但不会发生错误。

模型需要推断单个输入形状。但是 preprocess 解析任何形状的序列化图像张量,这是在记录流式传输时即时完成的,因此不可能推断出所有数据的输入形状。

这很容易通过添加断言张量形状的 TF 函数来解决,tf.ensure_shape:

def preprocess(example):
    example = tf.io.parse_single_example(example, feature_description)
    image, label = example["image"], example["label"]
    image = tf.io.parse_tensor(image, out_type=tf.float64)
    image = tf.ensure_shape(image, image_shape)    # THE FIX
    return image, label