如何提高此函数的速度?
def foo(mri_data, radius):
mask = mri_data.copy()
ny = len(mri_data[0,:])
nx = len(mri_data[:])
for y in xrange(0, ny):
for x in xrange(0, nx):
if (mri_data[x-radius:x+radius,y-radius:y+radius] != 1.0).all():
mask[x,y] = 0.0
return mask.copy()
它以numpy数组的形式接收图像切片。遍历每个像素并测试该像素周围的边界框。如果框中没有值等于1,则我们通过将其设置为0来丢弃该像素。
有人告诉我可以使用numpy.convolve
,但我不确定这是怎么回事。
编辑:图像值在二进制范围内,因此最小值为0.0,最大值为1.0。数值介于ex之间:0.767。
可以滥用卷积的情况之一。我不会用它,但边界在其他方面很乏味。。。
from scipy.ndimage import convolve
not_one = (mri_data != 1.0) # are you sure you want to compare with float like that?!
conv = convolve(not_one, np.ones((2*radius, 2*radius)))
all_not_one = (conv == (2*radius)**2)
mask[all_not_one] = 0
真的应该做同样的事情。。。
您所做的被称为binary_dilation
,但您的代码中有一个小错误。具体来说,当x,y小于半径时,得到的是负指数。这些负数是使用numpy索引规则来解释的,这不是你想要的。更多关于索引的信息,会在图像的两个边缘上给出错误的结果。
这里有一些代码使用二进制膨胀来完成同样的事情,并修复了上面提到的错误。
import numpy as np
from scipy.ndimage import binary_dilation
def foo(mri_data, radius):
structure = np.ones((2*radius, 2*radius))
# I set the origin here to match your code
mask = binary_dilation(mri_data == 1, structure, origin=-1)
return np.where(mask, mri_data, 0)