在 pytorch 中追溯已弃用的警告



我在这里使用以下代码在我的数据上训练 yolov3: https://github.com/cfotache/pytorch_custom_yolo_training/

但是我收到这个烦人的弃用警告

Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (expandTensors at /pytorch/aten/src/ATen/native/IndexingUtils.h:20)

我尝试使用python3 -W ignore train.py我尝试添加:

import warnings
warnings.filterwarnings('ignore')

但警告仍然存在。

我在堆栈溢出上找到了这段代码,该代码在警告上打印了该堆栈,

import traceback
import warnings
import sys
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file,'write') else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
warnings.showwarning = warn_with_traceback

这是我得到的:

File "/content/pytorch_custom_yolo_training/train.py", line 102, in <module>
loss = model(imgs, targets)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/pytorch_custom_yolo_training/models.py", line 267, in forward
x, *losses = module[0](x, targets)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/content/pytorch_custom_yolo_training/models.py", line 203, in forward
loss_x = self.mse_loss(x[mask], tx[mask])
File "/usr/lib/python3.6/warnings.py", line 99, in _showwarnmsg
msg.file, msg.line)
File "/content/pytorch_custom_yolo_training/train.py", line 29, in warn_with_traceback
traceback.print_stack(file=log)
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

转到堆栈中提到的文件和函数,我找不到任何uint8。 我该如何解决问题,甚至停止收到这些警告?

发现问题。 行 :loss_x = self.mse_loss(x[mask], tx[mask])mask变量是一个已弃用的ByteTensor。只是用BoolTensor代替了它

这适用于我的情况:添加

obj_mask = obj_mask.type(torch.BoolTensor)
noobj_mask = noobj_mask.type(torch.BoolTensor)

以前

loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])

in models.py.

这工作正常

obj_mask = obj_mask.bool()
noobj_mask = noobj_mask.bool()

相关内容

  • 没有找到相关文章

最新更新