#!/usr/bin/env python3 from megskull.network import RawNetworkBuilder import megskull.opr.all as O from megskull.opr.external import TensorRTRuntimeOpr from meghair.utils.io import dump import argparse def str2tuple(x): x = x.split(',') x = [int(a) for a in x] x = tuple(x) return x def make_network(model, isize): data = [O.DataProvider('input{}'.format(i), shape=isizes[i]) for i in range(len(isizes))] f = open(model, 'rb') engine = f.read() opr = TensorRTRuntimeOpr(data, engine, 1) net = RawNetworkBuilder(inputs=[data], outputs=opr.outputs) return net if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument(dest = 'model') parser.add_argument(dest = 'output') parser.add_argument('--isize', help='input sizes. ' 'e.g. for models with two (1,3,224,224) inputs, ' 'the option --isize="1,3,224,224;1,3,224,224" should be used') args = parser.parse_args() isizes = [str2tuple(x) for x in args.isize.split(';')] net = make_network(args.model, isizes) dump(net, args.output)