Happy to split this PR more if it helps.
This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.
Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.
autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).
Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.
Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
Adds a setup_context staticmethod to autograd.Function.
If it exists, then the user splits the ctx-specific logic from the
forward() and puts it in the setup_context staticmethod.
Docs will come later when we remove the feature flag.
Test Plan:
- some light tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89859
Approved by: https://github.com/soulitzer
Not sure why, but top-level `using namespace` directive causes VC++ fail with (if C++17 standard is used, but everything is fine with C++14):
```
C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\detail\../pytypes.h(1520): error C2872: 'attr': ambiguous symbol
C:\actions-runner\_work\pytorch\pytorch\aten\src\ATen/core/interned_strings.h(349): note: could be 'c10::attr'
C:\actions-runner\_work\pytorch\pytorch\torch/csrc/jit/ir/ir.h(75): note: or 'torch::jit::attr'
C:\actions-runner\_work\pytorch\pytorch\cmake\..\third_party\pybind11\include\pybind11/pybind11.h(1094): note: see reference to function template instantiation 'pybind11::str pybind11::str::format<_Ty1&>(_Ty1 &) const' being compiled
with
[
_Ty1=pybind11::handle
]
```
Solve this by replacing global `using namespace torch::jit;` with
specific usages of objects/methods from namespaces
Another prep change for https://github.com/pytorch/pytorch/70188
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90228
Approved by: https://github.com/kit1980, https://github.com/albanD
Now that there will be two types of Python function prehooks, I prefer have the PyFunction hook taking all grad_outputs and returning all grad_inputs as the more "canonical" one
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83225
Approved by: https://github.com/albanD
Add flag (inline_autograd) to enable inline export of model consisting of autograd functions. Currently, this flag should only be used in TrainingMode.EVAL and not for training.
An example:
If a model containing ``autograd.Function`` is as follows
```
class AutogradFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
result = result.log()
ctx.save_for_backward(result)
return result
```
Then the model is exported as
```
graph(%0 : Float):
%1 : Float = ^AutogradFunc(%0)
return (%1)
```
If inline_autograd is set to True, this will be exported as
```
graph(%0 : Float):
%1 : Float = onnx::Exp(%0)
%2 : Float = onnx::Log(%1)
return (%2)
```
If one of the ops within the autograd module is not supported, that particular node is exported as is mirroring ONNX_FALLTHROUGH mode
Fixes: #61813
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74765
Approved by: https://github.com/BowenBao, https://github.com/malfet
We define specializations for pybind11 defined templates
(in particular, PYBIND11_DECLARE_HOLDER_TYPE) and consequently
it is important that these specializations *always* be #include'd
when making use of pybind11 templates whose behavior depends on
these specializations, otherwise we can cause an ODR violation.
The easiest way to ensure that all the specializations are always
loaded is to designate a header (in this case, torch/csrc/util/pybind.h)
that ensures the specializations are defined, and then add a lint
to ensure this header is included whenever pybind11 headers are
included.
The existing grep linter didn't have enough knobs to do this
conveniently, so I added some features. I'm open to suggestions
for how to structure the features better. The main changes:
- Added an --allowlist-pattern flag, which turns off the grep lint
if some other line exists. This is used to stop the grep
lint from complaining about pybind11 includes if the util
include already exists.
- Added --match-first-only flag, which lets grep only match against
the first matching line. This is because, even if there are multiple
includes that are problematic, I only need to fix one of them.
We don't /really/ need this, but when I was running lintrunner -a
to fixup the preexisting codebase it was annoying without this,
as the lintrunner overall driver fails if there are multiple edits
on the same file.
I excluded any files that didn't otherwise have a dependency on
torch/ATen, this was mostly caffe2 and the valgrind wrapper compat
bindings.
Note the grep replacement is kind of crappy, but clang-tidy lint
cleaned it up in most cases.
See also https://github.com/pybind/pybind11/issues/4099
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82552
Approved by: https://github.com/albanD
Summary:
in the NVTX markers. This feature adds additional information
to the NVTX marker string eg seq_ids=[101, 102, 103]. This indicates
the sequence id of the op which produced the input tensor based on its
position index in the array. In the above example input tensor 0 was produced by
the node with sequence id 101, input tensor 1 is from node 102, input tensor 2 is from
node with sequence id 103. This is the same way the sizes array is
organized. If you know the sequence id of the node and the sequence ids
of the input edges, then you have enough information to construct the
network graph.
Fixes https://github.com/pytorch/pytorch/issues/66105
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70264
Reviewed By: chaekit
Differential Revision: D34792707
Pulled By: robieta
fbshipit-source-id: 4407b853c929a737505803b0db77a8ecd966cce2
(cherry picked from commit cd3c0c8c9d4d63d7897f60521c407883240d1d5b)
Summary:
Minimal example that deadlocks before but not after:
```python
import torch
from torch.autograd import Function
class Foo(Function):
staticmethod
def forward(ctx, x):
return x.clone()
staticmethod
def forward(ctx, gO):
return gO.clone()
def get_out():
inp = torch.rand(2, requires_grad=True)
# The python function is first so that it runs
# last in the backward pass
right = Foo.apply(inp)
# An op that creates new memory
left1 = inp.clone()
# An op that saves its input
left2 = left1 ** 2
# Inplace modify so that the backward for
# left2 always raises an error
left1 += 1
# An op that takes both side as input.
# After running, both side's last op will be in
# the ready queue
# And the op for left will run first as it was
# executed last during the forward
out = left2 + right
return out
# Nothing should be global variables here as, from what
# I can see, python leaks all the global objects
get_out().sum().backward()
```
Since this requires the python interpreter to die, it is hard to test in CI.
Let me know if you have an idea how to do it though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73961
Reviewed By: malfet
Differential Revision: D34752747
Pulled By: albanD
fbshipit-source-id: 1a537b1f733e161e8d3ff053cd432b37b34d432a
(cherry picked from commit 17943e4c04c782d81deab439e010195f04e75bbd)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71866
See title. There is a minimal perf regression for the non-functorch case
(a TLS access and a null check).
Test Plan: Imported from OSS
Reviewed By: soulitzer
Differential Revision: D33825279
Pulled By: zou3519
fbshipit-source-id: afa2ad5a672cc9225d2bb6b46ee7f3f1513c1e02
(cherry picked from commit 17ae1d3e9d)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71569
Not sure if this is the right API
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D33695395
Pulled By: soulitzer
fbshipit-source-id: 652b5758f15d901f98ff0da94e977030c7f3415b
(cherry picked from commit 9421a6846a)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66744
Modified loops in files under fbsource/fbcode/caffe2/ from the format
`for(TYPE var=x0;var<x_max;x++)`
to the format
`for(const auto var: irange(xmax))`
This was achieved by running r-barnes's loop upgrader script (D28874212) with some modification to exclude all files under /torch/jit and a number of reversions or unused variable suppression warnings added by hand.
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D31705358
fbshipit-source-id: d6ea350cbaa8f452fc78f238160e5374be637a48
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66234
Modified loops in files under fbsource/fbcode/caffe2/ from the format
`for(TYPE var=x0;var<x_max;x++)`
to the format
`for(const auto var: irange(xmax))`
This was achieved by running r-barnes's loop upgrader script (D28874212) with some modification to exclude all files under /torch/jit and a number of reversions or unused variable suppression warnings added by hand.
bypass_size_limit
allow-large-files
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D30652629
fbshipit-source-id: 0ae6c4bbbb554bad42e372792a6430e1acf15e3e
Summary:
Also use range loop instead of regular one
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66315
Reviewed By: albanD
Differential Revision: D31503730
Pulled By: malfet
fbshipit-source-id: f5568f7f28e15a9becd27986dd061a6fcae34651
Summary:
When generating IR for autograd.Function, if the function has multiple outputs, a TupleUnpack may be inserted after the original function node, and Pytorch only assigns proper information (tensor element type and shape) to the TupleUnpack and forgets the original function node. In contrast, if autograd.Function only produces one output, the original function node may have tensor
element type and shape in its output schema.
Before this PR:
- (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp -> output (tensor, dtype=float32, shape=[4, 5])
- (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp -> output_0 **(tensor)**, output_1 **(tensor)** -> TupleUnpack output_2 (tensor, dtype=float32, shape=[4, 5]), output_3 (tensor, dtype=float32, shape=[6, 7])
After this PR:
- (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp -> output (tensor, dtype=float32, shape=[4, 5])
- (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp ->output_0 **(tensor, dtype=float32, shape=[4, 5])**, output_1 **(tensor, dtype=float32, shape=[6, 7])** -> TupleUnpack output_2 (tensor, dtype=float32, shape=[4, 5]), output_3 (tensor, dtype=float32, shape=[6, 7])
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57966
Reviewed By: zhxchen17
Differential Revision: D30208207
Pulled By: gmagogsfm
fbshipit-source-id: 42a3d1f9c0932133112a85df0c49cf4ea0afa175
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`
All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008
Reviewed By: driazati, r-barnes
Differential Revision: D29838584
Pulled By: malfet
fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
Summary:
Switches most of the simple for loops outside of `jit` directories to use `c10::irange`.
Generated with D28874212.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59481
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D28909681
fbshipit-source-id: ec9ab1bd602933238d9d0f73d4d8d027b75d9d85
Summary:
For facebook employees, this fix some internal failures from https://www.internalfb.com/tasks/?t=92100671
This was not a problem before https://github.com/pytorch/pytorch/pull/58271 because these cycles used to just be leaked (so nothing was cleared/dealloced).
Now that we properly clean up these cycles, we have to fix the assert in the clear.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59301
Reviewed By: jbschlosser
Differential Revision: D28841564
Pulled By: albanD
fbshipit-source-id: e2ec51f6abf44c4e3a83c293e90352295a43ba37
Summary:
Fixes https://github.com/pytorch/pytorch/issues/30696
### Release Notes
Instantiating a custom autograd function is now deprecated. Users should call `.apply()` on the class itself because it is a static method.
--end release notes--
- There are a couple error messages that we can't entirely remove because accessing these attributes of the autograd function instance may segfault (due to cdata being nullptr). Also added a TORCH_CHECK for the name attribute which previously segfaulted.
- Error message updated to convey 1) old-style functions have been deprecated 2) this access pattern was once valid
- Updates variable -> Tensor for some error messages
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57357
Reviewed By: mrshenli
Differential Revision: D28193095
Pulled By: soulitzer
fbshipit-source-id: f021b105e9a3fd4a20d6ee3dfb6a06a8c34b10ca
Summary:
In my last PR I've missed CUDA and distributed folders, fixing this now
This change is autogenerated by `python tool/clang_tidy.py -s`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57235
Reviewed By: janeyx99
Differential Revision: D28084444
Pulled By: malfet
fbshipit-source-id: bf222f69ee90c7872c3cb0931e8cdb84f0cb3cda
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55799
I'm going to change the implementation of cdata soon so I need to
abstract over cdata access with a function. Additionally, many
users are casting manually casting to THPVariable to access
the member so I can remove these unsafe casts in the client code
(the implementation, of course, is still doing an unsafe cast.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D27712130
Pulled By: ezyang
fbshipit-source-id: 95fcc013bf3913d67f2c634068eb5b3aab144cb3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52422
As mentioned in https://github.com/pytorch/pytorch/issues/52415,
`torch.utils.checkpoint` doesn't support checkpointing for functions which have
non-tensor inputs and outputs.
This PR resolves this issue by ensuring the autograd machinery ignores the
non-tensor inputs and outputs and processes the tensors accordingly.
ghstack-source-id: 124406867
Test Plan:
1) unit test
2) waitforbuildbot
Reviewed By: albanD
Differential Revision: D26507228
fbshipit-source-id: 0a5a1591570814176185362e83ad18dabd9c84b0
Summary:
Fixes https://github.com/pytorch/pytorch/issues/49756
## Background
Fix applied here is to remove the grad enabled check from `collect_next_edges`, unconditionally returning the actual collected edges. This pushes the responsibility for determining whether the function should be called without grad mode to its call-sites. With this update, `collect_next_edges` will no longer incorrectly return an empty list, which caused the problem described in the issue. Three call-sites depended on this behavior and have been updated.
Beyond bad printing side effects, this fix addresses the more general issue of accessing `grad_fn` with grad mode disabled after an in-place operation on a view. The included test verifies this without the use of print.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51364
Test Plan:
```
python test/test_autograd.py TestAutogradDeviceTypeCPU.test_inplace_view_then_no_grad_cpu
```
Reviewed By: zou3519
Differential Revision: D26190451
Pulled By: jbschlosser
fbshipit-source-id: 9b004a393463f8bd4ac0690e5e53c07a609f87f0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46227
Follow up from https://github.com/pytorch/pytorch/issues/45419, in
this PR I've removed as many PyCFunction casts as I could from the codebase.
The only ones I didn't remove were the ones with `METH_VARARGS | METH_KEYWORDS`
which have 3 parameters instead of 2 and had to be casted. Example: `
{"copy_", (PyCFunction)(void(*)(void))THPStorage_(copy_), METH_VARARGS |
METH_KEYWORDS, nullptr},`
ghstack-source-id: 114632704
Test Plan: waitforbuildbot
Reviewed By: albanD
Differential Revision: D24269435
fbshipit-source-id: 025cfd43a9a2a3e59f6b2951c1a78749193d77cf
Summary:
Fixes https://github.com/pytorch/pytorch/issues/43405.
This pull request adds a feature of printing all tracebacks if a `detect_anomaly` mode detects `nan` in nested backward operations.
The way I did it is by assigning a node as a parent to all nodes it produces during its backward calculation. Then if one of the children produces `nan`, it will print the traceback from the parent and grand parents (if any).
The parent is assigned in `parent_node_` member in `Node` class which is accessible in C++ by function `node->parent()` and in Python by `node.parent_function`.
A node has a parent iff:
1. it is created from a backward operation, and
2. created when anomaly mode and grad mode are both enabled.
An example of this feature:
import torch
def example():
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(1e-8, requires_grad=True) # small to induce nan in n-th backward
a = x * y
b = x * y
z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved
z = z1 * z1
gy , = torch.autograd.grad( z , (y,), create_graph=True)
gy2, = torch.autograd.grad(gy , (y,), create_graph=True)
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True)
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True)
return gy4
with torch.autograd.detect_anomaly():
gy4 = example()
with output:
example.py:16: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning: Error detected in DivBackward0. Traceback of forward call that caused the error:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 12, in example
gy3, = torch.autograd.grad(gy2, (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:61.)
return Variable._execution_engine.run_backward(
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning:
Traceback of forward call that induces the previous calculation:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 11, in example
gy2, = torch.autograd.grad(gy , (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.)
return Variable._execution_engine.run_backward(
/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py:190: UserWarning:
Traceback of forward call that induces the previous calculation:
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 8, in example
z1 = a / b # can produce nan in n-th backward as long as https://github.com/pytorch/pytorch/issues/43414 is unsolved
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:65.)
return Variable._execution_engine.run_backward(
Traceback (most recent call last):
File "example.py", line 17, in <module>
gy4 = example()
File "example.py", line 13, in example
gy4, = torch.autograd.grad(gy3, (y,), create_graph=True)
File "/home/mfkasim/anaconda2/envs/base3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 190, in grad
return Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 1th output.
cc & thanks to albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43626
Reviewed By: malfet
Differential Revision: D23397499
Pulled By: albanD
fbshipit-source-id: aa7435ec2a7f0d23a7a02ab7db751c198faf3b7d
Summary:
Added a new option in AutogradContext to tell autograd to not materialize output grad tensors, that is, don't expand undefined/None tensors into tensors full of zeros before passing them as input to the backward function.
This PR is the second part that closes https://github.com/pytorch/pytorch/issues/41359. The first PR is https://github.com/pytorch/pytorch/pull/41490.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41821
Reviewed By: albanD
Differential Revision: D22693163
Pulled By: heitorschueroff
fbshipit-source-id: a8d060405a17ab1280a8506a06a2bbd85cb86461
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37587
Lifting RecordFunction up into the dispatcher code
Test Plan: Imported from OSS
Differential Revision: D21374246
fbshipit-source-id: 19f9c1719e6fd3990e451c5bbd771121e91128f7
Summary:
Leave undefined tensors / None returned from custom backward functions as undefined/None instead of creating a tensor full of zeros. This change improves performance in some cases.
**This is BC-Breaking:** Custom backward functions that return None will now see it potentially being propagated all the way up to AccumulateGrad nodes. Potential impact is that .grad field of leaf tensors as well as the result of autograd.grad may be undefined/None where it used to be a tensor full of zeros. Also, autograd.grad may raise an error, if so, consider using allow_unused=True ([see doc](https://pytorch.org/docs/stable/autograd.html?highlight=autograd%20grad#torch.autograd.grad)) if it applies to your case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41490
Reviewed By: albanD
Differential Revision: D22578241
Pulled By: heitorschueroff
fbshipit-source-id: f4966f4cb520069294f8c5c1691eeea799cc0abe
Summary:
# Goals
Do the following things during a distributed backward pass.
1. Accumulate the gradient of a variable to RPC context once the gradient is ready instead of at the very end of the backward pass.
2. Run post/pre hooks installed in`AccumulateGrad` nodes once the gradient is ready for the variable. Currently, the hooks in `AccumulateGrad` are not executed just because the function `AccumulateGrad` itself is not even evaluated by the local engine.
3. Make it extensible to support post hooks installed by DDP's reducer.
# Introduce GradCapturePreHook
## Why do we need this?
### Root issue:
* dist engine uses the autograd.grad-like API on the vanilla engine and then in the Future callback populates the context with the gradients. This is a bad emulation of the .backward() call on the vanilla engine.
### Practical issue:
* The leaf’s hook are not called (because associated with the AccumulateGrad that is not call in the autograd.grad-like API). Modules like DDP rely on these hooks.
* The Future is marked as completed before the context is actually populated with the grads leading to unexpected behavior on the user side.
* The Future callback is only called at the complete end of the backward and so too late for DDP if they want to overlap compute/transfert.
### Proposed solution:
* Provide hooks in the autograd.grad-like API that will allow the distributed engine to populate the context and call the hooks to better emulate the .backward call.
## Who can install a grad capture pre-hook?
This will be an internal hook at C++ level and it won’t be exposed to PyThon code. Only call-sites directly interacting with the local engine can install such hooks.
## Signature
The returned `grad` will be captured.
```
virtual const torch::Tensor& grad operator()(const torch::Tensor& grads) = 0;
```
## Where are hooks installed?
Grad capture pre-hooks are install in GraphTask::ExecInfo::Capture. ExecInfo is per node. Every backward run will have its own GraphTask instance.
## When/How will hooks be called?
When the local engine captures the grads for a node, all grad capture pre hooks are called one by one in the order they are added. The output grads of the hooks will replace the original grads.
The output of the last hook will be used for grad capturing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34501
Test Plan:
All existing tests should pass.
```
python setup.py develop
python test/distributed/rpc/test_dist_autograd_spawn.py DistAutogradTestWithSpawn.test_post_hooks
```
Differential Revision: D20953673
Pulled By: hczhu
fbshipit-source-id: 543b3844823330ea9f9856bab7c5cb2679290a53