Pytorch:为什么在 DDP 中日志记录失败?

Pytorch: why logging fails in DDP?

我想在分布式数据并行管理的进程之一中使用日志记录。但是,日志记录在以下代码中不打印任何内容(代码源自 this tutorial):

#!/usr/bin/python

import os, logging
# logging.basicConfig(level=logging.DEBUG)

import torch

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # Initialize the process group.
    dist.init_process_group('NCCL', rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    setup(rank, world_size)

    if rank == 0:
        logger = logging.getLogger('train')
        logger.setLevel(logging.DEBUG)
        logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20, 10)  # .to(rank)

    outputs = ddp_model(inputs)

    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func, world_size):
    mp.spawn(
        demo_func,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )


def main():
    run_demo(demo_basic, 4)


if __name__ == "__main__":
    main()

但是,当我取消注释第 4 行时,日志记录有效。我可以知道原因以及如何修复该错误吗?

更新

让我们简要回顾一下 logging 模块中的记录器是如何工作的。

记录器以树形结构组织,即每个记录器都有一个唯一的 parent 记录器。默认情况下,它将是 root 记录器,而 root 记录器没有 parent 记录器。

当您在记录器上调用 Logger.info 方法(为简单起见忽略此处的级别检查)时,记录器会迭代其所有 handlers 并让它们处理当前记录,例如处理程序可以是可以打印到标准输出的 StreamHandler,也可以是打印到某个文件的 FileHandler)。当前记录器的所有处理程序完成其工作后,记录将被提供给它的 parent 记录器,并且 parent 记录器以相同的方式处理记录,即迭代所有处理程序 [=62] =] logger 并让他们处理记录,最后将记录传递给“grandparent”。此过程一直持续到到达当前记录器树的根,该树没有 parent.

检查下面的实现或here:

def callHandlers(self, record):
    c = self
    found = 0
    while c:
        for hdlr in c.handlers:
            found = found + 1
            if record.levelno >= hdlr.level:
                hdlr.handle(record)
        if not c.propagate:
            c = None    #break out
        else:
            c = c.parent

所以在您的情况下,您没有为 train 记录器指定任何处理程序。当您取消注释第 6 行时,即通过调用 logging.basicConfig(level=logging.DEBUG),将为 root 记录器创建一个 StreamHandler。尽管 train 记录器没有任何处理程序,但它的 parent 有一个 StreamHandler,即 root 记录器,它打印您实际看到的任何内容,而 train 记录器在这种情况下不打印任何内容。当注释的第 6 行时,甚至没有为 root 处理程序创建一个 StreamHandler,因此在这种情况下不会打印任何内容。所以其实这个问题与DDP无关。

顺便说一句,一开始我无法重现您的问题的原因是因为我使用 PyTorch 1.8,其中 logging.info 将在执行 dist.init_process_group 期间针对 MPI 以外的后端调用,隐式调用 basicConfig,为根记录器创建一个 StreamHandler,并且似乎按预期打印消息。

============================================= =========================

一个可能的原因:因为在执行dist.init_process_group的过程中,会调用_store_based_barrier,最终会调用logging.info(见源码here)。所以如果你在调用dist.init_process_group之前调用logging.basicConfig,它会被提前初始化,这使得根记录器忽略所有级别的日志。

在您的代码中不是这种情况,因为 logging.basicConfig 位于文件的顶部,它将在 dist.init_process_group 之前首先执行。实际上,在填充 nndist 等缺失的导入之后,我可以 运行 您提供的代码,但日志记录工作正常。也许您试图通过减少代码来重现问题,却在不知不觉中绕过了真正的问题?你能仔细检查一下这是否解决了你的问题吗?

我会为那些像我一样不熟悉 python 的日志记录机制的人澄清我发现的内容。

这个问题不是因为Pytorch,任何第三方包都可能触发这种现象。原因是 logging 是单例。对 basicConfig 的任何修改都会影响我们的代码和来自第三方的代码。

我目前的解决方案是创建一个模块级记录器并在不同模块之间共享它。