如果在多个gpu列中有不同长度的张量数组,则默认的all_gather
方法不起作用,因为它要求长度相同。
例如,如果您有:
if gpu == 0:
q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
q = torch.tensor([5.3], device=torch.device(gpu))
如果我需要收集这两个张量数组,如下所示:
all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])
默认的torch.all_gather
不起作用,因为长度2, 1
不同。
由于无法直接使用内置方法进行收集,我们需要通过以下步骤编写自定义函数:
- 使用
dist.all_gather
获取所有数组的大小 - 找到最大尺寸
- 使用零/常数将局部数组填充到最大大小
- 使用
dist.all_gather
获取所有填充数组 - 使用步骤1中找到的大小取消对添加的零/常数的填充
下面的功能可以做到这一点:
def all_gather(q, ws, device):
"""
Gathers tensor arrays of different lengths across multiple gpus
Parameters
----------
q : tensor array
ws : world size
device : current gpu device
Returns
-------
all_q : list of gathered tensor arrays from all the gpus
"""
local_size = torch.tensor(q.size(), device=device)
all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
dist.all_gather(all_sizes, local_size)
max_size = max(all_sizes)
size_diff = max_size.item() - local_size.item()
if size_diff:
padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
q = torch.cat((q, padding))
all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
dist.all_gather(all_qs_padded, q)
all_qs = []
for q, size in zip(all_qs_padded, all_sizes):
all_qs.append(q[:size])
return all_qs
一旦我们能够完成上述操作,如果需要,我们就可以轻松地使用torch.cat
进一步连接到单个阵列中:
torch.cat(all_q)
[torch.tensor([1.5, 2.3, 5.3])
改编自:github
这里是@omsrisagar解决方案的一个扩展,它支持任意数量维度的张量(不仅仅是一维张量(。
def all_gather_nd(tensor):
"""
Gathers tensor arrays of different lengths in a list.
The length dimension is 0. This supports any number of extra dimensions in the tensors.
All the other dimensions should be equal between the tensors.
Args:
tensor (Tensor): Tensor to be broadcast from current process.
Returns:
(Tensor): output list of tensors that can be of different sizes
"""
world_size = dist.get_world_size()
local_size = torch.tensor(tensor.size(), device=tensor.device)
all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(all_sizes, local_size)
max_length = max(size[0] for size in all_sizes)
length_diff = max_length.item() - local_size[0].item()
if length_diff:
pad_size = (length_diff, *tensor.size()[1:])
padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype)
tensor = torch.cat((tensor, padding))
all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(all_tensors_padded, tensor)
all_tensors = []
for tensor_, size in zip(all_tensors_padded, all_sizes):
all_tensors.append(tensor_[:size[0]])
return all_tensors
请注意,这要求所有张量具有相同数量的维度,并且除第一维度外,所有张量的维度都相等。