在 TensorFlow 中获取由另一个张量部分索引的切片的好方法是什么?

What is a nice way to obtain a slice partially indexed by another tensor in TensorFlow?

假设我们有一个张量 x 第一维未知(例如 [?, 32, 32, 3]),另一个张量 i 实际上是一个标量。是否有一种很好的方法来获得 x 的第 i 切片按第一维分割,例如,以获得维度 [32, 32, 3] 的张量?我是 TensorFlow 的新手,只能想出这个极其笨拙的解决方案。

index = tf.concat(0, [i, tf.constant([0, 0, 0], tf.int64)])
size = [1, x.get_shape()[1].value, x.get_shape()[2].value, x.get_shape()[3].value]
result = tf.unpack(tf.slice(x, index, size))[0]

您可以利用 -1tf.slice() size 参数的特殊参数这一事实,意思是 "all remaining elements in that dimension"。然后,假设 i 是一个标量(而不是您的代码片段中看起来的长度为 1 的向量),您可以执行以下操作:

result = tf.squeeze(tf.slice(x, tf.pack([index, 0, 0, 0]), [1, -1, -1, -1]), [0])

或者,您可以使用 tf.gather() 到 select 来自第零维张量的一个或多个切片。在这种情况下,i 必须是一个向量:

i = tf.expand_dims(i, 0)  # Converts `i` to a vector if it is a scalar.
result = tf.squeeze(tf.gather(x, i), [0])

在这两种情况下,tf.squeeze() 操作都会删除第 0 维以提供三维结果。