我有一个更快的R CNN探测器,我用pytorch闪电在一个噪音很大但很大的数据集上训练过它。我预计在1个历元的训练之后,模型将只输出数据集中的标签,在我的情况下是0到56。然而,它给了我诸如64和89之类的标签。这是怎么回事?它从哪里冒出了这些从未受过训练的标签?
无法共享任何代码,因为这个问题可能与我的数据集有关,而不是我的代码。使用COCO预训练模型,它可以很好地工作。
问题不是我的数据或模型。问题是pytorchnn.module.load_state_dict()
方法。该方法有一个参数strict
,它本应允许用户在没有完全相同的权重键的情况下加载state_dict,但它实际上导致加载的模型完全错误。我强烈建议在pytorch中加载带有load_state_dict()
的模型时不要使用strict=False
。