# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import types import numpy as np import pytest import megengine as mge import megengine.functional as F import megengine.module as M import megengine.traced_module as tm class myconv(M.Conv2d): pass class mybn(M.BatchNorm2d): pass class MyBlock(M.Module): def __init__(self, conv_cls, bn_cls): super().__init__() self.conv = conv_cls(3, 3, 1, 1, 0) self.bn = bn_cls(3) self.conv2 = conv_cls(3, 3, 1, 1, 0) self.bn2 = bn_cls(3) self.scale = mge.Tensor([3, 4]) def forward(self, x): x1 = self.conv(x) x1 = self.bn(x1) x1 = F.relu(x1) x1 = x1 * self.scale[0] x2 = self.conv2(x) x2 = self.bn2(x2) x2 = F.relu(x2) x2 = x2 * self.scale[1] y = x1 + x2 y = y + 4 y = self.scale[0] + y y = F.relu(y) * 3 return y class MyModule(M.Module): def __init__(self, conv_cls, bn_cls): super().__init__() self.block_0 = MyBlock(conv_cls, bn_cls) self.block_1 = MyBlock(conv_cls, bn_cls) def forward(self, x): x1 = self.block_0(x) x2 = self.block_1(x) y = x1 + x2 y = F.reshape(y, (-1)) y = y * 3 return y @pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) @pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) def test_backward_fold_scale(conv_cls, bn_cls): module = MyModule(conv_cls, bn_cls) module.eval() inp = mge.Tensor(np.random.random((1, 3, 32, 32))) desired = module(inp) traced_net = tm.trace_module(module, inp) traced_net = traced_net.flatten() optimized_net = tm.optimize(traced_net, "BackwardFoldScale") actual = optimized_net(inp) np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) # fuse all mul to conv mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list() assert len(mul_list) == 0 @pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) @pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) def test_fuse_bn(conv_cls, bn_cls): module = MyModule(conv_cls, bn_cls) module.eval() inp = mge.Tensor(np.random.random((1, 3, 32, 32))) desired = module(inp) traced_net = tm.trace_module(module, inp) traced_net = traced_net.flatten() optimized_net = tm.optimize(traced_net, "FuseConvBn") actual = optimized_net(inp) np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) # fuse all mul to conv bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list() assert len(bn_list) == 0 bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list() assert len(bn_list) == 0