使用 TorchScript 类作为 pytorch 模块中的成员



我正在尝试使一些现有的pytorch模型支持TorchScript jit编译器,但是我遇到了非原始类型的成员的问题。

这个小例子说明了这个问题:

import torch
@torch.jit.script
class Factory(object):
def __init__(self):
pass
def create(self, x: float) -> torch.Tensor:
return torch.tensor([x])
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.factory: Factory = Factory()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
mod = torch.jit.script(Foo())

运行时,jit 编译器给出错误

RuntimeError:
module has no attribute 'factory':
at example.py:17:15
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.factory.create(0)
~~~~~~~~~~~~ <--- HERE

我已经测试过Factory类可用于forward方法中的 jit,但是当我将其存储为成员时,它不会确认它。这是为什么呢?有什么方法可以让 jit 编译器将这种成员保存到编译的模块中?

这是PyTorch 中的一个错误,在您发布问题后不久就解决了:https://discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645,https://github.com/pytorch/pytorch/issues/27495。

更新 PyTorch 应该可以修复它。

最新更新