Commit Graph

307 Commits

Author SHA1 Message Date
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Jason Lu
bc88028e8e Back out "Reland "Make adding buffers more like adding parameters (#104069)" (#106224)" (#106743)
Summary:
Original commit changeset: 81319beb97f3

Original Phabricator Diff: D47961182

Test Plan: revert to maintain backward compat with legacy ads_dper3 production package. Read details in: S357822

Reviewed By: atuljangra

Differential Revision: D48131623

@diff-train-skip-merge
(D48131623 landed internally)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106743
Approved by: https://github.com/malfet
2023-08-08 15:27:34 +00:00
Mikayla Gawarecki
d8e5f2aa6d Reland "Make adding buffers more like adding parameters (#104069)" (#106224)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106224
Approved by: https://github.com/atalman, https://github.com/albanD
2023-07-31 17:18:56 +00:00
Mikayla Gawarecki
ca7ece9b50 [easy] improve hint on error message in nn.Module.load_state_dict (#106042)
Fix #105963

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106042
Approved by: https://github.com/albanD
2023-07-27 19:56:02 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Justin Chu
79c5e33349 [BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105436
Approved by: https://github.com/malfet, https://github.com/albanD
2023-07-21 07:38:46 +00:00
Andrey Talman
c6653b65d8 Back out "Make adding buffers more like adding parameters (#104069)" (#105581)
Summary:
D47537831 is breaking pyper tests: https://fb.workplace.com/groups/802176577445480/posts/1018902842439518/

with `TypeError: register_buffer() takes 3 positional arguments but 4 were given`

Original commit changeset: d4b4069fbd38

Original Phabricator Diff: D47537831

Test Plan:
```
buck2 run //caffe2/torch/fb/training_toolkit/integration_tests/training_lifecycle/cogwheel_tests/pyper_release_v2:cogwheel_smallworld_inline_cvr_infer_pyper_pyper__canary_offline_training-launcher -- --run-harness-in-tupperware --build-fbpkg ads_dper3 --build-fbpkg training_platform
```

Reviewed By: atalman

Differential Revision: D47600140

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105581
Approved by: https://github.com/mikaylagawarecki
2023-07-20 03:39:53 +00:00
ekamiti
32d422f335 Make adding buffers more like adding parameters (#104069)
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.

Fixes #35735

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104069
Approved by: https://github.com/mikaylagawarecki
2023-07-17 17:59:05 +00:00
Jenny
e095716161 Add a note for Incorrect signature in nn.Module.register_full_backwar… (#104964)
…d_pre_hook

Fixes #102645

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104964
Approved by: https://github.com/albanD
2023-07-11 16:24:13 +00:00
Mikayla Gawarecki
1ad435772b Added option to always call nn.Module global/non-global forward hooks (#104278)
Fix #103997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104278
Approved by: https://github.com/albanD
2023-07-10 18:58:07 +00:00
Sergei Vorobev
ee19121931 Change nn.Module.__getattr__ return type to Any (#104321)
When working with a highly-dynamic python code it's not always possible to express the static types. However if we consider the end-user experience for somebody who uses both pytorch and a static type checker (mypy, pyright), we should error on the side of being ergonomic and not technically correct.

The  `nn.Module.__getattr__` is one of the such examples: on paper the return type is correct. In practice the community would benefit from having `Any` as a return type because it would avoid littering the idiomatic pytorch code with `cast`, `# type: ignore`, `assert`, `isinstance`, etc.

Some evidences:
- linked in the comment thread on pyright bug tracker https://github.com/microsoft/pyright/issues/4213
- `pyre` type checker steps outside of the normal type checking practices and special-cases `registrer_buffer()` in part to avoid this problem. https://pyre-check.org/docs/features/ This is not a very scalable solution since type-checkers generally aim at adhering to the spec (various typing PEPs).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104321
Approved by: https://github.com/kit1980, https://github.com/albanD
2023-06-28 16:14:36 +00:00
Mikayla Gawarecki
b93ed8164e Add non-recursive module.to_empty option (#104197)
Fixes https://github.com/pytorch/pytorch/issues/97049, related to https://github.com/pytorch/pytorch/issues/104187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104197
Approved by: https://github.com/albanD
2023-06-26 21:47:22 +00:00
Mikayla Gawarecki
d1cecd9c32 Add assign kwarg to module.load_state_dict (#102212)
Fixes #64601 and #98906

Adds an `assign` argument to `load_state_dict` that loads params/buffers by assignment instead of doing `param.copy_(param_from_state_dict)`.

Primarily intended to remove the need for the `.to_empty()` in

```
with torch.device('meta'):
    m = SomeModule()
m.to_empty()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict)
```

so we can instead do

```
with torch.device('meta'):
    m = SomeModule()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict, assign=True)
```

**A problem with this PR for the case where the model is initialized on meta is what happens to nonpersistent buffers/params corresponding to keys missing from the state dict?**
What happens in the case where `load_state_dict(state_dict, strict=False, assign=True)` and the state_dict is missing some keys? The corresponding params missing from the `state_dict` and nonpersistent buffers would still be on `meta` and need to be manually initialized. However, I don't think we offer an API that would initialize these.

One solution would be to make these empty tensors but it might not be semantically correct...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102212
Approved by: https://github.com/albanD
2023-06-15 18:41:00 +00:00
Mark Saroufim
bf52d570d9 torch.save/load torch.compiled models (#97565)
Opening this so I can discuss with @albanD

I built a proof of concept of an in place API for an nn.Module that allows us to save and load a torch.compiled model with no issues https://github.com/msaroufim/mlsys-experiments/blob/main/save-compiled-model.py

So users can run` model.compile()` and then run `torch.save(model, "model.pt")` and `torch.load(model, "model.pt)` with no issues unlike the rather strange current suggestion we give to users which is `opt_mod = torch.compile(mod); torch.save(mod, "model.pt")`

Right now I'm trying to extend this to work for nn.modules more generally

TODO: Failing tests
* [x] torch.jit.load -> issue was because of aliasing `__call__` to `_call_impl`, _call_impl used to be skipped when now it lo longer is so expanded the skip check. I added an explicit `torch.jit.load()` test now which @davidberard98 suggested
* [x] functorch seems to be a flake - ran locally and it worked `pytest functorch/test_eager_transforms.py`
* [x] a test infra flake - `test_testing.py::TestImports::test_no_mutate_global_logging_on_import_path_functorch`
* [x] It seems like I broke inlining in dynamo though `python -m pytest test/dynamo/test_dynamic_shapes.py -k test_issue175` chatting with Voz about it but still not entirely sure how to fix - found a workaround after chatting with @yanboliang
* [x] `pytest test/dynamo/test_modules.py` and `test/dynamo/test_dynamic_shapes` `test/dynamo/test_misc.py` seem to be failing in CI but trying it out locally they all pass tests passed with 0 failures
* [x] `pytest test/profiler/test_profiler_tree.py ` these tests have ProfilerTrees explicitly printed and will now break if __call__ is not in tree - ran with `EXPECT_ACCEPT=1`
* [x] `pytest test/test_torch.py::TestTorch::test_typed_storage_deprecation_warning` a flake, ran this locally and it works fine
* [x] I reverted my changes to `_dynamo/nn_module.py` since it looks like @wconstab is now directly handling `_call_impl` there but this is triggering an infinite inlining which is crashing
* [x] Tried out to instead override `__call__`, python doesnt like this though https://github.com/pytorch/pytorch/pull/97565#issuecomment-1524570439

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97565
Approved by: https://github.com/aaronenyeshi, https://github.com/albanD, https://github.com/voznesenskym
2023-05-05 03:57:49 +00:00
PyTorch MergeBot
04d67e20a7 Revert "torch.save/load torch.compiled models (#97565)"
This reverts commit 87f08d717e.

Reverted https://github.com/pytorch/pytorch/pull/97565 on behalf of https://github.com/clee2000 due to sorry but I think this breaks dynamo tests 87f08d717e ([comment](https://github.com/pytorch/pytorch/pull/97565#issuecomment-1535103171))
2023-05-04 17:07:33 +00:00
Mark Saroufim
87f08d717e torch.save/load torch.compiled models (#97565)
Opening this so I can discuss with @albanD

I built a proof of concept of an in place API for an nn.Module that allows us to save and load a torch.compiled model with no issues https://github.com/msaroufim/mlsys-experiments/blob/main/save-compiled-model.py

So users can run` model.compile()` and then run `torch.save(model, "model.pt")` and `torch.load(model, "model.pt)` with no issues unlike the rather strange current suggestion we give to users which is `opt_mod = torch.compile(mod); torch.save(mod, "model.pt")`

Right now I'm trying to extend this to work for nn.modules more generally

TODO: Failing tests
* [x] torch.jit.load -> issue was because of aliasing `__call__` to `_call_impl`, _call_impl used to be skipped when now it lo longer is so expanded the skip check. I added an explicit `torch.jit.load()` test now which @davidberard98 suggested
* [x] functorch seems to be a flake - ran locally and it worked `pytest functorch/test_eager_transforms.py`
* [x] a test infra flake - `test_testing.py::TestImports::test_no_mutate_global_logging_on_import_path_functorch`
* [x] It seems like I broke inlining in dynamo though `python -m pytest test/dynamo/test_dynamic_shapes.py -k test_issue175` chatting with Voz about it but still not entirely sure how to fix - found a workaround after chatting with @yanboliang
* [x] `pytest test/dynamo/test_modules.py` and `test/dynamo/test_dynamic_shapes` `test/dynamo/test_misc.py` seem to be failing in CI but trying it out locally they all pass tests passed with 0 failures
* [x] `pytest test/profiler/test_profiler_tree.py ` these tests have ProfilerTrees explicitly printed and will now break if __call__ is not in tree - ran with `EXPECT_ACCEPT=1`
* [x] `pytest test/test_torch.py::TestTorch::test_typed_storage_deprecation_warning` a flake, ran this locally and it works fine
* [x] I reverted my changes to `_dynamo/nn_module.py` since it looks like @wconstab is now directly handling `_call_impl` there but this is triggering an infinite inlining which is crashing
* [x] Tried out to instead override `__call__`, python doesnt like this though https://github.com/pytorch/pytorch/pull/97565#issuecomment-1524570439

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97565
Approved by: https://github.com/aaronenyeshi, https://github.com/albanD
2023-05-04 16:23:12 +00:00
Yanli Zhao
9bc03db670 Move nn.module state dict pre hook (#98964)
Some modules like lazyModule may override '_save_to_state_dict()', in this case, pre_state_dict hook will not be called. So move the pre_state_dict hook out from '_save_to_state_dict()' to make sure the pre hook could be called

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98964
Approved by: https://github.com/albanD
2023-04-26 16:51:13 +00:00
Kazuaki Ishizaki
a531a464fd Fix typos under torch/nn directory (#97594)
This PR fixes typos in comments of `.py` files under `torch/nn` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97594
Approved by: https://github.com/dagitses, https://github.com/kit1980
2023-04-10 22:07:15 +00:00
Sergii Dymchenko
477f3f555f Simplify by using yield from (#97831)
The issues were found by SIM104 flake8-simplify in a local run.

I'll take a look on adding the check to the CI separately.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97831
Approved by: https://github.com/Skylion007
2023-03-29 19:15:24 +00:00
Will Constable
2f6a371ae9 Revert "Optimize nn.Module __call__ fast path for dynamo (#95931)" (#96242)
Reverting due to concerns over silent unsoundness (skipped hooks) if users have directly added hooks dicts without using official torch APIs.

This reverts commit 26045336ca.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96242
Approved by: https://github.com/albanD
2023-03-10 01:05:01 +00:00
PyTorch MergeBot
6bbae86253 Revert "Fix hooks handling for unpickled nnmodule (#96224)"
This reverts commit 8ca264ef36.

Reverted https://github.com/pytorch/pytorch/pull/96224 on behalf of https://github.com/ezyang due to inductor regression
2023-03-08 13:01:16 +00:00
Will Constable
8ca264ef36 Fix hooks handling for unpickled nnmodule (#96224)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96224
Approved by: https://github.com/albanD
2023-03-08 05:33:15 +00:00
Will Constable
26045336ca Optimize nn.Module __call__ fast path for dynamo (#95931)
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks.

It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards.  (But this observer change seems more involved...)

Idea: maintain a flag, and keep it up to date whenever adding or
removing hooks. Use the flag rather than dict checks to enter the call fast path.
  - need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal.
  - also need to handle the flag in ScriptModule which still uses the python call impl when called from python.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95931
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
2023-03-04 15:09:40 +00:00
Jane Xu
e5b9d98752 Rephrase zero_grad docs (#95643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95643
Approved by: https://github.com/albanD
2023-02-28 22:04:23 +00:00
Xuehai Pan
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Aaron Gokaslan
748bac8757 [BE]: Apply pyupgrade yield from and unit test alias upgrades (#94309)
Applies some more harmless pyupgrades. This one gets rid of deprecated aliases in unit_tests and more upgrades yield for loops into yield from generators which are more performance and propagates more information / exceptions from original generator. This is the modern recommended way of forwarding generators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94309
Approved by: https://github.com/albanD
2023-02-07 20:08:58 +00:00
Vivswan Shah
8c1ee89f19 Added super init to Module (#91819)
Added super init to Module for complex user modules derived from multiple python classes.
And by adding the super __init__ call at the end so it doesn't change any functionality of Module class.

I am working on building a module for simulating analog neural network on PyTorch.
and this small change is really useful for that and we can definitely think of many other useful cases especially for more module or mro hierarchy.

Issues: https://github.com/pytorch/pytorch/issues/28746, https://github.com/pytorch/pytorch/issues/48626, https://github.com/pytorch/pytorch/issues/61662, https://github.com/pytorch/pytorch/issues/74036
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91819
Approved by: https://github.com/albanD
2023-02-01 22:17:59 +00:00
Jane Xu
b90496eef5 [nn] zero_grad() set_to_none default True (#92731)
Attempts to fix #92656

BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731
Approved by: https://github.com/ngimel
2023-01-26 01:04:28 +00:00
PyTorch MergeBot
09eb4c2a70 Revert "Update Module.__setattr__ to respect property setters (#92044)"
This reverts commit 0c8f4b5893.

Reverted https://github.com/pytorch/pytorch/pull/92044 on behalf of https://github.com/saitcakmak due to Caused regressions in a Meta internal model
2023-01-21 02:39:21 +00:00
kshitij12345
387ca598a1 [nn] full_backward{_pre}_hook: warning for Module returning dict, list, etc (#87547)
Fixes https://github.com/pytorch/pytorch/issues/87540

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87547
Approved by: https://github.com/albanD
2023-01-18 06:28:00 +00:00
Sait Cakmak
0c8f4b5893 Update Module.__setattr__ to respect property setters (#92044)
Fixes #52664. Checks if the attribute is a property that defines a setter and uses fset in __setattr__ rather than registering an inaccessible module / parameter.

This is BC-breaking as the attribute setters on nn.Module properties used to be ignored and now will be called properly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92044
Approved by: https://github.com/albanD
2023-01-17 20:00:06 +00:00
Matthew Hoffman
a26e5e21b5 Improve type hints for Module forward hooks (#92061)
Fixes #91654.

Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`.

The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`.

---

Here are some examples of:
1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints
2. mypy throwing errors d.t. incorrect `hook` signatures
3. false negatives of pre-hooks being accepted as forward hooks
4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs`

```python
from typing import Any, Dict, Tuple

import torch
from torch import nn

def forward_pre_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> None:
    ...

def forward_pre_hook_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
    ...

def forward_pre_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> None:
    ...

def forward_pre_hook_with_kwargs_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
    ...

def forward_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

def forward_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_with_kwargs_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

model = nn.Module()

# OK
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_pre_hook(forward_pre_hook_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True)

model.register_forward_hook(forward_hook)
model.register_forward_hook(forward_hook_return_output)
model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True)
model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True)

# mypy(error): [arg-type]
model.register_forward_pre_hook(forward_hook)
model.register_forward_pre_hook(forward_hook_return_output)
model.register_forward_pre_hook(forward_hook_with_kwargs)
model.register_forward_pre_hook(forward_hook_with_kwargs_return_output)

model.register_forward_hook(forward_pre_hook)
model.register_forward_hook(forward_pre_hook_return_input)

# false negatives
model.register_forward_hook(forward_pre_hook_with_kwargs)
model.register_forward_hook(forward_pre_hook_with_kwargs_return_input)

model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False)
...
```

---

Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types):

```python
T = TypeVar("T", bound="Module")

class Module:

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[False] = ...,
    ) -> RemovableHandle:
        ...

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[True] = ...,
    ) -> RemovableHandle:
        ...

    def register_forward_hook(...):
        ...

```

which would:

1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above)
2. implicitly define the [fallback `bool` signature](https://github.com/python/mypy/issues/6113#issuecomment-1266186192) e.g. to handle if a non-literal is provided for `with_kwargs`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92061
Approved by: https://github.com/albanD
2023-01-13 15:45:42 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Douwe den Blanken
b285f1080f Fix small typo in comment (#91247)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91247
Approved by: https://github.com/albanD
2022-12-21 19:45:39 +00:00
Rohan Varma
9c80f13692 [Resubmit] state_dict_pre_hook (#90435)
Resubmit of https://github.com/pytorch/pytorch/pull/88541 which got stale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90435
Approved by: https://github.com/fegin
2022-12-08 07:54:14 +00:00
Shen Li
f5d18574a3 Allow Module forward-pre and forward hooks to take kwargs (#89389)
closes #35643

This PR is mostly borrowed from #82042. Thanks @Padarn for implementing
the first version and debugging into the errors.

Based on the discussion in #82042 this PR adds a with_kwargs
argument to register_forward_pre_hook and register_forward_hook
methods. When the arg is set to true, the provided hook must accept
kwargs args. Under the hook, this PR adds a
`_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs`
set to keep track of which hooks accept kwargs.

Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89389
Approved by: https://github.com/soulitzer
2022-11-23 02:43:32 +00:00
soulitzer
6b521bbf35 Prevent module full_backward_hook from erroring in double backward (#88357)
Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes https://github.com/pytorch/pytorch/issues/88312

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88357
Approved by: https://github.com/albanD
2022-11-16 19:27:30 +00:00
Samantha Andow
87238e6491 [nn] add remove_duplicate flag to named_parameters (#759) (#88090)
Summary:
X-link: https://github.com/pytorch/torchrec/pull/759

Since the remove_duplicate flag was added to named_buffers in D39493161 (c12f829cce), this adds the same flag to named_parameters

Test Plan:
python test/test_nn.py -k test_buffers_and_named_buffers

OSS Tests

Differential Revision: D40801899

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88090
Approved by: https://github.com/albanD
2022-11-09 00:09:20 +00:00
Kazuaki Ishizaki
2ddefbdc3c Fix typos used in documents under torch directory (#88300)
This PR fixes typos, in comments of Python files, that are found from a search box at https://pytorch.org/docs/master/search.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88300
Approved by: https://github.com/lezcano
2022-11-02 09:38:13 +00:00
Shen Li
82698b8954 Add prepend argument to nn.Module hooks (#87370)
cc @ezyang @gchanan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87370
Approved by: https://github.com/soulitzer
2022-10-25 19:18:04 +00:00
Antonio Kim
6b59d9b566 Fix registration hooks (#87369)
There is a bug in the implementation of the registration hooks introduced in https://github.com/pytorch/pytorch/pull/86148 whereby if the hook returns a tensor, then the short circuiting logic:
```
value = hook(self, name, value) or value
```
Raises an exception
```
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
```
Fixing the logic so that it only checks to see if the value is `None` before overriding

Fixes #85837

CC: @albanD @jbschlosser
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87369
Approved by: https://github.com/albanD
2022-10-21 05:12:25 +00:00
Kshiteej K
54ee95c8ec [nn] module: full_backward_pre_hook (#86700)
Fixes https://github.com/pytorch/pytorch/issues/42824

* [x] Test
* [x] Doc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86700
Approved by: https://github.com/soulitzer
2022-10-13 17:36:39 +00:00
Antonio Kim
09a676f639 Add hooks for register_buffer/module/parameter (#86148)
As described in the issue, this PR adds hooks to be run when `register_parameter`, `register_buffer` and `register_module` are called.

Fixes #85837

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86148
Approved by: https://github.com/albanD
2022-10-12 20:57:22 +00:00
Jerry Zhang
c12f829cce [nn] Add remove_duplicate flag to named_buffers (#674) (#85903)
Summary:
X-link: https://github.com/pytorch/torchrec/pull/674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84984

this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases
ghstack-source-id: 168589597

Test Plan:
python test/test_nn.py -k test_buffers_and_named_buffers

Imported from OSS

Reviewed By: albanD

Differential Revision: D39493161

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85903
Approved by: https://github.com/albanD
2022-10-11 18:49:09 +00:00
Weiyi Zheng
b2311192e6 [NN module] speed up _load_from_state_dict (#85743)
Fixes #61398

The original implementation is very slow when the state_dict.keys() is long. This PR only passes relevant keys to the child module.

existing test passes: `pytest test/test_nn.py -k state_dict`
I couldn't figure out a good way to write a new test for this new behavior. Had a new snippet, but it will be flaky if integrated into the main CI because it's a timing based check.
But I can verify that the test took 30s to run, after this PR it only takes 0.5s.

```python
    def test_load_state_dict_large(self):
        # construct a module with 4 levels of module, 10 linear each, leads to 10k items in the dictionary
        import copy
        import time
        base_module = nn.Linear(1,1)
        model = base_module
        for level in range(4):
           model = nn.Sequential(*[copy.deepcopy(model) for _ in range(10)])
        state_dict = model.state_dict()
        self.assertEqual(len(state_dict), 20000)
        st = time.time()
        model.load_state_dict(state_dict, strict=True)
        strict_load_time = time.time() - st
        # it took 0.5 seconds to
        self.assertLess(strict_load_time, 10)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85743
Approved by: https://github.com/albanD
2022-09-28 15:26:03 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
Ian Barber
76d5699e13 Fix use-generator lint warnings in module.py (#83700)
% pylint --disable=all --enable=R1729 torch/nn/modules/module.py
Verified in pylint 2.14.5

--------------------------------------------------------------------
Your code has been rated at 10.00/10 (previous run: 10.00/10, +0.00)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83700
Approved by: https://github.com/kit1980, https://github.com/albanD
2022-08-19 02:51:44 +00:00