来自参差不齐的张量的样本
Sample from ragged tensor
我有一个 row_lens 的参差不齐的张量,从 1 到 10k。我想从中随机 select 元素,并以可扩展的方式对每行的数量设置上限。就像这个例子:
vect = [[1,2,3],[4,5][6],[7,8,9,10,11,12,13]]
limit = 3
sample(vect, limit)
-> 输出:[[1,2,3],[4,5],[6],[7,9,11]]
我的想法是 select * 在 len_row < limit 的情况下,在另一种情况下是随机的。我想知道这是否可以通过一些张量流操作以低于 batch_size 的复杂度来完成?
您可以尝试在图形模式下使用 tf.map_fn
:
import tensorflow as tf
vect = tf.ragged.constant([[1,2,3],[4,5],[6],[7,8,9,10,11,12,13]])
@tf.function
def sample(x, samples=3):
length = tf.shape(x)[0]
x = tf.cond(tf.less_equal(length, samples), lambda: x, lambda: tf.gather(x, tf.random.shuffle(tf.range(length))[:samples]))
return x
c = tf.map_fn(sample, vect)
<tf.RaggedTensor [[1, 2, 3], [4, 5], [6], [12, 7, 9]]>
请注意,tf.vectorized_map
可能会更快,但当前 bug 关于此函数和参差不齐的张量。使用 tf.while_loop
也是一种选择。
我有一个 row_lens 的参差不齐的张量,从 1 到 10k。我想从中随机 select 元素,并以可扩展的方式对每行的数量设置上限。就像这个例子:
vect = [[1,2,3],[4,5][6],[7,8,9,10,11,12,13]]
limit = 3
sample(vect, limit)
-> 输出:[[1,2,3],[4,5],[6],[7,9,11]]
我的想法是 select * 在 len_row < limit 的情况下,在另一种情况下是随机的。我想知道这是否可以通过一些张量流操作以低于 batch_size 的复杂度来完成?
您可以尝试在图形模式下使用 tf.map_fn
:
import tensorflow as tf
vect = tf.ragged.constant([[1,2,3],[4,5],[6],[7,8,9,10,11,12,13]])
@tf.function
def sample(x, samples=3):
length = tf.shape(x)[0]
x = tf.cond(tf.less_equal(length, samples), lambda: x, lambda: tf.gather(x, tf.random.shuffle(tf.range(length))[:samples]))
return x
c = tf.map_fn(sample, vect)
<tf.RaggedTensor [[1, 2, 3], [4, 5], [6], [12, 7, 9]]>
请注意,tf.vectorized_map
可能会更快,但当前 bug 关于此函数和参差不齐的张量。使用 tf.while_loop
也是一种选择。