#! /usr/bin/env python3 # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse import collections import json import re import textwrap import numpy as np from tabulate import tabulate from megengine.utils.profile_analyzer import ( NonExistNum, ProfileAnalyzer, TimeFuncHelper, ) def _tabulate_ml(tab, **kwargs): r"""Tabulate profile output with multi-line support.""" new_tab = [] new_tab_is_row = [] for row in tab: col_lines = [str(i).split("\n") for i in row] max_nr_line = max(map(len, col_lines)) new_tab_is_row.append(True) if max_nr_line > 1: new_tab_is_row.extend([False] * (max_nr_line - 1)) for i in col_lines: if len(i) < max_nr_line: i.extend([""] * (max_nr_line - len(i))) new_tab.extend(zip(*col_lines)) else: new_tab.append(row) assert len(new_tab_is_row) == len(new_tab) ret = [i + "\n" for i in tabulate(new_tab, **kwargs).split("\n")] for idx, val in enumerate(new_tab_is_row): if not val: ret[idx * 2 + 2] = "" return "".join(ret)[:-1] def _tabulate_confluence(tab, **kwargs): r"""Tabulate profile output.""" kwargs.pop("tablefmt", None) s = tabulate(tab, tablefmt="orgtbl", **kwargs) lines = s.split("\n") lines[1] = lines[1].replace("+", "|") return "\n".join(lines) def main(passed_args=None): # pylint: disable=too-many-statements r"""Analyses profile info from :mod:`~.utils.profile_analyzer` . Run this file with ``--help`` to get more usage. """ parser = argparse.ArgumentParser( description="analyze analyzer result", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("dump") parser.add_argument( "-t", "--top", type=int, default=3, help="number of most time-consuming operators to print", ) parser.add_argument( "--type", action="append", help="filter oprs in the top list by type" ) parser.add_argument( "--aggregate-by", default=None, choices=["type"], help="aggragate profiling result by", ) parser.add_argument( "--opr-name", help="filter oprs in the top list by regex of name" ) parser.add_argument( "--input-dtype", type=str, help="filter oprs in the top list by input dtype" ) parser.add_argument( "--top-end-key", default="end", choices=["end", "kern"], help="how time in top is calculated; end corresponds " "to total device time, and kern corresponds to only " "wait time", ) parser.add_argument( "--aggregate", default=None, help="aggregate operations", choices=["max", "min", "sum", "mean"], ) parser.add_argument( "--order-by", default="time", help="sort result according to given column; the param can be " " or +, meaning sorting in descending or " "ascending order respectively", ) parser.add_argument( "--copy-time", action="store_true", help="show copy time related result" ) parser.add_argument( "--min-time", type=float, default=float("-inf"), help="minimal time of a result to be printed", ) parser.add_argument( "--max-time", type=float, default=float("inf"), help="maximal time of a result to be printed", ) parser.add_argument( "--show-host", action="store_true", help="show host profiling info" ) parser.add_argument( "--dump-only-opr", action="store_true", help="only dump operator info as plaintext; useful " "for diff between two filtered profile results", ) parser.add_argument( "--confluence", "--wiki", action="store_true", help="output confluence-markdown-compatible table", ) parser.add_argument( "--print-only", choices={"summary", "device", "host"}, help="print only chosen info", ) args = parser.parse_args(passed_args) opr_filters = [] if args.type: opr_filters.append(lambda o, a, b: o["type"] in args.type) if args.opr_name: opr_filters.append( lambda o, a, b, r=re.compile(args.opr_name): r.match(o["name"]) ) if args.input_dtype: opr_filters.append( lambda o, a, b: any( [i["mem_plan"]["layout"]["dtype"] == args.input_dtype for i in a] ) ) if not opr_filters: def opr_filter(o, a, b): # pylint: disable=unused-argument return True else: def opr_filter(o, a, b): return all(i(o, a, b) for i in opr_filters) with open(args.dump) as fin: dump = json.load(fin) analyzer = ProfileAnalyzer(dump, opr_filter) analyzer_tot = ProfileAnalyzer(dump, lambda _, __, ___: True) def summary(): device_end_func = TimeFuncHelper.eval_time_func("device", "end", np.max) device_kern_func = TimeFuncHelper.eval_time_func("device", "kern", np.max) host_end_func = TimeFuncHelper.eval_time_func("host", "end", np.max) def get_tot_time(func): rec = analyzer_tot.select(func, aggregate=np.sum) if not rec: return "N/A" rec = rec[0] return rec.time tab = [] tot_dev_time = get_tot_time(device_end_func) tot_host_time = get_tot_time(host_end_func) tab.append(("total device time", tot_dev_time)) tab.append(("total host time", tot_host_time)) if args.copy_time: def fmt(a, b): a = a[0] b = b[0] return "tot={:.4f} avg={:.4f}".format(a.time, b.time) tab.append( ( "copy time", fmt( analyzer.select( device_end_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.sum, ), analyzer.select( device_end_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.mean, ), ), ) ) tab.append( ( "copy wait time", fmt( analyzer.select( device_kern_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.sum, ), analyzer.select( device_kern_func, lambda opr: opr.opr_info["type"] == "Copy", aggregate=np.mean, ), ), ) ) if args.confluence: tab_str = _tabulate_confluence(tab, headers=["name", "value"]) else: tab_str = tabulate(tab) return tab_str, tot_dev_time, tot_host_time def prof_details(prof_type, tot_time): tab = [] def func( opr, *, f0=TimeFuncHelper.eval_time_func(prof_type, args.top_end_key, np.max) ): t = f0(opr) if t is not None and (t < args.min_time or t > args.max_time): return None return t records = analyzer.select( func, aggregate=args.aggregate, aggregate_by=args.aggregate_by, top_k=args.top, sort_by=args.order_by, ) if args.dump_only_opr: ret = [] for i in records: ret.append(" ".join(i.info.values())) return "\n".join(ret) def format_shapes(shapes, layouts=None, sep="\n"): if isinstance(shapes, NonExistNum) or shapes is None: return repr(shapes) if layouts is None: layouts = [None] * len(shapes) comp = [] for i, j in zip(shapes, layouts): i = "{" + ",".join(map(str, i)) + "}" if j: i += "\n -[" + ",".join(map(str, j)) + "]" comp.append(i) return sep.join(comp) def fix_num_and_find_unit(x, base): if isinstance(x, NonExistNum) or ( isinstance(x, float) and not np.isfinite(x) ): return x, "" unit = iter(["", "K", "M", "G", "T", "P"]) while x >= base: x /= base next(unit) return x, next(unit) def get_number_with_unit(num, unit, base, sep="\n"): num, unit_prefix = fix_num_and_find_unit(num, base) if isinstance(unit, list): unit = unit[int(unit_prefix != "")] return ("{:.2f}" + sep + "{}{}").format(num, unit_prefix, unit) if args.confluence: rows = [] cum_time = 0 max_time = max([r.time for r in records]) max_bandwidth = max([r.bandwidth for r in records]) max_flops = max( [r.flops for r in records if not isinstance(r.flops, NonExistNum)] ) bar_length = 15 for idx, record in enumerate(records): cum_time += record.time opr_info = [("opr " + k, v) for k, v in record.info.items()] row = collections.OrderedDict( [ ("#", idx), ("time", "{:.3}".format(record.time)), ("ratio", "{:.1f}%".format(record.time / tot_time * 100)), ("time bar", "#" * int(record.time / max_time * bar_length)), ("cum-time", cum_time), ("cum-time ratio", cum_time / tot_time), ] + opr_info + [ ( "computation (MFLO)", "{:.1f}".format(record.computation / 1000 ** 2), ), ("MFLOPS", "{:.1f}".format(record.flops / 1000 ** 2)), ( "MFLOPS-bar", "" if isinstance(record.flops, NonExistNum) else ("#" * int(record.flops / max_flops * bar_length)), ), ("memory (MB)", "{:.1f}".format(record.memory / 1024 ** 2)), ( "bandwidth (MiB/s)", "{:.1f}".format(record.bandwidth / 1024 ** 2), ), ( "bandwidth bar", "#" * int(record.bandwidth / max_bandwidth * bar_length), ), ( "in_shapes", format_shapes( record.in_shapes, record.in_layouts, sep=", " ), ), ("out_shapes", format_shapes(record.out_shapes, sep=", ")), ] ) rows.append(row) headers = list(rows[0].keys()) tab = [[row[i] for i in headers] for row in rows] return _tabulate_confluence(tab, headers=headers) else: cum_time = 0 for idx, record in enumerate(records): cum_time += record.time tab.append( ( "#{}\n{:.3}\n{:.1f}%".format( idx, record.time, record.time / tot_time * 100 ), "{:.3}\n{:.1f}%".format(cum_time, cum_time / tot_time * 100), "\n".join( "\n- ".join(textwrap.wrap(str(i), width=30)) for i in record.info.values() ), get_number_with_unit(record.computation, "FLO", 1000), get_number_with_unit(record.flops, "FLOPS", 1000), get_number_with_unit(record.memory, ["byte", "iB"], 1024), get_number_with_unit( record.bandwidth, ["byte/s", "iB/s"], 1024 ), format_shapes(record.in_shapes, record.in_layouts), format_shapes(record.out_shapes), ) ) return _tabulate_ml( tab, headers=[ "{} self time".format(prof_type), "cumulative", "operator info", "computation", "FLOPS", "memory", "bandwidth", "in_shapes", "out_shapes", ], tablefmt="fancy_grid", ) summary_tab, tot_dev_time, tot_host_time = summary() if args.print_only: print( { "summary": lambda: summary_tab, "device": lambda: prof_details("device", tot_dev_time), "host": lambda: prof_details("host", tot_host_time), }[args.print_only]() ) else: print(summary_tab) print() print(prof_details("device", tot_dev_time)) if args.show_host: print() print(prof_details("host", tot_host_time)) if __name__ == "__main__": main()