mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Hub improvements (#26723)
Summary: Resubmit of https://github.com/pytorch/pytorch/pull/25980. Our old serialization was in tar (like `resnet18-5c106cde.pth` was in this format) so let's only support automatically unzip if checkpoints are zipfiles. We can still manage to get it work with tarfile, but let's delay it when there's an ask. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26723 Differential Revision: D17551795 Pulled By: ailzhang fbshipit-source-id: 00b4e7621f1e753ca9aa07b1fe356278c6693a1e
This commit is contained in:
parent
61dd485b3a
commit
0f1fbc0eb2
|
|
@ -89,6 +89,10 @@ show docstring and examples through ``torch.hub.help()`` and load the pre-traine
|
|||
|
||||
.. autofunction:: load
|
||||
|
||||
.. autofunction:: download_url_to_file
|
||||
|
||||
.. autofunction:: load_state_dict_from_url
|
||||
|
||||
Running a loaded model:
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
|||
import torch.hub as hub
|
||||
from torch.autograd._functions.utils import prepare_onnx_paddings
|
||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||
from common_utils import skipIfRocm, load_tests
|
||||
from common_utils import skipIfRocm, load_tests, IS_SANDCASTLE
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
|
|
@ -511,50 +511,64 @@ class TestONNXUtils(TestCase):
|
|||
try_check_onnx_broadcast(dims1, dims2, True, False)
|
||||
|
||||
|
||||
def sum_of_model_parameters(model):
|
||||
def sum_of_state_dict(state_dict):
|
||||
s = 0
|
||||
for p in model.parameters():
|
||||
s += p.sum()
|
||||
for _, v in state_dict.items():
|
||||
s += v.sum()
|
||||
return s
|
||||
|
||||
SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.992365
|
||||
SUM_OF_HUB_EXAMPLE = 431080
|
||||
TORCHHUB_EXAMPLE_RELEASE_URL = 'https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones'
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, 'Sandcastle cannot ping external')
|
||||
class TestHub(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Only run this check ONCE before all tests start.
|
||||
# - If torchvision is imported before all tests start, e.g. we might find _C.so
|
||||
# which doesn't exist in downloaded zip but in the installed wheel.
|
||||
# - After the first test is run, torchvision is already in sys.modules due to
|
||||
# Python cache as we run all hub tests in the same python process.
|
||||
if 'torchvision' in sys.modules:
|
||||
raise RuntimeError('TestHub must start without torchvision imported')
|
||||
|
||||
def test_load_from_github(self):
|
||||
hub_model = hub.load(
|
||||
'pytorch/vision',
|
||||
'resnet18',
|
||||
'ailzhang/torchhub_example',
|
||||
'mnist',
|
||||
pretrained=True,
|
||||
progress=False)
|
||||
self.assertEqual(sum_of_model_parameters(hub_model),
|
||||
SUM_OF_PRETRAINED_RESNET18_PARAMS)
|
||||
verbose=False)
|
||||
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
|
||||
def test_set_dir(self):
|
||||
temp_dir = tempfile.gettempdir()
|
||||
hub.set_dir(temp_dir)
|
||||
hub_model = hub.load(
|
||||
'pytorch/vision',
|
||||
'resnet18',
|
||||
'ailzhang/torchhub_example',
|
||||
'mnist',
|
||||
pretrained=True,
|
||||
progress=False)
|
||||
self.assertEqual(sum_of_model_parameters(hub_model),
|
||||
SUM_OF_PRETRAINED_RESNET18_PARAMS)
|
||||
assert os.path.exists(temp_dir + '/pytorch_vision_master')
|
||||
shutil.rmtree(temp_dir + '/pytorch_vision_master')
|
||||
verbose=False)
|
||||
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
assert os.path.exists(temp_dir + '/ailzhang_torchhub_example_master')
|
||||
shutil.rmtree(temp_dir + '/ailzhang_torchhub_example_master')
|
||||
|
||||
def test_list_entrypoints(self):
|
||||
entry_lists = hub.list('pytorch/vision', force_reload=True)
|
||||
self.assertObjectIn('resnet18', entry_lists)
|
||||
entry_lists = hub.list('ailzhang/torchhub_example', force_reload=True)
|
||||
self.assertObjectIn('mnist', entry_lists)
|
||||
|
||||
def test_download_url_to_file(self):
|
||||
temp_file = os.path.join(tempfile.gettempdir(), 'temp')
|
||||
hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, temp_file, progress=False)
|
||||
loaded_state = torch.load(temp_file)
|
||||
self.assertEqual(sum_of_state_dict(loaded_state),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
|
||||
def test_load_state_dict_from_url(self):
|
||||
loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL)
|
||||
self.assertEqual(sum_of_state_dict(loaded_state),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
|
||||
def test_load_zip_checkpoint(self):
|
||||
hub_model = hub.load(
|
||||
'ailzhang/torchhub_example',
|
||||
'mnist_zip',
|
||||
pretrained=True,
|
||||
verbose=False)
|
||||
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
112
torch/hub.py
112
torch/hub.py
|
|
@ -91,19 +91,6 @@ def _git_archive_link(repo_owner, repo_name, branch):
|
|||
return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
|
||||
|
||||
|
||||
def _download_archive_zip(url, filename):
|
||||
sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, filename))
|
||||
# We use a different API for python2 since urllib(2) doesn't recognize the CA
|
||||
# certificates in older Python
|
||||
response = urlopen(url)
|
||||
with open(filename, 'wb') as f:
|
||||
while True:
|
||||
data = response.read(READ_DATA_CHUNK)
|
||||
if len(data) == 0:
|
||||
break
|
||||
f.write(data)
|
||||
|
||||
|
||||
def _load_attr_from_module(module, func_name):
|
||||
# Check if callable is defined in the module
|
||||
if func_name not in dir(module):
|
||||
|
|
@ -142,7 +129,7 @@ def _parse_repo_info(github):
|
|||
return repo_owner, repo_name, branch
|
||||
|
||||
|
||||
def _get_cache_or_reload(github, force_reload):
|
||||
def _get_cache_or_reload(github, force_reload, verbose=True):
|
||||
# Parse github repo information
|
||||
repo_owner, repo_name, branch = _parse_repo_info(github)
|
||||
|
||||
|
|
@ -155,13 +142,15 @@ def _get_cache_or_reload(github, force_reload):
|
|||
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
||||
|
||||
if use_cache:
|
||||
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
|
||||
if verbose:
|
||||
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
|
||||
else:
|
||||
cached_file = os.path.join(hub_dir, branch + '.zip')
|
||||
_remove_if_exists(cached_file)
|
||||
|
||||
url = _git_archive_link(repo_owner, repo_name, branch)
|
||||
_download_archive_zip(url, cached_file)
|
||||
sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file))
|
||||
download_url_to_file(url, cached_file, progress=False)
|
||||
|
||||
with zipfile.ZipFile(cached_file) as cached_zipfile:
|
||||
extraced_repo_name = cached_zipfile.infolist()[0].filename
|
||||
|
|
@ -255,7 +244,7 @@ def set_dir(d):
|
|||
|
||||
|
||||
Args:
|
||||
d: path to a local folder to save downloaded models & weights.
|
||||
d (string): path to a local folder to save downloaded models & weights.
|
||||
"""
|
||||
global hub_dir
|
||||
hub_dir = d
|
||||
|
|
@ -266,10 +255,10 @@ def list(github, force_reload=False):
|
|||
List all entrypoints available in `github` hubconf.
|
||||
|
||||
Args:
|
||||
github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
|
||||
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
|
||||
tag/branch. The default branch is `master` if not specified.
|
||||
Example: 'pytorch/vision[:hub]'
|
||||
force_reload: Optional, whether to discard the existing cache and force a fresh download.
|
||||
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
|
||||
Default is `False`.
|
||||
Returns:
|
||||
entrypoints: a list of available entrypoint names
|
||||
|
|
@ -280,7 +269,7 @@ def list(github, force_reload=False):
|
|||
# Setup hub_dir to save downloaded files
|
||||
_setup_hubdir()
|
||||
|
||||
repo_dir = _get_cache_or_reload(github, force_reload)
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, True)
|
||||
|
||||
sys.path.insert(0, repo_dir)
|
||||
|
||||
|
|
@ -299,11 +288,11 @@ def help(github, model, force_reload=False):
|
|||
Show the docstring of entrypoint `model`.
|
||||
|
||||
Args:
|
||||
github: Required, a string with format <repo_owner/repo_name[:tag_name]> with an optional
|
||||
github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
|
||||
tag/branch. The default branch is `master` if not specified.
|
||||
Example: 'pytorch/vision[:hub]'
|
||||
model: Required, a string of entrypoint name defined in repo's hubconf.py
|
||||
force_reload: Optional, whether to discard the existing cache and force a fresh download.
|
||||
model (string): a string of entrypoint name defined in repo's hubconf.py
|
||||
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
|
||||
Default is `False`.
|
||||
Example:
|
||||
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
|
||||
|
|
@ -311,7 +300,7 @@ def help(github, model, force_reload=False):
|
|||
# Setup hub_dir to save downloaded files
|
||||
_setup_hubdir()
|
||||
|
||||
repo_dir = _get_cache_or_reload(github, force_reload)
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, True)
|
||||
|
||||
sys.path.insert(0, repo_dir)
|
||||
|
||||
|
|
@ -333,14 +322,17 @@ def load(github, model, *args, **kwargs):
|
|||
Load a model from a github repo, with pretrained weights.
|
||||
|
||||
Args:
|
||||
github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
|
||||
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
|
||||
tag/branch. The default branch is `master` if not specified.
|
||||
Example: 'pytorch/vision[:hub]'
|
||||
model: Required, a string of entrypoint name defined in repo's hubconf.py
|
||||
*args: Optional, the corresponding args for callable `model`.
|
||||
force_reload: Optional, whether to force a fresh download of github repo unconditionally.
|
||||
model (string): a string of entrypoint name defined in repo's hubconf.py
|
||||
*args (optional): the corresponding args for callable `model`.
|
||||
force_reload (bool, optional): whether to force a fresh download of github repo unconditionally.
|
||||
Default is `False`.
|
||||
**kwargs: Optional, the corresponding kwargs for callable `model`.
|
||||
verbose (bool, optional): If False, mute messages about hitting local caches. Note that the message
|
||||
about first download is cannot be muted.
|
||||
Default is `True`.
|
||||
**kwargs (optional): the corresponding kwargs for callable `model`.
|
||||
|
||||
Returns:
|
||||
a single model with corresponding pretrained weights.
|
||||
|
|
@ -353,8 +345,10 @@ def load(github, model, *args, **kwargs):
|
|||
|
||||
force_reload = kwargs.get('force_reload', False)
|
||||
kwargs.pop('force_reload', None)
|
||||
verbose = kwargs.get('verbose', True)
|
||||
kwargs.pop('verbose', None)
|
||||
|
||||
repo_dir = _get_cache_or_reload(github, force_reload)
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, verbose)
|
||||
|
||||
sys.path.insert(0, repo_dir)
|
||||
|
||||
|
|
@ -369,7 +363,21 @@ def load(github, model, *args, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def _download_url_to_file(url, dst, hash_prefix, progress):
|
||||
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
r"""Download object at the given URL to a local path.
|
||||
|
||||
Args:
|
||||
url (string): URL of the object to download
|
||||
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
|
||||
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
|
||||
Default: None
|
||||
progress (bool, optional): whether or not to display a progress bar to stderr
|
||||
Default: True
|
||||
|
||||
Example:
|
||||
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
|
||||
|
||||
"""
|
||||
file_size = None
|
||||
# We use a different API for python2 since urllib(2) doesn't recognize the CA
|
||||
# certificates in older Python
|
||||
|
|
@ -385,6 +393,7 @@ def _download_url_to_file(url, dst, hash_prefix, progress):
|
|||
# We deliberately save it in a temp file and move it after
|
||||
# download is complete. This prevents a local working checkpoint
|
||||
# being overriden by a broken download.
|
||||
dst = os.path.expanduser(dst)
|
||||
dst_dir = os.path.dirname(dst)
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
||||
|
||||
|
|
@ -414,16 +423,20 @@ def _download_url_to_file(url, dst, hash_prefix, progress):
|
|||
if os.path.exists(f.name):
|
||||
os.remove(f.name)
|
||||
|
||||
def _download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
warnings.warn('torch.hub._download_url_to_file has been renamed to\
|
||||
torch.hub.download_url_to_file to be a public API,\
|
||||
_download_url_to_file will be removed in after 1.3 release')
|
||||
download_url_to_file(url, dst, hash_prefix, progress)
|
||||
|
||||
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
|
||||
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False):
|
||||
r"""Loads the Torch serialized object at the given URL.
|
||||
|
||||
If the object is already present in `model_dir`, it's deserialized and
|
||||
returned. The filename part of the URL should follow the naming convention
|
||||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
||||
digits of the SHA256 hash of the contents of the file. The hash is used to
|
||||
ensure unique names and to verify the contents of the file.
|
||||
If downloaded file is a zip file, it will be automatically
|
||||
decompressed.
|
||||
|
||||
If the object is already present in `model_dir`, it's deserialized and
|
||||
returned.
|
||||
The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
|
||||
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
|
||||
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
|
||||
|
|
@ -433,7 +446,13 @@ def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=Tr
|
|||
url (string): URL of the object to download
|
||||
model_dir (string, optional): directory in which to save the object
|
||||
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
|
||||
progress (bool, optional): whether or not to display a progress bar to stderr
|
||||
progress (bool, optional): whether or not to display a progress bar to stderr.
|
||||
Default: True
|
||||
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
|
||||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
||||
digits of the SHA256 hash of the contents of the file. The hash is used to
|
||||
ensure unique names and to verify the contents of the file.
|
||||
Default: False
|
||||
|
||||
Example:
|
||||
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
|
||||
|
|
@ -462,6 +481,19 @@ def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=Tr
|
|||
cached_file = os.path.join(model_dir, filename)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = HASH_REGEX.search(filename).group(1)
|
||||
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
hash_prefix = HASH_REGEX.search(filename).group(1) if check_hash else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
|
||||
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
||||
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
||||
# E.g. resnet18-5c106cde.pth which is widely used.
|
||||
if zipfile.is_zipfile(cached_file):
|
||||
with zipfile.ZipFile(cached_file) as cached_zipfile:
|
||||
members = cached_zipfile.infolist()
|
||||
if len(members) != 1:
|
||||
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
|
||||
cached_zipfile.extractall(model_dir)
|
||||
extraced_name = members[0].filename
|
||||
cached_file = os.path.join(model_dir, extraced_name)
|
||||
|
||||
return torch.load(cached_file, map_location=map_location)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user