Child 使用 PyTorch 模型执行推理时进程挂起

Child process hangs when performing inference with PyTorch model

我有一个 PyTorch 模型 (class Net),连同它保存的权重/状态字典 (net.pth),我想在多处理环境中执行推理。

我注意到我不能简单地创建一个模型实例,加载权重,然后与 child 进程共享模型(尽管我认为这是可能的,因为 copy-on-write) . child 在 y = model(x) 上挂起,最终整个程序挂起(由于 parent 的 waitpid)。

以下是可重现的最小示例:

def handler():
    with torch.no_grad():
        x = torch.rand(1, 3, 32, 32)
        y = model(x)

    return y


model = Net()
model.load_state_dict(torch.load("./net.pth"))

pid = os.fork()

if pid == 0:
    # this doesn't get printed as handler() hangs for the child process
    print('child:', handler())
else:
    # everything is fine here
    print('parent:', handler())
    os.waitpid(pid, 0)

如果 parent 和 child 的模型加载是独立完成的,即没有共享,那么一切都会按预期进行。我也试过在模型的张量上调用 share_memory_,但无济于事。

我是不是做错了什么?

似乎共享状态字典并在每个进程中执行加载操作解决了问题:

LOADED = False 

def handler():
    global LOADED
    if not LOADED:
        # each process loads state independently
        model.load_state_dict(state)
        LOADED = True

    with torch.no_grad():
        x = torch.rand(1, 3, 32, 32)
        y = model(x)

    return y


model = Net()

# share the state rather than loading the state dict in parent
# model.load_state_dict(torch.load("./net.pth"))
state = torch.load("./net.pth")

pid = os.fork()

if pid == 0:
    print('child:', handler())
else:
    print('parent:', handler())
    os.waitpid(pid, 0)