如何在 Tensorflow 中使用 train_date.take(1)

How to use train_date.take(1) with Tensorflow

我正在使用 tensorflow 进行测试。我将我的数据集放入两个文件夹中。我为 train_data 配置了 batch_sizeheightwidth,但是我无法使用 matplotlib 看到它们或在模型中使用它们。

#Import dataset
import pathlib
import os

data_dir = pathlib.Path(r'C:\Users\vion1\Ele\Engie\Exercices\DL\Pikachu\dataset')
image_count = len(list(data_dir.glob('*/*')))
print(image_count)
#374

batch_size = 32
img_height = 256
img_width = 256

train_data = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=42,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  )

class_names = train_data.class_names
print(train_data)
#Found 374 files belonging to 2 classes.
#Using 300 files for training.
#<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

plt.figure(figsize=(10, 10))
for images, labels in train_data.take(1):
  for i in range(3):
    ax = plt.subplot(1, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.axis("off")

错误是:

InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
     [[{{node decode_image/DecodeImage}}]] [Op:IteratorGetNext]

我认为 train_date.take(1) 不接受文件,但我不明白为什么以及如何修复它,知道吗?

您提到的代码看起来正确,失败的主要原因可能是根据错误,您的 tf.data.Dataset 中的一个或多个文件不属于任何提到的文件扩展名。 要检查损坏的文件,您可以参考以下代码。 在这里,我使用 document

中提到的示例数据集
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf

from tensorflow import keras

import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

roses = list(data_dir.glob('roses/*'))

现在,让我们检查 roses 目录中的唯一文件名。

file_names = [str(i) for i in roses]
unique_files = set(i.split('.')[-1] for i in file_names)
print(unique_files)

Output:
{'jpg'}

如果您在输出目录中得到允许的文件类型以外的任何文件类型,则需要重新检查您的数据。 否则,您可以按照 this 文档执行相同的过程。