如何获得 N 阶张量的任何列?

How to get any columns of a rank N tensor?

考虑张量

import numpy as np
array = np.array([
    [[111, 112], [121, 122]],
    [[211, 212], [221, 222]],
])

>>> print(array[:, 0, [0, 1]])
[
 [111 112]
 [211 212]
]

>>> print(array[:, 1, [0, 1]])
[
 [121 122]
 [221 222]
]

现在,我将如何获取元素 (:, 0, 1)(:, 1, 0)

[
 [112 121]
 [212 221]
]

作为上面的 numpy ndarray?

好像

>>> print(array[:, [(1, 0), (0, 1)]])

不是正确的表示法。

通常,给定一个索引元组列表,我如何获得这些元组的 N-1 张量(-1,因为这里的第一个等级总是 :)?

如果numpy不支持,我愿意用numpy以外的库来做这个

您可以使用 [:, [0,1], [1,0]],查看更多语法 here:

array[:, [0,1], [1,0]]
#array([[112, 121],
#       [212, 221]])