# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os from contextlib import contextmanager from ._imperative_rt.core2 import get_option, set_option __compute_mode = "default" __conv_format = "default" _benchmark_kernel = False _deterministic_kernel = False __all__ = [ "benchmark_kernel", "deterministic_kernel", "async_level", "_compute_mode", "_conv_format", "_override", ] @property def benchmark_kernel(mod): r"""Whether or not run possible algorithms on real device to find the best one. The default option is false, which means use heuristic to choose the fastest algorithm. Examples: .. code-block:: import megengine as mge mge.config.benchmark_kernel = True """ return _benchmark_kernel @benchmark_kernel.setter def benchmark_kernel(mod, option: bool): global _benchmark_kernel _benchmark_kernel = option @property def deterministic_kernel(mod): r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false, which means the algorithm is not reproducible. Examples: .. code-block:: import megengine as mge mge.config.deterministic_kernel = True """ return _deterministic_kernel @deterministic_kernel.setter def deterministic_kernel(mod, option: bool): global _deterministic_kernel _deterministic_kernel = option @property def async_level(mod) -> int: r"""Get or set config whether raise error exactly when invoking op. The default level is 2, which means both device and user side errors are async. Examples: .. code-block:: import megengine as mge mge.config.async_level = 2 """ return get_option("async_level") @async_level.setter def async_level(mod, level: int): assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2" set_option("async_level", level) @property def _compute_mode(mod): r"""Get or set the precision of intermediate results. The default option is "default", which means that no special requirements will be placed on. When set to 'float32', it would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. Examples: .. code-block:: import megengine as mge mge.config._compute_mode = "default" """ return __compute_mode @_compute_mode.setter def _compute_mode(mod, _compute_mode: str): global __compute_mode __compute_mode = _compute_mode @property def _conv_format(mod): r"""Get or set convolution data/filter/output layout format. The default option is "default", which means that no special format will be placed on. There are all layout definitions ``NCHW`` layout: ``{N, C, H, W}`` ``NHWC`` layout: ``{N, H, W, C}`` ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}`` ``NHWCD4I`` layout: with ``align_axis = 2`` ``NCHW4`` layout: ``{N, C/4, H, W, 4}`` ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` Examples: .. code-block:: import megengine as mge mge.config._conv_format = "NHWC" """ return __conv_format @_conv_format.setter def _conv_format(mod, format: str): global __conv_format __conv_format = format def _reset_execution_config( benchmark_kernel=None, deterministic_kernel=None, async_level=None, compute_mode=None, conv_format=None, ): global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format orig_flags = ( _benchmark_kernel, _deterministic_kernel, get_option("async_level"), __compute_mode, __conv_format, ) if benchmark_kernel is not None: _benchmark_kernel = benchmark_kernel if deterministic_kernel is not None: _deterministic_kernel = deterministic_kernel if async_level is not None: set_option("async_level", async_level) if compute_mode is not None: __compute_mode = compute_mode if conv_format is not None: __conv_format = conv_format return orig_flags @contextmanager def _override( benchmark_kernel=None, deterministic_kernel=None, async_level=None, compute_mode=None, conv_format=None, ): r"""A context manager that users can opt in by attaching the decorator to set the config of the global variable. Examples: .. code-block:: import megengine as mge @mge.config._override( benchmark_kernel = True, deterministic_kernel = Fasle, async_level=2, compute_mode="float32", conv_format="NHWC", ) def train(): """ orig_flags = _reset_execution_config( benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format, ) try: yield finally: # recover the previous values _reset_execution_config(*orig_flags) def _get_actual_op_param(function_param, config_param): return function_param if config_param == "default" else config_param