Tensorflow:如何使用不规则张量作为普通张量的索引?

Tensorflow: How to use a Ragged Tensor as an index into a normal tensor?

我有一个 2D RaggedTensor,由我想要的完整张量的每一行的索引组成,例如:

[
    [0,4],
    [1,2,3],
    [5]
]

进入

[
    [200, 305, 400, 20, 20, 105],
    [200, 315, 401, 20, 20, 167],
    [200, 7, 402, 20, 20, 105],
]

给予

[
    [200,20],
    [315,401,20],
    [105]
]

我怎样才能以最有效的方式实现这一点(最好只使用 tf 函数)?我相信 gather_nd 之类的东西能够采用 RaggedTensors,但我不知道它是如何工作的。

您可以将 tf.gatherbatch_dims 关键字参数一起使用:

>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>