Child 使用 PyTorch 模型执行推理时进程挂起
Child process hangs when performing inference with PyTorch model
我有一个 PyTorch 模型 (class Net
),连同它保存的权重/状态字典 (net.pth
),我想在多处理环境中执行推理。
我注意到我不能简单地创建一个模型实例,加载权重,然后与 child 进程共享模型(尽管我认为这是可能的,因为 copy-on-write) . child 在 y = model(x)
上挂起,最终整个程序挂起(由于 parent 的 waitpid
)。
以下是可重现的最小示例:
def handler():
with torch.no_grad():
x = torch.rand(1, 3, 32, 32)
y = model(x)
return y
model = Net()
model.load_state_dict(torch.load("./net.pth"))
pid = os.fork()
if pid == 0:
# this doesn't get printed as handler() hangs for the child process
print('child:', handler())
else:
# everything is fine here
print('parent:', handler())
os.waitpid(pid, 0)
如果 parent 和 child 的模型加载是独立完成的,即没有共享,那么一切都会按预期进行。我也试过在模型的张量上调用 share_memory_
,但无济于事。
我是不是做错了什么?
似乎共享状态字典并在每个进程中执行加载操作解决了问题:
LOADED = False
def handler():
global LOADED
if not LOADED:
# each process loads state independently
model.load_state_dict(state)
LOADED = True
with torch.no_grad():
x = torch.rand(1, 3, 32, 32)
y = model(x)
return y
model = Net()
# share the state rather than loading the state dict in parent
# model.load_state_dict(torch.load("./net.pth"))
state = torch.load("./net.pth")
pid = os.fork()
if pid == 0:
print('child:', handler())
else:
print('parent:', handler())
os.waitpid(pid, 0)
我有一个 PyTorch 模型 (class Net
),连同它保存的权重/状态字典 (net.pth
),我想在多处理环境中执行推理。
我注意到我不能简单地创建一个模型实例,加载权重,然后与 child 进程共享模型(尽管我认为这是可能的,因为 copy-on-write) . child 在 y = model(x)
上挂起,最终整个程序挂起(由于 parent 的 waitpid
)。
以下是可重现的最小示例:
def handler():
with torch.no_grad():
x = torch.rand(1, 3, 32, 32)
y = model(x)
return y
model = Net()
model.load_state_dict(torch.load("./net.pth"))
pid = os.fork()
if pid == 0:
# this doesn't get printed as handler() hangs for the child process
print('child:', handler())
else:
# everything is fine here
print('parent:', handler())
os.waitpid(pid, 0)
如果 parent 和 child 的模型加载是独立完成的,即没有共享,那么一切都会按预期进行。我也试过在模型的张量上调用 share_memory_
,但无济于事。
我是不是做错了什么?
似乎共享状态字典并在每个进程中执行加载操作解决了问题:
LOADED = False
def handler():
global LOADED
if not LOADED:
# each process loads state independently
model.load_state_dict(state)
LOADED = True
with torch.no_grad():
x = torch.rand(1, 3, 32, 32)
y = model(x)
return y
model = Net()
# share the state rather than loading the state dict in parent
# model.load_state_dict(torch.load("./net.pth"))
state = torch.load("./net.pth")
pid = os.fork()
if pid == 0:
print('child:', handler())
else:
print('parent:', handler())
os.waitpid(pid, 0)