# Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # All benchmarks' relative path against root build directory. from argparse import Namespace from dataclasses import dataclass from typing import Optional import pathlib import re from common import benchmark_definition BENCHMARK_RESULTS_REL_PATH = "benchmark-results" CAPTURES_REL_PATH = "captures" @dataclass class TraceCaptureConfig: """Represents the settings for capturing traces during benchamrking. traced_benchmark_tool_dir: the path to the tracing-enabled benchmark tool directory. trace_capture_tool: the path to the tool for collecting captured traces. capture_tarball: the path of capture tar archive. capture_tmp_dir: the temporary directory to store captured traces. """ traced_benchmark_tool_dir: pathlib.Path trace_capture_tool: pathlib.Path capture_tarball: pathlib.Path capture_tmp_dir: pathlib.Path @dataclass class BenchmarkConfig: """Represents the settings to run benchmarks. tmp_dir: per-commit temporary directory. root_benchmark_dir: the root directory path/URL containing the built benchmark suites. benchmark_results_dir: the directory to store benchmark results files. git_commit_hash: the git commit hash. normal_benchmark_tool_dir: the path to the non-traced benchmark tool directory. trace_capture_config: the config for capturing traces. Set if and only if the traces need to be captured. driver_filter: filter benchmarks to those whose driver matches this regex (or all if this is None). model_name_filter: filter benchmarks to those whose model name matches this regex (or all if this is None). mode_filter: filter benchmarks to those whose benchmarking mode matches this regex (or all if this is None). keep_going: whether to proceed if an individual run fails. Exceptions will logged and returned. benchmark_min_time: min number of seconds to run the benchmark for, if specified. Otherwise, the benchmark will be repeated a fixed number of times. continue_from_previous: skip the benchmarks if their results are found in the benchmark_results_dir. verify: verify the output if model's expected output is available. """ tmp_dir: pathlib.Path root_benchmark_dir: benchmark_definition.ResourceLocation benchmark_results_dir: pathlib.Path git_commit_hash: str normal_benchmark_tool_dir: Optional[pathlib.Path] = None trace_capture_config: Optional[TraceCaptureConfig] = None driver_filter: Optional[str] = None model_name_filter: Optional[str] = None mode_filter: Optional[str] = None use_compatible_filter: bool = False keep_going: bool = False benchmark_min_time: float = 0 continue_from_previous: bool = False verify: bool = False @staticmethod def build_from_args(args: Namespace, git_commit_hash: str): """Build config from command arguments. Args: args: the command arguments. git_commit_hash: the git commit hash of IREE. """ def real_path_or_none(path: Optional[pathlib.Path]) -> Optional[pathlib.Path]: return path.resolve() if path else None if not args.normal_benchmark_tool_dir and not args.traced_benchmark_tool_dir: raise ValueError( "At least one of --normal_benchmark_tool_dir or --traced_benchmark_tool_dir should be specified." ) if not ( (args.traced_benchmark_tool_dir is None) == (args.trace_capture_tool is None) == (args.capture_tarball is None) ): raise ValueError( "The following 3 flags should be simultaneously all specified or all unspecified: --traced_benchmark_tool_dir, --trace_capture_tool, --capture_tarball" ) per_commit_tmp_dir: pathlib.Path = (args.tmp_dir / git_commit_hash).resolve() if args.traced_benchmark_tool_dir is None: trace_capture_config = None else: trace_capture_config = TraceCaptureConfig( traced_benchmark_tool_dir=args.traced_benchmark_tool_dir.resolve(), trace_capture_tool=args.trace_capture_tool.resolve(), capture_tarball=args.capture_tarball.resolve(), capture_tmp_dir=per_commit_tmp_dir / CAPTURES_REL_PATH, ) root_benchmark_dir = args.e2e_test_artifacts_dir # Convert the local path into Path object. if re.match("^[^:]+://", str(root_benchmark_dir)): root_benchmark_dir = benchmark_definition.ResourceLocation.build_url( root_benchmark_dir ) else: root_benchmark_dir = benchmark_definition.ResourceLocation.build_local_path( root_benchmark_dir ) return BenchmarkConfig( tmp_dir=per_commit_tmp_dir, root_benchmark_dir=root_benchmark_dir, benchmark_results_dir=per_commit_tmp_dir / BENCHMARK_RESULTS_REL_PATH, git_commit_hash=git_commit_hash, normal_benchmark_tool_dir=real_path_or_none(args.normal_benchmark_tool_dir), trace_capture_config=trace_capture_config, driver_filter=args.driver_filter_regex, model_name_filter=args.model_name_regex, mode_filter=args.mode_regex, use_compatible_filter=args.compatible_only, keep_going=args.keep_going, benchmark_min_time=args.benchmark_min_time, continue_from_previous=args.continue_from_previous, verify=args.verify, )