Commit Graph

94246 Commits

Author SHA1 Message Date
Yuanyuan Chen
ce6b589545 Enable B904 check of flake8 (#165047)
The description of `B904` is `Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling. `

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165047
Approved by: https://github.com/Lucaskabela
2025-10-10 03:08:01 +00:00
Dzmitry Huba
ae25dd51fc Simplifying computation of the final result for equals op on DTensor (#164999)
Instead of collecting local results using all_gather_object followed by local reduction, with this change we switch to using a single all_reduce with MIN reduction operation to compute the final equals result.

This change is needed to enable LocalTensor work (all_gather_object introduces challenges in for DTensor and LocalTensor integration).

topic: not user facing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164999
Approved by: https://github.com/ezyang
2025-10-10 03:01:28 +00:00
Simon Fan
a61d0de9f9 [hop] support local_map filtered gradients (#164437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164437
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433
2025-10-10 02:34:27 +00:00
Simon Fan
3ad88924ad [hop] support local_map None placements (#164433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164433
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431
2025-10-10 02:34:27 +00:00
Simon Fan
3241b9c15f [hop] support local_map None gradients (#164431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164431
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602
2025-10-10 02:34:27 +00:00
Simon Fan
25d4d5107e [dynamo] trace local_map with local shapes for AP (#163602)
Context is in https://www.internalfb.com/excalidraw/EX519691 and https://docs.google.com/document/d/1qnuXLZk_GYt_PksHTwkn7L2ELRDnYlIRPkHAlXTyuhw/edit?tab=t.0. And the description of the previous PR: https://github.com/pytorch/pytorch/pull/164340.

The previous PR adds the support on the HOP side for eager execution and AOTAutograd. Dynamo is still passing the HOP a subgraph with wrong shapes. This PR fixes that. This is similar to the HOP implementation, however we additionally need to manually keep the TensorVariable metadata in sync.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163602
Approved by: https://github.com/ydwu4
ghstack dependencies: #164296, #164321, #164419, #164420, #164340
2025-10-10 02:34:27 +00:00
Simon Fan
e4fe811be8 [hop] trace local_map with local shapes in fake key (#164340)
Context is in https://www.internalfb.com/excalidraw/EX519691 and https://docs.google.com/document/d/1qnuXLZk_GYt_PksHTwkn7L2ELRDnYlIRPkHAlXTyuhw/edit?tab=t.0.

So for Autoparallel initial trace, we want to trace the graph with global shapes initially. But, for the local_map region, we are forced to trace with the expected local tensors. To the tracers, this looks weird, because it's a plain tensor input (representing DTensor's full tensor .to_local()) that we need to "redistribute".

After hacking a miserable version that had cross-key dependencies, @ydwu4 proposed this simpler approach to override the fake key. This means the shape conversion will be invisible to all dispatch keys above fake, this covers all current tracing mechanisms. This manifests as the joint graph for the HOP body being traced with local shapes:
```python
# HOP forward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
    def forward(self, primals_0: "f32[10, 80]"):
        # No stacktrace found for following nodes
        view: "f32[800]" = torch.ops.aten.view.default(primals_0, [-1]);  primals_0 = None
        add: "f32[800]" = torch.ops.aten.add.Tensor(view, 10);  view = None
        view_1: "f32[10, 80]" = torch.ops.aten.view.default(add, [10, 80]);  add = None
        return (view_1,)

# HOP backward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
    def forward(self, tangents_0: "f32[10, 80]"):
        # No stacktrace found for following nodes
        clone: "f32[10, 80]" = torch.ops.aten.clone.default(tangents_0);  tangents_0 = None
        return (clone,)
```

while the rest of the graph is still traced with global shapes:
```python
# Parent graph joint, note global shapes (80, 80)
class inner_f(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: "f32[80, 80]"; tangents_1: "f32[80, 80]";

        primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
         # File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:597 in forward, code: return fn(x)
        call_local_map = torch._higher_order_ops.local_map.call_local_map(primals_1);  primals_1 = None
        getitem: "f32[80, 80]" = call_local_map[0];  call_local_map = None
        call_local_map_1 = torch._higher_order_ops.local_map.call_local_map(tangents_1);  tangents_1 = None
        getitem_1: "f32[80, 80]" = call_local_map_1[0];  call_local_map_1 = None
        return pytree.tree_unflatten([getitem, getitem_1], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164340
Approved by: https://github.com/ydwu4
ghstack dependencies: #164296, #164321, #164419, #164420
2025-10-10 02:34:27 +00:00
Simon Fan
82c71af59a [hop] local_map validate partitioned fw/bw wrt placements (#164420)
Reviewed GPT-5 Summary:

**Summary / Goal**
Add validation that partitioned forward/backward graphs respect placements.

**Details**
- Validates placement alignment in local_map.
- The HOP's autograd key gets called when we are tracing the joint, we need to validate:
  - the inputs to the HOP's fwd gm (typically this is the dynamo rewritten inputs)
  - the inputs to the HOP partitioned fwd/bwd gm
  - the outputs of the HOP partitioned fwd/bwd gm

**Motivation**
Catch mismatch errors earlier, improve debugging.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164420
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419
2025-10-10 02:34:27 +00:00
Simon Fan
7bd704a346 [hop] local_map fix fw_gm/bw_gm naming (#164419)
Reviewed GPT5 summary:

**Summary / Goal**
Fix inconsistent variable naming for forward/backward graphs.

**Details**
- Those methods are actually for both fw and bw graphs now that we reuse the same op for fw/bw

**Motivation**
Improves clarity, avoids confusion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164419
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164296, #164321
2025-10-10 02:34:27 +00:00
Simon Fan
ae139b73e0 [dynamo] Better error message for local_map subgraph mismatches number of inputs/outputs with placement info (#164321)
Reviewed GPT5 summary:

**Summary / Goal**
Improve error reporting when local_map subgraph input/output counts mismatch placement info.

**Details**
- Adds descriptive runtime error messages.

**Motivation**
Helps debug local_map misalignments.

```python
AssertionError: Expecting 2 inputs to local_map function based on placements, but found 1. If the count matches for eager, Dynamo may have flattened inputs to the function or found additional tensors used via closures. Please adjust the input placements to match what the traced graph sees:
class GraphModule(torch.nn.Module):
    def forward(self, l_args_0_: "f32[8, 8, 16]"):
         # File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:523 in mismatch_input, code: return x + scalar, scalar
        child: "f32[8, 8, 16]" = l_args_0_ + 10;  l_args_0_ = None
        return (child,)
        .
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164321
Approved by: https://github.com/ezyang, https://github.com/mlazos
ghstack dependencies: #164296
2025-10-10 02:34:27 +00:00
Simon Fan
cbaa07e438 [dtensor] add util to compute expected local sizes/strides for even sharding (#164296)
Reviewed GPT5 summary:

**Summary / Goal**
Add a utility to compute expected local tensor sizes and strides under *even sharding* in dtensor.

**Details**
- New function in `torch/distributed/tensor/_utils.py`.
- Computes local sizes/strides given global shape, mesh, and placements.
- Enforces divisibility of global dimension by mesh size (strict even sharding).
- Complements `compute_global_tensor_info`.

**Motivation**
Ensures correctness for stride/layout computations in distributed tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164296
Approved by: https://github.com/ezyang
2025-10-10 02:34:27 +00:00
Yuanyuan Chen
bc0e2a0d2b Fix a condition error in torch/_inductor/codegen/debug_utils.py (#165033)
This PR fixes the condition
```
if arg_signatures is None and self.kernel_type == "cpp" or "extern"
```
which is interpreted as
```
if (arg_signatures is None and self.kernel_type == "cpp") or ("extern"):
```
and it is always evaluated to `True`. According to the context the intention was
```
if arg_signatures is None and (self.kernel_type == "cpp" or self.kernel_type == "extern"):
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165033
Approved by: https://github.com/Skylion007
2025-10-10 02:20:00 +00:00
drisspg
0747d95994 Add Loads from fixed inputs (#162031)
## TODO
Check on multi indices
```Python

    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
        in_ptr4 = buffers[0]
        tmp0 = tSrS_ssa
        tmp1 = b_idx
        tmp2 = h_idx
        tmp3 = cute.make_fragment(1, cutlass.Int32)
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6])
        tmp8 = (tmp5.load()).to(cutlass.Float32)
        tmp9 = (tmp0 + tmp8)
        tSrS_ssa = tmp9

        return tSrS_ssa

 ```

I dont think that
```
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6]

```

 is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162031
Approved by: https://github.com/v0i0
ghstack dependencies: #161118
2025-10-10 01:23:37 +00:00
drisspg
0a2cde2f06 Add Flash Attention support to FlexAttention (#161118)
Relies on this PR in Flash Attention: https://github.com/Dao-AILab/flash-attention/pull/1840

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161118
Approved by: https://github.com/v0i0
2025-10-10 01:23:37 +00:00
Jithun Nair
c7b57d9349 Add gfx1100 to build target for ROCm docker builds (#165103)
Fixes issue of gfx1100 test jobs timing out

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165103
Approved by: https://github.com/jeffdaily
2025-10-10 01:18:56 +00:00
PyTorch MergeBot
7614338b69 Revert "Add SVE128 ISA (#158932)"
This reverts commit 92284fb2ff.

Reverted https://github.com/pytorch/pytorch/pull/158932 on behalf of https://github.com/malfet due to Hmm, but from OSS point of view, this is a no-op ([comment](https://github.com/pytorch/pytorch/pull/158932#issuecomment-3387961238))
2025-10-10 01:17:02 +00:00
Edward Z. Yang
a6fa4f9c28 Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
2025-10-10 00:15:00 +00:00
Shunting Zhang
344e6365a0 [inductor][eazy] change how torch.use_deterministic_algorithms affect inductor (#164905)
Previously when torch.are_deterministic_algorithms_enabled() is True Inductor will
- skip autotuning pointwise kernels
- pick a fixed (and quite arbitrary) config for reduction

This PR change the behavior to
- for pointwise kernels, we still do autotuning
- for reduction kernels, we use the recent added heuristic to pick a config

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164905
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532, #164904
2025-10-10 00:00:58 +00:00
Shunting Zhang
a3c700656f [inductor] verify determinism with inductor benchmark script (#164904)
Verify the deterministic mode with torch.compile benchmark scripts.

Here is what my testing script does (pasted in the end):
- run a model in default mode, save it's result
- run the model again in default mode, but distort the benchmarking results. Compare it with the saved result.
- Do the above again in deterministic mode.

I tried to test a few modes
- BertForMaskedLM and GoogleFnet: I can repro the numeric change by distorting the benchnmark result in the default mode. The non-determinism is gone in the deterministic mode
- DistillGPT2: I can not repro the numeric change by distorting the benchmarking result in the default mode. It does not surprise me much. Reduction order change does not always cause numeric change.

```
model=GoogleFnet

export TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED=0
export TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  # disable autotune cache
export TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE=0
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
export TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting/
export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
export INDUCTOR_TEST_DISABLE_FRESH_CACHE=1

# Non deterministic mode
# --float32 rather than --amp to make it easier to repro non-deterministic
echo "Save results for non-deterministic mode"
python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-non-deterministic.pkl

echo "Compare results with distorted benchmarking in non-deterministic mode"
TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-non-deterministic.pkl

echo "Save results for deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --save-model-outputs-to=/tmp/saved-deterministic.pkl

echo "Compare results with distorted benchmarking in deterministic mode"
TORCHINDUCTOR_DETERMINISTIC=1 TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT=inverse python benchmarks/dynamo/huggingface.py --backend inductor --float32 --accuracy --only $model --training --disable-cudagraphs --compare-model-outputs-with=/tmp/saved-deterministic.pkl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164904
Approved by: https://github.com/jansel, https://github.com/v0i0
ghstack dependencies: #164801, #164532
2025-10-10 00:00:58 +00:00
Yidi Wu
600db525bd [easy][while_loop] use copy_input instead of clone in _clone_aliased_inputs (#164955)
Compared with clone, ExternKernel.copy_input additionally realize the buffer, which downstream assumes the input buffer are realized.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164955
Approved by: https://github.com/BoyuanFeng
2025-10-09 23:39:00 +00:00
Animesh Jain
f6de195616 [dynamo][trace_rules] Add ao.quantization (#165069)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165069
Approved by: https://github.com/tugsbayasgalan, https://github.com/mlazos
2025-10-09 23:08:42 +00:00
angelayi
4a0df39f81 Symintify fused_scaled_matmul_reduce_scatter (#165086)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165086
Approved by: https://github.com/zou3519, https://github.com/Skylion007
2025-10-09 23:07:40 +00:00
PyTorch MergeBot
34ac9b61cb Revert "[export] Turn on install_free_tensors flag (#164691)"
This reverts commit 0e9b3a772a.

Reverted https://github.com/pytorch/pytorch/pull/164691 on behalf of https://github.com/izaitsevfb due to breaks tests internally, author asked to revert, see [D84230990](https://www.internalfb.com/diff/D84230990) ([comment](https://github.com/pytorch/pytorch/pull/164691#issuecomment-3387718323))
2025-10-09 22:53:50 +00:00
Jeff Daily
9aa92f246f Hotfix test scaled matmul cuda (#165104)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165104
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-09 22:51:30 +00:00
Tugsbayasgalan Manlaibaatar
a57a14868d Better handling of restore_state_dict (#164401)
After lean export, we might want to be able to restore the original fqn. This PR refactors one util function in export that sort of does this. Note that strict_export has some complicated logic of updating the graph signature as well which we don't want. I think we can gradually make this util more refined by handling constants, non persistent buffers etc and change how strict_export does it today.

Differential Revision: [D83687844](https://www.internalfb.com/diff/D83687844)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164401
Approved by: https://github.com/avikchaudhuri
2025-10-09 22:39:11 +00:00
PyTorch MergeBot
47956196d9 Revert "Call internal log_compilation_event if it exists (#164855)"
This reverts commit 98a081a24c.

Reverted https://github.com/pytorch/pytorch/pull/164855 on behalf of https://github.com/albanD due to We should not land this kind of code in core ([comment](https://github.com/pytorch/pytorch/pull/164855#issuecomment-3387692988))
2025-10-09 22:38:45 +00:00
Nikita Shulga
6d27a8e509 [CD] Do not propagate download.pytorch.org IP into container (#165075)
Followup after https://github.com/pytorch/pytorch/pull/164969

Should fix binary build test failures
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165075
Approved by: https://github.com/seemethere, https://github.com/huydhn
ghstack dependencies: #164968, #164969
2025-10-09 21:59:31 +00:00
Eddie Yan
cd62a73dcb [cuDNN][SDPA] Handle noncontig nested tensors in cuDNN SDPA (#164958)
Previously we hardcoded the assumption in cuDNN that the inputs would be dense which breaks when e.g., the user is chunking tensors yielding noncontig inputs

New test added to check this  when `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1` is set in `test/test_transformers.py`

One issue I noticed was that the old gating of nested tensor in `sdp_utils.cpp` seems to be a no-op? All of the inputs are reported as "dense" by the time that function is called in the nested tensor tests in `test/test_nestedtensor.py -k sdpa`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164958
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2025-10-09 21:58:54 +00:00
PyTorch MergeBot
4d7f9f3aed Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 8e1f409b8c.

Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/jeffdaily due to broke cuda and rocm ci ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3387558806))
2025-10-09 21:36:10 +00:00
William Wen
2b9ff99535 [flex attention] change "==" to "is" in inspect parameter comparison (#165003)
Patch for https://github.com/pytorch/pytorch/issues/164760.

This doesn't actually fix the underlying torch function issue though.

Explanation: `is` is traced differently compared to `__eq__`, so we end up avoiding the issue where we attempt to evaluate `torch.eq(tensor, inspect._empty)` in the first place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165003
Approved by: https://github.com/mlazos
2025-10-09 21:18:05 +00:00
Sam Larsen
98a081a24c Call internal log_compilation_event if it exists (#164855)
Summary: For internal conda on mast jobs, call the internal version of log_compilation_event if it exists.

Test Plan: Ran a simple test job that just calls the API: https://fburl.com/scuba/dynamo_compile/dqx8d10g
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164855
Approved by: https://github.com/c00w
2025-10-09 21:15:11 +00:00
Lakshay Garg
6c0125dbc0 Mark functions const in CUDACachingAllocator (#165007)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165007
Approved by: https://github.com/eqy
2025-10-09 20:53:58 +00:00
Murray Steele
0fd976b65c Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)
This change removes the Windows requirement for mimalloc builds, and makes mimalloc the default c10 system allocator for AArch64 builds. This significantly improves the performance of AArch64 builds of PyTorch as large allocations are better cached by mimalloc than glibc.

**Updated Results**

Torchbench FP32 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-fp32-diff" src="https://github.com/user-attachments/assets/7fe3ea0c-3b52-42e7-879b-612444479c90" />

Torchbench BF16 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-bf16-diff" src="https://github.com/user-attachments/assets/56469a72-9e06-4d57-ae2a-aeb139ca79a3" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164741
Approved by: https://github.com/fadara01, https://github.com/aditew01, https://github.com/malfet
2025-10-09 20:49:46 +00:00
Maggie Moss
9944cac6e6 Add suppressions to torch/_inductor (#165062)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Split this directory into two PRs to keep them from being too large.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062
Approved by: https://github.com/oulgen, https://github.com/mlazos
2025-10-09 20:34:20 +00:00
Nikita Shulga
e7fd296930 [CI] Add full debug build to trunk (#164974)
But not test, just import torch, as regression test for https://github.com/pytorch/pytorch/issues/164297

Test plan: Re-apply #164974 on top of this change and observer the failure in the workflows: https://github.com/pytorch/pytorch/actions/runs/18383302153/job/52375282838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164974
Approved by: https://github.com/seemethere, https://github.com/clee2000, https://github.com/atalman
ghstack dependencies: #164968, #164969
2025-10-09 20:12:16 +00:00
Sam Larsen
fac85fcfb5 [inductor] custom_graph_pass.get_hash_for_files: don't hash paths (#165020)
Summary: We have an internal user where caching broke because the paths that are unzipped are probably different per host. We can't think of a use case where a path change matters when the file content has not changed, so removing this part

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165020
Approved by: https://github.com/oulgen
2025-10-09 20:07:53 +00:00
Natalia Gimelshein
228973df7f Fix channels-last dimension mapping in CUDA parallel_cat (#165023)
Fixes #164849
`dimension` was updated in-place, so for more than one batch of channels-last tensors the concat `dimension` for the second kernel launch was wrong

## Testing
- python -m compileall test/test_tensor_creation_ops.py

------
https://chatgpt.com/codex/tasks/task_e_68e708879b30832f89b10ae55faa68e8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165023
Approved by: https://github.com/ezyang
2025-10-09 20:04:32 +00:00
PyTorch MergeBot
ed2d514ad8 Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit 724463d5a2.

Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/malfet due to Not sure if it's related, but looks it triggered fuzzer compiler test failure, see a2f29bcd63/1 ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3387288464))
2025-10-09 19:53:38 +00:00
Tianren Gao
a2f29bcd63 [inductor] Remove Repeated Code in Subgraph (#164892)
Discovered some repeated code blocks in the subgraph.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164892
Approved by: https://github.com/PaulZhang12
2025-10-09 19:16:02 +00:00
FFFrog
5390324984 [CodeClean] Replace std::runtime_error with TORCH_CHECK (#164129)
As the title stated.

**Changes**:
- torch/csrc/Module.cpp
- torch/csrc/utils.cpp
- torch/csrc/stable
- torch/lib/libshm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164129
Approved by: https://github.com/albanD
2025-10-09 19:01:07 +00:00
Avik Chaudhuri
ae25ec569c reorder wrappers in aot_stage2_inference to match forward compile in aot_stage2_autograd (#165016)
In aot_stage2_autograd:
Before calling fw_compiler, we run pre_compile for the following wrappers:
* FakifiedOutWrapper
* FunctionalizedRngRuntimeWrapper

After, we run post_compile for the following wrappers:
 * EffectTokensWrapper
 * AOTDispatchSubclassWrapper
 * FunctionalizedRngRuntimeWrapper
 * FakifiedOutWrapper

In aot_stage2_inference:
Before calling inference compiler, we run pre_compile for the following wrappers (same as above):
 * FakifiedOutWrapper
 * FunctionalizedRngRuntimeWrapper

After, we run post_compile for the following wrappers  (different than above):
 * FunctionalizedRngRuntimeWrapper
 * FakifiedOutWrapper
 * EffectTokensWrapper
 * AOTDispatchSubclassWrapper

This PR makes both do the post_compiles in the same order.

Differential Revision: D84213657

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165016
Approved by: https://github.com/zhxchen17, https://github.com/bdhirsh
2025-10-09 18:36:04 +00:00
PaulZhang12
8e1f409b8c [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-10-09 18:08:30 +00:00
Jithun Nair
ee6a1ecb0a [ROCm] Enable MI355 CI on PRs, and run full set of UTs on PRs (#160215)
Useful to have PR testing for PRs such as https://github.com/pytorch/pytorch/pull/151360

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160215
Approved by: https://github.com/malfet, https://github.com/atalman

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-09 18:03:12 +00:00
Lakshay Garg
3c0577bd15 Remove shared_ptr from MHAGraphCache (#164895)
This commit makes several cleanup changes to MHA.cpp, the main
one of which is removal of shared_ptr from MHAGraphCache as the
cache does not actually intend to share ownership. The changes are:

1. Remove shared_ptr from MHAGraphCache
2. Remove template arguments from MHAGraphCache
3. Remove unnecessary optional<shared_ptr<...>> vars
4. Change some functions with auto return type to the actual type

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164895
Approved by: https://github.com/eqy
2025-10-09 17:44:28 +00:00
PyTorch MergeBot
688efd9741 Revert "Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)"
This reverts commit 87eccf10e8.

Reverted https://github.com/pytorch/pytorch/pull/164741 on behalf of https://github.com/malfet due to But it breaks MacOS builds, see https://github.com/pytorch/pytorch/actions/runs/18382886648/job/52373781138 ([comment](https://github.com/pytorch/pytorch/pull/164741#issuecomment-3386859778))
2025-10-09 17:30:25 +00:00
PyTorch MergeBot
91040f4934 Revert "[Code Clean] Remove support of python3.9 (#163846)"
This reverts commit bc1690c7e8.

Reverted https://github.com/pytorch/pytorch/pull/163846 on behalf of https://github.com/izaitsevfb due to breaks distributed tests ([comment](https://github.com/pytorch/pytorch/pull/163846#issuecomment-3386855437))
2025-10-09 17:27:08 +00:00
Murray Steele
87eccf10e8 Enable mimalloc on non-Windows platforms and make default for AArch64 builds (#164741)
This change removes the Windows requirement for mimalloc builds, and makes mimalloc the default c10 system allocator for AArch64 builds. This significantly improves the performance of AArch64 builds of PyTorch as large allocations are better cached by mimalloc than glibc.

**Updated Results**

Torchbench FP32 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-fp32-diff" src="https://github.com/user-attachments/assets/7fe3ea0c-3b52-42e7-879b-612444479c90" />

Torchbench BF16 eager Inference, 16 threads:
<img width="1510" height="733" alt="mimalloc-v2-bf16-diff" src="https://github.com/user-attachments/assets/56469a72-9e06-4d57-ae2a-aeb139ca79a3" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164741
Approved by: https://github.com/fadara01, https://github.com/aditew01, https://github.com/malfet
2025-10-09 16:45:31 +00:00
Ryo Suzuki
5d459dd609 avoid bit cast for bfloat16_t (#159946)
using bit_cast<bfloat16_t> triggers a static_assert, so replace it with intrinsics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159946
Approved by: https://github.com/aditew01, https://github.com/malfet
2025-10-09 16:42:49 +00:00
albanD
24d69c57cb Add view support for library custom Function (#164520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164520
Approved by: https://github.com/soulitzer, https://github.com/ezyang
2025-10-09 16:17:48 +00:00
Catherine Lee
eaa02655ea [CI] Run cpp tests on windows in one run_tests call (#164861)
The windows cpp tests take ~1 hour according to logs.  Each has run_test called on them individually, so I tried batching them together so it's just one run_test call for all of them.  I believe it now takes 30min.  I turned off TD since I don't think cpp tests are included in TD stuff.

As always with batch, I'm not sure if the errorlevel/error surfacing stuff is correct

This code is written with a lot of help from chatgpu and copilot
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164861
Approved by: https://github.com/huydhn
2025-10-09 16:07:28 +00:00