这个np.einsum('kij',A)的结果怎么理解呢?

How to understand the result of this np.einsum('kij',A)?

例如,

A = np.arange(24).reshape((2, 3, 4))
print np.einsum('ijk', A)

这仍然是 A 没有问题。

但是如果我这样做 print np.einsum('kij', A) 形状是 (3, 4, 2)。不应该是(4, 2, 3)吗?

print np.einsum('cab', A) 形状的结果是 (4, 2, 3) 也没有问题。为什么 print np.einsum('kij', A) 不一样?

如果您仅指定一组下标,这些下标将被解释为 input 数组相对于 output[=21] 的维度顺序=],反之亦然。

例如:

import numpy as np

A = np.arange(24).reshape((2, 3, 4))
B = np.einsum('kij', A)

i, j, k = np.indices(B.shape)

print(np.all(B[i, j, k] == A[k, i, j]))
# True

正如@hpaulj在评论中指出的,你可以通过指定两组下标来使输入和输出维度之间的对应关系更加明确:

# this is equivalent to np.einsum('kij', A)
print(np.einsum('kij->ijk', A).shape)
# (3, 4, 2)

# this is the behavior you are expecting
print(np.einsum('ijk->kij', A).shape)
# (4, 2, 3)