# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import io import numpy as np import megengine.functional as F import megengine.module as M import megengine.utils.comp_graph_tools as cgtools from megengine.jit import trace from megengine.module import Module from megengine.traced_module import trace_module class MyBlock(Module): def __init__(self, in_channels, channels): super(MyBlock, self).__init__() self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) self.bn1 = M.BatchNorm2d(channels) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = F.relu(x) + 1 return x class MyModule(Module): def __init__(self): super(MyModule, self).__init__() self.block0 = MyBlock(8, 4) self.block1 = MyBlock(4, 2) def forward(self, x): x = self.block0(x) x = self.block1(x) return x def test_jit_trace(): module = MyModule() module.eval() x = F.ones((1, 8, 14, 14)) expect = module(x) traced_module = trace_module(module, x) func = trace(traced_module, capture_as_const=True) np.testing.assert_array_equal(func(x), expect) model = io.BytesIO() func.dump(model) model.seek(0) infer_cg = cgtools.GraphInference(model) np.testing.assert_allclose( list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6 )