Finally, this PR adds BundledAOTAutogradCacheEntry. A BundledAOTAutogradCacheEntry is an AOTAutogradCacheEntry that saves the entire CompiledFxGraph directly in the entry.
This has some advantages:
- No more dependency on FxGraphCache at all
- Clearing FxGraphCache does not result in AOTAutogradCache miss
- Simpler logic, as BundledAOTAutogradCacheEntry has everything you need to load a full compiled python wrapper from a dynamo output
We plan to use BundledAOTAutogradCacheEntry for precompile. There's also a question of whether we want to use it for regular caching — the main disadvantage of this is having to save the same CompiledFxGraph twice, once in Inductor cache and once for AOTAutogradCache. With MegaCaching, this *could* be a regression in total cache size (as well as a minor cold start regression, as you have to save the same graph twice). I will import this and measure the mega cache space complexity, and if it looks good I'll enable it by default for caching as well.
On warm start, if AOTAutogradCache hits, you won't have to load inductor at all, so warm start overhead should be unaffected.
Differential Revision: [D74593304](https://our.internmc.facebook.com/intern/diff/D74593304)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152840
Approved by: https://github.com/zhxchen17
This PR:
- cleans up some existing comments that don't make sense anymore
- hooks up the "custom_op_default_layout_constraint" back (that seems to
have broken)
- cleans up the "lazy registration path" which seems to never get hit
anymore
- adds dislike_padding to nodes that require exact strides
Test Plan:
- tests + CI
disable padding
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148104
Approved by: https://github.com/shunting314, https://github.com/eellison
Summary: This PR adds a private configuration to the partitioner that ensures that the decision taken is the same across all ranks. This is a temporary workaround, as when size_hints are also taken into account in compiler collectives this workaround will not be needed anymore.
Test Plan:
This has been tested on some internal models, but I haven't added any tests in PyTorch (yet?)
T
Differential Revision: D73666017
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152264
Approved by: https://github.com/bdhirsh
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
This adds a new env var and flag,
autograd_cache_allow_custom_autograd_functions, (env var: `TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD`) which allows custom autograd functions into AOTAutogradCache.
@hirsheybar and I worked together to verify that the higher order op AutogradFunctionApply is pure with respect to the dynamo input being passed in, so this *should* be safe. I'm still putting it behind a flag and turning it on slowly, first on an internal model, though. Once we verify that it is correct on the internal model we can work to enable the flag by default.
Differential Revision: [D71633184](https://our.internmc.facebook.com/intern/diff/D71633184/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149751
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/
The argument here is that:
(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)
(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)
(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
This was added in https://github.com/pytorch/pytorch/pull/126320. It's a very nice feature, which can be used to predict memory usage for different budget values.
However, it had some limitations, notably in terms of resolution (it only sampled 21 points across the whole range thus missed many threshold values) and in distributed settings.
Here I fix those by using recursive binary searches to identify all thresholds (up to a resolution of 1e-3, which can be made configurable) and output them in SVG (to be able to discern different points), plus I add the rank to the filename and store it in a user-define directory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148678
Approved by: https://github.com/Chillee, https://github.com/fmassa
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
Summary:
A reland of https://github.com/pytorch/pytorch/pull/142426.
Copying the description over here:
For torch.export (strict and non-strict), we don't do functional decomposition. Instead, we preserve the custom triton ops as custom ops. This is because we want the exported program to be high-level and serializable.
The alternative:
If we decompose the custom op to a functional hop and make it a node in exported program, we need to figure out ways of serializing the hop and its arguments, which can be triton.jited python functions and triton dtypes. This is undesireble because:
it can be tedious to maintain layer that serialize the jited function (e.g. with a string) and dtypes.
changes to triton or the serialization logic for triton arguments can be BC breaking
exported program will expose the implementation detail (i.e. triton source code) for a specific backend (GPU) to users, which mixes levels of abstraction.
Future plans:
After this PR, in the short term, we expect users to have a seperate aot_compile stage that compiles the exported program into a Cubin file on the same machine that users call export, which does autotuning and removes triton dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
In the long term, we may export multiple cubins for the triton op directly.
Test Plan: see new tests.
Differential Revision: D67879685
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144284
Approved by: https://github.com/zou3519
Allow mutations mutations for subclasses that are non-contiguous.
Changes:
Removing assert in collect_metadata_analysis
Main requested testcase:
Compilation of NJT.index_put()
Adding test in test_nestedtensor.py, that compiles NJT.index_put()
It is decomposed to NJT split,unbind, which needed additional `torch._check`, `torch._check_is_size` for NJT.unbind() and guard_size_oblivious() usage in _meta_registrations and _inductor/lowering.py.
Special case:
If tangent is mutated outside of the graph, it does not participate in backward graph. Autograd in this case will set this tangent to zeros tensor.
We handle it separately in CompiledFunction.backward: not doing any processing for this tangent and broadcast to number of expected subclass unwrapped arguments.
disabling for dynamo 2 tests:
1/ For nested tensor - symbolic shapes issue on nested_tensor index operation that does splits [0, 0, 0] - there is a failure with "pending unbacked symints". This PR does not add more .tolist()/item() ops than it was before.
2/ As we do not fail with exception in collect_metadata_analysis new paths for dynamo started working and it started failing with smth strange that set_ in storage_offset (because of test for views) handling updates storage "cpu" -> "meta"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139630
Approved by: https://github.com/bdhirsh
This PR removes the functorch config that set an upper limit on the number of aliased
inputs with dynamic shapes. After moving them to be run at runtime in C++, the compilation
time and runtime (in true alias cases) improved, rendering the error no longer relevant.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141680
Approved by: https://github.com/bdhirsh
ghstack dependencies: #139554, #139555, #140013
Currently real tensor tracing raises MetadataMismatchErrors if registered fake kernels don't match the real kernels (e.g. shape, aliasing, dtype, etc.). This adds an option to use fake kernel inference to bypass mismatches - this option defaults to False for real tensor tracing, but is on for draft export.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139766
Approved by: https://github.com/angelayi, https://github.com/zou3519
This PR enables donated buffer in OSS and handles two edge cases:
1. While donated buffer relies on storage to check alias, sparse tensor subclasses does not provide access to storage. So we skip sparse tensor subclasses for donated buffer.
2. Handles missing "val" from n.meta. This is observed from `inductor/test_fused_attention.py::SDPAPatternRewriterCpuTests::test_sdpa_rewriter_11_cpu`,
`functorch/test_aotdispatch.py::TestAOTAutograd::test_input_mutation_simple_with_none_and_nontensor`, and
`inductor/test_compiled_autograd.py::TestCompiledAutograd::test_trace_run_with_rng_state`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139669
Approved by: https://github.com/bdhirsh
I was debugging an internal ne divergence for a while that ended up being because of a bad meta. I added an explicit a config option and an explicit backend `aot_eager_decomp_partition_crossref` to enable the FakeCrossRefMode when running the graph. I added an explicit backend bc I suspect it will be useful for internal models but I'm also happy to leave as config option.
It will only test ops that have meta to avoid memory overhead of hitting fallback path and running in eager.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138651
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
Summary: This just adds a config option and JK for turning on remote AOTAutogradCache. It does not implement anything with the new options being passed in. That will come next diff.
This PR also changes the command for turning on the local AOTAutogradCache to be more consistent to that of FXGraphCache: TORCHINDUCTOR_AUTOGRAD_CACHE
Test Plan: Existing tests should pass and should build
Reviewed By: oulgen
Differential Revision: D63321965
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137011
Approved by: https://github.com/oulgen
Implements donated buffer feature and adds unit tests. Donated buffer is a saved tensor that is not aliased with forward inputs, fw_outputs (except saved tensors), and bw_outputs. We detect donated buffers during `aot_dispatch_autograd` and store donated buffers in `ViewAndMutationMetadata`, such that it can be accssed in inductor.
Fixes#129496
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130580
Approved by: https://github.com/bdhirsh
This PR makes it so that we don't try to serialize FunctionalTensorWrappers. FunctionalTensorWrappers don't pickle well because they have no underlying storage. This should be fixable at a later point, but I might not be the right author for implementing the serialization for it. If there's a way to avoid actually saving the FunctionalTensorWrappers themselves and just saving the ViewMetadata so we can replay it, that would also work.
To do this, we disable view_replay_input_mutations when using AOTAutogradCache, and then only keep the functional tensor in the ViewAndMutationMeta if we need it for view_replay_input_mutations (i.e. the cache is off).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128335
Approved by: https://github.com/bdhirsh
This diff introduces AOTAutogradTestWithCache, which runs AOTAutogradTests with both dynamo and AOTAutogradCache.
To do this, for any verify_aot_autograd() calls in the original tests, we run compiled_f an extra time. We also turn on a new strict mode that throws any time a cache is missed due to weird reasons, like BypassAOTAutogradCache or FxGraphCacheMiss.
We use a mocked version of FXGraphCache to decrease the number of variables for these tests. The normal tests in test_aot_autograd_cache.py will still run with FXGraphCache. I might change my mind and unmock these in the future.
In total, 87 of the tests pass naturally. None of the tests fail in non strict cache mode, so the cache never crashes, it just misses more often than we'd like. The remaining 27 tests fail due to relatively simple (though not necessarily easy to fix) reasons. I'll fix the remaining test failures in the next few PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128222
Approved by: https://github.com/bdhirsh
This PR implements "V0" of AOTAutogradCache. Given an input to AOTAutograd, we calculate a cache key, then save an AOTAutogradCacheEntry.
Each AOTAutogradCacheEntry has:
- A CompiledForward and optionally a CompiledBackward
- A bunch of metadata.
CompiledForward and CompiledBackward each save the *key* to the FXGraphCache associated with the compiled object. FXGraphCache populates this key field as long as it's able to return a compiled graph given a set of inputs. We then load the same object from the FXGraphCache on an AOTAutogradCache hit.
On cache miss:
- Run AOTAutograd, up to AOTAutogradDispatch.post_compile.
- Save an AOTAutogradCacheEntry to the cache after compiling the necessary portions and receiving a cache key from FXGraphCache. In this we *always* compile the backwards ahead of time. The PR above this one implements backward lazy caching, so that we only save to the cache after compiling the backward in a lazy backward scenario.
- Return the resulting object
On cache hit:
- Run AOTAutogradCacheEntry.post_compile() on the cache key.
- This attempts to load the forward and backward graphs from FXGraphCache
- As long as we successfully load from FXGraphCache, it's a hit. We then rewrap the callable with post compile wrappers using our saved metadata.
For now, we ignore the fakified out and debug wrappers. We only save to the cache if Fakified out is turned off.
V0 Guards behavior:
FXGraphCache serializes guards that are needed in the shape_env based on the symint inputs to the graph. The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly the same as the ones it passes to inductor, for both the forward and backward passes. (This does *not* mean that the tensor values passed in are the same: only that their symints are). That is, AOTAutograd and Inductor never create new guards based on symints with *different sources* than those passed to it by inductor.
We don't currently store any AOTAutograd specific guards: my hypothesis is that FXGraphCache already stores these, as any guards generated by AOTAutograd should already be in the shape_env before calling into inductor, and we don't generate new guards post inductor. If this is needed, I'll add it in another diff.
Testing:
We'll start with some basic unit tests, but I'll be adding more and more complicated testing as the next step.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126791
Approved by: https://github.com/bdhirsh
This PR implements "V0" of AOTAutogradCache. Given an input to AOTAutograd, we calculate a cache key, then save an AOTAutogradCacheEntry.
Each AOTAutogradCacheEntry has:
- A CompiledForward and optionally a CompiledBackward
- A bunch of metadata.
CompiledForward and CompiledBackward each save the *key* to the FXGraphCache associated with the compiled object. FXGraphCache populates this key field as long as it's able to return a compiled graph given a set of inputs. We then load the same object from the FXGraphCache on an AOTAutogradCache hit.
On cache miss:
- Run AOTAutograd, up to AOTAutogradDispatch.post_compile.
- Save an AOTAutogradCacheEntry to the cache after compiling the necessary portions and receiving a cache key from FXGraphCache. In this we *always* compile the backwards ahead of time. The PR above this one implements backward lazy caching, so that we only save to the cache after compiling the backward in a lazy backward scenario.
- Return the resulting object
On cache hit:
- Run AOTAutogradCacheEntry.post_compile() on the cache key.
- This attempts to load the forward and backward graphs from FXGraphCache
- As long as we successfully load from FXGraphCache, it's a hit. We then rewrap the callable with post compile wrappers using our saved metadata.
For now, we ignore the fakified out and debug wrappers. We only save to the cache if Fakified out is turned off.
V0 Guards behavior:
FXGraphCache serializes guards that are needed in the shape_env based on the symint inputs to the graph. The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly the same as the ones it passes to inductor, for both the forward and backward passes. (This does *not* mean that the tensor values passed in are the same: only that their symints are). That is, AOTAutograd and Inductor never create new guards based on symints with *different sources* than those passed to it by inductor.
We don't currently store any AOTAutograd specific guards: my hypothesis is that FXGraphCache already stores these, as any guards generated by AOTAutograd should already be in the shape_env before calling into inductor, and we don't generate new guards post inductor. If this is needed, I'll add it in another diff.
Testing:
We'll start with some basic unit tests, but I'll be adding more and more complicated testing as the next step.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126791
Approved by: https://github.com/bdhirsh
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.
This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.
I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:
```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```
Potential later follow ups:
* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125115
Approved by: https://github.com/IvanKobzarev
We should eventually make the non-overlapping checks faster when dynamic shapes are enabled, but this is pretty difficult to do. So for now this PR adds a config that lets us fail fast when this situation happens, instead of causing compile times to secretly come to a crawl.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123455
Approved by: https://github.com/ezyang
Given the following code/dynamo graph:
```
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_print = torch.ops.aten._print('moo')
res = l_x_ + l_x_; l_x_ = None
_print_1 = torch.ops.aten._print('moo')
return (res,)
```
AOTAutograd will trace the following program, threading tokens from the inputs, through the effectful operator calls (torch.ops.aten._print), and as an output:
```
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2, 3]"):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)
```
However when we get to inductor, since we want the inductor generated code to not have any token inputs/outputs for better readability, we want to modify the aten graph by removing the tokens from inputs, and creating them through `torch.ops.aten._make_dep_token`, and sinking them through the `torch.ops.aten._sink_tokens` operators.
This has to be done *after* the partitioner, otherwise the partitioner will add the make_token/sink_token operators to the backwards graph.
```
class <lambda>(torch.nn.Module):
def forward(self, arg1_1: "f32[2, 3]"):
_make_dep_token_default: "f32[0]" = torch.ops.aten._make_dep_token.default()
with_effects = torch._higher_order_ops.effects.with_effects(_make_dep_token_default, torch.ops.aten._print.default, 'moo'); _make_dep_token_default = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.aten._sink_tokens.default((getitem_2,)); getitem_2 = None
return (add,)
```
When doing inductor lowering, we convert `with_effects` calls to an `EffectfulKernel`, which just a `FallbackKernel` but with a pointer to previous effectful operator's call. During scheduling, we will create a `StarDep` between the EffectfulKernel and its previous EffectfulKernel so that they don't get reordered. The inductor generated python code looks like:
```
def call(args):
arg1_1, = args
args.clear()
assert_size_stride(arg1_1, (2, 3), (3, 1))
# Source Nodes: [_print], Original ATen: []
buf2 = aten._print.default('moo')
# Source Nodes: [_print_1], Original ATen: []
buf3 = aten._print.default('moo')
buf4 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
cpp_fused_add_0(arg1_1, buf4)
del arg1_1
return (buf4, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122347
Approved by: https://github.com/bdhirsh