Python 在子进程退出时执行函数



我有一个记忆函数包装器,带有命中和未命中计数器。 因为我无法从函数访问非局部变量,所以我使用字典来计算命中和未命中。

该函数在 48 个内核上以 ~1000 个并行进程运行,每个内核超过 100 万次,因此我正在使用Manager.dict来管理分数。

仅保留分数会使我的执行时间增加三倍,所以我想做一些更聪明的事情 - 我想保留一个本地计数器,它只是一个普通的字典,当进程退出时,将该分数添加到由经理管理的通用分数字典中。

有没有办法在子进程出口处执行函数? 类似于atexit的东西适用于生成的子级。

相关代码:(注意MAGICAL_AT_PROCESS_EXIT_CLASS,这是我想要的(

manager = Manager()
global_score = manager.dict({
"hits": 0,
"misses": 0
})
def memoize(func):
local_score = {
"hits": 0,
"misses": 0
}
cache = {}
def process_exit_handler():
global_score["hits"] += local_score["hits"]
global_score["misses"] += local_score["misses"]
MAGICAL_AT_PROCESS_EXIT_CLASS.register(process_exit_handler)
@wraps(func)
def wrap(*args):
cache_key = pickle.dumps(args)
if cache_key not in cache:
local_score["misses"] += 1
cache[cache_key] = func(*args)
else:
local_score["hits"] += 1
return cache[cache_key]
return wrap

def exit_handler():
print("Cache", global_score)
atexit.register(exit_handler)

(是的,我知道它独立缓存每个进程。是的,这是期望的行为(

当前解决方案:这仅与我的特定功能用例相关。我每个进程运行一次函数,每次运行它自己旋转大约一百万次。 我通过以下方式更改了包装器方法:

@wraps(func)
def wrap(*args):
cache_key = pickle.dumps(args)
if cache_key not in cache:
local_score["misses"] += 1
local_score["open"] += 1
cache[cache_key] = func(*args)
local_score["open"] -= 1
else:
local_score["hits"] += 1
if local_score["open"] == 0:
score["hits"] += local_score["hits"]
score["misses"] += local_score["misses"]
local_score["hits"] = 0
local_score["misses"] = 0
return cache[cache_key]

它不需要同步写入几亿次,只需要同步进程数(1000(。

通过子类化Process实现这一点相对容易,通过记忆来增强它,然后从中构建自己的池,但是由于您想使用multiprocessing.Pool,它变得越来越复杂。Pool不能通过选择来实现这一点,但我们必须干预它的胆量才能使其成为可能。确保在继续阅读时没有子进程正在监视。


有两个问题需要解决。

  1. 使子进程在进程终止时调用退出处理程序。
  2. 防止Pool在其退出处理程序完成之前终止子项。

为了与分叉一起使用作为子进程的启动方法,我发现有必要对猴子补丁multiprocessing.pool.worker。我们可以将atexit与启动方法"spawn"一起使用(Windows 上的默认值(,但这只会让我们节省很少的时间并剥夺我们分叉的好处,因此以下代码不使用atexit.补丁是围绕worker的包装器,在工作线程返回时调用我们的自定义at_exit函数,这发生在进程即将退出时。

# at_exit_pool.py
import os
import threading
from functools import wraps
import multiprocessing.pool
from multiprocessing.pool import worker, TERMINATE, Pool
from multiprocessing import util, Barrier
from functools import partial

def finalized(worker):
"""Extend worker function with at_exit call."""
@wraps(worker)
def wrapper(*args, **kwargs):
result = worker(*args, **kwargs)
at_exit()  # <-- patch
return result
return wrapper

worker = finalized(worker)
multiprocessing.pool.worker = worker  # patch

此解决方案也是子类化Pool来处理这两个问题。PatientPool引入了两个强制性参数at_exitat_exit_argsat_exit正在获取退出处理程序,PatientPool正在从标准Pool中捎带initializer,以在子进程中注册退出处理程序。以下是处理注册退出处理程序的函数:

# at_exit_pool.py
def at_exit(func=None, barrier=None, *args):
"""Call at_exit function and wait on barrier."""
func(*args)
print(os.getpid(), 'barrier waiting')  # DEBUG
barrier.wait()

def register_at_exit(func, barrier, *args):
"""Register at_exit function."""
global at_exit
at_exit = partial(at_exit, func, barrier, *args)

def combi_initializer(at_exit_args, initializer, initargs):
"""Piggyback initializer with register_at_exit."""
if initializer:
initializer(*initargs)
register_at_exit(*at_exit_args)

正如您在at_exit中看到的,我们将使用multiprocessing.Barrier。使用此同步原语是第二个问题的解决方案,防止Pool在 exit-handler 完成其工作之前终止子进程。

屏障的工作方式是,只要"各方"数量的进程没有调用.wait(),就会阻止任何调用.wait()的进程。

PatientPool初始化此类屏障并将其传递给其子进程。此屏障中的parties参数设置为子进程数 + 1。子进程正在调用.wait()这个障碍,一旦他们完成at_exitPatientPool本身也呼吁.wait()这个障碍。这发生在我们为此目的在Pool中重写的_terminate_pool方法中。这样做可以防止池过早终止子进程,因为所有调用.wait()的进程也只有在所有子进程都达到屏障时才会释放。

# at_exit_pool.py
class PatientPool(Pool):
"""Pool class which awaits completion of exit handlers in child processes
before terminating the processes."""
def __init__(self, at_exit, at_exit_args=(), processes=None,
initializer=None, initargs=(), maxtasksperchild=None,
context=None):
# changed--------------------------------------------------------------
self._barrier = self._get_barrier(processes)
at_exit_args = (at_exit, self._barrier) + at_exit_args
initargs = (at_exit_args, initializer, initargs)
super().__init__(
processes, initializer=combi_initializer, initargs=initargs,
maxtasksperchild=maxtasksperchild, context=context
)
# ---------------------------------------------------------------------
@staticmethod
def _get_barrier(processes):
"""Get Barrier object for use in _terminate_pool and
child processes."""
if processes is None:  # this will be repeated in super().__init__(...)
processes = os.cpu_count() or 1
if processes < 1:
raise ValueError("Number of processes must be at least 1")
return Barrier(processes + 1)
def _terminate_pool(self, taskqueue, inqueue, outqueue, pool,
worker_handler, task_handler, result_handler, cache):
"""changed from classmethod to normal method"""
# this is guaranteed to only be called once
util.debug('finalizing pool')
worker_handler._state = TERMINATE
task_handler._state = TERMINATE
util.debug('helping task handler/workers to finish')
self.__class__._help_stuff_finish(inqueue, task_handler, len(pool))  # changed
assert result_handler.is_alive() or len(cache) == 0
result_handler._state = TERMINATE
outqueue.put(None)  # sentinel
# We must wait for the worker handler to exit before terminating
# workers because we don't want workers to be restarted behind our back.
util.debug('joining worker handler')
if threading.current_thread() is not worker_handler:
worker_handler.join()
# patch ---------------------------------------------------------------
print('_terminate_pool barrier waiting')  # DEBUG
self._barrier.wait()  # <- blocks until all processes have called wait()
print('_terminate_pool barrier crossed')  # DEBUG
# ---------------------------------------------------------------------
# Terminate workers which haven't already finished.
if pool and hasattr(pool[0], 'terminate'):
util.debug('terminating workers')
for p in pool:
if p.exitcode is None:
p.terminate()
util.debug('joining task handler')
if threading.current_thread() is not task_handler:
task_handler.join()
util.debug('joining result handler')
if threading.current_thread() is not result_handler:
result_handler.join()
if pool and hasattr(pool[0], 'terminate'):
util.debug('joining pool workers')
for p in pool:
if p.is_alive():
# worker has not yet exited
util.debug('cleaning up worker %d' % p.pid)
p.join()

现在,在您的主模块中,您只需将Pool切换为PatientPool并传递所需的at_exit参数。为简单起见,我的退出处理程序将local_score附加到 toml 文件中。请注意,local_score需要是一个全局变量,以便 exit 处理程序可以访问它。

import os
from functools import wraps
# from multiprocessing import log_to_stderr, set_start_method
# import logging
import toml
from at_exit_pool import register_at_exit, PatientPool

local_score = {
"hits": 0,
"misses": 0
}

def memoize(func):
cache = {}
@wraps(func)
def wrap(*args):
cache_key = str(args)  # ~14% faster than pickle.dumps(args)
if cache_key not in cache:
local_score["misses"] += 1
cache[cache_key] = func(*args)
else:
local_score["hits"] += 1
return cache[cache_key]
return wrap

@memoize
def foo(x):
for _ in range(int(x)):
x - 1
return x

def dump_score(pathfile):
with open(pathfile, 'a') as fh:
toml.dump({str(os.getpid()): local_score}, fh)

if __name__ == '__main__':
# set_start_method('spawn')
# logger = log_to_stderr()
# logger.setLevel(logging.DEBUG)
PATHFILE = 'score.toml'
N_WORKERS = 4
arguments = [10e6 + i for i in range(10)] * 5
# print(arguments[:10])
with PatientPool(at_exit=dump_score, at_exit_args=(PATHFILE,),
processes=N_WORKERS) as pool:
results = pool.map(foo, arguments, chunksize=3)
# print(results[:10])

运行此示例将生成如下所示的终端输出,其中"_terminate_pool障碍交叉"将始终最后执行,而此行之前的流可能会有所不同:

555 barrier waiting
_terminate_pool barrier waiting
554 barrier waiting
556 barrier waiting
557 barrier waiting
_terminate_pool barrier crossed
Process finished with exit code 0

包含此运行分数的 toml 文件如下所示:

[555]
hits = 3
misses = 8
[554]
hits = 3
misses = 9
[556]
hits = 2
misses = 10
[557]
hits = 5
misses = 10

相关内容

  • 没有找到相关文章

最新更新