我正在研究高光谱图像。为了减少图像中的噪声,我使用pywt包进行小波变换。当我正常地做这件事(串行处理(时,它工作得很顺利。但是,当我试图使用多个核对图像进行小波变换来实现并行处理时,我必须通过某些参数,如
- 小波族
- 阈值
- 阈值技术(硬/软(
但我无法使用pool对象传递这些参数,当我使用pool.imap((时,我只能将数据作为参数传递。但当我使用pool.apply_async((时会花费更多的时间,而且输出的顺序也不一样。我在这里添加代码以供参考:
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing as mp
import os
import time
from math import log10, sqrt
import pywt
import tifffile
def spec_trans(d,wav_fam,threshold_val,thresh_type):
data=np.array(d,dtype=np.float64)
data_dec=decomposition(data,wav_fam)
data_t=thresholding(data_dec,threshold_val,thresh_type)
data_rec=reconstruction(data_t,wav_fam)
return data_rec
if __name__ == '__main__':
#input
X=tifffile.imread('data/Classification/university.tif')
#take paramaters
threshold_val=float(input("Enter the value for image thresholding: "))
print("The available wavelet functions:",pywt.wavelist())
wav_fam=input("Choose a wavelet function for transformation: ")
threshold_type=['hard','soft']
print("The available wavelet functions:",threshold_type)
thresh_type=input("Choose a type for threshholding technique: ")
start=time.time()
p = mp.Pool(4)
jobs=[]
for dataBand in xmp:
jobs.append(p.apply_async(spec_trans,args=(dataBand,wav_fam,threshold_val,thresh_type)))
transformedX=[]
for jobBit in jobs:
transformedX.append(jobBit.get())
end=time.time()
p.close()
此外,当我使用"软"技术进行阈值处理时,我会面临以下错误:
C:UsersSawonanaconda3libsite-packagespywt_thresholding.py:25: RuntimeWarning: invalid value encountered in multiply
thresholded = data * thresholded
串行执行和并行执行的结果大致相同。但我得到的结果略有不同。任何修改代码的建议都会有所帮助感谢
[这不是问题的直接答案,但与通过小评论框尝试以下内容相比,这是一个更清晰的后续查询]
作为一种快速检查,将迭代器计数器传递给spec_trans,并将其返回(以及您的结果(,然后将其推送到一个单独的列表中,transformedXseq或其他什么列表中,然后与您的输入序列进行比较。即
def spec_trans(d,wav_fam,threshold_val,thresh_type, iCount):
data=np.array(d,dtype=np.float64)
data_dec=decomposition(data,wav_fam)
data_t=thresholding(data_dec,threshold_val,thresh_type)
data_rec=reconstruction(data_t,wav_fam)
return data_rec, iCount
然后在主内
jobs=[]
iJobs = 0
for dataBand in xmp:
jobs.append(p.apply_async(spec_trans,args=(dataBand,wav_fam,threshold_val,thresh_type, iJobs)))
iJobs = iJobs + 1
transformedX=[]
transformedXseq=[]
for jobBit in jobs:
res = jobBit.get()
transformedX.append(res[0])
transformedXseq.append(res[1])
并检查transformedXseq列表,看看您是否已经按照提交作业的顺序收集了作业。它应该匹配!
假设wav_fam
、threshold_val
和thresh_type
在不同的调用中没有变化,首先将这些参数设置为工作函数spec_trans
:的第一个参数
def spec_trans(wav_fam, threshold_val, thresh_type, d):
现在我看不出在您的池创建块中在哪里定义了xmp
,但可能这是一个可迭代的。您需要按如下方式修改此代码:
from functools import partial
def compute_chunksize(pool_size, iterable_size):
chunksize, remainder = divmod(iterable_size, 4 * pool_size)
if remainder:
chunksize += 1
return chunksize
if __name__ == '__main__':
X=tifffile.imread('data/Classification/university.tif')
#take paramaters
threshold_val=float(input("Enter the value for image thresholding: "))
print("The available wavelet functions:",pywt.wavelist())
wav_fam=input("Choose a wavelet function for transformation: ")
threshold_type=['hard','soft']
print("The available wavelet functions:",threshold_type)
thresh_type=input("Choose a type for threshholding technique: ")
start=time.time()
p = mp.Pool(4)
# first 3 arguments to spec_trans will be wav_fam, threshold_val and thresh_type
worker = partial(spec_trans, wav_fam, threshold_val, thresh_type)
suitable_chunksize = compute_chunksize(4, len(xmp))
transformedX = list(p.imap(worker, xmp, chunksize=suitable_chunksize))
end=time.time()
为了获得比使用apply_async
更高的性能,必须使用";合适的块大小";值与CCD_ 7。函数compute_chunksize
可用于根据池的大小(4(和传递给imap
的可迭代的大小(len(xmp)
(来计算这样的值。如果xmp
的大小足够小,使得计算出的分块大小值为1,那么我真的看不出imap
的性能会比apply_async
高得多。
当然,你也可以使用:
transformedX = p.map(worker, xmp)
并让池计算自己合适的块大小。当可迭代性非常大并且不是已经是列表时,imap
比map
具有优势。对于map
来说,要计算合适的分块大小,首先必须将可迭代项转换为列表才能获得其长度,这可能会降低内存效率。但是,如果您知道可迭代项的长度(或近似长度(,那么通过使用imap,您可以显式地设置chunksize,而不必将可迭代项转换为列表。与map
相比,imap_unordered
的另一个优点是,当单个任务可用时,您可以处理它们的结果,而使用map
,您只有在所有提交的任务都完成时才能获得结果。
更新
如果您想捕获提交给工作函数的单个任务可能引发的异常,那么请坚持使用imap
,并使用以下代码迭代imap
返回的结果:
#transformedX = list(p.imap(worker, xmp, chunksize=suitable_chunksize))
transformedX = []
results = p.imap(worker, xmp, chunksize=suitable_chunksize)
import traceback
while True:
try:
result = next(results)
except StopIteration: # no more results
break
except Exception as e:
print('Exception occurred:', e)
traceback.print_exc() # print stacktrace
else:
transformedX.append(result)