https://github.com/pytorch/pytorch/issues/148222
Goal:
At the moment autograd saved tensors hooks are run in eager after compiled forward.
They are executed at the same time for all saved tensors.
Hooks can be used to reduce amout of memory used for saved tensors, doing quantization or offloading to cpu.
This is suboptimal for optimization of peak memory.
Better solution will be to put the hooks in the graph, as close as possible to the last usage of the tensor.
To get user specified autograd saved tensors hooks in the graph.
Logic:
UX:
If user specifies with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm).
Where pack_gm and unpack_gm are torch.fx.GraphModule.
Then AotAutograd will retrace those graph modules, doing decompositions and functionalization in aot_autograd, inlining the result graphs in forward epilogue and backward prologue.
User may want to use control logic in the hooks, for example applying quantization only for specific dtypes and sizes.
This is also possible, user can put it into torch.fx.wrap function and use symbolic trace to make a GraphModule.
In that case AotAutograd cahing will work only in case when user explicitly set to the torch.fx.wrap call_function node "user_cache_hash" metadata.
If this metadata set - then aot_autograd cache can use saved cache artifact.
If metadata is not set - then cache is bypassed.
Dynamo:
Dynamo traces pack and unpack hooks and installs them as subgraph and explicitly adds to the output_graph. (As those subgraphs are not used and will not be copied in the result by default).
The complexity here is that at this moment we do not have example of inputs for the hooks.
We trace pack_hook with some Tensor from the inputs.
The result subgraphs are added to the hashing of AotAutograd Cache.
In AotAutograd we retrace the graph with the true saved tensors coming from partitioner.
Backwards Compatibility:
As current hooks are executed in eager mode and not all of them will be traceable - we only try to put in the graph hooks, explicitly marked by user with annotation (@_inlineable_saved_tensors_hooks).
For other hooks or if compiled autograd is enabled - keep the same logic.
Recompilations:
Hooks are guarded with lambda guard matching function id to cause recompilation if user reruns compiled function.
Aot_autograd:
After partitioner prepared forward and backward module - we trace prepared at Dynamo graphs for pack and unpack hooks and inline them in epilogue of forward and prologue of backward. Forward outputs and backward inputs are changed, transparently for user.
We do not try to put it close the last usage etc., relying on inductor to do this optimization.
```
INFO: TRACED GRAPH
===== Forward graph pre saved_tensors_hooks inlining 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
return (view, add, primals_1, primals_2)
INFO: TRACED GRAPH
===== Backward graph pre saved_tensors_hooks inlining 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
return (view, add, primals_1, primals_2)
INFO: TRACED GRAPH
===== saved_tensors_pack_hook add 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class pack_float8(torch.nn.Module):
def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
# No stacktrace found for following nodes
_to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn); x_1 = None
return (torch.float32, _to_copy)
INFO: TRACED GRAPH
===== saved_tensors_unpack_hook add 3 =====
<eval_with_key>.22 from /data/users/ivankobzarev/a/pytorch/torch/fx/experimental/proxy_tensor.py:1225 in wrapped class pack_float8(torch.nn.Module):
def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
# No stacktrace found for following nodes
_to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn); x_1 = None
return (torch.float32, _to_copy)
INFO: TRACED GRAPH
===== Forward graph 3 =====
/data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1); primals_3 = None
# No stacktrace found for following nodes
_to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add, dtype = torch.float8_e4m3fn)
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]); add = None
return (view, _to_copy, primals_1, primals_2)
INFO: TRACED GRAPH
===== Backward graph 3 =====
<eval_with_key>.21 class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", add_packed_2: "f8e4m3fn[s0, s1][s1, 1]cuda:0", tangents_1: "f32[s0, s1][s1, 1]cuda:0"):
# No stacktrace found for following nodes
_to_copy: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add_packed_2, dtype = torch.float32); add_packed_2 = None
# File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
add_7: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_1, _to_copy); tangents_1 = _to_copy = None
return (None, None, add_7)
```
Differential Revision: [D72187044](https://our.internmc.facebook.com/intern/diff/D72187044)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150032
Approved by: https://github.com/bdhirsh
### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.
This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.
### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
* This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
* Opt-in with minimal test logic changes + no substantial impact on other tests.
#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.
```python
@dataclass
class SampleRule(ABC):
# function to indicate whether the rule applies to this op; return True if so
# NB: str arg of callable is device_type
op_match_fn: Callable[[str, OpInfo], bool] = None
# function to indicate whether the rule applies to this sample; return True if so
sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
@dataclass
class XFailRule(SampleRule):
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
@dataclass
class SkipRule(SampleRule):
...
```
* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`.
* The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
* The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.
### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
@ops(op_db)
def test_foo(self, device, dtype, op):
for sample in op.sample_inputs(device, dtype, requires_grad=False):
# do some SampleInput-based test logic
output = op.op(sample.input, *sample.args, **sample.kwargs)
...
```
This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.
This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
XFailRule(
error_type=ValueError,
error_mg="2D inputs not supported",
op_match_fn=lambda device, op: (
# NB: logic for which ops this rule applies to goes here
op.full_name == "add"
),
sample_match_fn=lambda device, sample: (
# NB: logic which samples this rule applies to goes here
sample.input.dim() == 2
),
# NB: optional rule identifier can help with debugging matched rules
name="add_with_2D_inputs_not_supported",
),
# NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
# this skips a particular SampleInput instead of xfailing :)
SkipRule(...),
...
]
class MyTestCase(TestCase):
@ops(op_db)
@sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS)
# NB: the @ops decorator automatically filters out any rules that don't apply to this op
def test_foo(self, device, dtype, op):
for sample, subtest_ctx in op.sample_inputs(
# NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True,
# an informative error will be thrown.
device, dtype, requires_grad=False, use_subtests=True
):
# NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
with subtest_ctx(self):
# do some SampleInput-based test logic
output = op.op(sample.input, *sample.args, **sample.kwargs)
...
```
More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.
I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`:
```python
...
# pre-existing xfail
xfail("item"),
# new granular xfail using syntactic sugar over the general system
xfailIf(
"to",
lambda sample: (
sample.kwargs["memory_format"] == torch.channels_last
),
),
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140443
Approved by: https://github.com/janeyx99, https://github.com/zou3519
ghstack dependencies: #140160, #138370
Fixes#130284Fixes#130653
- Add `torch.library.register_vmap` to custom ops
- Add `register_vmap` for operators in ops in custom_op_db.
- Make `torch.autograd.Function` support kwarg-only kwargs for vmap
- test operators in op_db with `tests/test_vmap`.
- change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589
Approved by: https://github.com/zou3519
Fixes#130284Fixes#130653
- Add `torch.library.register_vmap` to custom ops
- Add `register_vmap` for operators in ops in custom_op_db.
- Make `torch.autograd.Function` support kwarg-only kwargs for vmap
- test operators in op_db with `tests/test_vmap`.
- change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589
Approved by: https://github.com/zou3519
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127125
Approved by: https://github.com/Skylion007
ghstack dependencies: #127122, #127123, #127124
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
- `assert(a == b)` -> `assert a == b`
- `if(x > y or y < z):`->`if x > y or y < z:`
- And `return('...')` -> `return '...'`
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
* Enable PERF402. Makes code more efficient and succinct by removing useless list copies that could be accomplished either via a list constructor or extend call. All test cases have noqa added since performance is not as sensitive in that folder.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115505
Approved by: https://github.com/malfet
This PR adds the following OpInfo tests:
- vmap x vjp x vmap
- vjp x vmap x vmap
- vjp x vjp x vmap
These OpInfo tests only run for the autograd_function_db. In general,
testing composition of two transforms is sufficient to convince
ourselves that functorch works on a given operator.
The autograd.Function testing (especially the upcoming
generate_vmap_rule) didn't feel rigorous enough to me, so I added these
additional tests to convince myself.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90962
Approved by: https://github.com/samdow, https://github.com/soulitzer
Adds a set of generated tests for `AOTAutograd` using the `ModuleInfo` db, analogous to the `OpInfo`-based tests. Includes the following changes:
* Adds a `TestEagerFusionModuleInfo` test class, with both symbolic and non-symbolic tests, just like the OpInfo tests.
* Test logic "functionalizes" the module under test and calls into the now-factored-out verification logic the OpInfo tests use to compare compiled vs. non-compiled function outputs / grads.
* Adds a `decorateForModules(decorator, module_set)` utility to `test/functorch/common_utils.py` to handle xfails, skips, etc. The pre-existing logic is specific to ops, and I didn't want to duplicate all that, so I kept additions minimal with this function.
* Bunch of xfails to get everything passing; haven't looked deeply into all these yet. #90500 is relevant for the RNN failures.
* Fixes a bug in the `ModuleInfo` entry for `NLLLoss` to ensure sample input has the requested `requires_grad` setting (was causing spurious test failures).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90980
Approved by: https://github.com/ezyang
Happy to split this PR more if it helps.
This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.
Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.
autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).
Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.
Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer