使用 Pytorch 模型参数进行计算时出现结果类型转换错误
Result type cast error when doing calculations with Pytorch model parameters
当我运行下面的代码时:
import torchvision
model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
params[var] *= 0.1
报告运行时错误:
RuntimeError: result type Float can't be cast to the desired output type Long
但是当我将params[var] *= 0.1
更改为params[var] = params[var] * 0.1
时,错误消失了。
为什么会这样?
我以为params[var] *= 0.1
和params[var] = params[var] * 0.1
的效果是一样的。
首先,让我们知道densenet201
中的第一个long-type参数,你会发现features.norm0.num_batches_tracked
表示训练期间mini-batches的数量,用于计算如果模型中有 BatchNormalization 层,则为均值和方差。 This parameter is a long-type number and cannot be float type because it behaves like a counter
.
其次,在PyTorch中,有两种类型的操作:
- Non-Inplace 操作: 您将计算后的新输出分配给变量的新副本,例如x = x + 1 或 x = x / 2. x 赋值前的内存位置不等于赋值后的内存位置,因为你有原始变量的副本。
- 就地操作:当计算直接应用于变量的原始副本而不在此处进行任何复制时,例如x += 1 或 x /= 2.
让我们转到您的示例以了解发生了什么:
Non-Inplcae操作:
model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
name = 'features.norm0.num_batches_tracked'
print(id(params[name])) # 140247785908560
params[name] = params[name] + 0.1
print(id(params[name])) # 140247785908368
print(params[name].type()) # changed to torch.FloatTensor
就地操作:
print(id(params[name])) # 140247785908560
params[name] += 1
print(id(params[name])) # 140247785908560
print(params[name].type()) # still torch.LongTensor
params[name] += 0.1 # you want to change the original copy type to float ,you got an error
最后说几点:
- In-place 操作可以节省一些内存,但在计算导数时可能会出现问题,因为会立即丢失历史记录。因此,不鼓励使用它们。 Source
- 决定使用 in-place 操作时应谨慎,因为它们会覆盖原始内容。
- 如果用pandas,这个有点类似于pandas中的
inplace=True
:).
这是阅读更多关于 in-place 操作 source and read also this discussion source 的好资源。
当我运行下面的代码时:
import torchvision
model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
params[var] *= 0.1
报告运行时错误:
RuntimeError: result type Float can't be cast to the desired output type Long
但是当我将params[var] *= 0.1
更改为params[var] = params[var] * 0.1
时,错误消失了。
为什么会这样?
我以为params[var] *= 0.1
和params[var] = params[var] * 0.1
的效果是一样的。
首先,让我们知道densenet201
中的第一个long-type参数,你会发现features.norm0.num_batches_tracked
表示训练期间mini-batches的数量,用于计算如果模型中有 BatchNormalization 层,则为均值和方差。 This parameter is a long-type number and cannot be float type because it behaves like a counter
.
其次,在PyTorch中,有两种类型的操作:
- Non-Inplace 操作: 您将计算后的新输出分配给变量的新副本,例如x = x + 1 或 x = x / 2. x 赋值前的内存位置不等于赋值后的内存位置,因为你有原始变量的副本。
- 就地操作:当计算直接应用于变量的原始副本而不在此处进行任何复制时,例如x += 1 或 x /= 2.
让我们转到您的示例以了解发生了什么:
Non-Inplcae操作:
model = torchvision.models.densenet201(num_classes=10) params = model.state_dict() name = 'features.norm0.num_batches_tracked' print(id(params[name])) # 140247785908560 params[name] = params[name] + 0.1 print(id(params[name])) # 140247785908368 print(params[name].type()) # changed to torch.FloatTensor
就地操作:
print(id(params[name])) # 140247785908560 params[name] += 1 print(id(params[name])) # 140247785908560 print(params[name].type()) # still torch.LongTensor params[name] += 0.1 # you want to change the original copy type to float ,you got an error
最后说几点:
- In-place 操作可以节省一些内存,但在计算导数时可能会出现问题,因为会立即丢失历史记录。因此,不鼓励使用它们。 Source
- 决定使用 in-place 操作时应谨慎,因为它们会覆盖原始内容。
- 如果用pandas,这个有点类似于pandas中的
inplace=True
:).
这是阅读更多关于 in-place 操作 source and read also this discussion source 的好资源。