扁平化张量流数据集中的图像元组

Flattening tuple of images in tensorflow dataset

我有一个从 tfrecords 读取的三重图像数据集,我已使用以下代码将其转换为数据集

    def parse_dataset(record):
        def convert_raw_to_image_tensor(raw):
            raw = tf.io.decode_base64(raw)
            image_shape = tf.stack([299, 299, 3])
            decoded = tf.io.decode_image(raw, channels=3, 
                                dtype=tf.uint8, expand_animations=False)
            decoded = tf.cast(decoded, tf.float32)
            decoded = tf.reshape(decoded, image_shape)
            decoded = tf.math.divide(decoded, 255.)
            return decoded

        features = {
            'n': tf.io.FixedLenFeature([], tf.string),
            'p': tf.io.FixedLenFeature([], tf.string),
            'q': tf.io.FixedLenFeature([], tf.string)
        }
        sample = tf.io.parse_single_example(record, features)
        neg_image = sample['n']
        pos_image = sample['p']
        query_image = sample['q']

        neg_decoded = convert_raw_to_image_tensor(neg_image)
        pos_decoded = convert_raw_to_image_tensor(pos_image)
        query_decoded = convert_raw_to_image_tensor(query_image)
        return (neg_decoded, pos_decoded, query_decoded)

    record_dataset = tf.data.TFRecordDataset(filenames=path_dataset, num_parallel_reads=4)
    record_dataset = record_dataset.map(parse_dataset)

此结果数据集的形状是

<MapDataset shapes: ((299, 299, 3), (299, 299, 3), (299, 299, 3)), types: (tf.float32, tf.float32, tf.float32)>

我认为这意味着每个条目包含 3 张图像(我通过遍历数据集并打印第 1、第 2 和第 3 个元素来确认)。我想把它展平,所以我得到了一个不包含任何元组但只包含一个平面图像列表的数据集。我试过使用 flat_map 但这只是将图像转换为 (299, 3) 我试过遍历数据集,将每个图像附加到列表,然后调用 convert_to_tensor_slices 但这真的很低效.

我读过 this question 但似乎没有帮助。

顺便说一句,这是我试过的flat_map代码

record_dataset = record_dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))

并且生成的数据集具有此形状

<FlatMapDataset shapes: ((299, 3), (299, 3), (299, 3)), types: (tf.float32, tf.float32, tf.float32)>

我认为你只是错误地解包元组。

应该这样做:

def flatten(*x):
  return tf.data.Dataset.from_tensor_slices([i for i in x])

flattened = record_dataset.flat_map(flatten)

这样:

for i in flattened:
  print(i.shape)

给出:

(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
(299, 299, 3)
...

符合预期