numba未将Numpy数组形状识别为int



我想基于一些数组生成一个numpy矩阵,并使用jit或理想情况下的njit加速生成。如果nopython=False(启用了nopython,它失败了(,它会不断发送以下两个警告,我无法理解:

:14:Numba警告:编译返回到对象模式并启用环举,因为函数"process_ stuffs";类型推理失败,原因是:没有从转换"inp"的数组(int32,2d,C(到数组(int64,2d,A(,在None 中定义

文件"&";,第23行:defprocess_stuffs(输出,inp,route1,route2,zoneidx(:

input_pallets, _ = inp.shape
^

期间:在(23(键入参数

文件"&";,第23行:defprocess_stuffs(输出,inp,route1,route2,zoneidx(:

input_pallets, _ = inp.shape
^

@jit(nopython=False,:14:Numba警告:编译正在回退到对象模式WITH由于功能";process_ stuffs";失败类型推断原因:无法确定<班'numba.core.dispatcher.LiftedLoop'>

文件"&";,第25行:defprocess_stuffs(输出,inp,route1,route2,zoneidx(:

for minute in range(input_pallets):
^

@jit(nopython=False,C: \Anaconda3\envs\dev38\lib\site packages\numba\core\object_mode_passes.py:151:Numba警告:函数";process_ stuffs";是在对象模式下编译的without forceobj=True,但已提升循环。

虽然函数确实使用复杂类型,但在确定inp数组的长度时,它在非常的开头失败,然后它不想生成循环,尽管我已经看到了很多例子。

我试图通过使用locals指定类型来纠正错误,但正如你所看到的,这没有帮助。

这是一个最小的工作代码:

zoneidx=Dict.empty(key_type=types.unicode_type,value_type=types.int8)
zoneidx["A"]=np.int8(0)
zoneidx["B"]=np.int8(1)
zoneidx["C"]=np.int8(2)
zoneidx["D"]=np.int8(3)
zoneidx["E"]=np.int8(4)

output = np.zeros(shape=(110,5),dtype=np.int64)
inp = np.random.randint(0,2,size=(100,2))
route1 = np.random.choice(list('ABCDE'),size=10)
route2 = np.random.choice(list('ABCDE'),size=10)
@jit(nopython=False,
locals={'input_pallets':numba.int64,
'step':numba.int64,
'inp':numba.types.int64[:,:],
'route1':numba.types.unicode_type[:],
'route2':numba.types.unicode_type[:],
'output':numba.types.int64[:,:]})
def process_stuffs(output,inp,route1, route2, zoneidx):
input_pallets, _ = inp.shape
for minute in range(input_pallets):
prod1, prod2 = inp[minute]
if prod1+prod2 <1:
continue
if prod1:
routing = route1
number_of_pallets = prod1
number_of_steps = route1.shape[0]
else:
routing = route2
number_of_pallets = prod2
number_of_steps = route2.shape[0]
for step in range(number_of_steps):
zone = routing[step]
output[minute+step,zoneidx[zone]]+=number_of_pallets
return output


numba.__version__ == 0.53.1
numpy.__version__ == 1.19.2

我的代码出了什么问题?

NB:我对我的代码输出的正确性不感兴趣,我知道"路由2";将被忽略;路由1";由";inp";。我只是想把它编译一下。

警告消息具有误导性。事实上,输入的类型确实没有正确给定,这与.shape方法无关。

我的解决方案是使用numba.typeof函数来告诉它需要什么类型。例如,预期的是int32,而不是"0"的64;inp";。并且";unichr";应为,而不是unicode。

以下是我的最小示例的工作版本:

zoneidx=Dict.empty(key_type=numba.typeof(route1).dtype,value_type=types.int8)
zoneidx["A"]=np.int8(0)
zoneidx["B"]=np.int8(1)
zoneidx["C"]=np.int8(2)
zoneidx["D"]=np.int8(3)
zoneidx["E"]=np.int8(4)

output = np.zeros(shape=(110,5),dtype=np.int64)
inp = np.random.randint(0,2,size=(100,2))
route1 = np.random.choice(list('ABCDE'),size=10)
route2 = np.random.choice(list('ABCDE'),size=10)
@jit(nopython=False,
locals={'input_pallets':numba.int64,
'step':numba.int64,
'inp':numba.types.int32[:,:],
'route1':numba.typeof(route1),
'route2':numba.typeof(route1),
'output':numba.types.int64[:,:]})
def process_stuffs(output,inp,route1, route2, zoneidx):
input_pallets, _ = inp.shape
for minute in range(input_pallets):
prod1, prod2 = inp[minute]
if prod1+prod2 <1:
continue
if prod1:
routing = route1
number_of_pallets = prod1
number_of_steps = route1.shape[0]
else:
routing = route2
number_of_pallets = prod2
number_of_steps = route2.shape[0]
for step in range(number_of_steps):
zone = routing[step]
output[minute+step,zoneidx[zone]]+=number_of_pallets
return output

最新更新