从 tensorflow-dataset 获取特征时出错
Error when getting features from tensorflow-dataset
我在尝试加载 Caltech tensorflow-dataset 时遇到错误。我使用的是 tensorflow-datasets GitHub
中的标准代码
错误是这样的:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [204,300,3] and element 1 had shape [153,300,3]. [Op:IteratorGetNextSync]
错误指向行for features in ds_train.take(1)
代码:
ds_train, ds_test = tfds.load(name="caltech101", split=["train", "test"])
ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
for features in ds_train.take(1):
image, label = features["image"], features["label"]
问题出在数据集包含可变大小的图像这一事实(请参阅数据集描述 here)。 Tensorflow 只能将具有相同形状的事物组合在一起,因此您首先需要将图像重塑为常见形状(例如,网络的输入形状)或相应地填充它们。
如果要调整大小,请使用 tf.image.resize_images:
def preprocess(features, label):
features['image'] = tf.image.resize_images(features['image'], YOUR_TARGET_SIZE)
# Other possible transformations needed (e.g., converting to float, normalizing to [0,1]
return features, label
如果您想要填充,请使用 tf.image.pad_to_bounding_box(只需在上面的 preprocess
函数中替换它并根据需要调整参数)。
通常,对于我所知道的大多数网络,都会使用调整大小。
最后,将函数映射到您的数据集:
ds_train = (ds_train
.map(prepocess)
.shuffle(1000)
.batch(128)
.prefetch(10))
注意:错误代码中的变量形状来自shuffle
调用。
我在尝试加载 Caltech tensorflow-dataset 时遇到错误。我使用的是 tensorflow-datasets GitHub
中的标准代码错误是这样的:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [204,300,3] and element 1 had shape [153,300,3]. [Op:IteratorGetNextSync]
错误指向行for features in ds_train.take(1)
代码:
ds_train, ds_test = tfds.load(name="caltech101", split=["train", "test"])
ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
for features in ds_train.take(1):
image, label = features["image"], features["label"]
问题出在数据集包含可变大小的图像这一事实(请参阅数据集描述 here)。 Tensorflow 只能将具有相同形状的事物组合在一起,因此您首先需要将图像重塑为常见形状(例如,网络的输入形状)或相应地填充它们。
如果要调整大小,请使用 tf.image.resize_images:
def preprocess(features, label):
features['image'] = tf.image.resize_images(features['image'], YOUR_TARGET_SIZE)
# Other possible transformations needed (e.g., converting to float, normalizing to [0,1]
return features, label
如果您想要填充,请使用 tf.image.pad_to_bounding_box(只需在上面的 preprocess
函数中替换它并根据需要调整参数)。
通常,对于我所知道的大多数网络,都会使用调整大小。
最后,将函数映射到您的数据集:
ds_train = (ds_train
.map(prepocess)
.shuffle(1000)
.batch(128)
.prefetch(10))
注意:错误代码中的变量形状来自shuffle
调用。