找不到高效的 p​​ytorch 广播

Efficient pytorch broadcasting not found

我的实现中有以下代码片段。有一个包含 3 个循环的嵌套 for 循环。在主代码中,原始系统的 3D 坐标被堆叠为连续堆叠点的一维向量,对于坐标为 (x,y,z) 的点,样本单元看起来像

Predictions =[...x,y,z,...]

而对于我的计算,我需要 reshaped_prediction 向量作为具有 prediction_reshaped[i][0]=x, prediction_reshaped[i][1]=y prediction_reshaped[i][2]=z 的二维矩阵 其中 i 是矩阵 prediction_reshaped 中的任意样本行。 下面的代码展示了逻辑

prediction_reshaped=torch.zeros([batch,num_node,dimesion])
    for i in range(batch):
        for j in range(num_node):
            for k in range(dimesion):
                prediction_reshaped[i][j][k]=prediction[i][3*j+k]

他们是否有任何有效的广播来避免这三个嵌套循环?它正在减慢我的代码。 torch.reshape 不符合我的目的。 该代码是使用 pytorch 实现的,所有矩阵都作为 pytorch 张量,但任何 numpy 解决方案也会有所帮助。

这应该可以完成工作。

import torch
batch = 2
num_nodes = 4
x = torch.rand(batch, num_nodes * 3)
# tensor([[0.8076, 0.2572, 0.7100, 0.4180, 0.6420, 0.4668, 0.8915, 0.0366,  0.5704,
#          0.0834, 0.3313, 0.9080],
#         [0.2925, 0.7367, 0.8013, 0.4516, 0.5470, 0.5123, 0.1929, 0.4191,  0.1174,
#          0.0076, 0.2864, 0.9151]])
x = x.reshape(batch, num_nodes, 3)
# tensor([[[0.8076, 0.2572, 0.7100],
#         [0.4180, 0.6420, 0.4668],
#         [0.8915, 0.0366, 0.5704],
#         [0.0834, 0.3313, 0.9080]],
#
#        [[0.2925, 0.7367, 0.8013],
#         [0.4516, 0.5470, 0.5123],
#         [0.1929, 0.4191, 0.1174],
#         [0.0076, 0.2864, 0.9151]]])