# Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import re import subprocess import time from common.benchmark_command import BenchmarkCommand # Regexes for retrieving memory information. _VMHWM_REGEX = re.compile(r".*?VmHWM:.*?(\d+) kB.*") _VMRSS_REGEX = re.compile(r".*?VmRSS:.*?(\d+) kB.*") _RSSFILE_REGEX = re.compile(r".*?RssFile:.*?(\d+) kB.*") def run_command(benchmark_command: BenchmarkCommand) -> list[float]: """Runs `benchmark_command` and polls for memory consumption statistics. Args: benchmark_command: A `BenchmarkCommand` object containing information on how to run the benchmark and parse the output. Returns: An array containing values for [`latency`, `vmhwm`, `vmrss`, `rssfile`] """ command = benchmark_command.generate_benchmark_command() print("\n\nRunning command:\n" + " ".join(command)) benchmark_process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT ) # Keep a record of the highest VmHWM corresponding VmRSS and RssFile values. vmhwm = 0 vmrss = 0 rssfile = 0 while benchmark_process.poll() is None: pid_status = subprocess.run( ["cat", "/proc/" + str(benchmark_process.pid) + "/status"], capture_output=True, ) output = pid_status.stdout.decode() vmhwm_matches = _VMHWM_REGEX.search(output) vmrss_matches = _VMRSS_REGEX.search(output) rssfile_matches = _RSSFILE_REGEX.search(output) if vmhwm_matches and vmrss_matches and rssfile_matches: curr_vmhwm = float(vmhwm_matches.group(1)) if curr_vmhwm > vmhwm: vmhwm = curr_vmhwm vmrss = float(vmrss_matches.group(1)) rssfile = float(rssfile_matches.group(1)) time.sleep(0.5) stdout_data, _ = benchmark_process.communicate() if benchmark_process.returncode != 0: print( f"Warning! Benchmark command failed with return code:" f" {benchmark_process.returncode}" ) return [0, 0, 0, 0] else: print(stdout_data.decode()) latency_ms = benchmark_command.parse_latency_from_output(stdout_data.decode()) return [latency_ms, vmhwm, vmrss, rssfile]