import io import pickle import numpy as np import megengine.functional as F import megengine.module as M import megengine.utils.comp_graph_tools as cgtools from megengine.core._trace_option import set_symbolic_shape from megengine.jit import trace from megengine.traced_module import trace_module set_symbolic_shape(True) class Main(M.Module): def forward(self, x): return x class PreProcess(M.Module): def __init__(self): super().__init__() self.I = F.ones((1,)) self.M = F.zeros((1,)) def forward(self, data, idx, roi): N, H, W, C = data.shape xmax = roi[:, 1, 0] xmin = roi[:, 0, 0] ymax = roi[:, 1, 1] ymin = roi[:, 0, 1] scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H) I = F.broadcast_to(self.I, (N,)) M = F.broadcast_to(self.M, (N, 3, 3)) M[:, 0, 0] = scale M[:, 0, 2] = xmin M[:, 1, 1] = scale M[:, 1, 2] = ymin M[:, 2, 2] = I resized = ( F.warp_perspective( data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC" ) .transpose(0, 3, 1, 2) .astype(np.float32) ) return resized class Net(M.Module): def __init__(self, traced_module): super().__init__() self.pre_process = PreProcess() self.traced_module = traced_module def forward(self, data, idx, roi): x = self.pre_process(data, idx, roi) x = self.traced_module(x) return x def test_preprocess(): module = Main() data = F.ones((1, 14, 8, 8), dtype=np.uint8) traced_module = trace_module(module, data) obj = pickle.dumps(traced_module) traced_module = pickle.loads(obj) module = Net(traced_module) module.eval() idx = F.zeros((1,), dtype=np.int32) roi = F.ones((1, 2, 2), dtype=np.float32) y = module(data, idx, roi) traced_module = trace_module(module, data, idx, roi) np.testing.assert_array_equal(traced_module(data, idx, roi), y) func = trace(traced_module, capture_as_const=True) np.testing.assert_array_equal(func(data, idx, roi), y) model = io.BytesIO() func.dump(model, arg_names=("data", "idx", "roi")) model.seek(0) infer_cg = cgtools.GraphInference(model) np.testing.assert_allclose( list( infer_cg.run( inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()} ).values() )[0], y, atol=1e-6, )