Pytorch:在minibatch中只对一部分样本进行反向传递时没有加速-为什么?



我有一个函数(见下文),它修改了损失函数,以便它只返回最小损失的minibatch中K个样本的损失。我们的想法是在每个优化步骤中关注这些示例。

所以我首先做一个前向传递来获得小批量中每个样本的损失值,然后通过fn调整损失。"get_adapted_loss_for_minibatch"。

由于适应损失只考虑了minibatch中一定比例的样本(我目前使用的是60%的样本),因此我期望在训练期间也能获得可测量的加速,因为只需要对minibatch中的一小部分样本进行后退。

但不幸的是,情况并非如此,当我在minibatch中使用所有样本时,训练所需的时间几乎相同(因此,当我不适应损失时)。我使用的是"densenet121"网络,训练是在CIFAR-100上完成的。

我做错了什么吗?我应该在小批量中手动禁用一些样品的自动渐变吗?我以为' topk '函数会自动完成。

def get_adapted_loss_for_minibatch(loss):
# Returns the loss containing only the samples of the mini-batch with the _lowest_ loss
# Parameter 'loss' must be a vector containing the per-sample loss for all samples in the (original) minibatch
minibatch_size = loss.size()[0]
r = 0.6 * minibatch_size
# round r to integer, safeguard if r is 0
r = max(round(r), 1)
# The 'topk' function returns the loss for the 'r' samples with the _lowest_ loss in the minibtach
# See documentation at https://pytorch.org/docs/stable/generated/torch.topk.html
# Note the 'topk' operation is differentiable, see https://stackoverflow.com/questions/67570529/derive-the-gradient-through-torch-topk
# and https://math.stackexchange.com/questions/4146359/derivative-for-masked-matrix-hadamard-multiplication
loss_adapted = torch.topk(loss, r, largest = False, sorted = False, dim = 0)[0]
# return it
return loss_adapted

您在训练速度上没有看到差异的原因是您正在使用批处理规范化。反过来,这意味着你的渐变仍然依赖于整个批,即使您只使用批处理内容的一部分来计算最终损失项并反向传播。

从数学上讲,在每个批归一化层中测量的运行统计数据将涉及批中的所有元素。

如果你看一下平均计算(当然批规范也涉及标准偏差测量)。直观地说,当你用给定向量的平均值归一化时,结果向量的元素将取决于初始向量的所有元素,因为所有元素被精确地用于计算平均值。

如果你想了解更多,你可以阅读这篇关于通过x / x.mean(0)反向传播的文章。


根据我们的讨论,这里是用GroupNorm替换BatchNorm2d层的一种方法。遍历网络的子模块,然后用新的初始化的GroupNorm实例替换BatchNorm2d的所有实例。我将给你一个Bottleneck:

的例子
>>> net = Bottleneck(10, 2)
>>> for name, module in net.named_children():
...   if isinstance(module, nn.BatchNorm2d):
...     setattr(net, name, 
...        nn.GroupNorm(num_channels=module.num_features, num_groups=num_groups))

这样net看起来就像:

>>> net
Bottleneck(
(bn1): GroupNorm(2, 10, eps=1e-05, affine=True)
(conv1): Conv2d(10, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): GroupNorm(2, 8, eps=1e-05, affine=True)
(conv2): Conv2d(8, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)

相关内容

  • 没有找到相关文章

最新更新