如何在 Tensorflow 中进行类似 Numpy 的索引选择?

How to Do Numpy Like Index Selection in Tensorflow?

让我们使用下面的代码

#!/usr/bin/env python3
# encoding: utf-8
import numpy as np, tensorflow as tf # tf.__version__==2.7.0
sample_array=np.random.uniform(size=(2**10, 120, 20))
to_select=[5, 6, 9, 4]
sample_tensor=tf.convert_to_tensor(value=sample_array)

sample_array[:, :, to_select] # Works okay
sample_tensor[:, :, to_select] # TypeError. How to do this in tensor? 
tf.convert_to_tensor(value=sample_tensor.numpy()[:, :, to_select]) # Ugly workaround

基本上,如何将这些元素作为适当维度的张量,就像 numpy 一样?我尝试了 tf.slicetf.gather,但无法找出要传递的正确参数。

我可以将它转换为 numpy 并返回,但不确定它是否会牺牲操作的效率,并作为自定义训练循环的一部分工作。

最简单的解决方案是使用 tf.concat,尽管它可能不是那么有效:

import numpy as np
import tensorflow as tf

sample_array = np.random.uniform(size=(2, 2, 20))
to_select = [5, 6, 9, 4]
sample_tensor = tf.convert_to_tensor(value = sample_array)
numpy_way = sample_array[:, :, to_select]
tf_way = tf.concat([tf.expand_dims(sample_array[:, :, to_select[i]], axis=-1) for i in tf.range(len(to_select))], axis=-1)
#tf_way = tf.concat([tf.expand_dims(sample_array[:, :, s], axis=-1) for s in to_select], axis=-1)

print(numpy_way)
print(tf_way)
[[[0.81208086 0.03873406 0.89959868 0.97896671]
  [0.57569184 0.33659472 0.32566287 0.58383079]]

 [[0.59984846 0.43405048 0.42366314 0.25505199]
  [0.16180442 0.5903358  0.21302399 0.86569914]]]
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
  [0.57569184 0.33659472 0.32566287 0.58383079]]

 [[0.59984846 0.43405048 0.42366314 0.25505199]
  [0.16180442 0.5903358  0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)

一个更复杂但有效的解决方案将涉及使用 tf.meshgridtf.gather_nd。检查此 or this and finally 。这是基于您的问题的示例:

to_select = tf.expand_dims(tf.constant([5, 6, 9, 4]), axis=0)
to_select_shape = tf.shape(to_select)
sample_tensor_shape = tf.shape(sample_tensor)
to_select = tf.expand_dims(tf.reshape(tf.tile(to_select, [1, to_select_shape[1]]), (sample_tensor_shape[0], sample_tensor_shape[0] * to_select_shape[1])), axis=-1)

ij = tf.stack(tf.meshgrid(
    tf.range(sample_tensor_shape[0], dtype=tf.int32), 
    tf.range(sample_tensor_shape[1], dtype=tf.int32),
                              indexing='ij'), axis=-1)

gather_indices = tf.concat([tf.repeat(ij, repeats=to_select_shape[1], axis=1), to_select], axis=-1)
gather_indices = tf.reshape(gather_indices, (to_select_shape[1], to_select_shape[1], 3))

result = tf.gather_nd(sample_tensor, gather_indices, batch_dims=0)
result = tf.reshape(result, (result.shape[0]//2, result.shape[0]//2, result.shape[1]))
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
  [0.57569184 0.33659472 0.32566287 0.58383079]]

 [[0.59984846 0.43405048 0.42366314 0.25505199]
  [0.16180442 0.5903358  0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)