tensorflow 2 指南使用文件集会抛出切片索引超出范围

tensorflow 2 guide Consuming sets of files throws slice index out of bounds

尝试使用 consuming sets of files 中的代码(请参阅下文)会抛出维度 0 的切片索引 -1 超出范围(请参阅下面的输出)。

有没有人让这段代码起作用?

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def process_path(file_path):
  label = tf.strings.split(file_path, '/')[-2]
  return tf.io.read_file(file_path), label
flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)
for item in flowers_root.glob("*"):
  print(item.name)
list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))
for f in list_ds.take(5):
  print(f.numpy())
labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1): #this line throws
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())


$ py consume.py
daisy
dandelion
LICENSE.txt
roses
sunflowers
tulips
b'C:\Users\ray\.keras\datasets\flower_photos\dandelion\5024965767_230f140d60_n.jpg'
b'C:\Users\ray\.keras\datasets\flower_photos\dandelion\6012046444_fd80afb63a_n.jpg'
b'C:\Users\ray\.keras\datasets\flower_photos\roses\2677417735_a697052d2d_n.jpg'
b'C:\Users\ray\.keras\datasets\flower_photos\tulips\8677713853_1312f65e71.jpg'
b'C:\Users\ray\.keras\datasets\flower_photos\tulips\14235021006_dd001ea8ed_n.jpg'
2019-11-16 18:06:05.468670: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at strided_slice_op.cc:108 : Invalid argument: slice index -1 of dimension 0 out of bounds.
2019-11-16 18:06:05.481203: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at iterator_ops.cc:929 : Invalid argument: {{function_node __inference_Dataset_map_process_path_106}} slice index -1 of dimension 0 out of bounds.
         [[{{node strided_slice}}]]
Traceback (most recent call last):
  File "consume.py", line 21, in <module>
    for image_raw, label_text in labeled_ds.take(1):
  File "d:\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 622, in __next__
    return self.next()
  File "d:\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 666, in next
    return self._next_internal()
  File "d:\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 651, in _next_internal
    output_shapes=self._flat_output_shapes)
  File "d:\Anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_dataset_ops.py", line 2672, in iterator_get_next_sync
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_Dataset_map_process_path_106}} slice index -1 of dimension 0 out of bounds.
         [[{{node strided_slice}}]] [Op:IteratorGetNextSync]

我认为这是因为您在 Windows 上 运行。尝试切换以下行

label = tf.strings.split(file_path, '/')[-2]

label = tf.strings.split(file_path, '\')[-2]