使用 tf.data.Dataset() 向量化映射时出错

Error when vectorizing mapping with tf.data.Dataset()

我有一个通过 tf.data.Dataset.list_files() 检索到的图像数据集。

在我的 .map() 函数中,我读取并解码图像,如下所示:

def map_function(filepath):
    image = tf.io.read_file(filename=filepath)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
    return image

如果我使用(下面这个有效)

 dataset = tf.data.Dataset.list_files(file_pattern=...)
 dataset = dataset.map(map_function)
 for image in dataset.as_numpy_iterator():
    #Correctly outputs the numpy array, no error is displayed/encountered
    print(image)

但是,如果我使用(下面会抛出错误):

  dataset = tf.data.Dataset.list_files(file_pattern=...)
  dataset = dataset.batch(32).map(map_function)
  for image in dataset.as_numpy_iterator():
    #Error is displayed 
      print(image)

ValueError: Shape must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') with input shapes: [?].

现在,根据这个:https://www.tensorflow.org/guide/data_performance#vectorizing_mapping,代码应该不会失败并且应该优化预处理步骤(批处理与一次性处理)。

我的代码哪里出错了?

*** 如果我使用 map().batch() 它工作正常

错误发生是因为 map_function 需要未批处理的元素,但在第二个示例中你给了它批处理的元素。

https://www.tensorflow.org/guide/data_performance 中的示例很棘手,它定义了一个 increment 函数,该函数可以应用于批处理和非批处理元素,因为将 1 添加到 [1, 2, 3] 这样的批处理元素将导致 [2, 3, 4].

def increment(x):
    return x+1

要使用向量化,您需要编写一个 vectorized_map_function,它接受一个未批处理元素的向量,将 map 函数应用于向量中的每个元素,然后 returns 一个向量结果。

但在您的情况下,我认为向量化不会产生明显的影响,因为读取和解码文件的成本远高于调用函数的开销。当 map 函数非常便宜时,向量化最有影响力,以至于函数调用所花费的时间与在 map 函数中进行实际工作所花费的时间相当。