使用 tf.data 批量处理顺序数据
Batch sequential data with tf.data
让我们考虑一个有序的玩具数据集,它具有两个特征:
value
(例如1, 2, 3, 4, 5, 111, 222, 333, 444, 555
)
sequence_id
(例如0, 0, 0, 0, 0, 1, 1, 1, 1, 1
)
此数据基本上由两个拼接的扁平序列组成,1, 2, 3, 4, 5
(序列0
)和111, 222, 333, 444, 555
(序列1
)。
我想生成大小为 t
(例如 3
)的序列,由来自相同序列 (sequence_id
) 的连续元素组成,我不希望序列具有属于不同 sequence_id
.
的元素
例如,没有任何洗牌,我想得到以下批次:
- 第 1 批:
1, 2, 3
、
- 第二批:
2, 3, 4
,
- 第 3 批:
3, 4, 5
、
- 第 4 批:
111, 222, 333
、
- 第 5 批:
222, 333, 444
、
- 第 6 批:
333, 444, 555
、
- 第 7 批:
1, 2, 3
、
- 等等
我知道如何使用 tf.data.Dataset.window
或 tf.data.Dataset.batch
生成序列数据,但我不知道如何防止序列包含不同 sequence_id
的混合(例如序列4, 5, 111
应该无效,因为它混合了序列 0
和序列 1
).
中的元素
以下是我失败的尝试:
import tensorflow as tf
data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True)\
.repeat(-1)\
.flat_map(lambda x, y: x.batch(3))\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
输出:
[[ 1 2 3] # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 111] # bad – mix of sequence 0 (4, 5) and sequence 1 (111)
[ 5 111 222] # bad
[111 222 333] # good
[222 333 444] # good
[333 444 555] # good
[ 1 2 3] # good
[ 2 3 4]] # good
可以用filter()
判断sequence_id
是否一致。因为filter()
转换目前不支持嵌套数据集作为输入,所以需要zip()
.
import tensorflow as tf
data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True) \
.flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
.filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
.map(lambda x,y:x)\
.repeat(-1)\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
[[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]
[222 333 444]
[333 444 555]
[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]]
让我们考虑一个有序的玩具数据集,它具有两个特征:
value
(例如1, 2, 3, 4, 5, 111, 222, 333, 444, 555
)sequence_id
(例如0, 0, 0, 0, 0, 1, 1, 1, 1, 1
)
此数据基本上由两个拼接的扁平序列组成,1, 2, 3, 4, 5
(序列0
)和111, 222, 333, 444, 555
(序列1
)。
我想生成大小为 t
(例如 3
)的序列,由来自相同序列 (sequence_id
) 的连续元素组成,我不希望序列具有属于不同 sequence_id
.
例如,没有任何洗牌,我想得到以下批次:
- 第 1 批:
1, 2, 3
、 - 第二批:
2, 3, 4
, - 第 3 批:
3, 4, 5
、 - 第 4 批:
111, 222, 333
、 - 第 5 批:
222, 333, 444
、 - 第 6 批:
333, 444, 555
、 - 第 7 批:
1, 2, 3
、 - 等等
我知道如何使用 tf.data.Dataset.window
或 tf.data.Dataset.batch
生成序列数据,但我不知道如何防止序列包含不同 sequence_id
的混合(例如序列4, 5, 111
应该无效,因为它混合了序列 0
和序列 1
).
以下是我失败的尝试:
import tensorflow as tf
data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True)\
.repeat(-1)\
.flat_map(lambda x, y: x.batch(3))\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
输出:
[[ 1 2 3] # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 111] # bad – mix of sequence 0 (4, 5) and sequence 1 (111)
[ 5 111 222] # bad
[111 222 333] # good
[222 333 444] # good
[333 444 555] # good
[ 1 2 3] # good
[ 2 3 4]] # good
可以用filter()
判断sequence_id
是否一致。因为filter()
转换目前不支持嵌套数据集作为输入,所以需要zip()
.
import tensorflow as tf
data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True) \
.flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
.filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
.map(lambda x,y:x)\
.repeat(-1)\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
[[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]
[222 333 444]
[333 444 555]
[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]]