import importlib import os import shutil import sys import zipfile if sys.version_info[0] == 2: from urlparse import urlparse from urllib2 import urlopen # noqa f811 else: from urllib.request import urlopen from urllib.parse import urlparse # noqa: F401 MASTER_BRANCH = 'master' ENV_TORCH_HUB_DIR = 'TORCH_HUB_DIR' DEFAULT_TORCH_HUB_DIR = '~/.torch/hub' READ_DATA_CHUNK = 8192 hub_dir = None def _check_module_exists(name): if sys.version_info >= (3, 4): import importlib.util return importlib.util.find_spec(name) is not None elif sys.version_info >= (3, 3): # Special case for python3.3 import importlib.find_loader return importlib.find_loader(name) is not None else: # NB: imp doesn't handle hierarchical module names (names contains dots). try: import imp imp.find_module(name) except Exception: return False return True def _remove_if_exists(path): if os.path.exists(path): if os.path.isfile(path): os.remove(path) else: shutil.rmtree(path) def _git_archive_link(repo, branch): return 'https://github.com/' + repo + '/archive/' + branch + '.zip' def _download_url_to_file(url, filename): sys.stderr.write('Downloading: \"{}\" to {}'.format(url, filename)) response = urlopen(url) with open(filename, 'wb') as f: while True: data = response.read(READ_DATA_CHUNK) if len(data) == 0: break f.write(data) def _load_attr_from_module(module_name, func_name): m = importlib.import_module(module_name) # Check if callable is defined in the module if func_name not in dir(m): return None return getattr(m, func_name) def set_dir(d): r""" Optionally set hub_dir to a local dir to save downloaded models & weights. If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first, `~/.torch/hub` will be created and used as fallback. Args: d: path to a local folder to save downloaded models & weights. """ global hub_dir hub_dir = d def load(github, model, force_reload=False, *args, **kwargs): r""" Load a model from a github repo, with pretrained weights. Args: github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional tag/branch. The default branch is `master` if not specified. Example: 'pytorch/vision[:hub]' model: Required, a string of entrypoint name defined in repo's hubconf.py force_reload: Optional, whether to discard the existing cache and force a fresh download. Default is `False`. *args: Optional, the corresponding args for callable `model`. **kwargs: Optional, the corresponding kwargs for callable `model`. Returns: a single model with corresponding pretrained weights. """ if not isinstance(model, str): raise ValueError('Invalid input: model should be a string of function name') # Setup hub_dir to save downloaded files global hub_dir if hub_dir is None: hub_dir = os.getenv(ENV_TORCH_HUB_DIR, DEFAULT_TORCH_HUB_DIR) if '~' in hub_dir: hub_dir = os.path.expanduser(hub_dir) if not os.path.exists(hub_dir): os.makedirs(hub_dir) # Parse github repo information branch = MASTER_BRANCH if ':' in github: repo_info, branch = github.split(':') else: repo_info = github repo_owner, repo_name = repo_info.split('/') # Download zipped code from github url = _git_archive_link(repo_info, branch) cached_file = os.path.join(hub_dir, branch + '.zip') repo_dir = os.path.join(hub_dir, repo_name + '_' + branch) use_cache = (not force_reload) and os.path.exists(repo_dir) # Github uses '{repo_name}-{branch_name}' as folder name which is not importable # We need to manually rename it to '{repo_name}' # Unzip the code and rename the base folder if use_cache: sys.stderr.write('Using cache found in {}'.format(repo_dir)) else: _remove_if_exists(cached_file) _download_url_to_file(url, cached_file) cached_zipfile = zipfile.ZipFile(cached_file) # Github renames folder repo-v1.x.x to repo-1.x.x extraced_repo_name = cached_zipfile.infolist()[0].filename extracted_repo = os.path.join(hub_dir, extraced_repo_name) _remove_if_exists(extracted_repo) cached_zipfile.extractall(hub_dir) _remove_if_exists(cached_file) _remove_if_exists(repo_dir) shutil.move(extracted_repo, repo_dir) # rename the repo sys.path.insert(0, repo_dir) # Make Python interpreter aware of the repo dependencies = _load_attr_from_module('hubconf', 'dependencies') if dependencies is not None: missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] if len(missing_deps): raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) func = _load_attr_from_module('hubconf', model) if func is None: raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) # Check if func is callable if not callable(func): raise RuntimeError('{} is not callable'.format(func)) # Call the function return func(*args, **kwargs)