# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import Any, Dict, List, Set

from ... import functional as F
from ... import module as M
from ...core.ops.builtin import GetVarShape
from ...logger import get_logger
from ...tensor import Parameter, Tensor
from ..expr import (
    Expr,
    is_apply_def,
    is_call_function,
    is_call_module,
    is_call_tensor_method,
    is_constant,
    is_getattr,
)
from ..traced_module import InternalGraph
from ..utils import assign_attr, get_subattr
from .matcher import PatternMatcher
from .pass_base import BackwardPass, register_pass
from .pattern import is_const, is_op, is_var
from .utils import get_const_value

logger = get_logger(__name__)


@register_pass("BackwardFoldScale")
class BackwardFoldScale(BackwardPass):
    r"""Backward fold const scaling into weights of conv2d.
    
    For example, the following code
    
    .. code-block::
        
        x = conv(x, w, b)
        x = relu(x)
        x1 = x + 3
        x2 = x + 4
        y = (x1 + x2) * 3
    
    will be changed to
    
    .. code-block::

        x = conv(x, w * 3, b * 3)
        x = relu(x)
        x1 = x + 9
        x2 = x + 12
        y = x1 + x2

    """
    name = "BackwardFoldScale"
    required_pass = ["AttrToConstant", "NormElemWise"]
    run_once = True

    def __init__(self):
        super().__init__()
        # todo : supoort more axis
        self.scale_message = OrderedDict()
        self.used_names = defaultdict(int)

    def run_transform(self, expr: Expr) -> Expr:
        if expr not in self.scale_message:
            return expr

        var = is_var().check_users(False)
        mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg)
        add_const_pattern = var + is_const() | var + "*"
        conv_pattern = is_op(F.conv2d) | is_op(M.Conv2d)
        pattern = conv_pattern | add_const_pattern | mul_const_pattern
        macther = PatternMatcher()

        if not macther.match(pattern, expr):
            return expr
        macther_exprs = macther.matched_exprs

        if conv_pattern in macther_exprs:
            return self.fold_conv_mul(expr)

        if mul_const_pattern in macther_exprs:
            return self.fold_mul(expr)

        if add_const_pattern in macther_exprs:
            return self.fold_add_mul(expr)

        return expr

    def fold_add_mul(self, expr: Expr):
        if self.scale_message[expr] is None:
            return expr
        scale = self.scale_message[expr]
        if len(expr.inputs) == 1:
            const = expr.const_val[0][-1]
        else:
            const = get_const_value(expr.inputs[1])

        const = const * scale
        inp_node = expr.inputs[0]
        graph = expr.top_graph
        with graph.insert_exprs():
            add_node = inp_node + const

        graph.replace_node({expr.outputs[0]: add_node})
        graph.compile()
        add_node.name = expr.outputs[0].name
        return add_node.expr

    def fold_mul(self, expr: Expr):
        if self.scale_message[expr] is None:
            return expr
        graph = expr.top_graph
        graph.replace_node({expr.outputs[0]: expr.inputs[0]})
        graph.compile()
        return expr

    def fold_conv_mul(self, expr: Expr):
        graph = expr.top_graph
        scale = self.scale_message[expr]

        if scale is None:
            return expr

        if is_call_function(expr, F.conv2d):
            named_args = expr.named_args
            weight = get_const_value(named_args["weight"], named_args["weight"]) * scale
            bias = get_const_value(named_args["bias"], named_args["bias"]) * scale
            named_args["weight"] = weight
            named_args["bias"] = bias
            with graph.insert_exprs():
                out_node = F.conv2d(**named_args)
            graph.replace_node({expr.outputs[0]: out_node})
            graph.compile()
            out_node.name = expr.outputs[0].name
            return out_node.expr
        else:
            mnode = expr.inputs[0]
            attr_name = expr.inputs[0].expr.name
            graph = expr.top_graph
            if len(mnode.users) > 1:
                self.used_names[mnode.qualname] += 1
                attr_name = "{}_{}".format(attr_name, self.used_names[mnode.qualname])
                logger.warning(
                    "{} is used {} times and its name will be reset to {}.{}".format(
                        mnode.qualname, len(mnode.users), graph.qualname, attr_name
                    )
                )
            conv_module = mnode.owner
            if len(mnode.users) > 1:
                conv_module = deepcopy(conv_module)
                conv_module._name = None
            conv_module.weight = Parameter(conv_module.weight * scale)
            if conv_module.bias is not None:
                conv_module.bias = Parameter(conv_module.bias * scale)

            if len(mnode.users) > 1:
                self_node = mnode.expr.inputs[0]
                assign_attr(conv_module, self_node.owner, attr_name)
                with graph.insert_exprs(mnode.expr):
                    new_conv_node = get_subattr(self_node, attr_name)
                expr.replace_inputs({mnode: new_conv_node})
            return expr

    def reset_expr_message_to_none(
        self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr],
    ):
        if expr in skip_exprs:
            return
        scale_message[expr] = None
        if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d):
            return
        for out_node in expr.outputs:
            for user in out_node.users:
                if user in scale_message:
                    self.reset_expr_message_to_none(user, scale_message, skip_exprs)

    def before_visit_graph(self, graph: InternalGraph):
        var = is_var().check_users(False)
        mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg)
        relu_pattern = (
            is_op(F.relu) | is_op(M.ReLU) | is_op(F.leaky_relu) | is_op(M.LeakyReLU)
        )

        # The param of conv must be const, not support dynamic conv
        conv_pattern = (
            is_op(F.conv2d)(var, is_const(), is_const())
            | is_op(F.conv2d)(var, is_const())
            | is_op(M.Conv2d)
        )

        pattern = mul_const_pattern | relu_pattern | conv_pattern
        for op in [
            "__add__",
            F.reshape,
            "reshape",
            F.transpose,
            "tranpose",
            F.min,
            "min",
            F.max,
            "max",
            F.max_pool2d,
            M.MaxPool2d,
            F.avg_pool2d,
            M.AvgPool2d,
            F.adaptive_avg_pool2d,
            M.AdaptiveAvgPool2d,
            F.adaptive_max_pool2d,
            M.AdaptiveMaxPool2d,
            F.expand_dims,
            F.concat,
            "__getitem__",
        ]:
            pattern |= is_op(op)

        matcher = PatternMatcher()

        scale_message = OrderedDict()
        mem_conv_scale_message = OrderedDict()
        skip_exprs = self.init_skip_exprs(graph)
        for expr in reversed(graph._exprs):
            if expr in skip_exprs:
                continue

            if len(expr.outputs) > 1 or not matcher.match(pattern, expr):
                self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
                if is_call_function(expr, F.conv2d):
                    for user in expr.outputs[0].users:
                        self.reset_expr_message_to_none(user, scale_message, skip_exprs)
                continue

            matched_exprs = matcher.matched_exprs

            const = None
            if mul_const_pattern in matched_exprs:
                if is_call_function(expr, F.neg):
                    const = -1
                elif len(expr.inputs) == 1:
                    const = expr.const_val[0][-1]
                else:
                    const = get_const_value(expr.inputs[1])

            if isinstance(const, Tensor) and const._tuple_shape not in [(1,), tuple()]:
                self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
                continue

            users_const = [
                scale_message[e] for e in expr.outputs[0].users if e not in skip_exprs
            ]

            if len(users_const) == 0:
                scale_message[expr] = const
                continue

            if any(c is None or c != users_const[0] for c in users_const):
                self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
                scale_message[expr] = const
                continue

            const = 1 if const is None else const
            const = const * users_const[0]
            if relu_pattern in matched_exprs and const < 0:
                self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
                continue

            if conv_pattern in matched_exprs:
                self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
                mem_conv_scale_message[expr] = const
                continue

            scale_message[expr] = const

        self.scale_message.update(scale_message)
        self.scale_message.update(mem_conv_scale_message)

    def init_skip_exprs(self, graph: InternalGraph):
        skip_exprs = set()
        for expr in graph._exprs:
            if is_apply_def(expr, GetVarShape):
                skip_exprs.add(expr)
            elif is_call_tensor_method(expr, "__getitem__") and expr in skip_exprs:
                skip_exprs.add(expr)
            elif is_getattr(expr):
                skip_exprs.add(expr)
            elif is_constant(expr):
                skip_exprs.add(expr)
            elif all(n.expr in skip_exprs for n in expr.inputs):
                skip_exprs.add(expr)
        return skip_exprs