# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import abc from enum import Enum from functools import partial, update_wrapper, wraps from typing import Union import numpy as np from .. import functional as F from ..autodiff import Function from ..core._imperative_rt.core2 import apply from ..core.ops import builtin from ..core.tensor.dtype import ( QuantDtypeMeta, _builtin_quant_dtypes, create_quantized_dtype, ) from ..tensor import Tensor class Round(Function): r"""The functional round have no grad and can not use for quantization-aware-training. We use Function and STE(Straight-Through Estimator) to implement backward propagation. """ def forward(self, x): return F.round(x) def backward(self, output_grads): return output_grads def tqt_forward(qmin, qmax, inp, scale): op = builtin.TQT(qmin=qmin, qmax=qmax) (output,) = apply(op, inp, scale) return output def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None): if zero_point is None: zero_point = Tensor([0.0], dtype=np.float32) if scale_grad is None: scale_grad = Tensor([1.0], dtype=np.float32) op = builtin.LSQ(qmin=qmin, qmax=qmax) (output,) = apply(op, inp, step_size, zero_point, scale_grad) return output def register_method_to_class(cls): def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) if isinstance(func, partial): update_wrapper(func, func.func) setattr(cls, func.__name__, wrapper) return func return decorator class QuantMode(Enum): r"""Quantization mode enumerate class.""" SYMMERTIC = 1 ASYMMERTIC = 2 class QParams: r"""To standardize FakeQuant, Observer and Tensor's qparams format. If custom qparams is needed, inherit this class and add custom ``__slots__``. """ __slots__ = "mode", "dtype_meta", "scale", "zero_point" def __init__( self, mode: QuantMode, dtype_meta: QuantDtypeMeta, scale: Tensor, zero_point: Tensor, ): self.mode = mode self.dtype_meta = dtype_meta self.scale = scale self.zero_point = zero_point def update(self, qparams: "QParams"): for key in self.__slots__: setattr(self, key, getattr(qparams, key)) def __eq__(self, other): if len(self.__slots__) != len(other.__slots__): return False for key in self.__slots__: if not hasattr(other, key) or getattr(self, key) != getattr(other, key): return False return True def __repr__(self): content = ", ".join( ["{}={}".format(key, getattr(self, key)) for key in self.__slots__] ) return "QParams({})".format(content) class LSQParams(QParams): r"""LSQ qparams with extra grad_scale slot.""" __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" def __init__( self, mode: QuantMode, dtype_meta: QuantDtypeMeta, scale: Tensor, zero_point: Tensor, grad_scale: Tensor, ): super().__init__(mode, dtype_meta, scale, zero_point) self.grad_scale = grad_scale class QParamsModuleMixin(abc.ABC): def get_quantized_dtype(self): qparams = self.get_qparams() dtype = qparams.dtype_meta scale = float(qparams.scale.numpy()) if qparams.scale is not None else None zero_point = ( int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None ) return create_quantized_dtype(dtype, scale, zero_point) @abc.abstractmethod def get_qparams(self) -> QParams: pass _builtin_qparams = { QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC), QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC), } def create_qparams( mode: QuantMode = QuantMode.SYMMERTIC, dtype_meta: Union[str, QuantDtypeMeta] = None, scale: Tensor = None, zero_point: Tensor = None, ): r""" Args: mode: QuantMode: dtype_meta: Union[str: QuantDtypeMeta]: scale: Tensor: zero_point: Tensor: """ if isinstance(dtype_meta, str): dtype_meta = _builtin_quant_dtypes[dtype_meta] if mode is None: return QParams(mode, dtype_meta, scale, zero_point) assert isinstance(mode, QuantMode) return _builtin_qparams[mode]( dtype_meta=dtype_meta, scale=scale, zero_point=zero_point ) def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor: """Apply fake quantization to the inp tensor. Args: inp: the input tensor which need to be faked. qparams: to get mode, qmin, qmax, scale and zero_point from. """ scale = qparams.scale if qparams.mode == QuantMode.ASYMMERTIC: zero_point = qparams.zero_point else: zero_point = Tensor([0.0], dtype=np.float32) qmin = qparams.dtype_meta.qmin qmax = qparams.dtype_meta.qmax op = builtin.FakeQuant(qmin=qmin, qmax=qmax) return apply(op, inp, scale, zero_point)[0] def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: """Apply fake quantization to bias, with the special scale from input tensor and weight tensor, the quantized type set to qint32 also. Args: bias: the bias tensor which need to be faked. inp: the input tensor which contain the quantization parameters. w_qat: the weight tensor which contain the quantization parameters. Warning: Only work for symmetric quantization method now. """ b_qat = bias if ( getattr(inp, "qparams", None) is not None and getattr(w_qat, "qparams", None) is not None and bias is not None ): inp_params = inp.qparams w_params = w_qat.qparams if inp_params.scale is not None and w_params.scale is not None: assert inp_params.mode == w_params.mode, "incompatible QuantMode" # TODO: support quint8 dtype. assert ( inp_params.dtype_meta.np_dtype_str == "int8" and w_params.dtype_meta.np_dtype_str == "int8" ), "fake_quant_bias only support int8 like dtype now" # use the same mode with weight. # TODO: avoid hardcode b_dtype = _builtin_quant_dtypes["qint32"] b_param = create_qparams( w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale ) b_qat = fake_quant_tensor(bias, b_param) b_qat.qparams.update(b_param) return b_qat