mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Here's the command I used to invoke autopep8 (in parallel!):
git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i
Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.
Also configures flake8 to match pep8's behavior.
Also configures TravisCI to check the whole project for lint.
110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
import torch
|
|
|
|
import hashlib
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
if sys.version_info[0] == 2:
|
|
from urlparse import urlparse
|
|
from urllib2 import urlopen
|
|
else:
|
|
from urllib.request import urlopen
|
|
from urllib.parse import urlparse
|
|
try:
|
|
from tqdm import tqdm
|
|
except ImportError:
|
|
tqdm = None # defined below
|
|
|
|
# matches bfd8deac from resnet18-bfd8deac.pth
|
|
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
|
|
|
|
|
|
def load_url(url, model_dir=None):
|
|
r"""Loads the Torch serialized object at the given URL.
|
|
|
|
If the object is already present in `model_dir`, it's deserialied and
|
|
returned. The filename part of the URL should follow the naming convention
|
|
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
|
digits of the SHA256 hash of the contents of the file. The hash is used to
|
|
ensure unique names and to verify the contents of the file.
|
|
|
|
The default value of `model_dir` is ``$TORCH_HOME/models`` where
|
|
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
|
|
overriden with the ``$TORCH_MODEL_ZOO`` environement variable.
|
|
|
|
Args:
|
|
url (string): URL of the object to download
|
|
model_dir (string, optional): directory in which to save the object
|
|
|
|
Example:
|
|
>>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
|
|
|
|
"""
|
|
if model_dir is None:
|
|
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
|
|
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
|
|
if not os.path.exists(model_dir):
|
|
os.makedirs(model_dir)
|
|
parts = urlparse(url)
|
|
filename = os.path.basename(parts.path)
|
|
cached_file = os.path.join(model_dir, filename)
|
|
if not os.path.exists(cached_file):
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
|
hash_prefix = HASH_REGEX.search(filename).group(1)
|
|
_download_url_to_file(url, cached_file, hash_prefix)
|
|
return torch.load(cached_file)
|
|
|
|
|
|
def _download_url_to_file(url, dst, hash_prefix):
|
|
u = urlopen(url)
|
|
meta = u.info()
|
|
if hasattr(meta, 'getheaders'):
|
|
file_size = int(meta.getheaders("Content-Length")[0])
|
|
else:
|
|
file_size = int(meta.get_all("Content-Length")[0])
|
|
|
|
f = tempfile.NamedTemporaryFile(delete=False)
|
|
try:
|
|
sha256 = hashlib.sha256()
|
|
with tqdm(total=file_size) as pbar:
|
|
while True:
|
|
buffer = u.read(8192)
|
|
if len(buffer) == 0:
|
|
break
|
|
f.write(buffer)
|
|
sha256.update(buffer)
|
|
pbar.update(len(buffer))
|
|
|
|
f.close()
|
|
digest = sha256.hexdigest()
|
|
if digest[:len(hash_prefix)] != hash_prefix:
|
|
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
|
|
.format(hash_prefix, digest))
|
|
shutil.move(f.name, dst)
|
|
finally:
|
|
f.close()
|
|
if os.path.exists(f.name):
|
|
os.remove(f.name)
|
|
|
|
|
|
if tqdm is None:
|
|
# fake tqdm if it's not installed
|
|
class tqdm(object):
|
|
|
|
def __init__(self, total):
|
|
self.total = total
|
|
self.n = 0
|
|
|
|
def update(self, n):
|
|
self.n += n
|
|
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
|
|
sys.stderr.flush()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
sys.stderr.write('\n')
|