如何获取张量中的特定索引(列)并使用 TensorFlow 合并它们

How to get specific index (Column) in Tensors and merge them using TensorFlow

我正在尝试编写一个模型,并且有两个形状为 = (None, 8, 384) 的输入张量,我需要 select 它们基于第二个位置的索引并组合他们得到八个张量大小 (None, 2, 384).

例如,假设T1 的大小为(None, 8, 384),对应第一个变量有8 个城市和384 天。 T2 的大小为 (None, 8, 384),对应第二个变量有 8 个城市和 384 天。 我想 select 来自 T1 和 T2 的第一个城市 (None, 1, 348) 并将它们组合成一个大小为 (None, 2, 384) 的新张量。

column_indices = tf.concat([tf.gather(T1, [0], 轴=1),tf.gather(T2, [0], 轴=1 )], 轴=1)