Commit Graph

161 Commits

Author SHA1 Message Date
Manuel Candales
874efa2d72 AOTI MPS Shim Implementation (#163865)
## MPS Shim API

*   Updated MPS shimification API with handles and function declarations:
    *   `AOTIMetalShaderLibraryHandle` and `AOTIMetalKernelFunctionHandle` types
    *   Library management: `aoti_torch_mps_create_shader_library`, `aoti_torch_mps_delete_shader_library`, `aoti_torch_mps_get_kernel_function`
    *   Kernel execution: `aoti_torch_mps_run_command_block`, `aoti_torch_mps_start_encoding`, `aoti_torch_mps_dispatch` variants, etc

## MPS Shader Codegen

*   Modified to generate source constants instead of direct `DynamicMetalShaderLibrary` instantiation:
    *   **Before**: `at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(...)MTL");`
    *   **After**: `const char* mps_lib_0_source = R"MTL(...)MTL";`
*   Updated kernel call generation  to use shimified functions:
    *   Generates calls to shimified API instead of direct libtorch calls

## Before vs After Comparison

### Section 1: Shader Library
**Before (Direct Library Object)**
```cpp
at::native::mps::DynamicMetalShaderLibrary mps_lib_0(R"MTL(
    ...
)MTL");
```
**After (Source String)**
```cpp
const char* mps_lib_0_source = (R"MTL(
    ...
)MTL");
```

### Section 2: Getter Functions & RAII Management

**Before (Direct Library Access)**
```cpp
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}

AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
```

**After (Shim API + RAII Wrapper)**
```cpp
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static auto kernel_handle = []() {
        AOTIMetalShaderLibraryHandle lib_handle = nullptr;
        AOTIMetalKernelFunctionHandle kern_handle = nullptr;

        aoti_torch_mps_create_shader_library(mps_lib_0_source, &lib_handle);
        aoti_torch_mps_get_kernel_function(lib_handle, "generated_kernel", &kern_handle);

        // RAII wrapper with custom deleter
        auto lib_deleter = [](AOTIMetalShaderLibraryHandle h) {{
            if (h) aoti_torch_mps_delete_shader_library(h);
        }};

        using LibDeleter = decltype(lib_deleter);
        using LibPtr = std::unique_ptr<AOTIMetalShaderLibraryOpaque, LibDeleter>;

        // Return pair of kernel handle and library smart pointer for cleanup
        return std::make_pair(kern_handle, LibPtr(lib_handle, lib_deleter));
    }();
    return kernel_handle.first;
}
```

### Section 3: Runtime Execution

**Before (Direct Library Methods)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    get_mps_lib_0()->runCommandBlock([&] {
        get_mps_lib_0()->startEncoding();
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 0, buf0);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(get_mps_lib_0_handle(), 2, arg1_1);
        get_mps_lib_0()->dispatch({static_cast<uint64_t>(10LL)});

    });

    ...

} // AOTInductorModel::run_impl
```

**After (Shim API with Lambda Pattern)**
```cpp
void AOTInductorModel::run_impl(...) {

    ...

    auto mps_lib_0_lambda_0 = [&](AOTIMetalKernelFunctionHandle handle) {
        aoti_torch_mps_start_encoding(handle);
        aoti_torch_mps_set_arg_tensor(handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(handle, 1, arg0_1);
        aoti_torch_mps_set_arg_tensor(handle, 2, arg1_1);
        aoti_torch_mps_dispatch_single(handle, static_cast<uint64_t>(10LL));
    };

    std::function<void(AOTIMetalKernelFunctionHandle)> mps_lib_0_func_wrapper_0 = mps_lib_0_lambda_0;
    aoti_torch_mps_run_command_block(get_mps_lib_0_handle(), aoti_torch_mps_shared_callback, &mps_lib_0_func_wrapper_0);

    ...

} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163865
Approved by: https://github.com/angelayi, https://github.com/desertfire
2025-10-09 09:28:10 +00:00
Kurt Mohler
791eff96c8 [MPS] Add igamma/igammac ops (#161927)
Fixes #161725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161927
Approved by: https://github.com/malfet
2025-09-02 20:52:02 +00:00
Nikita Shulga
2042d2174a [MPS] Migrate round unary op to Metal (#161712)
And actually use the right function, as [`torch.round`](https://docs.pytorch.org/docs/stable/generated/torch.round.html) doesn't use `std::round`, but rather `std::rint`, which can be easily seen by running something like
```python
import torch
print(torch.arange(-3., 3., step=.5, device='mps').round())
print(torch.arange(-3., 3., step=.5, device='mps').cpu().round())
```

Before this change it printed
```
tensor([-3., -3., -2., -2., -1., -1.,  0.,  1.,  1.,  2.,  2.,  3.], device='mps:0')
tensor([-3., -2., -2., -2., -1., -0.,  0.,  0.,  1.,  2.,  2.,  2.])
```
But after this change results match

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161712
Approved by: https://github.com/dcci
2025-08-28 16:45:07 +00:00
angelayi
23cf241039 [aoti][mps] Initialize mps kernels first (#159753)
In some cases we have mps kernels which are reused across higher-order-op subgraphs and the toplevel code. However, currently we initialize the variable for the mps kernel the first time we use it, which runs into an issue if we run into the mps kernel within a subgraph since the kernel will only be initialized within the subgraph scope. For instance:
```
if ...
    auto mps_lib_0_func = ...
    mps_lib_0_func->run()

// since we already used mps_lib_0 once, we don't re-initialize it
mps_lib_0_func->run()  // error, mps_lib_0_func not initialized
```

So the solution we took here is to initialize all the kernels at the beginning:
```
const std::shared_ptr<at::native::mps::MetalKernelFunction> get_mps_lib_0() {
    static const auto func = mps_lib_0.getKernelFunction("generated_kernel");
    return func;
}
AOTIMetalKernelFunctionHandle get_mps_lib_0_handle() {
    static const auto handle = AOTIMetalKernelFunctionHandle(get_mps_lib_0().get());
    return handle;
}
...
if ...
    get_mps_lib_0()->run()

get_mps_lib_0()->run()  // success
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159753
Approved by: https://github.com/malfet
ghstack dependencies: #159456, #159695
2025-08-06 07:54:29 +00:00
angelayi
b1ec088113 [mps] Turn on inductor dynamic shapes tests (#159456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159456
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-08-05 22:27:06 +00:00
PyTorch MergeBot
fb8f32ef52 Revert "[mps] Turn on inductor dynamic shapes tests (#159456)"
This reverts commit 19f1f9960d.

Reverted https://github.com/pytorch/pytorch/pull/159456 on behalf of https://github.com/davidberard98 due to Sorry - this causes a merge conflict with https://github.com/pytorch/pytorch/pull/159798, which I'm trying to land with co-dev to resolve a sev ([comment](https://github.com/pytorch/pytorch/pull/159456#issuecomment-3152751821))
2025-08-04 23:11:05 +00:00
angelayi
19f1f9960d [mps] Turn on inductor dynamic shapes tests (#159456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159456
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-08-04 22:44:31 +00:00
angelayi
25ef3d315d [aoti][mps] Dynamic reductions (#159355)
Dynamic kernel:
```cpp
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    constant long& r0_numel,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int x0 = xindex;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_1_cnt = 0; r0_1_cnt < static_cast<int>(metal::floor(static_cast<float>(0.99902343750000000 + 0.00097656250000000000*r0_numel))); ++r0_1_cnt) {
        int r0_1 = 1024 * r0_1_cnt + r0_index;
        if (r0_1 >= r0_numel) break;
        auto tmp0 = in_ptr0[x0 + 5*r0_1];
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10:🤘:threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, metal::min(static_cast<decltype(1024+r0_numel)>(1024), static_cast<decltype(1024+r0_numel)>(r0_numel)));
    if (r0_index == 0) out_ptr0[x0] = static_cast<float>(tmp1);
}

void AOTInductorModel::run_impl(...) {
    ...
    auto arg0_1_size = arg0_1.sizes();
    int64_t s77 = arg0_1_size[0];
    inputs.clear();
    [[maybe_unused]] auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        aoti_torch_mps_set_arg_int(mps_lib_0_func_handle, 2, s77);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))}, {static_cast<uint64_t>(1), static_cast<uint64_t>(std::min(static_cast<int64_t>(1024LL), static_cast<int64_t>(s77)))});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Static kernel:
```cpp
kernel void generated_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
    int x0 = xindex;
    auto tmp0 = in_ptr0[x0];
    auto tmp1 = in_ptr0[5 + x0];
    auto tmp3 = in_ptr0[10 + x0];
    auto tmp5 = in_ptr0[15 + x0];
    auto tmp2 = tmp0 + tmp1;
    auto tmp4 = tmp2 + tmp3;
    auto tmp6 = tmp4 + tmp5;
    out_ptr0[x0] = static_cast<float>(tmp6);
}

void AOTInductorModel::run_impl(...) {
    ...
    static constexpr int64_t int_array_0[] = {5LL, };
    static constexpr int64_t int_array_1[] = {1LL, };
    AtenTensorHandle buf0_handle;
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_float32, cached_torch_device_type_mps, this->device_idx_, &buf0_handle));
    RAIIAtenTensorHandle buf0(buf0_handle);
    auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel");
    auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get());
    mps_lib_0_func->runCommandBlock([&] {
        mps_lib_0_func->startEncoding();
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 0, buf0);
        aoti_torch_mps_set_arg_tensor(mps_lib_0_func_handle, 1, arg0_1);
        mps_lib_0_func->dispatch({static_cast<uint64_t>(5LL)});

    });
    arg0_1.reset();
    output_handles[0] = buf0.release();
} // AOTInductorModel::run_impl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159355
Approved by: https://github.com/malfet
2025-07-31 23:15:02 +00:00
Nikita Shulga
f946b25865 [MPS] Speedup argmax/argmin (#159524)
By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
2025-07-31 16:18:32 +00:00
Nikita Shulga
1293405c8d [MPS] Add simd_[arg][max|min] (#158990)
And add eager tests for those.
Re-implement `threadgroup_[max|min]` using those function as they are significantly faster (though much slower than eager, due to the arg part) than before, which could be verified by running the following script
```python
import itertools
import timeit
import torch
from torch.utils.benchmark import Compare, Measurement, Timer

def bench_unary_op(func, x, label) -> Measurement:
    sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
    t = Timer(
        stmt=f"f(x);{sync_cmd}",
        globals={"f": func, "x": x},
        language="python",
        timer=timeit.default_timer,
        sub_label=f"{func.__name__} ({str(x.dtype)})",
        description=label,
        env=torch.__version__,
    )
    return t.blocked_autorange()

def bench_reduction(
    reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
    rc = []

    # Bench 2D with reduction over dim=0
    def f(t):
        return reduction_func(t, dim=0)[0]

    f.__name__ = reduction_func.__name__
    f_c = torch.compile(f, dynamic=False, fullgraph=True)

    for size in (512, 1024, 2048, 4096):
        x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
        rc_c, rc_e = f(x), f_c(x)
        rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
        rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
        rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
    return rc

def main() -> None:
    #dtypes = [torch.float16, torch.float32, torch.bfloat16, torch.int32, torch.int64]
    dtypes = [torch.float32, torch.int32, torch.int64]

    # Profile reduction ops
    rc = []
    for op, dtype in itertools.product([torch.max], dtypes):
        rc.extend(bench_reduction(op, dtype=dtype))
    Compare(rc).print()

if __name__ == "__main__":
    torch._dynamo.config.cache_size_limit = 2**16
    main()
```

Produces the following table before
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      297.3      |       531.6       |       394.1       |        2550.5       |       773.0       |        4904.7       |       3647.2      |        9682.0
      max (torch.int32)    |      297.8      |       359.2       |       387.7       |        1179.4       |       768.2       |        2175.0       |       3677.1      |        4495.9
      max (torch.int64)    |      278.7      |       541.4       |       410.2       |        2873.3       |       858.9       |        5620.4       |       6107.2      |       11176.1

Times are in microseconds (us).
```
And after
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      307.9      |       265.3       |       401.0       |        340.8        |       766.5       |        661.9        |       3463.5      |        2829.5
      max (torch.int32)    |      293.5      |       263.1       |       405.0       |        338.8        |       761.4       |        672.5        |       3050.0      |        2688.6
      max (torch.int64)    |      308.2      |       255.7       |       417.4       |        341.4        |       877.0       |        695.0        |       5812.2      |        5762.2

```

`argmax`/`argmin` are much tricker due to the nan-handling logic that need to be added there.

Also fixes `torch.max/min` compilation for half-precision types, added regression types for it.

This PR also introduces a bunch of helper functions, such as `simd_broadcast` that works for int64 and `c10:🤘:pair` template, which are used by `simd_argmax` to return both value and index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158990
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-07-30 21:57:25 +00:00
angelayi
84058d1179 [aoti][mps] Fix cpu kernel generation (#158350)
In the case where we have both mps and cpu code which can be inductor compiled, we need to case on the device -- this requires the device field to be correctly passed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158350
Approved by: https://github.com/malfet
ghstack dependencies: #158349
2025-07-23 00:54:53 +00:00
Nikita Shulga
ec816d73b4 [MPS] Add shifted_chebyshev_polynomial_[tuvw] (#157488)
For eager and inductor

As for all other chebyshev ops, logic is simply compiled from 94716db222/aten/src/ATen/native/cuda/Math.cuh (L2821)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157488
Approved by: https://github.com/dcci
2025-07-03 15:48:37 +00:00
PyTorch MergeBot
b6276a425f Revert "[MPS] Add shifted_chebyshev_polynomial_[tuvw] (#157488)"
This reverts commit 9620994067.

Reverted https://github.com/pytorch/pytorch/pull/157488 on behalf of https://github.com/clee2000 due to caused slow test config to time out [GH job link](https://github.com/pytorch/pytorch/actions/runs/16037776972/job/45254574100) [HUD commit link](e124a0d88c) ([comment](https://github.com/pytorch/pytorch/pull/157464#issuecomment-3032676989))
2025-07-03 15:24:15 +00:00
Nikita Shulga
9620994067 [MPS] Add shifted_chebyshev_polynomial_[tuvw] (#157488)
For eager and inductor

As for all other chebyshev ops, logic is simply compiled from 94716db222/aten/src/ATen/native/cuda/Math.cuh (L2821)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157488
Approved by: https://github.com/dcci
ghstack dependencies: #157464
2025-07-02 23:29:35 +00:00
angelayi
17dab018e3 [aoti][mps] Fix deduplication of kernels (#156843)
Previously I was not correctly deduplicating kernels generated by mps, so it would generate multiple of the same kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156843
Approved by: https://github.com/desertfire
2025-06-26 21:03:05 +00:00
Nikita Shulga
3cbae6dde8 [MPSInductor][BE] Fix multistage reduction check (#156567)
From less than max threadgroup size to less or equal to that, which eliminates redundant trivial loops.

I.e. it changes shader code generated for
```python
import torch

def f(x):
    var, mean = torch.var_mean(x, dim=2, keepdim = True)
    return x / var, var

torch.compile(f)(torch.rand(1, 16, 1024, dtype=torch.float32, device='mps'))

```

from
```metal
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr1,
    device float* out_ptr2,
    constant float* in_ptr0,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int x0 = xindex;
    threadgroup float3 tmp_acc_0[1024];
    tmp_acc_0[r0_index * 1] = 0.0;
    for(auto r0_1_cnt = 0; r0_1_cnt < 1; ++r0_1_cnt) {
        int r0_1 = 1 * r0_index + r0_1_cnt;
        auto tmp0 = in_ptr0[r0_1 + 1024*x0];
        tmp_acc_0[r0_index * 1] = ::c10:🤘:welford_combine(tmp_acc_0[r0_index * 1], float3(tmp0, 0.0, 1.0));
    }
    auto tmp1 = c10:🤘:threadgroup_welford_combine(tmp_acc_0, 1024);
    auto tmp2 = 1023.0;
    auto tmp3 = tmp1.y / tmp2;
    out_ptr1[x0] = static_cast<float>(tmp3);
    for(auto r0_1_cnt = 0; r0_1_cnt < 1; ++r0_1_cnt) {
        int r0_1 = 1 * r0_index + r0_1_cnt;
        auto tmp4 = in_ptr0[r0_1 + 1024*x0];
        auto tmp5 = tmp4 / tmp3;
        out_ptr2[r0_1 + 1024*x0] = static_cast<float>(tmp5);
    }
}
```
to
```metal
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device float* out_ptr1,
    device float* out_ptr2,
    constant float* in_ptr0,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    int r0_1 = r0_index;
    int x0 = xindex;
    threadgroup float tmp_acc_0[1024];
    auto tmp0 = in_ptr0[r0_1 + 1024*x0];
    tmp_acc_0[r0_index * 1] = tmp0;
    auto tmp1 = c10:🤘:threadgroup_welford_reduce(tmp_acc_0, 1024);
    auto tmp2 = 1023.0;
    auto tmp3 = tmp1.y / tmp2;
    out_ptr1[x0] = static_cast<float>(tmp3);
    auto tmp4 = tmp0 / tmp3;
    out_ptr2[r0_1 + 1024*x0] = static_cast<float>(tmp4);
}

``

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156567
Approved by: https://github.com/dcci
ghstack dependencies: #156566
2025-06-23 14:49:26 +00:00
Nikita Shulga
4cd6e96bf0 [MPSInductor] Fix nested loop var elimination (#156566)
As reduction resuts must be kept around
Add regression test that is specific for this issue

Fixes https://github.com/pytorch/pytorch/issues/156426

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156566
Approved by: https://github.com/dcci
2025-06-23 04:35:16 +00:00
Xuehai Pan
6ff6630375 [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156313
Approved by: https://github.com/jingsh
2025-06-23 02:57:12 +00:00
PyTorch MergeBot
f1331f3f1b Revert "[BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)"
This reverts commit 3627270bdf.

Reverted https://github.com/pytorch/pytorch/pull/156313 on behalf of https://github.com/atalman due to export/test_torchbind.py::TestCompileTorchbind::test_compile_error_on_input_aliasing_contents_backend_aot_eager [GH job link](https://github.com/pytorch/pytorch/actions/runs/15804799771/job/44548489912) [HUD commit link](c95f7fa874) ([comment](https://github.com/pytorch/pytorch/pull/156313#issuecomment-2994171213))
2025-06-22 12:31:57 +00:00
Xuehai Pan
3627270bdf [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156313
Approved by: https://github.com/jingsh
2025-06-22 08:43:09 +00:00
Nikita Shulga
bb1f3d1a55 [MPSInductor] Improve _default dtype inference (#156121)
By just adding 'mps' as one of the backend options and fixing reduction op to actually return tuple of CSEVariable's rather than tuple of strings

Test plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156121
Approved by: https://github.com/dcci
2025-06-16 23:11:53 +00:00
Nikita Shulga
b6add8c8ba [MPSInductor] Fix remainder implementation for int types (#155891)
Introduce `c10:🤘:remainder` and call it from both inductor and eager implementation, with integer specialization, which should make it much faster than before, while still compliant with Python way of rounding up negative numbers.

This allows one to remove complex type detection logic from mps codegen and rely on Metal(C++) type system to figure out input and output types.

This fixes compilation of something like
```python
@torch.compile
def f(x, y):
    return x[y % 5]
```

which beforehand failed to compile with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    kernel void generated_kernel(
        device float* out_ptr0,
        constant long* in_ptr0,
        constant float* in_ptr1,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = xindex;
        auto tmp0 = in_ptr0[x0];
        auto tmp1 = 12;
        auto tmp2 = static_cast<float>(tmp0) - static_cast<float>(tmp1) * metal::floor(static_cast<float>(tmp0) / static_cast<float>(tmp1));
        auto tmp3 = 1024;
        auto tmp4 = static_cast<long>(tmp3);
        auto tmp5 = tmp2 + tmp4;
        auto tmp6 = tmp2 < 0;
        auto tmp7 = tmp6 ? tmp5 : tmp2;
        if ((tmp7 < 0) && (tmp7 > 1024)) return;
        auto tmp9 = in_ptr1[tmp7];
        out_ptr0[x0] = static_cast<float>(tmp9);
    }
 with program_source:372:28: error: array subscript is not an integer
        auto tmp9 = in_ptr1[tmp7];
                           ^~~~~
```

This fixes fail_to_compile for GPT2ForSequenceClassification Huggingface model using `transformers==4.44.2`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155891
Approved by: https://github.com/manuelcandales
2025-06-13 16:42:56 +00:00
angelayi
4d3ecefda5 [aoti][mps] Use cpp sym-expr printer (#155646)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155646
Approved by: https://github.com/desertfire
ghstack dependencies: #155752, #154287, #155582, #155583
2025-06-12 23:33:28 +00:00
angelayi
2e65d72e1e [aoti][mps] Fix int/symint kernel args (#155583)
Integer arguments to mps kernels need to go through a different function, since `aoti_torch_mps_set_arg` only takes a Tensor. So I added a `aoti_torch_mps_set_arg_int`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155583
Approved by: https://github.com/desertfire
ghstack dependencies: #155752, #154287, #155582
2025-06-12 23:33:28 +00:00
angelayi
ffbda61fbe [aoti][mps] Fix dynamic dispatch size (#155582)
In the case where we pass in a symint to the `dispatch` call, the compiler errors, so we need to cast the input to int64_t.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155582
Approved by: https://github.com/malfet
ghstack dependencies: #155752, #154287
2025-06-12 23:33:15 +00:00
angelayi
da50835bde [aoti] Support c10 calls (#155256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155256
Approved by: https://github.com/malfet
2025-06-10 00:45:59 +00:00
Scott Wolchok
8e1474d3c6 [inductor] small cleanups in torch/_inductor/codegen/mps.py (#154921)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154921
Approved by: https://github.com/jansel, https://github.com/Skylion007
2025-06-03 20:57:25 +00:00
Nikita Shulga
634ce22601 [MPSInductor] Fix codegen for nested multistage reductions (#154578)
Yet to write a unittest for it, but this fixes codegen for
```
python3 benchmarks/dynamo/torchbench.py --performance --only hf_T5  --backend inductor --inference --devices mps --float16
```

By correctly closing triple nested loop

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154578
Approved by: https://github.com/jansel, https://github.com/dcci
2025-05-29 17:09:25 +00:00
angelayi
26471fc203 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-23 05:45:35 +00:00
PyTorch MergeBot
47a01f3efb Revert "[aoti] Initial Metal support (#153959)"
This reverts commit 28bcd9eb30.

Reverted https://github.com/pytorch/pytorch/pull/153959 on behalf of https://github.com/angelayi due to previous PR broke frl build ([comment](https://github.com/pytorch/pytorch/pull/153959#issuecomment-2901825315))
2025-05-22 16:17:07 +00:00
Isuru Fernando
f419373dd3 [inductor] lowering for fractional_max_pool3d (#148630)
also a lowering with a reduction for large window_sizes for
fractional_max_pool2d

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148630
Approved by: https://github.com/eellison
2025-05-22 16:06:29 +00:00
angelayi
28bcd9eb30 [aoti] Initial Metal support (#153959)
An example generated file: P1816629015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153959
Approved by: https://github.com/malfet, https://github.com/desertfire
ghstack dependencies: #153964
2025-05-21 21:55:59 +00:00
Nikita Shulga
58dc80dff6 [MPSInductor] Fix indexing calculation (#153997)
By using `c10:🤘:floor_divie` primitive

Which fixes `test_flip_cat_mps` test, and makes `doctr_reco_predictor` and `doctr_det_predictor` pass accuracy checks (at least locally, scheduled a workflow dispatch to validate it in CI)

Before this change following script generated different compile and eager results
```python
import torch

def foo(unsqueeze, unsqueeze_1):
    cat_1 = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1)
    view = torch.ops.aten.view.default(cat_1, [4])
    slice_5 = torch.ops.aten.slice.Tensor(view, 0, 0, 3)
    rev_1 = torch.ops.aten.flip.default(slice_5, [0])
    return rev_1

if __name__ == "__main__":
    x = torch.arange(1.0, 3.0, device='mps').reshape(2, 1)
    y = torch.arange(5.0, 7.0, device='mps').reshape(2, 1)

    rc, (kernel,) = torch._inductor.utils.run_and_get_kernels(torch.compile(foo), x, y)
    print(kernel)
    print("Compile: ", rc)
    print("Eager: ", foo(x, y))
```
After this change
```
'''
    #include <c10/metal/utils.h>
    kernel void generated_kernel(
        device float* out_ptr0,
        constant float* in_ptr0,
        constant float* in_ptr1,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = xindex;
        auto tmp6 = in_ptr0[1 + (c10:🤘:floor_divide((-1)*x0, 2))];
        auto tmp11 = in_ptr1[1 + (c10:🤘:floor_divide((-1)*x0, 2))];
        auto tmp0 = (2 + ((-1)*x0)) % (2);
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        auto tmp4 = 1;
        auto tmp5 = tmp1 < tmp4;
        auto tmp7 = tmp5 ? tmp6 : 0.0;
        auto tmp8 = tmp1 >= tmp4;
        auto tmp9 = 2;
        auto tmp10 = tmp1 < tmp9;
        auto tmp12 = tmp8 ? tmp11 : 0.0;
        auto tmp13 = tmp5 ? tmp7 : tmp12;
        out_ptr0[x0] = static_cast<float>(tmp13);
    }
'''
Compile:  tensor([2., 5., 1.], device='mps:0')
Eager:  tensor([2., 5., 1.], device='mps:0')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153997
Approved by: https://github.com/dcci
ghstack dependencies: #153970, #153971
2025-05-21 00:03:46 +00:00
Nikita Shulga
03859242ce [Testing] Fix test_deterministic_... on MPS (#153970)
By decorated emitted kernels with `'''` rather than `"""`

To match regex in `torch._inductor.utils.run_and_get_kernels`
This fixes `test_deterministic_codegen_mps`, `test_deterministic_codegen_on_graph_break_mps` and `test_deterministic_codegen_with_suffix_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153970
Approved by: https://github.com/dcci, https://github.com/jansel
2025-05-20 21:15:14 +00:00
Nikita Shulga
db26aeaec2 [MPSInductor] Support numpy scalars handling (#153598)
By default, numpy computes results in float64 format, but when passed as an argument to MPS function, must be implicitly converted to float32, which naturally occurs in some networks, for example in speech_transformer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153598
Approved by: https://github.com/cyyever, https://github.com/dcci
ghstack dependencies: #153582
2025-05-15 16:48:25 +00:00
Nikita Shulga
a6c5b59067 [MPSInductor] Fix multistage reduction suffixes (#153362)
By invalidating all variable created during the loop except for the context of iterator_cache, as storage can be done inside reduction loop and clear `IteratorRangeEntry` codegen cache.

Which results in the following kernel for `x / x.sum()` if x size is 2048 and max thread group size is 1024
```metal
[[max_total_threads_per_threadgroup(1024)]]
kernel void generated_kernel(
    device half* out_ptr1,
    constant half* in_ptr0,
    uint2 thread_pos [[thread_position_in_grid]],
    uint2 group_pos [[thread_position_in_threadgroup]]
) {
    auto xindex = thread_pos.x;
    auto r0_index = thread_pos.y;
    threadgroup float tmp_acc_0[32];
    float tmp_acc_1 = 0;
    for(auto r0_0_cnt = 0; r0_0_cnt < 2; ++r0_0_cnt) {
        int r0_0 = 2 * r0_index + r0_0_cnt;
        auto tmp0 = static_cast<float>(in_ptr0[r0_0]);
        tmp_acc_1 += tmp0;
    }
    auto tmp1 = c10:🤘:threadgroup_sum(tmp_acc_0, tmp_acc_1, r0_index * 1, 1024);
    for(auto r0_0_cnt = 0; r0_0_cnt < 2; ++r0_0_cnt) {
        int r0_0 = 2 * r0_index + r0_0_cnt;
        auto tmp2 = static_cast<float>(in_ptr0[r0_0]);
        auto tmp3 = tmp2 / tmp1;
        out_ptr1[r0_0] = static_cast<half>(tmp3);
    }
}
```

Fixes compilation report reported while running `GPUTests.test_pattern_matcher_multi_user_mps` and `GPUTests.test_weight_norm_bwd_mps`

Fixes https://github.com/pytorch/pytorch/issues/152155

Though inductor tests are still failing, need to keep refining the variable invalidation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153362
Approved by: https://github.com/manuelcandales, https://github.com/dcci, https://github.com/jansel
2025-05-13 03:07:53 +00:00
Nikita Shulga
fe36d7dc44 [MPSInductor] Fix truncdiv implementation (#152788)
For integral dtypes it should be just an alias for division

Fixes `GPUTests.test_div7_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152788
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #152663, #152515, #152737, #152743, #152758
2025-05-05 13:31:51 +00:00
Nikita Shulga
d35e900c74 [MPSInductor] Make sure sizevars are computed (#152436)
Before calling the kernel

This fixes `GPUTests.test_float_repr_dynamic_shapes_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152436
Approved by: https://github.com/dcci
ghstack dependencies: #152363, #152430
2025-04-29 17:53:29 +00:00
Nikita Shulga
835f95490f [MPSInductor] Fix type promotion in _print_Max (#152430)
Run into this problem while re-enabling `test_float_repr_dynamic_shapes`, where `_print_Max` were called for integer and long argument which resulted in the following compilation error
```
error: call to 'max' is ambiguous
        out_ptr0[x0 + x1*metal::max(1, ks0)] = static_cast<float>(tmp26);
                         ^~~~~~~~~~
/System/Library/PrivateFrameworks/GPUCompiler.framework/Versions/32023/Libraries/lib/clang/32023.619/include/metal/metal_integer:2477:16: note: candidate function
METAL_FUNC int max(int x, int y)
               ^
/System/Library/PrivateFrameworks/GPUCompiler.framework/Versions/32023/Libraries/lib/clang/32023.619/include/metal/metal_integer:3686:17: note: candidate function
METAL_FUNC long max(long x, long y)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152430
Approved by: https://github.com/dcci
ghstack dependencies: #152363
2025-04-29 17:53:29 +00:00
Nikita Shulga
9c7b902cb2 [MPSInductor][BE] Make all reductions cacheable (#152363)
By moving actual implementaiton to `_reduction_nocache` and make reduction a caching wrapper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152363
Approved by: https://github.com/dcci
2025-04-29 02:49:22 +00:00
Nikita Shulga
cbcc03c2ad [MPSInductor][BE] Only include headers when needed (#152266)
Store headers used by shader in `MetalKernel.headers`
Add headers when function depending on it gets invoked
Generate majority of a special ops from template
Delete two unused functors: `entr` and `xlog1py` as they are decomposed by inductor anyway

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152266
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci, https://github.com/cyyever
2025-04-27 05:09:50 +00:00
Nikita Shulga
015b526a2a [MPSInductor] Warn-cast double as floats (#151963)
To support sqrt over dynamic shapes, i.e. make something like:
```python
torch.compile(dynamic=True)(lambda x: x * math.sqrt(x.size(0))
```
compilable into
```metal
// Source node to ATen node mapping:
// Graph fragment:
//   %scalar_tensor_default : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%arg0_1,), kwargs = {})
//   %convert_element_type_default : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%scalar_tensor_default, torch.float64), kwargs = {})
//   %sqrt_default : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%convert_element_type_default,), kwargs = {})
//   %convert_element_type_default_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sqrt_default, torch.float32), kwargs = {})
//   %mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg1_1, %convert_element_type_default_1), kwargs = {})
 kernel void generated_kernel(
     device float* out_ptr0,
     constant float* in_ptr0,
     constant long& ks0,
     uint xindex [[thread_position_in_grid]]
 ) {
     int x0 = xindex;
     auto tmp0 = in_ptr0[x0];
     auto tmp1 = ks0;
     auto tmp2 = static_cast<float>(tmp1);
     auto tmp3 = metal::sqrt(tmp2);
     auto tmp4 = static_cast<float>(tmp3);
     auto tmp5 = tmp0 * tmp4;
     out_ptr0[x0] = static_cast<float>(tmp5);
 }
```

TODO:
 - Figure out if this could be tweaked in fx-passes, but overhead is probably too high

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151963
Approved by: https://github.com/dcci
ghstack dependencies: #151869, #151871, #151872
2025-04-23 00:30:45 +00:00
Davide Italiano
49b7ffbb15 [MPS] Implement _print_Trunc_to_Int (#151964)
Fixes `test_device_assert_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151964
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-23 00:30:00 +00:00
Nikita Shulga
2f851ac8f8 [MPSInductor] Implement atomic_add store mode (#151871)
Which fixes `GPUTests.test_index_put2_mps`, `GPUTests. test__unsafe_masked_index_put_accumulate_mps` and dozen of scatter/gather tests that relied on atomic_add store mode

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151871
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #151869
2025-04-22 22:00:16 +00:00
Davide Italiano
470132c6a1 [MPS] Add support for hermite_polynomial_he (inductor/eager). (#151754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151754
Approved by: https://github.com/malfet, https://github.com/jansel
2025-04-20 17:44:40 +00:00
Nikita Shulga
0c77af3576 [MPSInductor] Add pow, log2 and FloorToInt ops (#151449)
That enables `test_pow_by_natural_log2_dynamic_shapes_mps`

Not sure why log2 printer function suffix is `OpaqueUnaryFn_log2`, rather than just `log2`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151449
Approved by: https://github.com/jansel
2025-04-16 15:56:21 +00:00
Nikita Shulga
070357b61a [MPSInductor] Fix silent correctness in bitcast (#151272)
By using Metal `as_type` which according to documentation does exactly
that:
> Metal adds an as_type<type-id> operator to allow any scalar or vector data type (that is not
a pointer) to be reinterpreted as another scalar or vector data type of the same size. The bits in
the operand are returned directly without modification as the new type. The usual type
promotion for function arguments is not performed.

Using `reinterpret_cast` created a potential silent correctness error when dtypes of different sizes were bitcast to each other
Add expicit cast to src_type to avoid errors due to type promotion (i.e.
soemthing like (x+1).view(dtype=torch.float16) would work correctly in
eager mode for int16 dtype, but would fail in compile, as arithmetic
operations will promote int16 to int32

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151272
Approved by: https://github.com/dcci
ghstack dependencies: #151224, #151246
2025-04-14 23:39:42 +00:00
Nikita Shulga
46ce8f7df6 [MPSInductor] Cast halfs to floats (#151246)
To avoid accuracy issues when small reductions are unrolled, cast half to float during the `load` op
As `op_math_t<half>` is indeed float

This fixes `test_unroll_small_reduction` for reduced precision types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151246
Approved by: https://github.com/dcci
ghstack dependencies: #151224
2025-04-14 19:47:04 +00:00
Nikita Shulga
9699cc3eb9 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 21:44:51 +00:00
PyTorch MergeBot
7762bddd87 Revert "[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)"
This reverts commit 71073caa00.

Reverted https://github.com/pytorch/pytorch/pull/151152 on behalf of https://github.com/malfet due to Another lint failure ([comment](https://github.com/pytorch/pytorch/pull/151152#issuecomment-2799027274))
2025-04-12 20:27:48 +00:00