# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import abc import copy import weakref from importlib import import_module from typing import Any, Dict, List, Tuple, Type import numpy from .. import get_logger from ..core._imperative_rt.core2 import Tensor as RawTensor from ..module import Module from ..quantization.utils import QParams from ..tensor import Tensor from .module_tracer import active_module_tracer from .tm_config import _get_expr_checker from .utils import _check_obj_attr logger = get_logger(__name__) class Node: r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method. They are inputs/outputs of Expr (the operations on variables). """ expr = None # type: Expr r"""The Expr which produces the Node.""" __total_id = 0 # type: int _id = None # type: int _top_graph = None # type: weakref.ReferenceType _format_spec = "" # type: str def __init__(self, expr, name: str, qualname: str): self.expr = expr self.users = [] # List[Expr] self._id = Node.__total_id Node.__total_id += 1 self._name = name self._qualname = qualname self.actual_node = [] # type: List[Node] def __repr__(self): format_spec = Node._format_spec return self.__format__(format_spec) def __format__(self, format_spec: str) -> str: if not format_spec: format_spec = Node._format_spec name = self._name if name is None: name = "" if format_spec in ["i", "p", "ip", "pi"]: if "p" in format_spec: prefix_name = self.top_graph._name name = "{}_{}".format(prefix_name, name) if "i" in format_spec: name = "%{}_{}".format(self._id, name) return name else: return name if name else ("%d" % self._id) @property def name(self): r"""Return the name of this Node.""" return self._name @name.setter def name(self, new_name: str): r"""Set a new name to this Node.""" graph = self.top_graph assert graph is not None, "The parent graph of this Node cannot be None." assert graph._namespace.used_names.get(new_name, None) is None, ( "The name(%s) is already in use. Please try a different one again." % (new_name) ) graph._namespace.unassociate_name_with_obj(self) self._name = graph._namespace.create_unique_name(new_name, self) @property def qualname(self): r"""Get the `qualname` of this Node. The `qualname` can be used to get the submodule from the traced Module or Module. Example: .. code-block:: import megengine.module as M import megengine.functional as F import megengine.traced_module as tm import megengine as mge class block(M.Module): def __init__(self): super().__init__() self.param = mge.Tensor([1.]) self.relu = M.ReLU() def forward(self, x): x = x + self.param return self.relu(F.relu(x)) class module(M.Module): def __init__(self): super().__init__() self.block = block() def forward(self, x): x = self.block(x) return x net = module() traced_net = tm.trace_module(net, mge.Tensor([0.])) traced_net = traced_net.flatten() out_node = traced_net.graph.outputs[0] # qualname : "module.block.relu.[out]" qualname = out_node.qualname # qualname : "block.relu" qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0] assert qualname in list(map(lambda x: x[0], net.named_modules())) assert qualname in list(map(lambda x: x[0], traced_net.named_modules())) """ return self._qualname @property def top_graph(self): r"""Get the parent graph of this Node.""" if self._top_graph: return self._top_graph() return None @classmethod def _set_format_spec(cls, str): old_format_spec = cls._format_spec cls._format_spec = str return old_format_spec @classmethod def _get_next_id(cls): return cls.__total_id @classmethod def _set_next_id(cls, id: int = 0): assert isinstance(id, int) cls.__total_id = id def __copy__(self): cls = self.__class__ result = cls.__new__(cls) result.__dict__.update(self.__dict__) return result def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) state = {} memo[id(self)] = result for k, v in self.__dict__.items(): if not isinstance(v, weakref.ReferenceType) and k != "actual_node": state[k] = copy.deepcopy(v, memo) result.__dict__.update(state) return result class ModuleNode(Node): r"""``ModuleNode`` represents the Module objects.""" module_type = Module # type: Type[Module] r"""The type of the Module correspending to the ModuleNode.""" _owner = None # type: weakref.ReferenceType def __init__(self, expr, name: str = None, qualname: str = None): super().__init__(expr, name, qualname) def __getstate__(self): state = { "expr": self.expr, "users": self.users, "_id": self._id, "_name": self._name, "_qualname": self._qualname, "module_type": (self.module_type.__module__, self.module_type.__qualname__), } _check_obj_attr(state) return state def __setstate__(self, state): if "_orig_name" in state: state["_qualname"] = state.pop("_orig_name") self.__dict__.update(state) try: if isinstance(self.module_type, tuple): mname, classname = self.module_type mtype = getattr(import_module(mname), classname) self.module_type = mtype except Exception: pass @property def owner(self): r"""Get the ``Module`` corresponding to this ``ModuleNode``. """ if self._owner: return self._owner() return None class TensorNode(Node): r"""``TensorNode`` represents the Tensor objects.""" _shape = None # type: Tuple[int] _dtype = None # type: numpy.dtype _qparams = None # type: QParams _device = None _value = None # type: Tensor def __init__( self, expr, name: str = None, qualname: str = None, shape: Tuple[int] = None, dtype: numpy.dtype = None, qparams: QParams = None, ): super().__init__(expr, name, qualname) self._shape = shape self._dtype = dtype self._qparams = qparams def __getstate__(self): state = { "expr": self.expr, "users": self.users, "_id": self._id, "_qparams": self._qparams, "_shape": self._shape, "_dtype": self._dtype, "_device": self._device, "_name": self._name, "_qualname": self._qualname, } _check_obj_attr(state) return state def __setstate__(self, state): if "_orig_name" in state: qualname = state.pop("_orig_name") modulepath, comma, qualname = qualname.rpartition(".") expr_name = state["expr"].__class__.__name__ if expr_name not in ["GetAttr"]: qualname = "[{}]".format(qualname) if comma: qualname = "{}.{}".format(modulepath, qualname) state["_qualname"] = qualname self.__dict__.update(state) @property def shape(self): r"""Get the shape of this Node.""" return self._shape @shape.setter def shape(self, shape): self._shape = shape @property def dtype(self): r"""Get the dtype of this Node.""" return self._dtype @dtype.setter def dtype(self, dtype): self._dtype = dtype @property def device(self): r"""Get the device of this Node pointed Tensor.""" return self._device @device.setter def device(self, device): self._device = device @property def qparams(self): r"""Get the :class:`QParams` of this Node.""" return self._qparams @qparams.setter def qparams(self, qparams): self._qparams = qparams @property def value(self): r"""Get the bound Tensor of this Node.""" return self._value @value.setter def value(self, value): r"""Bind a :class:`Tensor` to this Node.""" if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: setattr(value, "_NodeMixin__node", None) self._value = value class NodeMixin(abc.ABC): __node = None @abc.abstractmethod def _record_wrapped_nodes(self, node): # record the nodes which had been bound to this NodeMixin pass @classmethod def _record_tensornode_property(cls, node, value): assert isinstance(node, TensorNode) assert isinstance(value, RawTensor) if isinstance(value, RawTensor): try: node._dtype = value.dtype except RuntimeError: node._dtype = None node._shape = ( value._tuple_shape if isinstance(value, Tensor) else value.shape ) node._device = value.device if hasattr(value, "_qparams") and value._qparams is not None: node._qparams = value.qparams @classmethod def wrap(cls, value, node): if isinstance(value, (NodeMixin, RawTensor)): if isinstance(node, Node): if isinstance(value, RawTensor): cls._record_tensornode_property(node, value) if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) setattr(value, "_NodeMixin__node", node) if _get_expr_checker(): if isinstance(value, RawTensor): active_module_tracer().checker.record_node2value(node, value) if isinstance(value, NodeMixin): active_module_tracer().checker.record_nodemixin(node, value) else: assert callable(node) n = node() assert isinstance(n, Node) if isinstance(value, RawTensor): cls._record_tensornode_property(n, value) if isinstance(value, NodeMixin): value._record_wrapped_nodes(n) setattr(value, "_NodeMixin__node", n) if _get_expr_checker(): if isinstance(value, RawTensor): active_module_tracer().checker.record_node2value(n, value) if isinstance(value, NodeMixin): active_module_tracer().checker.record_nodemixin(n, value) @classmethod def wrap_safe(cls, value, node): assert isinstance(value, (NodeMixin, RawTensor)) if isinstance(value, RawTensor): cls._record_tensornode_property(node, value) setattr(value, "_NodeMixin__node", node) if _get_expr_checker(): if isinstance(value, RawTensor): active_module_tracer().checker.record_node2value(node, value) if isinstance(value, NodeMixin): active_module_tracer().checker.record_nodemixin(node, value) if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) @classmethod def clear_node(cls, value): if hasattr(value, "_NodeMixin__node"): delattr(value, "_NodeMixin__node") @classmethod def get(cls, value, *default): return getattr(value, "_NodeMixin__node", *default) @classmethod def get_wrapped_type(cls, value): if isinstance(value, RawTensor): return TensorNode if isinstance(value, (Module, NodeMixin)): return ModuleNode return Node