# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from importlib import import_module from typing import Dict, Tuple from ..core._imperative_rt import OpDef from ..core.ops import builtin from ..tensor import Tensor from ..version import __version__ from .utils import _convert_kwargs_to_args OPDEF_LOADER = {} FUNCTIONAL_LOADER = {} TENSORMETHOD_LOADER = {} MODULE_LOADER = {} class _ModuleState: obj = None def __init__(self, module: Tuple, state: Dict, version: str): self.module = module self.state = state self.version = version @classmethod def get_module_state(cls, module): typem = (type(module).__module__, type(module).__qualname__) state = module.__dict__.copy() state.pop("_m_dump_modulestate", None) if hasattr(module, "_m_dump_modulestate"): assert isinstance(module._m_dump_modulestate, cls) module._m_dump_modulestate.__init__(typem, state, __version__) else: module.__dict__["_m_dump_modulestate"] = _ModuleState( typem, state, __version__ ) return module._m_dump_modulestate def __getstate__(self): return {"module": self.module, "state": self.state, "version": self.version} def to_module(self): if self.obj is None: typem = getattr(import_module(self.module[0]), self.module[1]) m_obj = typem.__new__(typem) m_obj.__setstate__(self.state) self.obj = m_obj return self.obj def register_opdef_loader(*opdefs): def callback(loader): for opdef in opdefs: assert opdef not in OPDEF_LOADER OPDEF_LOADER[opdef] = loader return loader return callback def register_functional_loader(*funcs): def callback(loader): for func in funcs: assert func not in FUNCTIONAL_LOADER FUNCTIONAL_LOADER[func] = loader return loader return callback def register_module_loader(*module_types): def callback(loader): for module_type in module_types: assert module_type not in MODULE_LOADER MODULE_LOADER[module_type] = loader return loader return callback def register_tensor_method_loader(*methods): def callback(loader): for method in methods: assert method not in TENSORMETHOD_LOADER TENSORMETHOD_LOADER[method] = loader return loader return callback def _replace_args_kwargs(expr, new_args, new_kwargs): if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set( expr.kwargs.keys() ): expr.set_args_kwargs(*new_args, **new_kwargs) def load_functional(expr): func = ( (expr.func.__module__, expr.func.__qualname__) if callable(expr.func) else expr.func ) assert isinstance(func, tuple) if func in FUNCTIONAL_LOADER: loader = FUNCTIONAL_LOADER[func] loader(expr) mname, fname = func f = import_module(mname) for i in fname.split("."): f = getattr(f, i) expr.func = f assert callable(expr.func) if not hasattr(expr, "version") or expr.version != __version__: args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs) _replace_args_kwargs(expr, args, kwargs) def load_call_module_expr(expr): m_type = expr.inputs[0].module_type if isinstance(m_type, type): m_type = (m_type.__module__, m_type.__qualname__) if m_type in MODULE_LOADER: MODULE_LOADER[m_type](expr) if isinstance(expr.inputs[0].module_type, tuple): mname, classname = expr.inputs[0].module_type expr.inputs[0].module_type = getattr(import_module(mname), classname) if not hasattr(expr, "version") or expr.version != __version__: fwd_func = getattr(expr.inputs[0].module_type, "forward") args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs) _replace_args_kwargs(expr, args, kwargs) def load_call_tensor_method_expr(expr): if expr.method in TENSORMETHOD_LOADER: loader = TENSORMETHOD_LOADER[expr.method] loader(expr) if not hasattr(expr, "version") or expr.version != __version__: tmethod = ( getattr(expr.args[0], expr.method) if isinstance(expr.args[0], type) else getattr(Tensor, expr.method) ) args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs) _replace_args_kwargs(expr, args, kwargs) def load_apply_expr(expr): opdef_type = type(expr.opdef) if opdef_type in OPDEF_LOADER: OPDEF_LOADER[opdef_type](expr) opdef_state = expr.opdef_state opdef_obj = opdef_state.pop("opdef_type")() opdef_obj.__setstate__(opdef_state) expr.opdef = opdef_obj