This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/
The argument here is that:
(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)
(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)
(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/144095
open to suggestions: the `hint_int(..., fallback=...)` API feels like a bit of a footgun, because:
(1) we use the same guess for every unbacked symint (both symbols, and compound expressions)
(2) the user may have established some relationship between some unbacked symints that we are not taking into account.
I'm not sure how real of an issue (2) is - is it common to e.g. generate two unbacked symints, and then add a runtime assert that they are unequal?
Instead I did something simpler that's just enough to fix the linked issue: if we have a sympy expression containing an unbacked symbol (e.g. `u0 + 1`), then the partitioner will now fill in the symbol with our guess instead of the expression (plugging in `u0=4096` gets us 4097). This was important for an internal custom op, that had some logic like this:
```
def custom_op(x: [u0], y: [u0 + 1]):
assert x.shape[0] = y.shape[0] - 1
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144097
Approved by: https://github.com/laithsakka
This was added in https://github.com/pytorch/pytorch/pull/126320. It's a very nice feature, which can be used to predict memory usage for different budget values.
However, it had some limitations, notably in terms of resolution (it only sampled 21 points across the whole range thus missed many threshold values) and in distributed settings.
Here I fix those by using recursive binary searches to identify all thresholds (up to a resolution of 1e-3, which can be made configurable) and output them in SVG (to be able to discern different points), plus I add the rank to the filename and store it in a user-define directory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148678
Approved by: https://github.com/Chillee, https://github.com/fmassa
Earlier, with inline flag we were lifting id-guarded tensors to the inputs to the Fx graph. But this offers no benefit. Main idea behind lifting parameters as inputs was to reuse the compilation units across many instances of the nn-module. However, if we are guarding on the `id`, we are explicitly specializing the compiled artifact to the parameter.
This PR installs the parameters back into the graph. The benefit is removal of all pre-graph bytecode to extract the id-guarded tensors from locals/globals. This increases speedup from 1.67x to 1.75x for an internal model that has large number of optimizer parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147824
Approved by: https://github.com/jansel
Co-authored-by: Jason Ansel <jansel@meta.com>
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
Bug was reported by internal user.
AOTD classified outputs that are aliases of intermediates of the graph in different categories.
...
- output is alias of intermediate which base is already output
- output is alias of intermediate which base is not in output
If we look at the fn:
```
def fn(x):
ix = x + 1
a = ix.transpose(0, 1)
return a.detach(), a
```
output 0: detach view of alias a, where a is already output
output 1: alias of intermediate ix, then additional output ix will be added internally
output 0 base is TensorAlias(a) in this case, but could be Tensor.
Adding runtime unwrapping solves this problem.
Alternatively we should track base of a.detach() all the way to ix, in that case the base will be always a Tensor, not TensorAlias.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147638
Approved by: https://github.com/bdhirsh
Since the functional autograd + compiled autograd migration, we don't trace into nodes anymore, and everything is lifted. We can't support this flag which tries to inline make_fx style in CA initial pass. There's no more usage internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146720
Approved by: https://github.com/zou3519
Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.
I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...
Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
Approved by: https://github.com/Chillee
In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102
Approved by: https://github.com/oulgen
Fixes https://github.com/pytorch/pytorch/issues/143876
Open to other suggestions - we have an invariant that all nodes in our ATen graphs should have a `meta['val']` field, but I don't think this is actually true in all cases, so I just hardcoded the invariant to ignore `_assert_scalar()` (which is a "special" op used in dynamic shapes for runtime asserts, and doesn't have a meta['val'] field)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143877
Approved by: https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/145081
This looks like it was a source of quadratic compile times in the torchtitan CP graphs. There's some code in the partitioner that iteratively adds users of a node to a heap, and pops the earliest user. If you have long parallel chains of fusible ops that all eventually feed into some shared ops, then this can result in:
(1) a node getting added to the heap many times
(2) each time we pop that node, we add (duplicates of) each of that node users to the heap
(3) repeat with each user
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145082
Approved by: https://github.com/xmfan
Summary:
## Why
To make it possible to run torch dispatch mode inside compiled modules. This is to enable running MemoryTrackerMode (in next diff) to collect memory usage of compiled modules.
## What
Add a backend aot_eager_decomp_partition_with_mode.
Add an enable_log to the backend to control the compilation logging (which can be very verbose and slow the run of mode)
Test Plan:
unittest
E2e tested in the next diff which shows the memory read from the mode passed to this backend is very close to the actual job's memory snapshot.
Differential Revision: D67227144
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143250
Approved by: https://github.com/bdhirsh
We will always proxy autograd.Function nodes in compiled autograd's
initial graph capture (previously there was an
option to proxy vs trace into the autograd.Function)
We have some requirements for the AOTBackward. Compiled Autograd runs
accumulate grad reordering passes on the AOTBackward graph directly
after the initial graph capture, so we can't just proxy a single node for it.
Instead, we:
- proxy the AOTBackward prologue function into the CA graph
- copy-paste the AOTBackward graph into the CA graph
- trace directly through the epilogue (the traced nodes go into the CA
graph).
Tracing through the epilogue is safe (assuming no Tensor subclasses)
because the only thing the epilogue does is drop some outputs. The
Tensor subclass situation was already broken so this doesn't regress
anything but this PR sets it up to be fixed (in a followup, where we
will proxy "make_subclass" calls into the graph from the epilogue).
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143405
Approved by: https://github.com/jansel, https://github.com/xmfan
ghstack dependencies: #143296, #143304, #143387
Fix: #141974
This PR makes `ViewMeta` sequence, present in functional tensors,
serializable with pickle. In order to accomplish that, it makes
`ViewMeta` an abstract class with overridable `forward` and `reverse`
functions. In this context, each operation that once instanciated
`ViewMeta`, should now create a new specialized class that inherits from
`ViewMeta. Therefore, this PR also uses codegen for creating these
specializations.
In summary, these are the changes this PR introduces:
- `ViewMeta` is turned into an abstract class (see
_FunctionalStorageImpl.cpp_). `forward` and `reverse` are pure virtual
functions that need to be implemented. `to_out_index` should be
implemented by operations that might return more than 1 output.
- New `ViewMeta` specializations for `resize_` and `_unsafe_view` are
created (see _FunctionalizeFallbackKernel.h_).
- New templates _ViewMetaClasses.{cpp,h}_ are created. They hold the
declaration and definition of the `ViewMeta` specializations, which
are automatically generated in the ATen codegen (see _gen.py_).
- New `_functionalization` Python sub-module is created (see
_Module.cpp_). It serves as namespace for the `ViewMeta`
specializations and `InverseReturnMode` enum.
- New template _ViewMetaClassesPythonBinding.cpp_ is created. It holds
the automatically generated Python bindings for the `ViewMeta`
specialization, which are generated in the torch codegen (see
_generate_code.py_).
Note that this PR makes use of codegen at 2 different moments:
- ATen codegen (_gen.py_): generates the `ViewMeta` specialized classes.
- Torch codegen (_generate_code.py_): generated the Python bindings for
them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143712
Approved by: https://github.com/bdhirsh
Differential Revision: [D68173093](https://our.internmc.facebook.com/intern/diff/D68173093/)
This diff allows any function in torch_non_c_binding_in_graph_functions to be safe to cache. These functions should be safe to cache because they are part of the torch API, and do not save global state (or if they do, dynamo creates unique guards around the constants they return).
A function that's allowed in a dynamo graph is safe to cache for AOTAutograd purposes as long as:
- It's functional (i.e. does not access global state);
- or its value is constant folded away (and guarded against by dynamo)
The tricky cases are functions that dynamo uses special handlers to track. These special handlers can sometimes close over stuff that's safe for dynamo locally, but isn't encoded anywhere when cached across processes. An example of this is `DTensor.from_local`, where various DeviceMesh information doesn't change in the same dynamo process, but can change across multiple processes. The handler for `DTensor.from_local` closes over these and dynamo creates a proxy for the function call. This is not safe to cache.
That said, most special handlers are in fact functional and safe. So I add a unit test to test_trace_rules.py that confirms that any function with special handlers in dynamo added to this list needs to be audited to be safe to cache.
The list of safe handlers there either:
- Don't access global state;
- Guard on global state; or
- Always returns a constant that never changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144802
Approved by: https://github.com/bdhirsh