# Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import dataclasses import json import pathlib import tempfile from typing import Optional import unittest from common import benchmark_config from common.benchmark_suite import BenchmarkCase, BenchmarkSuite from common.benchmark_driver import BenchmarkDriver from common import benchmark_definition from common.benchmark_definition import ( IREE_DRIVERS_INFOS, DeviceInfo, PlatformType, BenchmarkLatency, BenchmarkMemory, BenchmarkMetrics, ) from e2e_test_framework.definitions import common_definitions, iree_definitions class FakeBenchmarkDriver(BenchmarkDriver): def __init__( self, *args, raise_exception_on_case: Optional[BenchmarkCase] = None, **kwargs ): super().__init__(*args, **kwargs) self.raise_exception_on_case = raise_exception_on_case self.run_benchmark_cases = [] def run_benchmark_case( self, benchmark_case: BenchmarkCase, benchmark_results_filename: Optional[pathlib.Path], capture_filename: Optional[pathlib.Path], ) -> None: if self.raise_exception_on_case == benchmark_case: raise Exception("fake exception") self.run_benchmark_cases.append(benchmark_case) if benchmark_results_filename: fake_benchmark_metrics = BenchmarkMetrics( real_time=BenchmarkLatency(0, 0, 0, "ns"), cpu_time=BenchmarkLatency(0, 0, 0, "ns"), host_memory=BenchmarkMemory(0, 0, 0, 0, "bytes"), device_memory=BenchmarkMemory(0, 0, 0, 0, "bytes"), raw_data={}, ) benchmark_results_filename.write_text( json.dumps(fake_benchmark_metrics.to_json_object()) ) if capture_filename: capture_filename.write_text("{}") class BenchmarkDriverTest(unittest.TestCase): def setUp(self): self._tmp_dir_obj = tempfile.TemporaryDirectory() self._root_dir_obj = tempfile.TemporaryDirectory() self.tmp_dir = pathlib.Path(self._tmp_dir_obj.name) (self.tmp_dir / "build_config.txt").write_text( "IREE_HAL_DRIVER_LOCAL_SYNC=ON\n" "IREE_HAL_DRIVER_LOCAL_TASK=ON\n" "IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF=ON\n" ) self.benchmark_results_dir = ( self.tmp_dir / benchmark_config.BENCHMARK_RESULTS_REL_PATH ) self.captures_dir = self.tmp_dir / benchmark_config.CAPTURES_REL_PATH self.benchmark_results_dir.mkdir() self.captures_dir.mkdir() self.config = benchmark_config.BenchmarkConfig( tmp_dir=self.tmp_dir, root_benchmark_dir=benchmark_definition.ResourceLocation.build_local_path( self._root_dir_obj.name ), benchmark_results_dir=self.benchmark_results_dir, git_commit_hash="abcd", normal_benchmark_tool_dir=self.tmp_dir, trace_capture_config=benchmark_config.TraceCaptureConfig( traced_benchmark_tool_dir=self.tmp_dir, trace_capture_tool=self.tmp_dir / "capture_tool", capture_tarball=self.tmp_dir / "captures.tar", capture_tmp_dir=self.captures_dir, ), use_compatible_filter=True, ) self.device_info = DeviceInfo( platform_type=PlatformType.LINUX, model="Unknown", cpu_abi="x86_64", cpu_uarch="CascadeLake", cpu_features=[], gpu_name="unknown", ) model_tflite = common_definitions.Model( id="tflite", name="model_tflite", tags=[], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, source_url="", entry_function="predict", input_types=["1xf32"], ) device_spec = common_definitions.DeviceSpec.build( id="dev", device_name="test_dev", architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, host_environment=common_definitions.HostEnvironment.LINUX_X86_64, device_parameters=[], tags=[], ) compile_target = iree_definitions.CompileTarget( target_backend=iree_definitions.TargetBackend.LLVM_CPU, target_architecture=( common_definitions.DeviceArchitecture.X86_64_CASCADELAKE ), target_abi=iree_definitions.TargetABI.LINUX_GNU, ) gen_config = iree_definitions.ModuleGenerationConfig.build( imported_model=iree_definitions.ImportedModel.from_model(model_tflite), compile_config=iree_definitions.CompileConfig.build( id="comp_a", tags=[], compile_targets=[compile_target] ), ) exec_config_a = iree_definitions.ModuleExecutionConfig.build( id="exec_a", tags=["sync"], loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, driver=iree_definitions.RuntimeDriver.LOCAL_SYNC, ) run_config_a = iree_definitions.E2EModelRunConfig.build( module_generation_config=gen_config, module_execution_config=exec_config_a, target_device_spec=device_spec, input_data=common_definitions.DEFAULT_INPUT_DATA, tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, ) exec_config_b = iree_definitions.ModuleExecutionConfig.build( id="exec_b", tags=["task"], loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, driver=iree_definitions.RuntimeDriver.LOCAL_TASK, ) run_config_b = iree_definitions.E2EModelRunConfig.build( module_generation_config=gen_config, module_execution_config=exec_config_b, target_device_spec=device_spec, input_data=common_definitions.DEFAULT_INPUT_DATA, tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, ) self.case1 = BenchmarkCase( model_name="model_tflite", model_tags=[], bench_mode=["sync"], target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu-sync"], module_dir=benchmark_definition.ResourceLocation.build_local_path("case1"), benchmark_tool_name="tool", run_config=run_config_a, ) self.case2 = BenchmarkCase( model_name="model_tflite", model_tags=[], bench_mode=["task"], target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], module_dir=benchmark_definition.ResourceLocation.build_local_path("case2"), benchmark_tool_name="tool", run_config=run_config_b, ) compile_target_rv64 = iree_definitions.CompileTarget( target_backend=iree_definitions.TargetBackend.LLVM_CPU, target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, target_abi=iree_definitions.TargetABI.LINUX_GNU, ) gen_config_rv64 = iree_definitions.ModuleGenerationConfig.build( imported_model=iree_definitions.ImportedModel.from_model(model_tflite), compile_config=iree_definitions.CompileConfig.build( id="comp_rv64", tags=[], compile_targets=[compile_target_rv64] ), ) device_spec_rv64 = common_definitions.DeviceSpec.build( id="rv64_dev", device_name="rv64_dev", architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, host_environment=common_definitions.HostEnvironment.LINUX_X86_64, device_parameters=[], tags=[], ) run_config_incompatible = iree_definitions.E2EModelRunConfig.build( module_generation_config=gen_config_rv64, module_execution_config=exec_config_b, target_device_spec=device_spec_rv64, input_data=common_definitions.DEFAULT_INPUT_DATA, tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, ) self.incompatible_case = BenchmarkCase( model_name="model_tflite", model_tags=[], bench_mode=["task"], target_arch=common_definitions.DeviceArchitecture.RV64_GENERIC, driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], module_dir=benchmark_definition.ResourceLocation.build_local_path( "incompatible_case" ), benchmark_tool_name="tool", run_config=run_config_incompatible, ) self.benchmark_suite = BenchmarkSuite( [ self.case1, self.case2, self.incompatible_case, ] ) def tearDown(self) -> None: self._tmp_dir_obj.cleanup() self._root_dir_obj.cleanup() def test_run(self): driver = FakeBenchmarkDriver( self.device_info, self.config, self.benchmark_suite ) driver.run() self.assertEqual(driver.get_benchmark_results().commit, "abcd") self.assertEqual(len(driver.get_benchmark_results().benchmarks), 2) self.assertEqual( driver.get_benchmark_results().benchmarks[0].metrics.raw_data, {} ) self.assertEqual( driver.get_benchmark_result_filenames(), [ self.benchmark_results_dir / f"{self.case1.run_config}.json", self.benchmark_results_dir / f"{self.case2.run_config}.json", ], ) self.assertEqual( driver.get_capture_filenames(), [ self.captures_dir / f"{self.case1.run_config}.tracy", self.captures_dir / f"{self.case2.run_config}.tracy", ], ) self.assertEqual(driver.get_benchmark_errors(), []) def test_run_disable_compatible_filter(self): self.config.use_compatible_filter = False driver = FakeBenchmarkDriver( self.device_info, self.config, self.benchmark_suite ) driver.run() self.assertEqual(len(driver.get_benchmark_results().benchmarks), 3) def test_run_with_no_capture(self): self.config.trace_capture_config = None driver = FakeBenchmarkDriver( self.device_info, self.config, self.benchmark_suite ) driver.run() self.assertEqual(len(driver.get_benchmark_result_filenames()), 2) self.assertEqual(driver.get_capture_filenames(), []) def test_run_with_exception_and_keep_going(self): self.config.keep_going = True driver = FakeBenchmarkDriver( self.device_info, self.config, self.benchmark_suite, raise_exception_on_case=self.case1, ) driver.run() self.assertEqual(len(driver.get_benchmark_errors()), 1) self.assertEqual(len(driver.get_benchmark_result_filenames()), 1) def test_run_with_previous_benchmarks_and_captures(self): benchmark_filename = ( self.benchmark_results_dir / f"{self.case1.run_config}.json" ) benchmark_filename.touch() capture_filename = self.captures_dir / f"{self.case1.run_config}.tracy" capture_filename.touch() config = dataclasses.replace(self.config, continue_from_previous=True) driver = FakeBenchmarkDriver( device_info=self.device_info, benchmark_config=config, benchmark_suite=self.benchmark_suite, ) driver.run() self.assertEqual(len(driver.run_benchmark_cases), 1) self.assertEqual(len(driver.get_benchmark_result_filenames()), 2) self.assertEqual(len(driver.get_capture_filenames()), 2) if __name__ == "__main__": unittest.main()