Starmap combined with tqdm?



我正在做一些并行处理,如下:

with mp.Pool(8) as tmpPool:
        results = tmpPool.starmap(my_function, inputs)

输入看起来像: [(1,0.2312(,(5,0.52(...]即,int和float的元组。

代码运行良好,但我似乎无法将其包裹在加载栏(TQDM(上,例如可以使用IMAP方法来完成的,如下所示:

tqdm.tqdm(mp.imap(some_function,some_inputs))

这也可以为starmap做吗?

谢谢!

starmap()不可能,但是加上添加Pool.istarmap()的补丁是可能的。它基于imap()的代码。您要做的就是创建istarmap.py -FILE并导入模块以在制作常规多处理 - IMPORTS之前应用补丁。

Python< 3.8

# istarmap.py for Python <3.8
import multiprocessing.pool as mpp

def istarmap(self, func, iterable, chunksize=1):
    """starmap-version of imap
    """
    if self._state != mpp.RUN:
        raise ValueError("Pool not running")
    if chunksize < 1:
        raise ValueError(
            "Chunksize must be 1+, not {0:n}".format(
                chunksize))
    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self._cache)
    self._taskqueue.put(
        (
            self._guarded_task_generation(result._job,
                                          mpp.starmapstar,
                                          task_batches),
            result._set_length
        ))
    return (item for chunk in result for item in chunk)

mpp.Pool.istarmap = istarmap

Python 3.8

# istarmap.py for Python 3.8+
import multiprocessing.pool as mpp

def istarmap(self, func, iterable, chunksize=1):
    """starmap-version of imap
    """
    self._check_running()
    if chunksize < 1:
        raise ValueError(
            "Chunksize must be 1+, not {0:n}".format(
                chunksize))
    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self)
    self._taskqueue.put(
        (
            self._guarded_task_generation(result._job,
                                          mpp.starmapstar,
                                          task_batches),
            result._set_length
        ))
    return (item for chunk in result for item in chunk)

mpp.Pool.istarmap = istarmap

然后在您的脚本中:

import istarmap  # import to apply patch
from multiprocessing import Pool
import tqdm    

def foo(a, b):
    for _ in range(int(50e6)):
        pass
    return a, b    

if __name__ == '__main__':
    with Pool(4) as pool:
        iterable = [(i, 'x') for i in range(10)]
        for _ in tqdm.tqdm(pool.istarmap(foo, iterable),
                           total=len(iterable)):
            pass

最简单的方法可能是在输入周围应用tqdm((,而不是映射函数。例如:

inputs = zip(param1, param2, param3)
with mp.Pool(8) as pool:
    results = pool.starmap(my_function, tqdm.tqdm(inputs, total=len(param1)))

请注意,在调用my_function时,而不是返回时,将更新栏。如果这种区别很重要,则可以考虑像其他答案所暗示的那样重写个性疾病。否则,这是一个简单有效的替代方案。

如Darkonaut所述,在写作时,没有istarmap本地可用。如果要避免修补,则可以作为解决方法添加简单的 * _star功能。(此解决方案灵感来自本教程。(

import tqdm
import multiprocessing
def my_function(arg1, arg2, arg3):
  return arg1 + arg2 + arg3
def my_function_star(args):
    return my_function(*args)
jobs = 4
with multiprocessing.Pool(jobs) as pool:
    args = [(i, i, i) for i in range(10000)]
    results = list(tqdm.tqdm(pool.imap(my_function_star, args), total=len(args))

一些笔记:

我也非常喜欢科里的答案。它更干净,尽管进度栏似乎没有像我的答案那样顺利更新。请注意,Corey的答案是我在上面发布的chunksize=1(默认值(上发布的代码加快了几个数量级。我猜这是由于多处理序列化造成的,因为增加chunksize(或更昂贵的my_function(使他们的运行时间可比。

我的申请答案是因为我的序列化/功能成本比非常低。

临时解决方案:用imap重写要合行的方法。

相关内容

  • 没有找到相关文章

最新更新