we use dummy tensors in our initial trace, so we should never inline. the subclass dispatch might not support the dummy tensor, e.g. DTensor accumulate grad will check that both param and grad are DTensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149014
Approved by: https://github.com/jansel
ghstack dependencies: #149064
i'm changing CA initial trace to always trace as dynamic, fixes these errors:
```python
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
FAILED [0.2139s] test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_autograd_python_custom_function_inplace - RuntimeError: !has_symbolic_sizes_strides_ INTERNAL ASSERT FAILED at "/home/xmfan/core/a/pytorch/aten/src/ATen/TensorGeometry.h":63, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
python test/test_autograd.py TestAutogradWithCompiledAutograd.test_autograd_python_custom_function_inplace
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
FAILED [0.0057s] test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_copy_slices_graph_task_updates - RuntimeError: !has_symbolic_sizes_strides_ INTERNAL ASSERT FAILED at "/home/xmfan/core/a/pytorch/aten/src/ATen/TensorGeometry.h":63, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
python test/test_autograd.py TestAutogradWithCompiledAutograd.test_copy_slices_graph_task_updates
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
FAILED [0.9662s] test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_inplace_on_view_weak_grad_fn - RuntimeError: !has_symbolic_sizes_strides_ INTERNAL ASSERT FAILED at "/home/xmfan/core/a/pytorch/aten/src/ATen/TensorGeometry.h":63, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
python test/test_autograd.py TestAutogradWithCompiledAutograd.test_inplace_on_view_weak_grad_fn
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
FAILED [0.0077s] test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_leaf_assignment - RuntimeError: !has_symbolic_sizes_strides_ INTERNAL ASSERT FAILED at "/home/xmfan/core/a/pytorch/aten/src/ATen/TensorGeometry.h":63, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
python test/test_autograd.py TestAutogradWithCompiledAutograd.test_leaf_assignment
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
FAILED [5.0485s] test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_setitem_mask - RuntimeError: !has_symbolic_sizes_strides_ INTERNAL ASSERT FAILED at "/home/xmfan/core/a/pytorch/aten/src/ATen/TensorGeometry.h":63, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
python test/test_autograd.py TestAutogradWithCompiledAutograd.test_setitem_mask
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
FAILED [0.0102s] test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_tensor_hooks_inplace_over_view - RuntimeError: !has_symbolic_sizes_strides_ INTERNAL ASSERT FAILED at "/home/xmfan/core/a/pytorch/aten/src/ATen/TensorGeometry.h":63, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
python test/test_autograd.py TestAutogradWithCompiledAutograd.test_tensor_hooks_inplace_over_view
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148799
Approved by: https://github.com/jansel, https://github.com/zou3519
## Before
Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op.
## After
We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as:
```python
# pseudocode
class SavedVariable:
def unpack(self):
if self.hook:
return self.hook(self.packed_data)
else:
return self.packed_data
# This approach won't directly work when we add support for Forward AD or double-backward.
```
Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution.
All tests pass when running the CA graph directly, the remaining issues are in Dynamo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242
Approved by: https://github.com/jansel
## Before
Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op.
## After
We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as:
```python
# pseudocode
class SavedVariable:
def unpack(self):
if self.hook:
return self.hook(self.packed_data)
else:
return self.packed_data
# This approach won't directly work when we add support for Forward AD or double-backward.
```
Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution.
All tests pass when running the CA graph directly, the remaining issues are in Dynamo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242
Approved by: https://github.com/jansel
PR#141153 exposed the option to collect sizes as dynamic. After this
change, the function set_autograd_compiler returns PyTuple object which
is populated using PyTuple_SET_ITEM function. Yet, that function steals
reference to the object and doesn't INCREF it. So currently we are
missing INCREF on prior_compiler when it is Py_None and INCREF on
prior_dynamic which is either Py_False or Py_True. This bug may lead to
the possible memory corruption.
@xmfan @jansel @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145482
Approved by: https://github.com/albanD, https://github.com/jansel
The previous PRs built up to this. We change compiled autograd's initial
trace to stop baking in metadata.
While tracing, we allocate some weirdly shaped tensors that we can put
proxies on. The initial trace should not be accessing any metadata of
these tensors (it will likely error out if it does because of how weird
the shapes are).
This involved fixing some various sites where we do specialize on the
metadata, like:
- we change CopySlices's apply_with_saved to proxy some calls
into the graph (this change is fairly hard to split out by itself).
- we stop calling InputBuffer::add
- we delete the weird metadata from the graph so that no graph passes
can make use of it.
Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143417
Approved by: https://github.com/jansel, https://github.com/xmfan
ghstack dependencies: #143296, #143304, #143387, #143405
We define a functional version of a C++ torch::autograd::Function. The
functional version reconstructs the ctx object and then calls
backward with it.
Some more details:
- we define how to pack/unpack ctx.saved_data into an IValue. It's a
Dict[str, IValue], so it wasn't difficult.
- every call to CppNode::apply_with_saved binds a new function to
Python. This is because we're unable to reuse the a previously bound
function for reasons (the schema may change depending on what the user
actually puts into their Dict[str, IValue]).
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143387
Approved by: https://github.com/jansel, https://github.com/xmfan
ghstack dependencies: #143296, #143304
This PR is on the way to getting compiled autograd's initial capture to
stop specializing on Tensor metadata.
This PR changes compiled autograd's initial capture to proxy an opaque
(w.r.t. Dynamo) function into the graph for all built-in codegen'ed
autograd nodes and validate_outputs.
We changed each codegen'ed apply_with_saved (e.g.
MulBackward0::apply_with_saved) to call into Python to proxy a function
(compiled_autograd.ops.MulBackward0) into the graph. Then, we use the
node's InputMetadata to "guess" at the properties of the output Tensors
to create some new FakeTensors.
Some details:
- MulBackward0::apply_with_saved lives in libtorch_cpu, but needs to be
call to Python via libtorch_python. There is an indirection
(PyCompilerInterface) to do this.
- MulBackward0::apply_with_saved passes a C++ function to Python. To make
our lives easier, every codegen'ed apply_with_saved passes a C++
function with the same signature
`(variable_list, ivalue_list) -> variable_list`.
- We define how to pack arbitrary C++ types into IValue via a helper
IValuePacker struct and codegen functional variants of each builtin
C++ autograd node (e.g. MulBackward0_apply_functional_ivalue).
MulBackward0 before this PR:
https://gist.github.com/zou3519/a80381d5fa38e970e413fcd91b0530de
MulBackward0 after this PR:
https://gist.github.com/zou3519/0c2eee8b3d8d96232b51ef430b53c5b0
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143296
Approved by: https://github.com/jansel
Partially addresses https://github.com/pytorch/pytorch/issues/130170 for float scalars saved from forward pass of a custom c++ autograd function. With this PR, compiled autograd no longer recaptures when the float value changes, but downstream support isn't there yet: 4bdb4bbd86/torch/_dynamo/config.py (L58-L61)
Currently, any non-tensors passed in ctx->saved_data is specialized on by compiled autograd. To stop specializing on float values, we lift the float. We also require user code to use IValue::toSymFloat instead of IValue::toDouble in order to swap the SymFloat to proxy during compiled autograd tracing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133048
Approved by: https://github.com/jansel
ghstack dependencies: #132771
Addresses https://github.com/pytorch/pytorch/issues/130170 for int scalars saved from forward pass of a custom c++ autograd function
Currently, any non-tensors passed in ctx->saved_data is specialized on by compiled autograd. To stop specializing on int values, we lift the ints. We also require user code to use IValue::toSymInt instead of IValue::toInt in order to swap the SymInt to proxy during compiled autograd tracing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132771
Approved by: https://github.com/jansel