#!/usr/bin/env python3 import megbrain as mgb from megskull.graph import FpropEnv import megskull as mgsk from megskull.opr.compatible.caffepool import CaffePooling2D from megskull.opr.arith import ReLU from megskull.opr.all import ( DataProvider, Conv2D, Pooling2D, FullyConnected, Softmax, Dropout, BatchNormalization, CrossEntropyLoss, ElementwiseAffine, WarpPerspective, WarpPerspectiveWeightProducer, WeightDecay, ParamProvider, ConvBiasActivation, ElemwiseMultiType) from megskull.network import RawNetworkBuilder from megskull.utils.debug import CallbackInjector import megskull.opr.helper.param_init as pinit from megskull.opr.helper.elemwise_trans import Identity from megskull.opr.netsrc import DataProvider from megskull.opr.cnn import Conv2D, Pooling2D, FullyConnected, Softmax, Conv2DImplHelper from megskull.opr.loss import CrossEntropyLoss from megskull.opr.regularizer import Dropout, BatchNormalization from megskull.opr.arith import Add, ReLU from megskull.opr.netsrc import ConstProvider from megskull.network import RawNetworkBuilder import numpy as np from megskull.network import RawNetworkBuilder, NetworkVisitor from megskull.graph import iter_dep_opr from megskull.utils.misc import get_2dshape import functools import re import fnmatch import argparse import sys def create_bn_relu_float(conv_name, f_in, ksize, stride, pad, num_outputs, has_relu, args): f = Conv2D(conv_name, f_in, kernel_shape=ksize, stride=stride, padding=pad, output_nr_channel=num_outputs, nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) if has_relu: f = ReLU(f) return f def get_num_inputs(feature, format): if format == 'NCHW': return feature.partial_shape[1] else: assert format == 'NCHW4' return feature.partial_shape[1] * 4 def create_bn_relu(prefix, f_in, ksize, stride, pad, num_outputs, has_relu, conv_name_fun, args): if conv_name_fun: conv_name = conv_name_fun(prefix) else: conv_name = prefix return create_bn_relu_float(conv_name, f_in, ksize, stride, pad, num_outputs, has_relu, args) def create_bottleneck(prefix, f_in, stride, num_outputs1, num_outputs2, args, has_proj=False): proj = f_in if has_proj: proj = create_bn_relu(prefix, f_in, ksize=1, stride=stride, pad=0, num_outputs=num_outputs2, has_relu=False, conv_name_fun=lambda p: "interstellar{}_branch1".format( p), args=args) f = create_bn_relu(prefix, f_in, ksize=1, stride=1, pad=0, num_outputs=num_outputs1, has_relu=True, conv_name_fun=lambda p: "interstellar{}_branch2a".format( p), args=args) f = create_bn_relu(prefix, f, ksize=3, stride=stride, pad=1, num_outputs=num_outputs1, has_relu=True, conv_name_fun=lambda p: "interstellar{}_branch2b".format( p), args=args) f = create_bn_relu(prefix, f, ksize=1, stride=1, pad=0, num_outputs=num_outputs2, has_relu=False, conv_name_fun=lambda p: "interstellar{}_branch2c".format( p), args=args) f = ReLU(f + proj) return f def get(args): img_size = 224 num_inputs = 3 data = DataProvider('data', shape=(args.batch_size, num_inputs, img_size, img_size)) inp = data f = create_bn_relu("conv1", inp, ksize=7, stride=2, pad=3, num_outputs=64, has_relu=True, conv_name_fun=None, args=args) f = Pooling2D("pool1", f, window=3, stride=2, padding=1, mode="MAX", format=args.format) pre = [2, 3, 4, 5] stages = [3, 4, 6, 3] mid_outputs = [64, 128, 256, 512] enable_stride = [False, True, True, True] for p, s, o, es in zip(pre, stages, mid_outputs, enable_stride): for i in range(s): has_proj = False if i > 0 else True stride = 1 if not es or i > 0 else 2 prefix = "{}{}".format(p, chr(ord("a") + i)) f = create_bottleneck(prefix, f, stride, o, o * 4, args, has_proj) print("{}\t{}".format(prefix, f.partial_shape)) f = Pooling2D("pool5", f, window=7, stride=7, padding=0, mode="AVERAGE", format=args.format) f = FullyConnected("fc1000", f, output_dim=1000, nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) f = Softmax("cls_softmax", f) f.init_weights() net = RawNetworkBuilder(inputs=[data], outputs=[f]) return net if __name__ == '__main__': parser = argparse.ArgumentParser( description='dump pkl model for resnet50', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-b', '--batch-size', help='batch size of the model', default=1) parser.add_argument('-f', '--format', choices=['NCHW', 'NCHW4'], help='format of conv', default='NCHW') parser.add_argument('-o', '--output', help='output pkl path', required=True) args = parser.parse_args() if args.format != 'NCHW': print('Only suppprt NCHW for float model') parser.print_help() sys.exit(1) from meghair.utils import io io.dump(get(args), args.output)