我如何检测回调是否在pytorch中触发?



我正在微调BERT模型。首先,我想要冻结图层并训练一下。当某个回调被触发时(假设是ReduceLROnPlateau),我想要解冻图层。我该怎么做呢?

恐怕PyTorch中的学习率调度器不提供钩子。看看这里的ReduceLROnPlateau的实现,当调度程序被触发时,两个属性被重置(i.e.当它识别到平台并降低学习率时):

if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0

基于此,您可以包装调度程序步骤调用,并通过检查scheduler.cooldown_counter == scheduler.cooldownscheduler.num_bad_epochs == 0是否为真来确定_reduce_lr是否被触发。

相关内容

  • 没有找到相关文章

最新更新