我正在编写一个网格搜索实用程序,并试图使用多处理来加快计算速度。我有一个目标函数,它与一个由于内存限制而无法pickle的大型类交互(我只能pickle该类的相关属性(。
import pickle
from multiprocessing import Pool
class TestClass:
def __init__(self):
self.param = 10
def __getstate__(self):
raise RuntimeError("don't you dare pickle me!")
def __setstate__(self, state):
raise RuntimeError("don't you dare pickle me!")
def loss(self, ext_param):
return self.param*ext_param
if __name__ == '__main__':
test_instance = TestClass()
def objective_function(param):
return test_instance.loss(param)
with Pool(4) as p:
result = p.map(objective_function, range(20))
print(result)
在下面的玩具示例中,我预计在pickle objective_function的过程中,test_instance也必须进行pickle,从而引发RuntimeError(由于在__getstate__处引发异常(。然而,这并没有发生,并且代码运行顺利。
所以我的问题是——这里到底腌了什么?如果test_instance没有被腌制,那么它是如何在单个过程中重建的?
在windows+python3.8上,我无法运行原始代码,该代码将test_instance和objective_function定义为main的局部变量,错误如下
AttributeError: Can't get attribute 'objective_function' on <module '__mp_main' from 'xxx.py'>
我已经将objective_function的定义和test_instance的初始化转移到全局范围,它和您提到的一样工作得很好。然而,从这一点来看,似乎已经为不同的过程再次初始化了全局变量,而不是pickled/unpickled。
最后,我更改了您的代码如下,它触发了您预期的错误。
test_instance1 = TestClass()
test_instance2 = TestClass()
with Pool(4) as p:
result = p.map(objective_function, [test_instance1, test_instance2])
print(result)
所以p.map中的婴儿车实际上是腌制的/未腌制的。
好吧,在Wilson的帮助和进一步的挖掘下,我终于回答了自己的问题。我将插入上面修改过的代码来帮助解释:
import pickle
from multiprocessing import Pool, current_process
class TestClass:
def __init__(self):
self.param = 0
def __getstate__(self):
raise RuntimeError("don't you dare pickle me!")
def __setstate__(self, state):
raise RuntimeError("don't you dare pickle me!")
def loss(self, ext_param):
self.param += 1
print(f"{current_process().pid}: {hex(id(self))}: {self.param}: {ext_param} ")
return f"{self.param}_{ext_param}"
def objective_function(param):
return test_instance.loss(param)
if __name__ == '__main__':
test_instance = TestClass()
print(hex(id(test_instance)))
print('objective_function' in globals()) # this returns True on my MacOS+python3.7
with Pool(2) as p:
result = p.map(objective_function, range(6))
print(result)
print(test_instance.param)
# ---- RUN RESULTS BELOW ----
# 0x7f987b955e48
# True
# 10484: 0x7f987b955e48: 1: 0
# 10485: 0x7f987b955e48: 1: 1
# 10484: 0x7f987b955e48: 2: 2
# 10485: 0x7f987b955e48: 2: 3
# 10484: 0x7f987b955e48: 3: 4
# 10485: 0x7f987b955e48: 3: 5
# ['1_0', '1_1', '2_2', '2_3', '3_4', '3_5']
# 0
正如Wilson正确暗示的那样,在p.map过程中唯一被篡改的是参数本身,而不是目标函数,然而,这并没有被重新初始化,而是被复制,以及在os.fork()
过程中发生在Pool初始化中的test_instance。您可以看到,即使在每个进程内,test_instance.param
值彼此独立,它们也共享与fork之前的类的原始实例相同的虚拟内存(可以在此处看到不同进程共享相同虚拟内存的示例(。
根据最初问题的解决方案,我认为正确解决这个问题的唯一方法是在共享内存或内存管理器中分配必要的参数。