Tensorflow tf.data.Dataset API,数据集解压功能?

Tensorflow tf.data.Dataset API, dataset unzip function?

在 tensorflow 1.12 中有 Dataset.zip 函数:已记录 here

但是,我想知道是否有一个数据集解压功能可以return 返回原来的两个数据集。

# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { 4, 5, 6 }
c = { (7, 8), (9, 10), (11, 12) }
d = { 13, 14 }

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }

# The `datasets` argument may contain an arbitrary number of
# datasets.
Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
                            (2, 5, (9, 10)),
                            (3, 6, (11, 12)) }

# The number of elements in the resulting dataset is the same as
# the size of the smallest dataset in `datasets`.
Dataset.zip((a, d)) == { (1, 13), (2, 14) }

我想要以下

dataset = Dataset.zip((a, d)) == { (1, 13), (2, 14) }
a, d = dataset.unzip()

我的解决方法是只使用 map,但不确定以后是否会对 unzip 的语法糖函数感兴趣。

a = dataset.map(lambda a, b: a)
b = dataset.map(lambda a, b: b)

基于 Ouwen Huang 的回答,此函数似乎适用于任意数据集:

def split_datasets(dataset):
    tensors = {}
    names = list(dataset.element_spec.keys())
    for name in names:
        tensors[name] = dataset.map(lambda x: x[name])

    return tensors

我已经为 tf.data.Dataset 管道编写了一个更通用的解压缩函数,它还处理 递归情况,其中管道具有多个压缩级别。

import tensorflow as tf


def tfdata_unzip(
    tfdata: tf.data.Dataset,
    *,
    recursive: bool=False,
    eager_numpy: bool=False,
    num_parallel_calls: int=tf.data.AUTOTUNE,
):
    """
    Unzip a zipped tf.data pipeline.

    Args:
        tfdata: the :py:class:`tf.data.Dataset`
            to unzip.

        recursive: Set to ``True`` to recursively unzip
            multiple layers of zipped pipelines.
            Defaults to ``False``.

        eager_numpy: Set this to ``True`` to return
            Python lists of primitive types or
            :py:class:`numpy.array` objects. Defaults
            to ``False``.

        num_parallel_calls: The level of parallelism to
            each time we ``map()`` over a
            :py:class:`tf.data.Dataset`.

    Returns:
        Returns a Python list of either
             :py:class:`tf.data.Dataset` or NumPy
             arrays.
    """
    if isinstance(tfdata.element_spec, tf.TensorSpec):
        if eager_numpy:
            return list(tfdata.as_numpy_iterator())
        return tfdata
        
    
    def tfdata_map(i: int) -> list:
        return tfdata.map(
            lambda *cols: cols[i],
            deterministic=True,
            num_parallel_calls=num_parallel_calls,
        )

    if isinstance(tfdata.element_spec, tuple):
        num_columns = len(tfdata.element_spec)
        if recursive:
            return [
                tfdata_unzip(
                    tfdata_map(i),
                    recursive=recursive,
                    eager_numpy=eager_numpy,
                    num_parallel_calls=num_parallel_calls,
                )
                for i in range(num_columns)
            ]
        else:
            return [
                tfdata_map(i)
                for i in range(num_columns)
            ]

    raise ValueError(
        "Unknown tf.data.Dataset element_spec: " +
        str(tfdata.element_spec)
    )

根据这些示例数据集,tfdata_unzip() 的工作原理如下:

>>> import numpy as np

>>> baby = tf.data.Dataset.from_tensor_slices([
    np.array([1,2]),
    np.array([3,4]),
    np.array([5,6]),
])
>>> baby.element_spec
TensorSpec(shape=(2,), dtype=tf.int64, name=None)
TensorSpec(shape=(2,), dtype=tf.int64, name=None)

>>> parent = tf.data.Dataset.zip((baby, baby))
>>> parent.element_spec
(TensorSpec(shape=(2,), dtype=tf.int64, name=None),
 TensorSpec(shape=(2,), dtype=tf.int64, name=None))

>>> grandparent = tf.data.Dataset.zip((parent, parent))
>>> grandparent.element_spec
((TensorSpec(shape=(2,), dtype=tf.int64, name=None),
  TensorSpec(shape=(2,), dtype=tf.int64, name=None)),
 (TensorSpec(shape=(2,), dtype=tf.int64, name=None),
  TensorSpec(shape=(2,), dtype=tf.int64, name=None)))

这就是上面 babyparentgrandparent 数据集上的 tfdata_unzip() returns:

>>> tfdata_unzip(baby)
<TensorSliceDataset shapes: (2,), types: tf.int64>

>>> tfdata_unzip(parent)
[<ParallelMapDataset shapes: (2,), types: tf.int64>,
 <ParallelMapDataset shapes: (2,), types: tf.int64>]

>>> tfdata_unzip(grandparent)
[<ParallelMapDataset shapes: ((2,), (2,)), types: (tf.int64, tf.int64)>,
 <ParallelMapDataset shapes: ((2,), (2,)), types: (tf.int64, tf.int64)>]

>>> tfdata_unzip(grandparent, recursive=True)
[[<ParallelMapDataset shapes: (2,), types: tf.int64>,
  <ParallelMapDataset shapes: (2,), types: tf.int64>],
 [<ParallelMapDataset shapes: (2,), types: tf.int64>,
  <ParallelMapDataset shapes: (2,), types: tf.int64>]]

>>> tfdata_unzip(grandparent, recursive=True, eager_numpy=True)
[[[array([1, 2]), array([3, 4]), array([5, 6])],
  [array([1, 2]), array([3, 4]), array([5, 6])]],
 [[array([1, 2]), array([3, 4]), array([5, 6])],
  [array([1, 2]), array([3, 4]), array([5, 6])]]]

TensorFlow 的 get_single_element() 终于 around 可用于解压缩数据集(如上述问题中所问)。

这避免了使用 .map()iter() 生成和使用迭代器的需要(这对于大数据集来说可能代价高昂)。

get_single_element() returns 封装数据集所有成员的张量(或张量的元组或字典)。我们需要将批处理的数据集的所有成员传递到一个元素中。

这可用于获取特征作为张量数组,或特征和标签作为(张量数组的)元组或字典,具体取决于原始数据集的方式已创建。

import tensorflow as tf

a = [ 1, 2, 3 ]
b = [ 4, 5, 6 ]
c = [ (7, 8), (9, 10), (11, 12) ]
d = [ 13, 14 ]
# Creating datasets from lists
ads = tf.data.Dataset.from_tensor_slices(a)
bds = tf.data.Dataset.from_tensor_slices(b)
cds = tf.data.Dataset.from_tensor_slices(c)
dds = tf.data.Dataset.from_tensor_slices(d)

list(tf.data.Dataset.zip((ads, bds)).as_numpy_iterator()) == [ (1, 4), (2, 5), (3, 6) ] # True
list(tf.data.Dataset.zip((bds, ads)).as_numpy_iterator()) == [ (4, 1), (5, 2), (6, 3) ] # True

# Let's zip and unzip ads and dds
x = tf.data.Dataset.zip((ads, dds))
xa, xd = tf.data.Dataset.get_single_element(x.batch(len(x)))
xa = list(xa.numpy())
xd = list(xd.numpy())
print(xa, xd) # [1,2] [13, 14] # notice how xa is now different from a because ads was curtailed when zip was done above.
d == xd # True