数据并行大小中全局批处理的增加导致 OOM 错误

Increased Global Batch in Data Parallelism size Causes OOM Error

在 ImageNet 数据集上训练 AlexNet 模型时,随着 GPU 数量的增加,我正在增加批量大小。当我收到 OOM 错误时,它可以正常工作到 4096。我从 4 个 GPU 上的批量大小 1024 开始,然后是 8 个 GPU 上的 2048。然而,当我在 16 个 GPU 上尝试 4096 时,我得到了 OOM。理想情况下,这不应该发生,因为在数据并行性中,每个 GPU 的样本保持不变。我正在使用 ChainerMN 进行培训。

终于想通了。不要随着 GPU 数量的增加而增加批量大小。如果您将批量大小设置为 32,则每个 GPU 的批量大小将为 32。