This should fix the hang in https://fb.workplace.com/groups/1075192433118967/permalink/1603268720311333/
The argument here is that:
(1) in general, it is not safe for the partitioner to sometimes choose to recompute collectives in the backward. Why? If we are running a distributed job, where many ranks are compiling at the same time, we need every rank to make a consistent decision about which collectives are recomputed for backward. If we let each compiler instance make its own choice without any cross-rank communication, they can make different choices and cause NCCL hangs (see the link above)
(2) later on, we'll want an `spmd_mode` flag that causes the compiler to issue collectives and communicate info across ranks. Once we have such a config, then turning it on should make it safe for the partitioner to potentially choose to recompute collectives (and agree on the binary "recompute-or-save" choice across all ranks)
(3) even without an `spmd_mode`, users can override this choice by using `torch.utils.checkpoint()` in their user code. User checkpointing generally always overrides the partitioner, and this should be safe because we expect the user to apply checkpointing consistently across ranks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147561
Approved by: https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/144095
open to suggestions: the `hint_int(..., fallback=...)` API feels like a bit of a footgun, because:
(1) we use the same guess for every unbacked symint (both symbols, and compound expressions)
(2) the user may have established some relationship between some unbacked symints that we are not taking into account.
I'm not sure how real of an issue (2) is - is it common to e.g. generate two unbacked symints, and then add a runtime assert that they are unequal?
Instead I did something simpler that's just enough to fix the linked issue: if we have a sympy expression containing an unbacked symbol (e.g. `u0 + 1`), then the partitioner will now fill in the symbol with our guess instead of the expression (plugging in `u0=4096` gets us 4097). This was important for an internal custom op, that had some logic like this:
```
def custom_op(x: [u0], y: [u0 + 1]):
assert x.shape[0] = y.shape[0] - 1
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144097
Approved by: https://github.com/laithsakka
This was added in https://github.com/pytorch/pytorch/pull/126320. It's a very nice feature, which can be used to predict memory usage for different budget values.
However, it had some limitations, notably in terms of resolution (it only sampled 21 points across the whole range thus missed many threshold values) and in distributed settings.
Here I fix those by using recursive binary searches to identify all thresholds (up to a resolution of 1e-3, which can be made configurable) and output them in SVG (to be able to discern different points), plus I add the rank to the filename and store it in a user-define directory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148678
Approved by: https://github.com/Chillee, https://github.com/fmassa
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146102
Approved by: https://github.com/oulgen
Fixes https://github.com/pytorch/pytorch/issues/143876
Open to other suggestions - we have an invariant that all nodes in our ATen graphs should have a `meta['val']` field, but I don't think this is actually true in all cases, so I just hardcoded the invariant to ignore `_assert_scalar()` (which is a "special" op used in dynamic shapes for runtime asserts, and doesn't have a meta['val'] field)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143877
Approved by: https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/145081
This looks like it was a source of quadratic compile times in the torchtitan CP graphs. There's some code in the partitioner that iteratively adds users of a node to a heap, and pops the earliest user. If you have long parallel chains of fusible ops that all eventually feed into some shared ops, then this can result in:
(1) a node getting added to the heap many times
(2) each time we pop that node, we add (duplicates of) each of that node users to the heap
(3) repeat with each user
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145082
Approved by: https://github.com/xmfan
Summary:
## Why
To make it possible to run torch dispatch mode inside compiled modules. This is to enable running MemoryTrackerMode (in next diff) to collect memory usage of compiled modules.
## What
Add a backend aot_eager_decomp_partition_with_mode.
Add an enable_log to the backend to control the compilation logging (which can be very verbose and slow the run of mode)
Test Plan:
unittest
E2e tested in the next diff which shows the memory read from the mode passed to this backend is very close to the actual job's memory snapshot.
Differential Revision: D67227144
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143250
Approved by: https://github.com/bdhirsh
# Summary:
Full Context: https://docs.google.com/document/d/1-j5KSbfGFJQcH4sYh7BIeJXso3zYzl5G5yFQqXdKx_o/edit?usp=sharing
tl;dr
This change introduces classes which help determine a dynamic memory budget. This will mostly be helpful for models with many implicit graph breaks.
---
New Classes:
*GraphInfoProvider*
* Takes the joint_graph as well as the input memories and runtimes and parses the graph + values into usable forms for the SolverEvaluator.
*KnapsackEvaluator*
* Provides a function: Given all of the four inputs (solver function as a callable, max_dynamic_memory_budget, min_dynamic_memory_budget, dynamic_memory_budget_pareto_granularity) it returns an approximation of the knee point of the pareto distribution.
# Test Plan:
### LintRunner
LintRunner Output: P1700445547
### Unit Tests
```
$ buck test @mode/opt //caffe2/test/functorch:test_ac_knapsack
`@mode/opt` was specified, but not found. Using file at `//mode/opt`.
This behavior is being deprecated. Please use `"@//mode/opt"` instead
File changed: fbcode//caffe2/.ruff_cache/0.7.4/.tmpB6PmDS
File changed: fbsource//xplat/caffe2/test/functorch/test_ac_knapsack.py
File changed: fbcode//caffe2/.ruff_cache/0.7.4/.tmpyjCiPn
20 additional file change events
Buck UI: https://www.internalfb.com/buck2/414ead46-9ede-4192-8e1a-5d3c52bdb9cc
Test UI: https://www.internalfb.com/intern/testinfra/testrun/6473924710342830
Network: Up: 0B Down: 0B (reSessionID-159794b9-9d61-477e-8e63-9bdeaa537dca)
Analyzing targets. Remaining 0/214
Executing actions. Remaining 0/6933 0.1s exec time total
Command: test. Finished 1 local
Time elapsed: 18.5s
Tests finished: Pass 15. Fail 0. Fatal 0. Skip 0. Build failure 0
```
### Test Run
Updated the config:
```
activation_memory_budget_solver: DYNAMIC_MEMORY_BUDGET_DP
```
Confirming proper execution via: [aps-fb_fm_v4_768_01_dynamic-2a792ba8af](https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-fb_fm_v4_768_01_dynamic-2a792ba8af?job_attempt=0&version=0&env=PRODUCTION)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143539
Approved by: https://github.com/jansel
Summary:
Original Context Doc: https://docs.google.com/document/d/1Gv5ZqN3UY7kTuCUd0JLU3L_onwVIr2vd3AiZyim1Muc/edit?disco=AAABZX_tqdk
### Changes
This diff restructures the Partitioners.py file in the _functorch package.
* Moves the three knapsack problem algorithems (greedy, ilp, dp) into a separate file
Test Plan:
### Unit Testing
```
$ buck test mode/opt //caffe2/test/functorch:test_ac
File changed: fbsource//xplat/caffe2/test/functorch/TARGETS
File changed: fbsource//xplat/caffe2/test/functorch
File changed: fbsource//xplat/caffe2/test
7 additional file change events
Soft Error: source_directory_includes_subpackage: Directory `v2.17.1-1` of package `fbsource//third-party/nccl` may not cover any subpackages, but includes subpackage `v2.17.1-1/src/tests`.
Soft Error: source_directory_includes_subpackage: Directory `v2.18.3-1` of package `fbsource//third-party/nccl` may not cover any subpackages, but includes subpackage `v2.18.3-1/src/tests`.
Soft Error: source_directory_includes_subpackage: Directory `v2.19.3-1` of package `fbsource//third-party/nccl` may not cover any subpackages, but includes subpackage `v2.19.3-1/src/tests`.
Buck UI: https://www.internalfb.com/buck2/a2f91f8a-5326-435e-9075-5af0de930b8b
Test UI: https://www.internalfb.com/intern/testinfra/testrun/7036874660108924
Network: Up: 28MiB Down: 3.5GiB (reSessionID-19af3d71-7528-448c-9126-5615d27b3bd7)
Jobs completed: 423656. Time elapsed: 3:57.2s.
Cache hits: 99%. Commands: 146147 (cached: 145758, remote: 317, local: 72)
Tests finished: Pass 8. Fail 0. Fatal 0. Skip 0. Build failure 0
```
### Tested on Local Training Run
```
CUDA_VISIBLE_DEVICES=5,6 AOT_PARTITIONER_DEBUG=1 PARTITIONER_MEMORY_BUDGET_PARETO=0 buck2 run mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2 2>&1 | tee log_2024-11-2320:41:50.757256__bento_trigger.txt
```
Output Summary Paste: P1685697066
Differential Revision: D65800097
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141614
Approved by: https://github.com/jansel, https://github.com/Chillee
When multiple nodes have similar sizes and are part of the `banned_nodes` (which is a `set` and not a `list`), there is non-determinism present in the partitioner due to sorting only by node-size.
This PR fixes this by also sorting by node name.
It would be good to add some tests, but I'm not sure about the best way to do it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141682
Approved by: https://github.com/Chillee, https://github.com/yf225
Summary: When AOT_PARTITIONER_DEBUG is set to 1 and debug logging is turned on we can now log the full input and output for each knapsack problem.
Differential Revision: D65633086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140757
Approved by: https://github.com/jansel
Summary:
Removes print statements and implements logging via the logging library.
Hopefully this will allow more control on the level of logging when running models.
Test Plan:
```
AOT_PARTITIONER_DEBUG=1 buck2 run @mode/opt //aps_models/ads/icvr:icvr_launcher -- mode=local_fb_fm_v4 launcher.num_workers=2
```
Resulting output paste: P1674535630
* Full logs paste: P1674535621
```
pastry P1674535621 | grep "functorch/partitioners.py" | pastry
```
Logging results: P1674549514
Differential Revision: D61678215
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139782
Approved by: https://github.com/paryxyt, https://github.com/jansel
Summary: same as title. Plan is to pass a callable to the partitioner to perform custom autoAC via an ILP. This is the same as a previous diff D63714905 which was landed and then subsequently reverted by PyTorch Release Engineering because of a failing unit test (f7b8d36c28). We think the unit test is buggy, and we also fix the same.
Test Plan: tbd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137785
Approved by: https://github.com/basilwong
Co-authored-by: Huy Do <huydhn@gmail.com>
Summary: same as title. Plan is to pass a callable to the partitioner to perform custom autoAC via an ILP. This is the same as a previous diff D63714905 which was landed and then subsequently reverted by PyTorch Release Engineering because of a failing unit test (f7b8d36c28). We think the unit test is buggy, and we also fix the same.
Test Plan: tbd
Differential Revision: D64246495
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137785
Approved by: https://github.com/basilwong
Summary: making it so that the config can pass `config.activation_memory_budget_solver` as a callable method and then that callable is invoked to determine the set of saved/recomputed nodes.
Test Plan: tbd
Reviewed By: Chillee, basilwong
Differential Revision: D63714905
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137314
Approved by: https://github.com/eellison, https://github.com/basilwong
Co-authored-by: Parikshit Shah <parikshit@meta.com>
Previously if we had a graph like:
```
triton_kernel_wrapper_functional_proxy = triton_kernel_wrapper_functional(...)
getitem: "f32[3][1]cuda:0" = triton_kernel_wrapper_functional_proxy['out_ptr']
getitem_1: "f32[3][1]cuda:0" = triton_kernel_wrapper_functional_proxy['out2_ptr']
sigmoid: "f32[3][1]cuda:0" = torch.ops.aten.sigmoid.default(getitem_1)
mul: "f32[3][1]cuda:0" = torch.ops.aten.mul.Tensor(tangents_1, sigmoid)
```
The partitioner would assume that the `sigmoid()` could be fused into either its user (the pointwise mul), or its producer (the user triton kernel). This could lead to a bad partitioning:
(1) If the partitioner thinks we can fuse the sigmoid with its producer triton kernel, we would keep the sigmoid compute in the forward, and have to generate two separate kernels in the forward (user triton kernel, dedicated sigmoid kernel)
(2) if the partitioner puts the sigmoid in the backward instead, we could fuse it with an existing backward kernel (the mul with a tangent)
Reviewed By: embg
Differential Revision: D63551393
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136878
Approved by: https://github.com/zou3519
If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad` -> circular dependency with (1)!
Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.
----
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good).
This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern.
------
Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries`
- `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients`
- `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching`
- `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager`
- `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32`
- `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730
Approved by: https://github.com/bdhirsh
- The new implementation (auto_functionalized_v2) is enabled by default but can be disable
using an inductor flag.
- In export mode the old implementation is used.
**Motiviation**
Previous functionalization fails to re-inplace arguments when they are view over other tensors.
see issue https://github.com/pytorch/pytorch/issues/131192
The new functionalization is easier to re-inplace for views.
**A) Functionalizations pass**
consider a program:
```
func(t)
x = t[0]
y = t[1]
foo(x, y) # custom operator with x, y mutable
return (x, y, t)
```
- To functionalize `foo` we generate a function that operates on the base tensors of the inputs; (x.base() and y.base())
and record how to regenerates the views out of the base for argument x by recording ```ViewInfo=(x.base(), x.size(), x.stride, x,storage_offset())```
- Due to some limitations on the torch.export arguments format, we have to generate alot of arguments, but this is something we can simplify in the future, for the example above we get the following function.
```
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default,
_x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0 ,
_y_base_index = 0,_y_size = (), _y_stride = (), _y_storage_offset = 1 ,
_all_bases = [arg0_1])
```
- In the code above:
- _all_bases[t]: refers to a unique set of bases for all foo arguments.
- for each argument x we have _x_base_index, _x_size, _x_stride, _x_storage_offset that can be used to (1) regenerate x from _all_bases[_x_base_index] or a copy of a the base.
- the output of auto_functionalized is foo output , followed by x tensors one for each base in _all_bases, that is a copy of the base tensor after observing the mutations of the all the arguments that are views of that base.
- for each use of a base in _all_bases or a view of it , that are after the call to foo, replace it with a view of the new output
for the function above after functionalization we get :
```
def forward(self, arg0_1: "f32[2][1]cpu"):
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized[1]; auto_functionalized = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
# No stacktrace found for following nodes
select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None
return (select_2, select_3)
```
**B) Semantics of auto_functionalize**
The new semantics of auto_functionalize is as the following:
1. For each base in all_bases, copy the base and create all_bases copies. (if a base is inplaced we do not need to copy it)
2. For each arg, regenerate the arg from the copy of its base using the view information above.
3. return the original foo output followed by the new bases.
**C) Re-inplace pass**
since auto_functionalize not copy the bases, what we actually inplace is the bases.
(run just like before but on the beses instead of args).
1. For each base b in _all_bases check if there is any use of base (or its aliases/views) after auto_functionalize (before its overwritten with a copy) if there is not any, then inplace it (avoid copying it in step 1 above).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134409
Approved by: https://github.com/zou3519
Support of effectful operations in backward:
1/ AOTD collects metadata from forward fn only, so we can have usage of effectful ops in backward, that were not used in forward => Allowing tokens discovery during joint function .
FunctionalTensorMode holds _tokens, in Joint function after tracing forward we memoize _tokens as `_tokens_forward_output`.
2/ Tokens are added as primals inputs (forward) in EffectTokensWrapper.
Tokens that will be used in backward are in partitioner saved values. We do not have control on which positions they are saved in forward outputs.
2/ If new tokens discovered in backward after tracing joint_fn, the result graph will be manually added in the end of primals.
_aot_autograd/utils.py
3/ All effectful ops during backward are marked with 'must_be_in_backward' partitioner_tag, to prevent partiitoner to place them in forward.
For that functional_tensor_mode got new optional state `self._effects_partitioner_tag` for effectful ops, to set after tracing forward.
There are additional changes in partitioner to improve functionality of 'must_be_in_backward'
4/ Unlift tokens now should run for both forward and backward.
- As saved for backward tokens are placed on non static places - we identify input and output tokens to erase, by input and output of `with_effects` operation
- In forward we can have input tokens, discovered in backward, that are not used in with_effects ops in forward, but saved for backward. We identify them by position in forward inputs.
5/ Adding aot debug logging for graphs before unlifting and before adding additional primal for backward tokens.
Tests:
```
python test/higher_order_ops/test_with_effects.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132638
Approved by: https://github.com/bdhirsh
Part of #134054.
This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
move benchmarking out of `torch._inductor.runtime.runtime_utils` and into `torch._inductor.runtime.benchmarking`, and prefer this path over directly accessing Triton's benchmarking
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132827
Approved by: https://github.com/eellison
Summary:
When using activation_memory_budget for float8 training, two issues were noticed:
- When `aggressive_options` (https://fburl.com/code/m1yoskxw) is called , all fp8 gemms (the scaled_mm op) are saved for recomputation.
- After adding "scaled_mm" in the `compute_intensive_ops`, we got the next error from `estimate_runtime`: `mat2 must be col_major` from `meta_scaled_mm`.
To fix it, modified `materialize_arg` to also include the stride of the original tensor.
Test Plan: Run float8 training with `activation_memory_budget`.
Differential Revision: D60777297
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132687
Approved by: https://github.com/Chillee