使用 tf.data.Dataset() 向量化映射时出错
Error when vectorizing mapping with tf.data.Dataset()
我有一个通过 tf.data.Dataset.list_files()
检索到的图像数据集。
在我的 .map()
函数中,我读取并解码图像,如下所示:
def map_function(filepath):
image = tf.io.read_file(filename=filepath)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
return image
如果我使用(下面这个有效)
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.map(map_function)
for image in dataset.as_numpy_iterator():
#Correctly outputs the numpy array, no error is displayed/encountered
print(image)
但是,如果我使用(下面会抛出错误):
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.batch(32).map(map_function)
for image in dataset.as_numpy_iterator():
#Error is displayed
print(image)
ValueError: Shape must be rank 0 but is rank 1 for 'ReadFile' (op:
'ReadFile') with input shapes: [?].
现在,根据这个:https://www.tensorflow.org/guide/data_performance#vectorizing_mapping,代码应该不会失败并且应该优化预处理步骤(批处理与一次性处理)。
我的代码哪里出错了?
*** 如果我使用 map().batch()
它工作正常
错误发生是因为 map_function
需要未批处理的元素,但在第二个示例中你给了它批处理的元素。
https://www.tensorflow.org/guide/data_performance 中的示例很棘手,它定义了一个 increment
函数,该函数可以应用于批处理和非批处理元素,因为将 1 添加到 [1, 2, 3] 这样的批处理元素将导致 [2, 3, 4].
def increment(x):
return x+1
要使用向量化,您需要编写一个 vectorized_map_function
,它接受一个未批处理元素的向量,将 map 函数应用于向量中的每个元素,然后 returns 一个向量结果。
但在您的情况下,我认为向量化不会产生明显的影响,因为读取和解码文件的成本远高于调用函数的开销。当 map 函数非常便宜时,向量化最有影响力,以至于函数调用所花费的时间与在 map 函数中进行实际工作所花费的时间相当。
我有一个通过 tf.data.Dataset.list_files()
检索到的图像数据集。
在我的 .map()
函数中,我读取并解码图像,如下所示:
def map_function(filepath):
image = tf.io.read_file(filename=filepath)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
return image
如果我使用(下面这个有效)
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.map(map_function)
for image in dataset.as_numpy_iterator():
#Correctly outputs the numpy array, no error is displayed/encountered
print(image)
但是,如果我使用(下面会抛出错误):
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.batch(32).map(map_function)
for image in dataset.as_numpy_iterator():
#Error is displayed
print(image)
ValueError: Shape must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') with input shapes: [?].
现在,根据这个:https://www.tensorflow.org/guide/data_performance#vectorizing_mapping,代码应该不会失败并且应该优化预处理步骤(批处理与一次性处理)。
我的代码哪里出错了?
*** 如果我使用 map().batch()
它工作正常
错误发生是因为 map_function
需要未批处理的元素,但在第二个示例中你给了它批处理的元素。
https://www.tensorflow.org/guide/data_performance 中的示例很棘手,它定义了一个 increment
函数,该函数可以应用于批处理和非批处理元素,因为将 1 添加到 [1, 2, 3] 这样的批处理元素将导致 [2, 3, 4].
def increment(x):
return x+1
要使用向量化,您需要编写一个 vectorized_map_function
,它接受一个未批处理元素的向量,将 map 函数应用于向量中的每个元素,然后 returns 一个向量结果。
但在您的情况下,我认为向量化不会产生明显的影响,因为读取和解码文件的成本远高于调用函数的开销。当 map 函数非常便宜时,向量化最有影响力,以至于函数调用所花费的时间与在 map 函数中进行实际工作所花费的时间相当。