在脚本类(torch.jit.script)中使用numpy



我想知道我是否可以在将由torch.jit.script编写脚本的函数中使用numpy API。我有这个简单的功能不起作用:

import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
@torch.jit.ignore
def call_np():
return torch.jit.export(np.random.choice(2, p=[.95,.05]))
def forward(self):
pass 
@torch.jit.export
def func(self):
done = self.call_np()
print (done)

scripted_module = torch.jit.script(MyModule((( scripted_module.func((

这导致:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-133-ab1ce37d6edc> in <module>()
18         print (done)
19 
---> 20 scripted_module = torch.jit.script(MyModule())
21 scripted_module.func()
C:ProgramDataAnaconda3libsite-packagestorchjit__init__.py in script(obj, optimize, _frames_up, _rcb)
1201 
1202     if isinstance(obj, torch.nn.Module):
-> 1203         return torch.jit.torch.jit._recursive.recursive_script(obj)
1204 
1205     qualified_name = _qualified_name(obj)
C:ProgramDataAnaconda3libsite-packagestorchjit_recursive.py in recursive_script(mod, exclude_methods)
171     filtered_methods = filter(ignore_overloaded, methods)
172     stubs = list(map(make_stub, filtered_methods))
--> 173     return copy_to_script_module(mod, overload_stubs + stubs)
174 
175 
C:ProgramDataAnaconda3libsite-packagestorchjit_recursive.py in copy_to_script_module(original, stubs)
93             setattr(script_module, name, item)
94 
---> 95     torch.jit._create_methods_from_stubs(script_module, stubs)
96 
97     # Now that methods have been compiled, take methods that have been compiled
C:ProgramDataAnaconda3libsite-packagestorchjit__init__.py in _create_methods_from_stubs(self, stubs)
1421     rcbs = [m.resolution_callback for m in stubs]
1422     defaults = [get_default_args(m.original_method) for m in stubs]
-> 1423     self._c._create_methods(self, defs, rcbs, defaults)
1424 
1425 # For each user-defined class that subclasses ScriptModule this meta-class,
RuntimeError: Unable to cast Python instance of type <class 'int'> to C++ type 'unsigned __int64'

我感谢任何帮助或评论。

我在pytorch论坛上得到了答案: https://discuss.pytorch.org/t/use-numpy-in-script-class-torch-jit-script/62351

最新更新