Pytorch DataParallel 与自定义模型

Pytorch DataParallel with custom model

我想用多个 GPU 训练模型。我正在使用以下代码

model = load_model(path)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

除了 DataParallel 不包含原始模型中的函数外,它运行良好,是否有解决方法?谢谢

nn.Module passed to nn.DataParallel 最终将被 class 包装以处理数据并行性。您仍然可以使用 module 属性访问您的模型。

>>> p_model = nn.DataParallel(model)
>>> p_model.module # <- model

例如,要访问基础模型的 quantize 属性,您可以:

>>> p_model.module.quantize