# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import operator from collections import defaultdict from typing import Any, Callable, List from ... import functional as F from ... import module as M from ...logger import get_logger from ...tensor import Parameter, Tensor from ...utils.bn_fusion import fold_weight_bias from ..expr import Expr, is_call_function from ..utils import assign_attr, get_subattr from .matcher import PatternMatcher from .pass_base import BackwardPass, register_pass from .pattern import ExprPattern, any_node, is_const, is_op, is_var from .utils import get_const_value, register_obj logger = get_logger(__name__) @register_pass("FuseAddMul") class FuseAddMul(BackwardPass): """Fold adjacent const add or mul binary operations. For example, the following code .. code-block:: x = x + 1 x = 2 + x x = x * 4 x = x * 0.25 will be changed to .. code-block:: x = x + 3 """ name = "FuseAddMul" required_pass = ["NormElemWise"] run_once = False def __init__(self,): super().__init__() def _make_pattern(op_0, op_1) -> ExprPattern: x = is_var().check_users(False) if op_0 not in [operator.add, operator.mul]: op_0 = is_op(op_0) if op_1 not in [operator.add, operator.mul]: op_1 = is_op(op_1) pattern = op_0(x, is_const()) | op_0(x, "*") pattern = op_1(pattern, is_const()) | op_1(pattern, "*") return pattern self.pattern_dict = {} for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],): self.pattern_dict[_make_pattern(op, op)] = func for op_0 in [F.neg, operator.mul]: for op_1 in [F.neg, operator.mul]: self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul def run_transform(self, expr: Expr): matcher = PatternMatcher() for pattern, func in self.pattern_dict.items(): res = matcher.match(pattern, expr) if res: break if not res: return expr return func(expr) def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable): const_0 = self.get_const_value(expr) # todo: support more shape if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]: return expr const_1 = self.get_const_value(expr.inputs[0].expr) if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]: return expr inp_node = expr.inputs[0].expr.inputs[0] const = op_c(const_0, const_1) graph = expr.top_graph if (const == 1 and op_t in [operator.pow, operator.mul]) or ( const == 0 and op_t in [operator.add] ): graph.replace_node({expr.outputs[0]: inp_node}) graph.compile() return expr with expr.top_graph.insert_exprs(): out_node = op_t(inp_node, const) graph.replace_node({expr.outputs[0]: out_node}) graph.compile() return out_node.expr def fold_add(self, expr: Expr): return self._fold_helper(expr, operator.add, operator.add) def fold_mul(self, expr): return self._fold_helper(expr, operator.mul, operator.mul) def fold_pow(self, expr): return self._fold_helper(expr, operator.mul, F.pow) def get_const_value(self, expr: Expr): if is_call_function(expr, F.neg): return -1 if len(expr.inputs) == 2: value = get_const_value(expr.inputs[1].expr, None) assert value is not None, " " return value value = expr.const_val[0][-1] return value @register_pass("FuseConvBn") class FuseConvBn(BackwardPass): r"""Fuse BN layers into conv2d.""" name = "FuseConvBn" required_pass = ["AttrToConstant"] run_once = True def __init__(self): super().__init__() self.used_name = defaultdict(int) def run_transform(self, expr: Expr): conv_pat_0 = is_op(M.Conv2d) conv_pat_1 = is_op(F.conv2d) bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1) bn_pat_1 = is_op(F.batch_norm) # inp, running_mean, running_var, weight, bias bn_inps = ( conv_pat_0 | conv_pat_1, is_const(), is_const(), is_const(), is_const(), ) bn_pat = ( (bn_pat_1(*bn_inps[:3])) | (bn_pat_1(*bn_inps[:4])) | (bn_pat_1(*bn_inps)) | bn_pat_0 ) matcher = PatternMatcher() if not matcher.match(bn_pat, expr): return expr matched_exprs = matcher.matched_exprs if conv_pat_0 in matched_exprs: return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat]) else: return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat]) def fold_convm_bn(self, conv: Expr, bn: Expr): mnode, inp_node = conv.inputs[:2] self_node = mnode.expr.inputs[0] attr_name = conv.inputs[0].expr.name graph = conv.top_graph if len(mnode.users) > 1: self.used_name[mnode.qualname] += 1 attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname]) logger.warning( "{} is used {} times and its name will be reset to {}.{}".format( mnode.qualname, len(mnode.users), graph.qualname, attr_name ) ) conv_module = mnode.owner weight, bias = conv_module.weight, conv_module.bias mean, var, gamma, beta, eps = self.get_bn_params(bn) weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) new_conv = M.Conv2d( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=conv_module.kernel_size, stride=conv_module.stride, padding=conv_module.padding, dilation=conv_module.dilation, groups=conv_module.groups, bias=conv_module.bias is not None, conv_mode=conv_module.conv_mode, compute_mode=conv_module.compute_mode, name=conv_module.name, ) new_conv.weight = Parameter(weight) new_conv.bias = Parameter(bias) new_conv.training = conv_module.training assign_attr(new_conv, self_node.owner, attr_name) with graph.insert_exprs(mnode.expr): out_node = get_subattr(self_node, attr_name)(inp_node) graph.replace_node({bn.outputs[0]: out_node}) graph.compile() out_node.name = conv.outputs[0].name return out_node.expr def fold_convf_bn(self, conv: Expr, bn: Expr): named_args = conv.named_args weight = get_const_value(named_args["weight"], named_args["weight"]) bias = get_const_value(named_args["bias"], named_args["bias"]) mean, var, gamma, beta, eps = self.get_bn_params(bn) weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps) named_args["weight"] = weight named_args["bias"] = bias graph = conv.top_graph with graph.insert_exprs(): out_node = F.conv2d(**named_args) graph.replace_node({bn.outputs[0]: out_node}) graph.compile() out_node.name = conv.outputs[0].name return out_node.expr def get_bn_params(self, bn: Expr): if is_call_function(bn): named_args = bn.named_args mean = get_const_value( named_args["running_mean"], named_args["running_mean"] ) var = get_const_value(named_args["running_var"], named_args["running_var"]) gamma = get_const_value(named_args["weight"], named_args["weight"]) beta = get_const_value(named_args["bias"], named_args["bias"]) eps = named_args["eps"] return mean, var, gamma, beta, eps else: bn_module = bn.inputs[0].owner mean = bn_module.running_mean var = bn_module.running_var gamma = bn_module.weight beta = bn_module.bias eps = bn_module.eps return mean, var, gamma, beta, eps