来自all_gather的分布式割炬数据冲突(将all_gather结果写入文件"修复"问题)



问题:

  • 分布式进程计算错误并将它们与float索引一起返回
  • 当从不同的列中收集错误时,这些索引会发生冲突
    • 因此,如果数据集有100个样本,GPU的数量为4,则得到的索引集的长度将为25,而不是预期的100
  • 当我将每个秩的数据(预收集)写入文件时,我可以验证索引是否100%不相交
  • 当我将每个等级的数据(收集后)写入文件时,问题就消失了
  • 注释掉后收集调试数据文件的编写,问题返回

注意:打印出采集后的结果也";修复";问题,但对收集后的结果进行排序却没有。

因此,将收集后的数据写入文件可以解决一些分布式的恶作剧。。。有人提醒我需要flush流以避免意外结果,但我在文档中没有看到任何必然结果。

下面是一个显示代码中发生的事情的最小示例:

# setup_distributed_stuff()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Data returned from distributed computation.
# Note that there's no overlap between the different ranks.
data = torch.arange(
0 + (rank * 100 // world_size),
(rank + 1) * 100 // world_size,
)
# `data` is confirmed to be disjoint across ranks by writing to file here.
# Gather data from all ranks.
if world_size > 1:
all_data = [torch.zeros_like(data) for _ in range(world_size)]
torch.distributed.all_gather(all_data, data)
data = torch.cat(all_data, dim=0)
# By writing "data" to file for debugging, the problem goes away...
#     i.e. len(set(data.numpy())) == 100!
# If I comment this out, then my gathered data collides...
#     i.e. len(set(data.numpy())) == 100 // world_size
with open("debug_data.pt", "wb") as _file:
torch.save(data, _file)
# I can also simply print the indices and get the same effect...
logger.info(
"Gathered result indices: {}...{}".format(
data[:10, -1], data[-10:, -1]
)
)
# However, sorting the indices doesn't do me any good...
data = data[data[:, -1].argsort(dim=0)]

if rank == 0:
# do_something(data)

all_gather()调用之后添加torch.distributed.barrier()调用以更令人满意的方式修复了问题。我没有想过这么做,因为文档中指出all_gather()是一个阻塞调用。也许它们的意思是阻断,而不是async;与CCD_ 7不同。

我想记录并将结果写入文件";"修复";sort没有的问题是因为前者不是强制同步的火炬操作(因此不由分布式进程组管理)。

最新更新