from __future__ import absolute_import, division, print_function, unicode_literals import torch import errno import hashlib import os import re import shutil import sys import tempfile try: from requests.utils import urlparse from requests import get as urlopen requests_available = True except ImportError: requests_available = False if sys.version_info[0] == 2: from urlparse import urlparse # noqa f811 from urllib2 import urlopen # noqa f811 else: from urllib.request import urlopen from urllib.parse import urlparse try: from tqdm import tqdm except ImportError: tqdm = None # defined below # matches bfd8deac from resnet18-bfd8deac.pth HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') def load_url(url, model_dir=None, map_location=None, progress=True): r"""Loads the Torch serialized object at the given URL. If the object is already present in `model_dir`, it's deserialized and returned. The filename part of the URL should follow the naming convention ``filename-.ext`` where ```` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. The default value of `model_dir` is ``$TORCH_HOME/models`` where ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be overridden with the ``$TORCH_MODEL_ZOO`` environment variable. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) progress (bool, optional): whether or not to display a progress bar to stderr Example: >>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') """ if model_dir is None: torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) try: os.makedirs(model_dir) except OSError as e: if e.errno == errno.EEXIST: # Directory already exists, ignore. pass else: # Unexpected OSError, re-raise. raise parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = HASH_REGEX.search(filename).group(1) _download_url_to_file(url, cached_file, hash_prefix, progress=progress) return torch.load(cached_file, map_location=map_location) def _download_url_to_file(url, dst, hash_prefix, progress): file_size = None if requests_available: u = urlopen(url, stream=True) if hasattr(u.headers, "Content-Length"): file_size = int(u.headers["Content-Length"]) u = u.raw else: u = urlopen(url) meta = u.info() if hasattr(meta, 'getheaders'): content_length = meta.getheaders("Content-Length") else: content_length = meta.get_all("Content-Length") if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) f = tempfile.NamedTemporaryFile(delete=False) try: if hash_prefix is not None: sha256 = hashlib.sha256() with tqdm(total=file_size, disable=not progress) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) if hash_prefix is not None: sha256.update(buffer) pbar.update(len(buffer)) f.close() if hash_prefix is not None: digest = sha256.hexdigest() if digest[:len(hash_prefix)] != hash_prefix: raise RuntimeError('invalid hash value (expected "{}", got "{}")' .format(hash_prefix, digest)) shutil.move(f.name, dst) finally: f.close() if os.path.exists(f.name): os.remove(f.name) if tqdm is None: # fake tqdm if it's not installed class tqdm(object): def __init__(self, total=None, disable=False): self.total = total self.disable = disable self.n = 0 def update(self, n): if self.disable: return self.n += n if self.total is None: sys.stderr.write("\r{0:.1f} bytes".format(self.n)) else: sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) sys.stderr.flush() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self.disable: return sys.stderr.write('\n')