使用 tf.gather 基于另一个张量按行(第一维)按行提取张量

Use tf.gather to extract tensors row-wise based on another tensor row-wisely (first dimension)

我有两个张量,尺寸分别为 A:[B,3000,3]C:[B,4000]。我想使用 tf.gather() 将张量 C 中的每一行用作索引,并将张量 A 中的每一行用作参数,以获得大小为 [B,4000,3].

的结果

这里有一个例子可以让这个更容易理解:假设我有张量

A = [[1,2,3],[4,5,6],[7,8,9]],

C = [0,2,1,2,1],

result = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],

通过使用 tf.gather(A,C)。应用于维数小于 3 的张量时都可以。

但如开头描述的情况,通过应用tf.gather(A,C,axis=1),结果张量的形状为

[B,B,4000,3]

似乎 tf.gather() 只是为张量 C 中的每个元素做了工作,作为在张量 A 中收集元素的索引。我正在考虑的唯一解决方案是使用 for 循环,但是通过使用 tf.gather(A[i,...],C[i,...]) 来获得张量的正确大小

会极大地降低计算能力

[B,4000,3]

因此,是否有任何函数可以类似地完成此任务?

您需要使用tf.gather_nd:

import tensorflow as tf

A = ...  # B x 3000 x 3
C = ...  # B x 4000
s = tf.shape(C)
B, cols = s[0], s[1]
# Make indices for first dimension
idx = tf.tile(tf.expand_dims(tf.range(B, dtype=C.dtype), 1), [1, cols])
# Complete index for gather_nd
gather_idx = tf.stack([idx, C], axis=-1)
# Gather result
result = tf.gather_nd(A, gather_idx)