将 Detectron2 模型导出到火炬脚本时"Could not export Python function call '_ScaleGradient'"



在Detectron2训练了一个模块后,我试图将模型导出到TorchScript,然后我得到了以下错误:

无法导出Python函数调用"_ScaleGradient"。删除对Python函数的调用>出口前。你忘了添加@script或@script_method注释了吗?如果这是>nn。ModuleList,将其添加到__constants_

我发现代码在detectron2/modeling/roi_heads/cascade_rcnn.py 中

class _ScaleGradient(Function):
@staticmethod
def forward(ctx, input, scale):
ctx.scale = scale
return input
@staticmethod
def backward(ctx, grad_output):
return grad_output * ctx.scale, None

所以我把@statcmethod annos改为@torch.jit.script_method,之后,我得到了一个"ScriptMethodStub的对象是不可调用的"错误

我不熟悉torchscript,如何解决这个问题?

提前谢谢。

在推理阶段似乎不需要_ScaleGradient方法,所以我只将以下代码添加到cacasde_rcnn.py 中

if self.training:
#call _ScaleGradient.apply
else:
#don't call _ScaleGradient.apply

相关内容

最新更新