## Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Helpers that generate paths for model artifacts.""" import pathlib import urllib.parse from e2e_test_artifacts import utils from e2e_test_framework.definitions import common_definitions MODEL_ARTIFACT_PREFIX = "model" # Archive extensions used to pack models. ARCHIVE_FILE_EXTENSIONS = [".tar", ".gz"] def get_model_path( model: common_definitions.Model, root_path: pathlib.PurePath = pathlib.PurePath() ) -> pathlib.PurePath: """Returns the path of an model artifact file or directory. Args: model: source model. root_path: path of the root artifact directory, on which the returned path will base. Returns: Path of the model artifact. """ model_url = urllib.parse.urlparse(model.source_url) # Drop the archive extensions. file_exts = pathlib.PurePath(model_url.path).suffixes while len(file_exts) > 0 and file_exts[-1] in ARCHIVE_FILE_EXTENSIONS: file_exts.pop() model_ext = "".join(file_exts) # Model path: /_ filename = utils.get_safe_name(model.name + model_ext) return root_path / f"{MODEL_ARTIFACT_PREFIX}_{filename}"