将 BatchDataset 与 Keras VGG16 连接 preprocess_input
Connecting BatchDataset with Keras VGG16 preprocess_input
我正在使用 tf.keras.preprocessing.image_dataset_from_directory
得到一个 BatchDataset
,其中数据集有 10 个 类。
我正在尝试将此 BatchDataset
与 Keras VGG16
(docs) 网络集成。来自文档:
Note: each Keras Application expects a specific kind of input preprocessing. For VGG16, call tf.keras.applications.vgg16.preprocess_input
on your inputs before passing them to the model.
但是,我正在努力让这个 preprocess_input
与 BatchDataset
一起工作。 你能帮我弄清楚如何连接这两个点吗?
请看下面的代码:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(train_data_dir, image_size=(224, 224))
train_ds = tf.keras.applications.vgg16.preprocess_input(train_ds)
这将抛出 TypeError: 'BatchDataset' object is not subscriptable
:
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: 'BatchDataset' object is not subscriptable
来自 TypeError: 'DatasetV1Adapter' object is not subscriptable (from BatchDataset not subscriptable when trying to format Python dictionary as table) 建议使用:
train_ds = tf.keras.applications.vgg16.preprocess_input(
list(train_ds.as_numpy_iterator())
)
然而,这也失败了:
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: list indices must be integers or slices, not tuple
这都是使用 Python==3.10.3
和 tensorflow==2.8.0
。
我怎样才能让它工作?提前谢谢你。
好的,我明白了。我需要传递 tf.Tensor
,而不是 tf.data.Dataset
。可以通过遍历 Dataset
.
得到 Tensor
这可以通过几种方式完成:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(...)
# Option 1
batch_images = next(iter(train_ds))[0]
preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)
# Option 2:
for batch_images, batch_labels in train_ds:
preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)
如果将选项2转换成生成器,可以直接传入下游model.fit
。干杯!
我正在使用 tf.keras.preprocessing.image_dataset_from_directory
得到一个 BatchDataset
,其中数据集有 10 个 类。
我正在尝试将此 BatchDataset
与 Keras VGG16
(docs) 网络集成。来自文档:
Note: each Keras Application expects a specific kind of input preprocessing. For VGG16, call
tf.keras.applications.vgg16.preprocess_input
on your inputs before passing them to the model.
但是,我正在努力让这个 preprocess_input
与 BatchDataset
一起工作。 你能帮我弄清楚如何连接这两个点吗?
请看下面的代码:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(train_data_dir, image_size=(224, 224))
train_ds = tf.keras.applications.vgg16.preprocess_input(train_ds)
这将抛出 TypeError: 'BatchDataset' object is not subscriptable
:
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: 'BatchDataset' object is not subscriptable
来自 TypeError: 'DatasetV1Adapter' object is not subscriptable (from BatchDataset not subscriptable when trying to format Python dictionary as table) 建议使用:
train_ds = tf.keras.applications.vgg16.preprocess_input(
list(train_ds.as_numpy_iterator())
)
然而,这也失败了:
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: list indices must be integers or slices, not tuple
这都是使用 Python==3.10.3
和 tensorflow==2.8.0
。
我怎样才能让它工作?提前谢谢你。
好的,我明白了。我需要传递 tf.Tensor
,而不是 tf.data.Dataset
。可以通过遍历 Dataset
.
Tensor
这可以通过几种方式完成:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(...)
# Option 1
batch_images = next(iter(train_ds))[0]
preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)
# Option 2:
for batch_images, batch_labels in train_ds:
preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)
如果将选项2转换成生成器,可以直接传入下游model.fit
。干杯!