如何对具有 "requires_grad = true" 的张量进行计算?

How can I do calculations on tensors that have "requires_grad = true"?

我有你在下面看到的这个程序。

import torch


def dht_calculate_transformation_of_single_joint(para_dht_parameters):
    var_a = para_dht_parameters[0]
    var_d = para_dht_parameters[1]
    var_alpha = para_dht_parameters[2]
    var_theta = para_dht_parameters[3]

    var_transformation = torch.tensor(data=[
        [torch.cos(var_theta), -1 * torch.sin(var_theta) * torch.cos(var_alpha), torch.sin(var_theta) * torch.sin(var_alpha), var_a * torch.cos(var_theta)],
        [torch.sin(var_theta), torch.cos(var_theta) * torch.cos(var_alpha), -1 * torch.cos(var_theta) * torch.sin(var_alpha), var_a * torch.sin(var_theta)],
        [0, torch.sin(var_alpha), torch.cos(var_alpha), var_d],
        [0, 0, 0, 1]
    ], dtype=torch.float32, requires_grad=True)

    return var_transformation


def dht_calculate_positions_of_all_joints(para_all_transformations_of_joints):
    var_all_positions_of_joints = torch.zeros(size=[27], dtype=torch.float32, requires_grad=True)
    var_index_all_positions_of_joints = 0
    var_transformation_to_joint = torch.zeros(size=[4, 4], dtype=torch.float32, requires_grad=True)

    for var_index_of_transformation_of_joint, var_transformation_of_joint in enumerate(para_all_transformations_of_joints):
        if var_index_of_transformation_of_joint == 0:
            var_transformation_to_joint = var_transformation_of_joint
        else:
            var_transformation_to_joint = torch.matmul(var_transformation_to_joint, var_transformation_of_joint)

        var_all_positions_of_joints[var_index_all_positions_of_joints + 0] = var_transformation_to_joint[0][3]
        var_all_positions_of_joints[var_index_all_positions_of_joints + 1] = var_transformation_to_joint[1][3]
        var_all_positions_of_joints[var_index_all_positions_of_joints + 2] = var_transformation_to_joint[2][3]
        var_index_all_positions_of_joints += 3

    return var_all_positions_of_joints


def dht_complete_calculation(para_input):
    var_input_reshaped = para_input.view(-1, 9, 4)
    var_output = torch.zeros(size=[para_input.shape[0], 27], dtype=torch.float32, requires_grad=True)  # Tensor ist x Zeilen (Datenreihen) * 27 Spalten (Positionen von Joints) groß.

    for var_index_of_current_row, var_current_row in enumerate(var_input_reshaped):
        var_all_transformations_of_joints = torch.zeros(size=[9, 4, 4], dtype=torch.float32, requires_grad=True)
        for var_index_of_current_column, var_current_column in enumerate(var_current_row):
            var_all_transformations_of_joints[var_index_of_current_column] = dht_calculate_transformation_of_single_joint(var_current_column)

        var_output[var_index_of_current_row] = dht_calculate_positions_of_all_joints(var_all_transformations_of_joints)

    return var_output


if __name__ == "__main__":
    inp = torch.tensor(data=
        [
            [5.1016, 5.2750, 5.0043, 5.2184,
             4.8471, 5.3377, 5.0113, 5.0789,
             4.8800, 5.0455, 5.0394, 4.9092,
             4.6609, 5.5003, 5.1327, 4.7121,
             4.9442, 5.0918, 4.8083, 4.3548,
             5.0163, 4.8840, 4.7491, 4.8089,
             4.8919, 5.0975, 4.9931, 5.0999,
             4.6400, 5.0069, 4.7420, 5.3347,
             4.6725, 5.0338, 5.0310, 5.0470],
            [4.9628, 5.0113, 5.0834, 4.7143,
             5.0336, 5.1864, 5.4348, 5.0918,
             5.1570, 4.8881, 4.5411, 4.6745,
             4.6072, 4.9938, 4.9655, 5.2279,
             5.5559, 5.1952, 5.2229, 5.0727,
             5.1382, 4.7613, 4.6449, 4.3832,
             5.1866, 5.6650, 4.9886, 4.8088,
             4.9390, 5.3506, 5.1028, 4.4640,
             5.1076, 5.0772, 4.8219, 5.1303]
        ]
    , requires_grad=True)

    t1 = dht_complete_calculation(inp)
    print("Endergebins \n", t1, t1.shape)

我在执行 main 时收到以下消息:

Traceback (most recent call last):
  File "dht.py", line 77, in <module>
    t1 = dht_complete_calculation(inp)
  File "dht.py", line 46, in dht_complete_calculation
    var_all_transformations_of_joints[var_index_of_current_column] = dht_calculate_transformation_of_single_joint(var_current_column)
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

问题是“dht_complete_calculation”函数将与神经网络一起使用(它不在代码片段中,与问题无关)。神经网络的输出将输入到“dht_complete_calculation”函数中。这就是为什么计算中使用的输出张量和每个张量都需要“requires_grad = true”。 “dht_complete_calculation”函数获取一个具有 x 行和 36 列的张量作为输入,并且应该输出一个具有 x 行和 27 列的张量。您看到的计算是正确的,因为如果我从每个张量中删除“requires_grad = true”,它就会起作用。 这是所需的输出:

tensor([[ 2.4727e+00, -4.4623e+00,  5.2750e+00,  6.6468e+00, -4.1351e+00,
          1.1145e+01,  1.3516e+01, -4.3618e+00,  1.2571e+01,  1.7557e+01,
         -1.0147e+01,  1.4048e+01,  1.8344e+01, -1.2500e+01,  2.0697e+01,
          2.4276e+01, -1.4575e+01,  2.3784e+01,  2.6110e+01, -2.0825e+01,
          2.6521e+01,  2.6707e+01, -2.4291e+01,  3.2371e+01,  3.1856e+01,
         -2.4376e+01,  3.6915e+01],
        [ 9.4848e-03, -4.9628e+00,  5.0113e+00,  3.1514e+00, -6.8211e+00,
          1.1249e+01,  9.8675e+00, -6.9772e+00,  1.3564e+01,  1.1752e+01,
         -9.6508e+00,  1.9519e+01,  1.1553e+01, -8.3219e+00,  2.7006e+01,
          1.4205e+01, -2.2681e+00,  2.9327e+01,  1.6872e+01, -2.0226e+00,
          3.6526e+01,  1.2353e+01, -5.7472e-01,  4.2049e+01,  1.0814e+01,
          3.8157e+00,  4.7547e+01]]) torch.Size([2, 27])

Process finished with exit code 0

然而,如果删除“requires_grad = true”,网络将不会学到任何东西,这不是我想要的。

你能帮我了解是哪部分代码触发了这个错误以及如何修复它吗?

这里的问题不在于您在 requires_grad=True 张量上进行计算。毕竟这就是获得渐变的方式!通过对这样的张量进行计算 :)

问题是您正在执行所谓的就地操作。

就地我们的意思是以前存在的变量的内存位置现在被新变量替换。结果计算图被破坏,无法实现梯度反向传播。

这看起来怎么样?我在 this Pytorch-forum question

中找到了一些快速示例

特别是:

>>> x = torch.rand(1)
>>> y = torch.rand(1)
>>> x
tensor([0.2738])
>>> id(x)
140736259305336
>>> x = x + y   # Normal operation
>>> id(x)
140726604827672 # New location
>>> x += y
>>> id(x)
140726604827672 # Existing location used (in-place)

那么,你可能会问,你在哪里做的?

其中一个地方是

var_all_positions_of_joints[var_index_all_positions_of_joints + 0] = var_transformation_to_joint[0][3]
    var_all_positions_of_joints[var_index_all_positions_of_joints + 1] = var_transformation_to_joint[1][3]
    var_all_positions_of_joints[var_index_all_positions_of_joints + 2] = var_transformation_to_joint[2][3]

您不应该这样做,而是应该将所有 var_transofrmation_to_joint 变量收集到一个列表中,然后根据您的情况执行 torch.stack 或 torch.cat。或者,如果将来您寻求重新安排张量中元素的位置,我建议使用 einops 之类的东西来获得高效且独立于框架的解决方案。