"reshape" (N, 2) 形状的 numpy 数组变成 (N, 2, 2),其中每一列(大小 2)变成一个 diag (2,2) 块?
"reshape" numpy array of (N, 2) shape into (N, 2, 2) where each column (size 2) become a diag (2,2) block?
有没有有效的方法来做到这一点?
例如我有
[[1, 2, 3],
[4, 5, 6]]
我想得到:
[[[1, 0],
[0, 4]],
[[2, 0],
[0, 5]],
[[3, 0],
[0, 6]]]
对于大型阵列,我建议 np.einsum
如下:
>>> data
array([[1, 2, 3],
[4, 5, 6]])
>>> out = np.zeros((*reversed(data.shape),2),data.dtype)
>>> np.einsum("...ii->...i",out)[...] = data.T
>>> out
array([[[1, 0],
[0, 4]],
[[2, 0],
[0, 5]],
[[3, 0],
[0, 6]]])
einsum
创建一个包含对角线元素的内存位置的可写跨步视图。这与它在 numpy 中的效率差不多。
不是跨步视图,但也许更容易理解的是 'diagonal' 填充 (3,2,2) 数组:
In [28]: arr = np.arange(1,7).reshape(2,3)
In [29]: res = np.zeros((3,2,2),int)
In [30]: res[:,np.arange(2),np.arange(2)].shape
Out[30]: (3, 2)
In [31]: res[:,np.arange(2),np.arange(2)]=arr.T
In [32]: res
Out[32]:
array([[[1, 0],
[0, 4]],
[[2, 0],
[0, 5]],
[[3, 0],
[0, 6]]])
对于这个小案例,时间差别不大。我不知道他们会如何扩展:
In [39]: timeit np.einsum("...ii->...i",out)[...] = arr.T
5.21 µs ± 5.99 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [40]: timeit res[:,np.arange(2),np.arange(2)]=arr.T
6.4 µs ± 21.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
有没有有效的方法来做到这一点? 例如我有
[[1, 2, 3],
[4, 5, 6]]
我想得到:
[[[1, 0],
[0, 4]],
[[2, 0],
[0, 5]],
[[3, 0],
[0, 6]]]
对于大型阵列,我建议 np.einsum
如下:
>>> data
array([[1, 2, 3],
[4, 5, 6]])
>>> out = np.zeros((*reversed(data.shape),2),data.dtype)
>>> np.einsum("...ii->...i",out)[...] = data.T
>>> out
array([[[1, 0],
[0, 4]],
[[2, 0],
[0, 5]],
[[3, 0],
[0, 6]]])
einsum
创建一个包含对角线元素的内存位置的可写跨步视图。这与它在 numpy 中的效率差不多。
不是跨步视图,但也许更容易理解的是 'diagonal' 填充 (3,2,2) 数组:
In [28]: arr = np.arange(1,7).reshape(2,3)
In [29]: res = np.zeros((3,2,2),int)
In [30]: res[:,np.arange(2),np.arange(2)].shape
Out[30]: (3, 2)
In [31]: res[:,np.arange(2),np.arange(2)]=arr.T
In [32]: res
Out[32]:
array([[[1, 0],
[0, 4]],
[[2, 0],
[0, 5]],
[[3, 0],
[0, 6]]])
对于这个小案例,时间差别不大。我不知道他们会如何扩展:
In [39]: timeit np.einsum("...ii->...i",out)[...] = arr.T
5.21 µs ± 5.99 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [40]: timeit res[:,np.arange(2),np.arange(2)]=arr.T
6.4 µs ± 21.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)