如何访问 tf.data.Dataset.list_files() 收集的文件名?

How can I access the filenames gathered by tf.data.Dataset.list_files()?

我正在使用

file_data = tf.data.Dataset.list_files("../*.png")

收集用于在 TensorFlow 中训练的图像文件,但想要访问收集的文件名列表以便我可以执行标签查找。

调用 sess.run([file_data]) 失败:

TypeError: Fetch argument <TensorSliceDataset shapes: (), types: tf.string> has invalid type <class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>, must be a string or Tensor. (Can not convert a TensorSliceDataset into a Tensor or Operation.)

还有其他方法可以使用吗?

通过一些额外的实验,我找到了解决这个问题的方法:

首先,将 Dataset 变成一个迭代器:

iterator_helper = file_data.make_one_shot_iterator()

然后,遍历 tf Session 中的元素:

with tf.Session() as sess:
    filename_temp = iterator_helper.get_next()
    print(sess.run[filename_temp])

Dataset.list_files() API 使用 tf.matching_files() op 列出匹配给定模式的文件。您还可以使用该操作将文件列表作为 tf.Tensor 获取,并将其直接传递给 sess.run():

filenames_as_tensor = tf.matching_files("../*.png")
filenames_as_array = sess.run(filenames_as_tensor)

for filename in filenames_as_array:
  print(filename)

这是我在 Tensorflow 2

中的做法
def load_image_train(image_file):
  """ 
      a function to load image and return
      the image and it's address
  """
  my_image = load_image_func(image_file)    
  return my_image, image_file

然后使用tf.data.Dataset.list_files加载文件夹中的文件列表:

PATH = "path_to_dataset_folder"
train_dataset_names = tf.data.Dataset.list_files(os.path.join(PATH , 'train/*.jpg'))  

最后映射它们,这样你就可以将“文件地址”和“数据”都作为张量:

train_dataset = train_dataset_names.map(load_image_train,
                         num_parallel_calls=tf.data.AUTOTUNE)

然后你可以根据需要将它们分开,并使用文件名作为标签或其他任何东西。