Pytorch Dist.broadcast不工作-导致正在广播的等级停止工作



我目前正在运行一个与此处类似的实验:https://pytorch.org/tutorials/intermediate/dist_tuto.html在这里,我试图将一个等级的更新张量广播给其他等级。这是代码:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def run(rank, size):
tensor = torch.zeros(1)
if rank == 0:
tensor += 1
# Send the tensor to process 1
# line causing problem below
dist.broadcast(tensor=tensor, src=0)
dist.barrier()
print('Rank ', rank, ' has data ', tensor[0])
def init_process(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)

if __name__ == "__main__":
size = 5
processes = []
mp.set_start_method("spawn")
for rank in range(size):
p = mp.Process(target=init_process, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()

现在,在理想的情况下,运行结束时的print语句将打印";秩n具有数据张量(1.(";对于所有n个秩,但当我运行它时,对于秩0,当它到达广播线时,该过程立即停止,并且它打印出"0";秩n具有数据张量(0.(";对于其余列组,意味着没有收到更新。我不确定是什么原因造成的,所以任何帮助都将不胜感激。我的pytorch版本是1.9.0。

将屏障移到外部if,它等待所有线程收集他们不会做的屏障。