使用 Tensorflow 数据进行批处理和填充 API

Batching and padding using the Tensorflow data API

我无法理解 TensorFlow 数据 API (tensorflow.data.Dataset) 的工作原理。我的 输入是我要批处理、填充的整数列表列表 并连接起来。例如我的数据看起来像这样

data = [[1, 2, 3, 4, 5, 6, 7],
        [1, 2, 3, 4],
        [1]]

批量大小为 3 时应变为:

[[[1, 2, 3], [4, 5, 6], [7, 0, 0]],
 [[1, 2, 3], [4, 0, 0]],
 [[1, 0, 0]]]

最后:

[[1, 2, 3], [4, 5, 6], [7, 0, 0],
 [1, 2, 3], [4, 0, 0], [1, 0, 0]]

这并不容易,但我终于成功了:

def batch_each(x):
    return Dataset.from_tensor_slices(x).batch(3)
data = [[1, 2, 3, 4, 5, 6, 7],
        [1, 2, 3, 4],
        [1]]
rt = tf.ragged.constant(data)
ds = Dataset \
    .from_tensor_slices(rt) \
    .flat_map(batch_each) \
    .padded_batch(1, padded_shapes = (3,)) \
    .unbatch()
for e in ds:
    print(e)