双精度和单精度扩散核的Numba性能



我在使numba代码达到高性能方面遇到了困难。在我的机器上,C++或Julia中的等效代码在大约80毫秒内执行双精度,在大约40毫秒内执行单精度,但numba代码提供:

双倍精度:

In [2]: %timeit diff(at, a, float_type(0.1), float_type(0.1), float_type(0.1), float_type(0.1), itot, jtot, ktot) 291 ms ± 5.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

单精度:

In [2]: %timeit diff(at, a, float_type(0.1), float_type(0.1), float_type(0.1), float_type(0.1), itot, jtot, ktot) 330 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numba没有正确捕获类型,即使每个输入变量都有一个明确的类型。为什么这个代码比C++慢3倍多,为什么单精度工作不好?

import numpy as np
from numba import jit, prange
@jit(nopython=True, nogil=True)
def diff(at, a, visc, dxidxi, dyidyi, dzidzi, itot, jtot, ktot):
for k in range(1, ktot-1):
for j in range(1, jtot-1):
for i in range(1, itot-1):
at[k, j, i] += visc * ( 
+ ( (a[k+1, j  , i  ] - a[k  , j  , i  ])  
- (a[k  , j  , i  ] - a[k-1, j  , i  ]) ) * dxidxi
+ ( (a[k  , j+1, i  ] - a[k  , j  , i  ])  
- (a[k  , j  , i  ] - a[k  , j-1, i  ]) ) * dyidyi
+ ( (a[k  , j  , i+1] - a[k  , j  , i  ])  
- (a[k  , j  , i  ] - a[k  , j  , i-1]) ) * dzidzi )
float_type = np.float32
# float_type = np.float64
itot = 384;
jtot = 384;
ktot = 384;
ncells = itot*jtot*ktot;
at = np.zeros((ktot, jtot, itot), dtype=float_type)
a = np.random.rand(ktot, jtot, itot)
a = a.astype(float_type)
diff(at, a, float_type(0.1), float_type(0.1), float_type(0.1), float_type(0.1), itot, jtot, ktot)

为了进行比较,这是相应的Julia代码:

## Packages
using BenchmarkTools
using LoopVectorization
## Diffusion kernel
function diff!(
at, a,
visc, dxidxi, dyidyi, dzidzi,
itot, jtot, ktot)
@tturbo unroll=8 for k in 2:ktot-1
for j in 2:jtot-1
for i in 2:itot-1
at[i, j, k] += visc * (
( (a[i+1, j  , k  ] - a[i, j, k]) - (a[i, j, k] - a[i-1, j  , k  ]) ) * dxidxi +
( (a[i  , j+1, k  ] - a[i, j, k]) - (a[i, j, k] - a[i  , j-1, k  ]) ) * dyidyi +
( (a[i  , j  , k+1] - a[i, j, k]) - (a[i, j, k] - a[i  , j  , k-1]) ) * dzidzi )
end
end
end
end
## Set the grid size.
itot = 384
jtot = 384
ktot = 384
## Solve the problem in double precision.
visc = 0.1
dxidxi = 0.1
dyidyi = 0.1
dzidzi = 0.1
a = rand(Float64, (itot, jtot, ktot))
at = zeros(Float64, (itot, jtot, ktot))
@btime diff!(
$at, $a,
$visc, $dxidxi, $dyidyi, $dzidzi,
$itot, $jtot, $ktot)
## Solve the problem in single precision.
visc_f = Float32(visc)
dxidxi_f = Float32(dxidxi)
dyidyi_f = Float32(dyidyi)
dzidzi_f = Float32(dzidzi)
a_f = rand(Float32, (itot, jtot, ktot))
at_f = zeros(Float32, (itot, jtot, ktot))
@btime diff!(
$at_f, $a_f,
$visc_f, $dxidxi_f, $dyidyi_f, $dzidzi_f,
$itot, $jtot, $ktot)

避免环绕检查

Numba试图表现得和Python一样。其中一些功能可以很容易地关闭,例如使用error_model="numpy"进行的除以零检查,但我不知道如何关闭环绕检查。

本例中的另一个问题是float32/64行为。您可以使用func.nopython_signatures检查检测到的签名。或者,您可以使用显式签名。

解决方案1

一种可能性是让Numba明白,负指数并没有发生。

@nb.njit(["(float32[:,:,::1])(float32[:,:,::1], float32[:,:,::1], float32, float32, float32, float32)",
"(float64[:,:,::1])(float64[:,:,::1], float64[:,:,::1], float64, float64, float64, float64)"])
def diff_2(at, a, visc, dxidxi, dyidyi, dzidzi):
ktot, jtot, itot=at.shape
for k in range(ktot-2):
for j in range(jtot-2):
for i in range(itot-2):
at[k+1, j+1, i+1] += visc * ( 
+ ( (a[k+2, j+1, i+1] - a[k+1, j+1 , i+1])  
- (a[k+1, j+1, i+1] - a[k  , j+1 , i+1]) ) * dxidxi
+ ( (a[k+1, j+2, i+1] - a[k+1, j+1 , i+1])  
- (a[k+1, j+1, i+1] - a[k+1, j   , i+1]) ) * dyidyi
+ ( (a[k+1, j+1, i+2] - a[k+1, j+1 , i+1])  
- (a[k+1, j+1, i+1] - a[k+1, j+1 , i  ]) ) * dzidzi )
return at
#float32
#%timeit diff_2(at, a, 0.1, 0.1, 0.1, 0.1)
#31 ms ± 708 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#float64
#%timeit diff_2(at, a, 0.1, 0.1, 0.1, 0.1)
#65.4 ms ± 805 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

解决方案2

@nb.njit(["(float32[:,:,::1])(float32[:,:,::1], float32[:,:,::1], float32, float32, float32, float32)",
"(float64[:,:,::1])(float64[:,:,::1], float64[:,:,::1], float64, float64, float64, float64)"],)
def diff_2(at, a, visc, dxidxi, dyidyi, dzidzi):
ktot, jtot, itot=at.shape
for kk in range(ktot-2):
for jj in range(jtot-2):
for ii in range(itot-2):
i=ii+1
j=jj+1
k=kk+1
at[k, j, i] += visc * ( 
+ ( (a[k+1, j  , i  ] - a[k  , j  , i  ])  
- (a[k  , j  , i  ] - a[k-1, j  , i  ]) ) * dxidxi
+ ( (a[k  , j+1, i  ] - a[k  , j  , i  ])  
- (a[k  , j  , i  ] - a[k  , j-1, i  ]) ) * dyidyi
+ ( (a[k  , j  , i+1] - a[k  , j  , i  ])  
- (a[k  , j  , i  ] - a[k  , j  , i-1]) ) * dzidzi )
return at

最新更新