使用 Pytorch 模型参数进行计算时出现结果类型转换错误

Result type cast error when doing calculations with Pytorch model parameters

当我运行下面的代码时:

import torchvision

model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
    params[var] *= 0.1

报告运行时错误:

RuntimeError: result type Float can't be cast to the desired output type Long

但是当我将params[var] *= 0.1更改为params[var] = params[var] * 0.1时,错误消失了。

为什么会这样?

我以为params[var] *= 0.1params[var] = params[var] * 0.1的效果是一样的。

首先,让我们知道densenet201中的第一个long-type参数,你会发现features.norm0.num_batches_tracked表示训练期间mini-batches的数量,用于计算如果模型中有 BatchNormalization 层,则为均值和方差。 This parameter is a long-type number and cannot be float type because it behaves like a counter.

其次,在PyTorch中,有两种类型的操作:

  • Non-Inplace 操作: 您将计算后的新输出分配给变量的新副本,例如x = x + 1 或 x = x / 2. x 赋值前的内存位置不等于赋值后的内存位置,因为你有原始变量的副本。
  • 就地操作:当计算直接应用于变量的原始副本而不在此处进行任何复制时,例如x += 1 或 x /= 2.

让我们转到您的示例以了解发生了什么:

  1. Non-Inplcae操作:

    model = torchvision.models.densenet201(num_classes=10)
    params = model.state_dict()
    name = 'features.norm0.num_batches_tracked'
    
    print(id(params[name]))  # 140247785908560
    params[name] = params[name] + 0.1
    print(id(params[name]))  # 140247785908368  
    print(params[name].type()) # changed to torch.FloatTensor
    
  2. 就地操作:

    print(id(params[name]))  # 140247785908560
    params[name] += 1
    print(id(params[name]))  # 140247785908560 
    print(params[name].type()) # still torch.LongTensor
    
    params[name] += 0.1     # you want to change the original copy type to float ,you got an error
    

最后说几点:

  • In-place 操作可以节省一些内存,但在计算导数时可能会出现问题,因为会立即丢失历史记录。因此,不鼓励使用它们。 Source
  • 决定使用 in-place 操作时应谨慎,因为它们会覆盖原始内容。
  • 如果用pandas,这个有点类似于pandas中的inplace=True :).

这是阅读更多关于 in-place 操作 source and read also this discussion source 的好资源。