Commit Graph

647 Commits

Author SHA1 Message Date
Laith Sakka
1aef88c72d Avoid DDE in narrow with unbacked start (#166361)
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-01 07:10:23 +00:00
Yuanyuan Chen
2de4cf2102 [1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-30 12:22:25 +00:00
Shangdi Yu
d2eff5d454 Add python stack trace to AOTI generated code (#160539)
Summary:
We add a thread_local KernelContext object so Strobelight (and other potential profilers) can read the stack trace information of the running kernel.

This will bring extra overhead, so we guard this behind the `cpp.enable_kernel_profile` flag.

Example output code:

```cpp
#include <torch/csrc/inductor/aoti_runtime/kernel_context_tls.h>
namespace torch::aot_inductor {
thread_local KernelContext* tls_kernel_context = nullptr;
}
// Other code .....
void AOTInductorModel::run_impl(
    AtenTensorHandle*
        input_handles, // array of input AtenTensorHandle; handles
                        // are stolen; the array itself is borrowed
    AtenTensorHandle*
        output_handles, // array for writing output AtenTensorHandle; handles
                        // will be stolen by the caller; the array itself is
                        // borrowed
    DeviceStreamType stream,
    AOTIProxyExecutorHandle proxy_executor
) {
    __check_inputs_outputs(input_handles, output_handles);
    auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 4);
    auto arg2_1 = std::move(inputs[0]);
    auto arg3_1 = std::move(inputs[1]);
    auto arg4_1 = std::move(inputs[2]);
    auto arg5_1 = std::move(inputs[3]);
    [[maybe_unused]] auto& fc1_weight = constants_->at(0);
    [[maybe_unused]] auto& fc1_bias = constants_->at(1);
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {8L, 16L};
    static constexpr int64_t int_array_1[] = {16L, 1L};
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    // Topologically Sorted Source Nodes: [linear], Original ATen: [aten.t, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:1
    static constexpr int64_t int_array_2[] = {10L, 16L};
    static constexpr int64_t int_array_3[] = {1L, 10L};
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 829, in forward
    x = self.fc1(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf0, fc1_bias, arg2_1, wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper(fc1_weight, 2, int_array_2, int_array_3, 0L)), 1L, 1L));
    }
    arg2_1.reset();
    auto buf1 = std::move(buf0);  // reuse
    static constexpr int64_t int_array_4[] = {10L, 20L};
    static constexpr int64_t int_array_5[] = {20L, 1L};
    AtenTensorHandle buf2_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_4, int_array_5, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf2_handle));
    RAIIAtenTensorHandle buf2(buf2_handle);
    // [Provenance debug handles] cpp_fused_mul_relu_sigmoid_0:2
    {
    KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 831, in forward
    x = self.sigmoid(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 359, in forward
    return torch.sigmoid(input)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 830, in forward
    x = self.relu(x)
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/torch/nn/modules/activation.py", line 144, in forward
    return F.relu(input, inplace=self.inplace)
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 832, in forward
    d = a * 3.14
)");
    cpp_fused_mul_relu_sigmoid_0((float*)(buf1.data_ptr()), (const float*)(arg3_1.data_ptr()), (float*)(buf2.data_ptr()));
    }
    arg3_1.reset();
    static constexpr int64_t int_array_6[] = {10L, 30L};
    static constexpr int64_t int_array_7[] = {30L, 1L};
    AtenTensorHandle buf3_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(2, int_array_6, int_array_7, cached_torch_dtype_float32, cached_torch_device_type_cpu, this->device_idx_, &buf3_handle));
    RAIIAtenTensorHandle buf3(buf3_handle);
    // Topologically Sorted Source Nodes: [mul, addmm], Original ATen: [aten.mul, aten.addmm]
    // [Provenance debug handles] aoti_torch_cpu_addmm_out:3
    {
    KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 833, in forward
    y = torch.addmm(c, d, b)
)");
    RAIIAtenRecordFunctionHandle record_aoti_torch_cpu_addmm_out_("aoti_torch_cpu_addmm_out", nullptr);
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(buf3, arg5_1, buf2, arg4_1, 1L, 1L));
    }
    arg4_1.reset();
    arg5_1.reset();
    buf2.reset();
    auto buf4 = std::move(buf3);  // reuse
    // [Provenance debug handles] cpp_fused_gelu_1:4
    {
    KernelContextGuard _ctx("cpp_fused_gelu_1", R"(
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/cba6f4fb5faa5f79/caffe2/test/inductor/__provenance_tracing__/provenance_tracing#link-tree/caffe2/test/inductor/test_provenance_tracing.py", line 834, in forward
    z = torch.nn.functional.gelu(y)
)");
    cpp_fused_gelu_1((float*)(buf4.data_ptr()));
    }
    output_handles[0] = buf1.release();
    output_handles[1] = buf4.release();
} // AOTInductorModel::run_impl
```

Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r  stack_traces
```

Rollback Plan:

Differential Revision: D78436007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160539
Approved by: https://github.com/yiming0416
2025-10-29 22:47:52 +00:00
PyTorch MergeBot
1dd6b76914 Revert "[1/N] Remove unused loop variables (#166258)"
This reverts commit 76b2c37045.

Reverted https://github.com/pytorch/pytorch/pull/166258 on behalf of https://github.com/atalman due to breaks test/distributed/test_serialization.py::TestSerialization::test_weights_only [GH job link](https://github.com/pytorch/pytorch/actions/runs/18894311802/job/53929321703) [HUD commit link](76b2c37045) ([comment](https://github.com/pytorch/pytorch/pull/166258#issuecomment-3460964612))
2025-10-29 11:10:37 +00:00
Yuanyuan Chen
76b2c37045 [1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-29 01:34:15 +00:00
Oguz Ulgen
8d4e48831e Remove JITFunction constexpr and some arg_names (#166280)
https://github.com/triton-lang/triton/pull/8536 breaks torch.compile integration. This PR attempts to fix it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166280
Approved by: https://github.com/jansel
2025-10-27 09:29:03 +00:00
Maggie Moss
9940e894ea Fix pyrefly ignore syntax in _inductor (#166247)
Ensures pyrefly ignores only ignore the intended error code.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166247
Approved by: https://github.com/oulgen
2025-10-27 02:48:42 +00:00
Blaine Burton Rister
f2bb22ff84 [Inductor-FX] Support Tensor.item (#165599)
# Feature
This PR supports compiling `Tensor.item` with Inductor's FX backend. This maps to a custom WrapperCodeGen method called `codegen_dynamic_scalar`.

# Implementation
The implementation is fairly mechanical, following the usual flow for these types of PRs.
1. Introduce a new Wrapper IR line for this, called `DynamicScalarLine`.
2. Split `PythonWrapperCodegen.codegen_dynamic_scalar` into 2 parts: a public method which generates the Wrapper IR line, and a private one generating Python from Wrapper IR.
3. Implement an FX codegen method for the wrapper IR line. This one calls `aten.where.Scalar` to handle code like `1 if x.item() else 0`, which is a bit tricky. It also calls `aten.item.default` to convert tensors to scalars.

# Test plan
Added CI tests mirroring the AOTI ones. They test float, int and bool types, the latter taking a distinct codegen path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165599
Approved by: https://github.com/angelayi, https://github.com/jansel
2025-10-21 07:09:56 +00:00
Mu-Chu Lee
9fccbdd4f0 Fix incorrect function signature in template (#165567)
Summary:
In https://github.com/pytorch/pytorch/pull/148305 we refactored the grid
argument out, but it's not reflected in our template.

Test Plan:
Included in commit.
python test/inductor/test_aot_inductor.py
AOTInductorTestABICompatibleGpu.test_cond_symint_input_disable_one_pass_cuda

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165567
Approved by: https://github.com/desertfire
2025-10-17 02:40:56 +00:00
Maggie Moss
5641de7b6b Add suppressions for _inductor/codegen (#165659)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165659
Approved by: https://github.com/oulgen
2025-10-16 21:37:37 +00:00
Boyuan Feng
f071f17911 [Graph Partition] fix partition x memory plan issue (#165514)
For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)):

```
def partition_0(args):
    ...
    del buf0
    return (buf3, buf4, buf5, buf2, primals_4, )

...

  File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0
    return (buf3, buf4, buf5, buf2, primals_4, )
                              ^^^^
NameError: name 'buf2' is not defined. Did you mean: 'buf0'?
```

When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)):

```
def call(self, args):
    ...
    buf2 = buf0; del buf0  # reuse
    ...
```

Note that the issue is buf0 is not reused for buf2 when using graph partition.

Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165514
Approved by: https://github.com/ProExpertProg, https://github.com/eellison
2025-10-15 21:52:16 +00:00
Colin Peppler
a33f85e791 Add tlparse artifact for autotune_at_compile_time (#164984)
This is useful for inspecting autotuning code when `autotune_at_compile_time=True`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164984
Approved by: https://github.com/yushangdi, https://github.com/desertfire
2025-10-12 23:38:11 +00:00
Pian Pawakapan
474d07554a [dynamic shapes] unbacked-safe slicing (#161414)
Summary:
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.

Test Plan:
contbuild & OSS CI, see 56218d85e2

Rollback Plan:

Differential Revision: D80948073

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161414
Approved by: https://github.com/laithsakka
2025-09-30 01:15:19 +00:00
Blaine Burton Rister
604da4bb9a [Inductor-FX] Support unbacked symbol definitions (#163729)
# Problem
Inductor sometimes generates unbacked symints to handle things like mismatched branches of `torch.cond`. This code is represented by `pytree.KeyPath`, with special codegen logic to convert it to Python and C++. This was not previously supported by the FX backend.

# Feature
This PR adds support for unbacked symbol declarations to the FX backend. The implementation is fairly straightforward.
1. Instead of raw Python/C++, update the wrapper codegen method to emit a new Wrapper IR line called `UnbackedSymbolDefsLine`. This contains all the information needed to  generate the Python and C++ code.
2. Move the existing Python/C++ codegen to a private method, which is invoked by `UnbackedSymbolDefsLine.codegen()`.
3. Implement a method to generate FX IR from unbacked symbol definitions. The implementation is based on recursive descent, consuming some keypath entries, emitting an FX IR node, and recursing to the rest of the keypath. It is conceptually identical to the existing algorithm for Python and C++, except it generates FX nodes.
4. The FX backend currently relies on size hints to generate autotuning arguments, and consequently autotuning does not support unbacked SymInts. At some point, we would like to generalize the autotuning logic to support these. But for now, simply emit a warning and skip autotuning when we see them.
5. The new test case exposed some tricky issues reconciling Triton call args with constants stored in `triton_meta`. This PR rewrites the relevant helper function to do this in a more principled way.

# Test plan
This PR imports an existing control flow test to the FX backend's test suite. The test uses unbacked symbol definitions to handle mismatched dynamic shapes coming from `torch.cond` branches.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163729
Approved by: https://github.com/jansel
2025-09-29 18:10:37 +00:00
dolpm
3746039b47 [inductor] fix: 'get_raw_stream' undefined (#163707)
Summary:
ran into this when precompiling baidu/ERNIE-4.5-21B-A3B-PT

codegen after fix:
```py
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
with torch.cuda._DeviceGuard(0):
    stream0 = get_raw_stream(0)
...
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163707
Approved by: https://github.com/jamesjwu
2025-09-29 15:48:16 +00:00
can-gaa-hou
eb4361a801 [Fix] Adding missing f prefixes to formatted strings [1/N] (#164065)
As stated in the title.

* #164068
* #164067
* #164066
* __->__ #164065

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164065
Approved by: https://github.com/Skylion007
2025-09-29 04:53:00 +00:00
Shangdi Yu
520fca82c8 Refactor Provenance Tracking (#163378)
Summary:
- Move the `provenance_level` flag check to inside the `set_kernel_post_grad_provenance_tracing` call to simply the code

- Move the `set_kernel_post_grad_provenance_tracing` call and `write_provenance_debug_handle` call to `codegen_comment`.

- If some `call_kernel` call sites don't have a proceeding `codegen_comment` call, add one. Now all `call_kernel` call sites are accompanied with a  `codegen_comment` call.

- Add a `codegen_comment` method to BaseScheduling and remove the noop `codegen_comment` method in Scheduling

- Remove `debug_handle` from `call_kernel`.

Test Plan:
CI

```
buck run @//mode/opt-split-dwarf fbcode//caffe2/test/inductor:provenance_tracing
```

Differential Revision: D82839271

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163378
Approved by: https://github.com/angelayi
2025-09-25 22:55:59 +00:00
Bin Bao
be6c127927 [AOTI] Pass comments from metadata to the autotune block (#163600)
Summary: When generating Triton kernels in the compile-time autotune blocks, it will be useful to generate source information as code comments. Previously we ignore these comments for autotune code blocks because the generated main output code will contain the same information, but it won't work if the generated autotune code crashes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163600
Approved by: https://github.com/yushangdi
2025-09-24 02:01:59 +00:00
Blaine Burton Rister
e56dd5d770 [Inductor-FX] Support torch.cond (#163234)
# Feature

Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`:
- Submodules as stored as attributes, and accessed via `getattr`.
- The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs.

# Implementation overview

The FX backend generates code for subgraphs using the following steps:
1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`.
   a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`.

# Implementation details
This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.

Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`.

Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.

In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.)

Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR.

This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs.

# Test plan
This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
    return buf1
```

It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163234
Approved by: https://github.com/angelayi, https://github.com/jansel
2025-09-20 03:52:31 +00:00
Boyuan Feng
505ee42570 [Graph Partition] allow sharing default device context (#162873)
Entering a device context takes 30 us and exiting a device context takes 11 us. If all graph partitions and cudagraph-unsafe ops happen on the same device, we can share the device context.

## Trace

Use vLLM as an example. The first trace shows dynamo graph partition.
<img width="1338" height="453" alt="image" src="https://github.com/user-attachments/assets/b81815fd-cdcb-4024-846a-5b64164f8bac" />

The second trace shows inductor graph partition prior to this PR.
<img width="1331" height="270" alt="image" src="https://github.com/user-attachments/assets/8d98b127-2053-4eae-9a31-5491661f14d8" />

Comparing with fx graph partition, we can see inductor graph partition shows extra overhead from enter/exit device contexts (13+6 us -> 30+11 us), but smaller runtime overhead (13 us -> 7 us). This motivates the PR to share default device context.

The third trace shows Inductor graph partition after this PR. We observe that the extra overhead from enter/exit device contexts have been fixed. At the same time, we observe the smaller runtime overhead.
<img width="1336" height="276" alt="image" src="https://github.com/user-attachments/assets/77be2237-34dd-4bac-ad9c-d9af3be36417" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162873
Approved by: https://github.com/shunting314
2025-09-16 19:36:42 +00:00
Blaine Burton Rister
9aca0ba027 [Inductor-FX] Support IndexPutFallback (#162863)
# Feature

This PR supports lowering `IndexPutFallback` through Inductor's FX converter. The approach is very similar to the one taken in https://github.com/pytorch/pytorch/pull/162686.

Compared to `ScatterFallback`, this required one additional change: the value of `self.op_overload` for `IndexPutFallback` was inaccurate. Previously, it used `aten.index_put`, which would result in unsound FX IR. The existing Python/C++ codegen use `aten.index_put_`, since the fallback mutates its input. This PR changes `self.op_overload` to match that.

# Test plan
Added a CI test lowering deterministic index put via the FX converter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162863
Approved by: https://github.com/angelayi
2025-09-16 08:52:47 +00:00
Blaine Burton Rister
a7bbc5fea7 [Inductor-FX] Support ScatterFallback (#162686)
# Problem
Inductor has a `ScatterFallback` op with custom Python and C++ wrapper codegen macros. This is used in certain situations where the default Triton codegen doesn't apply, and especially for reductions which need to be deterministic. Since this op used direct Python/C++ codegen, it wasn't compatible with the FX backend.

# Feature
This PR refactors the associated wrapper codegen to support `ScatterFallback`. This follows the same basic steps that were used for other fallback ops including `MultiOutput` and `ExternKernel`:

1. Create a new wrapper IR op called `ScatterFallbackLine`. Move the logic in `ScatterFallback.cogeden` to `ScatterFallbackLine.codegen`, to prevent it from affecting the FX backend. This logic is unsafe for FX because it may generate Python or C++ strings with methods like `codegen_reference()`.
2. To eleminate the dependence on `V.graph`, move language-specific logic to the respective wrapper codegen subclasses. In this case, C++ codegen has some special logic, which is moved to `CppWrapperCpu`.
3. Create a new method in `FXWrapperCodegen` to handle `ScatterFallbackLine`.

# Test plan
Added a couple of CI tests for the FX backend with scatter fallbacks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162686
Approved by: https://github.com/jansel
2025-09-12 08:41:50 +00:00
Colin Peppler
94755e81c4 [inductor] Enable combo kernels with unbacked inputs (#162442)
Internal user tried enabling combo kernels, but ran into "Cannot convert symbols to int". This PR is to enable combo kernels on inputs with data-dependent shapes.

### Example exception

```
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 4997, in benchmark_combo_kernel
    kernel_code_list = self.generate_combo_kernel_code(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/simd.py", line 1849, in generate_combo_kernel_code
    src_code = kernel.codegen_kernel()
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 802, in codegen_kernel
    code.splice(self.codegen_kernel_benchmark(num_gb=0))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 852, in codegen_kernel_benchmark
    var_names.extend(self.kernel_benchmark_extra_args())
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 733, in kernel_benchmark_extra_args
    extra_args.append(str(V.graph.sizevars.size_hint(tree.numel)))
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 584, in size_hint
    return int(out)
           ^^^^^^^^
  File "/home/colinpeppler/.conda/envs/pytorch/lib/python3.12/site-packages/sympy/core/expr.py", line 307, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._inductor.exc.InductorError: TypeError: Cannot convert symbols to int
```

Differential Revision: [D82042230](https://our.internmc.facebook.com/intern/diff/D82042230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162442
Approved by: https://github.com/jansel
2025-09-10 20:49:38 +00:00
Yidi Wu
48e3be3ab6 [while_loop][autograd] add hop while_loop_stack_output (#160467)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160467
Approved by: https://github.com/zou3519
ghstack dependencies: #160548
2025-09-06 21:26:33 +00:00
ruisizhang123
291cd11f2d [inductor] estimate peak memory in codegen only when buffer reuse (#162300)
As titled, this PR ensures peak memory is estimated only when buffer reuse is enabled. Without this config, some nodes' successor nodes are eliminated from memory estimation after inductor bucketing, which can cause errors.

The original codegen peak memory estimation code is from this PR: https://github.com/pytorch/pytorch/pull/159530

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162300
Approved by: https://github.com/eellison, https://github.com/v0i0
2025-09-06 01:30:38 +00:00
Nan Zhang
3a207816cc Forward fix for user defined triton kernel grid calc (#162162)
Summary:

This change fixes the test: inductor:fxir_backend - test_custom_triton_autotune_dynamic which was broken by https://github.com/pytorch/pytorch/pull/160997

Test Plan:
inductor:fxir_backend - test_custom_triton_autotune_dynamic

Rollback Plan:

Differential Revision: D81679217

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162162
Approved by: https://github.com/eellison, https://github.com/jansel
2025-09-04 22:51:23 +00:00
Oguz Ulgen
81aeefa657 Add torch.compile support for triton.constexpr_function (#162106)
Fixes #161868

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162106
Approved by: https://github.com/jansel, https://github.com/zou3519
2025-09-04 16:46:55 +00:00
Chong Gu
c024b1f5a1 [AMD] [Reland] Fix AMD User Defined Kernel Autotune (#161521)
Summary: This is a reland of D80285441, fixed the unit test.

Test Plan:
```
buck2 run mode/opt-amd-gpu -m rocm641 -c fbcode.split-dwarf=true -c fbcode.use_link_groups=true -c fbcode.enable_gpu_sections=true //hpc/new/models/feed/benchmark:feed_lower_benchmark -- --load=manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/894698382/0/gpu_lowering/new_input8 --skip-eager --skip-flop-estimation --sync-mode=0 --lower-backend=AOT_INDUCTOR

```
will succeed after this diff.

Rollback Plan:

Differential Revision: D80971224

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161521
Approved by: https://github.com/frank-wei
2025-09-04 08:41:18 +00:00
Yidi Wu
bf6aaba0f7 [while_loop] avoid aliasing when body_fn never executes (#160670)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160670
Approved by: https://github.com/zou3519
ghstack dependencies: #160548, #160669
2025-08-29 19:36:37 +00:00
Shangdi Yu
9a12bab0d3 Add debug handle to inductor provenance tracking (#161110)
Summary:
Use debug handle on kernel names to distinguish different calls to the same kernel.

Previous kernel name: kernel_name

New kernel name: kernel_name:debug_handle

We add the debug handle to the tlparse artifacts: `inductor_provenance_tracking_node_mappings` and `inductor_provenance_tracking_kernel_stack_traces`.

We also add debug handles in the comments of the generated code so we can map to them in the provenance tracking highlighter tool: https://github.com/pytorch/tlparse/pull/134

Example output code is below. If a kernel doesn't have a debug handle, the `[Provenance debug handles]` comment line will not be written.

```
        # Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.addmm, aten.gelu]
        # [Provenance debug handles] triton_poi_fused_addmm_gelu_2:3
        stream0 = get_raw_stream(0)
        triton_poi_fused_addmm_gelu_2.run(buf4, primals_5, 300, stream=stream0)
```

The debug handles will also be used by downstream profilers such as zoomer.

Test Plan:
```
buck run mode/opt fbcode//caffe2/test/inductor:provenance_tracing
```

Rollback Plan:

Differential Revision: D78994959

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161110
Approved by: https://github.com/angelayi
2025-08-27 04:56:11 +00:00
PyTorch MergeBot
40c0e700a4 Revert "[AMD] Fix AMD User Defined Kernel Autotune (#160671)"
This reverts commit 431846a632.

Reverted https://github.com/pytorch/pytorch/pull/160671 on behalf of https://github.com/atalman due to new test is failing: inductor/test_aot_inductor.py::AOTInductorTestABICompatibleGpu::test_rocm_triton_autotuning_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/17172795679/job/48725235301) [HUD commit link](431846a632) ([comment](https://github.com/pytorch/pytorch/pull/160671#issuecomment-3220442141))
2025-08-25 14:07:48 +00:00
Chong Gu
431846a632 [AMD] Fix AMD User Defined Kernel Autotune (#160671)
Summary: AMD specific kwargs need to be removed from the guard, otherwise a keyerror will be raised when executing the kernel.

Test Plan:
```
buck2 run mode/opt-amd-gpu -m rocm641 -c fbcode.split-dwarf=true -c fbcode.use_link_groups=true -c fbcode.enable_gpu_sections=true //hpc/new/models/feed/benchmark:feed_lower_benchmark -- --load=manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/894698382/0/gpu_lowering/new_input8 --skip-eager --skip-flop-estimation --sync-mode=0 --lower-backend=AOT_INDUCTOR
```
can succeed after this change.

Rollback Plan:

Differential Revision: D80285441

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160671
Approved by: https://github.com/muchulee8
2025-08-23 07:23:09 +00:00
PyTorch MergeBot
3f1a97a99c Revert "[dynamic shapes] unbacked-safe slicing (#157944)"
This reverts commit 44549c7146.

Reverted https://github.com/pytorch/pytorch/pull/157944 on behalf of https://github.com/pianpwk due to this PR & internal diff landed out of sync, just reverted internal with D80720654, will revert this & reland as codev ([comment](https://github.com/pytorch/pytorch/pull/157944#issuecomment-3215610135))
2025-08-22 20:48:46 +00:00
Pian Pawakapan
44549c7146 [dynamic shapes] unbacked-safe slicing (#157944)
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944
Approved by: https://github.com/laithsakka
2025-08-20 22:52:56 +00:00
PyTorch MergeBot
6ea4be1e2e Revert "[dynamic shapes] unbacked-safe slicing (#157944)"
This reverts commit 2f0cba934d.

Reverted https://github.com/pytorch/pytorch/pull/157944 on behalf of https://github.com/seemethere due to This is blocking internal sync due to merge conflicts ([comment](https://github.com/pytorch/pytorch/pull/157944#issuecomment-3206833193))
2025-08-20 15:16:45 +00:00
Markus Hoehnerbach
65d21dae18 [inductor] dont reuse buffers if it affects peak (#145883) (#159530)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530
Approved by: https://github.com/eellison
2025-08-19 19:02:56 +00:00
Pian Pawakapan
2f0cba934d [dynamic shapes] unbacked-safe slicing (#157944)
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944
Approved by: https://github.com/laithsakka
2025-08-19 17:32:47 +00:00
PyTorch MergeBot
5e98d9f9ba Revert "[dynamic shapes] unbacked-safe slicing (#157944)"
This reverts commit 56218d85e2.

Reverted https://github.com/pytorch/pytorch/pull/157944 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think this is failing test_draft_export in trunk 56218d85e2 ([comment](https://github.com/pytorch/pytorch/pull/157944#issuecomment-3198874677))
2025-08-19 01:16:17 +00:00
Pian Pawakapan
56218d85e2 [dynamic shapes] unbacked-safe slicing (#157944)
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944
Approved by: https://github.com/laithsakka
2025-08-18 22:38:16 +00:00
PyTorch MergeBot
9df07ecfbe Revert "[inductor] dont reuse buffers if it affects peak (#145883) (#159530)"
This reverts commit 3be70dc30e.

Reverted https://github.com/pytorch/pytorch/pull/159530 on behalf of https://github.com/clee2000 due to newly added test fail internally D80316528, probably just a targets change, but also imo the tests should probably go into a testcase class from common or inductor utils.  While I'm pretty sure CI can run the globally defined ones, theres some CI related functionality that on the testcase class that CI benefits from ([comment](https://github.com/pytorch/pytorch/pull/159530#issuecomment-3191947506))
2025-08-15 15:49:04 +00:00
Shangdi Yu
aa99e0958f Separate provenance tracking to different levels (#160383)
Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.

- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
    - `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
    - `dump_inductor_provenance_info` if we're dumping tlparse log
    - `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
    - add stack trace from post grad nodes to inductor IR nodes
    - add exception swallowing for all functions above

Test Plan:
CI

Rollback Plan:

Differential Revision: D80031559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
2025-08-15 04:59:35 +00:00
Markus Hoehnerbach
3be70dc30e [inductor] dont reuse buffers if it affects peak (#145883) (#159530)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530
Approved by: https://github.com/eellison
2025-08-14 21:14:36 +00:00
Boyuan Feng
5f1010fbb3 [Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Run the same diff on two days and both show speedup on average.

[first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d)
<img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" />

[second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf)
<img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667
Approved by: https://github.com/eellison
2025-08-12 04:37:58 +00:00
PyTorch MergeBot
09381f5dac Revert "[Graph Partition] Pass all OSS unit tests (#154667)"
This reverts commit ca7315c171.

Reverted https://github.com/pytorch/pytorch/pull/154667 on behalf of https://github.com/clee2000 due to broke inductor/test_memory.py::TestOperatorReorderForPeakMemory::test_reorder_peak_memory_lpmf [GH job link](https://github.com/pytorch/pytorch/actions/runs/16885961204/job/47836769279) [HUD commit link](ca7315c171) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/154667#issuecomment-3176805477))
2025-08-11 20:34:27 +00:00
Shangdi Yu
9ccd0f5e31 Fix unbacked symint and memory leak in inductor memory planning (#159839)
Summary:

In memory planning, some allocation sizes involve unbacked symints. These unbacked symints are not known before they are computed in run time, so **allocation pools that involve unbacked symints cannot be allocated until we have the values of the unbacked symints** .

So we add a notion of `earliest_available` to Allocation nodes. If an allocation node has unbacked symint, it is available at only when its live range begin.

Then in AllocationPool, if a pool involves an Allocation node that has an earliest available time, we restrict its life range.

If a block's earliest available time is later than a pool's life range's start time, we cannot allocate it from the pool.

We also fix a memory leak that's caused by allocating tensor without wrapping it with RAIIAtenTensor.

In python wrapper for JIT inductor, `codegen_alloc_from_pool` doesn't actually write the alloc lines to wrapper, it just returns the string to alloc. However, in cpp_wrapper, `codegen_alloc_from_pool`  actually write to the wrapper. Specifically, it writes the following and returns string `RAIIAtenTensorHandle`.

```
AtenTensorHandle handle_name;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(....);
```

This is bug prune. **If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle as well**, otherwise you get memory leaks.

We remove the alloc_from_pool call from codegen_create, because this doesn't work for AOTI. In python wrapper, we can generate the same alloc_from_pool variable name for the same block, but cpp_wrapper will generate a different variable name for each call to alloc_from_pool.

Test Plan:
```
 python test/inductor/test_memory_planning.py
```

Rollback Plan:

Differential Revision: D79603119

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159839
Approved by: https://github.com/jansel
2025-08-11 17:16:15 +00:00
Boyuan Feng
ca7315c171 [Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Run the same diff on two days and both show speedup on average.

[first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d)
<img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" />

[second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf)
<img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667
Approved by: https://github.com/eellison
2025-08-11 16:25:12 +00:00
Markus Hoehnerbach
e167c7d0f3 [inductor] allocate non-blocking copy destinations in pinned memory (#155121) (#158758)
Fixes #155121

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158758
Approved by: https://github.com/EikanWang, https://github.com/eellison
2025-08-07 17:07:26 +00:00
eellison
eb25a95a6e Fix inductor memory estimation when a single buf has multiple mutations. Add runtime verification of mem tracking (#159569)
With fsdp, we sometimes have multiple, non-overlapping views of a single buffer which are all mutated. Previously we considered the original buffer as an allocation, and make the mutated buffer the deallocation. With multiple mutations of the same buffer, we need to consider the original buffer as deallocated only when all of its aliases die (and avoid double counting the input buffer size). See comment inline:

```
    When an operation mutates a buffer in-place, the scheduler creates a new buffer name
    to track the "before" and "after" states, even though they share the same memory.
    The mutated buffer represents a rename with zero allocation and deallocation cost.
    During dependency tracking, we transfer dependencies from the mutated name back to
    the original buffer, ensuring the original memory is only freed when all aliases
    are done.
    This handles cases where a buffer has multiple non-overlapping aliases - rather than
    trying to assign free costs to individual aliases, we forward all alias dependencies
    to the original buffer.
    Consider:
        buf0 = op0()
        buf1 = mutation_op_(buf0)
        del buf0
        ...
        op(buf1)
        del buf1
    The only memory events are the creation prior to op0, and the deletion following buf1.
```

As @IvanKobzarev 's logs in https://github.com/pytorch/pytorch/pull/158361/files#diff-e173a1d52aff49959c9f6d17ecc09946d8a616fc5909df884e62a15e1ebd1d41R1776-R1807 show, it can a bit of a pain to pinpoint which part of our memory calculation is incorrect.

This pr also adds a runtime verifier `config.test_configs.track_memory_lifecycle` which tracks buffer allocation and deallocation, and errors if their lifetime does not match our expectations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159569
Approved by: https://github.com/IvanKobzarev
2025-08-05 19:58:11 +00:00
Michael Lazos
7ba996bbaa [Cutlass] Fix wrapper code generation breakage (#159760)
Fixes issues introduced by https://github.com/pytorch/pytorch/pull/159355

The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159760
Approved by: https://github.com/angelayi
2025-08-04 23:03:03 +00:00
PyTorch MergeBot
83ba3f1101 Revert "[inductor] allocate non-blocking copy destinations in pinned memory (#155121) (#158758)"
This reverts commit 6085bf7565.

Reverted https://github.com/pytorch/pytorch/pull/158758 on behalf of https://github.com/davidberard98 due to I need to revert #158462 (it causes device-side asserts), and this PR causes a merge conflict in the test file. Sorry about that! ([comment](https://github.com/pytorch/pytorch/pull/158758#issuecomment-3152490371))
2025-08-04 21:47:11 +00:00