如何在 Tensorflow 中使用 train_date.take(1)
How to use train_date.take(1) with Tensorflow
我正在使用 tensorflow 进行测试。我将我的数据集放入两个文件夹中。我为 train_data
配置了 batch_size
、height
和 width
,但是我无法使用 matplotlib
看到它们或在模型中使用它们。
#Import dataset
import pathlib
import os
data_dir = pathlib.Path(r'C:\Users\vion1\Ele\Engie\Exercices\DL\Pikachu\dataset')
image_count = len(list(data_dir.glob('*/*')))
print(image_count)
#374
batch_size = 32
img_height = 256
img_width = 256
train_data = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=42,
image_size=(img_height, img_width),
batch_size=batch_size,
)
class_names = train_data.class_names
print(train_data)
#Found 374 files belonging to 2 classes.
#Using 300 files for training.
#<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>
plt.figure(figsize=(10, 10))
for images, labels in train_data.take(1):
for i in range(3):
ax = plt.subplot(1, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
错误是:
InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
[[{{node decode_image/DecodeImage}}]] [Op:IteratorGetNext]
我认为 train_date.take(1)
不接受文件,但我不明白为什么以及如何修复它,知道吗?
您提到的代码看起来正确,失败的主要原因可能是根据错误,您的 tf.data.Dataset
中的一个或多个文件不属于任何提到的文件扩展名。
要检查损坏的文件,您可以参考以下代码。
在这里,我使用 document
中提到的示例数据集
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow import keras
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
roses = list(data_dir.glob('roses/*'))
现在,让我们检查 roses 目录中的唯一文件名。
file_names = [str(i) for i in roses]
unique_files = set(i.split('.')[-1] for i in file_names)
print(unique_files)
Output:
{'jpg'}
如果您在输出目录中得到允许的文件类型以外的任何文件类型,则需要重新检查您的数据。
否则,您可以按照 this 文档执行相同的过程。
我正在使用 tensorflow 进行测试。我将我的数据集放入两个文件夹中。我为 train_data
配置了 batch_size
、height
和 width
,但是我无法使用 matplotlib
看到它们或在模型中使用它们。
#Import dataset
import pathlib
import os
data_dir = pathlib.Path(r'C:\Users\vion1\Ele\Engie\Exercices\DL\Pikachu\dataset')
image_count = len(list(data_dir.glob('*/*')))
print(image_count)
#374
batch_size = 32
img_height = 256
img_width = 256
train_data = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=42,
image_size=(img_height, img_width),
batch_size=batch_size,
)
class_names = train_data.class_names
print(train_data)
#Found 374 files belonging to 2 classes.
#Using 300 files for training.
#<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>
plt.figure(figsize=(10, 10))
for images, labels in train_data.take(1):
for i in range(3):
ax = plt.subplot(1, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
错误是:
InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
[[{{node decode_image/DecodeImage}}]] [Op:IteratorGetNext]
我认为 train_date.take(1)
不接受文件,但我不明白为什么以及如何修复它,知道吗?
您提到的代码看起来正确,失败的主要原因可能是根据错误,您的 tf.data.Dataset
中的一个或多个文件不属于任何提到的文件扩展名。
要检查损坏的文件,您可以参考以下代码。
在这里,我使用 document
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow import keras
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
roses = list(data_dir.glob('roses/*'))
现在,让我们检查 roses 目录中的唯一文件名。
file_names = [str(i) for i in roses]
unique_files = set(i.split('.')[-1] for i in file_names)
print(unique_files)
Output:
{'jpg'}
如果您在输出目录中得到允许的文件类型以外的任何文件类型,则需要重新检查您的数据。 否则,您可以按照 this 文档执行相同的过程。