使用 Tensorflow 数据进行批处理和填充 API
Batching and padding using the Tensorflow data API
我无法理解 TensorFlow 数据 API (tensorflow.data.Dataset) 的工作原理。我的
输入是我要批处理、填充的整数列表列表
并连接起来。例如我的数据看起来像这样
data = [[1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4],
[1]]
批量大小为 3 时应变为:
[[[1, 2, 3], [4, 5, 6], [7, 0, 0]],
[[1, 2, 3], [4, 0, 0]],
[[1, 0, 0]]]
最后:
[[1, 2, 3], [4, 5, 6], [7, 0, 0],
[1, 2, 3], [4, 0, 0], [1, 0, 0]]
这并不容易,但我终于成功了:
def batch_each(x):
return Dataset.from_tensor_slices(x).batch(3)
data = [[1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4],
[1]]
rt = tf.ragged.constant(data)
ds = Dataset \
.from_tensor_slices(rt) \
.flat_map(batch_each) \
.padded_batch(1, padded_shapes = (3,)) \
.unbatch()
for e in ds:
print(e)
我无法理解 TensorFlow 数据 API (tensorflow.data.Dataset) 的工作原理。我的 输入是我要批处理、填充的整数列表列表 并连接起来。例如我的数据看起来像这样
data = [[1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4],
[1]]
批量大小为 3 时应变为:
[[[1, 2, 3], [4, 5, 6], [7, 0, 0]],
[[1, 2, 3], [4, 0, 0]],
[[1, 0, 0]]]
最后:
[[1, 2, 3], [4, 5, 6], [7, 0, 0],
[1, 2, 3], [4, 0, 0], [1, 0, 0]]
这并不容易,但我终于成功了:
def batch_each(x):
return Dataset.from_tensor_slices(x).batch(3)
data = [[1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4],
[1]]
rt = tf.ragged.constant(data)
ds = Dataset \
.from_tensor_slices(rt) \
.flat_map(batch_each) \
.padded_batch(1, padded_shapes = (3,)) \
.unbatch()
for e in ds:
print(e)