Commit Graph

23 Commits

Author SHA1 Message Date
Nikita Shulga
e2a5c42e7e [BE][MPS] Build metal kernels of MacOS-14+ (#159733)
Which makes `#if __METAL_VERSION__ >= 310` guards for `bfloat` use support unnecessary.
Rename `kernels_bfloat.metallib` into `kernels_basic` and remove custom build/selection logic.

Part of https://github.com/pytorch/pytorch/issues/159275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159733
Approved by: https://github.com/dcci
ghstack dependencies: #159731, #159732
2025-08-03 20:53:58 +00:00
Nikita Shulga
f946b25865 [MPS] Speedup argmax/argmin (#159524)
By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
2025-07-31 16:18:32 +00:00
Nikita Shulga
1293405c8d [MPS] Add simd_[arg][max|min] (#158990)
And add eager tests for those.
Re-implement `threadgroup_[max|min]` using those function as they are significantly faster (though much slower than eager, due to the arg part) than before, which could be verified by running the following script
```python
import itertools
import timeit
import torch
from torch.utils.benchmark import Compare, Measurement, Timer

def bench_unary_op(func, x, label) -> Measurement:
    sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
    t = Timer(
        stmt=f"f(x);{sync_cmd}",
        globals={"f": func, "x": x},
        language="python",
        timer=timeit.default_timer,
        sub_label=f"{func.__name__} ({str(x.dtype)})",
        description=label,
        env=torch.__version__,
    )
    return t.blocked_autorange()

def bench_reduction(
    reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
    rc = []

    # Bench 2D with reduction over dim=0
    def f(t):
        return reduction_func(t, dim=0)[0]

    f.__name__ = reduction_func.__name__
    f_c = torch.compile(f, dynamic=False, fullgraph=True)

    for size in (512, 1024, 2048, 4096):
        x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
        rc_c, rc_e = f(x), f_c(x)
        rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
        rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
        rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
    return rc

def main() -> None:
    #dtypes = [torch.float16, torch.float32, torch.bfloat16, torch.int32, torch.int64]
    dtypes = [torch.float32, torch.int32, torch.int64]

    # Profile reduction ops
    rc = []
    for op, dtype in itertools.product([torch.max], dtypes):
        rc.extend(bench_reduction(op, dtype=dtype))
    Compare(rc).print()

if __name__ == "__main__":
    torch._dynamo.config.cache_size_limit = 2**16
    main()
```

Produces the following table before
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      297.3      |       531.6       |       394.1       |        2550.5       |       773.0       |        4904.7       |       3647.2      |        9682.0
      max (torch.int32)    |      297.8      |       359.2       |       387.7       |        1179.4       |       768.2       |        2175.0       |       3677.1      |        4495.9
      max (torch.int64)    |      278.7      |       541.4       |       410.2       |        2873.3       |       858.9       |        5620.4       |       6107.2      |       11176.1

Times are in microseconds (us).
```
And after
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      307.9      |       265.3       |       401.0       |        340.8        |       766.5       |        661.9        |       3463.5      |        2829.5
      max (torch.int32)    |      293.5      |       263.1       |       405.0       |        338.8        |       761.4       |        672.5        |       3050.0      |        2688.6
      max (torch.int64)    |      308.2      |       255.7       |       417.4       |        341.4        |       877.0       |        695.0        |       5812.2      |        5762.2

```

`argmax`/`argmin` are much tricker due to the nan-handling logic that need to be added there.

Also fixes `torch.max/min` compilation for half-precision types, added regression types for it.

This PR also introduces a bunch of helper functions, such as `simd_broadcast` that works for int64 and `c10:🤘:pair` template, which are used by `simd_argmax` to return both value and index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158990
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-07-30 21:57:25 +00:00
Nikita Shulga
39a8f66d59 [BE] Use simdgroup_size constexpr (#157751)
Instead of every shader defining it separately, move it to `c10/metal/common.h`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157751
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157746
2025-07-08 03:46:20 +00:00
Nikita Shulga
9699cc3eb9 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 21:44:51 +00:00
PyTorch MergeBot
7762bddd87 Revert "[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)"
This reverts commit 71073caa00.

Reverted https://github.com/pytorch/pytorch/pull/151152 on behalf of https://github.com/malfet due to Another lint failure ([comment](https://github.com/pytorch/pytorch/pull/151152#issuecomment-2799027274))
2025-04-12 20:27:48 +00:00
Nikita Shulga
71073caa00 [MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
2025-04-12 19:16:33 +00:00
Nikita Shulga
397d37acc5 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 03:11:38 +00:00
PyTorch MergeBot
77407b38a9 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 575f348965.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to Linter fails again, landrace this time? ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798392241))
2025-04-12 02:22:22 +00:00
Nikita Shulga
575f348965 [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-12 00:46:01 +00:00
PyTorch MergeBot
83f14c0b06 Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 5edfb4c4fa.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to I should have waited for lint ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798249264))
2025-04-12 00:21:14 +00:00
Nikita Shulga
5edfb4c4fa [MPSInductor] Naive welford_reduce implementation (#150824)
Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
2025-04-11 23:21:35 +00:00
Nikita Shulga
7ac8186851 [MPSInductor] Speedup sum/prod reductions (#150566)
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions

Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512         | 202.1       | 131.8       |
| 1024x1024   |   780.6    | 176.9       |
| 2048x2048    |   1423.4       | 339.9      |
| 4096x4097    |    2982.2 | 1047.2      |

Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
  out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,  0,  0,  0,  0, 0,  0,  0,  0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
2025-04-05 02:47:27 +00:00
Nikita Shulga
52135db69a [BE] Fix signed/unsigned comparison warning (#150246)
One will see them only if compilation fails, but still
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150246
Approved by: https://github.com/cyyever, https://github.com/jansel
2025-03-29 15:12:42 +00:00
Nikita Shulga
61a64c20c4 [MPSInductor] Move threadfence at the right location (#149437)
Not sure how it worked in the past, but fence should be before first read from the shared memory, not after it.
This bug was exposed by https://github.com/pytorch/pytorch/pull/148969 which removed unnecessary barrier before calling `threadgroup_reduce` functions
Test plan:
```
% python3 generate.py --checkpoint_path checkpoints/stories15M/model.pth --prompt "Once upon a time" --device mps --compile
```
Before that it produced gibberish, now it works fine
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149437
Approved by: https://github.com/manuelcandales, https://github.com/dcci
2025-03-18 23:27:19 +00:00
Nikita Shulga
758522d56a [MPSInductor][EZ] Fix argmin/max signatures (#149020)
threadgroup_argmin used to return input type, which is wrong, it should have returned `int` or `long`

Change signatures of both thredgroup_argmin and threadgroup_argmax to return int, as group size is small, no need to carry over large integeres
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149020
Approved by: https://github.com/jansel
ghstack dependencies: #148969, #148975, #149004
2025-03-12 04:39:29 +00:00
Nikita Shulga
2328dcccb9 [MPSInductor] Implement Welford reduction (#146703)
Still work in progress, though fallback works as expected, but custom shader is not

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146703
Approved by: https://github.com/jansel, https://github.com/dcci
2025-02-08 05:00:00 +00:00
Nikita Shulga
495049860b [BE][Metal] Fix signed unsigned comparison warning (#146549)
I wish I knew how to extract Metal warnings during JIT compilation but https://developer.apple.com/documentation/metal/mtldevice/makelibrary(source:options:)?changes=_7&language=objc is a lie as `error:` stays `nil` unless shader compilation fails. But when it does following warnings are thrown
```
program_source:666:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (auto idx = 1; idx < size; ++idx) {
                     ~~~ ^ ~~~~
program_source:677:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (auto idx = 1; idx < size; ++idx) {
                     ~~~ ^ ~~~~
program_source:688:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (auto idx = 1; idx < size; ++idx) {
                     ~~~ ^ ~~~~
program_source:699:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (auto idx = 1; idx < size; ++idx) {
                     ~~~ ^ ~~~~
program_source:710:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (auto idx = 1; idx < size; ++idx) {
                     ~~~ ^ ~~~~
program_source:723:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (auto idx = 1; idx < size; ++idx) {

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146549
Approved by: https://github.com/dcci
2025-02-06 00:40:17 +00:00
Nikita Shulga
3525b834f0 [MPSInductor] Implement argmax/argmin (#146429)
TODOs:
 - Find test with NaN
 - Report internal compiler error when running `test_argmax_argmin1` (which is actually not enough shared memory)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146429
Approved by: https://github.com/dcci
ghstack dependencies: #146423, #146428
2025-02-04 19:16:06 +00:00
Nikita Shulga
7d60235aa6 [Metal] Small speedup for sum/prod (#146428)
As they can not really be invoked over empty arrays
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146428
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #146423
2025-02-04 19:10:33 +00:00
Nikita Shulga
5d81bc3696 [MPSInductor] Implement prod reduction (#146396)
Mostly reusing `sum` reduction logic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146396
Approved by: https://github.com/dcci
ghstack dependencies: #146369, #146370, #146380, #146389
2025-02-04 14:08:04 +00:00
Nikita Shulga
bbe95341d9 [MPSInductor] Implement min and max reductions (#146389)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146389
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #146369, #146370, #146380
2025-02-04 14:04:10 +00:00
Nikita Shulga
54ceb7c565 [MPSInductor] Add support for sum reduction (#146380)
- Add `threadgroup_sum` template to `c10/metal/reduction_utils.h` that so far uses barrier to compute the reductions

TODOs:
 - Implement efficient reduction using cooperative functions such as `simd_shuffle_down`
 - Figure out how to merge several sum reduction together
 - Implement `reduction_store` that will only write results from the first thread

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146380
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #146369, #146370
2025-02-04 06:23:44 +00:00