找不到高效的 pytorch 广播
Efficient pytorch broadcasting not found
我的实现中有以下代码片段。有一个包含 3 个循环的嵌套 for 循环。在主代码中,原始系统的 3D 坐标被堆叠为连续堆叠点的一维向量,对于坐标为 (x,y,z) 的点,样本单元看起来像
Predictions =[...x,y,z,...]
而对于我的计算,我需要 reshaped_prediction 向量作为具有 prediction_reshaped[i][0]=x, prediction_reshaped[i][1]=y prediction_reshaped[i][2]=z
的二维矩阵
其中 i 是矩阵 prediction_reshaped
中的任意样本行。
下面的代码展示了逻辑
prediction_reshaped=torch.zeros([batch,num_node,dimesion])
for i in range(batch):
for j in range(num_node):
for k in range(dimesion):
prediction_reshaped[i][j][k]=prediction[i][3*j+k]
他们是否有任何有效的广播来避免这三个嵌套循环?它正在减慢我的代码。 torch.reshape 不符合我的目的。
该代码是使用 pytorch 实现的,所有矩阵都作为 pytorch 张量,但任何 numpy 解决方案也会有所帮助。
这应该可以完成工作。
import torch
batch = 2
num_nodes = 4
x = torch.rand(batch, num_nodes * 3)
# tensor([[0.8076, 0.2572, 0.7100, 0.4180, 0.6420, 0.4668, 0.8915, 0.0366, 0.5704,
# 0.0834, 0.3313, 0.9080],
# [0.2925, 0.7367, 0.8013, 0.4516, 0.5470, 0.5123, 0.1929, 0.4191, 0.1174,
# 0.0076, 0.2864, 0.9151]])
x = x.reshape(batch, num_nodes, 3)
# tensor([[[0.8076, 0.2572, 0.7100],
# [0.4180, 0.6420, 0.4668],
# [0.8915, 0.0366, 0.5704],
# [0.0834, 0.3313, 0.9080]],
#
# [[0.2925, 0.7367, 0.8013],
# [0.4516, 0.5470, 0.5123],
# [0.1929, 0.4191, 0.1174],
# [0.0076, 0.2864, 0.9151]]])
我的实现中有以下代码片段。有一个包含 3 个循环的嵌套 for 循环。在主代码中,原始系统的 3D 坐标被堆叠为连续堆叠点的一维向量,对于坐标为 (x,y,z) 的点,样本单元看起来像
Predictions =[...x,y,z,...]
而对于我的计算,我需要 reshaped_prediction 向量作为具有 prediction_reshaped[i][0]=x, prediction_reshaped[i][1]=y prediction_reshaped[i][2]=z
的二维矩阵
其中 i 是矩阵 prediction_reshaped
中的任意样本行。
下面的代码展示了逻辑
prediction_reshaped=torch.zeros([batch,num_node,dimesion])
for i in range(batch):
for j in range(num_node):
for k in range(dimesion):
prediction_reshaped[i][j][k]=prediction[i][3*j+k]
他们是否有任何有效的广播来避免这三个嵌套循环?它正在减慢我的代码。 torch.reshape 不符合我的目的。 该代码是使用 pytorch 实现的,所有矩阵都作为 pytorch 张量,但任何 numpy 解决方案也会有所帮助。
这应该可以完成工作。
import torch
batch = 2
num_nodes = 4
x = torch.rand(batch, num_nodes * 3)
# tensor([[0.8076, 0.2572, 0.7100, 0.4180, 0.6420, 0.4668, 0.8915, 0.0366, 0.5704,
# 0.0834, 0.3313, 0.9080],
# [0.2925, 0.7367, 0.8013, 0.4516, 0.5470, 0.5123, 0.1929, 0.4191, 0.1174,
# 0.0076, 0.2864, 0.9151]])
x = x.reshape(batch, num_nodes, 3)
# tensor([[[0.8076, 0.2572, 0.7100],
# [0.4180, 0.6420, 0.4668],
# [0.8915, 0.0366, 0.5704],
# [0.0834, 0.3313, 0.9080]],
#
# [[0.2925, 0.7367, 0.8013],
# [0.4516, 0.5470, 0.5123],
# [0.1929, 0.4191, 0.1174],
# [0.0076, 0.2864, 0.9151]]])