Commit Graph

20 Commits

Author SHA1 Message Date
Shen Li
f5d18574a3 Allow Module forward-pre and forward hooks to take kwargs (#89389)
closes #35643

This PR is mostly borrowed from #82042. Thanks @Padarn for implementing
the first version and debugging into the errors.

Based on the discussion in #82042 this PR adds a with_kwargs
argument to register_forward_pre_hook and register_forward_hook
methods. When the arg is set to true, the provided hook must accept
kwargs args. Under the hook, this PR adds a
`_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs`
set to keep track of which hooks accept kwargs.

Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89389
Approved by: https://github.com/soulitzer
2022-11-23 02:43:32 +00:00
soulitzer
6b521bbf35 Prevent module full_backward_hook from erroring in double backward (#88357)
Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes https://github.com/pytorch/pytorch/issues/88312

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88357
Approved by: https://github.com/albanD
2022-11-16 19:27:30 +00:00
Kshiteej K
54ee95c8ec [nn] module: full_backward_pre_hook (#86700)
Fixes https://github.com/pytorch/pytorch/issues/42824

* [x] Test
* [x] Doc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86700
Approved by: https://github.com/soulitzer
2022-10-13 17:36:39 +00:00
anjali411
0183c1e336 Add __all__ to torch.utils submodules (#85331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85331
Approved by: https://github.com/albanD
2022-09-27 14:45:26 +00:00
soulitzer
1cafb1027f Fix leak when create_graph and full backward hook registered (#82788)
Fixes #82528
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82788
Approved by: https://github.com/albanD
2022-08-05 15:35:36 +00:00
Rodrigo Berriel
b71f01f70d Fix full backward hook when grad is disabled (#65335)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/59901. See discussion in the issue.

cc albanD soulitzer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65335

Reviewed By: malfet

Differential Revision: D31055865

Pulled By: albanD

fbshipit-source-id: 53605df62bc73c99d8908248087ab400b81ac495
2021-09-20 13:31:19 -07:00
Sigmund_Rolfsjord
8b12c8e8b3 Fixes: register_full_backward_hook crash if first argument don't require a gradient (#57944) (#57945)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/57944

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57945

Reviewed By: mruberry

Differential Revision: D28351929

Pulled By: albanD

fbshipit-source-id: d0db898e6bf13d1877cd81892a5a65c7854c8102
2021-05-11 15:07:35 -07:00
albanD
22b151a3ba Make sure full backward hook fire when no input requires grad (#56693)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/56380

BC-breaking note:
This changes the behavior of full backward hooks as they will now fire properly even if no input to the Module require gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56693

Reviewed By: ezyang

Differential Revision: D27947030

Pulled By: albanD

fbshipit-source-id: e8353d769ba5a2c1b6bdf3b64e2d61308cf624a2
2021-04-23 08:46:49 -07:00
Chester Liu
f6df18f6ca Clean up future imports for Python 2 (#53349)
Summary:
See https://github.com/pytorch/pytorch/issues/42919

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53349

Reviewed By: malfet

Differential Revision: D27039089

Pulled By: bugra

fbshipit-source-id: 8063dc184248604506a8dbb1bcb73da8ec85bb18
2021-03-14 15:56:13 -07:00
albanD
ccd646696b Fix Module backward hooks for all Tensor inputs/outputs (#46163)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/598

This is BC-breaking as we now explicitly don't call the hook when there are not Tensors at the top level of the output.
This feature was not working anyways as the returned grad_input/grad_output were wrong (not respecting the output structure and wrong inputs for multi-Node Module).

This is also BC-breaking as we now report the correct gradients for `nn.Module`s that contain multiple autograd `Node`s while we use to return bad results before.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46163

Reviewed By: ailzhang, mruberry

Differential Revision: D24894180

Pulled By: albanD

fbshipit-source-id: e1b5d193d2818eb2f51e2a2722c7405c8bd13c2b
2020-12-18 09:04:36 -08:00
Xiang Gao
20ac736200 Remove py2 compatible future imports (#44735)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44735

Reviewed By: mruberry

Differential Revision: D23731306

Pulled By: ezyang

fbshipit-source-id: 0ba009a99e475ddbe22981be8ac636f8a1c8b02f
2020-09-16 12:55:57 -07:00
Ralf Gommers
da32bf4cc6 Move type annotations for remaining torch.utils stub files inline (#43406)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43406

Reviewed By: mruberry

Differential Revision: D23319736

Pulled By: malfet

fbshipit-source-id: e25fbb49f27aa4893590b022441303d6d98263a9
2020-08-31 18:44:09 -07:00
olramde
d770fbc1d2 Some modifications to improve readability (#31352)
Summary:
In the long string, formalstring thinks it is good to have a name.

When using dict, literal is better for readability and faster than dict constructor.

I always appreciate your efforts in creating the world's best frameworks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31352

Differential Revision: D19191967

Pulled By: ngimel

fbshipit-source-id: 21f063b163b67de8cf9761a4db5991f74318e991
2020-01-02 12:48:34 -08:00
Yangqing Jia
c47f680086 arc lint torch/utils (#13141)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13141

This is an example diff to show what lint rules are being applied.

Reviewed By: mingzhe09088

Differential Revision: D10858478

fbshipit-source-id: cbeb013f10f755b0095478adf79366e7cf7836ff
2018-10-25 14:59:03 -07:00
Edward Yang
3bfa7258b3 Don't serialize hooks (#11705)
Summary:
Fixes #11683.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11705

Differential Revision: D9833057

Pulled By: ezyang

fbshipit-source-id: 18af9bcd77b088326738d567100fbe4a4c869dd6
2018-10-16 20:11:03 -07:00
Sam Gross
ea563c1df1 Make weight norm pickleable (#2066) 2017-07-12 17:21:22 -04:00
Sam Gross
6336300880 Fix bug where adding a hook could replace an existing hook.
We were keying hooks by RemovableHandle id. However, we don't hold onto
handles and ids of dead objects can be reused. This replaces id(handle)
with a global counter.
2017-03-06 12:47:53 -08:00
Luke Yeager
e7c1e6a8e3 [pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!):

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i

Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.

Also configures flake8 to match pep8's behavior.

Also configures TravisCI to check the whole project for lint.
2017-01-28 01:15:51 +01:00
Sam Gross
69d8331195 Use functools.partial 2017-01-13 23:10:45 +01:00
Sam Gross
7e4ddcfe8a Remove names from register_hook calls (#446)
The register hook calls now return an object that can be used to remove
the hook. For example,

   >>> h = module.register_forward_hook(callback)
   >>> h.remove()  # removes hook

Or as a context manager:

   >>> with module.register_forward_hook(callback):
   ...     pass

This makes it easier for libraries to use hooks without worrying about
name collisions.
2017-01-13 15:57:03 -05:00