Commit Graph

1801 Commits

Author SHA1 Message Date
Jason Ansel
e90cf4abcf [inductor] Add some typing to common.py (#145691)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145691
Approved by: https://github.com/malfet
ghstack dependencies: #145690
2025-01-27 06:27:13 +00:00
Jason Ansel
ddae87f792 [inductor] Add some typing to simd.py (#145690)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145690
Approved by: https://github.com/malfet
2025-01-27 06:27:13 +00:00
Nikita Shulga
71caac2b30 [MPSInductor] Add rand support (#145705)
Using Philox4 as PRNG

Test plan (other that CI)
Run
```python
mport torch
from torch._inductor.utils import run_and_get_code
from contextlib import nullcontext

def foo(x):
   return x * torch.randn_like(x)

foo_c = torch.compile(foo)

x = torch.ones(100, 100, device="mps")

y = foo_c(x)

print(y.mean().item(), y.std().item())
for i in range(25):
  print(y[i].mean(), y[i].std())
```
And observe that printed values are close to 0 and 1

TODO: Better `randint` algorithm for large ranges

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145705
Approved by: https://github.com/dcci, https://github.com/jansel
2025-01-27 06:07:36 +00:00
Jason Ansel
9007eb5f8e [inductor] Kernel memory analysis for use in heuristics (#142026)
This computes statistics about each kernel's memory usage that should allow us to write more precise heuristics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142026
Approved by: https://github.com/eellison
2025-01-25 04:58:54 +00:00
Aaron Gokaslan
f3304571fc [BE][Ez]: FURB148 - remove useless enumerate calls (#145619)
Remove useless enumerate calls

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145619
Approved by: https://github.com/drisspg
2025-01-24 23:37:15 +00:00
Davide Italiano
57591edca1 [mps/inductor] Add support for erfinv. (#145643)
After several rounds of refactoring, this seems to be done now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145643
Approved by: https://github.com/malfet, https://github.com/jansel
2025-01-24 22:55:44 +00:00
c8ef
a989a0b13a [NFC] Fix some minor typos. (#145599)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145599
Approved by: https://github.com/Skylion007
2025-01-24 18:58:59 +00:00
Bin Bao
b8087747f5 [inductor][BE] Enable test_cpu_cpp_wrapper in fbcode (#145373)
Differential Revision: D68278174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145373
Approved by: https://github.com/Skylion007
2025-01-24 17:59:13 +00:00
David Peixotto
97c0b7cb0a Add unique identifer to bmm thread_mm functions (#145303)
Summary:
The bmm template generates code like this

```
template<bool accum>
void cpp_fused_bmm_66_micro_gemm(...) {
    ...
}

void single_thread_mm() {
    ...
    cpp_fused_bmm_66_micro_gemm(...)
    ...
}

void threaded_mm() {
    ...
    cpp_fused_bmm_66_micro_gemm(...)
    ...
}

void cpp_fused_bmm_66(...)
{
    ...
    single_thread_mm(...);
    ...
    threaded_mm(...);
    ...
}
```

The generated  `fused_bmm` and `fused_bmm_microgemm` functions both have unique identifiers added to their names, but the `single_threaded_mm` and `threaded_mm` do not.

This diff adds unique identifies to those generated functions as well. The identifier is based on the kernel name. So for the example above we would generate a bmm template name like `cpp_fused_bmm_66_single_thread_mm()`.

Differential Revision: D68364772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145303
Approved by: https://github.com/leslie-fang-intel, https://github.com/frost-intel, https://github.com/hl475
2025-01-24 17:35:50 +00:00
PyTorch MergeBot
9d6927715f Revert "Fix triton masked loading for non-block tl.loads (#144782)"
This reverts commit 31c2f36989.

Reverted https://github.com/pytorch/pytorch/pull/144782 on behalf of https://github.com/ezyang due to This regresses compile time for one of our internal models by 20%, internal xref https://fb.workplace.com/groups/1075192433118967/posts/1591490218155850 ([comment](https://github.com/pytorch/pytorch/pull/144782#issuecomment-2612660287))
2025-01-24 14:28:48 +00:00
Benjamin Glass
d5629889f1 cpp_wrapper: Properly handle scalars when input to tensor arguments (#144910)
Additionally, reduce code duplication in `cpp_wrapper_cpu_array_ref.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144910
Approved by: https://github.com/desertfire
2025-01-24 02:06:35 +00:00
David Berard
b2c89bc115 [inductor][2/N] triton support post-#5512, user-defined triton kernels (#145348)
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This PR fixes user-defined triton kernel handling (in most cases) for these new triton commits.

What this PR fixes:
* in triton_kernel_wrap.py, AST->TTIR parsing was to be updated for the new triton API
* ir.py - don't remove None args when using newer triton versions
* wrapper.py - update signature & constant handling

What this doesn't fix:
* correct None handling - I want to do a closer look at constant handling (including None, equal_to_1, and other constants).
* cpp wrapper (which needs to be fixed for both user-defined triton kernels and inductor-generated kernels)

test/inductor/test_triton_kernels.py passed on triton commit 74de6b46, with the exception of three tests (those shown here: 1374074098)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145348
Approved by: https://github.com/jansel
ghstack dependencies: #145051
2025-01-24 00:34:01 +00:00
David Berard
b963ab5325 [inductor][1/N] triton support post-#5512, main components (#145051)
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This is an initial PR to add support for Triton versions after commit 5512 landed.

The main changes in 5220 and 5512 that need to be supported:
* AttrsDescriptor() gets replaced with a raw dict. The raw dict has the format `{(TUPLES): [["tt.divisibility", 16]]}`, where `(TUPLES)` is a tuple of indices, e.g. `((0,), (1,), (3,))` to indicate that args 0, 1, and 3 are divisible by 16. These indices are, themselves, represented as tuples to support nested inputs (e.g. an argument that's a tuple), but support for tuples is not implemented right now.
* "signature" changes: the signature now contains _all_ args, including constexpr and constant args.
* ASTSource now takes "constexprs" instead of "constants" - for example, equal-to-1 args are constants but not constexprs so we don't need to pass these args as "constants".

What this PR supports:
* Triton versions before Dec 9, 2024, and (partial support for) Triton versions after Jan 1, 2025
* (triton jan 1+) typical inductor-generated triton: updated AttrsDescriptor, signatures, constexpr/constant handling.

What this PR doesn't support (TODO in follow-up PRs):
* Triton versions between Dec 9, 2024 and before Jan 1, 2025
* (triton jan 1+) user-defined triton kernel support (this is implemented already in @anmyachev's patch)
* (triton jan 1+) triton_helper support (failing in triton codegen - needs investigation)
* (triton jan 1+) AOTI / cpp wrapper

thanks to @anmyachev for patches in https://github.com/intel/intel-xpu-backend-for-triton/blob/main/scripts/pytorch.patch, which contains most of these changes already

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145051
Approved by: https://github.com/jansel
2025-01-24 00:34:01 +00:00
PyTorch MergeBot
ce4a097bf7 Revert "Added swizzle searching, disabled fp16 accum, and enabled ping-pong for cutlass (#144829)"
This reverts commit 55084443ca.

Reverted https://github.com/pytorch/pytorch/pull/144829 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/144829#issuecomment-2610855579))
2025-01-23 19:37:54 +00:00
Shunting Zhang
d3f196909d [inductor] let inplace-padding support cpp-wrapper (#145325)
Some context: Inplace padding is an optimization to do padding in place. E.g., if a tensor has size [2048, 2047] and stride [2048, 1]. When we need pad one extra element to the end of each row (e.g. during mm padding), we can just reuse the original tensor and do the padding inplace. This saves memory and bandwidth.  One caveat for this optimization is, PyTorch does not allocate 2048 elements for the last row of the original tensor. It only allocate 2047 elements. So assuming the last row having enough space for 2048 elements may be wrong and cause OOB memory access (although I never see this happen maybe due to overallocation in the CUDACachingAllocation, this should better be fixed).

The fix is when we allocate the tensor, instead of doing something like:
```
  buf0 = randn_strided([2048, 2047], [2048, 1])
```
we do some small overallocation
```
  buf0 = randn_strided([2048, 2048], [2048, 1]).as_strided([2048, 2047], [2048, 1])
```

cpp_wrapper needs special handling since memory allocation goes thru different code path to python wrapper.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145325
Approved by: https://github.com/desertfire, https://github.com/jansel
ghstack dependencies: #140249
2025-01-23 09:22:38 +00:00
Nikita Shulga
70ccbade83 [MPSInductor] Add gamma op (#145341)
By moving `gamma` and `log_gamma` implementation from `Gamma.metal` to `c10/metal/special_math.h`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145341
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #145309
2025-01-22 19:37:45 +00:00
Isuru Fernando
31c2f36989 Fix triton masked loading for non-block tl.loads (#144782)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144782
Approved by: https://github.com/eellison
2025-01-22 14:30:56 +00:00
Shunting Zhang
3a58512613 [Inductor] inplace padding (#140249)
https://github.com/pytorch/pytorch/issues/139865

This PR may change the semantic of constant_pad_nd from 'clone' to 'view'. I tried a few tests to do inplace update. Looks like thanks to functionalization, this works fine.

Perf for `test_linear_and_cel`:
```
# TORCHINDUCTOR_INPLACE_PADDING=0 DO_PERF_TEST=1 python test/inductor/test_inplace_padding.py -k test_linear_and_cel
inductor_config.inplace_padding=False ms=83.311

# TORCHINDUCTOR_INPLACE_PADDING=1 DO_PERF_TEST=1 python test/inductor/test_inplace_padding.py -k test_linear_and_cel
inductor_config.inplace_padding=True ms=79.827
```

The saving is about 4ms (slightly less since we need fill 0 for the padding area). Similar savings for llm.c.
- Without the feature: 182.151ms per batch, 180.9K tokens/s
- With the feature:  178.278ms per batch, 183.9K tokens/s. There are 3K tokens/s increase.

Perf test shows compilation time regression. . I'm not sure if that's real. Will debug more. But a good thing is, there is no accuracy failure: [link](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2004%20Nov%202024%2020%3A23%3A22%20GMT&stopTime=Mon%2C%2011%20Nov%202024%2020%3A23%3A22%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=gh/shunting314/186/head&lCommit=03fd924ff382958daf5055dc8425d279e4e10a1e&rBranch=main&rCommit=c03324de2dfbbf0006818c86b88c92a3378f46b7) .

UPDATE: Perf test regression seems to be not real. Here is a rerun [link](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Thu%2C%2007%20Nov%202024%2001%3A29%3A55%20GMT&stopTime=Thu%2C%2021%20Nov%202024%2001%3A29%3A55%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=gh/shunting314/186/head&lCommit=7e2c8e5d9256ac06205e7cd5e740c9e20ce804d0&rBranch=main&rCommit=565a7942eee1ddc23067cdbae597443d0f2290a0). Our dashboard is not that reliable recently due to AWS migration.

Differential Revision: [D68340248](https://our.internmc.facebook.com/intern/diff/D68340248)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140249
Approved by: https://github.com/jansel, https://github.com/eellison
2025-01-22 03:37:06 +00:00
sanchitintel
46851022ff [Inductor][CPU] Add auto-tuning support for da8w8 sym act sym wgt GEMM (#143187)
## Summary

Templated `int8xint8->int32` GEMM that uses AMX ISA (present on Intel Xeon Gen 4 & above). Any epilogues such as weight scale, activation scale, and bias are applied per output block in a fused manner .
Performs well for large values of `M` dimension (assuming canonical dimensions [`M, K`] and [`K, N`] for the activation & weight matrices'/tensors' sizes) when the activation is quantized per-token.
Also supports SmoothQuant GEMM pattern when activation is quantized per-tensor (scalar scale) or per-token (vector scale is applied as an epilogue in this case).

Also increased coverage of GEMM template for uint8 activation, int8 weight GEMM UTs for when the activation zero point is a 1D tensor (the existing implementation only accepted 0D tensors). However, some of such UTs would have to be explicitly enabled with `max-autotune` Inductor config.

## Performance data

The templated codegened fused GEMM with M=32, K=4096, N=14336 used in LLaMA3 exhibits more than 2x perf-gain compared to oneDNN qlinear + mul (for activation's scale) with 48 cores of one socket of Xeon SP 4th gen Platinum 8468 when per-token quantization is used.

For M=1, K=4096, N=14336, regardless of whether per-tensor quantization was used for activation or per-token, the perf gain was more than 3x.

Intel OpenMP & libtcmalloc had been preloaded. All cores used by the workload corresponded to distinct physical cores.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143187
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel, https://github.com/jgong5

Co-authored-by: Leslie Fang <leslie.fang@intel.com>
2025-01-22 02:27:53 +00:00
Nikita Shulga
980c75fe6e [MPSInductor] Add TrueDiv and Round[Int|Decimal] (#145160)
That fixes `test_builtins_round_float_ndigits_neg` and `test_builtins_round`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145160
Approved by: https://github.com/jansel, https://github.com/dcci
2025-01-20 04:29:42 +00:00
Davide Italiano
8cc415774f [mps/inductor] Introduce a metal approx for erf() and use it. (#145161)
Probably we can do better, but this is a start.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145161
Approved by: https://github.com/malfet
2025-01-19 02:29:05 +00:00
Nikita Shulga
cede43e06b [MPSInductor][BE] NaN-propagating min/max to header (#145157)
May be to be later reused from eager op as well

Also, didn't know that Metal already have type_traits
And use `metal::isunorderder(a, b)` instead of `metal::isnan(a + b)` is it is defined as function that is equivalent  `a != a || b != b`, but I suspect it might have a best native implementation for the specific architecture

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145157
Approved by: https://github.com/dcci
2025-01-18 22:52:44 +00:00
Nikita Shulga
8a57234033 [MPSInductor] Implement i0 and i1 ops (#145092)
Using shared definitions with eager op

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145092
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #145023, #145087
2025-01-18 15:41:02 +00:00
Aaron Orenstein
2bf772d1ba PEP585 update - torch/_inductor/codegen (#145106)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145106
Approved by: https://github.com/bobrenjc93
2025-01-18 06:56:03 +00:00
Sam Larsen
55084443ca Added swizzle searching, disabled fp16 accum, and enabled ping-pong for cutlass (#144829)
Summary:

Test Plan:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144829
Approved by: https://github.com/Chillee
2025-01-18 02:39:22 +00:00
Bin Bao
0b151f260f [AOTI] Add an option to skip optimizing generated wrapper code (#144866)
Summary: In some cases, generated wrapper code faces a long cpp compilation time. As an alleviation, this PR adds an option to skip cpp compiler optimizers for the generated main wrapper function body.

D68174038

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144866
Approved by: https://github.com/chenyang78, https://github.com/hl475
2025-01-18 01:44:21 +00:00
Jason Ansel
7c1fb9b1ae [inductor] Refactor CachingAutotuner so that it can pickle (#144044)
These are refactors needed for #144288

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144044
Approved by: https://github.com/eellison
2025-01-18 01:44:16 +00:00
PyTorch MergeBot
94c0f15302 Revert "cpp_wrapper: Move #includes to per-device header files (#143909)"
This reverts commit d62b3979da.

Reverted https://github.com/pytorch/pytorch/pull/143909 on behalf of https://github.com/kit1980 due to breaking internal builds because of removal of torch‎/_inductor‎/codegen‎/aoti_runtime‎/implementation.cpp‎ ([comment](https://github.com/pytorch/pytorch/pull/143909#issuecomment-2597188669))
2025-01-17 00:36:38 +00:00
Mwiza Kunda
0e6d44df3f Add heuristic to fail block pointer match early (#144681)
This PR adds a heuristic to potentially fail the block pointer match early. Expressions like below take a long time to match using sympy (e.g. > 100 seconds)
```python
# torch._inductor.config.triton.use_block_ptr = True
# torch._inductor.config.triton.prefer_nd_tiling = True
# Expression from pytest -k test_max_pool2d1_dynamic_shapes_cuda:
 ((xindex//ps1))*((s2 - 3//2))**2 + 2*((xindex//ps1))*((s2 - 3//2)) + ((xindex//ps1)) + ((s2 - 3//2))*(ModularIndexing(xindex, ps0, ps0)) + (ModularIndexing(xindex, 1, ps0)) + (ModularIndexing(xindex, ps0, ps0))
```
Additionally, the heuristic for the number of dimensions based on the indexing expression is refined to only add dimensions for FloorDiv(index, denom) and ModularIndexing(index, denom, modulo) instead of including FloorDiv/ModularIndexing expressions that don't involve the index.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144681
Approved by: https://github.com/jansel
2025-01-16 21:57:30 +00:00
Nikita Shulga
41ec2e8d3e [MPSInductor] Fix codegen regression (#144924)
Caused by https://github.com/pytorch/pytorch/pull/144649

Do not try to insert anything into the header if wrapper is not ready yet

Fixes `test_sort_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144924
Approved by: https://github.com/dcci
ghstack dependencies: #144827, #144917
2025-01-16 02:12:42 +00:00
Nikita Shulga
05505771a0 [MPSInductor] Properly convert index (#144917)
By calling `self.index_to_str` from `load`,`store` and `check_bounds` in order to properly handle sizevars variables renames

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144917
Approved by: https://github.com/dcci
ghstack dependencies: #144827
2025-01-16 02:12:41 +00:00
Benjamin Glass
d62b3979da cpp_wrapper: Move #includes to per-device header files (#143909)
This prepares us for the next PR in the stack, where we introduce pre-compiled per-device header files to save compilation time.

Differential Revision: [D67938955](https://our.internmc.facebook.com/intern/diff/D67938955)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143909
Approved by: https://github.com/desertfire
2025-01-15 21:14:02 +00:00
Nikita Shulga
904641769e [MPSInductor] Implement pow() (#144827)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144827
Approved by: https://github.com/dcci, https://github.com/jansel
2025-01-15 20:11:34 +00:00
Nikita Shulga
d2ca8163c0 [MPSInductor] Support abs in MetalPrintExpr (#144826)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144826
Approved by: https://github.com/dcci
ghstack dependencies: #144509, #144798, #144795, #144796
2025-01-15 05:01:25 +00:00
Nikita Shulga
e2251fffbb [MPSInductor] Add min/max to MetalExprPrinter (#144798)
After that `GPUTests::test_avg_pool2d8_mps` and `GPUTests::test_avg_pool2d5_mps` passes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144798
Approved by: https://github.com/dcci
ghstack dependencies: #144509
2025-01-15 01:43:42 +00:00
Henry Tsang
8c2aa0c533 [cutlass backend] cexpr the arg before writing to cpp file (#144714)
Summary: The problem is for certain shapes, see unit test, one of the dimensions is like `s0 // 2`. If we use cutlass backend, this means writing that to C++ file, which would lead to C++ compilation error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144714
Approved by: https://github.com/ColinPeppler, https://github.com/chenyang78, https://github.com/desertfire
2025-01-14 23:09:44 +00:00
Nikhil Gupta
e666807653 [Fix]: Enable support for Arm Neon & SVE support for FP32 Gemm Wrapper (#144327)
**Performance Improvements**:
Linear Layer [ 1x512 * 512x512 ] ->  2x - 4x
Linear Layer [ 3x512 * 512x512 ] -> 2x - 4x

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144327
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/cfRod, https://github.com/malfet

Co-authored-by: Crefeda Rodrigues <crefeda.Rodrigues@arm.com>
2025-01-14 17:52:00 +00:00
leslie-fang-intel
9d98b66e7b [Inductor][CPP] Enable Epilogue Fusion for Grouped GEMM Template (#143897)
**Summary**
In this PR, we enable the epilogues fusion and code generation for Grouped GEMM. Here are the high-level description of how we implement it.

**Fusion**

- The Grouped GEMM Template produces a `Template Buffer` with a `MultiOutputLayout` and a set of `MultiOutput Buffers`, where each buffer corresponds to a specific GEMM.
- During the initial round of fusion, the `Template Buffer` and all associated `MultiOutput Buffers` are fused into a `FusedSchedulerNode` by extending the existing fusion design.
- In subsequent fusion rounds, this `FusedSchedulerNode` can further fuse with its epilogues, following the original fusion design principles.

**Code Gen**
We maintain a list of epilogues and codegen it one by one.

- If any of the GEMM has bias, we create  a extra `bias_add` epilogue and prepend it at first of the epilogue list.
- If any of the GEMM has no epilogue, we create a `to_bf16` copy epilogue and append it at last of the epilogue list.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_epilogue
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143897
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #143796
2025-01-14 06:07:50 +00:00
leslie-fang-intel
25de671ea8 [Inductor][CPP] Enable Grouped GEMM Template (#143796)
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012

- Support flexible number of GEMMs
- Share activation across GEMMs
  - The Grouped GEMM Template supports independent activations
  - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
  - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
  - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```

**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16

class M(torch.nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
        self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)

    def forward(self, x):
        return self.linear0(x), self.linear1(x)

if __name__ == "__main__":
    with torch.no_grad():
        input = torch.randn(batch_size, in_features, dtype=dtype)
        m = M(bias=bias).to(dtype=dtype).eval()
        cm = torch.compile(m)
        act_res = cm(input)
```

Generated Code:  https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py

**Next Step**

- Support Epilogue fusion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
2025-01-14 05:59:07 +00:00
Davide Italiano
35b46a75f1 [mps/inductor] Add support for round() (#144731)
With this change, inductor/test_view_on_aliased passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144731
Approved by: https://github.com/malfet
2025-01-14 05:56:13 +00:00
Davide Italiano
de9d6a25d7 [mps/inductor] Add support for ceil (#144715)
inductor/test_index_dynamic_shapes passes after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144715
Approved by: https://github.com/malfet
2025-01-14 01:16:47 +00:00
Nikita Shulga
c40d917182 [MPSInductor] Fix maximum/minimum for int types (#144665)
`metal::isnan` is only defined for floats, so provide a generic wrapper
that is false for integral types

TODO: Figure out why type propagantion is not working (or should it?)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144665
Approved by: https://github.com/dcci
2025-01-13 15:14:01 +00:00
Isuru Fernando
8633845090 Support nanj in inductor (#144064)
Fixes https://github.com/pytorch/pytorch/issues/144029
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144064
Approved by: https://github.com/amjames, https://github.com/eellison
2025-01-13 14:29:38 +00:00
Davide Italiano
417354d953 [mps/inductor] Add support for truncdiv(). (#144666)
Two other inductor tests pass after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144666
Approved by: https://github.com/malfet
2025-01-13 13:39:38 +00:00
Nikita Shulga
7e2239f1f0 [MPSInductor] Better error when kernel fails to compile (#144649)
Now error message looks as follows:
```
% python ../test/inductor/test_torchinductor.py -v -k test_cat_unbacked_2d_mps
test_cat_unbacked_2d_mps (__main__.GPUTests) ... inline_call []
stats [('calls_captured', 6)]
inductor [('extern_calls', 2), ('fxgraph_cache_miss', 1)]
aot_autograd [('total', 1), ('autograd_cache_bypass', 1), ('not_ok', 1)]
ERROR

======================================================================
ERROR: test_cat_unbacked_2d_mps (__main__.GPUTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3126, in wrapper
    method(*args, **kwargs)
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 12254, in new_test
    return value(self)
  File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 5885, in test_cat_unbacked_2d
    self.common(
  File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 620, in check_model_gpu
    check_model(
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 461, in check_model
    actual = run(*example_inputs, **kwargs)
  File "/Users/malfet/git/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 704, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 689, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1149, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1064, in codegen_and_compile
    compiled_fn = graph.compile_to_module().call
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 1977, in compile_to_module
    return self._compile_to_module()
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 2018, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/codecache.py", line 2768, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 40, in <module>
  File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 32, in _compile_mps_shader
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    kernel void generated_kernel(
        device float* out_ptr0,
        constant float* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        long x1 = (xindex) / (3);
        auto tmp0 = x1;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        auto tmp4 = 2;
        auto tmp5 = tmp1 < tmp4;
        long x0 = (xindex) % (3);
        auto tmp6 = in_ptr0[x0 + 3*(x1)];
        auto tmp7 = tmp5 ? tmp6 : 0.0;
        auto tmp8 = tmp1 >= tmp4;
        auto tmp9 = 2 + ks0;
        auto tmp10 = static_cast<long>(tmp9);
        auto tmp11 = tmp1 < tmp10;
        auto tmp12 = 1.0;
        auto tmp13 = tmp8 ? tmp12 : 0.0;
        auto tmp14 = tmp5 ? tmp7 : tmp13;
        long x2 = xindex;
        out_ptr0[x2] = static_cast<float>(tmp14);
    }
 with program_source:18:25: error: use of undeclared identifier 'ks0'
        auto tmp9 = 2 + ks0;
                        ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

To execute this test, run the following from the base repo dir:
    python test/inductor/test_torchinductor.py GPUTests.test_cat_unbacked_2d_mps

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 1 test in 0.472s

FAILED (errors=1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144649
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #144647, #144648
2025-01-13 13:38:03 +00:00
Nikita Shulga
a08bd8154e [MPSInductor] Add support for sizevars (#144662)
Just pass them as kernel arguments

After this change  `pytest test/inductor/test_torchinduct.py -v -k _mps` reports 330 failed, 429 passed  after and 335 failed, 424 passed before

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144662
Approved by: https://github.com/jansel
2025-01-13 06:22:38 +00:00
Nikita Shulga
91a65cbd31 [MPSInductor] Implement check_bounds (#144635)
Although at the moment it returns rather than rasises assert due to https://github.com/pytorch/pytorch/pull/144632

`pytest test/inductor/test_torchinductor.py -v -k _mps` score is `368
failed, 391 passed, 32 skipped`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144635
Approved by: https://github.com/jansel
2025-01-12 21:01:20 +00:00
Nikita Shulga
cec245806e [MPSInductor] Implement bitcasts (#144638)
That will be used to compile something like `torch.rand(32, device='mps').view(dtype=torch.int32)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144638
Approved by: https://github.com/dcci
2025-01-12 06:11:28 +00:00
Nikita Shulga
32a91dedc5 [MPSInductor] Properly generate index expressions (#144632)
Now test_slice_scatter4_mps passes

Before this change test_torchinductor.py reported 422 failed and 337 passed, after this change 412 failed 347 passed.

Fixes https://github.com/pytorch/pytorch/issues/144630

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144632
Approved by: https://github.com/dcci
2025-01-12 06:10:05 +00:00
Davide Italiano
e0f67405a1 [mps/inductor] Add support for exp(). (#144606)
inductor/test_silu now passes after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144606
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-01-12 00:38:11 +00:00