# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import binascii import os import queue import subprocess from multiprocessing import Queue import pyarrow import pyarrow.plasma as plasma MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB # Each process only need to start one plasma store, so we set it as a global variable. # TODO: how to share between different processes? MGE_PLASMA_STORE_MANAGER = None def _clear_plasma_store(): # `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess, # so this function should be called explicitly global MGE_PLASMA_STORE_MANAGER if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0: del MGE_PLASMA_STORE_MANAGER MGE_PLASMA_STORE_MANAGER = None class _PlasmaStoreManager: __initialized = False def __init__(self): self.socket_name = "/tmp/mge_plasma_{}".format( binascii.hexlify(os.urandom(8)).decode() ) debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0)) # NOTE: this is a hack. Directly use `plasma_store` may make subprocess # difficult to handle the exception happened in `plasma-store-server`. # For `plasma_store` is just a wrapper of `plasma-store-server`, which use # `os.execv` to call the executable `plasma-store-server`. cmd_path = os.path.join(pyarrow.__path__[0], "plasma-store-server") self.plasma_store = subprocess.Popen( [cmd_path, "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),], stdout=None if debug_flag else subprocess.DEVNULL, stderr=None if debug_flag else subprocess.DEVNULL, ) self.__initialized = True self.refcount = 1 def __del__(self): if self.__initialized and self.plasma_store.returncode is None: self.plasma_store.kill() class PlasmaShmQueue: def __init__(self, maxsize: int = 0): r"""Use pyarrow in-memory plasma store to implement shared memory queue. Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle and communication overhead, leading to better performance in multi-process application. Args: maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) """ # Lazy start the plasma store manager global MGE_PLASMA_STORE_MANAGER if MGE_PLASMA_STORE_MANAGER is None: try: MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager() except Exception as e: err_info = ( "Please make sure pyarrow installed correctly!\n" "You can try reinstall pyarrow and see if you can run " "`plasma_store -s /tmp/mge_plasma_xxx -m 1000` normally." ) raise RuntimeError( "Exception happened in starting plasma_store: {}\n" "Tips: {}".format(str(e), err_info) ) else: MGE_PLASMA_STORE_MANAGER.refcount += 1 self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name # TODO: how to catch the exception happened in `plasma.connect`? self.client = None # Used to store the header for the data.(ObjectIDs) self.queue = Queue(maxsize) # type: Queue def put(self, data, block=True, timeout=None): if self.client is None: self.client = plasma.connect(self.socket_name) try: object_id = self.client.put(data) except plasma.PlasmaStoreFull: raise RuntimeError("plasma store out of memory!") try: self.queue.put(object_id, block, timeout) except queue.Full: self.client.delete([object_id]) raise queue.Full def get(self, block=True, timeout=None): if self.client is None: self.client = plasma.connect(self.socket_name) object_id = self.queue.get(block, timeout) if not self.client.contains(object_id): raise RuntimeError( "ObjectID: {} not found in plasma store".format(object_id) ) data = self.client.get(object_id) self.client.delete([object_id]) return data def qsize(self): return self.queue.qsize() def empty(self): return self.queue.empty() def join(self): self.queue.join() def disconnect_client(self): if self.client is not None: self.client.disconnect() def close(self): self.queue.close() self.disconnect_client() global MGE_PLASMA_STORE_MANAGER MGE_PLASMA_STORE_MANAGER.refcount -= 1 _clear_plasma_store() def cancel_join_thread(self): self.queue.cancel_join_thread()