问题:
- 分布式进程计算错误并将它们与
float
索引一起返回 - 当从不同的列中收集错误时,这些索引会发生冲突
- 因此,如果数据集有100个样本,GPU的数量为4,则得到的索引集的长度将为25,而不是预期的100
- 当我将每个秩的数据(预收集)写入文件时,我可以验证索引是否100%不相交
- 当我将每个等级的数据(收集后)写入文件时,问题就消失了
- 注释掉后收集调试数据文件的编写,问题返回
注意:打印出采集后的结果也";修复";问题,但对收集后的结果进行排序却没有。
因此,将收集后的数据写入文件可以解决一些分布式的恶作剧。。。有人提醒我需要flush
流以避免意外结果,但我在文档中没有看到任何必然结果。
下面是一个显示代码中发生的事情的最小示例:
# setup_distributed_stuff()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Data returned from distributed computation.
# Note that there's no overlap between the different ranks.
data = torch.arange(
0 + (rank * 100 // world_size),
(rank + 1) * 100 // world_size,
)
# `data` is confirmed to be disjoint across ranks by writing to file here.
# Gather data from all ranks.
if world_size > 1:
all_data = [torch.zeros_like(data) for _ in range(world_size)]
torch.distributed.all_gather(all_data, data)
data = torch.cat(all_data, dim=0)
# By writing "data" to file for debugging, the problem goes away...
# i.e. len(set(data.numpy())) == 100!
# If I comment this out, then my gathered data collides...
# i.e. len(set(data.numpy())) == 100 // world_size
with open("debug_data.pt", "wb") as _file:
torch.save(data, _file)
# I can also simply print the indices and get the same effect...
logger.info(
"Gathered result indices: {}...{}".format(
data[:10, -1], data[-10:, -1]
)
)
# However, sorting the indices doesn't do me any good...
data = data[data[:, -1].argsort(dim=0)]
if rank == 0:
# do_something(data)
在all_gather()
调用之后添加torch.distributed.barrier()
调用以更令人满意的方式修复了问题。我没有想过这么做,因为文档中指出all_gather()
是一个阻塞调用。也许它们的意思是阻断,而不是async
;与CCD_ 7不同。
我想记录并将结果写入文件";"修复";sort
没有的问题是因为前者不是强制同步的火炬操作(因此不由分布式进程组管理)。