Tensorflow:如何使用不规则张量作为普通张量的索引?
Tensorflow: How to use a Ragged Tensor as an index into a normal tensor?
我有一个 2D RaggedTensor,由我想要的完整张量的每一行的索引组成,例如:
[
[0,4],
[1,2,3],
[5]
]
进入
[
[200, 305, 400, 20, 20, 105],
[200, 315, 401, 20, 20, 167],
[200, 7, 402, 20, 20, 105],
]
给予
[
[200,20],
[315,401,20],
[105]
]
我怎样才能以最有效的方式实现这一点(最好只使用 tf
函数)?我相信 gather_nd
之类的东西能够采用 RaggedTensors,但我不知道它是如何工作的。
您可以将 tf.gather
与 batch_dims
关键字参数一起使用:
>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>
我有一个 2D RaggedTensor,由我想要的完整张量的每一行的索引组成,例如:
[
[0,4],
[1,2,3],
[5]
]
进入
[
[200, 305, 400, 20, 20, 105],
[200, 315, 401, 20, 20, 167],
[200, 7, 402, 20, 20, 105],
]
给予
[
[200,20],
[315,401,20],
[105]
]
我怎样才能以最有效的方式实现这一点(最好只使用 tf
函数)?我相信 gather_nd
之类的东西能够采用 RaggedTensors,但我不知道它是如何工作的。
您可以将 tf.gather
与 batch_dims
关键字参数一起使用:
>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>