Tensorflow 数据:将函数应用于批处理

Tensorflow data : apply function TO batch

我正在使用 tf.data 从大型文本语料库迭代批处理。

我只想将函数应用于数据子集(或批次子集),而不是一个一个地应用。 具体来说,我的数据迭代器产生 query, reply 与批次。它们都是正对,所以我只想洗牌下一批的子集(在这种情况下,只有 "reply" 批次)以生成随机负数。

例如, 输入:

query1 reply1

query2 reply2

query3 reply3

...

输出:

当然可以使用 python 随机播放文本,但我想使用 tf.data 提高效率,因为数据量太大。

假设你有 queriesreplies 作为两个张量。你需要的是我想像下面这样的东西,你可以将它与原始批次连接起来。

batch_size = 10
def reply_shuffle(queries, replies):
   shuffled_indices = tf.random_uniform(minval=0, maxval=batch_size+1, shape=[batch_size], dtype=tf.int32)
   shuffled_replies = tf.gather_nd(replies, shuffled_indices) 
   return queries, shuffled_replies