# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import multiprocessing as mp import threading import time from collections import defaultdict from functools import partial from socketserver import ThreadingMixIn from xmlrpc.client import ServerProxy from xmlrpc.server import SimpleXMLRPCServer from ..core._imperative_rt.utils import create_mm_server from ..utils.future import Future class Methods: r"""Distributed Server Method. Used for exchange information between distributed nodes. Args: mm_server_port: multiple machine rpc server port. """ def __init__(self, mm_server_port): self.lock = threading.Lock() self.mm_server_port = mm_server_port self.dict_is_grad = defaultdict(partial(Future, True)) self.dict_remote_tracer = defaultdict(partial(Future, True)) self.dict_pack_list = defaultdict(partial(Future, False)) self.dict_barrier_counter = defaultdict(int) self.dict_barrier_event = defaultdict(threading.Event) self.user_dict = defaultdict(partial(Future, False)) self.bcast_dict = {} def connect(self): r"""Method for checking connection success.""" return True def get_mm_server_port(self): r"""Get multiple machine rpc server port.""" return self.mm_server_port def set_is_grad(self, key, is_grad): r"""Mark send/recv need gradiants by key. Args: key: key to match send/recv op. is_grad: whether this op need grad. """ with self.lock: future = self.dict_is_grad[key] future.set(is_grad) return True def check_is_grad(self, key): r"""Check whether send/recv need gradiants. Args: key: key to match send/recv op. """ with self.lock: future = self.dict_is_grad[key] ret = future.get() with self.lock: del self.dict_is_grad[key] return ret def set_remote_tracer(self, key, tracer_set): r"""Set tracer dict for tracing send/recv op. Args: key: key to match send/recv op. tracer_set: valid tracer set. """ with self.lock: future = self.dict_remote_tracer[key] future.set(tracer_set) return True def check_remote_tracer(self, key): r"""Get tracer dict for send/recv op. Args: key: key to match send/recv op. """ with self.lock: future = self.dict_remote_tracer[key] ret = future.get() with self.lock: del self.dict_remote_tracer[key] return ret def group_barrier(self, key, size): r"""A barrier wait for all group member. Args: key: group key to match each other. size: group size. """ with self.lock: self.dict_barrier_counter[key] += 1 counter = self.dict_barrier_counter[key] event = self.dict_barrier_event[key] if counter == size: del self.dict_barrier_counter[key] del self.dict_barrier_event[key] event.set() else: event.wait() return True def user_set(self, key, val): r"""Set user defined key-value pairs across processes.""" with self.lock: future = self.user_dict[key] future.set(val) return True def user_get(self, key): r"""Get user defined key-value pairs across processes.""" with self.lock: future = self.user_dict[key] return future.get() def bcast_val(self, val, key, size): with self.lock: if key not in self.bcast_dict: self.bcast_dict[key] = [Future(False), size] arr = self.bcast_dict[key] if val is not None: arr[0].set(val) val = None else: val = arr[0].get() with self.lock: cnt = arr[1] - 1 arr[1] = cnt if cnt == 0: del self.bcast_dict[key] return val def _del(self, key): with self.lock: del self.user_dict[key] # thread safe function def user_pop(self, key): ret = self.user_get(key) self._del(key) return ret class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass def _start_server(py_server_port, queue): r"""Start python distributed server and multiple machine server. Args: py_server_port: python server port. mm_server_port: multiple machine server port. queue: server port will put in this queue, puts exception when process fails. """ try: mm_server_port = create_mm_server("0.0.0.0", 0) server = ThreadXMLRPCServer( ("0.0.0.0", py_server_port), logRequests=False, allow_none=True ) server.register_instance(Methods(mm_server_port)) _, py_server_port = server.server_address queue.put((py_server_port, mm_server_port)) server.serve_forever() except Exception as e: queue.put(e) class Server: r"""Distributed Server for distributed training. Should be running at master node. Args: port: python server port. """ def __init__(self, port=0): q = mp.Queue() self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True) self.proc.start() ret = q.get() if isinstance(ret, Exception): raise ret else: self.py_server_port, self.mm_server_port = ret def __del__(self): self.proc.terminate() class Client: r"""Distributed Client for distributed training. Args: master_ip: ip address of master node. port: port of server at master node. """ def __init__(self, master_ip, port): self.master_ip = master_ip self.port = port self.connect() self.bcast_dict = defaultdict(lambda: 0) def connect(self): r"""Check connection success.""" while True: try: self.proxy = ServerProxy( "http://{}:{}".format(self.master_ip, self.port), allow_none=True ) if self.proxy.connect(): break except: time.sleep(1) def get_mm_server_port(self): r"""Get multiple machine server port.""" while True: try: return self.proxy.get_mm_server_port() except: time.sleep(0.5) def set_is_grad(self, key, is_grad): r"""Mark send/recv need gradiants by key. Args: key: key to match send/recv op. is_grad: whether this op need grad. """ self.proxy.set_is_grad(key, is_grad) def check_is_grad(self, key): r"""Check whether send/recv need gradiants. Args: key: key to match send/recv op. """ return self.proxy.check_is_grad(key) def set_remote_tracer(self, key, tracer_set): r"""Set tracer dict for tracing send/recv op. Args: key: key to match send/recv op. tracer_set: valid tracer set. """ self.proxy.set_remote_tracer(key, tracer_set) def check_remote_tracer(self, key): r"""Get tracer dict for send/recv op. Args: key: key to match send/recv op. """ return self.proxy.check_remote_tracer(key) def group_barrier(self, key, size): r"""A barrier wait for all group member. Args: key: group key to match each other. size: group size. """ # FIXME: group_barrier is not idempotent while True: try: self.proxy.group_barrier(key, size) return except: time.sleep(0.5) def user_set(self, key, val): r"""Set user defined key-value pairs across processes.""" return self.proxy.user_set(key, val) def user_get(self, key): r"""Get user defined key-value pairs across processes.""" return self.proxy.user_get(key) def user_pop(self, key): r"""Get user defined key-value pairs and delete the resources when the get is done""" return self.proxy.user_pop(key) def bcast_val(self, val, key, size): idx = self.bcast_dict[key] + 1 self.bcast_dict[key] = idx key = key + "_bcast_" + str(idx) return self.proxy.bcast_val(val, key, size) def main(port=0, verbose=True): mm_server_port = create_mm_server("0.0.0.0", 0) server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) server.register_instance(Methods(mm_server_port)) _, port = server.server_address print("serving on port", port) server.serve_forever() if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("-p", "--port", type=int, default=0) ap.add_argument("-v", "--verbose", type=bool, default=True) args = ap.parse_args() main(port=args.port, verbose=args.verbose)