Commit Graph

643 Commits

Author SHA1 Message Date
Nikita Shulga
2e0c98ff05 [MPS] Add bicubic2d_aa (#149378)
Which is currently the most frequently requested op in https://github.com/pytorch/pytorch/issues/141287

Mostly done by refactoring `upsample_bilinear2d_aa` to accept Functor as one of the template arguments, which closely ideas from eec43cfbc0/src/libImaging/Resample.c as well as
bb42e4d137/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu (L472-L478)

Populate unit tests by copying upsample_bilinear_2d_aa and reusing it as upsample_bicubic2d_aa

At that point, only difference between upsample_bilinear2d_aa and upsample_bicubic2d_aa are convolution kernel function and size: for bilinear it's 3x3, for bicubic it's 5x5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149378
Approved by: https://github.com/dcci
2025-03-18 05:35:41 +00:00
Davide Italiano
c43e35d6f7 [MPS] Implement support for modified_bessel_i1 in eager. (#149368)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149368
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-18 03:29:10 +00:00
Davide Italiano
186cc7327c [MPS/BE] Remove decorator that skipped test on macOS 12. (#149365)
macOS 12 is not really supported anymore.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149365
Approved by: https://github.com/malfet
2025-03-18 00:58:08 +00:00
Davide Italiano
9f33c6f0a0 [MPS] Add support for modified_bessel_i0 in eager. (#149264)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149264
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-16 04:45:49 +00:00
Nikita Shulga
96795e9533 [BE] Parametrize TestMPS.test_binops_dtype_precedence (#149234)
No op change, just splits a longer tests into a series of a smaller ones
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149234
Approved by: https://github.com/atalman, https://github.com/dcci
ghstack dependencies: #149216, #149233
2025-03-15 00:37:11 +00:00
Isalia20
dd6e9df3d0 [MPS] fix attention enable_gqa crash on mps (#149147)
Fixes #149132

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149147
Approved by: https://github.com/malfet
2025-03-14 21:25:54 +00:00
Nikita Shulga
f2221b2fce [MPS] Add support for i1e (#149203)
Followup after https://github.com/pytorch/pytorch/pull/149174
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149203
Approved by: https://github.com/dcci
2025-03-14 17:33:52 +00:00
cyy
a9aae05a6b Remove test decorations on MacOS 12 (#148942)
MacOS 12 may reach EOL, as from https://endoflife.date/macos
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148942
Approved by: https://github.com/malfet
2025-03-14 17:22:37 +00:00
Davide Italiano
706c22549c [MPS] Add support for i0e in eager. (#149174)
Add `special.i0e` to XFAIL_GRADLIST for now, as its backward op is not yet implemented
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149174
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-14 14:43:46 +00:00
PyTorch MergeBot
be4e6c1c8e Revert "[MPS] Add support for i0e in eager. (#149174)"
This reverts commit b4745db904.

Reverted https://github.com/pytorch/pytorch/pull/149174 on behalf of https://github.com/malfet due to MPS are red on trunk ([comment](https://github.com/pytorch/pytorch/pull/149174#issuecomment-2723774600))
2025-03-14 06:35:01 +00:00
Nikita Shulga
db6d72213b [MPS] Add torch.special.bessel_[jy][01] implementations (#149123)
By copy-n-pasting functions from
f59064f2b7/aten/src/ATen/native/cuda/Math.cuh (L1463)

With an  ugly workaround for `bessel_y[01]` to avoid internal compiler exception on M1/M2 machines (see FB16863363 /  https://gist.github.com/malfet/e7785e4b572e7740887a83a2386ef769 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149123
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-03-14 05:13:55 +00:00
Davide Italiano
b4745db904 [MPS] Add support for i0e in eager. (#149174)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149174
Approved by: https://github.com/malfet
2025-03-14 02:51:28 +00:00
Nikita Shulga
924a247fbb [MPS] Enable angle and atan2 for torch.long (#149017)
This check was added by https://github.com/pytorch/pytorch/pull/85817, that introduced no unit-tests and its content seems to be totally unrelated to title/subject of that PR. Anyway, right now it seems to be working fine on MacOS-13+

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149017
Approved by: https://github.com/dcci
2025-03-12 04:48:52 +00:00
Nikita Shulga
c18858d633 [MPS] Make torch.mps.compile_shader public (#148972)
It was a private method in 2.6, but nothin changes in its API for 2.7
and it will likely remain the same in 2.8, so time to remove underscore
from its name

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148972
Approved by: https://github.com/Skylion007, https://github.com/atalman, https://github.com/seemethere, https://github.com/albanD, https://github.com/dcci
2025-03-11 20:20:58 +00:00
Nikita Shulga
b95889042c [MPS] Introduce strides unary op (#148468)
By adding following template
```metal
template <typename T, typename F>
kernel void unary_strided(
    device result_of<F, T>* output [[buffer(0)]],
    constant T* input [[buffer(1)]],
    constant long* sizes [[buffer(2)]],
    constant long* input_strides [[buffer(3)]],
    constant long* output_strides [[buffer(4)]],
    constant uint& ndim,
    uint index [[thread_position_in_grid]]) {
  F f;
  int pos[max_ndim];
  pos_from_thread_index(int(index), pos, sizes, ndim);
  const auto input_offs = offset_from_coord(pos, input_strides, ndim);
  const auto output_offs = offset_from_coord(pos, output_strides, ndim);
  output[output_offs] = f(input[input_offs]);
}
```
and instantiating it for all existing unary shaders, which eliminates the need to any intermediate copies.
No extra testing are needed as those cases are already covered by `test_output_grad_match_corrcoef_cpu_float32` as well as `test_unary_ops_storage_offset_strided`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148468
Approved by: https://github.com/dcci
2025-03-09 22:30:51 +00:00
Nikita Shulga
da923afdc7 [MPS][BE] Align bitshift behavior with CPU (#148719)
By casting the argument to output type
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148719
Approved by: https://github.com/Skylion007
ghstack dependencies: #148685, #148686
2025-03-07 18:28:14 +00:00
Nikita Shulga
f84710aef4 [MPS] Fix scalar to tensors bitshifts (#148686)
By introducing a concept of non-commutative binary op and renaming all op templates from `bitwise_foo_tensor` and `bitwise_foo_scalar` to `bitwise_foo_tensor_tensor` and `bitwise_foo_tensor_scalar`

Add regression tests

Please note, that for some undefined values MPS and CPU behaviors are different, for example
```
>>> import torch
>>> 4095 >> torch.arange(12, device="mps", dtype=torch.uint8)
tensor([255, 255, 255, 255, 255, 127,  63,  31,  15,   7,   3,   1],
       device='mps:0', dtype=torch.uint8)
>>> 4095 >> torch.arange(12, device="cpu", dtype=torch.uint8)
tensor([255, 127,  63,  31,  15,   7,   3,   1,   0,   0,   0,   0],
       dtype=torch.uint8)
```
Because on CPU scalar is cast to output dtype before operation is performed, but on MPS this happens after the op is done

Fixes https://github.com/pytorch/pytorch/issues/147889
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148686
Approved by: https://github.com/albanD
ghstack dependencies: #148685
2025-03-07 18:28:14 +00:00
Isalia20
02e1580e39 [MPS] fix crash for mse loss with 0 numel inputs (#148608)
Fixes #148589

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148608
Approved by: https://github.com/malfet
2025-03-06 03:32:34 +00:00
Nikita Shulga
864b75dd50 [MPS] Fix unary_kernel_strided logic (#148512)
Fixes bug introduced by https://github.com/pytorch/pytorch/pull/148350
Before this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[  0.0000,   1.4142,   2.0000,   2.4495],
        [ 80.0000,  82.0000,  84.0000,  86.0000],
        [ 96.0000,  98.0000, 100.0000, 102.0000],
        [112.0000, 114.0000, 116.0000, 118.0000]], device='mps:0')
```
After this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[0.0000, 1.4142, 2.0000, 2.4495],
        [4.0000, 4.2426, 4.4721, 4.6904],
        [5.6569, 5.8310, 6.0000, 6.1644],
        [6.9282, 7.0711, 7.2111, 7.3485]], device='mps:0')
```
One can not avoid copies if both input and output tensors have the same strides, one needs to make sure that they are dense-in-storage (transposed tensor would be dense, but say selecting every odd and even column wouldn't)

Add regression test to prevent those from happening again

Also, no need to check that sizes match, luckily it is checked by the structured op (and `out` for unary ops does not support broadcasting, I just checked)

Revived needs_copy_logic, though it  will become irrelevant after https://github.com/pytorch/pytorch/pull/148468 is landed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148512
Approved by: https://github.com/janeyx99
2025-03-05 15:57:54 +00:00
Isalia20
0c0a4baddd [MPS] unary kernels - avoid copying tensors if they have same stride (#148350)
I was a bit concerned when I saw in #148272 that metal unary kernel was 0.02x of the performance of what we had with MPS Graphs for sqrt(for non contiguous) tensors. This change makes it so that copying is only done if we don't have same strided tensors(for input/output). So if out tensor is not provided then we don't do copy(don't call contiguous) at all and dispatch the kernel as is. After making this change the script that I listed at the end of the above PR has the same execution time as the non-transposed one.

Times for reference(on transposed tensor where matrix is NxN matrix):

| N     | time_old           | time_new           |
|-------|--------------------|--------------------|
| 100   | 0.0002241021       | 0.0001548659       |
| 1000  | 0.0005934822       | 0.0002150342       |
| 10000 | 0.3242016407       | 0.0045755033       |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148350
Approved by: https://github.com/janeyx99
2025-03-04 23:20:26 +00:00
Isalia20
439395c0ae [MPS] add slogdet and logdet implementations to mps (#148287)
Low hanging fruits, all ops for these are implemented so just adding them to native functions adds the functionality on mps. Probably next op I should add should be lu solve seeing as how many ops need it for the grad calculation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148287
Approved by: https://github.com/malfet
2025-03-04 19:49:23 +00:00
Nikita Shulga
84502baaff [MPS] Fix sqrt and other for torch.chalf (#148285)
Those kernels, instead of being instantiated for half2 (which corresponds to ComplexHalf) were instnatiated for short2, which resuled in the following test
```
% python3 -c "import torch; print(torch.rand(6, device='mps', dtype=torch.chalf).sqrt())"
```
Fail with
```
RuntimeError: Failed to create function state object for: sqrt_complex_half_half
```

As sqrt is not implemented for CPU, add explicit test to `test_sqrt`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148285
Approved by: https://github.com/dcci
2025-03-03 16:03:54 +00:00
Isalia20
19de523de6 [MPS] metal unary kernel for sqrt (#148272)
Issue #148219 highlighted the high dispatch times of ops which ran with MPS Graph on smaller tensors. This PR rewrites the sqrt with metal kernel to mitigate that issue

## Speedups:

Matrix size means NxN matrix here.
![speedup_sqrt](https://github.com/user-attachments/assets/db0a705b-1a0e-42b4-bd42-4e7960415c81)

Code to generate the times(needs building the torch with old time and new time):
```python
import torch
import numpy as np
import time
import csv

matrix_sizes = [1, 100, 1000, 10_000]
num_runs = 1000
warmup_runs = 3

def run_sqrt(A):
    torch.mps.synchronize()
    start = time.perf_counter()
    c = torch.sqrt(A)
    torch.mps.synchronize()
    end = time.perf_counter()
    return c, end - start

results = {
    'N': [],
    'mean_time': [],
    'std_time': []
}

for n in matrix_sizes:
    print(f"\nBenchmarking N={n}")

    try:
        A_mps = torch.rand((n, n), dtype=torch.float32, device="mps")

        for _ in range(warmup_runs):
            _, _ = run_sqrt(A_mps)

        times = []
        for _ in range(num_runs):
            _, t = run_sqrt(A_mps)
            times.append(t)

        mean_time = np.mean(times)
        std_time = np.std(times)

        results['N'].append(n)
        results['mean_time'].append(mean_time)
        results['std_time'].append(std_time)

        print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")

    except RuntimeError as e:
        print(f"Error for N={n}: {e}")
        continue

with open('sqrt_benchmark_times_new.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['N', 'mean_time', 'std_time'])
    for i in range(len(results['N'])):
        writer.writerow([
            results['N'][i],
            results['mean_time'][i],
            results['std_time'][i]
        ])

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148272
Approved by: https://github.com/malfet
2025-03-02 00:45:45 +00:00
Nikita Shulga
3a0c9f7f9d [MPS] Fix SDPA crash (#148239)
If operation is invoked with mask twice it will crash, as mask expansion logic was implemented inside cache creation block, which is executed only once for all shapes

Fixes https://github.com/pytorch/pytorch/issues/148194 which is a regression introduced by https://github.com/pytorch/pytorch/pull/147545
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148239
Approved by: https://github.com/dcci
2025-03-01 13:06:51 +00:00
Nikita Shulga
735d7b1af6 [EZ][BE] Increase tolerances for interpolate op (#148224)
Not sure why tolerances were set like that, this logic was added in https://github.com/pytorch/pytorch/pull/104181 without much explanation
But if I'm to make a guess, it's likely due to the inaccuracy of bilinear op, that has since been replaced by shader
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148224
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #148154, #148187, #148211
2025-03-01 13:03:59 +00:00
Isalia20
08434df1f2 [MPS] fix empty place holder error for smooth l1 loss (#148133)
Fixes #123171

And parametrizes the tests for it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148133
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-01 02:32:45 +00:00
Nikita Shulga
e5e31050d3 [MPS] Implement linear1d as shader (#148154)
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
```python
import torch
import time
import subprocess

def benchmark(device, dtype):
    # Create example inputs
    x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype)
    sf = .5

    # Check output
    y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
    z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear")
    outputs_match = torch.allclose(y.cpu(), z)
    if not outputs_match:
       atol = (y.cpu() - z).abs().max()
       rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
       print(f"atol={atol} rtol={rtol}")

    # Measure time manually
    start_time = time.time() * 1000
    for _ in range(1000):
        y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
    torch.mps.synchronize
    end_time = time.time() * 1000
    manual_delta = (end_time - start_time)
    average_time = f"{manual_delta:6.1f}"

    return "True " if outputs_match else "False", average_time

outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
    for dtype in [torch.float32, torch.float16, torch.bfloat16]:
        outputs_match, average_time = benchmark(device, dtype)
        outputs_match_list.append(str(outputs_match))
        average_time_list.append(average_time)

brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device            :                MPS        |               CPU")
print("Dtype             :   FP32  |  FP16  |  BF16  |  FP32  |  FP16  |  BF16  ")
print(f"Outputs Match     :  ", " |  ".join(outputs_match_list))
print(f"Average Time (us) :", "  |".join(average_time_list))
```

Benchmark results after the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device            :                MPS        |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |  FP32  |  FP16  |  BF16
Outputs Match     :   True  |  True  |  True  |  True  |  True  |  True
Average Time (us) :    2.5  |   2.1  |   2.2  | 161.4  | 115.0  | 161.1
```
And before the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device            :                MPS        |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |  FP32  |  FP16  |  BF16
Outputs Match     :   True  |  True  |  True  |  True  |  True  |  True
Average Time (us) :  354.0  | 336.0  | 332.4  | 145.5  | 114.7  | 148.3
```

Fixes https://github.com/pytorch/pytorch/issues/144245
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148154
Approved by: https://github.com/dcci
2025-02-28 16:47:42 +00:00
Davide Italiano
683e083e8d [MPS] Add support for entr() in eager. (#147948)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147948
Approved by: https://github.com/malfet
2025-02-26 19:55:02 +00:00
Nikita Shulga
00732c3f7e [MPS] Implemented masked_fill_scalar as shader (#147369)
- Move `pos_from_thread_index and `offset_from_pos` from `UnfoldBackward.metal` into `c10/metal/indexing.h` header
- Initial idea were to implement `StridedTensor` and `ConstStridedTensor` and use them to have masked_fill kernel a something simple as the following loop
```metal
ConstStridedTensor<bool> mask(mask_data, sizes, mask_strides, ndim);
if (mask[thread_index]) {
  StridedTensor<T> input(input_data, sizes, input_strides, ndim);
  input[thread_index] = val;
}
```
But though it looks elegant and works correctly, performance wise it's much slower that the existing MPS shader (see table below), as int64 divisions on M2 GPU are really slow

- Solved performance issue by implementing 3 flavors of the same shader: `dense`, that is used when both input and mask are dense tensors of the same size, `broadcast`, which is used when `mask` is leading dimensions expandable into input tensor and `strided`  which is a general purpose fallback, but still computes position in the tensors only ones. As result, perf is even better than existing MPS shader for dense and broadcast able tensors.

Performance measured on M2Pro thru different iterations of the same shader

| dtype | MPS | int64-idx | int64-inlined | 32-bit strided | 32-bit broadcasted |
| ------|------| -----|   ---- | --- | ---- |
| float32 | 2.8 msec  | 41.6 msec | 26.9 msec | 5 msec | 2.4 msec |
| float16 | 1.86 msec | 38.2 msec| 26.6 msec | 4.6 msec | 1.9 msec |
|bfloat16|1.86 msec |38.3 msec | 26.6 msec | 4.6 msec | 1.9 msec |

And benchmark script
```python
import torch

from timeit import default_timer
from itertools import product
from torch.utils.benchmark import Measurement, Timer

def bench_mask_fill(
    n,
    binary_func,
    dtype=torch.float32,
) -> Measurement:
    t = Timer(
        stmt=f"x.masked_fill(y, -17.0); torch.mps.synchronize()",
        setup=f"x,y = torch.rand(1, 20, {n}, {n}, dtype={dtype}, device='mps'), torch.ones({n}, {n}, device='mps').triu().bool()",
        globals = {'f': binary_func},
        language="python", timer=default_timer
    )
    return t.blocked_autorange()

if __name__ == "__main__":
    n = 1024
    for dtype in [torch.float32, torch.float16, torch.bfloat16]:
        eager_t = bench_mask_fill(n, torch.fmax, dtype)
        use_msec = eager_t.mean > 1e-4
        multiplier = 1e3 if use_msec else 1e6
        uname = "msec" if use_msec else "usec"
        print(f"torch.masked_fill_() {str(dtype):>14} {eager_t.mean*multiplier:>7.2f} {uname}")
```
Fixes https://github.com/pytorch/pytorch/issues/143477
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147369
Approved by: https://github.com/dcci
ghstack dependencies: #147977
2025-02-26 18:39:15 +00:00
Nikita Shulga
9ed40af917 [BE][EZ] Delete MacOS-12.3 xfail list (#147905)
As PyTorch requires at least MacOS-13 (and Metal-3) to work, delete any pre-MacoS13 checks from test script
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147905
Approved by: https://github.com/dcci
ghstack dependencies: #147892
2025-02-26 05:08:09 +00:00
Nikita Shulga
346bbefa63 [BE] Parameterize TestSDPA in test_mps.py (#147856)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147856
Approved by: https://github.com/Skylion007
2025-02-25 16:07:24 +00:00
Isalia20
a695aae89b [MPS] fix attention for >4d tensors (#147545)
Fixes #147443

and adds tests for >4d tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147545
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-25 13:55:28 +00:00
Davide Italiano
4e934ee5a7 [MPS] Add eager support for xlog1py. (#147687)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147687
Approved by: https://github.com/malfet
2025-02-24 01:23:59 +00:00
Nikita Shulga
f03e7f3801 [MPS] Workaround rng bug for 5D tensors (#147667)
For some reason MPSGraph returns repeated values is tensor dimention is
larger than 4, which can be clearly seen by running following
```swift
import Metal
import MetalPerformanceShadersGraph

func randMPS(device: MTLDevice, obuf: MTLBuffer, nelem: Int, ndim: Int = 5) {
  let graph = MPSGraph()
  var dims = Array(repeating: 1, count: ndim)
  dims[0] = nelem
  let shape = dims.map { NSNumber(value: $0) }
  let randNode = graph.randomUniformTensor(withShape: shape, seed: 42, name: nil)
  let mpsOutputBuffer = MPSGraphTensorData(obuf, shape: shape, dataType: .float32)
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  graph.run(with: queue, feeds: [:], targetOperations: nil, resultsDictionary: [randNode: mpsOutputBuffer])
}

func printBuf(_ prefix: String, buf: MTLBuffer, nelem: Int) {
  let buf_data = buf.contents().assumingMemoryBound(to: Float.self)
  print(prefix)
  for i in 0..<nelem {
      print(buf_data[i], terminator: i != nelem - 1 ? " " : "\n")
  }
}

guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
print("Using device \(device.name)")

let nelem = 2
guard let buf = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }

randMPS(device: device, obuf: buf, nelem: nelem, ndim: 4)
printBuf("4D uniform", buf: buf, nelem: nelem)

randMPS(device: device, obuf: buf, nelem: nelem, ndim: 5)
printBuf("5D uniform", buf: buf, nelem: nelem)
```

Workaround by flatting the tensor if it's contiguous

Fixes https://github.com/pytorch/pytorch/issues/147624
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147667
Approved by: https://github.com/dcci
2025-02-23 16:52:01 +00:00
Nikita Shulga
198ffbdf11 [MPS] Implement and test round.decimals (#147266)
If inductor can do it, why not eager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147266
Approved by: https://github.com/Skylion007
ghstack dependencies: #147286
2025-02-16 23:17:13 +00:00
tim
b9a22b3f37 bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps (#146623)
This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.

The issue was found in https://github.com/hiyouga/LLaMA-Factory/issues/6835, in [transformers qwen2_vl](1590c66430/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py (L373C14-L373C93)), 3d q/k/v were passed into sdpa function, which lead to an error.

Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.

---
reproduce code:
```
import torch
import torch.nn.functional as F

head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1

q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)

oo_cpu = F.scaled_dot_product_attention(
    q.to("cpu"),
    k.to("cpu"),
    v.to("cpu"),
    attention_mask.to("cpu"),
    dropout_p=0.0
)

if torch.backends.mps.is_available():
    oo_mps = F.scaled_dot_product_attention(
        q.to("mps"),
        k.to("mps"),
        v.to("mps"),
        attention_mask.to("mps"),
        dropout_p=0.0
    )
    assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)
```

error outputs:
```
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
    oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
```

hardware and envs:
```
torch               2.6.0
apple m3 max
```

---

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146623
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-13 07:00:51 +00:00
Isalia20
17a808557c [MPS] cholesky ex version (#146799)
PR #145701 didn't have experimental version of cholesky. This PR adds that version

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146799
Approved by: https://github.com/malfet
2025-02-13 07:00:21 +00:00
Isalia20
d763093b49 [MPS] fix lu factor for large tensors with bs>1 (#146753)
Try this:
```python
import torch

batch_size = 2
A = torch.eye(256, device="mps")[None, :, :].expand(batch_size, -1, -1) + 0.1 * torch.randn((batch_size, 256, 256), device="mps")
A_cpu = A.cpu()
LU_cpu, pivots_cpu = torch.linalg.lu_factor(A_cpu)
LU, pivots = torch.linalg.lu_factor(A)
torch.testing.assert_close(LU.cpu(), LU_cpu)
```
You'll get huge difference in LU tensors
<img width="706" alt="Screenshot 2025-02-08 at 12 14 39" src="https://github.com/user-attachments/assets/b45f2b3c-e0a5-49c8-aa07-42792150b781" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146753
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-11 00:37:07 +00:00
Davide Italiano
dfe3b64282 [mps] Implement eager support for spherical_bessel_j0 (#146818)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146818
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-10 16:58:05 +00:00
Nikita Shulga
611ca163fd [MPS] Add bilineard2d_aa implementation (#145526)
Interesting quirk of the algorithm, that is not very well documented, is that value of align_corners is ignored in antialias mode, see arguments of
e8304f08fe/aten/src/ATen/native/cpu/UpSampleKernel.cpp (L747-L751)

Error out on  uint8 implementation(as it relies on a very fragile integer integer arithmetic), as it's not implemented on any other Accelerator devices at the moment.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145526
Approved by: https://github.com/dcci
2025-02-10 15:03:14 +00:00
Isalia20
0ab67299c3 [MPS] lu unpack (#146681)
Implements lu unpack function on MPS. Haven't added new tests because they are covered by removing the lu_unpack from UNIMPLEMENTED_XFAILLIST in test_mps with `test_output_match` function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146681
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-08 00:16:17 +00:00
Nikita Shulga
624d94bdb8 [MPS] Extend torch.special.sinc to complex (#146648)
And to integral data types as well

Was too lazy to deduce the formula myself(or write a sympy script), but ChatGPT did a decent job of doing it, though it forgot that input must be multiplied by $$\pi$$:
```math
\text{Re}\left(\text{sinc}(x + i y)\right) = \frac{\sin(x)\cosh(y) x - \cos(x)\sinh(y) y}{x^2 + y^2}
```
```math
\text{Im}\left(\text{sinc}(x + i y)\right) = \frac{\cos(x)\sinh(y) x + \sin(x)\cosh(y) y}{x^2 + y^2}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146648
Approved by: https://github.com/dcci
2025-02-07 01:12:37 +00:00
Davide Italiano
46390e9a37 [mps] Implement support for sinc() operator (inductor and eager). (#146539)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146539
Approved by: https://github.com/malfet, https://github.com/jansel

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-06 16:37:27 +00:00
Isalia20
0dc03134d9 [MPS] linalg solve implementation (#146531)
Fixes #98222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146531
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-06 00:57:49 +00:00
Davide Italiano
8a2000fd42 [MPS] Implement support for zeta (both eager and inductor). (#146465)
A test was failing in inductor (`test_pointwise_zeta`) -- and I realized the operation was missing also from eager.
Implemented for both, leveraging the kernel. Happy to split in two (one PR for eager, one for inductor) if folks prefer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146465
Approved by: https://github.com/malfet
2025-02-05 13:55:50 +00:00
Nikita Shulga
aafaf4016f [MPS] Add error checking when dispatching kernel (#146458)
That thread-group size should not exceed maximum thread group size
Add regression test to validate that
Make failures like https://github.com/pytorch/pytorch/issues/146430 much easier to detect
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146458
Approved by: https://github.com/dcci
2025-02-05 02:56:40 +00:00
Isalia20
e3643e1e0e [MPS] Add linalg det and fix lu factor for non contiguous tensors (#146279)
Requested in #77764

This PR adds support for linalg.det on MPS and fixes lu factor for non contiguous tensors, current implementation crashed on any kind of non-contiguous tensor with an error:
```
-[AGXG13XFamilyCommandBuffer blitCommandEncoderCommon:]:833: failed assertion `A command encoder is already encoding to this command buffer'
zsh: abort      python det.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146279
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-03 06:06:43 +00:00
Isalia20
5d55a6585d [MPS] lu factor ex implementation (#144651)
Implements `torch.linalg.lu_factor_ex`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144651
Approved by: https://github.com/malfet
2025-02-02 15:09:49 +00:00
Nikita Shulga
99a0940991 [MPS] Fix regression in con-contig bitwise ops (#146085)
Caused by https://github.com/pytorch/pytorch/pull/128393 that change semantic of `needsGather`, which resulted in silent correctness errors on MacOS-15+ if output tensor is non-contiguous

Fixes https://github.com/pytorch/pytorch/issues/145203

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146085
Approved by: https://github.com/dcci
2025-01-30 22:36:56 +00:00
Nikita Shulga
1fdb4d65c0 [MPS] Extend torch.mm/torch.bmm to integral types (#145809)
By using `naive_mm` kernel, but make sure that accumulation is done over int32 for smaller int types (and float for half and bfloat) as well as adding `navie_bmm` that follows the same pattern.
Remove stale restriction on `torch.dot` (which works fine on MacOS-14/15)
This also enables integer op flavors for:
- `addmv`
- `einsum`
- `inner`
- `linalg.multi_dot`
- `matmul`
- `mv`
- `tensordot`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145809
Approved by: https://github.com/dcci
2025-01-30 19:35:25 +00:00