在 PyTorch 分布式数据并行 (DDP) 教程中,“setup”如何知道它的排名?
In the PyTorch Distributed Data Parallel (DDP) tutorial, how does `setup` know it's rank?
教程Getting Started with Distributed Data Parallel
setup()
函数如何在 mp.spawn()
未通过排名时知道排名?
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
.......
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)
mp.spawn
确实将等级传递给它调用的函数。
来自 torch.multiprocessing.spawn
文档
torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')
...
fn (function) -
Function is called as the entrypoint of the spawned process. This
function must be defined at the top level of a module so it can be
pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as fn(i, *args)
, where i
is the process index
and args
is the passed through tuple of arguments.
因此,当 spawn
调用 fn
时,它会将进程索引作为第一个参数传递给它。
教程Getting Started with Distributed Data Parallel
setup()
函数如何在 mp.spawn()
未通过排名时知道排名?
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
.......
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)
mp.spawn
确实将等级传递给它调用的函数。
来自 torch.multiprocessing.spawn
文档
torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')
...
fn (function) -
Function is called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as
fn(i, *args)
, wherei
is the process index andargs
is the passed through tuple of arguments.
因此,当 spawn
调用 fn
时,它会将进程索引作为第一个参数传递给它。