修改多维 numpy 数组中的对角线

Modifying diagonals in multidimensional numpy arrays

我有一个形状为 (7, 3, 7, 3) 的多维 numpy 数组,我想修改轴 0 和轴 2 重合的广义对角线。这个广义对角线将被定义为数组中第 0 个和第 2 个索引重合且形状为 (3, 3, 7) 的元素。

正在做:

arr.diagonal(axis1=0, axis2=2)

我可以访问对角线的元素,但我不能修改它们'in place',至少在 numpy 的 1.8.2 版本中是这样。

Numpy documentation 解释说在 1.10 版本中这可能是可能的。但是,由于我依赖于使用相同代码的其他人,因此无法更新到 numpy 1.10。文档还建议使用 .copy() 以获得可移植的解决方案,但是 .copy() 会复制数组,但如果我想修改原始数组的对角线,这没有帮助。

或者,我尝试直接索引对角线元素[使用来自 numpy.indices((7, 3, 7, 3))] 的输入,但没有成功。

如何访问广义对角线的元素来修改 numpy 1.8.2 中的原始数组?

创建这种广义对角线视图的一种方法是使用模块 numpy.lib.stride_tricks 中的 as_strided 函数。与两个轴的对角线关联的轴的步幅是这些轴的步幅之和。

例如:

In [196]: from numpy.lib.stride_tricks import as_strided

创建一个形状为 (7, 3, 7, 3) 的数组:

In [197]: a = np.arange(21*21).reshape(7,3,7,3)

In [198]: a[5, :, 5, :]
Out[198]: 
array([[330, 331, 332],
       [351, 352, 353],
       [372, 373, 374]])

创建与轴 0 和 2 关联的 "diagonal" 的视图。视图的形状为 (3, 3, 7):

In [199]: d = as_strided(a, strides=(a.strides[1], a.strides[3], a.strides[0] + a.strides[2]), shape=(3, 3, 7))

例如,检查 d[:, :, 5] 是否与 a[5, :, 5, :] 相同:

In [200]: d[:, :, 5]
Out[200]: 
array([[330, 331, 332],
       [351, 352, 353],
       [372, 373, 374]])

验证 da 的视图,方法是修改 d 并查看 a 已更改:

In [201]: d[1, 1, 5] = -1

In [202]: a[5, :, 5, :]
Out[202]: 
array([[330, 331, 332],
       [351,  -1, 353],
       [372, 373, 374]])

小心as_strided!如果参数错误,您可以写入 a 之外的内存,可能导致 python 崩溃。