# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as TF from tabulate import tabulate import megengine as mge import megengine.functional as MF import megengine.module as MM module_cache = { "conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()), "dw_conv2d": ( MM.Conv2d(32, 32, 3, 1, 0, groups=32), nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda(), ), "conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()), "ConvTranspose2d": ( MM.ConvTranspose2d(32, 32, 3, 1, 0), nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda(), ), "BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()), "Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()), } test_cases = [ # (mge op, torch op, small inps, large inps, unpack_inps, rep) ( "adaptive_avg_pool2d", lambda x: MF.adaptive_avg_pool2d(x, (7, 7)), lambda x: TF.adaptive_avg_pool2d(x, (7, 7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000, ), ( "adaptive_max_pool2d", lambda x: MF.adaptive_max_pool2d(x, (7, 7)), lambda x: TF.adaptive_max_pool2d(x, (7, 7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000, ), ("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000), ( "avg_pool2d", lambda x: MF.avg_pool2d(x, 2), lambda x: TF.avg_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000, ), ( "broadcast", lambda x: MF.broadcast_to(x, (5,) + x.shape), lambda x: torch.broadcast_to(x, (5,) + x.shape), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "batchedmatmul", MF.matmul, torch.matmul, [(8, 64, 32), (8, 32, 64)], [(8, 2048, 512), (8, 512, 2048)], True, 1000, ), ( "batchnrom2d", lambda x: module_cache["BatchNorm2d"][0](x), lambda x: module_cache["BatchNorm2d"][1](x), [(2, 64, 16, 16)], [(64, 64, 128, 128)], True, 1000, ), ( "concat", MF.concat, torch.cat, [(20, 100), (50, 100), (30, 100)], [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)], False, 1000, ), ( "conv2d", lambda x: module_cache["conv2d"][0](x), lambda x: module_cache["conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000, ), ( "conv3d", lambda x: module_cache["conv3d"][0](x), lambda x: module_cache["conv3d"][1](x), [(2, 32, 8, 8, 8)], [(32, 32, 16, 16, 16)], True, 1000, ), ( "convTranspose2d", lambda x: module_cache["ConvTranspose2d"][0](x), lambda x: module_cache["ConvTranspose2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000, ), ( "dropout", lambda x: MF.dropout(x, 0.5), TF.dropout, [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "dw_conv2d", lambda x: module_cache["dw_conv2d"][0](x), lambda x: module_cache["dw_conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000, ), ( "elemwise.unary", MF.log, torch.log, [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "elemwise.binary", MF.add, torch.add, [(100, 100), (100, 100)], [(64, 512, 16, 16), (64, 512, 16, 16)], True, 1000, ), ( "expand_dims", lambda x: MF.expand_dims(x, 0), lambda x: torch.unsqueeze(x, 0), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ("gelu", MF.gelu, TF.gelu, [(100, 100)], [(64, 512, 16, 16)], True, 1000), ("hswish", MF.hswish, TF.hardswish, [(100, 100)], [(64, 512, 16, 16)], True, 1000), ( "hsigmoid", MF.hsigmoid, TF.hardsigmoid, [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ("isinf", MF.isinf, torch.isinf, [(100, 100)], [(64, 512, 16, 16)], True, 1000), ( "indeixngMultiAxisVec", lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], [(10, 10, 10, 10)], [(64, 512, 16, 16)], True, 1000, ), ( "logsigmoid", MF.logsigmoid, TF.logsigmoid, [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "leaky_relu", lambda x: MF.leaky_relu(x, 0.5), lambda x: TF.leaky_relu(x, 0.5), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "linear", lambda x: module_cache["Linear"][0](x), lambda x: module_cache["Linear"][1](x), [(10, 1000)], [(64, 128, 1000)], True, 1000, ), ("matinv", MF.matinv, torch.inverse, [(10, 10)], [(30, 30)], True, 1000), ( "matmul", MF.matmul, torch.matmul, [(64, 32), (32, 64)], [(2048, 1024), (1024, 2048)], True, 1000, ), ( "max_pool2d", lambda x: MF.max_pool2d(x, 2), lambda x: TF.max_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000, ), ( "normal", lambda x: mge.random.normal(0, 1, x.shape), lambda x: torch.randn(x.shape, device="cuda"), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "prelu", MF.prelu, TF.prelu, [(100, 100), (1,)], [(64, 512, 16, 16), (1,)], True, 1000, ), ( "reduce.max", lambda x: MF.max(x, 0), lambda x: torch.max(x, 0), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ("relu", MF.relu, TF.relu, [(100, 100)], [(64, 512, 16, 16)], True, 1000), ("relu6", MF.relu6, TF.relu6, [(100, 100)], [(64, 512, 16, 16)], True, 1000), ( "repeat", lambda x: MF.repeat(x, 5), lambda x: torch.repeat_interleave(x, 5), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ("silu", MF.silu, TF.silu, [(100, 100)], [(64, 512, 16, 16)], True, 1000), ( "split", lambda x: MF.split(x, 5), lambda x: torch.split(x, 5), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ("sigmoid", MF.sigmoid, TF.sigmoid, [(100, 100)], [(64, 512, 16, 16)], True, 1000), ( "softmax", lambda x: MF.softmax(x, axis=1), lambda x: TF.softmax(x, dim=1), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "softplus", MF.softplus, TF.softplus, [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "squeeze", lambda x: MF.squeeze(x, 0), lambda x: torch.squeeze(x, 0), [(1, 100, 100)], [(1, 64, 512, 16, 16)], True, 1000, ), ( "stack", MF.stack, torch.stack, [(100, 100), (100, 100)], [(64, 512, 16, 16), (64, 512, 16, 16)], False, 10000, ), ( "subtensor", lambda x: x[0:20, 10:60], lambda x: x[0:20, 10:60], [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "topk", lambda x: MF.topk(x, 10), lambda x: torch.topk(x, 10), [(100, 100)], [(1000, 1000)], True, 1000, ), ( "tile", lambda x: MF.tile(x, (2,) * len(x.shape)), lambda x: torch.tile(x, (2,) * len(x.shape)), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "transpose", lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]), lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "where", lambda x: MF.where(x > 0.5, x, x), lambda x: torch.where(x > 0.5, x, x), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ( "uniform", lambda x: mge.random.uniform(0, 1, x.shape), lambda x: torch.rand(x.shape, device="cuda"), [(100, 100)], [(64, 512, 16, 16)], True, 1000, ), ] def perf_func(func, inps, reps, unpack_inps, is_mge): if is_mge: mge._full_sync() tik = time.time() for _ in range(reps): if unpack_inps: out = func(*inps) else: out = func(inps) mge._full_sync() else: torch.cuda.synchronize() with torch.no_grad(): tik = time.time() for _ in range(reps): if unpack_inps: out = func(*inps) else: out = func(inps) torch.cuda.synchronize() return time.time() - tik def get_avg_time(func, inps, reps, unpack_inps, is_mge): # warm up for _ in range(2): t = perf_func(func, inps, reps, unpack_inps, is_mge) times = [] for _ in range(5): t = perf_func(func, inps, reps, unpack_inps, is_mge) times.append(t) return np.mean(times) def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps): inps = [np.random.randn(*shape) for shape in shapes] inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps] avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True) inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps] avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False) return avg_time_mge, avg_time_torch if __name__ == "__main__": header = [ "opr_name", "time(mge/pytorch; small input)", "time(mge/pytorch; large input)", ] table = [] for case in test_cases: assert len(case) == 7 name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case data = [] data.append(name) print("========== op: {}".format(name)) avg_time_mge, avg_time_torch = get_perf_results( mge_func, torch_func, small_shapes, unpack_inps, reps ) print("mge time: {}".format(avg_time_mge)) print("torch time: {}".format(avg_time_torch)) data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) avg_time_mge, avg_time_torch = get_perf_results( mge_func, torch_func, large_shapes, unpack_inps, reps ) print("mge time: {}".format(avg_time_mge)) print("torch time: {}".format(avg_time_torch)) data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) table.append(data) print(tabulate(table, header, tablefmt="github"))