# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import abstractmethod from typing import Any, Callable, Dict, List from ...core._imperative_rt import OpDef from ...logger import get_logger from ...module import Module from ..expr import Expr from ..node import Node logger = get_logger(__name__) class ExprPattern: def __init__(self): self._check_users = True self._users = [] def __call__(self, *args): args = list(args) if len(args) == 1 and args[0] is None: args = None return CallPattern(self, *args) def __add__(self, other): return is_op("__add__")(self, other) def __iadd__(self, other): return is_op("__iadd__")(self, other) def __radd__(self, other): return is_op("__radd__")(self, other) def __sub__(self, other): return is_op("__sub__")(self, other) def __isub__(self, other): return is_op("__isub__")(self, other) def __rsub__(self, other): return is_op("__rsub__")(self, other) def __mul__(self, other): return is_op("__mul__")(self, other) def __imul__(self, other): return is_op("__imul__")(self, other) def __rmul__(self, other): return is_op("__rmul__")(self, other) def __truediv__(self, other): return is_op("__truediv__")(self, other) def __itruediv__(self, other): return is_op("__itruediv__")(self, other) def __rtruediv__(self, other): return is_op("__rtruediv__")(self, other) def __or__(self, other): assert isinstance(other, ExprPattern) return OrPattern(self, other) def get_output(self, index): raise NotImplementedError def check_users(self, check: bool = True): self._check_users = check return self def _add_users(self, pattern: "ExprPattern"): self._users.append(pattern) def _clear_users(self,): self._users.clear() def __getitem__(self, index): return is_op("__getitem__")(self, index) def has_attr(self, **attrs): logger.warning("has_param only support ModulePattern") return self def has_param(self, **params): logger.warning("has_param only support FunctionPattern") return self @abstractmethod def __repr__(self) -> str: raise NotImplementedError class CallPattern(ExprPattern): def __init__(self, op: ExprPattern, *args: List[ExprPattern]): super().__init__() self.op = op self.args = list(filter(lambda x: isinstance(x, ExprPattern), args)) self._match_all_args = True def __repr__(self) -> str: return "{}({})".format(self.op, ",".join(str(x) for x in self.args)) def not_all_args(self): self._match_all_args = False def check_users(self, check: bool = True): self._check_users = check self.op.check_users(check) return self def _add_users(self, pattern: "ExprPattern"): self._users.append(pattern) self.op._add_users(pattern) def _clear_users(self): self._users.clear() self.op._clear_users() class OrPattern(ExprPattern): def __init__(self, left: ExprPattern, right: ExprPattern): super().__init__() self.left = left self.right = right def __repr__(self) -> str: return "({}|{})".format(self.left, self.right) def check_users(self, check: bool = True): self._check_users = check self.left.check_users(check) self.right.check_users(check) return self def _clear_users(self): self._users.clear() self.left._clear_users() self.right._clear_users() class GetOutputPaterrn(ExprPattern): def __init__(self, op, index): super().__init__() self.op = op self.index = index def __repr__(self) -> str: return "{}[{}]".format(self.op, self.index) class ModulePattern(ExprPattern): def __init__(self, module_cls: Module) -> None: super().__init__() self.attrs = {} self.target = module_cls def has_attr(self, **attrs): self.attrs.update(attrs) return self def __repr__(self) -> str: return "{}".format(self.target.__name__) class FunctionPattern(ExprPattern): def __init__(self, func: Callable): super().__init__() self.params = {} self.target = func def has_params(self, **params): self.params.update(params) return self def __repr__(self) -> str: return "{}".format(self.target.__name__) class TensorMethodPattern(ExprPattern): def __init__(self, method: str): super().__init__() self.target = method def __repr__(self) -> str: return self.target class ApplyDefPattern(ExprPattern): def __init__(self, opdef: OpDef): super().__init__() self.target = opdef def __repr__(self) -> str: return "{}".format(self.target.__name__) class VarPattern(ExprPattern): def __init__(self): super().__init__() def __repr__(self) -> str: return "var" class ConstantPattern(ExprPattern): def __init__(self): super().__init__() def __repr__(self) -> str: return "const" class AnyPattern(ExprPattern): def __init__(self): super().__init__() def __repr__(self) -> str: return "any" def is_op(target): if isinstance(target, type): if issubclass(target, Module): return ModulePattern(target) if issubclass(target, OpDef): return ApplyDefPattern(target) elif callable(target): return FunctionPattern(target) elif isinstance(target, str): return TensorMethodPattern(target) else: raise ValueError("not support") def is_const(): return ConstantPattern().check_users(False) def any_node(): return AnyPattern() def is_var(): return VarPattern()