# Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import abc import re from typing import Optional class BenchmarkCommand(abc.ABC): """Abstracts a benchmark command.""" def __init__( self, benchmark_binary: str, model_name: str, num_threads: int, num_runs: int, driver: Optional[str] = None, taskset: Optional[str] = None, ): self.benchmark_binary = benchmark_binary self.model_name = model_name self.taskset = taskset self.num_threads = num_threads self.num_runs = num_runs self.driver = driver self.args = [] @property @abc.abstractmethod def runtime(self): pass @abc.abstractmethod def parse_latency_from_output(self, output: str) -> float: pass def generate_benchmark_command(self) -> list[str]: """Returns a list of strings that correspond to the command to be run.""" command = [] if self.taskset: command.append("taskset") command.append(str(self.taskset)) command.append(self.benchmark_binary) command.extend(self.args) return command class TFLiteBenchmarkCommand(BenchmarkCommand): """Represents a TFLite benchmark command.""" def __init__( self, benchmark_binary: str, model_name: str, model_path: str, num_threads: int, num_runs: int, taskset: Optional[str] = None, ): super().__init__( benchmark_binary, model_name, num_threads, num_runs, taskset=taskset ) self.args.append("--graph=" + model_path) self._latency_large_regex = re.compile( r".*?Inference \(avg\): (\d+.?\d*e\+?\d*).*" ) self._latency_regex = re.compile(r".*?Inference \(avg\): (\d+).*") @property def runtime(self): return "tflite" def parse_latency_from_output(self, output: str) -> float: # First match whether a large number has been recorded e.g. 1.18859e+06. matches = self._latency_large_regex.search(output) if not matches: # Otherwise, regular number e.g. 71495.6. matches = self._latency_regex.search(output) latency_ms = 0 if matches: latency_ms = float(matches.group(1)) / 1000 else: print("Warning! Could not parse latency. Defaulting to 0ms.") return latency_ms def generate_benchmark_command(self) -> list[str]: command = super().generate_benchmark_command() if self.driver == "gpu": command.append("--use_gpu=true") command.append("--num_threads=" + str(self.num_threads)) command.append("--num_runs=" + str(self.num_runs)) return command class IreeBenchmarkCommand(BenchmarkCommand): """Represents an IREE benchmark command.""" def __init__( self, benchmark_binary: str, model_name: str, model_path: str, num_threads: int, num_runs: int, taskset: Optional[str] = None, ): super().__init__( benchmark_binary, model_name, num_threads, num_runs, taskset=taskset ) self.args.append("--module=" + model_path) self._latency_regex = re.compile( r".*?BM_main/process_time/real_time_mean\s+(.*?) ms.*" ) @property def runtime(self): return "iree" def parse_latency_from_output(self, output: str) -> float: matches = self._latency_regex.search(output) latency_ms = 0 if matches: latency_ms = float(matches.group(1)) else: print("Warning! Could not parse latency. Defaulting to 0ms.") return latency_ms def generate_benchmark_command(self) -> list[str]: command = super().generate_benchmark_command() command.append("--device=" + self.driver) command.append("--task_topology_max_group_count=" + str(self.num_threads)) command.append("--benchmark_repetitions=" + str(self.num_runs)) return command