用于网格旋转的 Numpy einsum()
Numpy einsum() for rotation of meshgrid
我有一组使用 meshgrid() 生成的 3d 坐标。我希望能够围绕 3 个轴旋转它们。
我尝试解开网格并在每个点上进行旋转,但网格很大,我 运行 内存不足。
在 2d 中使用 einsum() 解决了这个问题,但在将它扩展到 3d 时我无法弄清楚字符串格式。
我已经阅读了其他几页关于 einsum() 及其格式字符串的内容,但一直无法弄明白。
编辑:
我将我的网格轴称为 X、Y 和 Z,每个轴的形状均为 (213, 48, 37)。此外,当我试图将结果放回网格时,实际的内存错误出现了。
当我尝试'unravel'它进行逐点旋转时,我使用了以下函数:
def mg2coords(X, Y, Z):
return np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
我用以下循环遍历了结果:
def rotz(angle, point):
rad = np.radians(angle)
sin = np.sin(rad)
cos = np.cos(rad)
rot = [[cos, -sin, 0],
[sin, cos, 0],
[0, 0, 1]]
return np.dot(rot, point)
旋转后我将使用这些点进行插值。
使用您的定义:
In [840]: def mg2coords(X, Y, Z):
return np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
In [841]: def rotz(angle):
rad = np.radians(angle)
sin = np.sin(rad)
cos = np.cos(rad)
rot = [[cos, -sin, 0],
[sin, cos, 0],
[0, 0, 1]]
return np.array(rot)
# just to the rotation matrix
定义示例网格:
In [842]: X,Y,Z=np.meshgrid([0,1,2],[0,1,2,3],[0,1,2],indexing='ij')
In [843]: xyz=mg2coords(X,Y,Z)
逐行旋转:
In [844]: xyz1=np.array([np.dot(rot,i) for i in xyz])
相当于einsum
逐行计算:
In [845]: xyz2=np.einsum('ij,kj->ki',rot,xyz)
他们匹配:
In [846]: np.allclose(xyz2,xyz1)
Out[846]: True
或者我可以将 3 个数组收集到一个 4d 数组中,然后用 einsum
旋转它。这里 np.array
在开头添加一个维度。所以 dot
和 j
维度是第 1,数组的第 3d 如下:
In [871]: XYZ=np.array((X,Y,Z))
In [872]: XYZ2=np.einsum('ij,jabc->iabc',rot,XYZ)
In [873]: np.allclose(xyz2[:,0], XYZ2[0,...].ravel())
Out[873]: True
1
和 2
类似。
或者我可以将 XYZ2
拆分为 3 个组件数组:
In [882]: X2,Y2,Z2 = XYZ2
In [883]: np.allclose(X2,xyz2[:,0].reshape(X.shape))
Out[883]: True
如果您想向另一个方向旋转,请使用 ji
而不是 ij
,即使用 rot.T
.
我有一组使用 meshgrid() 生成的 3d 坐标。我希望能够围绕 3 个轴旋转它们。
我尝试解开网格并在每个点上进行旋转,但网格很大,我 运行 内存不足。
我已经阅读了其他几页关于 einsum() 及其格式字符串的内容,但一直无法弄明白。
编辑:
我将我的网格轴称为 X、Y 和 Z,每个轴的形状均为 (213, 48, 37)。此外,当我试图将结果放回网格时,实际的内存错误出现了。
当我尝试'unravel'它进行逐点旋转时,我使用了以下函数:
def mg2coords(X, Y, Z):
return np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
我用以下循环遍历了结果:
def rotz(angle, point):
rad = np.radians(angle)
sin = np.sin(rad)
cos = np.cos(rad)
rot = [[cos, -sin, 0],
[sin, cos, 0],
[0, 0, 1]]
return np.dot(rot, point)
旋转后我将使用这些点进行插值。
使用您的定义:
In [840]: def mg2coords(X, Y, Z):
return np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
In [841]: def rotz(angle):
rad = np.radians(angle)
sin = np.sin(rad)
cos = np.cos(rad)
rot = [[cos, -sin, 0],
[sin, cos, 0],
[0, 0, 1]]
return np.array(rot)
# just to the rotation matrix
定义示例网格:
In [842]: X,Y,Z=np.meshgrid([0,1,2],[0,1,2,3],[0,1,2],indexing='ij')
In [843]: xyz=mg2coords(X,Y,Z)
逐行旋转:
In [844]: xyz1=np.array([np.dot(rot,i) for i in xyz])
相当于einsum
逐行计算:
In [845]: xyz2=np.einsum('ij,kj->ki',rot,xyz)
他们匹配:
In [846]: np.allclose(xyz2,xyz1)
Out[846]: True
或者我可以将 3 个数组收集到一个 4d 数组中,然后用 einsum
旋转它。这里 np.array
在开头添加一个维度。所以 dot
和 j
维度是第 1,数组的第 3d 如下:
In [871]: XYZ=np.array((X,Y,Z))
In [872]: XYZ2=np.einsum('ij,jabc->iabc',rot,XYZ)
In [873]: np.allclose(xyz2[:,0], XYZ2[0,...].ravel())
Out[873]: True
1
和 2
类似。
或者我可以将 XYZ2
拆分为 3 个组件数组:
In [882]: X2,Y2,Z2 = XYZ2
In [883]: np.allclose(X2,xyz2[:,0].reshape(X.shape))
Out[883]: True
如果您想向另一个方向旋转,请使用 ji
而不是 ij
,即使用 rot.T
.