tf.dynamic_partition 中包含可变大小项目的数据集

Dataset with variable-sized items from tf.dynamic_partition

类似于, I want to build a TF dataset from a list with each element of different sizes. However, unlike the linked question, I would like to generate the dataset from the output of tf.dynamic_partition,输出张量列表。

我的设置:

import tensorflow as tf
D = tf.data.Dataset # shorthand notation

x = tf.range(9) # Array to be partitioned
p = tf.constant([1,0,2,0,0,0,2,2,1]) # Defines partitions

因此数据集应具有三个元素,分别包含 [1 3 4 5][0 8][2 6 7]

如预期的那样,直接方法失败了:

dataset = D.from_tensor_slices(tf.dynamic_partition(x,p,3))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    nl = sess.run(next_element)

tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [4] != values[1].shape = [2]

接下来我尝试的是 , applying from_generator:

的应用
dataset = D.from_generator(lambda: tf.dynamic_partition(x,p,3), tf.int32, output_shapes=[None])
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    nl = sess.run(next_element)

tensorflow.python.framework.errors_impl.InvalidArgumentError: exceptions.ValueError: setting an array element with a sequence.

如何根据 tf.dynamic_partition 的输出创建包含可变大小项目的数据集?

from_generator 不起作用,因为它期望生成器函数产生 numpy 数组而不是张量。

解决您的问题的一种方法是为分区的每个元素创建一个数据集。在您的情况下,您将数据分成 3 组,因此您将创建 3 个数据集并将它们与 tf.data.Dataset.concatenate():

组合
x = tf.range(9)  # Array to be partitioned
p = tf.constant([1, 0, 2, 0, 0, 0, 2, 2, 1])  # Defines partitions

partition = tf.dynamic_partition(x, p, 3)

dataset = tf.data.Dataset.from_tensors(partition[0])
for i in range(1, 3):
    dataset_bis = tf.data.Dataset.from_tensors(partition[i])
    dataset = dataset.concatenate(dataset_bis)

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()


with tf.Session() as sess:
    for i in range(3):
        nl = sess.run(next_element)
        print(nl)