# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from copy import deepcopy from typing import List, Set from ...logger import get_logger from ..traced_module import TracedModule from .pass_base import get_default_pass_context, get_registered_pass logger = get_logger(__name__) def optimize( module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"], ) -> TracedModule: r"""Performs a set of optimization passes to optimize a `TracedModule` for inference. The following passes are currently supported: * FuseConvBn: fuse BN layers into to conv2d * FuseAddMul: fold adjacent const add or mul binary operations * BackwardFoldScale: backward fold const scaling into weights of conv2d Args: module: the :class:`TracedModule` to be optimized. enabled_pass: optimization passes to be enabled during optimization. Default: ["FuseConvBn"] Returns: the optimized :class:`TracedModule`. """ defalut_passes_list = [ "FuseConvBn", "FuseAddMul", ] if isinstance(enabled_pass, str): enabled_pass = [enabled_pass] if "BackwardFoldScale" in enabled_pass: if "FuseConvBn" not in enabled_pass: logger.warning( "Since BackwardFoldScale requires FuseConvBn" ", FuseConvBn will be enabled." ) enabled_pass.append("FuseConvBn") defalut_passes_list.extend( ["BackwardFoldScale", "FuseAddMul",] ) pass_ctx = get_default_pass_context() def run_pass(mod: TracedModule): for pass_name in defalut_passes_list: if pass_name in enabled_pass: pass_func = get_registered_pass(pass_name)() mod = pass_func(mod, pass_ctx) return mod module = deepcopy(module) module = run_pass(module) return module