如何在Numba中使用np.dot()使用连续数组设置批处理矩阵乘法



我试图用numba加速批处理矩阵乘法问题,但它一直告诉我它使用连续代码更快。

注意:我使用numba版本0.55.1,和numpy版本1.21.5

问题来了:

import numpy as np
import numba as nb
def numbaFastMatMult(mat,vec):
result = np.zeros_like(vec)
for n in nb.prange(vec.shape[0]):
result[n,:] = np.dot(vec[n,:], mat[n,:,:])
return result
D,N = 10,1000
mat = np.random.normal(0,1,(N,D,D))
vec = np.random.normal(0,1,(N,D))
result = numbaFastMatMult(mat,vec)
print(mat.data.contiguous)
print(vec.data.contiguous)
print(mat[n,:,:].data.contiguous)
print(vec[n,:].data.contiguous)

显然所有相关数据都是连续的(运行上面的代码片段并查看print()的结果…

但是,当我运行此代码时,我得到以下警告:

NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 1d, C), array(float64, 2d, A))
result[n,:] = np.dot(vec[n,:], mat[n,:,:])

2额外注释:

  1. 这只是一个复制的玩具问题。我实际上使用的是有更多数据点的东西,所以希望这能加快速度。
  2. 我认为"权利"解决这个问题的方法是用np。tensordot。但是,我想了解发生了什么,以便将来参考。例如,本文讨论了一个类似的问题,但据我所知,并没有解决为什么会直接出现警告。

我试着添加一个装饰符:

nb.float64[:,::1](nb.float64[:,:,::1],nb.float64[:,::1]),

我已经尝试重新排序数组,所以批索引是第一个(n在上面的代码)我试着打印是否"匹配"。变量从函数

开始连续

我就不写了,不过我想出来了:

在numba函数外:

mat[n,:,:].data.contiguous==True

但在numba内部,mat[n,:,:]不再连续

将上面的代码更改为np.dot(vec[n], mat[n])删除了警告。

我把这个改成"正确的"因为它解决了我的问题。然而,根据max9111的响应,这种行为可能是一个bug!