Commit Graph

137 Commits

Author SHA1 Message Date
angelayi
ffbda61fbe [aoti][mps] Fix dynamic dispatch size (#155582)
In the case where we pass in a symint to the `dispatch` call, the compiler errors, so we need to cast the input to int64_t.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155582
Approved by: https://github.com/malfet
ghstack dependencies: #155752, #154287
2025-06-12 23:33:15 +00:00
angelayi
da50835bde [aoti] Support c10 calls (#155256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155256
Approved by: https://github.com/malfet
2025-06-10 00:45:59 +00:00
Scott Wolchok
8e1474d3c6 [inductor] small cleanups in torch/_inductor/codegen/mps.py (#154921)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154921
Approved by: https://github.com/jansel, https://github.com/Skylion007
2025-06-03 20:57:25 +00:00
Nikita Shulga
634ce22601 [MPSInductor] Fix codegen for nested multistage reductions (#154578)
Yet to write a unittest for it, but this fixes codegen for
```
python3 benchmarks/dynamo/torchbench.py --performance --only hf_T5  --backend inductor --inference --devices mps --float16
```

By correctly closing triple nested loop

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154578
Approved by: https://github.com/jansel, https://github.com/dcci
2025-05-29 17:09:25 +00:00
angelayi
26471fc203 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-23 05:45:35 +00:00
PyTorch MergeBot
47a01f3efb Revert "[aoti] Initial Metal support (#153959)"
This reverts commit 28bcd9eb30.

Reverted https://github.com/pytorch/pytorch/pull/153959 on behalf of https://github.com/angelayi due to previous PR broke frl build ([comment](https://github.com/pytorch/pytorch/pull/153959#issuecomment-2901825315))
2025-05-22 16:17:07 +00:00
Isuru Fernando
f419373dd3 [inductor] lowering for fractional_max_pool3d (#148630)
also a lowering with a reduction for large window_sizes for
fractional_max_pool2d

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148630
Approved by: https://github.com/eellison
2025-05-22 16:06:29 +00:00
angelayi
28bcd9eb30 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-21 21:55:59 +00:00
Nikita Shulga
58dc80dff6 [MPSInductor] Fix indexing calculation (#153997)
By using `c10:🤘:floor_divie` primitive

Which fixes `test_flip_cat_mps` test, and makes `doctr_reco_predictor` and `doctr_det_predictor` pass accuracy checks (at least locally, scheduled a workflow dispatch to validate it in CI)

Before this change following script generated different compile and eager results
```python
import torch

def foo(unsqueeze, unsqueeze_1):
    cat_1 = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1)
    view = torch.ops.aten.view.default(cat_1, [4])
    slice_5 = torch.ops.aten.slice.Tensor(view, 0, 0, 3)
    rev_1 = torch.ops.aten.flip.default(slice_5, [0])
    return rev_1

if __name__ == "__main__":
    x = torch.arange(1.0, 3.0, device='mps').reshape(2, 1)
    y = torch.arange(5.0, 7.0, device='mps').reshape(2, 1)

    rc, (kernel,) = torch._inductor.utils.run_and_get_kernels(torch.compile(foo), x, y)
    print(kernel)
    print("Compile: ", rc)
    print("Eager: ", foo(x, y))
```
After this change
```
'''
    #include <c10/metal/utils.h>
    kernel void generated_kernel(
        device float* out_ptr0,
        constant float* in_ptr0,
        constant float* in_ptr1,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = xindex;
        auto tmp6 = in_ptr0[1 + (c10:🤘:floor_divide((-1)*x0, 2))];
        auto tmp11 = in_ptr1[1 + (c10:🤘:floor_divide((-1)*x0, 2))];
        auto tmp0 = (2 + ((-1)*x0)) % (2);
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        auto tmp4 = 1;
        auto tmp5 = tmp1 < tmp4;
        auto tmp7 = tmp5 ? tmp6 : 0.0;
        auto tmp8 = tmp1 >= tmp4;
        auto tmp9 = 2;
        auto tmp10 = tmp1 < tmp9;
        auto tmp12 = tmp8 ? tmp11 : 0.0;
        auto tmp13 = tmp5 ? tmp7 : tmp12;
        out_ptr0[x0] = static_cast<float>(tmp13);
    }
'''
Compile:  tensor([2., 5., 1.], device='mps:0')
Eager:  tensor([2., 5., 1.], device='mps:0')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153997
Approved by: https://github.com/dcci
ghstack dependencies: #153970, #153971
2025-05-21 00:03:46 +00:00
Nikita Shulga
03859242ce [Testing] Fix test_deterministic_... on MPS (#153970)
By decorated emitted kernels with `'''` rather than `"""`

To match regex in `torch._inductor.utils.run_and_get_kernels`
This fixes `test_deterministic_codegen_mps`, `test_deterministic_codegen_on_graph_break_mps` and `test_deterministic_codegen_with_suffix_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153970
Approved by: https://github.com/dcci, https://github.com/jansel
2025-05-20 21:15:14 +00:00
Nikita Shulga
db26aeaec2 [MPSInductor] Support numpy scalars handling (#153598)
By default, numpy computes results in float64 format, but when passed as an argument to MPS function, must be implicitly converted to float32, which naturally occurs in some networks, for example in speech_transformer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153598
Approved by: https://github.com/cyyever, https://github.com/dcci
ghstack dependencies: #153582
2025-05-15 16:48:25 +00:00
Nikita Shulga
a6c5b59067 [MPSInductor] Fix multistage reduction suffixes (#153362)
By invalidating all variable created during the loop except for the context of iterator_cache, as storage can be done inside reduction loop and clear `IteratorRangeEntry` codegen cache.

Which results in the following kernel for `x / x.sum()` if x size is 2048 and max thread group size is 1024
```metal
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device half* out_ptr1,
    constant half* in_ptr0,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_0_cnt = 0; r0_0_cnt < 2; ++r0_0_cnt) {
        int r0_0 = 2 * r0_index + r0_0_cnt;
        auto tmp0 = static_cast<float>(in_ptr0[r0_0]);
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10:🤘:threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, 1024);
    for(auto r0_0_cnt = 0; r0_0_cnt < 2; ++r0_0_cnt) {
        int r0_0 = 2 * r0_index + r0_0_cnt;
        auto tmp2 = static_cast<float>(in_ptr0[r0_0]);
        auto tmp3 = tmp2 / tmp1;
        out_ptr1[r0_0] = static_cast<half>(tmp3);
    }
}
```

Fixes compilation report reported while running `GPUTests.test_pattern_matcher_multi_user_mps` and `GPUTests.test_weight_norm_bwd_mps`

Fixes https://github.com/pytorch/pytorch/issues/152155

Though inductor tests are still failing, need to keep refining the variable invalidation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153362
Approved by: https://github.com/manuelcandales, https://github.com/dcci, https://github.com/jansel
2025-05-13 03:07:53 +00:00
Nikita Shulga
fe36d7dc44 [MPSInductor] Fix truncdiv implementation (#152788)
For integral dtypes it should be just an alias for division

Fixes `GPUTests.test_div7_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152788
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #152663, #152515, #152737, #152743, #152758
2025-05-05 13:31:51 +00:00
Nikita Shulga
d35e900c74 [MPSInductor] Make sure sizevars are computed (#152436)
Before calling the kernel

This fixes `GPUTests.test_float_repr_dynamic_shapes_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152436
Approved by: https://github.com/dcci
ghstack dependencies: #152363, #152430
2025-04-29 17:53:29 +00:00
Nikita Shulga
835f95490f [MPSInductor] Fix type promotion in _print_Max (#152430)
Run into this problem while re-enabling `test_float_repr_dynamic_shapes`, where `_print_Max` were called for integer and long argument which resulted in the following compilation error
```
error: call to 'max' is ambiguous
        out_ptr0[x0 + x1*metal::max(1, ks0)] = static_cast<float>(tmp26);
                         ^~~~~~~~~~
/System/Library/PrivateFrameworks/GPUCompiler.framework/Versions/32023/Libraries/lib/clang/32023.619/include/metal/metal_integer:2477:16: note: candidate function
METAL_FUNC int max(int x, int y)
               ^
/System/Library/PrivateFrameworks/GPUCompiler.framework/Versions/32023/Libraries/lib/clang/32023.619/include/metal/metal_integer:3686:17: note: candidate function
METAL_FUNC long max(long x, long y)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152430
Approved by: https://github.com/dcci
ghstack dependencies: #152363
2025-04-29 17:53:29 +00:00
Nikita Shulga
9c7b902cb2 [MPSInductor][BE] Make all reductions cacheable (#152363)
By moving actual implementaiton to `_reduction_nocache` and make reduction a caching wrapper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152363
Approved by: https://github.com/dcci
2025-04-29 02:49:22 +00:00
Nikita Shulga
cbcc03c2ad [MPSInductor][BE] Only include headers when needed (#152266)
Store headers used by shader in `MetalKernel.headers`
Add headers when function depending on it gets invoked
Generate majority of a special ops from template
Delete two unused functors: `entr` and `xlog1py` as they are decomposed by inductor anyway

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152266
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci, https://github.com/cyyever
2025-04-27 05:09:50 +00:00
Nikita Shulga
015b526a2a [MPSInductor] Warn-cast double as floats (#151963)
To support sqrt over dynamic shapes, i.e. make something like:
```python
torch.compile(dynamic=True)(lambda x: x * math.sqrt(x.size(0))
```
compilable into
```metal
// Source node to ATen node mapping:
// Graph fragment:
//   %scalar_tensor_default : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%arg0_1,), kwargs = {})
//   %convert_element_type_default : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%scalar_tensor_default, torch.float64), kwargs = {})
//   %sqrt_default : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%convert_element_type_default,), kwargs = {})
//   %convert_element_type_default_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sqrt_default, torch.float32), kwargs = {})
//   %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg1_1, %convert_element_type_default_1), kwargs = {})
 kernel void generated_kernel(
     device float* out_ptr0,
     constant float* in_ptr0,
     constant long& ks0,
     uint xindex [[thread_position_in_grid]]
 ) {
     int x0 = xindex;
     auto tmp0 = in_ptr0[x0];
     auto tmp1 = ks0;
     auto tmp2 = static_cast<float>(tmp1);
     auto tmp3 = metal::sqrt(tmp2);
     auto tmp4 = static_cast<float>(tmp3);
     auto tmp5 = tmp0 * tmp4;
     out_ptr0[x0] = static_cast<float>(tmp5);
 }
```

TODO:
 - Figure out if this could be tweaked in fx-passes, but overhead is probably too high

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151963
Approved by: https://github.com/dcci
ghstack dependencies: #151869, #151871, #151872
2025-04-23 00:30:45 +00:00
Davide Italiano
49b7ffbb15 [MPS] Implement _print_Trunc_to_Int (#151964)
Fixes `test_device_assert_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151964
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-23 00:30:00 +00:00
Nikita Shulga
2f851ac8f8 [MPSInductor] Implement atomic_add store mode (#151871)
Which fixes `GPUTests.test_index_put2_mps`, `GPUTests. test__unsafe_masked_index_put_accumulate_mps` and dozen of scatter/gather tests that relied on atomic_add store mode

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151871
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #151869
2025-04-22 22:00:16 +00:00
Davide Italiano
470132c6a1 [MPS] Add support for hermite_polynomial_he (inductor/eager). (#151754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151754
Approved by: https://github.com/malfet, https://github.com/jansel
2025-04-20 17:44:40 +00:00
Nikita Shulga
0c77af3576 [MPSInductor] Add pow, log2 and FloorToInt ops (#151449)
That enables `test_pow_by_natural_log2_dynamic_shapes_mps`

Not sure why log2 printer function suffix is `OpaqueUnaryFn_log2`, rather than just `log2`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151449
Approved by: https://github.com/jansel
2025-04-16 15:56:21 +00:00
Nikita Shulga
070357b61a [MPSInductor] Fix silent correctness in bitcast (#151272)
By using Metal `as_type` which according to documentation does exactly
that:
> Metal adds an as_type<type-id> operator to allow any scalar or vector data type (that is not
a pointer) to be reinterpreted as another scalar or vector data type of the same size. The bits in
the operand are returned directly without modification as the new type. The usual type
promotion for function arguments is not performed.

Using `reinterpret_cast` created a potential silent correctness error when dtypes of different sizes were bitcast to each other
Add expicit cast to src_type to avoid errors due to type promotion (i.e.
soemthing like (x+1).view(dtype=torch.float16) would work correctly in
eager mode for int16 dtype, but would fail in compile, as arithmetic
operations will promote int16 to int32

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151272
Approved by: https://github.com/dcci
ghstack dependencies: #151224, #151246
2025-04-14 23:39:42 +00:00
Nikita Shulga
46ce8f7df6 [MPSInductor] Cast halfs to floats (#151246)
To avoid accuracy issues when small reductions are unrolled, cast half to float during the `load` op
As `op_math_t<half>` is indeed float

This fixes `test_unroll_small_reduction` for reduced precision types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151246
Approved by: https://github.com/dcci
ghstack dependencies: #151224
2025-04-14 19:47:04 +00:00
Nikita Shulga
9699cc3eb9 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 21:44:51 +00:00
PyTorch MergeBot
7762bddd87 Revert "[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)"
This reverts commit 71073caa00.

Reverted https://github.com/pytorch/pytorch/pull/151152 on behalf of https://github.com/malfet due to Another lint failure ([comment](https://github.com/pytorch/pytorch/pull/151152#issuecomment-2799027274))
2025-04-12 20:27:48 +00:00
Nikita Shulga
71073caa00 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 19:16:33 +00:00
Nikita Shulga
3b86cb8dff [MPSInductor][BE] Implement reduction caching (#151151)
That avoids double/triple invocation of welford reductions when both
mean and deviation must be returned

Code has been copy-n-pasted for Halide implementation
575f348965/torch/_inductor/codegen/halide.py (L1189-L1191)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151151
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824
2025-04-12 19:16:33 +00:00
Nikita Shulga
397d37acc5 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 03:11:38 +00:00
PyTorch MergeBot
77407b38a9 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 575f348965.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to Linter fails again, landrace this time? ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798392241))
2025-04-12 02:22:22 +00:00
Nikita Shulga
575f348965 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 00:46:01 +00:00
PyTorch MergeBot
83f14c0b06 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 5edfb4c4fa.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to I should have waited for lint ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798249264))
2025-04-12 00:21:14 +00:00
Nikita Shulga
5edfb4c4fa [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-11 23:21:35 +00:00
Nikita Shulga
c830c12a87 [MPSInductor] Fix tiled reduction logic (#150737)
In case of tiles, index must include both reduction dimentions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150737
Approved by: https://github.com/dcci
2025-04-06 00:20:41 +00:00
Nikita Shulga
7ac8186851 [MPSInductor] Speedup sum/prod reductions (#150566)
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions

Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512         | 202.1       | 131.8       |
| 1024x1024   |   780.6    | 176.9       |
| 2048x2048    |   1423.4       | 339.9      |
| 4096x4097    |    2982.2 | 1047.2      |

Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
  out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,  0,  0,  0,  0, 0,  0,  0,  0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
2025-04-05 02:47:27 +00:00
Davide Italiano
295b7e21eb [MPS/inductor] Add support for hermite_polynomial_h. (#150664)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150664
Approved by: https://github.com/malfet
2025-04-04 13:14:52 +00:00
Nikita Shulga
dee016ceb7 [MPSInductor] Add store_reduce method (#150457)
That restrict the store operation to 0th thread, which should be much better, shouldn't it
(Though I don't observe it in the benchmark)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150457
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150452
2025-04-02 05:12:49 +00:00
Nikita Shulga
f94ac263af [MPSInductor] Fix neg for unsigned types (#150412)
By more-or-less copy-n-pasting the fix from https://github.com/pytorch/pytorch/pull/94035

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150412
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150382, #150386
2025-04-01 16:52:41 +00:00
Nikita Shulga
965784eb9b [MPSInductor] Specify max_total_threads_per_threadgroup (#150247)
When generating reduction kernel, otherwise compiler can unroll loops too much that kernel could not be launched for the intended threadgroup size

Extend `c10:🤘:max` to accept different dtypes

Together this fixes `test_large_broadcast_reduction`

TODO:
  - Explore different threadgroup_sizes for best perf

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150247
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150246
2025-03-29 19:37:15 +00:00
Nikita Shulga
6aca002d82 [MPS] Add chebyshev_polynomial_[uvw] (#150060)
For both eager and inductor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150060
Approved by: https://github.com/dcci, https://github.com/jansel
2025-03-26 23:35:05 +00:00
Davide Italiano
e85ce64bde [MPS/Inductor] Add support for chebyshev_polynomial_t. (#149928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149928
Approved by: https://github.com/malfet
2025-03-25 21:02:13 +00:00
Davide Italiano
2b848ab192 [MPS/inductor] Add support for modified_scaled_bessel_k{0,1} (#149794)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149794
Approved by: https://github.com/malfet
2025-03-22 15:41:40 +00:00
Davide Italiano
0ed34210b2 [MPS] Add support for modified_bessel_k1 to eager and inductor. (#149687)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149687
Approved by: https://github.com/malfet
2025-03-21 04:59:06 +00:00
Davide Italiano
595293316d [MPS/Inductor] Add support for modified_bessel_k0. (#149593)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149593
Approved by: https://github.com/jansel
2025-03-20 04:51:44 +00:00
Davide Italiano
9cd52da45c [MPS/inductor] Add support for modified_bessel_i1. (#149379)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149379
Approved by: https://github.com/malfet
2025-03-18 06:02:33 +00:00
Davide Italiano
e4f6e4ac84 [MPS] Add inductor support for modified_bessel_i0. (#149342)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149342
Approved by: https://github.com/malfet
2025-03-17 21:45:51 +00:00
Nikita Shulga
d7d9a71e19 [MPSInductor] Add support for atan2 (#149216)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149216
Approved by: https://github.com/dcci
2025-03-14 21:53:03 +00:00
Davide Italiano
0bd863a62f [MPS] Add inductor support for i1e. (#149221)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149221
Approved by: https://github.com/malfet
2025-03-14 21:18:38 +00:00
Nikita Shulga
42e468d9b0 [MPSInductor] Adjust check_bounds (#147205)
To make upper bound inclusive, which fixes `test_vectorized_ops_masked` and results in the following code
```python
mps_lib_0 = compile_mps_shader("""
    #include <c10/metal/random.h>
    #include <c10/metal/special_math.h>
    #include <c10/metal/utils.h>
    kernel void generated_kernel(
        device float* out_ptr0,
        constant float* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (64);
        int x1 = (xindex) / (64);
        auto tmp5 = in_ptr0[x0 + 63*x1];
        int x2 = xindex;
        auto tmp0 = x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 63;
        auto tmp3 = tmp1 < tmp2;
        if (x0 > 63) return;
        auto tmp6 = tmp3 ? tmp5 : 7;
        out_ptr0[x2] = static_cast<float>(tmp6);
    }
""")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147205
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #147211
2025-03-14 17:26:00 +00:00
Davide Italiano
f2ea77c099 [MPS] Add inductor support for i0e. (#149180)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149180
Approved by: https://github.com/malfet
2025-03-14 16:15:52 +00:00