我有一个函数(见下文),它修改了损失函数,以便它只返回最小损失的minibatch中K个样本的损失。我们的想法是在每个优化步骤中关注这些示例。
所以我首先做一个前向传递来获得小批量中每个样本的损失值,然后通过fn调整损失。"get_adapted_loss_for_minibatch"。
由于适应损失只考虑了minibatch中一定比例的样本(我目前使用的是60%的样本),因此我期望在训练期间也能获得可测量的加速,因为只需要对minibatch中的一小部分样本进行后退。
但不幸的是,情况并非如此,当我在minibatch中使用所有样本时,训练所需的时间几乎相同(因此,当我不适应损失时)。我使用的是"densenet121"网络,训练是在CIFAR-100上完成的。
我做错了什么吗?我应该在小批量中手动禁用一些样品的自动渐变吗?我以为' topk '函数会自动完成。
def get_adapted_loss_for_minibatch(loss):
# Returns the loss containing only the samples of the mini-batch with the _lowest_ loss
# Parameter 'loss' must be a vector containing the per-sample loss for all samples in the (original) minibatch
minibatch_size = loss.size()[0]
r = 0.6 * minibatch_size
# round r to integer, safeguard if r is 0
r = max(round(r), 1)
# The 'topk' function returns the loss for the 'r' samples with the _lowest_ loss in the minibtach
# See documentation at https://pytorch.org/docs/stable/generated/torch.topk.html
# Note the 'topk' operation is differentiable, see https://stackoverflow.com/questions/67570529/derive-the-gradient-through-torch-topk
# and https://math.stackexchange.com/questions/4146359/derivative-for-masked-matrix-hadamard-multiplication
loss_adapted = torch.topk(loss, r, largest = False, sorted = False, dim = 0)[0]
# return it
return loss_adapted
您在训练速度上没有看到差异的原因是您正在使用批处理规范化。反过来,这意味着你的渐变仍然依赖于整个批,即使您只使用批处理内容的一部分来计算最终损失项并反向传播。
从数学上讲,在每个批归一化层中测量的运行统计数据将涉及批中的所有元素。
如果你看一下平均计算(当然批规范也涉及标准偏差测量)。直观地说,当你用给定向量的平均值归一化时,结果向量的元素将取决于初始向量的所有元素,因为所有元素被精确地用于计算平均值。
如果你想了解更多,你可以阅读这篇关于通过x / x.mean(0)
反向传播的文章。
根据我们的讨论,这里是用GroupNorm
替换BatchNorm2d
层的一种方法。遍历网络的子模块,然后用新的初始化的GroupNorm
实例替换BatchNorm2d
的所有实例。我将给你一个Bottleneck
:
>>> net = Bottleneck(10, 2)
>>> for name, module in net.named_children():
... if isinstance(module, nn.BatchNorm2d):
... setattr(net, name,
... nn.GroupNorm(num_channels=module.num_features, num_groups=num_groups))
这样net
看起来就像:
>>> net
Bottleneck(
(bn1): GroupNorm(2, 10, eps=1e-05, affine=True)
(conv1): Conv2d(10, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): GroupNorm(2, 8, eps=1e-05, affine=True)
(conv2): Conv2d(8, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)