How the old retains_grad hooks was implemented:
- retains_grad hooks are stored on the autograd_meta, as entries in a vector
- upon registration, a wrapper hook CppFunctionTensorPreHook is created to wrap that vector, and then that wrapper hook is registered to the grad_fn, i.e., by appending it to a vector of retains_grad hooks on the grad_fn
- upon in-place, for the old grad_fn we set the retains_grad hook to nullptr, so that even though the old grad_fn still references the vector, the vector contains a single nullptr. For the new grad_fn, we create a new wrapper hook around the vector (storing the single retains_grad hook) on autograd_meta.
The new retains_grad hook implementation:
- we store std::function by value, and we store it on the grad_fn rather than the autograd_meta
- a single grad_fn can have multiple outputs, so it can potentially hold multiple retains_grad hooks. We use an unordered_map (previously a vector).
- on in-place we remove the hook from the old grad_fn and put it in the new grad_fn (small implication of this change is that we we now need to have access to both the old grad_fn and new grad_fn, this isn't a problem)
Other details:
- CppFunctionTensorPreHook took a shared_ptr to vector of std::function. In our new implementation, we add a new wrapper hook CppFunctionSingleTensorPreHook, which takes a single std::function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92604
Approved by: https://github.com/albanD
This reverts commit e525f433e1.
Original PR: #85849
Fixes #ISSUE_NUMBER
In addition to reverting the revert, this PR:
- defines the virtual destructor of FunctionPreHook in the header. Why? Presumably the internal build imports the header from somewhere, but does not have function_hooks.cpp (where the virtual destructor was previously defined) in the same compilation unit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92559
Approved by: https://github.com/albanD
Addresses: https://github.com/pytorch/pytorch/issues/35802
Design doc: https://docs.google.com/document/d/19xSib7FFknRQ5f3ptGFUmiOt3BrgXSUlTQH2xMcZJYg/edit#
### Changes in this PR
#### Implementation
- We have now have 3 fields: pre_hooks, retains_grad_hooks, and tensor_pre_hooks so that we can more precisely define their ordering and when they are executed.
- Since retains grad uses an entirely new field, we cannot reuse the old retains grad, logic. We refactor retains grad to call directly into the variable.cpp logic. Other logic in variable.cpp that handle cpp hooks must also be updated.
#### Hooks ordering and execution:
- Defines pre-hooks registered on tensor to run before pre-hooks registered on grad_fn
- Updates pre-hooks registered on tensor to always run, even if they are the inputs= to .grad()
- Post hooks (and pre hooks) can now observe the modifications to gradient by the tensor pre hook
#### Retains grad hooks
- retains grad hooks always execute last, even if there are other tensor pre-hooks registered
#### Unchanged:
- pre_hooks registered to grad_fn aren't expected to execute if they are the inputs= to .grad()
Follow ups:
- simplify retains_grad field to not be a vector, since it always holds a single hook
- potentially merge capture hooks with tensor pre hooks, this would involve some additional refactoring since
- python hooks registered to tensor behavior on in-place is still wrong
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85849
Approved by: https://github.com/albanD
CUDA 12 introduces behavioral changes in `cudaSetDevice`. In the old version it would just set the device to be used for kernel launches and memory allocations without creating a CUDA context. Now, in CUDA 12, every time `cudaSetDevice` is called for the first time it creates a CUDA context. See issue #91122.
The autograd engine iterates over all devices and sets them:
f8b348c1fc/torch/csrc/autograd/engine.cpp (L1399-L1402)f8b348c1fc/torch/csrc/autograd/engine.cpp (L349)
Which causes pollution of CUDA contexts on sibling devices.
This PR introduces a workaround this issue by conditionally setting the device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91191
Approved by: https://github.com/ngimel
Not sure, what I was thinking when writing something like:
```
auto foo = std::getenv("BAR");
if (!foo) {
foo = "baz";
}
```
as `std::getenv` return `char *` (i.e. mutable string), but string literals are immutable. (i.e. `const char *`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87949
Approved by: https://github.com/kit1980
In this PR:
- graph_task stores graph roots on construction so that we can later traverse through the graph
- before the nodes are returned, they needed to be converted from raw_ptr to shared_ptr, and this should be OK because the graph is guaranteed to be alive
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87507
Approved by: https://github.com/albanD
Addresses: https://github.com/pytorch/pytorch/issues/83617
This PR a way to query the TLS graph task's exec_info which is a map mapping the Node to a bool indicating whether it will be executed in the current backward pass (as determined by the inputs= argument for .grad of .backward).
- this works with both custom Function nodes and normal codegened nodes
- to be able to verify whether the pyobject passed is an actual node, we now store pointers to PyTypeObjects into a set on registration.
- error out when .backward without inputs= to avoid silently returning True
Alternatives:
- not sure if it is possible to bind to Python from a raw pointer to Node. At least we wouldn't be able to use existing logic, and the Python object should only hold a weak reference to the Node.
- other solutions to the motivating issue seem to require more extensive modification to the engine
See the issue linked for an example of usage
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84773
Approved by: https://github.com/albanD
### Introduction
<!-- What did you change and why was it needed? -->
Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.
For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).
The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.
<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">
### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: https://github.com/pytorch/pytorch/issues/56500
### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.
Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.
### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99
Benchmark result:
6 hidden layers, batch size 10000, on A100
FP32 result
| hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) |
| Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) |
TF32 result
| hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) |
| Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) |
For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark)
@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?
### Testing
<!-- How did you test your change? -->
- [x] we need to figure out a way for unittest
### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)
cc @zasdfgbnm @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544
Approved by: https://github.com/soulitzer
In preparation of adopting future rocblas library options, it is necessary to track when the backward pass of training is executing. The scope-based helper class `BackwardPassGuard` is provided to toggle state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71881
Approved by: https://github.com/albanD
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76366
caffe2 is not currently being built for XROS.
Test Plan: CI
Reviewed By: kimishpatel
Differential Revision: D35923922
fbshipit-source-id: 260dacadf0bd5b6bab7833a4ce81e896d280b053
(cherry picked from commit 8370b8dd2519d55a79fa8d45e7951ca8dc0b21a8)
This pull request enables accumulating gradients for the CSR tensor.
Functions that work and are tested:
- tensor.abs()
- tensor.neg()
- tensor.conj_physical()
- torch.addmm
`torch.mm` also works, but tests will be added later.
In addition, this PR adds throwing an error when trying to access strides, storage, and contiguity info on a CSR tensor.
`tensor.to_sparse_csr().to_sparse_csr()` was failing and now fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75435
Approved by: https://github.com/cpuhrsch
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72689
Fix https://github.com/pytorch/pytorch/issues/69839
Should we add a private python binding to check if the bad fork guard has been set and add test in CI to make sure that it is never set on our CPU-only CI build? Not sure how flaky that will be out of CI for people that run CPU build on a machine that cuda installed...
EDIT: turns out, we already had such tests in test_multiprocessing. So should be tested and enforced now!
Test Plan: Imported from OSS
Reviewed By: soulitzer
Differential Revision: D34180243
Pulled By: albanD
fbshipit-source-id: 3284db52dcf4568362244b60e3c5657153e64fa4
(cherry picked from commit 6e23f7a33a)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72688
Refactor how we know what to run on the cpu queue.
The Lazy Tensor moved there as it is always present as a device guard and would make the number of devices 1 all the time (forcing the creation of a thread).
FYI wconstab you most likely don't care about this unless you ever use multiple Lazy device?
This should slightly improve the perf if you run backward with Lazy Tensors as the work will be done in the main thread and not a worker thread.
Test Plan: Imported from OSS
Reviewed By: soulitzer
Differential Revision: D34180245
Pulled By: albanD
fbshipit-source-id: 88c5d5bdd631ad01bf271d720d1eab69aba84fc0
(cherry picked from commit da7e9b902f)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72687
Not sure who would use that. It is not used in the code base as far as I can see. And I don't know of anyone working with the engine directly out of tree. So tentatively removing it.
Test Plan: Imported from OSS
Reviewed By: soulitzer
Differential Revision: D34180244
Pulled By: albanD
fbshipit-source-id: 678ba1c4a1cbd9a0458d33be97664d1e3d1bd86b
(cherry picked from commit 3968ca3a38)
Summary:
As issue https://github.com/pytorch/pytorch/issues/59750 is fixed, this PR is to remove the workaround implemented for it on ROCm.
Enabled hasPrimaryContext() related PyTorch unit tests on ROCm.
cc: amathews-amd, jithunnair-amd
cc jeffdaily sunway513 jithunnair-amd ROCmSupport KyleCZH
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71146
Reviewed By: anjali411
Differential Revision: D33754615
Pulled By: albanD
fbshipit-source-id: b3a5c65a20c6d52d5f2ffc9e6f9628c819329b5d
(cherry picked from commit cfdd12166c)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69041
`TH_CONCAT_{N}` is still being used by THP so I've moved that into
it's own header but all the compiled code is gone.
Test Plan: Imported from OSS
Reviewed By: anjali411
Differential Revision: D32872477
Pulled By: ngimel
fbshipit-source-id: 06c82d8f96dbcee0715be407c61dfc7d7e8be47a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68266
* Use `if...endif` to adjust pyTorch internals towards XROS
Test Plan: CI
Reviewed By: kkosik20
Differential Revision: D32190771
fbshipit-source-id: cce073dea53c2b5681d913321101cd83c6472019
Summary:
Fixes https://github.com/pytorch/pytorch/issues/50209
This adds a new warning handler that stores all warnings in a shared
queue, which can be "replayed" at a later time and, crucially, on
another thread. Then, I use this inside the autograd engine to ensure
that warnings are processed by the handler registered on the main
thread.
For testing, I also add an operator that always warns in the backward
pass and test that the warning is a normal Python warning.
cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66235
Reviewed By: ejguan
Differential Revision: D31505413
Pulled By: albanD
fbshipit-source-id: 1a7f60b038f55c20591c0748b9e86735b3fec2f9
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65610
- Replace HIP_PLATFORM_HCC with USE_ROCM
- Dont rely on CUDA_VERSION or HIP_VERSION and use USE_ROCM and ROCM_VERSION.
- In the next PR
- Will be removing the mapping from CUDA_VERSION to HIP_VERSION and CUDA to HIP in hipify.
- HIP_PLATFORM_HCC is deprecated, so will add HIP_PLATFORM_AMD to support HIP host code compilation on gcc.
cc jeffdaily sunway513 jithunnair-amd ROCmSupport amathews-amd
Reviewed By: jbschlosser
Differential Revision: D30909053
Pulled By: ezyang
fbshipit-source-id: 224a966ebf1aaec79beccbbd686fdf3d49267e06
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65235
1. Updated the legacy type checks in `torch/csrc/autograd/engine.cpp` to individually validate the dtype, device, and layout equality for grad and tensor.
2. Removed device field from `InputMetadata` since it's already stored via storing options. Also, added `dtype()` and `layout()` methods to `InputMetadata`. To make this change, some calls had to be updated due to the change in constructor.
3. To fix https://github.com/pytorch/pytorch/issues/65016:
a. Added a `is_tensor_subclass` field in `InputMetadata` to skip device checks for grad and tensor when the tensor has
python key set on it (tensor subclass).
Test Plan: Imported from OSS
Reviewed By: jbschlosser
Differential Revision: D31117318
Pulled By: anjali411
fbshipit-source-id: 825401df98695c48bf9b320be54585f6aff500bd
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65235
1. Updated the legacy type checks in `torch/csrc/autograd/engine.cpp` to individually validate the dtype, device, and layout equality for grad and tensor.
2. Removed device field from `InputMetadata` since it's already stored via storing options. Also, added `dtype()` and `layout()` methods to `InputMetadata`. To make this change, some calls had to be updated due to the change in constructor.
3. To fix https://github.com/pytorch/pytorch/issues/65016:
a. Added a `is_tensor_subclass` field in `InputMetadata` to skip device checks for grad and tensor when the tensor has
python key set on it (tensor subclass).
Test Plan: Imported from OSS
Reviewed By: pbelevich
Differential Revision: D31082693
Pulled By: anjali411
fbshipit-source-id: cb551cd438c6ca40b0f18a4d0009e0861cf0fd4e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63619
Adds a RECORD_FUNCTION with the function that is being valuate as part
of backwards execution. This has been useful in picking up some operations
in the backwards pass that otherwise would not show up, for example custom cpp
functions that use custom C++ code.
ghstack-source-id: 137041723
Test Plan:
CI
benchmark:
buck run mode/opt //scripts/rvarm1/ddp:bench
Reviewed By: albanD
Differential Revision: D30439492
fbshipit-source-id: 955917770cdf2a2edb0303223ace710b668ba388
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63116
This PR removes the special flag to disable grad mode tracking on the ThreadLocalState and replaces it with an explicit setter that users can use.
This allows to reduce complexity of ThreadLocalState.
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D30388098
Pulled By: albanD
fbshipit-source-id: 85641b3d711179fb78ff6a41ed077548dc821a2f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63115
This actually changes:
- callbacks now run with proper grad mode even in worker threads
- graphtask's Future callbacks now run with proper TLS when erroring
out from a worker thread
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D30388100
Pulled By: albanD
fbshipit-source-id: 7ae9c461c2f0040548dd9e1e314f25e8da0c2e67
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`
All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008
Reviewed By: driazati, r-barnes
Differential Revision: D29838584
Pulled By: malfet
fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
Summary:
The thread local state of backward thread is not aligned to the GraphTask's `thread_local_` when calling the hooks in backward.
This is required for profiling the statistics c10d operation of `DistributedDataParallel` module.
Is there any concern to add the thread local state guard when calling the hooks in backward? ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60067
Reviewed By: ezyang
Differential Revision: D29654599
Pulled By: albanD
fbshipit-source-id: 656c4f91017184fd40f1a184de24757a13387e37
Summary:
Before https://github.com/pytorch/pytorch/pull/57833, calls to backward() or grad() synced only the calling thread's default stream with autograd leaf streams at the end of backward. This made the following weird pattern safe:
```python
with torch.cuda.stream(s):
# imagine forward used many streams, so backward leaf nodes may run on many streams
loss.backward()
# no sync
use grads
```
but a more benign-looking pattern was unsafe:
```python
with torch.cuda.stream(s):
# imagine forward used a lot of streams, so backward leaf nodes may run on many streams
loss.backward()
# backward() syncs the default stream with all the leaf streams, but does not sync s with anything,
# so counterintuitively (even though we're in the same stream context as backward()!)
# it is NOT SAFE to use grads here, and there's no easy way to make it safe,
# unless you manually sync on all the streams you used in forward,
# or move "use grads" back to default stream outside the context.
use grads
```
mruberry ngimel and I decided backward() should have the [same user-facing stream semantics as any cuda op](https://pytorch.org/docs/master/notes/cuda.html#stream-semantics-of-backward-passes).** In other words, the weird pattern should be unsafe, and the benign-looking pattern should be safe. Implementationwise, this meant backward() should sync its calling thread's current stream, not default stream, with the leaf streams.
After https://github.com/pytorch/pytorch/pull/57833, backward syncs the calling thread's current stream AND default stream with all leaf streams at the end of backward. The default stream syncs were retained for temporary backward compatibility.
This PR finishes https://github.com/pytorch/pytorch/pull/57833's work by deleting syncs on the default stream.
With this PR, graph-capturing an entire backward() call should be possible (see the [test_graph_grad_scaling diffs](https://github.com/pytorch/pytorch/compare/master...mcarilli:streaming_backwards_remove_default_syncs?expand=1#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R3641-R3642)).
** first paragraph has a formatting error which this PR should also fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60421
Reviewed By: albanD
Differential Revision: D29370344
Pulled By: ngimel
fbshipit-source-id: 3248bc5fb92fc517db0c15c897e5d7250f67d7fe
Summary:
Before https://github.com/pytorch/pytorch/pull/57833, calls to backward() or grad() synced only the calling thread's default stream with autograd leaf streams at the end of backward. This made the following weird pattern safe:
```python
with torch.cuda.stream(s):
# imagine forward used many streams, so backward leaf nodes may run on many streams
loss.backward()
# no sync
use grads
```
but a more benign-looking pattern was unsafe:
```python
with torch.cuda.stream(s):
# imagine forward used a lot of streams, so backward leaf nodes may run on many streams
loss.backward()
# backward() syncs the default stream with all the leaf streams, but does not sync s with anything,
# so counterintuitively (even though we're in the same stream context as backward()!)
# it is NOT SAFE to use grads here, and there's no easy way to make it safe,
# unless you manually sync on all the streams you used in forward,
# or move "use grads" back to default stream outside the context.
use grads
```
mruberry ngimel and I decided backward() should have the [same user-facing stream semantics as any cuda op](https://pytorch.org/docs/master/notes/cuda.html#stream-semantics-of-backward-passes).** In other words, the weird pattern should be unsafe, and the benign-looking pattern should be safe. Implementationwise, this meant backward() should sync its calling thread's current stream, not default stream, with the leaf streams.
After https://github.com/pytorch/pytorch/pull/57833, backward syncs the calling thread's current stream AND default stream with all leaf streams at the end of backward. The default stream syncs were retained for temporary backward compatibility.
This PR finishes https://github.com/pytorch/pytorch/pull/57833's work by deleting syncs on the default stream.
With this PR, graph-capturing an entire backward() call should be possible (see the [test_graph_grad_scaling diffs](https://github.com/pytorch/pytorch/compare/master...mcarilli:streaming_backwards_remove_default_syncs?expand=1#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R3641-R3642)).
** first paragraph has a formatting error which this PR should also fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60421
Reviewed By: VitalyFedyunin, albanD
Differential Revision: D29342234
Pulled By: ngimel
fbshipit-source-id: 98e6be7fdd8550872f0a78f9a66cb8dfe75abf63
Summary:
Fixes https://github.com/pytorch/pytorch/issues/59844.
Streaming backwards collects "leaf streams" for AccumulateGrad functions that stash or accumulate .grad attributes for autograd leaf tensors, and syncs those streams with some ambient stream(s) so later ops can safely consume the grads on the ambient stream(s).
But, currently, streaming backwards does not collect leaf streams for grads produced out-of-place (ie, not stashed onto a .grad attribute) by `torch.autograd.grad`, because these out-of-place grads are "captured" and returned before they reach an AccumulateGrad function. Some out-of-place grads might not even have an AccumulateGrad function to go to, because `torch.autograd.grad` can be told to make grads for non-leaf temporaries.[1]
The upshot is, when streaming backwards makes ops that produce out-of-place gradients run on side streams, no ambient stream is told to sync on these side streams, so `torch.autograd.grad` doesn't offer the same post-call safe-use guarantees for grads as the leaf accumulation of `torch.autograd.backward`.
This PR ensures `torch.autograd.grad` gives the same safe-use guarantees as `torch.autograd.backward` by also stashing leaf streams for grads created out-of-place.
I augmented a streaming backwards test to include a torch.autograd.grad attempt. The test fails on current master[2] and passes with the engine.cpp diffs.
I have no idea if this bug or its fix matter to distributed autograd. pritamdamania mrshenli should take a look before it's merged.
[1] example:
```python
leaf = torch.tensor(..., requires_grad=True)
tmp = leaf * 2
loss = tmp.sum()
torch.autograd.grad(loss, inputs=(tmp, leaf))
```
Technically, because `torch.autograd.grad` can be told to produce grads for non-leaf temporaries, these streams might NOT be "leaf streams". Maybe I should rename `leaf_streams`?
[2] the way the test currently fails is fun: it reports
```
AssertionError: False is not true : Tensors failed to compare as equal!With rtol=1.3e-06 and atol=1e-05, found 0 element(s) (out of 25) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.0 (5.0 vs. 5.0), which occurred at index (0, 0).
```
I suspect this [kafka trap](https://en.wiktionary.org/wiki/Kafkatrap) happens because assertEqual does a comparison test on the device, syncs on some bool result, sees failure and prints the tensors post-sync at which point is IS safe to access the values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60127
Reviewed By: mrshenli
Differential Revision: D29276581
Pulled By: albanD
fbshipit-source-id: a9f797e2fd76e2f884cce5a32ecf5d9b704c88ee