我试图用numba加速批处理矩阵乘法问题,但它一直告诉我它使用连续代码更快。
注意:我使用numba版本0.55.1,和numpy版本1.21.5
问题来了:
import numpy as np
import numba as nb
def numbaFastMatMult(mat,vec):
result = np.zeros_like(vec)
for n in nb.prange(vec.shape[0]):
result[n,:] = np.dot(vec[n,:], mat[n,:,:])
return result
D,N = 10,1000
mat = np.random.normal(0,1,(N,D,D))
vec = np.random.normal(0,1,(N,D))
result = numbaFastMatMult(mat,vec)
print(mat.data.contiguous)
print(vec.data.contiguous)
print(mat[n,:,:].data.contiguous)
print(vec[n,:].data.contiguous)
显然所有相关数据都是连续的(运行上面的代码片段并查看print()的结果…
但是,当我运行此代码时,我得到以下警告:
NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (array(float64, 1d, C), array(float64, 2d, A))
result[n,:] = np.dot(vec[n,:], mat[n,:,:])
2额外注释:
- 这只是一个复制的玩具问题。我实际上使用的是有更多数据点的东西,所以希望这能加快速度。
- 我认为"权利"解决这个问题的方法是用np。tensordot。但是,我想了解发生了什么,以便将来参考。例如,本文讨论了一个类似的问题,但据我所知,并没有解决为什么会直接出现警告。
我试着添加一个装饰符:
nb.float64[:,::1](nb.float64[:,:,::1],nb.float64[:,::1]),
我已经尝试重新排序数组,所以批索引是第一个(n在上面的代码)我试着打印是否"匹配"。变量从函数
开始连续我就不写了,不过我想出来了:
在numba函数外:
mat[n,:,:].data.contiguous==True
但在numba内部,mat[n,:,:]
不再连续
将上面的代码更改为np.dot(vec[n], mat[n])
删除了警告。
我把这个改成"正确的"因为它解决了我的问题。然而,根据max9111的响应,这种行为可能是一个bug!