numpy数组的共享字典?



我想用多个numpy数组存储一个字典,并跨进程共享。

import ctypes
import multiprocessing
from typing import Dict, Any
import numpy as np
dict_of_np: Dict[Any, np.ndarray] = multiprocessing.Manager().dict()

def get_numpy(key):
if key not in dict_of_np:
shared_array = multiprocessing.Array(ctypes.c_int32, 5)
shared_np = np.frombuffer(shared_array.get_obj(), dtype=np.int32)
dict_of_np[key] = shared_np
return dict_of_np[key]

if __name__ == "__main__":
a = get_numpy("5")
a[1] = 5
print(a)  # prints [0 5 0 0 0]
b = get_numpy("5")
print(b)  # prints [0 0 0 0 0]

我按照在共享内存中使用numpy数组进行多处理中的说明使用缓冲区创建numpy数组,但是当我试图将结果numpy数组保存在字典中时,它不起作用。如上所示,使用键再次访问字典时,不会保存对numpy数组的更改。

如何共享numpy数组的字典?我需要字典和数组共享并使用相同的内存。

根据我们对这个问题的讨论,我可能已经提出了一个解决方案:通过在主进程中使用一个线程来处理multiprocessing.shared_memory.SharedMemory对象的实例化,您可以确保对共享内存对象的引用保留下来,并且底层内存不会过早删除。这只解决了windows中文件被删除的问题。当它不再存在的时候。它不能解决只要底层内存视图需要,就需要保持每个打开实例的问题。

这个管理器线程"监听"用于输入multiprocessing.Queue上的消息,并创建/返回有关共享内存对象的数据。锁用于确保响应被正确的进程读取(否则响应可能会混淆)。

所有共享内存对象首先由主进程创建,直到显式删除,以便其他进程可以访问它们。

的例子:

import multiprocessing
from multiprocessing import shared_memory, Queue, Process, Lock
from threading import Thread
import numpy as np
class Exit_Flag: pass

class SHMController:
def __init__(self):
self._shm_objects = {}
self.mq = Queue() #message input queue
self.rq = Queue() #response output queue
self.lock = Lock() #only let one child talk to you at a time
self._processing_thread = Thread(target=self.process_messages)

def start(self): #to be called after all child processes are started
self._processing_thread.start()

def stop(self):
self.mq.put(Exit_Flag())

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()

def process_messages(self):
while True:
message_obj = self.mq.get()
if isinstance(message_obj, Exit_Flag):
break
elif isinstance(message_obj, str):
message = message_obj
response = self.handle_message(message)
self.rq.put(response)
self.mq.close()
self.rq.close()

def handle_message(self, message):
method, arg = message.split(':', 1)
if method == "exists":
if arg in self._shm_objects: #if shm.name exists or not
return "ok:true"
else:
return "ok:false"
if method == "size":
if arg in self._shm_objects:
return f"ok:{len(self._shm_objects[arg].buf)}"
else:
return "ko:-1"
if method == "create":
args = arg.split(",") #name, size or just size
if len(args) == 1:
name = None
size = int(args[0])
elif len(args) == 2:
name = args[0]
size = int(args[1])
if name in self._shm_objects:
return f"ko:'{name}' already created"
else:
try:
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
except FileExistsError:
return f"ko:'{name}' already exists"
self._shm_objects[shm.name] = shm
return f"ok:{shm.name}"
if method == "destroy":
if arg in self._shm_objects:
self._shm_objects[arg].close()
self._shm_objects[arg].unlink()
del self._shm_objects[arg]
return f"ok:'{arg}' destroyed"
else:
return f"ko:'{arg}' does not exist"

def create(mq, rq, lock):
#helper functions here could make access less verbose
with lock:
mq.put("create:key123,8")
response = rq.get()
print(response)
if response[:2] == "ok":
name = response.split(':')[1]
with lock:
mq.put(f"size:{name}")
response = rq.get()
print(response)
if response[:2] == "ok":
size = int(response.split(":")[1])
shm = shared_memory.SharedMemory(name=name, create=False, size=size)
else:
print("Oh no....")
return
else:
print("Uh oh....")
return
arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
arr[:] = (1,2)
print(arr)
shm.close()

def modify(mq, rq, lock):
while True: #until the shm exists
with lock:
mq.put("exists:key123")
response = rq.get()
if response == "ok:true":
print("key:exists")
break
with lock:
mq.put("size:key123")
response = rq.get()
print(response)
if response[:2] == "ok":
size = int(response.split(":")[1])
shm = shared_memory.SharedMemory(name="key123", create=False, size=size)
else:
print("Oh no....")
return
arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
arr[0] += 5
print(arr)
shm.close()

def delete(mq, rq, lock):
pass #TODO make a test for this?

if __name__ == "__main__":
multiprocessing.set_start_method("spawn") #because I'm mixing threads and processes
with SHMController() as controller:
mq, rq, lock = controller.mq, controller.rq, controller.lock
create_task = Process(target=create, args=(mq, rq, lock))
create_task.start()
create_task.join()
modify_task = Process(target=modify, args=(mq, rq, lock))
modify_task.start()
modify_task.join()
print("finished")

为了解决每个shm都像数组一样存活的问题,您必须保留对的引用,该特定的shm对象。通过将引用作为属性附加到自定义数组子类(从numpy指南复制到子类),在数组旁边保持引用相当简单

class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array
def __new__(cls, input_array, shm=None):
obj = np.asarray(input_array).view(cls)
obj.shm = shm
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.shm = getattr(obj, 'shm', None)
#example
shm = shared_memory.SharedMemory(name=name)
np_array = SHMArray(np.ndarray(shape, buffer=shm.buf, dtype=np.int32), shm)

最新更新