我正在学习如何使用Numba(虽然我已经相当熟悉Cython)。我应该如何加速这段代码?请注意,该函数返回一个由两个整数元组组成的字典。我正在使用IPython笔记本。我更喜欢Numba而不是Cython。
@autojit
def generateadj(width,height):
adj = {}
for y in range(height):
for x in range(width):
s = set()
if x>0:
s.add((x-1,y))
if x<width-1:
s.add((x+1,y))
if y>0:
s.add((x,y-1))
if y<height-1:
s.add((x,y+1))
adj[x,y] = s
return adj
我设法用Cython写了这个,但我不得不放弃数据的结构方式。我不喜欢这样。我在 Numba 文档中的某处读到它可以处理列表、元组等基本内容。
%%cython
import numpy as np
def generateadj(int width, int height):
cdef int[:,:,:,:] adj = np.zeros((width,height,4,2), np.int32)
cdef int count
for y in range(height):
for x in range(width):
count = 0
if x>0:
adj[x,y,count,0] = x-1
adj[x,y,count,1] = y
count += 1
if x<width-1:
adj[x,y,count,0] = x+1
adj[x,y,count,1] = y
count += 1
if y>0:
adj[x,y,count,0] = x
adj[x,y,count,1] = y-1
count += 1
if y<height-1:
adj[x,y,count,0] = x
adj[x,y,count,1] = y+1
count += 1
for i in range(count,4):
adj[x,y,i] = adj[x,y,0]
return adj
虽然numba
支持dict
s和set
s等Python数据结构,但它在对象模式下这样做。根据numba
术语表,对象模式定义为:
生成处理所有值的代码的 Numba 编译模式 作为 Python 对象,并使用 Python C API 执行所有操作 在这些对象上。在对象模式下编译的代码通常不会运行 比 Python 解释代码更快,除非 Numba 编译器可以 利用循环抖动。
因此,在编写numba
代码时,您需要坚持使用数组等内置数据类型。下面是一些可以做到这一点的代码:
@jit
def gen_adj_loop(width, height, adj):
i = 0
for x in range(width):
for y in range(height):
if x > 0:
adj[i,0] = x
adj[i,1] = y
adj[i,2] = x - 1
adj[i,3] = y
i += 1
if x < width - 1:
adj[i,0] = x
adj[i,1] = y
adj[i,2] = x + 1
adj[i,3] = y
i += 1
if y > 0:
adj[i,0] = x
adj[i,1] = y
adj[i,2] = x
adj[i,3] = y - 1
i += 1
if y < height - 1:
adj[i,0] = x
adj[i,1] = y
adj[i,2] = x
adj[i,3] = y + 1
i += 1
return
这需要一个数组adj
。每行的格式为 x y adj_x adj_y
。所以对于 (3,4)
处的像素,我们有四行:
3 4 2 4
3 4 4 4
3 4 3 3
3 4 3 5
我们可以将上述函数包装在另一个函数中:
@jit
def gen_adj(width, height):
# each pixel has four neighbors, but some of these neighbors are
# off the grid -- 2*width + 2*height of them to be exact
n_entries = width*height*4 - 2*width - 2*height
adj = np.zeros((n_entries, 4), dtype=int)
gen_adj_loop(width, height, adj)
此功能非常快,但不完整。我们必须将adj
转换为您问题中形式的字典。问题是这是一个非常缓慢的过程。我们必须遍历adj
数组并将每个条目添加到 Python 字典中。这不能被numba
抖动。
所以底线是这样的:结果是元组字典的要求实际上限制了你可以优化这段代码的程度。