Commit Graph

443 Commits

Author SHA1 Message Date
Jason Ansel
fed37dbfbc [inductor] Cooperative reductions (#137756)
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
    xnumel = 1
    rnumel = 1048576
    rsplit_id = tl.program_id(0)
    num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
    rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
    rsplit_start = rsplit_chunk * rsplit_id
    rsplit_end = rsplit_chunk * (rsplit_id + 1)
    xoffset = tl.program_id(1) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(rsplit_start, rsplit_end, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp2 = tmp0 + tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(rmask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    if RSPLIT > 1:
        tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
        tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
    if RSPLIT > 1:
        triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
    if RSPLIT > 1:
        tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
        tmp4 = tl.sum(tmp4_peers, 1)[:, None]
    if rsplit_id == (0 % RSPLIT):
        tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
ghstack dependencies: #138970
2024-10-27 16:31:38 +00:00
David Berard
94e341c6a3 [user triton] fix codegen for tl.constexpr globals (#138757)
Fixes #138509

tl.constexpr globals would be codegen-ed as `constexpr()` instead of `tl.constexpr()` if they were un-annotated. This fixes the issue (and adds a test). The correct handling was already added but the corrected string was not being used in the un-annotated branch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138757
Approved by: https://github.com/oulgen
2024-10-25 03:00:42 +00:00
Animesh Jain
dd4dd85210 [hierarchical-compilation][inductor] Support invoke_subgraph HOP (#138031)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138031
Approved by: https://github.com/eellison
ghstack dependencies: #137538, #138036, #137965
2024-10-23 21:32:14 +00:00
Animesh Jain
77868697b7 [inductor][subgraph] Add size asserts (#138424)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138424
Approved by: https://github.com/eellison
ghstack dependencies: #137555
2024-10-21 22:43:49 +00:00
Sam Ginzburg
c1ead6fba3 Bugfix for passing None args to user defined Triton kernel (#138472)
add test

fewer failing tests

more tests passing

tests passing

lint

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138472
Approved by: https://github.com/aakhundov
2024-10-21 20:00:04 +00:00
Jason Ansel
4632594546 [inductor] Move V.graph.scheduler.current_device to V.graph.current_device (#138252)
There are some places where it would be nice to use this, but the scheduler hasn't yet been created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138252
Approved by: https://github.com/eellison
ghstack dependencies: #138170
2024-10-18 23:05:54 +00:00
Jason Ansel
85a6a782e5 [inductor] Generalize WorkspaceArg for graph-level semaphores (#138170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138170
Approved by: https://github.com/Chillee
2024-10-18 23:05:54 +00:00
Adnan Akhundov
d116d007ee Add host-side Triton TMA support to Inductor (#137950)
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- Due to Dynamo support implemented in the previous PR, the `tma_descriptor_metadata` dict is delivered to the `triton_kerenl_wrap_` lowering and passed to the `ir.UserDefinedTritonKernel` as additional argument.
- Looking into the `tma_descriptor_metadata`, `ir.UserDefinedTritonKernel` substitutes the corresponding `TensorBox` arguments of the kernel (swapped upstream in Dynamo) by the new `ir.TMADescriptor` nodes implementing TMA descriptors in Inductor IR.
- `ir.TMADescriptor.__init__` provides the wiring between the upstream underlying `ir.TensorBox` and the downstream `ir.UserDefinedTritonKernel` kernel. In particular, we use `ir.NonOwnedLayout` wrapping `ir.ReinterpretView` to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel).
- Via `ir.TMADescriptor.codegen`, the Triton's `create_{1d,2d}_tma_descriptor` function call is codegened in the wrapper (in the host code).
- New `TMADescriptorArg` dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA.
- AOT Inductor support will be implemented in a follow-up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137950
Approved by: https://github.com/eellison
ghstack dependencies: #137677
2024-10-18 06:27:24 +00:00
Bin Bao
2e67d7cc35 [AOTI] Remove the non-ABI-compatible mode (part 1) (#138009)
Summary: The ABI-compatible mode has been turned on as default in https://github.com/pytorch/pytorch/pull/136534. Removing the non-ABI-compatible logic to greatly simplify the wrapper codegen logic.

Differential Revision: [D64439676](https://our.internmc.facebook.com/intern/diff/D64439676)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138009
Approved by: https://github.com/chenyang78
ghstack dependencies: #137982, #138016
2024-10-17 02:48:26 +00:00
Jason Ansel
5fee1ee3f4 [inductor] Refactor generate_workspace_allocation (#137673)
And some other small changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137673
Approved by: https://github.com/Chillee
ghstack dependencies: #137754
2024-10-13 01:25:14 +00:00
Animesh Jain
04adb74d08 [inductor][cond] Remove redundant prefix (#137718)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137718
Approved by: https://github.com/eellison
ghstack dependencies: #137200
2024-10-11 18:13:18 +00:00
Animesh Jain
cd02c85ba4 [inductor][subgraph][python-wrapper] Lift subgraph code into functions (#137200)
Earlier the subgraphs were getting inlined into the output code. This PR lifts the subgraphs into a function, and then we just call the function in the output code.

This is the output code for test `test_cond_reintepret_view_inputs_outputs`

Before this PR - https://www.internalfb.com/intern/paste/P1632948905/
With this PR - https://www.internalfb.com/intern/paste/P1632946348/

A relevant snippet from the above paste is

~~~

def false_graph_0(args):
    false_graph_0_arg0_1, false_graph_0_arg1_1, s0 = args
    args.clear()
    s0 = s0
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        false_graph_0_buf0 = empty_strided_cuda(((-1) + s0, 20), (20, 1), torch.float32)
        false_graph_0_buf1 = empty_strided_cuda(((-1) + s0, 20), (20, 1), torch.float32)
        # Unsorted Source Nodes: [cond, z1, z2], Original ATen: [aten.sub, aten.add]
        triton_poi_fused_add_sub_1_xnumel = (-20) + (20*s0)
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_sub_1.run(false_graph_0_arg0_1, false_graph_0_arg1_1, false_graph_0_buf0, false_graph_0_buf1, triton_poi_fused_add_sub_1_xnumel, grid=grid(triton_poi_fused_add_sub_1_xnumel), stream=stream0)
        del false_graph_0_arg0_1
        del false_graph_0_arg1_1
    return (reinterpret_tensor(false_graph_0_buf0, ((-3) + s0, 20), (20, 1), 40), reinterpret_tensor(false_graph_0_buf1, ((-1) + s0, 16), (20, 1), 4), )

async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1 = args
    args.clear()
    s0 = arg0_1
    assert_size_stride(arg1_1, (s0, 20), (20, 1))
    assert_size_stride(arg2_1, (s0, 20), (20, 1))
    assert_size_stride(arg3_1, (), ())
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = [None] * 2
        buf0 = [None] * 2
        if arg3_1.item():
            # subgraph: true_graph_0
            true_graph_0_arg0_1 = reinterpret_tensor(arg1_1, ((-1) + s0, 20), (20, 1), 0)
            true_graph_0_arg1_1 = reinterpret_tensor(arg2_1, ((-1) + s0, 20), (20, 1), 0)
            (true_graph_0_buf0, true_graph_0_buf1) = true_graph_0([true_graph_0_arg0_1, true_graph_0_arg1_1, s0])
            buf0[0] = true_graph_0_buf0
            buf0[1] = true_graph_0_buf1
        else:
            # subgraph: false_graph_0
            false_graph_0_arg0_1 = reinterpret_tensor(arg1_1, ((-1) + s0, 20), (20, 1), 0)
            false_graph_0_arg1_1 = reinterpret_tensor(arg2_1, ((-1) + s0, 20), (20, 1), 0)
            (false_graph_0_buf0, false_graph_0_buf1) = false_graph_0([false_graph_0_arg0_1, false_graph_0_arg1_1, s0])
            buf0[0] = false_graph_0_buf0
            buf0[1] = false_graph_0_buf1
        del arg1_1
        del arg2_1
        del arg3_1
        buf1 = buf0[0]
        buf2 = buf0[1]
        del buf0
    return (buf1, buf2, )

~~~

The key change is to recursively call `codegen` for the subgraph and rely on `SubgraphPythonWrapper` to generate just the subgraph `fn`. The resulting subgraph_code is then inserted into the parent wrapper.

Note that this PR only works for python wrapper. For cpp wrapper, we need a lot of refactor to ensure that we don't duplicate the global variables in the outpute_code. So, for now, I fallback to the old way of inlining for cpp wrapper. I am hoping someone with more familiarity with cpp wrapper can support subgraph lifting (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov).

This work will unblock hierarchical compilation (or cold start compile time work).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137200
Approved by: https://github.com/desertfire, https://github.com/eellison
2024-10-11 17:57:10 +00:00
Colin Peppler
9690cacd61 [aotinductor] Add helper fn to atomically apply size_hint to an expr w/ unbacked symints (#137537)
### Context
Fixes CUDA IMA in autotune_at_compile_time, where we would generate an example tensor with an incorrect stride.

In the case below, the stride should be (u0 * 128, 128, 1). However, we apply the fallback on the entire expr (i.e. u0 * 128).
```
# buf817 = tensor(size=(s0, u0, 128), stride=(u0 * 128, 128, 1))

buf812 = generate_example_value(
    (64, 8192, 128), (8192, 128, 1), "cuda:0", torch.bfloat16, 0
)
```

The fix is to apply the fallback on each symbol.

### Test
```
PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer python test_aot_inductor.py -k test_stride_with_unbacked_expr_abi_compatible_cuda

========= Invalid __global__ write of size 2 bytes
```

Differential Revision: [D64074561](https://our.internmc.facebook.com/intern/diff/D64074561)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137537
Approved by: https://github.com/jingsh
2024-10-10 17:11:24 +00:00
Animesh Jain
04e48ac562 [inductor] Refactor prefix to make it easy to create subclass of PythonWrapper (#137198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137198
Approved by: https://github.com/jansel
ghstack dependencies: #137191, #137193
2024-10-07 17:20:58 +00:00
Bin Bao
15c3479db7 [AOTI] Fix _scaled_mm ABI-compatible codegen (#137132)
Summary: Similar to https://github.com/pytorch/pytorch/pull/137008, but for supporting _scaled_mm in the ABI-compatible mode.

Differential Revision: [D63757729](https://our.internmc.facebook.com/intern/diff/D63757729)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137132
Approved by: https://github.com/chenyang78
ghstack dependencies: #137008
2024-10-04 14:05:18 +00:00
David Berard
54094c0c26 [inductor][user triton] Check size hints to determine indexing dtype (#137234)
Previously, all integer inputs to user-defined triton kernels were assumed to be int32. This would result in errors if your input was actually an int64.

This PR checks the value to determine which dtype to use for indexing: if it is known to be < int_max, then use int32 (and add guards if relevant); if we can't check (e.g. unbacked symint), then use int64.

Differential Revision: [D63797975](https://our.internmc.facebook.com/intern/diff/D63797975)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137234
Approved by: https://github.com/eellison
2024-10-03 22:07:26 +00:00
Edward Z. Yang
cc8f1cddd4 Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972
Approved by: https://github.com/Skylion007
ghstack dependencies: #136934, #136935
2024-10-01 13:22:10 +00:00
niklasz
3f457ee1f6 Fix AOT Graph capture not propagating non_blocking copy parameter to … (#136513)
…inductor codegen.

Fixes #136260

**Note**: this is my first code contribution to torch so please let me know if there's anything I need to fix/some other convention I should follow.

Regarding the bug, re-running the issue's reproduction code:
```
import torch

def fn(x):
    return x.to(device="cuda", non_blocking=True)

inp = torch.randn(3, 4)

torch.compile(fn)(inp)
```

We now have the non_blocking being passed on to codegen properly:

```
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_1 =====
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4]"):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4][4, 1]cpu"):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.404000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:114] [0/0] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0),subclass_metadata=None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] TRACED GRAPH
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  ===== Forward graph 0 =====
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]     def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0), True);  arg0_1 = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32);  device_put = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         return (convert_element_type,)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1134] [0/0] [__output_code] Output code written to: /tmp/torchinductor_niklasz/ha/chaai264g6ribfw3q2qhl6ayjtaqaavku5wivxtzw4nabgd6htsv.py
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] Output code:
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] # AOT ID: ['0_inference']
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import torch
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import math
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import random
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import os
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import tempfile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from math import inf, nan
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch import device, empty_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] aten = torch.ops.aten
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] _quantized = torch.ops._quantized
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile = AsyncCompile()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile.wait(globals())
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del async_compile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def call(args):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1, = args
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     args.clear()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     assert_size_stride(arg0_1, (3, 4), (4, 1))
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         torch.cuda.set_device(0)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0.copy_(arg0_1, True)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         del arg0_1
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return (buf0, )
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     fn = lambda: call([arg0_1])
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] if __name__ == "__main__":
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
```
See above line `buf0.copy_(arg0_1, True)`. Specific log setting used: `export TORCH_LOGS="graph_code,aot_graphs,output_code"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136513
Approved by: https://github.com/eellison
2024-10-01 00:32:47 +00:00
eellison
18e707645c Substitute unbacked symints in expressions (#137020)
Differential Revision: [D63647095](https://our.internmc.facebook.com/intern/diff/D63647095)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137020
Approved by: https://github.com/ezyang
2024-09-30 23:07:22 +00:00
PyTorch MergeBot
8982906502 Revert "Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)"
This reverts commit 3ff2d93d9f.

Reverted https://github.com/pytorch/pytorch/pull/136972 on behalf of https://github.com/ezyang due to need to back out for merge conflict ([comment](https://github.com/pytorch/pytorch/pull/136972#issuecomment-2384182244))
2024-09-30 21:35:08 +00:00
Jez Ng
71aac59e93 Add Triton CPU as an Inductor backend (#133408)
The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend.

Differential Revision: [D63298968](https://our.internmc.facebook.com/intern/diff/D63298968)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133408
Approved by: https://github.com/jansel, https://github.com/blaine-rister, https://github.com/malfet
2024-09-30 20:24:52 +00:00
Edward Z. Yang
3ff2d93d9f Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972
Approved by: https://github.com/Skylion007
ghstack dependencies: #136917, #136934, #136935
2024-09-30 18:04:36 +00:00
PyTorch MergeBot
36428f91e9 Revert "Add Triton CPU as an Inductor backend (#133408)"
This reverts commit 31c0467594.

Reverted https://github.com/pytorch/pytorch/pull/133408 on behalf of https://github.com/int3 due to internal tests failing ([comment](https://github.com/pytorch/pytorch/pull/133408#issuecomment-2379692517))
2024-09-27 16:54:27 +00:00
Jez Ng
31c0467594 Add Triton CPU as an Inductor backend (#133408)
The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend.

Differential Revision: [D63298968](https://our.internmc.facebook.com/intern/diff/D63298968)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133408
Approved by: https://github.com/jansel, https://github.com/blaine-rister, https://github.com/malfet
2024-09-26 15:35:26 +00:00
Bin Bao
20a855bf01 [AOTI] Move stack_allocation logic from PythonWrapperCodegen (#136463)
Summary: Move stack_allocation logic from PythonWrapperCodegen to CppWrapperCpuArrayRef

Differential Revision: [D63319970](https://our.internmc.facebook.com/intern/diff/D63319970)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136463
Approved by: https://github.com/chenyang78
ghstack dependencies: #136062, #136461, #136462
2024-09-25 14:06:33 +00:00
Jokeren
cabfbef6cf [pytorch][PR] [inductor] More fixes on the keys of constants and signature dictionaries (#136514)
Summary: Previous PR forgets to change two other places that also create `constants` and `signature`.

Test Plan:
Imported from GitHub, without a `Test Plan:` line.
 {F1884584338}

Differential Revision: D63027728

Pulled By: Myrthan

Co-authored-by: Jokeren <robinho364@gmail.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136514
Approved by: https://github.com/jansel

Co-authored-by: Jokeren <robinho364@gmail.com>
2024-09-25 09:34:14 +00:00
Bin Bao
95c0f7493f [Inductor] Rename WrapperCodeGen to PythonWrapperCodegen (#136062)
Summary: Rename WrapperCodeGen to PythonWrapperCodegen to make its meaning more explicit.

Differential Revision: [D63300358](https://our.internmc.facebook.com/intern/diff/D63300358)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136062
Approved by: https://github.com/angelayi, https://github.com/chenyang78
2024-09-24 21:02:51 +00:00
Shangdi Yu
3bc073d728 [aoti] Fix workspace generation for triton (#135552)
Fixes #131337

- add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`.
- do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead.
- add workspace allocation generation code to `kernel_autotune_calls`. e.g.
```python
    workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8)
    workspace.zero_()
    .....
    triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0)
    del buf2, arg0_1, arg1_1, workspace
```
-  add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code.

The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `.

```cpp
    static constexpr int64_t int_array_0[] = {1280L, };
    static constexpr int64_t int_array_1[] = {1L, };
    AtenTensorHandle workspace_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda,  0, &workspace_handle));

        RAIIAtenTensorHandle workspace(workspace_handle);
        workspace.zero_();
```

- Fix handle grid_fn  for grid computation. Pass in "RBLOCK" to `split_scan_grid`
-  Fix dynamic shapes:
Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined.

The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code.

- We also generate slightly different cpp code depending on if `abi_compatible` is turned on.
```cpp
RAIIAtenTensorHandle workspace(workspace_handle);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get()));
```
vs

```cpp
    at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast<int64_t>((255L + s0)), static_cast<int64_t>(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA);
    workspace.zero_();
```

Test Plan:

```
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1  python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper
TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
TORCHINDUCTOR_CPP_WRAPPER=1  python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552
Approved by: https://github.com/desertfire
2024-09-22 04:51:37 +00:00
PyTorch MergeBot
d0cebedb31 Revert "Add Triton CPU as an Inductor backend (#133408)"
This reverts commit e498b02b47.

Reverted https://github.com/pytorch/pytorch/pull/133408 on behalf of https://github.com/jeanschmidt due to Broke internal signals, see D62737208 for more details ([comment](https://github.com/pytorch/pytorch/pull/133408#issuecomment-2353623816))
2024-09-16 18:33:33 +00:00
PyTorch MergeBot
0199fd4d7e Revert "[inductor] More fixes on the keys of constants and signature dictionaries (#135406)"
This reverts commit e54b559e88.

Reverted https://github.com/pytorch/pytorch/pull/135406 on behalf of https://github.com/jeanschmidt due to Reverting as it is breaking triton_mtia internal signals @jansel could you have a look and help get those changes merged? ([comment](https://github.com/pytorch/pytorch/pull/135406#issuecomment-2353557481))
2024-09-16 17:58:02 +00:00
Jez Ng
e498b02b47 Add Triton CPU as an Inductor backend (#133408)
The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133408
Approved by: https://github.com/jansel
2024-09-14 21:45:19 +00:00
PyTorch MergeBot
18f9331e5d Revert "[aoti] Fix workspace generation for triton (#135552)"
This reverts commit d383325392.

Reverted https://github.com/pytorch/pytorch/pull/135552 on behalf of https://github.com/izaitsevfb due to blocks revert of #135313, internal failures, see D62511427 ([comment](https://github.com/pytorch/pytorch/pull/135552#issuecomment-2349641372))
2024-09-13 17:47:36 +00:00
Rachel Guo
7834c0bb2c [AOTI][Tooling] Add stats summary (mean/min/max, etc) for jit inductor tensor value printing (#135887)
Summary:
As title. Follow up to add stats summary (mean/min/max, etc) for jit inductor tensor value printing as well.

The inductor python wrapper code level printing would look something like this:

 {F1859224287}

Test Plan: CI

Reviewed By: chenyang78

Differential Revision: D62415575

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135887
Approved by: https://github.com/chenyang78
2024-09-13 17:19:25 +00:00
Jokeren
e54b559e88 [inductor] More fixes on the keys of constants and signature dictionaries (#135406)
Previous PR forgets to change two other places that also create `constants` and `signature`. https://github.com/pytorch/pytorch/pull/135170

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135406
Approved by: https://github.com/jansel
2024-09-13 04:10:41 +00:00
Shangdi Yu
d383325392 [aoti] Fix workspace generation for triton (#135552)
Fixes #131337

- add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`.
- do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead.
- add workspace allocation generation code to `kernel_autotune_calls`. e.g.
```python
    workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8)
    workspace.zero_()
    .....
    triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0)
    del buf2, arg0_1, arg1_1, workspace
```
-  add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code.

The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `.

```cpp
    static constexpr int64_t int_array_0[] = {1280L, };
    static constexpr int64_t int_array_1[] = {1L, };
    AtenTensorHandle workspace_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda,  0, &workspace_handle));

        RAIIAtenTensorHandle workspace(workspace_handle);
        workspace.zero_();
```

- Fix handle grid_fn  for grid computation. Pass in "RBLOCK" to `split_scan_grid`
-  Fix dynamic shapes:
Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined.

The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code.

- We also generate slightly different cpp code depending on if `abi_compatible` is turned on.
```cpp
RAIIAtenTensorHandle workspace(workspace_handle);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get()));
```
vs

```cpp
    at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast<int64_t>((255L + s0)), static_cast<int64_t>(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA);
    workspace.zero_();
```

Test Plan:

```
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1  python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper
TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
TORCHINDUCTOR_CPP_WRAPPER=1  python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552
Approved by: https://github.com/desertfire
2024-09-12 23:53:09 +00:00
Rachel Guo
c1277945d3 [AOTI][Tooling] Support debug printing for inductor level extern kernel call such as externkernel.addmm, bmm, etc. (#135731)
Summary:
As title.

Effect after merging this diff would look something like this:

```
        print('inductor: before_launch - triton_poi_fused_0 - buf0', buf0)
        triton_poi_fused_0.run(buf0, 6, grid=grid(6), stream=stream0)
        print('inductor: after_launch - triton_poi_fused_0 - buf0', buf0)
        buf1 = empty_strided_cuda((16, 6), (6, 1), torch.float32)
        # Topologically Sorted Source Nodes: [linear], Original ATen: [aten.addmm]
        print('inductor: before_launch - extern_kernels.addmm - buf0', buf0)
        extern_kernels.addmm(buf0, reinterpret_tensor(arg2_1, (16, 16), (16, 1), 0), reinterpret_tensor(L__self___weight, (16, 6), (1, 16), 0), alpha=1, beta=1, out=buf1)
        print('inductor: after_launch - extern_kernels.addmm - buf0', buf0)
```

Context: D62272588 only support major triton kernel jit inductor debug printing codegen

Test Plan: CI & OSS CI

Reviewed By: chenyang78, ColinPeppler

Differential Revision: D62397017

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135731
Approved by: https://github.com/ColinPeppler
2024-09-12 17:31:10 +00:00
xinan.lin
13ee85ca5e [Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR. (#135312)
[Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135312
Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/eellison
2024-09-11 23:59:54 +00:00
Rachel Guo
1f15973657 [AOTI][Tooling][7/n] Add debug printing support for JIT inductor codegen path as well (#135285)
Summary:
1.  Add the debug printer call to a level lower for triton kernel python wrapper codegen path
2. Add `torch.save()` for jit inductor as well
3. This also fixes the issue introduced in D61949020 (at python wrapper code level for triton kernel not printing)

Test Plan:
```
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1  TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda
```

Differential Revision: D62272588

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135285
Approved by: https://github.com/chenyang78
2024-09-10 19:24:58 +00:00
xinan.lin
ca16956b20 [Inductor] Generalize device guard codegen for cpp_wrapper mode. (#134761)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134761
Approved by: https://github.com/jansel, https://github.com/EikanWang
ghstack dependencies: #134693
2024-09-10 10:11:52 +00:00
Bin Bao
9c6dff4941 [AOTI] Add C shim for aten.mkldnn_rnn_layer in cpp wrapper (#134857)
Summary: Support aten.mkldnn_rnn_layer in the ABI-compatible mode. Because aten.mkldnn_rnn_layer is an aten op, it is easier to add a C shim function for it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134857
Approved by: https://github.com/angelayi
2024-09-09 16:54:12 +00:00
Henry Tsang
3988b3468b [aoti][easy] remove breakpoint() in wrapper.py (#134807)
Differential Revision: D61687146

Remove an unintended breakpoint in code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134807
Approved by: https://github.com/YUNQIUGUO
2024-09-06 17:25:05 +00:00
qchip
f6398eb0fa dynamic shapes for combo_kenel/foreach_kernel (#134477)
This PR add dynamic shapes support to foreach and combo kernels for horizontal fusion.
A flag `combo_kernel_foreach_dynamic_shapes` (default False to avoid disturb production workflows) is added to _inductor/config.py. Setting it to True enables automatic dynamic shapes for foreach kernels. It is always enabled for combo kernels cases. Added unit cases.

This PR also fixes a flaky test case for [T198833257](https://www.internalfb.com/intern/tasks/?t=198833257)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134477
Approved by: https://github.com/mlazos
2024-08-30 19:58:20 +00:00
Yueming Hao
4f9c68454a [inductor]Let output or input_as_strided match exact strides (#130956)
Fixes #130394

TorchInductor doesn't respect original strides of outputs. It opens up optimization opportunities like changing up memory layout. But for some cases, such as the case in https://github.com/pytorch/pytorch/issues/130394, we do need the output match the exact stride as required. The correctness is the first priority goal. So, this PR adds a new API `ir.ExternKernel.require_exact_strides(x, exact_strides, allow_padding=False)` to fix the issue.  This PR enables dense and non-dense outputs' strides follow the strides required by semantics.

The comparison between the original and after this fix for the test is the below.

```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 8
    x1 = (xindex // 8)
-   x2 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (16*x1)), xmask)
    tmp1 = tmp0 + tmp0
-   tl.store(out_ptr0 + (x2), tmp1, xmask)
+   tl.store(out_ptr0 + (x0 + (16*x1)), tmp1, xmask)

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (16, 8), (16, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
-       buf1 = empty_strided_cuda((16, 8), (8, 1), torch.float32)
+       buf1 = empty_strided_cuda((16, 8), (16, 1), torch.float32)
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_copy_0.run(arg0_1, buf1, 128, grid=grid(128), stream=stream0)
        del arg0_1
    return (buf1, )
```

The buf1 is created with exact stride required by users, and its values are written in same stride with the input.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130956
Approved by: https://github.com/eellison, https://github.com/blaine-rister, https://github.com/desertfire
2024-08-29 03:06:58 +00:00
Rachel Guo
89929d9abc [AOTI][Tooling][4/n] Add torch.save() for individual intermediate tensor (#133871)
Differential Revision: D61415304

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133871
Approved by: https://github.com/ColinPeppler
2024-08-28 04:48:00 +00:00
Bin Bao
a4b44dd2ef [AOTI] Introduce DeferredCudaGridLine for cuda cpp wrapper (#129268)
Summary: Similar to https://github.com/pytorch/pytorch/pull/129135, use DeferredCudaGridLine to create a deferred grid computation line when generating cpp wrapper.

Differential Revision: [D61800622](https://our.internmc.facebook.com/intern/diff/D61800622)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129268
Approved by: https://github.com/angelayi
2024-08-27 19:23:25 +00:00
Jason Ansel
2c8fc3f4ce [inductor] Move imports to top of file in generated code (#134195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134195
Approved by: https://github.com/eellison
ghstack dependencies: #134194
2024-08-24 00:35:57 +00:00
Mwiza Kunda
be207af6e1 Disable unwrapping scalar tensors when used as outputs (#132859)
If the scalar tensor is an output tensor, it shouldn't be unwrapped (i.e. `.item()` called) since `tl.store` requires a pointer type for outputs. This issue only occurs for mutated buffers: the input tensor is also used as an output tensor.

Fixes #ISSUE_NUMBER

@yanboliang @jansel @ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132859
Approved by: https://github.com/jansel
2024-08-16 21:40:45 +00:00
Max Podkorytov
3d45717219 [ROCm][CK][Inductor] enable dynamic shapes for CK backend to gemm max autotune (#133285)
This PR enables dynamic shapes for the CK backend for gemm max autotune (see #125453).

This is achieved via unhardcoding the problem sizes from the template body and passing them as parameters instead.

We handle passing the problem sizes for the kernel call as well as for the benchmark call.

# Testing

`pytest test/inductor/test_ck_backend.py [-k dynamic]`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133285
Approved by: https://github.com/ColinPeppler
2024-08-16 06:05:23 +00:00
Xuehai Pan
758a0a88a2 [BE][Easy] enable ruff rule PIE790: unnecessary pass statement (#133200)
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change.

Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
2024-08-15 15:50:19 +00:00
Rachel Guo
c17d26c3c1 [AOTI][Tooling] A couple fixes / minor updates for initial debug printer (#133016)
Summary:
Follow up small diff to fix a couple issues:
-  add condition for cuda/gpu case to only print kernel name list in the second pass i.e. when we do the cpp wrapper codegen

- other minor fixes around `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT` option

Test Plan:
```
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="triton_poi_fused_0" AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda
```

Differential Revision: D60954888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133016
Approved by: https://github.com/ColinPeppler
2024-08-13 23:00:29 +00:00