在 PyTorch 分布式数据并行 (DDP) 教程中,"设置"如何知道它的排名?



关于分布式数据并行入门教程

mp.spawn()没有通过秩时,setup()函数如何知道秩?

def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
.......
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)

mp.spawn确实将秩传递给它调用的函数。

来自torch.multiprocessing.spawn文档

torch.multiprocessing.spawn(fn,args=((,nprocs=1,join=True,daemon=False,start_method='spawn'(

  • fn(函数(-

    函数被调用为派生进程的入口点。这函数必须在模块的顶层定义,这样它才能腌制和繁殖。这是多处理所提出的要求函数被称为fn(i, *args),其中i是进程索引并且CCD_ 7是参数的传递元组

因此,当spawn调用fn时,它会将流程索引作为第一个参数传递给它。

最新更新