Tensorflow MirroredStrategy 将二维减半,但对象中的形状保持正确
Tensorflow MirroredStrategy halves the 2nd dimension, though shape in the object remains right
我最近尝试使用 MirroredStrategy 进行训练。
相关代码为:
dataset = tf.data.Dataset.from_tensor_slices((samples, feature))
model.fit(dataset,
batch_size=500, epochs=150,
#callbacks=[tensorboard_callback]
)
数据集打印为:
Out[1]: <TensorSliceDataset element_spec=(TensorSpec(shape=(16, 17, 10), dtype=tf.float64, name=None), TensorSpec(shape=(1000,), dtype=tf.float32, name=None))>
维度正确,
但我收到以下错误:
ValueError: Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 16, 17, 10), found shape=(8, 17, 10)
这很奇怪,正如文档所说,该策略会将第一个维度减半而不是第二个维度,它应该沿第一个轴将数据集拆分为 2。
有人知道问题出在哪里吗?
旁注:变量“feature”的名称具有误导性,应该改为标签。
Keras API 虽然在很多方面都很棒,但也很老套,特别是 fit 严重超载。
在你的情况下,不要在 fit 中使用 batch_size arg...而是通过
批处理你的数据集
数据集 = dataset.batch(500).
查看 keras fit 文档:
batch_size Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).
我最近尝试使用 MirroredStrategy 进行训练。 相关代码为:
dataset = tf.data.Dataset.from_tensor_slices((samples, feature))
model.fit(dataset,
batch_size=500, epochs=150,
#callbacks=[tensorboard_callback]
)
数据集打印为:
Out[1]: <TensorSliceDataset element_spec=(TensorSpec(shape=(16, 17, 10), dtype=tf.float64, name=None), TensorSpec(shape=(1000,), dtype=tf.float32, name=None))>
维度正确, 但我收到以下错误:
ValueError: Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 16, 17, 10), found shape=(8, 17, 10)
这很奇怪,正如文档所说,该策略会将第一个维度减半而不是第二个维度,它应该沿第一个轴将数据集拆分为 2。
有人知道问题出在哪里吗?
旁注:变量“feature”的名称具有误导性,应该改为标签。
Keras API 虽然在很多方面都很棒,但也很老套,特别是 fit 严重超载。
在你的情况下,不要在 fit 中使用 batch_size arg...而是通过
批处理你的数据集数据集 = dataset.batch(500).
查看 keras fit 文档:
batch_size Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).