将 tensorflow tf.data.Dataset FlatMapDataset 转换为 TensorSliceDataset

Convert a tensorflow tf.data.Dataset FlatMapDataset to TensorSliceDataset

我想将 tf.Strings 的列表传递给 .map(_parse_function) 函数。

 def _parse_function(self, img_path):
        img_str = tf.read_file(img_path)
        img_decode = tf.image.decode_jpeg(img_str, channels=3)
        img_decode = tf.divide(tf.cast(img_decode , tf.float32),255)
        return img_decode

tf.data.DatasetTensorSliceDataset 类型时,

dataset_from_slices = tf.data.Dataset.from_tensor_slices((tensor_with_filenames))

我可以简单地做 dataset_from_slices.map(_parse_function),有效。

但是,dataset_from_generator = tf.data.Dataset.from_generator(...) returns 一个 DatasetFlatMapDataset 类型的一个实例并且 dataset_from_generator.map(_parse_function) 给出了以下错误:

InvalidArgumentError: Input filename tensor must be scalar, but had shape: [32]

如果我将第一行更改为:

img_str = tf.read_file(img_path[0])

这也有效,但我只得到第一张图片,这不是我要找的。有什么建议吗?

听起来你的 dataset_from_generator 的元素是批处理的。最简单的补救措施是使用 tf.contrib.data.unbatch() 将它们转换回单个元素:

# Each element is a vector of strings.
dataset_from_generator = tf.data.Dataset.from_generator(...)

# Converts each vector of strings into multiple individual elements.
dataset = dataset_from_generator.apply(tf.contrib.data.unbatch())

dataset = dataset.map(_parse_function)