mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17705 Differential Revision: D14338380 Pulled By: ailzhang fbshipit-source-id: d53eece30bede88a642e718ee6f829ba29c7d1c4
102 lines
3.6 KiB
ReStructuredText
102 lines
3.6 KiB
ReStructuredText
torch.hub
|
|
===================================
|
|
Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
|
|
|
|
Publishing models
|
|
-----------------
|
|
|
|
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
|
|
to a github repository by adding a simple ``hubconf.py`` file;
|
|
|
|
``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function with
|
|
the following signature.
|
|
|
|
::
|
|
|
|
def entrypoint_name(pretrained=False, *args, **kwargs):
|
|
...
|
|
|
|
How to implement an entrypoint?
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
Here is a code snippet from pytorch/vision repository, which specifies an entrypoint
|
|
for ``resnet18`` model. You can see a full script in
|
|
`pytorch/vision repo <https://github.com/pytorch/vision/blob/master/hubconf.py>`_
|
|
|
|
::
|
|
|
|
dependencies = ['torch', 'math']
|
|
|
|
def resnet18(pretrained=False, *args, **kwargs):
|
|
"""
|
|
Resnet18 model
|
|
pretrained (bool): a recommended kwargs for all entrypoints
|
|
args & kwargs are arguments for the function
|
|
"""
|
|
######## Call the model in the repo ###############
|
|
from torchvision.models.resnet import resnet18 as _resnet18
|
|
model = _resnet18(*args, **kwargs)
|
|
######## End of call ##############################
|
|
# The following logic is REQUIRED
|
|
if pretrained:
|
|
# For weights saved in local repo
|
|
# model.load_state_dict(<path_to_saved_file>)
|
|
|
|
# For weights saved elsewhere
|
|
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
|
model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
|
|
return model
|
|
|
|
- ``dependencies`` variable is a **list** of package names required to to run the model.
|
|
- Pretrained weights can either be stored local in the github repo, or loadable by
|
|
``model_zoo.load()``.
|
|
- ``pretrained`` controls whether to load the pre-trained weights provided by repo owners.
|
|
- ``args`` and ``kwargs`` are passed along to the real callable function.
|
|
- Docstring of the function works as a help message, explaining what does the model do and what
|
|
are the allowed arguments.
|
|
- Entrypoint function should **ALWAYS** return a model(nn.module).
|
|
|
|
Important Notice
|
|
^^^^^^^^^^^^^^^^
|
|
|
|
- The published models should be at least in a branch/tag. It can't be a random commit.
|
|
|
|
Loading models from Hub
|
|
-----------------------
|
|
|
|
Users can load the pre-trained models using ``torch.hub.load()`` API.
|
|
|
|
|
|
.. automodule:: torch.hub
|
|
.. autofunction:: load
|
|
|
|
Here's an example loading ``resnet18`` entrypoint from ``pytorch/vision`` repo.
|
|
|
|
::
|
|
|
|
hub_model = hub.load(
|
|
'pytorch/vision:master', # repo_owner/repo_name:branch
|
|
'resnet18', # entrypoint
|
|
1234, # args for callable [not applicable to resnet]
|
|
pretrained=True) # kwargs for callable
|
|
|
|
Where are my downloaded model & weights saved?
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
The locations are used in the order of
|
|
|
|
- hub_dir: user specified path. It can be set in the following ways:
|
|
- Setting the environment variable ``TORCH_HUB_DIR``
|
|
- Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)``
|
|
- ``~/.torch/hub``
|
|
|
|
.. autofunction:: set_dir
|
|
|
|
Caching logic
|
|
^^^^^^^^^^^^^
|
|
|
|
By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in ``hub_dir``.
|
|
|
|
Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete
|
|
the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
|
|
when updates are published to the same branch, users can keep up with the latest release.
|