# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse import contextlib import getpass import os import sys import urllib.parse import filelock from ..core._imperative_rt import PersistentCache as _PersistentCache from ..logger import get_logger from ..version import __version__, git_version class PersistentCacheOnServer(_PersistentCache): def __init__(self): super().__init__() cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE") if cache_type not in ("FILE", "MEMORY"): try: redis_config = self.get_redis_config() except Exception as exc: get_logger().error( "failed to connect to cache server {!r}; try fallback to " "in-file cache".format(exc) ) else: if redis_config is not None: self.add_config( "redis", redis_config, "fastrun use redis cache", "failed to connect to cache server", ) if cache_type != "MEMORY": path = self.get_cache_file(self.get_cache_dir()) self.add_config( "in-file", {"path": path}, "fastrun use in-file cache in {}".format(path), "failed to create cache file in {}".format(path), ) self.add_config( "in-memory", {}, "fastrun use in-memory cache", "failed to create in-memory cache", ) def get_cache_dir(self): cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") if not cache_dir: from ..hub.hub import _get_megengine_home cache_dir = os.path.expanduser( os.path.join(_get_megengine_home(), "persistent_cache") ) os.makedirs(cache_dir, exist_ok=True) return cache_dir def get_cache_file(self, cache_dir): cache_file = os.path.join(cache_dir, "cache.bin") with open(cache_file, "a"): pass return cache_file @contextlib.contextmanager def lock_cache_file(self, cache_dir): lock_file = os.path.join(cache_dir, "cache.lock") with filelock.FileLock(lock_file): yield def get_redis_config(self): url = os.getenv("MGE_FASTRUN_CACHE_URL") if url is None: return None assert sys.platform != "win32", "redis cache on windows not tested" prefix = "mgbcache:{}:MGB{}:GIT:{}".format( getpass.getuser(), __version__, git_version ) parse_result = urllib.parse.urlparse(url) assert not parse_result.username, "redis conn with username unsupported" if parse_result.scheme == "redis": assert parse_result.hostname and parse_result.port, "invalid url" assert not parse_result.path config = { "hostname": parse_result.hostname, "port": str(parse_result.port), } elif parse_result.scheme == "redis+socket": assert not (parse_result.hostname or parse_result.port) assert parse_result.path config = { "unixsocket": parse_result.path, } else: assert False, "unsupported scheme" if parse_result.password is not None: config["password"] = parse_result.password config["prefix"] = prefix return config def flush(self): if self.config is not None and self.config.type == "in-file": with self.lock_cache_file(self.get_cache_dir()): super().flush() def _clean(): nr_del = PersistentCacheOnServer().clean() if nr_del is not None: print("{} cache entries deleted".format(nr_del)) def main(): parser = argparse.ArgumentParser(description="manage persistent cache") subp = parser.add_subparsers(description="action to be performed", dest="cmd") subp.required = True subp_clean = subp.add_parser("clean", help="clean all the cache of current user") subp_clean.set_defaults(action=_clean) args = parser.parse_args() args.action() if __name__ == "__main__": main()