转置 3D 数组并乘以逐元素内存连续性

transpose 3D array and multiply elementwise-memory contiguity

我有一个巨大的 3D 阵列,看起来像 A.shape = (100000, 5000, 50)。 我需要将其转置为 A.shape = (50, 5000, 100000) 形式的数组。 然后我需要对A中包含的50个矩阵中的每一个进行操作a = a.T @ a。 这给了我一个 A.shape = (50, 5000, 5000).

形式的 3D 数组

如果我使用 A.transpose(2, 1, 0) @ A.transpose(2, 0, 1) 单矩阵乘法 a = a.T @ a 结果比 a 没有从 A 中提取的情况慢一千倍。

问题是转置之后,3维数组不连续。 我尝试在转置后使用 np.ascontiguousarray()copy()。它有所改进,但速度仍然较慢,并且花费了相当多的时间进行复制。

有人能提出更好的选择吗? 特别是我正在尝试使用 np.einsum 但我不能。

您可以尝试以下方法:

A = ...
b = np.einsum('jki,jli->ikl', A, A)
print(b.shape)
# (50, 5000, 5000)