## Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Helpers that build CMake rules. Each function takes a list of parameters and returns a string ready to be included in a CMakeLists.txt file. Builder functions handle optional arguments, lists, formatting, etc. For example: build_iree_fetch_artifact( target_name="abcd", source_url="https://example.com/abcd.tflite", output="./abcd.tflite", unpack=True) Outputs: iree_fetch_artifact( NAME "abcd" SOURCE_URL "https://example.com/abcd.tflite" OUTPUT "./abcd.tflite" UNPACK ) """ from typing import Dict, List, Optional, Sequence INDENT_SPACES = " " * 2 def _get_string_list(values: Sequence[str], quote: bool = True) -> List[str]: if quote: return [f'"{value}"' for value in values] return list(values) def _get_block_body(body: List[str]) -> List[str]: return [INDENT_SPACES + line for line in body] def _get_string_arg_block( keyword: str, value: Optional[str], quote: bool = True ) -> List[str]: if value is None: return [] if quote: value = f'"{value}"' return [f"{keyword} {value}"] def _get_string_list_arg_block( keyword: str, values: Sequence[str], quote: bool = True ) -> List[str]: if len(values) == 0: return [] body = _get_string_list(values, quote) return [keyword] + _get_block_body(body) def _get_option_arg_block(keyword: str, value: Optional[bool]) -> List[str]: if value is True: return [keyword] return [] def _build_call_rule( rule_name: str, parameter_blocks: Sequence[List[str]] ) -> List[str]: output = [f"{rule_name}("] for block in parameter_blocks: if len(block) == 0: continue output.extend(_get_block_body(block)) output.append(")") return output def _convert_block_to_string(block: List[str]) -> str: # Hack to append the terminating newline and only copies the list instead of # the whole string. return "\n".join(block + [""]) def build_iree_bytecode_module( target_name: str, src: str, module_name: str, flags: List[str] = [], compile_tool_target: Optional[str] = None, c_identifier: Optional[str] = None, static_lib_path: Optional[str] = None, deps: List[str] = [], friendly_name: Optional[str] = None, testonly: bool = False, public: bool = True, ) -> str: name_block = _get_string_arg_block("NAME", target_name) src_block = _get_string_arg_block("SRC", src) module_name_block = _get_string_arg_block("MODULE_FILE_NAME", module_name) c_identifier_block = _get_string_arg_block("C_IDENTIFIER", c_identifier) static_lib_block = _get_string_arg_block("STATIC_LIB_PATH", static_lib_path) compile_tool_target_block = _get_string_arg_block( "COMPILE_TOOL", compile_tool_target ) flags_block = _get_string_list_arg_block("FLAGS", flags) deps_block = _get_string_list_arg_block("DEPS", deps) friendly_name_block = _get_string_arg_block("FRIENDLY_NAME", friendly_name) testonly_block = _get_option_arg_block("TESTONLY", testonly) public_block = _get_option_arg_block("PUBLIC", public) return _convert_block_to_string( _build_call_rule( rule_name="iree_bytecode_module", parameter_blocks=[ name_block, src_block, module_name_block, c_identifier_block, compile_tool_target_block, static_lib_block, flags_block, friendly_name_block, deps_block, testonly_block, public_block, ], ) ) def build_iree_fetch_artifact( target_name: str, source_url: str, output: str, unpack: bool ) -> str: name_block = _get_string_arg_block("NAME", target_name) source_url_block = _get_string_arg_block("SOURCE_URL", source_url) output_block = _get_string_arg_block("OUTPUT", output) unpack_block = _get_option_arg_block("UNPACK", unpack) return _convert_block_to_string( _build_call_rule( rule_name="iree_fetch_artifact", parameter_blocks=[name_block, source_url_block, output_block, unpack_block], ) ) def build_iree_import_tf_model( target_path: str, source: str, import_flags: List[str], output_mlir_file: str ) -> str: target_name_block = _get_string_arg_block("TARGET_NAME", target_path) source_block = _get_string_arg_block("SOURCE", source) import_flags_block = _get_string_list_arg_block("IMPORT_FLAGS", import_flags) output_mlir_file_block = _get_string_arg_block("OUTPUT_MLIR_FILE", output_mlir_file) return _convert_block_to_string( _build_call_rule( rule_name="iree_import_tf_model", parameter_blocks=[ target_name_block, source_block, import_flags_block, output_mlir_file_block, ], ) ) def build_iree_import_tflite_model( target_path: str, source: str, import_flags: List[str], output_mlir_file: str ) -> str: target_name_block = _get_string_arg_block("TARGET_NAME", target_path) source_block = _get_string_arg_block("SOURCE", source) import_flags_block = _get_string_list_arg_block("IMPORT_FLAGS", import_flags) output_mlir_file_block = _get_string_arg_block("OUTPUT_MLIR_FILE", output_mlir_file) return _convert_block_to_string( _build_call_rule( rule_name="iree_import_tflite_model", parameter_blocks=[ target_name_block, source_block, import_flags_block, output_mlir_file_block, ], ) ) def build_iree_benchmark_suite_module_test( target_name: str, driver: str, expected_output: str, platform_module_map: Dict[str, str], runner_args: Sequence[str], timeout_secs: Optional[int] = None, labels: Sequence[str] = [], xfail_platforms: Sequence[str] = [], ) -> str: name_block = _get_string_arg_block("NAME", target_name) driver_block = _get_string_arg_block("DRIVER", driver) expected_output_block = _get_string_arg_block("EXPECTED_OUTPUT", expected_output) modules_block = _get_string_list_arg_block( "MODULES", [f"{platform}={path}" for platform, path in platform_module_map.items()], ) timeout_block = _get_string_arg_block( "TIMEOUT", str(timeout_secs) if timeout_secs is not None else None ) runner_args_block = _get_string_list_arg_block("RUNNER_ARGS", runner_args) labels_block = _get_string_list_arg_block("LABELS", labels) xfail_platforms_block = _get_string_list_arg_block( "XFAIL_PLATFORMS", xfail_platforms ) return _convert_block_to_string( _build_call_rule( rule_name="iree_benchmark_suite_module_test", parameter_blocks=[ name_block, driver_block, expected_output_block, timeout_block, modules_block, runner_args_block, labels_block, xfail_platforms_block, ], ) ) def build_add_dependencies(target: str, deps: List[str]) -> str: if len(deps) == 0: raise ValueError("Target dependencies can't be empty.") deps_list = _get_string_list(deps, quote=False) return _convert_block_to_string( [f"add_dependencies({target}"] + _get_block_body(deps_list) + [")"] ) def build_set(variable_name: str, value: str) -> str: return _convert_block_to_string( [f"set({variable_name}"] + _get_block_body([value]) + [")"] )