Hub improvements (#26723)

Summary:
Resubmit of https://github.com/pytorch/pytorch/pull/25980.
Our old serialization was in tar (like `resnet18-5c106cde.pth` was in this format) so let's only support automatically unzip if checkpoints are zipfiles.
We can still manage to get it work with tarfile, but let's delay it when there's an ask.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26723

Differential Revision: D17551795

Pulled By: ailzhang

fbshipit-source-id: 00b4e7621f1e753ca9aa07b1fe356278c6693a1e
This commit is contained in:
Ailing Zhang 2019-09-25 08:20:20 -07:00 committed by Facebook Github Bot
parent 61dd485b3a
commit 0f1fbc0eb2
3 changed files with 119 additions and 69 deletions

View File

@ -89,6 +89,10 @@ show docstring and examples through ``torch.hub.help()`` and load the pre-traine
.. autofunction:: load
.. autofunction:: download_url_to_file
.. autofunction:: load_state_dict_from_url
Running a loaded model:
^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -14,7 +14,7 @@ from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.hub as hub
from torch.autograd._functions.utils import prepare_onnx_paddings
from torch.autograd._functions.utils import check_onnx_broadcast
from common_utils import skipIfRocm, load_tests
from common_utils import skipIfRocm, load_tests, IS_SANDCASTLE
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -511,50 +511,64 @@ class TestONNXUtils(TestCase):
try_check_onnx_broadcast(dims1, dims2, True, False)
def sum_of_model_parameters(model):
def sum_of_state_dict(state_dict):
s = 0
for p in model.parameters():
s += p.sum()
for _, v in state_dict.items():
s += v.sum()
return s
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.992365
SUM_OF_HUB_EXAMPLE = 431080
TORCHHUB_EXAMPLE_RELEASE_URL = 'https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones'
@unittest.skipIf(IS_SANDCASTLE, 'Sandcastle cannot ping external')
class TestHub(TestCase):
@classmethod
def setUpClass(cls):
# Only run this check ONCE before all tests start.
# - If torchvision is imported before all tests start, e.g. we might find _C.so
# which doesn't exist in downloaded zip but in the installed wheel.
# - After the first test is run, torchvision is already in sys.modules due to
# Python cache as we run all hub tests in the same python process.
if 'torchvision' in sys.modules:
raise RuntimeError('TestHub must start without torchvision imported')
def test_load_from_github(self):
hub_model = hub.load(
'pytorch/vision',
'resnet18',
'ailzhang/torchhub_example',
'mnist',
pretrained=True,
progress=False)
self.assertEqual(sum_of_model_parameters(hub_model),
SUM_OF_PRETRAINED_RESNET18_PARAMS)
verbose=False)
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
SUM_OF_HUB_EXAMPLE)
def test_set_dir(self):
temp_dir = tempfile.gettempdir()
hub.set_dir(temp_dir)
hub_model = hub.load(
'pytorch/vision',
'resnet18',
'ailzhang/torchhub_example',
'mnist',
pretrained=True,
progress=False)
self.assertEqual(sum_of_model_parameters(hub_model),
SUM_OF_PRETRAINED_RESNET18_PARAMS)
assert os.path.exists(temp_dir + '/pytorch_vision_master')
shutil.rmtree(temp_dir + '/pytorch_vision_master')
verbose=False)
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
SUM_OF_HUB_EXAMPLE)
assert os.path.exists(temp_dir + '/ailzhang_torchhub_example_master')
shutil.rmtree(temp_dir + '/ailzhang_torchhub_example_master')
def test_list_entrypoints(self):
entry_lists = hub.list('pytorch/vision', force_reload=True)
self.assertObjectIn('resnet18', entry_lists)
entry_lists = hub.list('ailzhang/torchhub_example', force_reload=True)
self.assertObjectIn('mnist', entry_lists)
def test_download_url_to_file(self):
temp_file = os.path.join(tempfile.gettempdir(), 'temp')
hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, temp_file, progress=False)
loaded_state = torch.load(temp_file)
self.assertEqual(sum_of_state_dict(loaded_state),
SUM_OF_HUB_EXAMPLE)
def test_load_state_dict_from_url(self):
loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL)
self.assertEqual(sum_of_state_dict(loaded_state),
SUM_OF_HUB_EXAMPLE)
def test_load_zip_checkpoint(self):
hub_model = hub.load(
'ailzhang/torchhub_example',
'mnist_zip',
pretrained=True,
verbose=False)
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
SUM_OF_HUB_EXAMPLE)
if __name__ == '__main__':
run_tests()

View File

@ -91,19 +91,6 @@ def _git_archive_link(repo_owner, repo_name, branch):
return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
def _download_archive_zip(url, filename):
sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, filename))
# We use a different API for python2 since urllib(2) doesn't recognize the CA
# certificates in older Python
response = urlopen(url)
with open(filename, 'wb') as f:
while True:
data = response.read(READ_DATA_CHUNK)
if len(data) == 0:
break
f.write(data)
def _load_attr_from_module(module, func_name):
# Check if callable is defined in the module
if func_name not in dir(module):
@ -142,7 +129,7 @@ def _parse_repo_info(github):
return repo_owner, repo_name, branch
def _get_cache_or_reload(github, force_reload):
def _get_cache_or_reload(github, force_reload, verbose=True):
# Parse github repo information
repo_owner, repo_name, branch = _parse_repo_info(github)
@ -155,13 +142,15 @@ def _get_cache_or_reload(github, force_reload):
use_cache = (not force_reload) and os.path.exists(repo_dir)
if use_cache:
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
if verbose:
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
else:
cached_file = os.path.join(hub_dir, branch + '.zip')
_remove_if_exists(cached_file)
url = _git_archive_link(repo_owner, repo_name, branch)
_download_archive_zip(url, cached_file)
sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file))
download_url_to_file(url, cached_file, progress=False)
with zipfile.ZipFile(cached_file) as cached_zipfile:
extraced_repo_name = cached_zipfile.infolist()[0].filename
@ -255,7 +244,7 @@ def set_dir(d):
Args:
d: path to a local folder to save downloaded models & weights.
d (string): path to a local folder to save downloaded models & weights.
"""
global hub_dir
hub_dir = d
@ -266,10 +255,10 @@ def list(github, force_reload=False):
List all entrypoints available in `github` hubconf.
Args:
github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `master` if not specified.
Example: 'pytorch/vision[:hub]'
force_reload: Optional, whether to discard the existing cache and force a fresh download.
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
Default is `False`.
Returns:
entrypoints: a list of available entrypoint names
@ -280,7 +269,7 @@ def list(github, force_reload=False):
# Setup hub_dir to save downloaded files
_setup_hubdir()
repo_dir = _get_cache_or_reload(github, force_reload)
repo_dir = _get_cache_or_reload(github, force_reload, True)
sys.path.insert(0, repo_dir)
@ -299,11 +288,11 @@ def help(github, model, force_reload=False):
Show the docstring of entrypoint `model`.
Args:
github: Required, a string with format <repo_owner/repo_name[:tag_name]> with an optional
github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
tag/branch. The default branch is `master` if not specified.
Example: 'pytorch/vision[:hub]'
model: Required, a string of entrypoint name defined in repo's hubconf.py
force_reload: Optional, whether to discard the existing cache and force a fresh download.
model (string): a string of entrypoint name defined in repo's hubconf.py
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
Default is `False`.
Example:
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
@ -311,7 +300,7 @@ def help(github, model, force_reload=False):
# Setup hub_dir to save downloaded files
_setup_hubdir()
repo_dir = _get_cache_or_reload(github, force_reload)
repo_dir = _get_cache_or_reload(github, force_reload, True)
sys.path.insert(0, repo_dir)
@ -333,14 +322,17 @@ def load(github, model, *args, **kwargs):
Load a model from a github repo, with pretrained weights.
Args:
github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `master` if not specified.
Example: 'pytorch/vision[:hub]'
model: Required, a string of entrypoint name defined in repo's hubconf.py
*args: Optional, the corresponding args for callable `model`.
force_reload: Optional, whether to force a fresh download of github repo unconditionally.
model (string): a string of entrypoint name defined in repo's hubconf.py
*args (optional): the corresponding args for callable `model`.
force_reload (bool, optional): whether to force a fresh download of github repo unconditionally.
Default is `False`.
**kwargs: Optional, the corresponding kwargs for callable `model`.
verbose (bool, optional): If False, mute messages about hitting local caches. Note that the message
about first download is cannot be muted.
Default is `True`.
**kwargs (optional): the corresponding kwargs for callable `model`.
Returns:
a single model with corresponding pretrained weights.
@ -353,8 +345,10 @@ def load(github, model, *args, **kwargs):
force_reload = kwargs.get('force_reload', False)
kwargs.pop('force_reload', None)
verbose = kwargs.get('verbose', True)
kwargs.pop('verbose', None)
repo_dir = _get_cache_or_reload(github, force_reload)
repo_dir = _get_cache_or_reload(github, force_reload, verbose)
sys.path.insert(0, repo_dir)
@ -369,7 +363,21 @@ def load(github, model, *args, **kwargs):
return model
def _download_url_to_file(url, dst, hash_prefix, progress):
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
r"""Download object at the given URL to a local path.
Args:
url (string): URL of the object to download
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
Default: None
progress (bool, optional): whether or not to display a progress bar to stderr
Default: True
Example:
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
"""
file_size = None
# We use a different API for python2 since urllib(2) doesn't recognize the CA
# certificates in older Python
@ -385,6 +393,7 @@ def _download_url_to_file(url, dst, hash_prefix, progress):
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overriden by a broken download.
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
@ -414,16 +423,20 @@ def _download_url_to_file(url, dst, hash_prefix, progress):
if os.path.exists(f.name):
os.remove(f.name)
def _download_url_to_file(url, dst, hash_prefix=None, progress=True):
warnings.warn('torch.hub._download_url_to_file has been renamed to\
torch.hub.download_url_to_file to be a public API,\
_download_url_to_file will be removed in after 1.3 release')
download_url_to_file(url, dst, hash_prefix, progress)
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False):
r"""Loads the Torch serialized object at the given URL.
If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
If downloaded file is a zip file, it will be automatically
decompressed.
If the object is already present in `model_dir`, it's deserialized and
returned.
The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
@ -433,7 +446,13 @@ def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=Tr
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar to stderr
progress (bool, optional): whether or not to display a progress bar to stderr.
Default: True
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
Default: False
Example:
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
@ -462,6 +481,19 @@ def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=Tr
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = HASH_REGEX.search(filename).group(1)
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
hash_prefix = HASH_REGEX.search(filename).group(1) if check_hash else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used.
if zipfile.is_zipfile(cached_file):
with zipfile.ZipFile(cached_file) as cached_zipfile:
members = cached_zipfile.infolist()
if len(members) != 1:
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
cached_zipfile.extractall(model_dir)
extraced_name = members[0].filename
cached_file = os.path.join(model_dir, extraced_name)
return torch.load(cached_file, map_location=map_location)