Commit Graph

448 Commits

Author SHA1 Message Date
Shangdi Yu
7869196482 Fix torchbind schema str generation (#149239)
Summary: Fix Torchbind HOP schema generation when there's no input

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r schema
```

Differential Revision: D71231164

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149239
Approved by: https://github.com/zou3519
2025-03-18 04:29:56 +00:00
Shangdi Yu
cf19efd3d9 Support basic TorchBind in aot_compile and aoti_compile_and_package (#148506)
Summary:
**Codegen**

- Skip some codegen parts for torchbind (such as arg decleration) because they are loaded in proxy executor, so we do not need to declare torchbind args in cpp code
- Added a helper method to get the schema of CallTorchBind HOP. The returned schema is only the schema of `obj.method()`.

**Serialization**
Add support for torchbind object in serialization

- For CallTorchBind HOP, we need to handle it specially because of it's schema. The output serialized args is in the format of `(obj, method, *args, **kwargs)`.
- it.TorchBindObject inputs are serialized to `as_custom_obj` Argument.

**Packaging**

Add torchbind objects file and `custom_objs_config.json` file to generated files output of `aot_compile`.

The json file is stored in the `data/aotinductor/<model_name>` folder in pt2 archive.

The torchbind objects are stored in data/constants/ folder in pt2 archive.
The format of torchbind objects are `f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}"`. e.g. `custom_obj_0`.
CustomClassHolder objects implement their own pickle methods.

Note that this `custom_objs_config.json` file is different from the `model_constants_config.json` file produced in package_sigmoid(). The keys in `custom_objs_config` directly correspond to the arg name in extern nodes json.
The key in `model_constants_config.json` produced by `package_sigmoid` is the attribute name in the user mode code.

This is required for both internal and OSS torchbind support.
For OSS torchbind support, we also need to package torchbind_constants into the .pt2 output.

**Work Left**
We still need to add torchbind support in ProxyExecutor for inductor.aoti_load_package to work. See other diffs in the stack.

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r schema
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r aot_compile
```

Differential Revision: D69490718

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148506
Approved by: https://github.com/angelayi
2025-03-11 20:55:18 +00:00
PyTorch MergeBot
c916a8efc5 Revert "Use the device interface for detecting Triton availability (#139171)"
This reverts commit 940b60db97.

Reverted https://github.com/pytorch/pytorch/pull/139171 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @jansel can you please help get these changes working? See D70946254 for more details. To validate the fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/139171#issuecomment-2715392451))
2025-03-11 18:49:21 +00:00
Jason Ansel
09029010e5 [inductor] Fix create_specialize_impl error in latest Triton (#148933)
```py
$ python test/inductor/test_triton_kernels.py KernelTests.test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1
WARNING:torch._dynamo:Encountered an exception in identify_mutated_tensors, assuming every input is mutated
Traceback (most recent call last):
  File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 715, in identify_mutated_tensors
    ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 289, in generate_ttir
    specialization = _get_specialization(ordered_args.values())
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 262, in _get_specialization
    specialize_impl = triton.runtime.jit.create_specialize_impl()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: create_specialize_impl() missing 1 required positional argument: 'specialize_extra'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148933
Approved by: https://github.com/yanboliang, https://github.com/davidberard98
2025-03-11 15:54:47 +00:00
George White
940b60db97 Use the device interface for detecting Triton availability (#139171)
This allows for each device type to check current devices for Triton compatibility and ensure their Triton backend is present.

This PR replaces the `has_triton()` global method which was previously used for this task, and moves the initial check for each Inductor backend on to their associated `BaseScheduler` subclass. This means that other backends, such as Halide, can also implement their own availability checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139171
Approved by: https://github.com/jansel
2025-03-11 03:56:11 +00:00
David Berard
c3b05c4a27 [triton 3.3] support both specialize_impl and create_specialize_impl (#148806)
After https://github.com/triton-lang/triton/pull/6099, we sometimes need to do `from triton.runtime.jit import specialize impl` and sometimes do `triton.runtime.jit.create_specialize_impl()`. This should fix a bunch of the new errors that appeared with the triton 3.3 / pytorch 2.7 integration (e.g. `python test/inductor/test_aot_inductor.py -k test_triton_kernel_equal_to_1_float_arg_dynamic_False_cuda`, failing at https://hud.pytorch.org/pr/pytorch/pytorch/148684#38392501220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148806
Approved by: https://github.com/drisspg
2025-03-08 09:31:52 +00:00
Thomas Bohnstingl
23441492f6 [scan] Refactoring of input checking and dynamo invocation (#142125)
This PR does a refactoring of the way dynamo is invoked and how the input shapes are checked for scan and for associative_scan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142125
Approved by: https://github.com/ydwu4
2025-03-06 01:06:54 +00:00
Ryan Guo
ad9a10aff0 [dynamo] Make nonstrict_trace work with some pytree.register_constant-ed instances (#148007)
As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
2025-03-05 21:28:26 +00:00
Thomas Bohnstingl
e4c558be1d [scan] Corrections for scan (#146110)
This PR resolves some minor issues with the scan HOP and unifies the handling of the additional_inputs in the same way as for associative_scan.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146110
Approved by: https://github.com/ydwu4
2025-03-04 20:29:08 +00:00
Sijia Chen
4995e058bf [user-triton] handle inline_asm_case (#148043)
Summary: We currently failed the mutation analysis for all inline_asm ops. In this diff, we handle the case when "is_pure" is set to True since it indicates the operation doesn't mutate the input value

Test Plan:
../buck-out/v2/gen/fbcode/854b9ed00d28c5c5/caffe2/test/inductor/__triton_kernels__/triton_kernels.par --r test_mutations_inline_asm_kernel

```
test_mutations_inline_asm_kernel_is_pure_true (caffe2.test.inductor.test_triton_kernels.MutationTests) ... W0226 18:10:34.261000 1906801 /data/users/sijiac/fbsource/fbcode/caffe2/torch/_higher_order_ops/triton_kernel_wrap.py:656] TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True)
ok

----------------------------------------------------------------------
Ran 2 tests in 0.706s

OK
```

Differential Revision: D69878591

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148043
Approved by: https://github.com/zou3519
2025-02-28 20:52:51 +00:00
Thomas Bohnstingl
7c71ab1d40 [scan] User-facing reverse flag handling (#147886)
This PR removes the reverse flag from the backend implementation and resolves it via `torch.flip` in the frontend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147886
Approved by: https://github.com/ydwu4
2025-02-26 20:04:57 +00:00
Ryan Guo
f46f0e465c [dynamo] Initial support for nonstrict_trace (#146367)
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.

1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)

## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.

The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards).  This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
   input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
   buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
   `torch._higher_order_ops.flat_apply` (which unflattens the inputs and
   invokes the underlying function).

## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146367
Approved by: https://github.com/zou3519
ghstack dependencies: #146714
2025-02-26 19:47:39 +00:00
Ryan Guo
bab84f0bd9 [hop] Support more output types for flat_apply (#146714)
This patch enables `flat_apply` to support certain non-Tensor output
types like containers and graphable types. This will in turn enable the
upcoming `mark_traceable` to support more output types.

The patch also exposes a `func_to_graphable` rather than having the
users calling the lower level `pytree.flatten(ConstantFunction(...))`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146714
Approved by: https://github.com/zou3519
2025-02-26 19:47:39 +00:00
rzou
fb566c5aea Fix auto_functionalize x inference_mode (#147925)
Fixes #147924

We were using the wrong FunctionalTensorMode to construct
FunctionalTensors. FunctionalTensors modify the FunctionalTensorMode on
construction, so that led to the wrong FunctionalTensorMode being
modified. This PR threads the FunctionalTensorMode through correctly.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147925
Approved by: https://github.com/bdhirsh
2025-02-26 18:05:30 +00:00
Yidi Wu
824474cb35 [cond] support output sizes mismatch in front end (#147130)
This PR finishes https://github.com/pytorch/pytorch/pull/137615 by addressing the TODOs and comments left there.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147130
Approved by: https://github.com/zou3519
2025-02-25 20:28:41 +00:00
David Berard
94969d0a40 [inductor][user triton] Handle scf.yield more accurately (#147762)
**TL;DR**: Previously, the mutation analysis for scf.if/scf.for would bundle all the scf.yield arguments into a single op (the scf.yield), such that a mutation on any returned value from the scf.if/scf.for would register as a mutation to _all_ of the scf.yield args. To fix this, this PR artificially introduces a new scf.yield op for each of the scf.yield args.

**Context**: The relevant kernel is something like this one (added as a test in test_triton_kernels.py)

```python
        @triton.jit
        def branch_with_multiple_yield_args(
            in_ptr0,
            in_ptr1,
            out_ptr,
            conditional_ptr,
            n_elements,
            BLOCK_SIZE: "tl.constexpr",
        ):
            pid = tl.program_id(axis=0)
            block_start = pid * BLOCK_SIZE
            offsets = block_start + tl.arange(0, BLOCK_SIZE)
            mask = offsets < n_elements
            conditional = tl.load(conditional_ptr)
            if conditional:
                in0 = in_ptr0 + 1
                in1 = in_ptr1 + 1
                out = out_ptr + 1
            else:
                in0 = in_ptr0
                in1 = in_ptr1
                out = out_ptr
            x = tl.load(in0 + offsets, mask=mask)
            y = tl.load(in1 + offsets, mask=mask)
            tl.store(out + offsets, x + y, mask=mask)
```

The mutation analysis starts with the `tl.store` - and then does a DFS backwards towards the parameters.  When a new op is encountered in the DFS, the analysis pass recurses on the op's arguments.

The if branch gets converted to TTIR like this:

```mlir
    %21:3 = scf.if %20 -> (!tt.ptr<f32>, !tt.ptr<f32>, !tt.ptr<f32>) {
      ...
      scf.yield %31, %32, %33 : !tt.ptr<f32>, !tt.ptr<f32>, !tt.ptr<f32> loc(#loc10)
    } else {
      scf.yield %arg0, %arg1, %arg2 : !tt.ptr<f32>, !tt.ptr<f32>, !tt.ptr<f32> loc(#loc11)
    } loc(#loc7)
```

and so the "source" op of the `out` variable is marked as the `scf.yield` op - and then all of the arguments to `scf.yield` are marked as mutable (including arg0, arg1, and arg2 - only one of which is actually mutated).

**This PR** we duplicate the `scf.yield` to add one `scf.yield` per return value. That way we avoid marking all the returns from the scf.if/scf.for as mutated when only some are.

Differential Revision: [D70118202](https://our.internmc.facebook.com/intern/diff/D70118202)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147762
Approved by: https://github.com/oulgen, https://github.com/zou3519
2025-02-25 08:41:00 +00:00
Xuehai Pan
a50af71fb6 [FX] Refactor immutable collections implementation (#144640)
Get rid of dynamic class creation via `type(name, bases, ...)`. Convert it to classic static class definition for better readability and static analysis support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144640
Approved by: https://github.com/jansel
ghstack dependencies: #147699
2025-02-24 09:14:08 +00:00
Thomas Bohnstingl
6eb795c9e8 [associative_scan] compile backend change to "eager" (#146973)
This PR fixes some issues with torch export discussed here: https://github.com/pytorch/pytorch/pull/140043#discussion_r1941932960

However, this backend change does still not resolve the failure for specific shapes mentioned here: https://github.com/pytorch/pytorch/issues/137943#issuecomment-2649564994

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146973
Approved by: https://github.com/ydwu4
2025-02-21 20:21:41 +00:00
Aaron Orenstein
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
Yidi Wu
77aa602871 [torchbind] Differentiate ScriptModule and ScriptObject with qualified name (#147399)
Summary:
This pr add a _is_script_object method to differentiate scriptModule and scriptObject, where the formal inherits from ScriptObject in C++ so they both passes the isinstance(obj, torch.ScriptObject) check.

The qualified name of ScriptObject (i.e. custom class) would starts with "__torch__.torch.classes", this has been a widely used assumption for dealing with custom class across our code base.

Test Plan: Add new test.

Differential Revision: D69685316

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147399
Approved by: https://github.com/yushangdi
2025-02-20 04:57:57 +00:00
Shangdi Yu
0b0da81021 Support static method of torchbind attributes in torch.compile with inductor backend (#146927)
As title.

Many changes adapted from https://github.com/pytorch/pytorch/pull/129537.

Also this diff is only for *static* method of torchbind *attributes*. Some case that's not supported/tested:
- dynamic torchbind objects
-  torchbind objects as an input to the module.

Note that in JIT Inductor, the attributes are lifted as inputs. So even if we just have torchbind objects as attributes, they will show up as inputs in the graph.

Example generated python code in torch.compile with inductor backend for the test case in `inductor/test_torchbind.py` (P1730554370):

```python
async_compile.wait(globals())
del async_compile

def call(args):
    arg1_1, arg2_1, arg3_1 = args
    args.clear()
    assert_size_stride(arg1_1, (2, 3), (3, 1))
    assert_size_stride(arg2_1, (2, 3), (3, 1))
    buf2 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
    cpp_fused_add_0(arg1_1, arg2_1, buf2)
    del arg1_1
    del arg2_1
    # Topologically Sorted Source Nodes: [x, takes_foo_tuple_return], Original ATen: [aten.add]
    buf3 = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(arg3_1, buf2)
    buf4 = buf3[0]
    assert_size_stride(buf4, (2, 3), (3, 1))
    buf5 = buf3[1]
    assert_size_stride(buf5, (2, 3), (3, 1))
    buf6 = buf4; del buf4  # reuse
    cpp_fused_add_1(buf6, buf5)
    del buf5
    # Topologically Sorted Source Nodes: [y, b], Original ATen: [aten.add]
    buf7 = torch.ops._TorchScriptTesting.takes_foo.default(arg3_1, buf6)
    del buf3
    del buf6
    buf8 = buf7
    assert_size_stride(buf8, (2, 3), (3, 1))
    # Topologically Sorted Source Nodes: [c], Original ATen: []
    buf9 = torch.ops.higher_order.call_torchbind(arg3_1, 'add_tensor', buf2)
    del arg3_1
    del buf7
    buf10 = buf9
    assert_size_stride(buf10, (2, 3), (3, 1))
    del buf9
    buf11 = buf2; del buf2  # reuse
    cpp_fused_add_2(buf11, buf8, buf10)
    return (buf11, )

def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg1_1 = rand_strided((2, 3), (3, 1), device='cpu', dtype=torch.float32)
    arg2_1 = rand_strided((2, 3), (3, 1), device='cpu', dtype=torch.float32)
    import pickle
    global arg3_1
    arg3_1 = pickle.loads(b'\x80\x04\x95[\x00\x00\x00\x00\x00\x00\x00\x8c\x05torch\x94\x8c\x0cScriptObject\x94\x93\x94)\x81\x94]\x94(K\nK\x14e\x8c0__torch__.torch.classes._TorchScriptTesting._Foo\x94\x86\x94b.')
    fn = lambda: call([arg1_1, arg2_1, arg3_1])
    return print_performance(fn, times=times, repeat=repeat)

if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146927
Approved by: https://github.com/angelayi
2025-02-20 03:33:19 +00:00
rzou
fea718f062 [BaseHOP] change hop(subgraph, operands) to hop(subgraph, *operands) (#146730)
Our three main users are OK with this, with two of them (foreach_map,
invoke_quant) prefering it like this.

I was originally worried about BC issues (this now means you cannot add
any positional args) but I think that's not a concern -- one can always
add kwonly args.

Test Plan
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146730
Approved by: https://github.com/ydwu4, https://github.com/mlazos
2025-02-20 02:30:36 +00:00
Yidi Wu
85a82c5bc8 [cond] make cond re-dispatch in proxy mode (#146954)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146954
Approved by: https://github.com/zou3519
2025-02-14 23:13:14 +00:00
PyTorch MergeBot
9a883007a2 Revert "Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)"
This reverts commit c7515da7b0.

Reverted https://github.com/pytorch/pytorch/pull/140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](https://github.com/pytorch/pytorch/pull/140979#issuecomment-2657361940))
2025-02-13 18:04:26 +00:00
PyTorch MergeBot
65e8862b9a Revert "[cond] make cond re-dispatch in proxy mode (#146954)"
This reverts commit 2ce6de2415.

Reverted https://github.com/pytorch/pytorch/pull/146954 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but I need to revert it to cleanly revert 140979 ([comment](https://github.com/pytorch/pytorch/pull/146954#issuecomment-2657357742))
2025-02-13 18:02:33 +00:00
Yidi Wu
2ce6de2415 [cond] make cond re-dispatch in proxy mode (#146954)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146954
Approved by: https://github.com/zou3519
2025-02-13 00:50:33 +00:00
Thomas Bohnstingl
3a29992ee6 [associative_scan] Lifted arguments (#140043)
This PR implements lifted arguments for associative_scan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140043
Approved by: https://github.com/ydwu4
2025-02-11 23:25:55 +00:00
Daniel Galvez
c7515da7b0 Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)
This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to https://github.com/isaacs/github/issues/361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140979
Approved by: https://github.com/eqy, https://github.com/eellison
2025-02-11 18:16:15 +00:00
rzou
1d81ecfc54 Rename PrimHOPBase to BaseHOP + minor changes (#146727)
This PR:
- renames PrimHOPBase to BaseHOP
- changes the backward pass to always return a tuple (to match the
  forward pass).

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146727
Approved by: https://github.com/ydwu4
2025-02-11 02:43:37 +00:00
eellison
a36c22f2ed futher scheduler changes for invoke_quant: prologue low prec, (slightly) more aggressive fusion (#145104)
Respect invoke_quant low precision options, also, be more aggressive in attepmting fusion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145104
Approved by: https://github.com/shunting314, https://github.com/jansel
ghstack dependencies: #139102
2025-02-10 15:50:19 +00:00
eellison
92b7e610ab [Inductor changes] Invoke Quant (#139102)
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).

The primary motivations are

- Unifying scattered reasoning for quant operators throughout the code base

- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:

```
        @register_graph_pattern(
            CallFunction(
                torch.ops.aten.mm,
                CallFunction(
                    torch.ops.higher_order.invoke_quant,
                    Ignored(),
                    Ignored(),
                    Ignored(),
                    scheme="nf4",
                ),
                Arg(),
            ),
            pass_dict=test_pass,
        )
```

- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.

Example graph:

``` Python
 ===== AFTER POST GRAD =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
         # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
        repeated_subgraph0 = self.repeated_subgraph0
        invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');  repeated_subgraph0 = arg0_1 = arg1_1 = None
        return (invoke_quant,)

    class repeated_subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
             # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
            mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None
            add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1);  mul = arg1_1 = None
            return add
```

The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.

I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.

```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.

Feedback welcome.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
2025-02-08 19:30:19 +00:00
drisspg
69feef5a94 Fix broken meta function for flex-attention backwards (#146563)
# Summary

Fixes https://github.com/pytorch/pytorch/issues/146377

So what was the original problem: we were codegening a really weird epilogue:

```Python
        # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
        # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
        xindex = index_k + 64*index_n + 64*off_hkv*ks2 + 128*off_zq*ks2
        tl.store(out_ptr0 + (tl.broadcast_to(index_k + 64*index_n + off_hkv*ks1, dk.shape)), dk, mask)
        x5 = (xindex % ks3)
        tmp2 = tl.load(out_ptr0 + (x5 + ks1*off_hkv), mask, eviction_policy='evict_last')
        tl.store(out_ptr1 + (tl.broadcast_to(xindex, dk.shape)), tmp2, mask)
 ```

 This epilogue was writing and then reading from overlapping regions of memory causing a race condition.

 ### Why were we generating this epilgoue

 During the lowering we created a buffer w/ a different size/stride from the expected return strides. I :think this added an implicit node (for doing the permutation of this wrongly strided output to the the expected one from the meta func. The scheduler for some reason thought it was okay to fuse this into the epilogue, tbh I dont know why.

 This fixes the broken meta func and the original repro. I will add a test but it is hard to pop, better than nothing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146563
Approved by: https://github.com/Chillee
2025-02-08 04:13:52 +00:00
rzou
1bb977a2a4 [auto_functionalized] Support Tensor(a!)[]? (#145400)
Summary:
This is just updating some of the checks to allow the Tensor(a!)[]? type
through.

Fixes #144072

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145400
Approved by: https://github.com/laithsakka
2025-02-05 14:52:39 +00:00
rzou
0f768c7866 Barebones flat_apply HOP (#146060)
This PR:
- adds pytree.register_constant for registering a class to be treated as
  a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
  mark_traceable working. A lot more work is necessary to get the custom
  operator case working (when make_fx sees a custom operator with PyTree
  arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
  the current state to unblock the mark_traceable and custom ops
  work.

Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
  I added a really simple test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059
2025-02-01 16:17:48 +00:00
Sherlock Huang
cf2de4e230 Introduce aoti_call_delegate HOP (#145630)
Summary:
Previously, aoti compile node is represented as a kernel-less custom op in the exported program. The node was not eager runnable, which is a common practice for numerical validation during lowering.

I introduce a new HOP to address this.

The schema is following
```
aoti_call_delegate(lower_moduel: AOTInductorEPModule, original_gm: fx.GraphModule, weights: List[Tensor], inputs: List[Tensor])
```

There are a few problems exposed by HOP
- AOTI expects a FX graph with weights as getattr nodes, aka stateful graph. HOP expect graph_module arguments to be stateless. Export serializer also expect a stateless graph. Currently, to make AOTI happy, I am making `original_gm` stateful, and bypassing the serialization for `original_gm`.
- As a result, the HOP is not re-traceable, as functionalization on stateful graph module argument will fail.

Test Plan: buck2 test 'fbcode//mode/opt' fbcode//deeplearning/aot_inductor/cpu/test:cpu_lowering_utils_test

Reviewed By: zhxchen17

Differential Revision: D68359391

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145630
Approved by: https://github.com/zou3519
2025-01-31 04:57:36 +00:00
Yidi Wu
7e7341bddd [hop] fix unbacked_bindings meta for while_loop (#143559)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143559
Approved by: https://github.com/zou3519
2025-01-30 21:33:09 +00:00
Thomas Bohnstingl
9f9904172d [scan] scan dim handling in user-facing scan() (#145179)
This PR introduces the capability that the scan dim is handled in the user facing scan() call. Internally, the scan dim is always shifted to dim 0 and then the scan is performed over that dim.

This is a follow-up PR from https://github.com/bohnstingl/pytorch/pull/3

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145179
Approved by: https://github.com/ydwu4
2025-01-30 21:09:07 +00:00
Yidi Wu
a3698ebd5c [while_loop] specialize when cond_fn return constants (#144515)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144515
Approved by: https://github.com/zou3519
2025-01-30 19:02:34 +00:00
Yidi Wu
d1143c4b37 [export] fix non-strict pre_dispatch exporting while_loop (#145762)
fix https://github.com/pytorch/pytorch/issues/145737.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145762
Approved by: https://github.com/tugsbayasgalan, https://github.com/zou3519, https://github.com/avikchaudhuri
2025-01-30 18:58:34 +00:00
rzou
1e57154af3 Require that all HOPs be imported at import torch time (#145939)
E.g. torch.ops.higher_order.cond does not exist until it is imported,
which is bad if it shows up in an FX graph or is used in some code
somewhere.

This PR also makes some more HOPs get imported at `import torch` time.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145939
Approved by: https://github.com/ydwu4
ghstack dependencies: #145938
2025-01-29 22:27:52 +00:00
Bert Maher
ae0f305bf9 [inductor] Make triton kernel autotune config defaults backward-compatible (#145494)
If a model was torch.packaged using triton<=3.1, any user-defined
autotuned kernels will have reps/warmups burned in with the old defaults
(100/25).  If this model is loaded with triton>=3.2, inductor's checks for
unsupported non-default autotune args will fail, because triton.Autotuner's
defaults for these parameters has changed to `None`.  Let's explicitly support
those values for backward compatibility with these older models.

Differential Revision: [D68561014](https://our.internmc.facebook.com/intern/diff/D68561014/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145494
Approved by: https://github.com/aorenste
2025-01-29 00:31:39 +00:00
Thomas Bohnstingl
82859f6185 [associative_scan] scan dim handling in user-facing associative_scan() (#139864)
This PR implements the user-facing dim change, i.e., that the scan dim provided by the user is always moved to dim 0 and then the associative_scan operation always operates on dim 0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139864
Approved by: https://github.com/ydwu4
2025-01-28 23:58:10 +00:00
Pian Pawakapan
1a26cdd5cb [cond] remove warning for unsupported tuple returns (#145766)
I guess this is supported now
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145766
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2025-01-28 03:13:36 +00:00
Aaron Gokaslan
f3304571fc [BE][Ez]: FURB148 - remove useless enumerate calls (#145619)
Remove useless enumerate calls

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145619
Approved by: https://github.com/drisspg
2025-01-24 23:37:15 +00:00
David Berard
b2c89bc115 [inductor][2/N] triton support post-#5512, user-defined triton kernels (#145348)
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This PR fixes user-defined triton kernel handling (in most cases) for these new triton commits.

What this PR fixes:
* in triton_kernel_wrap.py, AST->TTIR parsing was to be updated for the new triton API
* ir.py - don't remove None args when using newer triton versions
* wrapper.py - update signature & constant handling

What this doesn't fix:
* correct None handling - I want to do a closer look at constant handling (including None, equal_to_1, and other constants).
* cpp wrapper (which needs to be fixed for both user-defined triton kernels and inductor-generated kernels)

test/inductor/test_triton_kernels.py passed on triton commit 74de6b46, with the exception of three tests (those shown here: 1374074098)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145348
Approved by: https://github.com/jansel
ghstack dependencies: #145051
2025-01-24 00:34:01 +00:00
Yidi Wu
3c247ee8c4 [hop][be] add utils for more comprehensive input alias and mutation (#145298)
This PR implements  the idea of checking input mutations through tensor version and check aliasing via storage  from @zou3519. Previously, we rely on whether there's a in place op that takes placeholder input, which doesn't take views into account.

When writing the PR, I also noticed a bug in previous input mutation checking logic: we were checking the whether there are operators functionalized_f where all the mutating ops have been replaced so we won't be able to detect any thing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145298
Approved by: https://github.com/zou3519
2025-01-23 18:12:28 +00:00
Aaron Orenstein
805c4b597a PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145202
Approved by: https://github.com/bobrenjc93
2025-01-20 22:37:26 +00:00
Aaron Orenstein
d782e46a36 [BE] typing for decorators - library (#138969)
Test Plan: unit tests

Differential Revision: D62302678

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138969
Approved by: https://github.com/zou3519
2025-01-15 17:08:55 +00:00
Yiming Zhou
87843ee9ab [export] Unify single and multiple return for hops (#143227)
Summary: Introduce `is_hop_single_tensor_return` field to the `Node` class in serialization so that during deserialization when there is a single return, we know whether it is a tuple of a single element or a single element.

Test Plan:
```
buck2 run @mode/dev-nosan sigmoid/inference/test:e2e_test_cpu -- -r E2ETestCPUCond
buck2 run @mode/dev-nosan sigmoid/inference/test:test_passes -- -r test_const_folding2
```

Differential Revision: D66991624

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143227
Approved by: https://github.com/zhxchen17
2025-01-13 03:31:14 +00:00
Sam Ginzburg
074aca3ed2 [user triton] add support for @triton.heuristics after @triton.autotune (#142208)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142208
Approved by: https://github.com/zou3519
2025-01-11 02:18:26 +00:00