# Copyright 2021 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # This tool handles mirroring tflite testing files from their source to the # the iree-model-artifacts test bucket. This avoids taking dependency on # external test files that may change or no longer be available. # # To update all files: # python update_tflite_models.py --file all # # To update a specific file: # python update_tflite_models.py --file posenet_i8_input.jpg # # Note you must have write permission to the iree-model-artifacts GCS bucket # with local gcloud authentication. from absl import app from absl import flags from google.cloud import storage from google_auth_oauthlib import flow import tempfile import urllib FLAGS = flags.FLAGS flags.DEFINE_string("file", "", "file to update") file_dict = dict( { "mobilenet_v1.tflite": "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite", "posenet_i8.tflite": "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite", "posenet_i8_input.jpg": "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg", } ) BUCKET_NAME = "iree-model-artifacts" FOLDER_NAME = "tflite-integration-tests" def upload_model(source, destination, tmpfile): """Uploads a file to the bucket.""" urllib.request.urlretrieve(source, tmpfile) storage_client = storage.Client() bucket = storage_client.get_bucket(BUCKET_NAME) blob = bucket.blob("/".join([FOLDER_NAME, destination])) blob.upload_from_filename(tmpfile) def main(argv): tf = tempfile.NamedTemporaryFile() items = file_dict.items() if FLAGS.file in file_dict: items = [(FLAGS.file, file_dict[FLAGS.file])] elif FLAGS.file != "all": print("Unknown file to upload: ", '"' + FLAGS.file + '"') exit() for dst, src in items: upload_model(src, dst, tf.name) if __name__ == "__main__": app.run(main)