from collections import OrderedDict import numpy as np import megengine.functional as F import megengine.module as M from megengine import Tensor from megengine.core._imperative_rt.core2 import apply from megengine.core.ops import builtin from megengine.module import Module from megengine.traced_module import TracedModule, enable_expr_checker, trace_module from megengine.traced_module.expr import Apply, CallFunction, Constant class MyModule1(M.Module): def forward(self, x): y = Tensor(x) y += 1 x = x + 2 return x, y class MyModule2(M.Module): def forward(self, x): y = Tensor([1, x, 1]) y += 1 x = x + 2 return x, y class MyModule3(M.Module): def __init__(self): super().__init__() self.modules = [ M.Elemwise("ADD"), M.Elemwise("ADD"), OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]), M.Elemwise("RELU"), M.Elemwise("RELU"), ] def forward(self, a, b): x = self.modules[0](a, b) y = self.modules[1](a, b) assert list(self.modules[2].keys()) == ["a", "b"] for _, m in self.modules[2].items(): y = m(x, y) for m in self.modules[3:]: y = m(y) return y class MyModule4(M.Module): def __init__(self): super().__init__() self.add = F.add def forward(self, x, y): return self.add(x, y) def test_trace_module(): enable_expr_checker() x = Tensor(1) m1 = MyModule1() tm1 = trace_module(m1, x) m2 = MyModule2() tm2 = trace_module(m2, x) inp = Tensor(2) gt = m1(inp) output = tm1(inp) for a, b in zip(output, gt): np.testing.assert_equal(a.numpy(), b.numpy()) gt1 = m2(inp) output1 = tm2(inp) for a, b in zip(output1, gt1): np.testing.assert_equal(a.numpy(), b.numpy()) a, b = Tensor(1), Tensor(2) m3 = MyModule3() gt = m3(a, b) tm3 = trace_module(m3, a, b) out = tm3(a, b) np.testing.assert_equal(out.numpy(), gt.numpy()) assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) assert isinstance(tm3.modules.__dict__["2"], TracedModule) assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) assert isinstance(tm3.modules.__dict__["3"], M.Elemwise) m4 = MyModule4() tm4 = trace_module(m4, a, b) np.testing.assert_equal(tm4(a, b).numpy(), 3) np.testing.assert_equal(tm4(a, y=b).numpy(), 3) np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) tm4 = trace_module(m4, a, y=b) np.testing.assert_equal(tm4(a, b).numpy(), 3) np.testing.assert_equal(tm4(a, y=b).numpy(), 3) np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) tm4 = trace_module(m4, x=a, y=b) np.testing.assert_equal(tm4(a, b).numpy(), 3) np.testing.assert_equal(tm4(a, y=b).numpy(), 3) np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) tm5 = trace_module(tm4, a, b) np.testing.assert_equal(tm5(a, b).numpy(), 3) np.testing.assert_equal(tm5(a, y=b).numpy(), 3) np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) tm5 = trace_module(tm4, a, y=b) np.testing.assert_equal(tm5(a, b).numpy(), 3) np.testing.assert_equal(tm5(a, y=b).numpy(), 3) np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) tm5 = trace_module(tm4, x=a, y=b) np.testing.assert_equal(tm5(a, b).numpy(), 3) np.testing.assert_equal(tm5(a, y=b).numpy(), 3) np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) assert len(tm4.graph._exprs) == 1 assert isinstance(tm4.graph._exprs[0], CallFunction) class MyModule5(Module): def __init__(self): super().__init__() self.m1 = tm4 def forward(self, x, y): return self.m1(x, y) tm6 = trace_module(MyModule5(), a, b) assert tm6.m1.argspec is None assert tm6.m1._is_top is False def test_trace_module_2(): class Model(M.Module): def __init__(self): super().__init__() def forward(self, x): out = x.shape out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1)) return out traced_model = trace_module(Model(), Tensor(([1,]))) assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance( traced_model.graph._exprs[0].opdef, builtin.GetVarShape ) assert isinstance(traced_model.graph._exprs[1], Constant) assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance( traced_model.graph._exprs[2].opdef, builtin.Elemwise ) assert int(traced_model(Tensor([1, 2]))[0]) == 3