# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np from ..functional import matmul, relu from ..tensor import Parameter from . import init from .module import Module class BatchMatMulActivation(Module): r"""Batched :func:`~.matmul` with activation(only :func:`~.relu` supported), no transpose anywhere.""" def __init__( self, batch: int, in_features: int, out_features: int, bias: bool = True, nonlinear_mode="identity", **kwargs ): super().__init__(**kwargs) self.batch = batch self.out_features = out_features self.in_features = in_features w_shape = (batch, out_features, in_features) self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) self.bias = None if bias: b_shape = (out_features,) self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) self.nonlinear_mode = nonlinear_mode.lower() self.reset_parameters() def _get_fanin(self): return self.in_features def reset_parameters(self) -> None: fanin = self._get_fanin() std = np.sqrt(1 / fanin) init.normal_(self.weight, 0.0, std) if self.bias is not None: init.zeros_(self.bias) def _calc_linear(self, x, weight, bias): res = matmul(weight, x) if self.bias is not None: res += bias if self.nonlinear_mode == "relu": res = relu(res) return res def forward(self, x): return self._calc_linear(x, self.weight, self.bias) def _module_info_string(self) -> str: return "batch={}, in_features={}, out_features={}, bias={}".format( self.batch, self.in_features, self.out_features, self.bias is not None )