Commit Graph

1062 Commits

Author SHA1 Message Date
Xuehai Pan
e8fadba28c [pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.

Changes:

1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
2025-11-01 04:12:11 +00:00
James Wu
08f4535378 Refactor AOTAutogradCacheEntry into AOTAutogradResult (#166656)
This PR refactors the name AOTAutogradCacheEntry into AOTAutogradResult, and BundledAOTAutogradCacheEntry into BundledAOTAutogradResult. It also moves all coresponding files to a new file, `aot_autograd_result`, which is analogous to `output_code.py` from Inductor.

Having all these be called cache entries made sense when all we used them for was caching. But with AOT compile using BundledAOTAutogradCacheEntry, we want a more generalized naming structure.

This is a no-op change,  and all existing tests should pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166656
Approved by: https://github.com/zhxchen17
ghstack dependencies: #166650
2025-10-31 18:54:09 +00:00
James Wu
30157d30f0 Add regional aot eager support to AOTAutogradCacheEntry (#166650)
This PR does two things:

- It genericizes `BundledAOTAutogradCacheEntry` to support *any* outputcode, not just CompiledFxGraphs
- It adds a brand new OutputCode for the `aot_eager_regional_inductor` backend, i.e. a graph module that has regional inductor components in it.

This allows BundledAOTAutogradCache to just integrate nicely with inductor out of the box, but more importantly, it allows the result of aot_autograd to be fully serializable when using `aot_eager_regional_inductor`. This will allow us to AOT precompile cases where we have an eager graph that has scooped up inductor bits.

It's a bit unfortunate that the naming makes BundledAOTAutogradCacheEntry sound like its primary use is for caching, but really the more common use is going to be as an AOTAutogradOutput. It may be worth revisiting how to refactor/rename these in a later PR:

- AOTAutogradCacheEntry -> AOTAutogradResult
- BundledAOTAutogradCacheEntry -> BundledAOTAutogradResult

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166650
Approved by: https://github.com/zhxchen17
2025-10-31 18:54:09 +00:00
IvanKobzarev
b470e59c38 partitioner option to ignore partitioner_tag for abstract usage (#166725)
Partitioner functionality is appealing to use in different scenarios (E.g. Autoparallel)

We have special logic about "partitioner_tag" from meta that is only needed for forward/backward split.

Adding optional argument to avoid it and do only generic split based on inputs/outputs.

Potentially we want to make `_extract_graph_with_inputs_outputs` without underscore :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166725
Approved by: https://github.com/bdhirsh
2025-10-31 18:50:02 +00:00
PyTorch MergeBot
85b85f6c2c Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)"
This reverts commit 108bb224f7.

Reverted https://github.com/pytorch/pytorch/pull/160843 on behalf of https://github.com/atalman due to failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/160843#issuecomment-3474354428))
2025-10-31 18:31:32 +00:00
PyTorch MergeBot
5bcfdae71d Revert "Make PT2 compile backprop through custom op without autograd key a hard error (#166367)"
This reverts commit 4acc66f119.

Reverted https://github.com/pytorch/pytorch/pull/166367 on behalf of https://github.com/atalman due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/166367#issuecomment-3473150269))
2025-10-31 13:44:05 +00:00
Xuehai Pan
108bb224f7 [pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.

Changes:

1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
2025-10-31 10:33:16 +00:00
Yuanyuan Chen
694db5f549 Use 'is' in callable comparisons (#166624)
Just like we use `is/is not` for class comparisons, it is generally advised to use `is/is not` for comparisons against torch functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166624
Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007
2025-10-30 19:00:09 +00:00
Edward Z. Yang
4acc66f119 Make PT2 compile backprop through custom op without autograd key a hard error (#166367)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166367
Approved by: https://github.com/bdhirsh
2025-10-30 18:43:07 +00:00
PyTorch MergeBot
3f1824742c Revert "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. (#166277)"
This reverts commit b2a0f90501.

Reverted https://github.com/pytorch/pytorch/pull/166277 on behalf of https://github.com/atalman due to Breaks internal executorch tests ([comment](https://github.com/pytorch/pytorch/pull/166277#issuecomment-3468696623))
2025-10-30 15:49:23 +00:00
Yuanyuan Chen
2de4cf2102 [1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-30 12:22:25 +00:00
linhaifeng
369f2d6951 [3/N] fix typo in other folders (#166606)
fix typo in other folders

#166374
#166126

_typos.toml
```bash
[files]
extend-exclude = ["tools/linter/dictionary.txt"]
[default.extend-words]
nd = "nd"
arange = "arange"
Nd = "Nd"
GLOBALs = "GLOBALs"
hte = "hte"
iy = "iy"
PN = "PN"
Dout = "Dout"
optin = "optin"
gam = "gam"
PTD = "PTD"
Sur = "Sur"
nin = "nin"
tme = "tme"
inpt = "inpt"
mis = "mis"
Raison = "Raison"
ouput = "ouput"
nto = "nto"
Onwer = "Onwer"
callibrate = "callibrate"
ser = "ser"
Metdata = "Metdata"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606
Approved by: https://github.com/ezyang
2025-10-30 10:30:40 +00:00
Laith Sakka
b2a0f90501 Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. (#166277)
Fix https://github.com/pytorch/pytorch/issues/163894

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166277
Approved by: https://github.com/Lucaskabela
2025-10-30 00:34:05 +00:00
PyTorch MergeBot
972030fe2e Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)"
This reverts commit 284716a691.

Reverted https://github.com/pytorch/pytorch/pull/160843 on behalf of https://github.com/atalman due to failing internal torchrec test' ([comment](https://github.com/pytorch/pytorch/pull/160843#issuecomment-3464647878))
2025-10-29 22:46:48 +00:00
PyTorch MergeBot
1dd6b76914 Revert "[1/N] Remove unused loop variables (#166258)"
This reverts commit 76b2c37045.

Reverted https://github.com/pytorch/pytorch/pull/166258 on behalf of https://github.com/atalman due to breaks test/distributed/test_serialization.py::TestSerialization::test_weights_only [GH job link](https://github.com/pytorch/pytorch/actions/runs/18894311802/job/53929321703) [HUD commit link](76b2c37045) ([comment](https://github.com/pytorch/pytorch/pull/166258#issuecomment-3460964612))
2025-10-29 11:10:37 +00:00
Xuehai Pan
284716a691 [pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.

Changes:

1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
2025-10-29 09:16:24 +00:00
Yuanyuan Chen
76b2c37045 [1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-29 01:34:15 +00:00
Maggie Moss
31e42eb732 Fix pyrefly ignore syntax (#166438)
Reformats pyrefly ignore suppressions so they only ignore one error code.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166438
Approved by: https://github.com/Skylion007
2025-10-29 00:02:21 +00:00
Laith Sakka
009ea77234 Remove not needed code path. (#166278)
I accepted a PR that added this code, but re-examining it now, I'm questioning the approach. It seems like we're working around an issue with the inductor generating incorrect sizes. A comment suggests it might be related to unsqueezed u0 values. Removing this code didn't cause any failures, so I'll take it out and address the root issue if it arises.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166278
Approved by: https://github.com/Lucaskabela
2025-10-28 19:03:22 +00:00
Tugsbayasgalan Manlaibaatar
6096c0fc74 Export should use aot_export_joint_with_descriptors (#165931)
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs:
1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer.
2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165931
Approved by: https://github.com/zhxchen17
2025-10-27 19:33:33 +00:00
Yuanyuan Chen
a60d9e1f6d Fix flake8 B028 warnings (#166224)
This PR fixes flake8 B028 warning by specifying stacklevel=2 in `warnings.warn`. The advantage is that users can know more contextual information about PyTorch warnings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166224
Approved by: https://github.com/ezyang
2025-10-26 06:18:55 +00:00
Maggie Moss
eb83c3ca23 Clean up unused Pyrefly suppressions (#166178)
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean.

test plan:
`lintrunner -a`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178
Approved by: https://github.com/oulgen
2025-10-25 05:32:21 +00:00
Jason Ansel
78bcfcf870 [fx] Optimize torch.fx.Node.replace_all_uses_with (#165889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165889
Approved by: https://github.com/aorenste
2025-10-25 03:44:41 +00:00
PyTorch MergeBot
690c8c13b9 Revert "Export should use aot_export_joint_with_descriptors (#165931)"
This reverts commit 882b834082.

Reverted https://github.com/pytorch/pytorch/pull/165931 on behalf of https://github.com/clee2000 due to breaking internal tests D85084301 for test_auto_functionalize?  I checked that they did run on OSS CI so I'm not entirely sure whats going on, I assume its the IS_FBCODE stuff ([comment](https://github.com/pytorch/pytorch/pull/165931#issuecomment-3443887361))
2025-10-24 16:02:20 +00:00
Tugsbayasgalan Manlaibaatar
882b834082 Export should use aot_export_joint_with_descriptors (#165931)
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs:
1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer.
2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165931
Approved by: https://github.com/zhxchen17
2025-10-23 22:42:11 +00:00
James Wu
35180fafee Allow GraphPickler to pickle graph modules containing AOTCompiled subgraphs (#165844)
This PR allows GraphPickler to pickle aot_eager graph modules that have regional inductor bits in them, with a few exceptions:
- FlexAttentionBackward isn't marked cacheable, so those tests don't work immediately since we're not sure how to serialize it. But it's safe to serialize/cache, so the next PR fixes those unit tests.
- It seems that when reloading a GraphPickled object, we don't recompile subgraphs. Will investigate this in a future PR

All unit tests in test_regional_inductor are parameterized so that we try serializing and deserializing the returned graph module before returning.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165844
Approved by: https://github.com/oulgen
ghstack dependencies: #165843
2025-10-22 17:03:49 +00:00
Yidi Wu
af4ba78543 [scan x vmap] support scan in vmap (#165580)
This is required by the chunked_with_scan work where two nested vmap(vmap) with chunk sizes > 1 are invoked, which produces a scan-> vmap -> scan -> vmap chain and we need to handle the case of vmap(scan) and scan(vmap).

The way we handle vmap(scan) is to turn it into scan(vmap(combine_fn)). The idea being that the combine_fn no longer do the combine_fn for a single slice, it vmaps over the combine_fn and do multiple combine_fns in one step. We need to need know how combine_fn propagates the batched tensor and what are the batched dims of the output. For this purpose, we use restore_vmap to give us the out_dims information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165580
Approved by: https://github.com/zou3519
ghstack dependencies: #165675
2025-10-22 09:46:00 +00:00
angelayi
6aed378958 [export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165334
Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
2025-10-21 15:53:05 +00:00
Avik Chaudhuri
259cb945f5 [stage 2c] make autograd and inference functions (#165668)
Add final stage of aot_stage2_compile for autograd and inference.

Differential Revision: D84844699

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165668
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2025-10-20 23:50:31 +00:00
Shangdi Yu
efc277cac7 [annotation] add logging for debugging annotation (#165797)
Add logging for debugging annotation bugs. Log will show with `TORCH_LOGS="+annotation" `

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165797
Approved by: https://github.com/ezyang, https://github.com/Skylion007, https://github.com/SherlockNoMad
2025-10-20 21:27:38 +00:00
Yuanyuan Chen
3255e7872b Enable all flake8-logging-format rules (#164655)
These rules are enabled by removing existing suppressions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164655
Approved by: https://github.com/janeyx99, https://github.com/mlazos
2025-10-19 00:59:28 +00:00
Yuanyuan Chen
1f43d17ce6 Fix self assignment (#165816)
This PR removes assignments of the form `var=var`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165816
Approved by: https://github.com/jansel
2025-10-18 18:51:52 +00:00
Shangdi Yu
cf3a787bbc [annotate] Annotate bw nodes before eliminate dead code (#165782)
Fixes https://github.com/pytorch/torchtitan/pull/1907

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165782
Approved by: https://github.com/SherlockNoMad
2025-10-18 01:54:31 +00:00
James Wu
dd3b48e85d Fix bug with serialization after AOTAutogradCache hit (#165474)
Fixes #165447

On AOTAutogradCache load, the serialization function we pick is just lambda: self, because the object itself is an AOTAutogradCacheEntry. However, this isn't safe, because `wrap_post_compile` will make `self` unserializable, since it needs to load triton kernels and stuff!

So instead, on AOTAutogradCache load, we preserve the bytes that were used to load the object to begin with, and return that object on a call to serialize(). This effectively makes it so that we save a copy of the pre-hydrated artifact, without needing to do an eager copy until someone actually calls `serialize`.

Test Plan:

Run

```py
import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(2, 4)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(4, 8)
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

device = "cuda"
m = M().to(device)
sample_inputs = (torch.randn(2, 2, device=device),)
eager_out = m(*sample_inputs)

with torch._dynamo.config.patch("enable_aot_compile", True):
    compiled_fn_path = "./m.pt"
    compiled_fn = torch.compile(
        m,
        fullgraph=True
    ).forward.aot_compile((sample_inputs, {}))

    compiled_fn.save_compiled_function(compiled_fn_path)
    torch._dynamo.reset()
    with torch.compiler.set_stance("fail_on_recompile"):
        with open(compiled_fn_path, "rb") as f:
            loaded_fn = torch.compiler.load_compiled_function(f)

assert loaded_fn is not None

compiled_out = loaded_fn(m, *sample_inputs)

assert torch.allclose(eager_out, compiled_out)
```

twice, see that it succeeds.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165474
Approved by: https://github.com/yiming0416, https://github.com/zhxchen17
2025-10-17 17:47:24 +00:00
Maggie Moss
d795fb225a [RFC] Add pyrefly to lintrunner (#165179)
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker.

References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283:

test plan:
`lintrunner init`
`lintrunner`
confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179
Approved by: https://github.com/ezyang
2025-10-16 20:07:09 +00:00
Brian Hirsh
ed74dc054d add the option to disable functionalization in AOTDispatcher (#164577)
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: https://github.com/pytorch/pytorch/pull/164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577
Approved by: https://github.com/ezyang
ghstack dependencies: #165372
2025-10-16 15:44:11 +00:00
Brian Hirsh
f33c7e1a43 add and fix OpInfo tests for the default partitioner (#165372)
I noticed the default partitioner was breaking in some dynamic shape tests, so prior to turning off functionalization I want to tweak it to pass all of our OpInfo tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165372
Approved by: https://github.com/ezyang
2025-10-16 15:44:11 +00:00
Avik Chaudhuri
fa1539594b consolidate fw and inference compile paths (#165457)
By design, fw compile and inference compile stages should share a bunch of code; just consolidating the duplication here.

Differential Revision: D84628978

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165457
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2025-10-15 21:33:50 +00:00
PyTorch MergeBot
b509fb9b5d Revert "add and fix OpInfo tests for the default partitioner (#165372)"
This reverts commit bcfea48ab7.

Reverted https://github.com/pytorch/pytorch/pull/165372 on behalf of https://github.com/malfet due to Looks like it broke slow jobs, see 331b7cc054/1 ([comment](https://github.com/pytorch/pytorch/pull/165372#issuecomment-3407567748))
2025-10-15 17:38:52 +00:00
PyTorch MergeBot
7778a58e7c Revert "[export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)"
This reverts commit bbb902c8dd.

Reverted https://github.com/pytorch/pytorch/pull/165334 on behalf of https://github.com/jeffdaily due to trunk CI passed here but failures on HUD after merge?  test/functorch/test_aot_joint_with_descriptors.py::TestAOTJointWithDescriptors::test_module_with_kwargs [GH job link](https://github.com/pytorch/pytorch/actions/runs/18511729262/job/52755708742) [HUD commit link](bbb902c8dd) ([comment](https://github.com/pytorch/pytorch/pull/165334#issuecomment-3404071893))
2025-10-15 00:21:49 +00:00
Brian Hirsh
bcfea48ab7 add and fix OpInfo tests for the default partitioner (#165372)
I noticed the default partitioner was breaking in some dynamic shape tests, so prior to turning off functionalization I want to tweak it to pass all of our OpInfo tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165372
Approved by: https://github.com/ezyang
ghstack dependencies: #165327
2025-10-14 23:34:34 +00:00
angelayi
bbb902c8dd [export] Handle kwargs better in aot_export_joint_with_descriptors (#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165334
Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
2025-10-14 22:22:58 +00:00
Shangdi Yu
5eddbb5e47 [annotate] Annotation should be mapped across submod (#165202)
The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

```
NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202
Approved by: https://github.com/SherlockNoMad
2025-10-14 16:19:38 +00:00
Yuanyuan Chen
fbe0d20a17 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-14 14:22:54 +00:00
Animesh Jain
f3683453ae [compile] Regional inductor compilation with fx.annotate (#164776)
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.

### UX

1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.

Example

```
        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
```

2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is

```

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

```

3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

### Implementation

1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`

Forward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)
```

Backward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)
```

### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
   a) compile flex
   b) streams
   c) nn.Module info to organize MoE components for pipelining
   d) PP stages
   e) Rename graph nodes for more debugging
   f) No nested regional compile

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165188
2025-10-13 22:22:20 +00:00
PyTorch MergeBot
8d49cd5b26 Revert "[compile] Regional inductor compilation with fx.annotate (#164776)"
This reverts commit 1e4c7dffa3.

Reverted https://github.com/pytorch/pytorch/pull/164776 on behalf of https://github.com/malfet due to Looks like this one broke everything, not the top of the stack ([comment](https://github.com/pytorch/pytorch/pull/164776#issuecomment-3393725466))
2025-10-11 23:14:23 +00:00
Animesh Jain
1e4c7dffa3 [compile] Regional inductor compilation with fx.annotate (#164776)
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.

### UX

1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.

Example

```
        def fn(x, y):
            sin = torch.sin(x)

            with fx_traceback.annotate({"compile_with_inductor": 0}):
                mul = sin * y
                add = mul + 1

            return torch.sin(add)
```

2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is

```

# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
    return aot_autograd(
        fw_compiler=compile_fx_annotated_nodes_with_inductor,
        bw_compiler=compile_fx_annotated_nodes_with_inductor,
    )

```

3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.

### Implementation

1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module

Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`

Forward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(sin, primals_2)

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
        getitem: "f32[10]" = inner[0];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
        return (sin_1, primals_1, primals_2, sin, getitem)
```

Backward graph
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
        cos: "f32[10]" = torch.ops.aten.cos.default(add);  add = None
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

        # No stacktrace found for following nodes
        inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2);  mul_1 = sin = primals_2 = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
        getitem: "f32[10]" = inner[0]
        getitem_1: "f32[10]" = inner[1];  inner = None

         # File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
        mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1);  getitem_1 = cos_1 = None
        return (mul_4, getitem)
```

### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
   a) compile flex
   b) streams
   c) nn.Module info to organize MoE components for pipelining
   d) PP stages
   e) Rename graph nodes for more debugging
   f) No nested regional compile

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
2025-10-11 15:49:42 +00:00
Lucas Kabela
f363114852 [Bugfix][Inductor][Dynamo] Fix stride incorrectness issues for stride 0 tensor (#164897)
Fixes #164814 - we update to include cases where we know symbolic expression is statically one.  There are two errors here; first in graph capture, where a tensor with size 0 yet symbolic stride would attempt to keep the symbolic stride, resulting in a mismatch.  The second is in inductor code gen, where we only checked in squeeze if size == 1, missing the case where a symbolic stride equals 1.

Also fixes #164924 (@bobrenjc93  for fuzzer finding an issue affecting users : )

### Test plan:
```
python test/dynamo/test_aot_autograd.py AotAutogradFallbackTests
```

Results in:
```
..
----------------------------------------------------------------------
Ran 49 tests in 45.622s

OK (expected failures=1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164897
Approved by: https://github.com/laithsakka
2025-10-10 21:26:57 +00:00
Avik Chaudhuri
3ed90f5a09 outline various stages from aot stage2 compile (#164808)
Splits the training and inference paths for aot stage2 compile.
1. Split `aot_stage2_autograd` into `_aot_stage2a_partition`, `_aot_stage2b_fw_compile` and `_aot_stage2b_bw_compile`, and rest.
2. Split `aot_stage2_inference` into `_aot_stage2b_inference_compile` and rest.
I'm leaving these as functions with underscore names since the I/O interfaces and the exact boundaries of these splits are somewhat in the air.

Differential Revision: D84028203

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164808
Approved by: https://github.com/SherlockNoMad
2025-10-10 17:04:36 +00:00
Yuanyuan Chen
fb64da0791 [2/N] Use "is" in python type comparison (#165142)
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
2025-10-10 15:36:44 +00:00