# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from copy import deepcopy from typing import Union from ..core.tensor.dtype import QuantDtypeMeta from ..quantization.fake_quant import QParamsModuleMixin, _FakeQuantize from ..quantization.utils import QParams, QuantMode, fake_quant_tensor class FakeQuantize(_FakeQuantize, QParamsModuleMixin): r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`.""" def __init__( self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs ): super().__init__(dtype, enable, **kwargs) self.qparams = None def fake_quant_forward(self, inp, qparams: QParams = None): if qparams is None: qparams = self.get_qparams() assert ( qparams.dtype_meta is self.dtype ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( qparams.dtype_meta, self.dtype ) return fake_quant_tensor(inp, qparams) def get_qparams(self): return self.qparams def set_qparams(self, qparams: QParams): r"""Initialize :attr:`~.FakeQuantize.qparams`. Args: qparams: used to set initial ``scale`` and ``zero_point``. """ if qparams.scale is None: raise AssertionError("Can not get an initialized scale") scale = qparams.scale if qparams.dtype_meta is None: qparams.dtype_meta = self.dtype else: assert ( qparams.dtype_meta is self.dtype ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( qparams.dtype_meta, self.dtype ) dtype_meta = qparams.dtype_meta zero_point = qparams.zero_point mode = qparams.mode self.qparams = QParams(mode, dtype_meta, scale, zero_point)