Summary:
Fixes https://github.com/pytorch/pytorch/issues/43622
- Moves the model loading part of `torch.hub.load()` into a new `torch.hub.load_local()` function that takes in a path to a local directory that contains a `hubconf.py` instead of a repo name.
- Refactors `torch.hub.load()` so that it now calls `torch.hub.load_local()` after downloading and extracting the repo.
- Updates `torch.hub` docs to include the new function + minor fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44204
Reviewed By: malfet
Differential Revision: D23817429
Pulled By: ailzhang
fbshipit-source-id: 788fd83c87a94f487b558715b2809d346ead02b2
Summary:
Fixes https://github.com/pytorch/pytorch/issues/38401
* `torch.hub.load_state_dict_from_url()` now also downloads to `$TORCH_HOME/hub/checkpoints` instead of `$TORCH_HOME/checkpoints` like `torch.hub.load()` and others.
* Make `hub_dir` private, add and use `get_dir()` instead.
Also updated docs. Did not see a need for additional unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38969
Differential Revision: D21725880
Pulled By: ailzhang
fbshipit-source-id: 58cc6b32ddbda91e58c1c1433cc3916223556ea1
Summary:
Resubmit of https://github.com/pytorch/pytorch/pull/25980.
Our old serialization was in tar (like `resnet18-5c106cde.pth` was in this format) so let's only support automatically unzip if checkpoints are zipfiles.
We can still manage to get it work with tarfile, but let's delay it when there's an ask.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26723
Differential Revision: D17551795
Pulled By: ailzhang
fbshipit-source-id: 00b4e7621f1e753ca9aa07b1fe356278c6693a1e
Summary:
This PR does a few small improvements to hub:
- add support `verbose` option in `torch.load`. Note that this mutes hitting cache message but keeps the message of first download as suggested. fixes https://github.com/pytorch/pytorch/issues/24791
- add support loading state dict from tar file or zip file in `torch.hub.load_state_dict_from_url`.
- add `torch.hub.download_url_to_file` as public API, and add BC bit for `_download_url_to_file`.
- makes hash check in filename optional through `check_hash`, many users don't have control over the naming, relaxing this constraint could potentially avoid duplicating download code on user end.
- move pytorch CI off `pytorch/vision` and use `ailzhang/torchhub_example` as a dedicated test repo. fixes https://github.com/pytorch/pytorch/issues/25865
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25980
Differential Revision: D17495679
Pulled By: ailzhang
fbshipit-source-id: 695df3e803ad5f9ca33cfbcf62f1a4f8cde0dbbe
Summary:
A few improvements while doing bert model
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19247
Differential Revision: D14989345
Pulled By: ailzhang
fbshipit-source-id: f4846813f62b6d497fbe74e8552c9714bd8dc3c7
Summary:
* `torch.hub.list('pytorch/vision')` - show all available hub models in `pytorch/vision`
* `torch.hub.show('pytorch/vision', 'resnet18')` - show docstring & example for `resnet18` in `pytorch/vision`
* Moved `torch.utils.model_zoo.load_url` to `torch.hub.load_state_dict_from_url` and deprecate `torch.utils.model_zoo`
* We have too many env to control where the cache dir is, it's not very necessary. I actually want to unify `TORCH_HUB_DIR`, `TORCH_HOME` and `TORCH_MODEL_ZOO`, but haven't done it. (more suggestions are welcome!)
* Simplify `pytorch/vision` example in doc, it was used to show how how hub entrypoint can be written so had some confusing unnecessary args.
An example of hub usage is shown below
```
In [1]: import torch
In [2]: torch.hub.list('pytorch/vision', force_reload=True)
Downloading: "https://github.com/pytorch/vision/archive/master.zip" to /private/home/ailzhang/.torch/hub/master.zip
Out[2]: ['resnet18', 'resnet50']
In [3]: torch.hub.show('pytorch/vision', 'resnet18')
Using cache found in /private/home/ailzhang/.torch/hub/vision_master
Resnet18 model
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
In [4]: model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
Using cache found in /private/home/ailzhang/.torch/hub/vision_master
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18758
Differential Revision: D14883651
Pulled By: ailzhang
fbshipit-source-id: 6db6ab708a74121782a9154c44b0e190b23e8309
Summary:
Added a few examples and explains to how publish/load models.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14862
Differential Revision: D13384790
Pulled By: ailzhang
fbshipit-source-id: 008166e84e59dcb62c0be38a87982579524fb20e
Summary:
[Edit: after applied colesbury 's suggestions]
* Hub module enable users to share code + pretrained weights through github repos.
Example usage:
```
hub_model = hub.load(
'ailzhang/vision:hub', # repo_owner/repo_name:branch
'wrapper1', # entrypoint
1234, # args for callable [not applicable to resnet18]
pretrained=True) # kwargs for callable
```
* Protocol on repo owner side: example https://github.com/ailzhang/vision/tree/hub
* The "published" models should be at least in a branch/tag. It can't be a random commit.
* Repo owner should have the following field defined in `hubconf.py`
* function/entrypoint with function signature `def wrapper1(pretrained=False, *args, **kwargs):`
* `pretrained` allows users to load pretrained weights from repo owner.
* `args` and `kwargs` are passed to the callable `resnet18`, repo owner should clearly specify their help message in the docstring
```
def wrapper1(pretrained=False, *args, **kwargs):
"""
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
"""
from torchvision.models.resnet import resnet18
model = resnet18(*args, **kwargs)
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
if pretrained:
model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
return model
```
* Hub_dir
* `hub_dir` specifies where the intermediate files/folders will be saved. By default this is `~/.torch/hub`.
* Users can change it by either setting the environment variable `TORCH_HUB_DIR` or calling `hub.set_dir(PATH_TO_HUB_DIR)`.
* By default, we don't cleanup files after loading so that users can use cache next time.
* Cache logic :
* We used the cache by default if it exists in `hub_dir`.
* Users can force a fresh reload by calling `hub.load(..., force_reload=True)`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12228
Differential Revision: D10511470
Pulled By: ailzhang
fbshipit-source-id: 12ac27f01d33653f06b2483655546492f82cce38