嗨~我现在使用来自https://github.com/KaiyangZhou/pytorch-center-loss的实现,中心被初始化为self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
。我很困惑为什么这种初始化保证最终中心是给定的特征/嵌入在某个类的中心?
我尝试了这个中心损失如上所述,但困惑为什么它在理论上工作?
因为随机初始化不会改变最终结果。它为反向传播提供了一个方向,即嵌入点会越来越靠近中心,这与你选择中心的哪个位置无关。它最终会达到目的的。
这发生在训练阶段。
在测试阶段,使用您训练的特征提取模型,而根本不使用中心位置信息。