# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import hashlib import os import sys import types from typing import Any, List from urllib.parse import urlparse from megengine.utils.http_download import download_from_url from ..distributed import is_distributed from ..logger import get_logger from ..serialization import load as _mge_load_serialized from .const import ( DEFAULT_CACHE_DIR, DEFAULT_GIT_HOST, DEFAULT_PROTOCOL, ENV_MGE_HOME, ENV_XDG_CACHE_HOME, HUBCONF, HUBDEPENDENCY, ) from .exceptions import InvalidProtocol from .fetcher import GitHTTPSFetcher, GitSSHFetcher from .tools import cd, check_module_exists, load_module logger = get_logger(__name__) PROTOCOLS = { "HTTPS": GitHTTPSFetcher, "SSH": GitSSHFetcher, } def _get_megengine_home() -> str: r"""MGE_HOME setting complies with the XDG Base Directory Specification""" megengine_home = os.path.expanduser( os.getenv( ENV_MGE_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"), ) ) return megengine_home def _get_repo( git_host: str, repo_info: str, use_cache: bool = False, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> str: if protocol not in PROTOCOLS: raise InvalidProtocol( "Invalid protocol, the value should be one of {}.".format( ", ".join(PROTOCOLS.keys()) ) ) cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) with cd(cache_dir): fetcher = PROTOCOLS[protocol] repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) return os.path.join(cache_dir, repo_dir) def _check_dependencies(module: types.ModuleType) -> None: if not hasattr(module, HUBDEPENDENCY): return dependencies = getattr(module, HUBDEPENDENCY) if not dependencies: return missing_deps = [m for m in dependencies if not check_module_exists(m)] if len(missing_deps): raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) def _init_hub( repo_info: str, git_host: str, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ): r"""Imports hubmodule like python import. Args: repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"`` git_host: host address of git repo. Eg: github.com use_cache: whether to use locally cached code or completely re-fetch. commit: commit id on github or gitlab. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. Returns: a python module. """ cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) os.makedirs(cache_dir, exist_ok=True) absolute_repo_dir = _get_repo( git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol ) sys.path.insert(0, absolute_repo_dir) hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF)) sys.path.remove(absolute_repo_dir) return hubmodule @functools.wraps(_init_hub) def import_module(*args, **kwargs): return _init_hub(*args, **kwargs) def list( repo_info: str, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> List[str]: r"""Lists all entrypoints available in repo hubconf. Args: repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"`` git_host: host address of git repo. Eg: github.com use_cache: whether to use locally cached code or completely re-fetch. commit: commit id on github or gitlab. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. Returns: all entrypoint names of the model. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) return [ _ for _ in dir(hubmodule) if not _.startswith("__") and callable(getattr(hubmodule, _)) ] def load( repo_info: str, entry: str, *args, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, **kwargs ) -> Any: r"""Loads model from github or gitlab repo, with pretrained weights. Args: repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"`` entry: an entrypoint defined in hubconf. git_host: host address of git repo. Eg: github.com use_cache: whether to use locally cached code or completely re-fetch. commit: commit id on github or gitlab. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. Returns: a single model with corresponding pretrained weights. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) _check_dependencies(hubmodule) module = getattr(hubmodule, entry)(*args, **kwargs) return module def help( repo_info: str, entry: str, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> str: r"""This function returns docstring of entrypoint ``entry`` by following steps: 1. Pull the repo code specified by git and repo_info. 2. Load the entry defined in repo's hubconf.py 3. Return docstring of function entry. Args: repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"`` entry: an entrypoint defined in hubconf.py git_host: host address of git repo. Eg: github.com use_cache: whether to use locally cached code or completely re-fetch. commit: commit id on github or gitlab. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. Returns: docstring of entrypoint ``entry``. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) doc = getattr(hubmodule, entry).__doc__ return doc def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: """Loads MegEngine serialized object from the given URL. If the object is already present in ``model_dir``, it's deserialized and returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. Args: url: url to serialized object. model_dir: dir to cache target serialized file. Returns: loaded object. """ if model_dir is None: model_dir = os.path.join(_get_megengine_home(), "serialized") os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) filename = os.path.basename(parts.path) # use hash as prefix to avoid filename conflict from different urls sha256 = hashlib.sha256() sha256.update(url.encode()) digest = sha256.hexdigest()[:6] filename = digest + "_" + filename cached_file = os.path.join(model_dir, filename) logger.info( "load_serialized_obj_from_url: download to or using cached %s", cached_file ) if not os.path.exists(cached_file): if is_distributed(): logger.warning( "Downloading serialized object in DISTRIBUTED mode\n" " File may be downloaded multiple times. We recommend\n" " users to download in single process first." ) download_from_url(url, cached_file) state_dict = _mge_load_serialized(cached_file) return state_dict class pretrained: r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s). For example, we can decorate a resnet18 function as follows .. code-block:: @hub.pretrained("https://url/to/pretrained_resnet18.pkl") def resnet18(**kwargs): Returns: When decorated function is called with ``pretrained=True``, MegEngine will automatically download and fill the returned model with pretrained weights. """ def __init__(self, url): self.url = url def __call__(self, func): @functools.wraps(func) def pretrained_model_func( pretrained=False, **kwargs ): # pylint: disable=redefined-outer-name model = func(**kwargs) if pretrained: weights = load_serialized_obj_from_url(self.url) model.load_state_dict(weights) return model return pretrained_model_func __all__ = [ "list", "load", "help", "load_serialized_obj_from_url", "pretrained", "import_module", ]