# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import builtins import collections import copy import fnmatch import functools import inspect import keyword import re import weakref from importlib import import_module from inspect import getcallargs, getmembers, isclass, ismethod from itertools import chain from types import FunctionType from typing import ( Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union, ) from .. import functional as F from .. import get_logger from .. import module as M from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import ( apply, is_tracing_module, set_module_tracing, unset_module_tracing, ) from ..core._trace_option import set_symbolic_shape from ..core.ops.builtin import Copy from ..module import Module from ..module import external as MExternal from ..module.qat import QATModule from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize from ..quantization.observer import ( ExponentialMovingAverageObserver, HistogramObserver, MinMaxObserver, Observer, PassiveObserver, SyncExponentialMovingAverageObserver, SyncMinMaxObserver, ) from ..tensor import Tensor from ..utils.max_recursion_limit import max_recursion_limit from ..version import __version__ from .expr import ( Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input, get_suffix_name, is_apply_def, is_call_function, is_call_module, is_call_tensor_method, is_constant, is_getattr, is_input, ) from .fake_quant import FakeQuantize as TM_FakeQuant from .module_tracer import ( PatchedFn, Patcher, active_module_tracer, get_tensor_wrapable_method, module_tracer, set_active_module_tracer, ) from .node import ModuleNode, Node, NodeMixin, TensorNode from .pytree import ( USER_REGISTERED_CONTAINER_TYPE, USER_REGISTERED_LEAF_TYPE, ArgsIndex, TreeDef, _register_supported_type, tree_flatten, ) from .serialization import ( _ModuleState, load_apply_expr, load_call_module_expr, load_call_tensor_method_expr, load_functional, ) from .tm_config import ( _exclude_from_trace, _get_default_checker, _get_expr_checker, _graph_surgery_mode, _set_graph_surgery_mode, ) from .utils import ( _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args, replace_container_with_module_container, ) logger = get_logger(__name__) def _is_builtin_name(name: str) -> bool: return ( name in builtins.__dict__ or name in keyword.kwlist or name in {"inf", "nan", "NoneType"} ) def _is_leaf(node): assert isinstance( node, RawTensor ), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format( type(node) ) return isinstance(node, RawTensor) def _node_to_tensor(*args, **kwargs): tensors = [] nodes, tree_def = tree_flatten((args, kwargs)) for n in nodes: if isinstance(n, TensorNode): if n.top_graph is not None: active_module_tracer().current_scope()._add_input(n) value = n.value if value is None: flag = _set_graph_surgery_mode(False) with _exclude_from_trace(): value = F.zeros(shape=n._shape, dtype=n._dtype) _set_graph_surgery_mode(flag) orig_n = NodeMixin.get(value, None) if orig_n is None or "setitem" not in orig_n._name: NodeMixin.wrap_safe(value, n) tensors.append(value) else: tensors.append(n) tensors = tree_def.unflatten(tensors) return tensors def _tensor_to_node(tensors): if tensors is None: return None nodes = [] tensors, out_def = tree_flatten(tensors) for t in tensors: if isinstance(t, Tensor): n = NodeMixin.get(t, None) if isinstance(n, TensorNode): n.value = t nodes.append(n) else: nodes.append(t) else: nodes.append(t) nodes = out_def.unflatten(nodes) return nodes def _name_setter(node: Node, new_name: str): surgery_mode = _set_graph_surgery_mode(False) graph = active_module_tracer().current_scope() if node.top_graph is not None: top_graph = active_module_tracer().top_scope() if node is top_graph._namespace.used_names.get(node._name, None): graph = top_graph else: graph = node.top_graph assert ( graph._namespace.used_names.get(new_name, None) is None ), "The name(%s) is already in use. Please try a different one again." % (new_name) graph._namespace.unassociate_name_with_obj(node) node._name = graph._namespace.create_unique_name(new_name, node) _set_graph_surgery_mode(surgery_mode) def _wrap_method_to_tensor_node(): def _any_method(name, func): def _any(*args, **kwargs): if is_tracing_module() and _graph_surgery_mode(): args, kwargs = _node_to_tensor(*args, **kwargs) attr = getattr(args[0], name) outs = attr if callable(attr): outs = attr(*(args[1:]), **kwargs) if name == "__setitem__": _node_to_tensor(outs) return None outs = _tensor_to_node(outs) return outs else: outs = func if callable(func): outs = func(*args, **kwargs) if isinstance(func, property): outs = func.__get__(*args, **kwargs) return outs return _any tensor_method_patch = [] for method in get_tensor_wrapable_method(): patch = PatchedFn(TensorNode, method) if type(getattr(Tensor, method)) == property: # Only support property.getter patch.set_func(property(_any_method(method, patch.origin_fn))) else: patch.set_func(_any_method(method, patch.origin_fn)) tensor_method_patch.append(patch) patch = PatchedFn(Node, "name") patch.set_func(property(patch.origin_fn.fget, _name_setter)) tensor_method_patch.append(patch) return tensor_method_patch def _convert_node_and_tensor(orig_func): @functools.wraps(orig_func) def _convert(*args, **kwargs): if is_tracing_module() and _graph_surgery_mode(): args, kwargs = _node_to_tensor(*args, **kwargs) rst = orig_func(*args, **kwargs, method_func=_convert) rst = _tensor_to_node(rst) return rst else: rst = orig_func(*args, **kwargs) return rst return _convert def _wrap_mnode_getattr(orig_getattr): @functools.wraps(orig_getattr) def wraped_fn(self, name): if is_tracing_module() and _graph_surgery_mode(): obj = self.owner current_graph = active_module_tracer().current_scope() if self.top_graph is not None: current_graph._add_input(self) attr = getattr(obj, name) node = attr if not isinstance(attr, TracedModuleBuilder): if isinstance(attr, Module): attr = TracedModuleBuilder(attr) setattr(obj, name, attr) if isinstance(attr, (NodeMixin, RawTensor)): NodeMixin.wrap( attr, lambda: GetAttr.make( self, type=NodeMixin.get_wrapped_type(attr), attr_name=name, name="", ), ) if isinstance(attr, (NodeMixin, RawTensor)): node = NodeMixin.get(attr) if isinstance(node, ModuleNode) and isinstance(attr, (NodeMixin, Module)): node._owner = weakref.ref(attr) return node else: node = object.__getattribute__(self, name) return node return wraped_fn def _wrap_mnode_call(orig_call): @functools.wraps(orig_call) def wraped_fn(self, *args, **kwargs): if is_tracing_module() and _graph_surgery_mode(): obj = self.owner if self.top_graph is not None: active_module_tracer().current_scope()._add_input(self) rst = obj(*args, **kwargs) else: raise TypeError("'ModuleNode' object is not callable") return rst return wraped_fn class _InsertExprs: def __init__(self, graph, expr: Optional[Expr] = None): self.graph = graph while graph.top_graph is not None: graph = graph.top_graph assert graph.inputs[0].owner._is_top self.root_graph = graph self.global_scope = InternalGraph(self.graph._name, self.graph._qualname) self.global_scope._namespace.merge(self.graph._namespace) self.expr = expr self._tensor_method_patch = None def __enter__(self): self.use_sym_shape = set_symbolic_shape(True) node_id, expr_id = self.root_graph._total_ids Node._set_next_id(node_id) Expr._set_next_id(expr_id) set_module_tracing() _set_graph_surgery_mode(True) assert active_module_tracer() is None set_active_module_tracer( module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x))) ) active_module_tracer().patcher.__enter__() for cls, name, func in [ [ModuleNode, "__getattr__", _wrap_mnode_getattr], [ModuleNode, "__call__", _wrap_mnode_call], [TracedModuleBuilder, "__call__", _convert_node_and_tensor], ]: active_module_tracer().patcher.patch_function(cls, name, func) self._tensor_method_patch = _wrap_method_to_tensor_node() active_module_tracer().push_scope(self.global_scope) def __exit__(self, ty, va, tr): if va is not None: return False active_module_tracer().patcher.__exit__(ty, va, tr) while self._tensor_method_patch: pf = self._tensor_method_patch.pop() pf.set_func(pf.origin_fn) # delete ModuleNode.__call__ to avoid entering the # ModuleNode.__init__ method when call a ModuleNode object. delattr(ModuleNode, "__call__") module = self.graph.inputs[0].owner def build_traced_module( module: TracedModuleBuilder, target_module: TracedModule ): for k, v in module.__dict__.items(): if isinstance(v, TracedModuleBuilder): traced_v = v.build() build_traced_module(v, traced_v) setattr(target_module, k, traced_v) build_traced_module(module, module) set_symbolic_shape(self.use_sym_shape) _set_graph_surgery_mode(False) set_active_module_tracer(None) unset_module_tracing() extra_inp_nodes = set(self.global_scope.inputs) max_inp_expr_idx = -1 for node in extra_inp_nodes: assert ( node.top_graph == self.graph ), "The input node ({}) is not in the graph ({})".format(node, self.graph) if node.expr in self.graph._exprs: max_inp_expr_idx = max( max_inp_expr_idx, self.graph._exprs.index(node.expr) ) max_inp_expr_idx += 1 insert_index = -1 if self.expr in self.graph._exprs: insert_index = self.graph._exprs.index(self.expr) insert_index += 1 if insert_index < max_inp_expr_idx: insert_index = max_inp_expr_idx for expr in self.global_scope._exprs: self.graph._exprs.insert(insert_index, expr) insert_index += 1 self.graph._namespace.merge(self.global_scope._namespace) self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id()) self.root_graph.inputs[0].owner._update_ref() for node in self.root_graph.nodes(): if isinstance(node, TensorNode): node.value = None return True class NameSpace: def __init__(self, name, qualname): self.name = name self.qualname = qualname self._used_names = {} def create_unique_name(self, name: str, node: Any = None) -> str: assert isinstance(name, str), "The name must be a string" if name in self._used_names and (self._used_names[name] is node): return name name = re.sub("[^0-9a-zA-Z_]+", "_", name) if name[0].isdigit(): name = "_{}".format(name) while ( name in self._used_names and self._used_names[name] is not None ) or _is_builtin_name(name): match = re.match(r"(.*)_(\d+)$", name) if match is None: name = name + "_1" else: base, num = match.group(1, 2) name = "{}_{}".format(base, int(num) + 1) self._used_names.setdefault(name) if node is not None: self.associate_name_with_obj(name, node) return name def auto_naming_for_outputs(self, expr: Expr): _add_suffix = lambda x: x + "_out" if is_call_module(expr): call_node = expr.inputs[0] qualname = "%s.[out]" % (call_node.qualname) name = call_node.name elif is_call_tensor_method(expr): name = expr.method.strip("_") qualname = "{}.[{}]".format( self.qualname, self.create_unique_name("method_%s" % (name)), ) elif is_call_function(expr): name = expr.func.__name__ qualname = "{}.[{}]".format( self.qualname, self.create_unique_name("func_%s" % name), ) elif is_apply_def(expr): name = str(expr.opdef).lower() qualname = "{}.[{}]".format( self.qualname, self.create_unique_name("def_%s" % name), ) elif is_getattr(expr): qualname = "{}.{}".format(expr.inputs[0].qualname, expr.name) name = get_suffix_name(self.qualname, qualname) _add_suffix = lambda x: x elif is_constant(expr) or is_input(expr): name = ( expr.name if expr.name else "const_" + type(expr.value).__name__.lower() ) qualname = "{}.[{}]".format(self.qualname, name) _add_suffix = lambda x: x for node in expr.outputs: cur_name = node._name if node._name else _add_suffix(name) node._name = self.create_unique_name(cur_name, node) if node._qualname == "": node._qualname = qualname assert get_suffix_name(self.qualname, qualname) is not None def merge(self, other: "NameSpace"): self._used_names.update(other.used_names) def associate_name_with_obj(self, name: str, node: Node): assert name in self.used_names assert self.used_names[name] is None, "The name(%s) is already in use" % (name) self._used_names[name] = node def unassociate_name_with_obj(self, node: Node): assert node.name in self.used_names # assert self.used_names[node.name] is node self._used_names[node.name] = None @property def used_names(self): return self._used_names class InternalGraph: r"""``InternalGraph`` is the main data structure used in the TracedModule. It is used to represent the execution procedure of Module's forward method. For example, the following code .. code-block:: import megengine.random as rand import megengine.functional as F import megengine.module as M import megengine.traced_module as tm class MyModule(M.Module): def __init__(self): super().__init__() self.param = rand.normal(size=(3, 4)) self.linear = M.Linear(4, 5) def forward(self, x): return F.relu(self.linear(x + self.param)) net = MyModule() inp = F.zeros(shape = (3, 4)) traced_module = tm.trace_module(net, inp) Will produce the following ``InternalGraph``:: print(traced_module.graph) .. code-block:: text MyModule.Graph (self, x) { %2: linear = getattr(self, "linear") -> (Linear) %3: param = getattr(self, "param") -> (Tensor) %4: add_out = x.__add__(param, ) %5: linear_out = linear(add_out, ) %6: relu_out = nn.relu(linear_out, ) return relu_out } """ _exprs = None # type: List[Expr] _inputs = None # type: List[Node] _outputs = None # type: List[Node] _top_graph = None # type: InternalGraph _total_ids = None # type: List[int] def __init__(self, name: str, qualname: str): self._exprs = [] self._inputs = [] self._outputs = [] self._watch_point = [] self._end_point = [] self._namespace = NameSpace(name, qualname) self._rst = collections.defaultdict(list) self._name = name self._qualname = qualname def _insert(self, expr): self._exprs.append(expr) @property def name(self) -> str: r"""Get the name of this graph.""" return self._name @name.setter def name(self, new_name: str): r"""Set a new name to this graph.""" mod = self.inputs[0].owner graph = self.top_graph assert graph is not None or mod._is_top, "The parent graph cannot be None." if graph is not None: assert graph._namespace.used_names.get(new_name, None) is None, ( "The name(%s) is already in use. Please try a different one again." % (new_name) ) new_name = graph._namespace.create_unique_name(new_name, self) self._name = new_name @property def qualname(self) -> str: r"""Get the `qualname` of this graph. The `qualname` can be used to get the submodule from the traced Module or Module. Example: .. code-block:: import megengine.module as M import megengine.traced_module as tm import megengine as mge class block(M.Module): def __init__(self): super().__init__() self.relu = M.ReLU() def forward(self, x): return self.relu(x) class module(M.Module): def __init__(self): super().__init__() self.block = block() def forward(self, x): x = self.block(x) return x net = module() traced_net = tm.trace_module(net, mge.Tensor([0.])) qualname = traced_net.block.graph.qualname # qualname = "module.block" qualname = qualname.split(".", 1)[-1] # qualname = "block" assert qualname in list(map(lambda x: x[0], net.named_modules())) assert qualname in list(map(lambda x: x[0], traced_net.named_modules())) """ return self._qualname @property def inputs(self) -> List[Node]: r"""Get the list of input Nodes of this graph. Returns: A list of ``Node``. """ return self._inputs @property def outputs(self) -> List[Node]: r"""Get the list of output Nodes of this graph. Returns: A list of ``Node``. """ return self._outputs @property def top_graph(self): r"""Get the parent graph of this graph. Returns: An ``InternalGraph``. """ if self._top_graph: return self._top_graph() return None def exprs(self, recursive=True): r"""Get the Exprs that constitute this graph. Args: recursive: whether to get the Exprs in the subgraph. Default: True Returns: A ``ExprFilter`` containing all Exprs of this graph. """ return ExprFilter(_expr_iter(self, recursive)) def nodes(self, recursive=True): r"""Get the Nodes that constitute this graph. Args: recursive: whether to get the Nodes in the subgraph. Default: True Returns: A ``NodeFilter`` containing all Nodes of this graph. """ return NodeFilter(_node_iter(self, recursive)) def get_function_by_type(self, func: Callable = None, recursive=True): r"""Filter Exprs by the type of ``CallFunction``. Args: func: a built-in function, such as ``F.relu``. recursive: whether to get the Exprs in the subgraph. Default: True Returns: A :class:`~.TracedModule.ExprFilterCallFunction`. """ return self.exprs(recursive).call_function(func) def get_method_by_type(self, method: str = None, recursive=True): r"""Filter Exprs by the type of ``CallMethod``. Args: method: a method string, such as "__add__". recursive: whether to get the Exprs in the subgraph. Default: True Returns: A :class:`~.TracedModule.ExprFilterCallMethod`. """ return self.exprs(recursive).call_method(method) def get_expr_by_id(self, expr_id: List[int] = None, recursive=True): r"""Filter Exprs by their ``id``. Args: expr_id: a list of :class:`int`. recursive: whether to get the Exprs in the subgraph. Default: True Returns: A :class:`~.TracedModule.ExprFilterExprId`. """ return self.exprs(recursive).expr_id(expr_id) def get_module_by_type(self, module_cls: Module, recursive=True): r"""Filter Nodes by the ``module_type`` of ``ModuleNode``. Args: module_cls: a subclass of :class:`~.Module`. recursive: whether to get the Nodes in the subgraph. Default: True Returns: A :class:`~.TracedModule.NodeFilterType`. """ return self.nodes(recursive).type(module_cls) def get_node_by_id(self, node_id: List[int] = None, recursive=True): r"""Filter Nodes by their ``id``. The ``id`` of the ``Node`` can be obtained by the following code .. code-block:: # node : Node print("{:i}".format(node)) print(node.__format__("i")) # graph : InternalGraph print("{:i}".format(graph)) print(graph.__format__("i")) Args: node_id: a list of :class:`int`. recursive: whether to get the Nodes in the subgraph. Default: True Returns: A :class:`~.TracedModule.NodeFilterNodeId`. """ return self.nodes(recursive).node_id(node_id) def get_node_by_name( self, name: str = None, ignorecase: bool = True, recursive=True ): r"""Filter Nodes by their full name. The full name of the ``Node`` can be obtained by the following code .. code-block:: # node : Node print("{:p}".format(node)) print(node.__format__("p")) # graph : InternalGraph print("{:p}".format(graph)) print(graph.__format__("p")) Args: name: a string in glob syntax that can contain ``?`` and ``*`` to match a single or arbitrary characters. ignorecase: whether to ignroe case. Default: True recursive: whether to get the Nodes in the subgraph. Default: True Returns: A :class:`~.TracedModule.NodeFilterName`. """ return self.nodes(recursive).name(name, ignorecase) def _add_input(self, i): self._inputs.append(i) def _add_output(self, o): self._outputs.append(o) def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: r"""Get the dependent Exprs of the ``nodes``. Args: nodes: a list of :class:`Node`. Returns: A list of dependent :class:`Expr`. """ if not isinstance(nodes, Sequence): nodes = (nodes,) ret = list() queue = list(nodes) visited_queue = list() while queue: node = queue.pop() visited_queue.append(node) expr = node.expr if expr not in ret: ret.append(expr) for i in expr.inputs: if i not in queue and i not in visited_queue: queue.append(i) return ret def reset_inputs(self, *args, **kwargs): forma_mnode = self.inputs[0] moudle = forma_mnode.owner assert moudle._is_top, "reset_inputs only supports top graph" inputs, tree_def = tree_flatten(((moudle, *args), kwargs)) def create_node(val: Tensor): name = self._namespace.create_unique_name("args") node = Input( type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) ).outputs[0] self._namespace.associate_name_with_obj(node.name, node) node.shape = val.shape node.dtype = val.dtype return node formal_node_inputs = [ forma_mnode, ] org_argdef = list(moudle.argdef_graph_map.keys())[0] for v in inputs[1:]: assert isinstance(v, RawTensor) formal_node_inputs.append(create_node(v)) self._inputs[:] = formal_node_inputs moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) return formal_node_inputs[1:] def add_input_node( self, shape: Tuple[int], dtype: str = "float32", name: str = "args" ): r"""Add an input node to the graph. The new Node will be the last of the positional arguments. Args: shape: the shape of the new input Node. dtype: the dtype of the new input Node. Default: float32 name: the name of the new input Node. When the name is used in the graph, a suffix will be added to it. """ forma_mnode = self.inputs[0] moudle = forma_mnode.owner assert moudle._is_top, "add_input_node only supports top graph" def create_node(name=None): name = self._namespace.create_unique_name(name) node = Input( type=TensorNode, name=name, qualname="%s.[%s]" % (self._qualname, name) ).outputs[0] self._namespace.associate_name_with_obj(node.name, node) node.shape = shape node.dtype = dtype return node org_argdef = list(moudle.argdef_graph_map.keys())[0] args, kwargs = org_argdef.unflatten(self._inputs) formal_inp_node = create_node(name) inputs, tree_def = tree_flatten( ((*args, formal_inp_node), kwargs), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), ) self._inputs[:] = inputs[:] moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) return formal_inp_node def reset_outputs(self, outputs): r"""Reset the output Nodes of the graph. .. note:: This method only supports resetting the output of graphs that do not have a parent graph. Args: outputs: an object which inner element is Node. Support tuple, list dict, etc. For example, the following code .. code-block:: import megengine.functional as F import megengine.module as M import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 return x net = MyModule() inp = F.zeros(shape = (1, )) traced_module = tm.trace_module(net, inp) graph = traced_module.graph inp_node = graph.inputs[1] out_node = graph.outputs[0] graph.reset_outputs((out_node, {"input": inp_node})) out = traced_module(inp) Will produce the following ``InternalGraph`` and ``out``:: print(graph) print(out) .. code-block:: text MyModule.Graph (self, x) { %2: add_out = x.__add__(1, ) return add_out, x } (Tensor([1.], device=xpux:0), {'input': Tensor([0.], device=xpux:0)}) """ outputs, out_def = tree_flatten( outputs, is_leaf=lambda x: isinstance(x, TensorNode), ) forma_mnode = self.inputs[0] moudle = forma_mnode.owner assert moudle._is_top, "reset_outputs only supports top graph" tree_def = list(moudle.argdef_graph_map.keys())[0] self._outputs[:] = outputs moudle.argdef_outdef_map[tree_def] = out_def def add_output_node(self, node: TensorNode): r"""Add an output node to the Graph. The Graph output will become a ``tuple`` after calling ``add_output_node``. The first element of the ``tuple`` is the original output, and the second is the ``node``. For example, the following code .. code-block:: import megengine.functional as F import megengine.module as M import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 return x net = MyModule() inp = F.zeros(shape = (1, )) traced_module = tm.trace_module(net, inp) graph = traced_module.graph inp_node = graph.inputs[1] out_node = graph.outputs[0] graph.add_output_node(inp_node) graph.add_output_node(out_node) out = traced_module(inp) Will produce the following ``InternalGraph`` and ``out``:: print(graph) print(out) .. code-block:: text MyModule.Graph (self, x) { %2: add_out = x.__add__(1, ) return add_out, x, add_out } ((Tensor([1.], device=xpux:0), Tensor([0.], device=xpux:0)), Tensor([1.], device=xpux:0)) """ forma_mnode = self.inputs[0] moudle = forma_mnode.owner assert moudle._is_top, "add_output_node only supports top graph" tree_def = list(moudle.argdef_graph_map.keys())[0] org_out_def = moudle.argdef_outdef_map[tree_def] org_outs = org_out_def.unflatten(self._outputs) outputs, out_def = tree_flatten( (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode), ) self._outputs[:] = outputs moudle.argdef_outdef_map[tree_def] = out_def def insert_exprs(self, expr: Optional[Expr] = None): r"""Initialize the trace mode and insertion position. When used within a 'with' statement, this will temporary set the trace mode and then restore normal mode when the with statement exits:: with graph.insert_exprs(e): # set the trace mode ... # trace function or module ... # inert exprs into graph and resotre normal mode Args: expr: the ``expr`` after which to insert. If None, the insertion position will be automatically set based on the input node. Returns: A resource manager that will initialize trace mode on ``__enter__`` and restore normal mode on ``__exit__``. """ if expr is not None: assert expr.top_graph == self, "Expr to insert after is not in graph." return _InsertExprs(self, expr) def replace_node(self, repl_dict: Dict[Node, Node]): r"""Replace the Nodes in the graph. Args: repl_dict: the map {old_Node: new_Node} that specifies how to replace the Nodes. """ while repl_dict: node, repl_node = repl_dict.popitem() assert type(node) == type( repl_node ), "The type of {}({}) and {}({}) are not the same".format( node, type(node).__name__, repl_node, type(repl_node).__name__ ) # check graph inputs and outputs for i, n in enumerate(self.outputs): if n is node: self.outputs[i] = repl_node # update users of node and repl_node # update inputs of expr in node.users graph = repl_node.top_graph assert graph is not None assert graph is self index = -1 if not isinstance(repl_node.expr, Input): index = graph._exprs.index(repl_node.expr) dep_exprs = self.get_dep_exprs(repl_node) i = 0 while i < len(node.users): n = node.users[i] if n in graph._exprs and index >= graph._exprs.index(n): i += 1 continue if n in dep_exprs: logger.info("Find a loop: ignore this replacement once") logger.info("node: %s" % node.__repr__()) logger.info("expr: %s" % n.__repr__()) i += 1 continue repl_node.users.append(n) node.users.pop(i) idx = n.inputs.index(node) n.inputs[idx] = repl_node def _merge_getattr_expr(self): getattr_nodes_map = dict() # Dcit[(Node, str), Node] node_to_attrname = dict() # Dict[Node, (Node, Str)] for expr in filter(lambda x: isinstance(x, GetAttr), self._exprs): base_node, attr_name = expr.inputs[0], expr.name if expr.inputs[0] in node_to_attrname: base_node, base_name = node_to_attrname[expr.inputs[0]] attr_name = "{}.{}".format(base_name, expr.name) if get_suffix_name(self.qualname, expr.outputs[0].qualname) != attr_name: expected_qualname = base_node.qualname + "." + attr_name logger.warning( "{}.qualname expects {}, got {} actually. You can re-trace this " "TracedModel to make the name correct.".format( expr.outputs[0], expected_qualname, expr.outputs[0].qualname ) ) expr.outputs[0]._qualname = expected_qualname key = (base_node, attr_name) node_to_attrname[expr.outputs[0]] = key if key in getattr_nodes_map: existed_node = getattr_nodes_map[key] repl_node = expr.outputs[0] for expr in repl_node.users: existed_node.users.append(expr) idx = expr.inputs.index(repl_node) expr.inputs[idx] = existed_node repl_node.users = [] else: if attr_name != expr.name: expr.name = attr_name expr.inputs[0].users.remove(expr) self.inputs[0].users.append(expr) expr.inputs[0] = self.inputs[0] getattr_nodes_map[key] = expr.outputs[0] def compile(self): r"""Delete unused expr.""" self._merge_getattr_expr() dep_exprs = self.get_dep_exprs(self.outputs) i = 0 while i < len(self._exprs): expr = self._exprs[i] if expr in dep_exprs or expr._disable_remove: i += 1 continue for n in expr.inputs: n.users.remove(expr) self._exprs.remove(expr) for n in expr.outputs: self._namespace.unassociate_name_with_obj(n) def _reset_ids(self): for total_expr_id, expr in enumerate(self.exprs()): expr._id = total_expr_id for total_node_id, node in enumerate(self.nodes()): node._id = total_node_id self._total_ids = (total_node_id + 1, total_expr_id + 1) def _re_associate_name(self): self._namespace.used_names.clear() for node in self.nodes(False): node._name = self._namespace.create_unique_name(node.name, node) def interpret(self, *inputs): node2value = {} end_nodes_set = set(self._end_point) endnode2value = {} def get_all_endnode_val(n, v): if n in end_nodes_set: endnode2value[n] = v end_nodes_set.remove(n) return not end_nodes_set return False ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0) for n, v in zip(self._inputs, inputs): if ref_count(n) > 0: node2value[n] = [v, ref_count(n)] if n in self._watch_point: self._rst[n].append(v) if n in self._end_point and get_all_endnode_val(n, v): return list(endnode2value[i] for i in self._end_point) for expr in self._exprs: values = expr.interpret(*list(node2value[i][0] for i in expr.inputs)) for n in expr.inputs: node2value[n][1] -= 1 if node2value[n][1] == 0: node2value.pop(n) if values is not None: assert len(values) == len(expr.outputs) for n, v in zip(expr.outputs, values): if ref_count(n) > 0: node2value[n] = [v, ref_count(n)] if n in self._watch_point: self._rst[n] = v if self._end_point and get_all_endnode_val(n, v): return list(endnode2value[i] for i in self._end_point) return list(node2value[i][0] for i in self._outputs) def eval(self, *inputs: Tuple[Tensor]): r"""Call this method to execute the graph. Args: inputs: the tensors corresponding to the ``graph.inputs[1:]``. """ assert len(inputs) == len(self._inputs) - 1 inp = [self._inputs[0].owner] + list(inputs) return self.interpret(*inp) def __repr__(self): return self.__format__() def __format__(self, format_spec: str = "") -> str: saved_format_spec = Node._set_format_spec(format_spec) name = "" if self._name: name = "%s.Graph" % self._name res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format( name, ", ".join(str(i) for i in self._inputs), "\n\t".join("{}".format(str(i)) for i in self._exprs), ", ".join(str(i) for i in self._outputs), ) Node._set_format_spec(saved_format_spec) return res def __getstate__(self): state = { "_exprs": self._exprs, "_inputs": self._inputs, "_outputs": self._outputs, "_watch_point": [], "_end_point": [], "_namespace": self._namespace, "_rst": collections.defaultdict(list), "_name": self._name, "_qualname": self._qualname, } if self._total_ids: state["_total_ids"] = self._total_ids _check_obj_attr(state) return state def __setstate__(self, state): old_version = False if "_module_name" in state: old_version = True state["_qualname"] = state.pop("_module_name") prefix_name = state.pop("_prefix_name") if prefix_name: state["_name"] = "{}_{}".format(prefix_name, state["_name"]) self.__dict__.update(state) if old_version: self.inputs[0]._qualname = self._qualname for e in self.exprs(False): if isinstance(e, GetAttr): e.outputs[0]._qualname = "{}.{}".format( e.inputs[0]._qualname, e.name ) for n in self.nodes(False): if isinstance(n.expr, CallMethod) and isinstance( n.expr.inputs[0], ModuleNode ): n._qualname = n.expr.inputs[0]._qualname + ".[out]" continue if ( not isinstance(n.expr, GetAttr) and isinstance(n, TensorNode) and n._qualname ): n._qualname = "{}.{}".format(self._qualname, n._qualname) self._namespace = NameSpace(self._name, self._qualname) self._re_associate_name() def __copy__(self): cls = self.__class__ result = cls.__new__(cls) result.__dict__.update(self.__dict__) return result def __deepcopy__(self, memo): with max_recursion_limit(): if id(self) in memo: return memo[id(self)] cls = self.__class__ result = cls.__new__(cls) state = {} memo[id(self)] = result for k, v in self.__dict__.items(): if not isinstance(v, weakref.ReferenceType): state[k] = copy.deepcopy(v, memo) result.__dict__.update(state) return result def _get_meth_name(obj, func): tp = obj if isinstance(obj, type) else type(obj) for cls in tp.mro(): for k, v in cls.__dict__.items(): if v == func: return k return None def _wrapped_function(orig_func): @functools.wraps(orig_func) def wrapped_fn(*args, **kwargs): method_func = kwargs.pop("method_func", wrapped_fn) if not is_tracing_module(): return orig_func(*args, **kwargs) with _exclude_from_trace(): inputs, tree_def = tree_flatten((args, kwargs)) for i in inputs: if not NodeMixin.get(i, None): if isinstance(i, (RawTensor, NodeMixin)): NodeMixin.wrap_safe(i, Constant.make(i)) args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs) meth_name = _get_meth_name(args[0], method_func) arg_type = args[0] if isinstance(args[0], type) else type(args[0]) if meth_name and arg_type and issubclass(arg_type, RawTensor): inputs, tree_def = tree_flatten((args, kwargs)) self = inputs[0] if meth_name == "__new__": if all([not isinstance(i, RawTensor) for i in inputs]): # only trace Tensor.__new__() when there are tensors in args return orig_func(*args, **kwargs) if isinstance(args[1], RawTensor): node = NodeMixin.get(inputs[1]) inputs[1] = apply( Copy(comp_node=inputs[1].device), Tensor(inputs[1]) )[0] # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, # which will cause they have same _NodeMixin__node in tracing. NodeMixin.wrap_safe(inputs[1], node) args, kwargs = tree_def.unflatten(inputs) call_node = CallMethod.make(self, meth_name) else: call_node = CallMethod.make(NodeMixin.get(self), meth_name) call_node.add_inputs(inputs[1:]) else: inputs, tree_def = tree_flatten((args, kwargs)) call_node = CallFunction.make(orig_func) call_node.add_inputs(inputs) call_node.arg_def = tree_def rst = orig_func(*args, **kwargs) if meth_name == "__setitem__": rst = self if rst is not None: outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) call_node.out_def = out_def else: outputs = None call_node.add_outputs(outputs) if _get_expr_checker(): with _exclude_from_trace(): active_module_tracer().checker.check_expr_interpret( call_node, outputs ) return rst return wrapped_fn class TracedModuleBuilder(NodeMixin): _mod = None # type: Module _body = None # type: InternalGraph _is_builtin = None # type: bool _argdef_graph_map = None # type: Dict[TreeDef, "InternalGraph"] _argdef_outdef_map = None # type: Dict[TreeDef, TreeDef] nodes = None __builder_attributes__ = [ "_mod", "_body", "_NodeMixin__node", "_is_builtin", "build", "_record_wrapped_nodes", "_argdef_graph_map", "_argdef_outdef_map", "_check_qat_module", "nodes", "__class__", "__dict__", "_is_top", ] def __init__(self, mod, is_top_module=False): super(TracedModuleBuilder, self).__init__() assert isinstance(mod, Module) self._mod = mod self._body = None self._is_top = is_top_module self._is_builtin = ( True if isinstance(mod, (Observer, _FakeQuantize)) else module_tracer.is_builtin(mod) ) if isinstance(self._mod, QATModule): with _exclude_from_trace(): self._check_qat_module(self._mod) self._argdef_graph_map = {} self._argdef_outdef_map = {} self.nodes = set() # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__. self.__class__ = type( "TracedModuleBuilder", (TracedModuleBuilder, mod.__class__), dict(TracedModuleBuilder.__dict__), ) def _check_qat_module(self, qat_module): def isbuiltin(m): return m is None or module_tracer.is_builtin(m) if qat_module.with_act: act_observer = qat_module.act_observer act_fakequant = qat_module.act_fake_quant if not isbuiltin(act_observer) or not isbuiltin(act_fakequant): qparams = ( act_observer.get_qparams() if hasattr(act_observer, "get_qparams") else act_fakequant.get_qparams() ) dtype = ( act_observer.dtype if hasattr(act_observer, "dtype") else act_fakequant.dtype ) qat_module.act_observer = None qat_module.act_fake_quant = TM_FakeQuant(dtype) qat_module.act_fake_quant.set_qparams(qparams) if qat_module.with_weight: weight_observer = qat_module.weight_observer weight_fakequant = qat_module.weight_fake_quant if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant): qparams = ( weight_observer.get_qparams() if hasattr(weight_observer, "get_qparams") else weight_fakequant.get_qparams() ) dtype = ( weight_observer.dtype if hasattr(weight_observer, "dtype") else weight_fakequant.dtype ) qat_module.weight_observer = None qat_module.weight_fake_quant = TM_FakeQuant(dtype) qat_module.weight_fake_quant.set_qparams(qparams) def build(self): if self._is_builtin: assert module_tracer.is_builtin(self._mod) mod_type = type(self._mod) for node in self.nodes: node.module_type = mod_type return self._mod elif isinstance(self._mod, TracedModule) and _graph_surgery_mode(): return self._mod else: is_qat = isinstance(self._mod, QATModule) or ( isinstance(self._mod, TracedModule) and self._mod.is_qat ) traced_module = TracedModule( self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat ) for _, g in self._argdef_graph_map.items(): g.compile() if self._is_top: g._total_ids = (Node._get_next_id(), Expr._get_next_id()) for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: if isinstance(v, TracedModuleBuilder): v = v.build() setattr(traced_module, k, v) elif isinstance(v, RawTensor): setattr(traced_module, k, v) if isinstance(self._mod, QATModule): with _exclude_from_trace(): traced_module.with_act = self._mod.with_act traced_module.with_weight = self._mod.with_weight if not hasattr(traced_module, "act_fake_quant"): traced_module.act_fake_quant = None if not hasattr(traced_module, "act_observer"): traced_module.act_observer = None if not hasattr(traced_module, "weight_fake_quant"): traced_module.weight_fake_quant = None if not hasattr(traced_module, "weight_observer"): traced_module.weight_observer = None if self._is_top: traced_module._update_ref() return traced_module def _record_wrapped_nodes(self, node): self.nodes.add(node) def __call__(self, *args, **kwargs): assert isinstance(self._mod, Module) is_graph_surgery_mode = _graph_surgery_mode() if isinstance(self._mod, TracedModule) and is_graph_surgery_mode: _set_graph_surgery_mode(False) # prepare args and kwargs for inner graph if "method_func" in kwargs: kwargs.pop("method_func") args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True) def mark_constant(x): node = NodeMixin.get(x, None) if node is None: # capture as constant NodeMixin.wrap(x, lambda: Constant.make(x)) inputs, tree_def = tree_flatten(((self, *args), kwargs)) for i in inputs: mark_constant(i) callnode = CallMethod.make(NodeMixin.get(self)) callnode.add_inputs(inputs[1:]) callnode.arg_def = tree_def if self._is_builtin or tree_def in self._argdef_graph_map: with _exclude_from_trace(): rst = self._mod(*args, **kwargs) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) if _get_expr_checker(): tmp = self.build() active_module_tracer().checker.check_builtin_module( tmp, callnode, outputs ) if self._is_builtin: self._body = None elif tree_def in self._argdef_graph_map: self._body = self._argdef_graph_map[tree_def] else: orig_self = NodeMixin.get(self) parent_graph = active_module_tracer().current_scope() module_qualname = orig_self._qualname self._body = InternalGraph( name=parent_graph._namespace.create_unique_name(module_qualname), qualname=module_qualname, ) parent_graph._namespace.associate_name_with_obj(self._body.name, self._body) active_module_tracer().push_scope(self._body) # rebind self to new input node NodeMixin.wrap_safe( self, Input.make( name="self", qualname=module_qualname, type=NodeMixin.get_wrapped_type(self), ), ) origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph index_args, index_kwargs = tree_def.unflatten( [ ArgsIndex(0), *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))), ] ) key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs) idx2key = {} for k, v in key2idx.items(): if isinstance(v, ArgsIndex): idx2key[v.index] = k else: flatten_argidx, _ = tree_flatten(v) for _i, v in enumerate(flatten_argidx): if isinstance(v, ArgsIndex): idx2key[v.index] = k + "_%d" % _i def wrap(x, name): if isinstance(x, (RawTensor, NodeMixin)): NodeMixin.wrap( x, lambda: Input.make( type=NodeMixin.get_wrapped_type(x), name=name, qualname="%s.[%s]" % (module_qualname, name), ), ) return x args = [self] orig_traced_inputs = ( None if not isinstance(self._mod, TracedModule) else self._mod.argdef_graph_map[tree_def].inputs ) ind = 1 for v in inputs[1:]: if isinstance(v, (RawTensor, NodeMixin)): args_name = ( orig_traced_inputs[ind]._name if orig_traced_inputs else idx2key[ind] ) ind += 1 args.append(wrap(v, args_name)) else: args.append(v) args, kwargs = tree_def.unflatten(args) active_module_tracer().patcher.auto_patch( getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) ) rst = type(self._mod).forward(*args, **kwargs) if _graph_surgery_mode(): rst = _node_to_tensor(rst)[0][0] outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) for i in ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ): mark_constant(i) active_module_tracer().current_scope()._add_output(NodeMixin.get(i)) NodeMixin.wrap_safe(self, orig_self) for arg, node in zip(inputs[1:], origin_inp_node): if node: NodeMixin.wrap_safe(arg, node) active_module_tracer().pop_scope() # rebind output to outer graph callnode.out_def = out_def callnode.add_outputs(outputs) self._argdef_graph_map[callnode.arg_def] = self._body self._argdef_outdef_map[callnode.arg_def] = out_def _set_graph_surgery_mode(is_graph_surgery_mode) return rst def __setattr__(self, name, value): object.__setattr__(self, name, value) def __repr__(self): return repr(self._mod) def __getattr__(self, name): if name not in self._mod.__dict__: attr = getattr(type(self._mod), name).__get__(self, type(self)) else: attr = getattr(self._mod, name) if ( isinstance(attr, FunctionType) and id(attr) in active_module_tracer().patcher.patched_fn_ids ): return active_module_tracer().patcher.wrap_fn(attr) if isinstance(attr, (List, Dict)): flag = _set_graph_surgery_mode(False) with _exclude_from_trace(): has_module, m_container = replace_container_with_module_container( attr ) if m_container: attr = m_container if has_module and not m_container: raise ValueError( "Can not trace the module that uses the same container to store" " Module and Non-Module objects." ) _set_graph_surgery_mode(flag) if isinstance(attr, Module): attr = TracedModuleBuilder(attr) if isinstance(attr, (Module, RawTensor)): setattr(self, name, attr) NodeMixin.wrap( attr, lambda: GetAttr.make( NodeMixin.get(self), type=NodeMixin.get_wrapped_type(attr), attr_name=name, name="", ), ) return attr def __getattribute__(self, name): if name in TracedModuleBuilder.__builder_attributes__: return object.__getattribute__(self, name) else: wrapped = object.__getattribute__(self, name) class_members = dict(inspect.getmembers(self.__class__)) if name in self._mod.__dict__: mod_attr = getattr(self._mod, name) if name in class_members: if ( not isinstance(wrapped, TracedModuleBuilder) and wrapped is not mod_attr ): wrapped = self.__getattr__(name) if isinstance(wrapped, TracedModuleBuilder): if not isinstance(mod_attr, (List, Dict, QATModule)): assert ( mod_attr is wrapped._mod ), "TracedModule do not support modify module attributes, please check your code." if isinstance(wrapped, RawTensor): assert ( mod_attr is wrapped ), "TracedModule do not support modify tensor attributes, please check your code." if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( wrapped, lambda: GetAttr.make( NodeMixin.get(self), type=NodeMixin.get_wrapped_type(wrapped), attr_name=name, name="", ), ) return wrapped class _expr_iter: def __init__(self, graph: InternalGraph, recursive: bool = True): self.graph = graph self.recursive = recursive self._visited_graph = set() def __iter__(self): yield from self._gen_expr(self.graph) def _gen_expr(self, graph: InternalGraph): visit_inp = set() for inp_node in graph.inputs: if inp_node not in visit_inp: yield inp_node.expr visit_inp.add(inp_node) for expr in graph._exprs: yield expr if ( self.recursive and hasattr(expr, "graph") and expr.graph is not None and id(expr.graph) not in self._visited_graph ): self._visited_graph.add(id(expr.graph)) yield from self._gen_expr(expr.graph) class _node_iter: def __init__(self, graph: InternalGraph, recursive: bool = True) -> None: nodes = [] node_ids = set() for expr in graph.exprs(recursive): for n in expr.outputs: assert id(n) not in node_ids nodes.append(n) node_ids.add(id(n)) self.nodes = nodes def __iter__(self): for node in self.nodes: yield node class BaseFilter: r"""``BaseFilter`` exposes some methods for converting ``_node_iter/_expr_iter`` to ``list``, ``dict``, etc.""" def __init__(self, iter: Iterable): self._iter = iter def __iter__(self): return iter(self._iter) def as_list(self): r"""Consume this iterator and return its content as a list. Returns: A list of ``Node`` or ``Expr``. """ return list(self) def as_dict(self): r"""Construct an ordered dict to map from ``id`` to objects in this iterator. Returns: An :class:`OrderedDict`. """ return collections.OrderedDict((i._id, i) for i in self) def as_unique(self): """Assert that this iterator yields only one ``Node`` or ``Expr`` and return it. Rerurns: A ``Node`` or ``Expr``. """ rst = self.as_list() assert len(rst) == 1, "{} elements found".format(len(rst)) (elem,) = self return elem def as_count(self): r"""Consume this iterator and get the number of elements.""" return sum(1 for _ in self) class ExprFilter(BaseFilter): """Filter on Expr iterator. This class is an iterator of :class:`.Expr` objects and multiple filtering conditions and mappers can be chained. """ def call_function(self, func): r"""Filter by specific ``CallFunction.func``. See :meth:`~.InternalGraph.get_function_by_type` for details. """ return ExprFilterCallFunction(self, func) def call_method(self, method): r"""Filter by specific ``CallMethod.method``. See :meth:`~.InternalGraph.get_function_by_type` for details. """ return ExprFilterCallMethod(self, method) def expr_id(self, expr_id: List[int]): r"""Filter Exprs by their ``id``. See :meth:`~.InternalGraph.get_function_by_type` for details. """ return ExprFilterExprId(self, expr_id) class NodeFilter(BaseFilter): """Filter on Node iterator. This class is an iterator of :class:`~.traced_module.Node` objects and multiple filtering conditions and mappers can be chained. """ def type(self, owner_type): r"""Filter by specific Module type. See :meth:`~.InternalGraph.get_module_by_type` for details. """ return NodeFilterType(self, owner_type) def node_id(self, node_id: List[int]): r"""Filter Nodes by their ``id``. See :meth:`~.InternalGraph.get_node_by_id` for details. """ return NodeFilterNodeId(self, node_id) def name(self, name: str, ignorecase: bool = True): r"""Filter Nodes by their full name. See :meth:`~.InternalGraph.get_node_by_name` for details. """ return NodeFilterName(self, name, ignorecase) class NodeFilterType(NodeFilter): """See :meth:`~.InternalGraph.get_module_by_type`""" def __init__(self, expr_iter, owner_type): super().__init__(expr_iter) self.owner_type = owner_type def __iter__(self): for node in self._iter: if not isinstance(node, ModuleNode): continue if not hasattr(node, "owner"): continue if isinstance(node.owner, self.owner_type): yield node class NodeFilterNodeId(NodeFilter): """See :meth:`~.InternalGraph.get_node_by_id`""" def __init__(self, expr_iter, node_id: List[int]): super().__init__(expr_iter) if not isinstance(node_id, Sequence): node_id = [node_id] self.node_id = node_id def __iter__(self): for node in self._iter: if node._id in self.node_id: yield node class NodeFilterName(NodeFilter): """See :meth:`~.InternalGraph.get_node_by_name`""" _re = None def __init__(self, node_iter, pattern, ignorecase): super().__init__(node_iter) self.pattern = pattern self._re = self.make_re(pattern, ignorecase) @classmethod def make_re(cls, pattern, ignorecase=True): assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern) assert isinstance(ignorecase, bool) flags = 0 if ignorecase: flags |= re.IGNORECASE return re.compile(fnmatch.translate(pattern), flags=flags) def __iter__(self): for i in self._iter: graph = i.top_graph name = "{}_{}".format(graph._name, i._name) if self.pattern == name or self._re.match(name): yield i class ExprFilterCallFunction(ExprFilter): """See :meth:`~.InternalGraph.get_function_by_type`""" def __init__(self, expr_iter, func: Callable = None): super().__init__(expr_iter) self.func = func def __iter__(self): for expr in self._iter: if not isinstance(expr, CallFunction): continue if self.func is None or expr.func == self.func: yield expr class ExprFilterCallMethod(ExprFilter): """See :meth:`~.InternalGraph.get_method_by_type`""" def __init__(self, expr_iter, method: str = None): super().__init__(expr_iter) self.method = method def __iter__(self): for expr in self._iter: if not isinstance(expr, CallMethod): continue if self.method is None or expr.method == self.method: yield expr class ExprFilterExprId(ExprFilter): """See :meth:`~.InternalGraph.get_expr_by_id`""" def __init__(self, expr_iter, expr_id: List[int]): super().__init__(expr_iter) if not isinstance(expr_id, Sequence): expr_id = [expr_id] self.expr_id = expr_id def __iter__(self): for expr in self._iter: if expr._id in self.expr_id: yield expr class TracedModule(Module): r"""``TracedModule`` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule`` will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs`` and interpret it. .. note:: ``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module` for more details. """ # m_node = None # type: ModuleNode argdef_graph_map = None argdef_outdef_map = None def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False): super(TracedModule, self).__init__() self.argdef_graph_map = argdef_graph_map self.argdef_outdef_map = argdef_outdef_map self._is_top = is_top self.watch_points = [] self.watch_node_value = {} self.end_points = [] self.is_qat = is_qat self.argspec = None def forward(self, *args, **kwargs): if hasattr(self, "argspec") and self.argspec is not None: args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) inputs, treedef = tree_flatten(((self, *args), kwargs)) assert ( treedef in self.argdef_graph_map ), "support input args kwargs format: \n{}, but get: \n{}".format( "\n ".join( "forward({})".format(i._args_kwargs_repr()) for i in self.argdef_graph_map.keys() ), treedef._args_kwargs_repr(), ) inputs = filter( lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs ) # allow TracedModuleBuilder for retrace. outputs = self.argdef_graph_map[treedef].interpret(*inputs) if self.watch_points: self.watch_node_value = {} for n in self.watch_points: self.watch_node_value[n] = n.top_graph._rst.pop(n) if self.end_points: return outputs out_def = self.argdef_outdef_map[treedef] outputs = out_def.unflatten(outputs) return outputs def set_watch_points(self, nodes): r"""Initialize the :attr:`~.TracedModule.watch_points`. You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime. Args: nodes: a list of ``Node``. For example, the following code .. code-block:: import megengine.module as M import megengine as mge import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 + 2 return x net = MyModule() inp = mge.Tensor([0]) traced_module = tm.trace_module(net, inp) add_1_node = traced_module.graph.get_node_by_id(2).as_unique() traced_module.set_watch_points(add_1_node) out = traced_module(inp) Will get the following ``watch_node_value``:: print(traced_module.watch_node_value) .. code-block:: text {add_out: Tensor([1.], device=xpux:0)} """ if not isinstance(nodes, Sequence): nodes = [nodes] self.watch_points = nodes if nodes: nodes[0].top_graph._watch_point = [] for n in nodes: n.top_graph._watch_point.append(n) def clear_watch_points(self): r"""Clear the :attr:`~.TracedModule.watch_points` and :attr:`~.TracedModule.watch_node_value`. """ for n in self.watch_points: n.top_graph._watch_point = [] self.watch_points = [] self.watch_node_value = {} def set_end_points(self, nodes: Sequence[Node]): r"""Initialize the :attr:`~.TracedModule.end_points`. When all the ``nodes`` are generated, the Module will stop execution and return directly. Args: nodes: a list of ``Node``. For example, the following code .. code-block:: import megengine.module as M import megengine as mge import megengine.traced_module as tm class MyModule(M.Module): def forward(self, x): x = x + 1 + 2 return x net = MyModule() inp = mge.Tensor([0]) traced_module = tm.trace_module(net, inp) add_1_node = traced_module.graph.get_node_by_id(2).as_unique() traced_module.set_end_points(add_1_node) out = traced_module(inp) Will get the following ``out``:: print(out) .. code-block:: text [Tensor([1.], device=xpux:0)] """ if not isinstance(nodes, Sequence): nodes = [nodes] self.end_points = nodes graphs = list(self.argdef_graph_map.values()) for n in nodes: assert n.top_graph in graphs n.top_graph._end_point.append(n) def clear_end_points(self): r"""Clear the :attr:`~.TracedModule.end_points`. """ for n in self.end_points: n.top_graph._end_point = [] self.end_points = [] @property def graph(self) -> InternalGraph: """Return the ``InternalGraph`` of this ``TracedModule``. """ assert len(self.argdef_graph_map) == 1 return list(self.argdef_graph_map.values())[0] def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None): for inp_def, graph in self.argdef_graph_map.items(): if top_graph is not None: graph._top_graph = weakref.ref(top_graph) for n in graph._inputs + graph._outputs: n.expr._top_graph = weakref.ref(graph) n._top_graph = weakref.ref(graph) graph._inputs[0]._owner = weakref.ref(self) for i, n in enumerate(graph._inputs): n.actual_node = [] if actual_node_map is not None and inp_def in actual_node_map.keys(): n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i]) node2obj = {} next_actual_node_map = collections.defaultdict( lambda: collections.defaultdict(list) ) node2obj[graph._inputs[0]] = self for expr in graph._exprs: for n in expr.inputs + expr.outputs: n._top_graph = weakref.ref(graph) expr._top_graph = weakref.ref(graph) if isinstance(expr, GetAttr) and isinstance( expr.outputs[0], ModuleNode ): obj = expr.interpret(node2obj[expr.inputs[0]])[0] expr.outputs[0]._owner = weakref.ref(obj) node2obj[expr.outputs[0]] = obj if isinstance(expr, Constant) and isinstance( expr.outputs[0], ModuleNode ): obj = expr.value expr.outputs[0]._owner = weakref.ref(obj) node2obj[expr.outputs[0]] = obj if ( isinstance(expr, CallMethod) and expr.method == "__call__" and isinstance(expr.inputs[0], ModuleNode) ): obj = node2obj[expr.inputs[0]] if expr.arg_def is not None: next_actual_node_map[obj][expr.arg_def].append(expr.inputs) for obj in node2obj.values(): if obj is self: continue mnode_map = None if obj in next_actual_node_map.keys(): mnode_map = next_actual_node_map[obj] if isinstance(obj, TracedModule): obj._update_ref(mnode_map, graph) def flatten(self): r"""Get a new TracedModule, which eliminates ``GetAttr`` and has no hierarchy. Retruns: A new :class:`TracedModule`. """ new_module = copy.deepcopy(self) def _replace_inputs_and_outputs(expr: Expr, repl_dict: Dict[Node, Node]): inputs, outputs = expr.inputs, expr.outputs for i, node in enumerate(inputs): if node in repl_dict: inputs[i] = repl_dict[node] for i, node in enumerate(outputs): if node in repl_dict: outputs[i] = repl_dict[node] outputs[i].expr = expr def _flatten_subgraph( parent_graph: InternalGraph, graph: InternalGraph, call: CallMethod, module: Module, ): repl_dict, node2obj, rename_blacklist = {}, {}, [] if call is not None: graph = copy.deepcopy(graph) node2obj[call.inputs[0]] = module repl_dict = dict(zip(graph._inputs, call.inputs)) for ind, out in enumerate(graph.outputs): if isinstance(out.expr, Input): assert out in repl_dict call_out = call.outputs[ind] for expr in call.outputs[ind].users: for index, inp in enumerate(expr.inputs): if inp is call_out: expr.inputs[index] = repl_dict[out] repl_dict[out].users.append(expr) if parent_graph is not None: for index, parent_out in enumerate(parent_graph._outputs): if parent_out is call_out: parent_graph._outputs[index] = repl_dict[out] continue repl_dict[out] = call.outputs[ind] if isinstance(out, TensorNode): call.outputs[ind]._qualname = out._qualname for node, repl_node in repl_dict.items(): assert node in graph._inputs or node in graph._outputs repl_node.users.extend(node.users) rename_blacklist = list(chain(call.inputs, call.outputs)) node2obj[graph._inputs[0]] = module prefix_name = call.inputs[0]._name if call else "" flattened_exprs = [] for expr in graph._exprs: exprs = [expr] if call is not None: _replace_inputs_and_outputs(expr, repl_dict) if isinstance(expr, GetAttr): mnode = expr.inputs[0] node2obj[expr.outputs[0]] = expr.interpret(node2obj[mnode])[0] if isinstance(expr, CallMethod): obj_node = expr.inputs[0] if isinstance(obj_node, ModuleNode) and isinstance( obj_node.expr, GetAttr ): obj = node2obj[obj_node] expr_graph = ( obj.argdef_graph_map[expr.arg_def] if hasattr(obj, "argdef_graph_map") else None ) if expr_graph is not None and not obj.is_qat: exprs = _flatten_subgraph(graph, expr_graph, expr, obj) if parent_graph is not None: for node in expr.outputs: name = node._name if node not in rename_blacklist: name = "{}_{}".format(prefix_name, name) node._name = parent_graph._namespace.create_unique_name( name, node ) flattened_exprs.extend(exprs) if call is not None: for i in call.inputs: i.users.remove(call) return flattened_exprs new_module.graph._exprs = _flatten_subgraph( None, new_module.graph, None, new_module ) new_module.graph._re_associate_name() new_module.graph.compile() new_module._update_ref() new_module.graph._reset_ids() return new_module def __getstate__(self): d = self.__dict__.copy() for k in Module.__dict__: d.pop(k, None) _check_obj_attr(d) for k in d: if module_tracer.is_builtin(d[k]): assert _check_builtin_module_attr( d[k] ), "Module {} can not be serialized. ".format(type(d[k])) d[k] = _ModuleState.get_module_state(d[k]) dump_info = { "version": __version__, "register_type": USER_REGISTERED_LEAF_TYPE, "register_container_type": USER_REGISTERED_CONTAINER_TYPE, "register_mdule": USER_REGISTERED_MODULE, "register_function": USER_REGISTERED_FUNCTION, } d["dump_info"] = dump_info return d def __setstate__(self, state): for k, v in state.items(): if isinstance(v, _ModuleState): state[k] = v.to_module() super().__setstate__(state) self._update_ref() for _, graph in self.argdef_graph_map.items(): for expr in graph._exprs: if isinstance(expr, CallFunction): load_functional(expr) if isinstance(expr, CallMethod): if expr.method == "__call__": load_call_module_expr(expr) else: load_call_tensor_method_expr(expr) if isinstance(expr, Apply): load_apply_expr(expr) for _, graph in self.argdef_graph_map.items(): ind = 0 while ind < len(graph._exprs): cur_expr = graph._exprs[ind] has_new_expr = False for i in cur_expr.inputs: if i.expr not in graph._exprs and not isinstance(i.expr, Input): graph._exprs.insert(ind, i.expr) has_new_expr = True if not has_new_expr: ind += 1 for expr in graph._exprs: for i in expr.inputs: if expr.inputs.count(i) != i.users.count(expr): add_or_del_count = expr.inputs.count(i) - i.users.count(expr) if add_or_del_count > 0: i.users.extend([expr] * add_or_del_count) else: [i.users.remove(expr) for i in range(-add_or_del_count)] for o in expr.outputs: if o.expr is not expr: assert o not in o.expr.outputs o.expr = expr for node in graph.nodes(False): # remove users of node which doesn't use node as input node.users = [e for e in node.users if node in e.inputs] for expr in graph._exprs: graph._namespace.auto_naming_for_outputs(expr) self._update_ref() for _, graph in self.argdef_graph_map.items(): graph._reset_ids() def __copy__(self): cls = self.__class__ result = cls.__new__(cls) result.__dict__.update(self.__dict__) return result def __deepcopy__(self, memo): with max_recursion_limit(): cls = self.__class__ result = cls.__new__(cls) state = {} memo[id(self)] = result for k, v in self.__dict__.items(): if not isinstance(v, weakref.ReferenceType): state[k] = copy.deepcopy(v, memo) result.__dict__.update(state) result._update_ref() return result def cpp_apply_module_trace(opdef, *args): return Apply.apply_module_trace_hook(opdef, *args) USER_REGISTERED_MODULE = [] USER_REGISTERED_FUNCTION = [] def register_as_builtin(mod_cls: Type[Module]) -> None: r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. Args: mod_cls: the module class which will be treated as builtin module in tracing. """ USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__)) module_tracer.register_as_builtin(mod_cls) def wrap(func: Callable): r"""Call this function to register ``func`` as a builtin function. This function can be called at module-level scope to register ``func`` as a builtin function. A builtin function will be converted to a :class:`CallFunction` Expr in tracing:: def my_func(x, y): return x + y import megengine.traced_module as tm tm.wrap(my_func) This function can also equivalently be used as a decorator:: @tm.wrap def my_func(x, y): return x + y Args: func: the function of the global function to insert into the graph when it's called. """ USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__)) assert callable(func), "func must be a callable" assert hasattr(func, "__code__") fn_name = func.__code__.co_name currentframe = inspect.currentframe() assert currentframe is not None f = currentframe.f_back assert f is not None assert ( f.f_code.co_name == "" ), "wrap must be called at the top level of a module" Patcher._builtin_functions.append((f.f_globals, fn_name)) return func def _register_all_builtin_module(): for sub_mod in [M, M.qat, M.quantized, MExternal]: for m in getmembers(sub_mod): if ( isclass(m[1]) and issubclass(m[1], M.Module) and m[1] is not M.Sequential ): module_tracer.register_as_builtin(m[1]) module_tracer.register_as_builtin(Observer) module_tracer.register_as_builtin(MinMaxObserver) module_tracer.register_as_builtin(SyncMinMaxObserver) module_tracer.register_as_builtin(ExponentialMovingAverageObserver) module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver) module_tracer.register_as_builtin(HistogramObserver) module_tracer.register_as_builtin(PassiveObserver) module_tracer.register_as_builtin(LSQ) module_tracer.register_as_builtin(TQT) module_tracer.register_as_builtin(FakeQuantize) module_tracer.register_as_builtin(TM_FakeQuant) def trace_module( mod: Module, *args: Tuple[Any], **kwargs: Dict[str, Any] ) -> TracedModule: r"""Traces module ``mod`` and returns corresponding :class:`TracedModule`. Args: mod: the module will be converted to :class:`TracedModule`. args: the positional arguments passed to forward method of ``mod``. kwargs: the keyword arguments passed to forward method of ``mod``. """ assert active_module_tracer() is None assert isinstance(mod, Module) try: net_name = mod._name if mod._name else mod.__class__.__name__ use_sym_shape = set_symbolic_shape(True) set_active_module_tracer(module_tracer(_wrapped_function)) set_module_tracing() for cls in [Expr, Node]: cls._set_next_id(0) with active_module_tracer().patcher: global_scope = InternalGraph(name="top", qualname=net_name) active_module_tracer().push_scope(global_scope) builder = TracedModuleBuilder(mod, True) NodeMixin.wrap_safe( builder, Input.make(name="top", type=ModuleNode, qualname=net_name) ) forward_argspec = ( mod.argspec if hasattr(mod, "argspec") else inspect.getfullargspec(mod.forward) ) args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True) inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): # assert isinstance(i, Tensor), "not support " if isinstance(i, RawTensor): NodeMixin.wrap_safe( i, Input.make( name="arg_{}".format(_), type=NodeMixin.get_wrapped_type(i), qualname="{}.[{}]".format(net_name, "arg_{}".format(_)), ), ) rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs)) active_module_tracer().pop_scope() traced_mod = builder.build() traced_mod.argspec = forward_argspec traced_mod.graph._reset_ids() has_expr_not_check = False if _get_expr_checker(): has_expr_not_check = ( active_module_tracer().checker.check_node_not_in_scope() ) if _get_default_checker() or has_expr_not_check: with _exclude_from_trace(): tm_res = traced_mod(*args, **kwargs) tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf) rst, _ = tree_flatten(rst, is_leaf=_is_leaf) active_module_tracer().checker.check_net_outputs(tm_res, rst) return traced_mod finally: set_symbolic_shape(use_sym_shape) unset_module_tracing() for t in mod.tensors(recursive=True): NodeMixin.clear_node(t) for t in inputs: NodeMixin.clear_node(t) set_active_module_tracer(None)