为什么相同的 Tensorflow 模型适用于数组列表但不适用于 tf.data.Dataset unbatched?

Why does the same Tensorflow model work with a list of arrays but doesn't work with tf.data.Dataset unbatched?

我有以下简单设置:

import tensorflow as tf

def make_mymodel():
  class MyModel(tf.keras.models.Model):

    def __init__(self):
      super(MyModel, self).__init__()

      self.model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(1, 2)),
        tf.keras.layers.Dense(1)
      ])

    def call(self, x):
      return self.model(x)

  mymodel = MyModel()

  return mymodel


model = make_mymodel()

X = [[[1, 1]],
     [[2, 2]],
     [[10, 10]],
     [[20, 20]],
     [[50, 50]]]
y = [1, 2, 10, 20, 50]

# ds_n_X = tf.data.Dataset.from_tensor_slices(X)
# ds_n_Y = tf.data.Dataset.from_tensor_slices(y)
# ds = tf.data.Dataset.zip((ds_n_X, ds_n_Y))
#
# for input, label in ds:
#   print(input.numpy(), label.numpy())

loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)

model.build((1, 2))
model.compile(optimizer='adam',
              loss=loss_fn)
model.summary()

model.fit(X, y, epochs=10)

print(model.predict([
  [[25, 25]]
]))

这很好用(尽管我得到了奇怪的预测),但是当我取消注释 ds 行并将 model.fit(X, y, epochs=10) 更改为 model.fit(ds, epochs=10) 时,我收到以下错误:

Traceback (most recent call last):
  File "example_dataset.py", line 51, in <module>
    model.fit(ds, epochs=10)
  ...


    ValueError: slice index 0 of dimension 0 out of bounds. for '{{node strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](Shape, strided_slice/stack, strided_slice/stack_1, strided_slice/stack_2)' with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <0>, input[2] = <1>, input[3] = <1>.

当我 运行 model.fit(ds.batch(2), epochs=10) 时错误得到解决(我向数据集添加了 batch 指令)。

我希望能够交替使用数组列表和 tf.data.Dataset,但出于某种原因,我需要向数据集添加批次维度才能使用 tf.data.Dataset。这是预期的行为还是我在概念上遗漏了什么?

因为模型期望输入为 (batch_dim, input_dim)。因此,对于您的数据,模型的每个输入都应该像 (None, 1, 2).

让我们按数组和数据集探索数据的维度。当您将输入定义为数组时,形状为:

>>> print(np.array(X).shape)
(5, 1, 2)

它符合模型的预期。但是当您使用数组定义数据集时,形状为:

>>> for input, label in ds.take(1):
        print(input.numpy().shape)
(1, 2)

这与模型的预期不兼容,如果我们对数据进行批处理:

>>> ds = ds.batch(1)
>>> for input, label in ds.take(1):
        print(input.numpy().shape)
(1, 1, 2)

然后,将数据集传递给model.fit()就可以了。