Commit Graph

510 Commits

Author SHA1 Message Date
James Wu
1b772de397 Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)
When running BundledAOTAutogradCache with precompile, we still need to run triton bundling so that the precompiled CompiledFxGraph has triton cuda kernels. We also pre save the autotune results in the precompile artifact.

It would be even better to pre trim the cuda kernels on save and apply them, which we can work on later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158048
Approved by: https://github.com/zhxchen17
2025-07-22 14:12:21 +00:00
PyTorch MergeBot
bc379aebe2 Revert "Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)"
This reverts commit 8e57cdb746.

Reverted https://github.com/pytorch/pytorch/pull/158048 on behalf of https://github.com/jeffdaily due to rocm failures due to unit test introduced in this PR, but no pre-merge signal available ([comment](https://github.com/pytorch/pytorch/pull/158048#issuecomment-3098746624))
2025-07-21 20:45:21 +00:00
Benjamin Glass
22920c9138 Grab bag of (mostly) typing improvements (#158075)
Collects some scattershot improvements made while attempting to enable training for AOTInductor. Non-typing changes are:

1. Swapping a few custom searches for the output node in an FX graph for calling `graph.output_node()`.
2. Removing two unused parameters from `torch.export._unlift._unlift`.
3. Switching handles to constants in `cpp_wrapper_cpu` to use C++ references for memory efficiency.
4. Cleaning out unused, unexported imports from `torch/export/__init__.py`, and adding one missing export to `__all__`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158075
Approved by: https://github.com/Skylion007
2025-07-21 19:17:01 +00:00
James Wu
8e57cdb746 Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)
When running BundledAOTAutogradCache with precompile, we still need to run triton bundling so that the precompiled CompiledFxGraph has triton cuda kernels. We also pre save the autotune results in the precompile artifact.

It would be even better to pre trim the cuda kernels on save and apply them, which we can work on later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158048
Approved by: https://github.com/zhxchen17
2025-07-21 13:35:46 +00:00
Shangdi Yu
1e86fa2e5b Add stack trace to Inductor IR nodes if inductor.config.trace.provenance_tracing=True (#158576)
Summary:
- Split `create_mapping` to `create_mapping_pre_post_grad_nodes` and  ` create_node_mapping_kernel_to_post_grad`
- Store a mapping from pre_grad graph node names to stack traces in `_inductor_pre_grad_node_stack_trace`
- Add `stack_traces` member to ir.Node and add it to the string representation of ir.Node
- When we create an IR node, if `inductor.config.trace.provenance_tracing=True`, we populate `stack_traces` from `origins`. The nodes in `origins` are post_grad graph nodes. If a node has `node.stack_trace`, we store the stack_trace directly. This is particularly important for backward graph nodes because they don't have a mapping to pre-grad graph nodes. If a node doesn't have `.stack_trace ` (such as `linear`-> `addmm` nodes), we use the stack trace of the pre_grad graph nodes that it maps to.
  - A post grad graph node might not have stack trace if it correspond to multiple pre grad graph nodes, e.g. [GroupLinearFusion](a00442421a/torch/_inductor/fx_passes/group_batch_fusion.py (L299))

Example:

```
scheduling ExternKernelOut(
  python_kernel_name='extern_kernels.mm',
  name=buf0,
  layout=FixedLayout('cuda:0', torch.float32, size=[8, 16], stride=[16, 1]),
  inputs=[InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.float32, size=[8, 10], stride=[10, 1])), ReinterpretView(
    StorageBox(
      ConstantBuffer(name='fc1_weight', layout=FixedLayout('cuda:0', torch.float32, size=[16, 10], stride=[10, 1]))
    ),
    FixedLayout('cuda:0', torch.float32, size=[10, 16], stride=[1, 10]),
    origins=OrderedSet([mm_default_1]),
    stack_traces = {,
    File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward,
        x = self.fc1(x),
      File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward,
        return F.linear(input, self.weight, self.bias),
    }
  )],
  constant_args=(),
  kwargs={},
  output_view=None,
  python_kernel_name=extern_kernels.mm,
  cpp_kernel_name=at::mm_out,
  ordered_kwargs_for_cpp_kernel=(),
  op_overload=None,
  arg_properties=[{}, {}],
  allarg_properties={},
  kwarg_properties=None,
  unbacked_bindings={},
  mutation_outputs=[],
  origin_node=mm_default_1,
  origins=OrderedSet([mm_default_1]),
  stack_traces = {,
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward,
      x = self.fc1(x),
    File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward,
      return F.linear(input, self.weight, self.bias),
  }
)
```

Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing
```

Rollback Plan:

Differential Revision: D78365534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158576
Approved by: https://github.com/angelayi
2025-07-18 04:05:17 +00:00
Shangdi Yu
82a1ee1135 Refactor Provenance Tracking (#158399)
Summary:
As inductor provenance tracking is getting more use cases, we want to separate the inductor provenance tracking guarding flag from the general `trace.enabled`, so we can enable provenance tracking without all the overhead of `trace.enabled`

- change the guard flag from `trace.enabled` to `trace.provenance_tracking`.  It is turned on by either `TORCH_COMPILE_DEBUG=1` or `INDUCTOR_PROVENANCE=1`.
- Move the provenance tracking logic and variables out of DebugContext, because DebugContext is only enabled with `trace.enabled`. Since the variables are now global variables, added `reset_provenance_globals()` context manager to reset them for each `compile_fx()` call.
- Move `set_kernel_post_grad_provenance_tracing` from `util.py` to `debug.py` so now all provenance related logic is in `debug.py`.

In the future, if we want to enable it further, we can change the provenance tracking flag to be enabled when `TORCH_TRACE` is set. I think we should do that in a separate PR, so it's easier to revert if this flag change creates any problem.

See more motivation in internal Diff

Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r graph_provenance
buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing
```

Differential Revision: D78287976

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158399
Approved by: https://github.com/angelayi
2025-07-17 00:23:00 +00:00
Shangdi Yu
4781d72faa [AOTI] codegen for static linkage (#157129)
Design doc: https://docs.google.com/document/d/1ncV7RpJ8xDwy8-_aCBfvZmpTTL824C-aoNPBLLVkOHM/edit?tab=t.0 (internal)

- Add codegen for static linkage
- refactor test code for test_compile_after_package tests

For now,  the following options must be used together with `"aot_inductor.compile_standalone": True`.
"aot_inductor.package_cpp_only": True,

Will change `"aot_inductor.package_cpp_only"` to be automatically set to True in followup PR.

```
python test/inductor/test_aot_inductor_package.py -k test_compile_after_package
python test/inductor/test_aot_inductor_package.py -k test_run_static_linkage_model
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157129
Approved by: https://github.com/desertfire
2025-07-10 16:03:50 +00:00
Shangdi Yu
effe376db0 Adding aoti_standalone config (#157731)
Summary: When `compile_standalone` is True, we set `package_cpp_only` to True as well. We raise an error if  `package_cpp_only` is explicitly set to False in config.

Test Plan:
```
buck2 run  mode/dev-nosan fbcode//caffe2/test/inductor:test_aot_inductor -- -r  TestAOTInductorConfig
```

Rollback Plan:

Differential Revision: D77889754

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157731
Approved by: https://github.com/desertfire
2025-07-09 04:30:04 +00:00
Sam Larsen
7a41f20794 [inductor] Quiesce Triton compile worker pool after each dynamo compile (#156187)
For internal usages, keeping the Triton compile worker pool active for the lifetime of the process has caused some challenges, e.g., it slows down and muddies profiling due to the huge number of threads on a box: N threads = 8 ranks * 32 subprocs * M threads started by torch. Also, each subproc can use more than 1GB each. This PR adds the functionality to shutdown worker subprocs after each dynamo compile when using the SubprocPool implementation. The idea is to leave the main sidecar process running, but signal it to tear down its internal ProcessPoolExecutor when compile is finished. Restarting the ProcessPoolExecutor is relatively fast, e.g., 500ms because the ProcessPoolExecutor forks from the sidecar. Changes:
* Do not start the ProcessPoolExecutor automatically when compile_fx is imported. Instead, start the sidecar process only. The sidecar process imports torch, so is still slow to start.
* Introduce wakeup() and quiesce() calls to the implementation to start and stop the ProcessPoolExecutor.
* Add a context manager to automatically quiesce() at the end of dynamo compilation.
* Signal a wakeup() in compile_fx only when we have cuda devices.
* Add a killswitch so we can turn of quiescing.

Testing:
For correctness, the stacked change at https://github.com/pytorch/pytorch/pull/156534 enables the feature for OSS so it's exercised in CI.

For performance, because of recent compile-time variance (see https://github.com/pytorch/pytorch/issues/152566), it's pretty hard to glean whether there's a regression....

* Training: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/masnesral/210/head&lCommit=1b7315031c3bfad66a1a01700167a9ca1a2ae5f1&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801
* Inference: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/masnesral/210/head&lCommit=1b7315031c3bfad66a1a01700167a9ca1a2ae5f1&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801

The wins (mostly for inference) don't make sense, but I'm also skeptical of the losses (mostly for training). I can't repro any of the slowdowns locally. Furthermore, check out the benchmarking results for the stacked diff, which actually enables the quiescing functionality for OSS. That should only slow down compile since there can only be overhead to stop and start the workers. But the results are somehow better:

* Training: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/masnesral/214/head&lCommit=41943253882a019b8ceafcd2bf4cd6acbe0cbca9&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801
* Inference: https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Tue%2C%2017%20Jun%202025%2021%3A32%3A04%20GMT&stopTime=Tue%2C%2024%20Jun%202025%2021%3A32%3A04%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/masnesral/214/head&lCommit=41943253882a019b8ceafcd2bf4cd6acbe0cbca9&rBranch=main&rCommit=eab45643f22e58ee12d95d8b0162d51ca0a50801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156187
Approved by: https://github.com/aorenste, https://github.com/jansel
2025-07-08 22:53:13 +00:00
Shangdi Yu
5b4e0255d7 Check FakeScriptObject in _resolve_name_collision (#157736)
Summary:
Fix https://github.com/pytorch/pytorch/issues/157401

torch.equal cannot handle FakeScriptObject inputs.

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r  test_aoti_torchbind_name_collision
```

Rollback Plan:

Differential Revision: D77894081

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157736
Approved by: https://github.com/angelayi
2025-07-08 17:51:46 +00:00
bobrenjc93
d58ed04d89 [async-compile] add progressive compile mode (#157305)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157305
Approved by: https://github.com/aorenste
2025-07-04 04:18:50 +00:00
James Wu
e7a66166ce [precompile] When using BundledAOTAutogradCache, disable FXGraphCache (#156611)
The goal of this PR is to fix a specific bug when turning precompile on/off between caching runs.

If you try to turn on BundledAOTAutogradCacheEntry today in between local runs, the FXGraphCache may randomly hit *between* the two runs, because FXGraphCache knows nothing about AOTAutogradCache's config. When FXGraphCache hits, it immediately will call make_launchers() immediately on the triton code it launches, which then causes an assertion failure because pickle should not be called after make_launchers.

One way to resolve the bug is just to add whether precompile is enabled to teh FxGraph cache key. But the better fix for this, however, is higher level/philosophical:

When using BundledAOTAutogradCacheEntry, the entire CompiledFxGraph is saved directly to the cache entry, and we expect the two caches to work in sync, i.e. as one cache. So to simplify the programming model, we disable FxGraphCache when BundledAOTAUtogradCache is turned on.

BundledAOTAutogradCacheEntry is only used for precompile use cases now; if we wanted to use BundledAOTAutogradCache for traditional caching use cases, there's a bunch of further work, one of which would be to re-enable FxGraphCache in the event that BundledAOTAutogradCache has to bypass. However, for precompile, this is not a scenario that should happen: we should always expect the entire callable to be saveable, and we should expect to never bypass. So we don't do that change for now.

Added a unit test demonstrating this behavior. Also updated existing unit tests to show that all fx graph cache operations are now 0 (but all tests still pass).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156611
Approved by: https://github.com/zhxchen17
2025-06-25 21:01:42 +00:00
Shangdi Yu
eaf704914e [aoti] package weights to disk and dedup (#155241)
We package the weights and save them in `data/weights/` (`WEIGHTS_DIR`). In addition, we store a `weights_config.json` in the model folder for each model to specify which weight file corresponding to which weight name.

Models can share weights. We dedup the weights based on their underlying storage (`tensor.untyped_storate()`).

- Use `"aot_inductor.package_constants_on_disk": True` config to produce the `Weights` in aot_compile
- If we see `Weights` in aoti_files, we'll automatically package them to disk
- `"aot_inductor.package_constants_on_disk"` config and `"aot_inductor.package_constants_in_so"` config work independently.
- Use `load_pt2(package_path, load_weights_from_disk=True)` to load the weights from disk. `load_weights_from_disk` defaults to False.

Test Plan:
```
buck2 run @//mode/dev-nosan //caffe2/test/inductor:aot_inductor_package -- -r "test_package_shared_weights"
```

Tested with whisper at https://github.com/pytorch-labs/torchnative/pull/7

Rollback Plan:

Differential Revision: D74747190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155241
Approved by: https://github.com/desertfire
2025-06-19 17:17:17 +00:00
Oguz Ulgen
a2a75be0f8 Rename inductor cache (#156128)
Requested by Simon on a different PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156128
Approved by: https://github.com/xmfan
2025-06-17 03:57:18 +00:00
Aaron Orenstein
e95e8eed0a mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-06-14 18:18:43 +00:00
Marcin Pioch
ce79056471 Custom FX pass for inductor's backend registration (#154841)
This PR is related to RFC #153532. It is an extension to Inductor's backend registration interface to allow to register custom FX passes by the backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154841
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
2025-06-14 17:29:54 +00:00
Animesh Jain
c9e9a0c823 [inductor][invoke_subgraph] Mark invoke_subgraph outputs as user_visible to constrain output strides (#155395)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155395
Approved by: https://github.com/zou3519
2025-06-12 03:58:16 +00:00
Oguz Ulgen
d1947a8707 Migrate from lru_cache to cache (#155613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613
Approved by: https://github.com/ezyang
ghstack dependencies: #155612
2025-06-11 19:44:18 +00:00
penknife6153
59eb61b2d1 [inductor] Improve GEMM logging to display batch size for batched operations (#155544)
Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`.

**Fixes #155307**

## Problem

The current GEMM logging for `torch.bmm` shows:
```python
# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch

M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)

compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
```

**Before:**
```
Name                 | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------
aten.bmm             | 1024                 | 1024                 | 1024                 | 1
----------------------------------------------------------------------------------------------------
```

The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.

## Solution

**After:**
```
Name                           | B                    | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------------------------------------
aten.bmm                      | 10                   | 1024                 | 1024                 | 1024                 | 1
aten.mm                       | -                    | 1024                 | 1024                 | 1024                 | 2
----------------------------------------------------------------------------------------------------------------------------------
```

## Changes Made

### 1. Enhanced Parsing Logic in compile_fx.py
- Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'`
- For batched operations: takes last 4 parts as `batch, m, n, k`
- For non-batched operations: takes last 3 parts as `m, n, k`
- **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name
- Shows batch size for batched operations, shows "-" for non-batched operations

### 2. Updated All MM Operations for Consistency
- **bmm.py**:
  - Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm`
  - Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}`
  - Enhanced log messages to include batch size information

- **mm.py**: Updated counter keys for consistency:
  - `aten.mm_{m}_{n}_{k}` (no batch dimension)
  - `aten.addmm_{m}_{n}_{k}` (no batch dimension)
  - `aten._int_mm_{m}_{n}_{k}` (no batch dimension)
  - `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155544
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
2025-06-11 16:57:40 +00:00
PyTorch MergeBot
79bdafe5b6 Revert "Custom FX pass for inductor's backend registration (#154841)"
This reverts commit e694280d12.

Reverted https://github.com/pytorch/pytorch/pull/154841 on behalf of https://github.com/clee2000 due to failing some tests internally D76135706 ([comment](https://github.com/pytorch/pytorch/pull/154841#issuecomment-2956357711))
2025-06-09 16:56:45 +00:00
Jovian Anthony Jaison
1ccc57e428 Log backward no-op to tlparse and pt2 compile events. (#154544)
Summary: Log backward no-op to tlparse and pt2 compile events.

Test Plan:
$ rm -rf /tmp/r && TORCH_TRACE=/tmp/r buck2 run //scripts/jovian:backward_noop_repro_compile

Used print statements to verify we enter the logging code region.

Differential Revision: D75231665

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154544
Approved by: https://github.com/c00w
2025-06-06 18:08:19 +00:00
Marcin Pioch
e694280d12 Custom FX pass for inductor's backend registration (#154841)
This PR is related to RFC #153532. It is an extension to Inductor's backend registration interface to allow to register custom FX passes by the backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154841
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@jansel.net>
2025-06-06 06:49:44 +00:00
Simon Fan
28796f71d0 Redo D75092426: [internal] Expose additional metadata to compilation callbacks (#155063)
Originally https://github.com/pytorch/pytorch/pull/153596
---------------

Summary:
via reverting D75708685

gate the ROCm failure

Test Plan:
Unit tests in OSS, sandcastle

Rollback Plan:

Bifferential Revision: D75894349

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155063
Approved by: https://github.com/masnesral
2025-06-05 23:40:31 +00:00
PyTorch MergeBot
35fc5c49b4 Revert "[internal] Expose additional metadata to compilation callbacks (#153596)"
This reverts commit f889dea97d.

Reverted https://github.com/pytorch/pytorch/pull/153596 on behalf of https://github.com/izaitsevfb due to introduces bunch of callback-related failures on rocm ([comment](https://github.com/pytorch/pytorch/pull/153596#issuecomment-2923139061))
2025-05-30 18:39:27 +00:00
Simon Fan
f889dea97d [internal] Expose additional metadata to compilation callbacks (#153596)
These hooks are used by internal stuck job detection to associate compilation events with the compile lease. Previously, we only had events for Dynamo and Inductor compilation. And recently, the callback handler was updated to ignore nested events. So the Inductor event was only really used by lazy backward.

Here, I remove the inductor event, and add an explicit lazy backward one. Additionally, I add other runtime compilation events: autotuning and cudagraphs. I also expose the CompileId as a string to avoid imports, this will let internal UIs track each graph's contribution to the timeout.

```python
class CallbackTrigger(enum.Enum):
    # most common case, dynamo attempts to trace a new frame
    DYNAMO = 1
    # backward compilation can be deferred to runtime
    LAZY_BACKWARD = 2
    # some backends autotune at runtime
    TRITON_AUTOTUNING = 3
    # cudagraphs record at runtime
    CUDAGRAPH_RECORDING = 4
```

Differential Revision: [D75092426](https://our.internmc.facebook.com/intern/diff/D75092426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153596
Approved by: https://github.com/masnesral
2025-05-30 08:07:04 +00:00
eellison
d6e29bf875 Reflect back mutation if we clone misaligned tensors (#154442)
Fix for https://github.com/pytorch/pytorch/issues/152425

inductor specializes whether or not a tensor is 16-bit aligned on the first invocation. then, on subsequent invocations, if we inferred alignment but are passed a non-aligned tensor we clone the tensor.

If we infer alignment, then run with unaligned, and mutate the input, we need to reflect back the mutation to the input. This pr adds back that mutation.

We could have also been less aggressive about inferring alignment for mutated tensors, but that has a pretty perf hit.See the following benchmark:
```
import torch

t = torch.rand(4096 * 4096, device="cuda", dtype=torch.float16)

@torch.compile(dynamic=False)
def foo(x):
    return x.add_(1)

import triton

print(triton.testing.do_bench(lambda: foo(t[:-1])))
torch._dynamo.reset()
print(triton.testing.do_bench(lambda: foo(t[1:])))
```
gives
```
0.04063070610165596
0.07613472988113162
```
So almost twice as slow for non-aligned tensors. Tensors changing alignment is a relatively rare case.

In the future, we could considering a multi-kernel approach, or codegening a triton kernel that does most of the loads with aligned instructions, and a prologue/epilogue of un-alignment. But, it's yet to be seen this is a huge issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154442
Approved by: https://github.com/bobrenjc93, https://github.com/bdhirsh
2025-05-29 13:36:48 +00:00
Shangdi Yu
04a6fe7914 Update provenance tracking doc (#154062)
Summary: Update the doc to reflect the changes in https://github.com/pytorch/pytorch/pull/153584/files#diff-e0cdb58c0f84f56f20c5433339b6d83c470dcde47847e2328effea6bedd4cd27 and https://github.com/pytorch/tlparse/pull/110

Test Plan: CI

Differential Revision: D75155981

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154062
Approved by: https://github.com/svekars, https://github.com/desertfire
2025-05-23 17:09:52 +00:00
Rachel Guo
cad0727fe1 Rename the provenance tracing artifact name for kernel <-> post_grad nodes mapping (#154046)
Summary:
Context:

Recently we've added a couple more kernel types support other than inductor generated triton kernels,

such as cpu cpp kernels, extern kernels.

The name appeared in tlparse chrome link can be confusing to users.

Rename from

`inductor_triton_kernel_to_post_grad_nodes.json`

to `inductor_generated_kernel_to_post_grad_nodes.json`

Test Plan: CI

Differential Revision: D75159042

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154046
Approved by: https://github.com/yushangdi
2025-05-22 19:20:56 +00:00
Gabriel Ferns
254293b777 Add flag _metrics_log_runtime to disable runtime metric logging by default (#153506)
https://github.com/pytorch/pytorch/pull/152708 expanded support of `get_estimated_runtime` to many more types of `SchedulerNodes`. This caused an increase in compile time because we're always calling `get_estimated_runtime` to populate the metrics table. This PR adds a flag for this logging, which reduces the instruction count by 8%. Long term, we should probably merge metrics.py with TORCH_LOGS/tlparse (suggestion from @xmfan).

Update: added support for TORCH_LOGS for the metrics logging.

Test Plan:
mm_loop.py and many existing tests cover.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153506
Approved by: https://github.com/eellison
2025-05-22 01:02:11 +00:00
PaulZhang12
a7c01d7f13 [Inductor] Subgraph check output strides (#153755)
Make sure outputs strides of subgraph consistent with original gm. Without checking strides, it was possible for subgraph to produce nans with a reinterpret tensor on the output of the subgraph output, in which itself was not contiguous.

Differential Revision: [D74691119](https://our.internmc.facebook.com/intern/diff/D74691119/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153755
Approved by: https://github.com/eellison
ghstack dependencies: #153754
2025-05-20 16:07:18 +00:00
Menglu Yu
701e22112d [PT2][Optimus][Observability] Refactor the logging to avoid excessive tlparse log (#153584)
Summary: context: https://fb.workplace.com/groups/943185660584207/permalink/1215335930035844/

Test Plan:
before: aps-aps-ig_v4_2t_2_make_baseline_30batch-735703723-f735706162

tlparse: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/aps-aps-ig_v4_2t_2_make_baseline_30batch-735703723-f735706162/attempt_0/version_0/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000&fbclid=IwZXh0bgNhZW0CMTEAAR575JfJZUtE7kQCqzIZVCYomv1q03JzuMFVok8qDA_FuGC8oZ6rhhb2EziSQA_aem_abITQJZQP45t51_r-J-cFw

Differential Revision: D74776025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153584
Approved by: https://github.com/jamesjwu
2025-05-19 22:57:29 +00:00
PyTorch MergeBot
3443627e07 Revert "[BE]: Enable RUFF TRY400 rule - log.exception (#153473)"
This reverts commit 4f4ecc583e.

Reverted https://github.com/pytorch/pytorch/pull/153473 on behalf of https://github.com/jeanschmidt due to seems to have broken internal signals, @albanD may I count on you to help the author merge his PR? D74837988 ([comment](https://github.com/pytorch/pytorch/pull/153473#issuecomment-2886017075))
2025-05-16 08:29:26 +00:00
Aaron Gokaslan
4f4ecc583e [BE]: Enable RUFF TRY400 rule - log.exception (#153473)
Change logging.error to logging.exception to log additional information when relevant.  A few places have slipped in logging.errors in try except since I last did a clean up here and the rule is stabilized so I am enabling it codebase wide. I have NOQA'd much of our custom exception stack trace handling for RPC calls and distributed and tried to a fix a few errors based on whether we immediately reraised it or if we didn't print any exception handling where it could be useful.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153473
Approved by: https://github.com/albanD, https://github.com/cyyever
2025-05-15 13:36:59 +00:00
clr
85f97b5a8c compile_fx: make a compile event that corresponds to the fx_compile waitcounter (#152983)
This is a pretty minor change, but by having exact correspondence, we can
easily confirm data differences between perfetto and wait counters

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152983
Approved by: https://github.com/jansel, https://github.com/masnesral
2025-05-14 01:54:42 +00:00
Animesh Jain
7fdd754136 [compile-time traces] Profile large missing gaps in compile time (#151256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151256
Approved by: https://github.com/bdhirsh, https://github.com/masnesral, https://github.com/zou3519, https://github.com/jansel
2025-05-13 14:44:51 +00:00
PyTorch MergeBot
01bb249978 Revert "has_triton: Use the device interface for detecting Triton availability (#139171)"
This reverts commit 48bfe9afc7.

Reverted https://github.com/pytorch/pytorch/pull/139171 on behalf of https://github.com/masnesral due to Performance regression for huggingface ([comment](https://github.com/pytorch/pytorch/pull/139171#issuecomment-2868939790))
2025-05-10 14:46:23 +00:00
George White
48bfe9afc7 has_triton: Use the device interface for detecting Triton availability (#139171)
This PR replaces the `has_triton()` global method which was previously used for this task.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139171
Approved by: https://github.com/jansel, https://github.com/shink
2025-05-07 12:23:10 +00:00
Aaron Orenstein
7a0781eaad Improve cache key graph printing performance (#151928)
Teach the graph printer how to allow overriding printing SymTypes (`SymInt`, `SymFloat`, `SymBool`) and then use that to reuse the fast SymNode printing from `torch._inductor.utils.sympy_str()` to make computing the cache key faster.

On my computer the repro from #151823 goes from 480s -> 80s (still terrible... but better).

Fixes #151823

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151928
Approved by: https://github.com/laithsakka
2025-05-06 17:39:53 +00:00
Jovian Anthony Jaison
5d36485b4a Log aot and idx waitcounters. (#152444)
Summary:
Added for create_aot_dispatcher_function and compile_fx_inner.

Note:
Log wait counters flag is already set for:
1. async_compile.precompile
2. remote_fx_graph_cache_get
3. remote_fx_graph_cache_put

Test Plan: contbuild

Differential Revision: D73866124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152444
Approved by: https://github.com/ppanchalia, https://github.com/masnesral
2025-05-06 16:21:58 +00:00
Animesh Jain
cc254eaa7c [inductor][refactor] Refactor the fetching of subgraph names (#152770)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152770
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #152772
2025-05-06 02:55:34 +00:00
PyTorch MergeBot
172a7c942e Revert "Log aot and idx waitcounters. (#152444)"
This reverts commit ea9ea02959.

Reverted https://github.com/pytorch/pytorch/pull/152444 on behalf of https://github.com/jovianjaison due to needs a fix ([comment](https://github.com/pytorch/pytorch/pull/152444#issuecomment-2851905261))
2025-05-05 18:11:37 +00:00
Jovian Anthony Jaison
ea9ea02959 Log aot and idx waitcounters. (#152444)
Summary:
Added for create_aot_dispatcher_function and compile_fx_inner.

Note:
Log wait counters flag is already set for:
1. async_compile.precompile
2. remote_fx_graph_cache_get
3. remote_fx_graph_cache_put

Test Plan: contbuild

Differential Revision: D73866124

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152444
Approved by: https://github.com/ppanchalia, https://github.com/masnesral
2025-05-05 17:35:29 +00:00
Laith Sakka
38a9a8b7f7 Fix: Consider input defined unbacked during inductor codegen for runtime asserts (#152231)
So when we use mark_unbacked the graph will have an unbacked inputs symInt. Right now,
deferred runtime assertions that uses those  is never generated.

This PR changes that, such that in the forward graph we consider those and generate the corresponding
runtime assertions of them. We still ignore them for backward which is not ideal

The way we generate runtime assertion is by emitting them when all the defined unbacked symbols used
in them are seen.

We previously skipped placeholder, because for backward we have a wacky approach were we
ignore input defined unbacked symbols and assumes assertions that uses them are already emitted
in forward and we try to emit all other runtime assertions again. see [Note [Backwards runtime asserts]

Doing that we ends up only emitting the runtime assertions that depends on things defined solely in backward, but we could miss checks that spans inputs defined in both backward and forward, i.e one symbol defined in forward passed as input to backward., and another that is defined in backward.) .This is not ideal an ideal approach could be something like this https://github.com/pytorch/pytorch/pull/151919 but it require more work .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152231
Approved by: https://github.com/aorenste
2025-05-02 07:01:48 +00:00
henrylhtsang
1845df05c6 [inductor][BE] Add more debug logs for why fx graph cache doesn't happen (#152487)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152487
Approved by: https://github.com/Skylion007, https://github.com/eellison
2025-05-01 17:25:28 +00:00
Blaine Burton Rister
7c63ddd817 [Inductor] Wrapper code refactors to prepare for FX codegen (#152391)
This PR contains some refactors from https://github.com/pytorch/pytorch/pull/146942, which help to enable Wrapper FX codegen:
1. Remove `OutputLine`, which is unused.
2. Add an attribute to the backend classes specifying whether they support caching.
3. Before compiling a graph, query the registered backends and check whether caching is supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152391
Approved by: https://github.com/jansel
2025-05-01 09:14:55 +00:00
PyTorch MergeBot
49a72011cc Revert "[inductor][BE] Add more debug logs for why fx graph cache doesn't happen (#152487)"
This reverts commit 76331657d2.

Reverted https://github.com/pytorch/pytorch/pull/152487 on behalf of https://github.com/malfet due to And it broke those tests, not sure why signal was ignored ([comment](https://github.com/pytorch/pytorch/pull/152487#issuecomment-2843333471))
2025-04-30 21:35:17 +00:00
rzou
22ecaeb145 [standalone_compile] fix dynamic shapes with config_patches (#152462)
compile_fx with config_patches goes down another path where we need to
propagate the kwarg...

Test Plan:
- updated test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152462
Approved by: https://github.com/oulgen
2025-04-30 21:02:14 +00:00
henrylhtsang
76331657d2 [inductor][BE] Add more debug logs for why fx graph cache doesn't happen (#152487)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152487
Approved by: https://github.com/Skylion007, https://github.com/eellison
2025-04-30 20:05:21 +00:00
Brian Hirsh
4a63cab624 [cudagraphs] Fix issue in collecting static_input_idxs (#152287)
related to https://github.com/pytorch/pytorch/issues/152275

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152287
Approved by: https://github.com/bdhirsh, https://github.com/eellison

Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
2025-04-30 03:24:05 +00:00
PyTorch MergeBot
a6d19fcfac Revert "[cudagraphs] Fix issue in collecting static_input_idxs (#152287)"
This reverts commit 75a564608a.

Reverted https://github.com/pytorch/pytorch/pull/152287 on behalf of https://github.com/wdvr due to causing ao failures - discussed with author ([comment](https://github.com/pytorch/pytorch/pull/152287#issuecomment-2837686127))
2025-04-29 06:57:06 +00:00