PyTorch内存泄漏引用周期在for循环



当使用PyTorch mps接口在Mac M1 GPU上迭代更新PyTorch中的张量时,我面临内存泄漏。以下是复制该行为的最小可复制示例:

import torch 
def leak_example(p1, device):

t1 = torch.rand_like(p1, device = device) # torch.cat((torch.diff(ubar.detach(), dim=0).detach().clone(), torch.zeros_like(ubar.detach()[:1,:,:,:], dtype = torch.float32)), dim = 0)
u1 = p1.detach() + 2 * (t1.detach())

B = torch.rand_like(u1, device = device)
mask = u1 < B

a1 = u1.detach().clone()
a1[~mask] = torch.rand_like(a1)[~mask]
return a1
if torch.cuda.is_available(): # cuda gpus
device = torch.device("cuda")
elif torch.backends.mps.is_available(): # mac gpus
device = torch.device("mps")
torch.set_grad_enabled(False)

p1 = torch.rand(5, 5, 224, 224, device = device)
for i in range(10000):
p1 = leak_example(p1, device)    

当我执行这个循环时,我Mac的GPU内存稳步增长。我试过在谷歌Colab的CUDA GPU上运行它,它的行为似乎类似,GPU的活动内存,不可释放内存和分配内存随着循环的进行而增加。

我试过分离和克隆张量,并使用弱参考,但无济于事。有趣的是,如果我不把leak_example的输出重新赋值给p1,这个行为就消失了,所以它看起来真的和递归赋值有关。有人知道我该怎么解决这个问题吗?

我想我找到泄漏的原因了,是掩码赋值。用等效的torch.where()语句替换它可以使泄漏消失。我想这与masked_scatter没有在PyTorch中实现MPS支持有关(尚未)?

相关内容

  • 没有找到相关文章

最新更新