Python 快速实现 3D 数组的卷积/互相关



我正在计算3D图像的卷积(互相关(。由于问题的性质,不需要基于FFT的卷积近似(例如scipy fftconvolve(,并且"直接和"是要走的路。图像的大小为 ~(150, 150, 150(,最大的内核大小为 ~(40, 40, 40(。图像是周期性的(具有周期性边界条件,或者需要用同一图像填充(,因为一次分析必须进行~100次这样的卷积,因此卷积函数的速度至关重要。

我已经实现并测试了几个函数,包括使用"method = direct"进行卷积的scipy实现,结果如下所示。我使用了一个(100,100,100(图像和一个(7,7,7(内核来对这里的方法进行基准测试:

import numpy as np
import time
from scipy import signal
image = np.random.rand(Nx,Ny,Nz)
kernel = np.random.rand(3,5,7)
signal.convolve(image,kernel, mode='same',method = "direct")

花费:8.198秒

然后我基于数组加法编写了自己的函数

def shift_array(array, a,b,c):
A = np.roll(array,a,axis = 0)
B = np.roll(A,b,axis = 1)
C = np.roll(B,c,axis = 2)
return C
def matrix_convolve2(image,kernel, mode = "periodic"):
if mode not in ["periodic"]:
raise NotImplemented
if mode is "periodic":
Nx, Ny, Nz = image.shape
nx, ny, nz = kernel.shape
rx = nx//2
ry = ny//2
rz = nz//2
result = np.zeros((Nx, Ny, Nz))
for i in range(nx):
for j in range(ny):
for k in range(nz):
result += kernel[i,j,k] * shift_array(image, rx-i, ry-j, rz-k) 
return result

matrix_convolve2(image,kernel)

花费:6.324秒

在这种情况下,这里的限制因素似乎是周期性边界条件的 np.roll 函数,所以我试图通过耕种输入图像来规避这一点

def matrix_convolve_center(image,kernel):
# Only get convolve result for the "central" block
nx, ny, nz = kernel.shape
rx = nx//2
ry = ny//2
rz = nz//2
result = np.zeros((Nx, Ny, Nz))
for i in range(nx):
for j in range(ny):
for k in range(nz):
result += kernel[i,j,k] * image[Nx+i-rx:2*Nx+i-rx,Ny+j-ry:2*Ny+j-ry,Nz+k-rz:2*Nz+k-rz]
return result
def matrix_convolve3(image,kernel):
Nx, Ny, Nz = image.shape
nx, ny, nz = kernel.shape
extended_image = np.tile(image,(3,3,3))
result = matrix_convolve_center(extended_image,kernel,Nx, Ny, Nz)
return result
matrix_convolve3(image,kernel)

采取:2.639秒

这种方法提供了迄今为止最好的性能,但对于实际应用来说仍然太慢了。

我做了一些研究,似乎使用"Numba"可以显着提高性能,或者也许以并行方式编写相同的函数也会有所帮助,但我对 Numba 不熟悉,也不是 python 并行化(我对multiprocess库有一些不好的体验......它似乎跳过迭代或有时突然停止(

你们能在这里帮我吗?任何改进将不胜感激。多谢!

这远非结论性,但对于我检查的示例fft确实比朴素(顺序(求和更准确。因此,除非您有充分的理由相信您的数据有所不同,否则我的建议是:省去麻烦并使用fft

更新:添加了我自己的直接方法,注意确保它使用成对求和。这设法比 ff 更准确一些,但仍然非常慢。

测试脚本:

import numpy as np
from scipy import stats, signal, fftpack
def matrix_convolve_center(image,kernel,Nx,Ny,Nz):
# Only get convolve result for the "central" block
nx, ny, nz = kernel.shape
rx = nx//2
ry = ny//2
rz = nz//2
result = np.zeros((Nx, Ny, Nz))
for i in range(nx):
for j in range(ny):
for k in range(nz):
result += kernel[i,j,k] * image[Nx+i-rx:2*Nx+i-rx,Ny+j-ry:2*Ny+j-ry,Nz+k-rz:2*Nz+k-rz]
return result
def matrix_convolve3(image,kernel):
Nx, Ny, Nz = image.shape
nx, ny, nz = kernel.shape
extended_image = np.tile(image,(3,3,3))
result = matrix_convolve_center(extended_image,kernel,Nx, Ny, Nz)
return result
P=0   # parity
CH=10 # chunk size
# make integer example, so exact soln is readily available
image = np.random.randint(0,100,(8*CH+P,8*CH+P,8*CH+P))
kernel = np.random.randint(0,100,(2*CH+P,2*CH+P,2*CH+P))
kerpad = np.zeros_like(image)
kerpad[3*CH:-3*CH,3*CH:-3*CH,3*CH:-3*CH]=kernel[::-1,::-1,::-1]
cexa = np.round(fftpack.fftshift(fftpack.ifftn(fftpack.fftn(fftpack.ifftshift(image))*fftpack.fftn(fftpack.ifftshift(kerpad)))).real).astype(int)
# sanity check
assert cexa.sum() == kernel.sum() * image.sum()
# normalize to preclude integer arithmetic during the actual test
image = image / image.sum()
kernel = kernel / kernel.sum()
cexa = cexa / cexa.sum()
# fft method
kerpad = np.zeros_like(image)
kerpad[3*CH:-3*CH,3*CH:-3*CH,3*CH:-3*CH]=kernel[::-1,::-1,::-1]
cfft = fftpack.fftshift(fftpack.ifftn(fftpack.fftn(fftpack.ifftshift(image))*fftpack.fftn(fftpack.ifftshift(kerpad))))
def direct_pp(image,kernel):
nx,ny,nz = image.shape
kx,ky,kz = kernel.shape
out = np.zeros_like(image)
image = np.concatenate([image[...,-kz//2+1:],image,image[...,:kz//2+P]],axis=2)
image = np.concatenate([image[:,-ky//2+1:],image,image[:,:ky//2+P]],axis=1)
image = np.concatenate([image[-kx//2+1:],image,image[:kx//2+P]],axis=0)
mx,my,mz = image.shape
ox,oy,oz = 2*mx-nx,2*my-ny,2*mz-nz
aux = np.empty((ox,oy,kx,ky),image.dtype)
s0,s1,s2,s3 = aux.strides
aux2 = np.lib.stride_tricks.as_strided(aux[kx-1:,ky-1:],(mx,my,kx,ky),(s0,s1,s2-s0,s3-s1))
for z in range(nz):
aux2[...] = np.einsum('ijm,klm',image[...,z:z+kz],kernel)
out[...,z] = aux[kx-1:kx-1+nx,ky-1:ky-1+ny].sum((2,3))
return out
# direct methods
print("How about a coffee? (This may take some time...)")
from time import perf_counter as pc
T = []
T.append(pc())
cdirpp = direct_pp(image,kernel)
T.append(pc())
cdir = np.roll(matrix_convolve3(image,kernel),P-1,(0,1,2))
T.append(pc())
# compare squared error
nrm = (cexa**2).sum()
print('accuracy')
print('fft   ',((cexa-cfft)*(cexa-cfft.conj())).real.sum()/nrm)
print('direct',((cexa-cdir)**2).sum()/nrm)
print('dir pp',((cexa-cdirpp)**2).sum()/nrm)
print('duration direct methods')
print('pp {} OP {}'.format(*np.diff(T)))

示例运行:

How about a coffee? (This may take some time...)
accuracy
fft    5.690597572945596e-32
direct 8.518853759493871e-30
dir pp 1.3317651721034386e-32
duration direct methods
pp 5.817311848048121 OP 20.05021938495338

最新更新