Commit Graph

743 Commits

Author SHA1 Message Date
Kurt Mohler
510c398a4f Add max_pool3d backward pass for MPS (#157498)
Note on backward precision over fp16:

A float16 number has 10 bits of mantissa, 5 bits of exponent, and 1 bit for the sign. If the sign bit is positive, then with a mantissa $m$ and exponent $e$ represented in base 10, the number that the float16 format represents is $(1 + m / 1024)  \exp2(e)$. ([source](https://en.wikipedia.org/wiki/Half-precision_floating-point_format))

Consider adding two numbers $a$ and $b$ which have arbitrary mantissas, and say their exponents are $e_a = 1$ (so $2 \le a \lt 4$) and $e_b=-3$ (so $0.175 \le b \lt 0.25$). Assume that the result has the same exponent as $a$. Since the exponents differ by 4, we'll effectively need to truncate the 4 rightmost bits of $b$'s mantissa, which would introduce a maximum error on the order of $(2^4 / 1024)  \exp2(-3) \approx 0.002$.

The error is nearly the same if $e_b = -2$ (so $0.25 \le b \lt 0.5$), where the 3 rightmost bits are truncated, giving a maximum error on the order of $(2^3 / 1024)  \exp2(-2) \approx 0.002$. Same for $e_b=-1$.

So if we're adding up nine different numbers that all have exponents -3, -2, or -1, and they sum to a number with exponent 1, then we would expect a maximum error of several times greater than 0.002. In my comments above, summing those particular nine numbers in different ways gave results that ranged between 3.1816 and 3.1758, a difference of $0.0058 \approx 2.9  * 0.002$.

That's within the acceptable bounds, and we can safely just increase the error tolerance used in test_output_grad_match for the case of max_pool3d_backward with float16.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157498
Approved by: https://github.com/malfet
2025-07-07 19:46:44 +00:00
Manuel Candales
d56f11a1f2 [MPS] Implement logcumsumexp metal kernel (#156858)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156858
Approved by: https://github.com/malfet
ghstack dependencies: #157512
2025-07-03 18:16:25 +00:00
Nikita Shulga
5e636d664a [BE] @serialTest decorator must be called (#157388)
Otherwise it turns test into a trivial one(that always succeeds), as following example demonstrates
```python
import torch
from torch.testing._internal.common_utils import serialTest, run_tests, TestCase

class MegaTest(TestCase):
    @serialTest
    def test_foo(self):
        if hasattr(self.test_foo, "pytestmark"):
            print("foo has attr and it is", self.test_foo.pytestmark)
        print("foo")

    @serialTest()
    def test_bar(self):
        if hasattr(self.test_bar, "pytestmark"):
            print("bar has attr and it is", self.test_bar.pytestmark)
        print("bar")

if __name__ == "__main__":
    run_tests()
```

That will print
```
test_bar (__main__.MegaTest.test_bar) ... bar has attr and it is [Mark(name='serial', args=(), kwargs={})]
bar
ok
test_foo (__main__.MegaTest.test_foo) ... ok

----------------------------------------------------------------------
Ran 2 tests in 0.013s

```

Added assert that arg is boolean in the decorator to prevent such silent skips in the future

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157388
Approved by: https://github.com/clee2000
2025-07-02 19:15:19 +00:00
Nikita Shulga
019e30e3b8 [BE] Decorate LargeTensorTest with serialTests (#157382)
May be it'll help make M2-15 jobs more stable, as that was the last test run before OOM
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157382
Approved by: https://github.com/clee2000
2025-07-01 20:35:42 +00:00
Isalia20
a1282b1823 [MPS] Add boilerplate sparse code support (#157238)
This PR makes minimal changes to support sparse tensors on MPS. In the followup PRs I'll start adding different operations slowly so we can fix the issue of
https://github.com/pytorch/pytorch/issues/129842
which is highly requested(I assume because of whisper using sparse tensors)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157238
Approved by: https://github.com/malfet
2025-06-30 01:53:45 +00:00
Nikita Shulga
a1e4f1f98a [MPS] Reimplement tri[ul] as Metal shaders (#157179)
And add in-place flavor, as it is currently broken for non-contig tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157179
Approved by: https://github.com/dcci
2025-06-28 01:33:18 +00:00
Isalia20
653c52fe52 [MPS] Fix batch norm incorrect gradient (#156867)
Fixes #156555

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156867
Approved by: https://github.com/malfet
2025-06-25 23:05:49 +00:00
Joona Havukainen
20a74c370b Add error message with assert to topK if ndims() - dim > 4 (#155475)
Addressing #154890

Not really a proper fix but at least it's more informative than the current crash.

For a more long term solution I'm testing if we can use the TopK API released in MacOS14 as it does not have the same MPSScan op issue that the Sort and ArgSort are hitting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155475
Approved by: https://github.com/kulinseth
2025-06-13 21:10:06 +00:00
Nikita Shulga
dd41a3907c [MPS] Fix unary/binary ops for 2**32+ elem tensors (#155183)
By using `TensorIterator::with_32bit_indexing()` primitive

Add `bind_tensors` helper function that correctly sets up MPS tensors originating from TensorIterator

TODO: Add comments to bind_tensors as well asunit test, based on
```
python  -c "import torch;print((torch.rand(1, 1024, 1024, dtype=torch.bfloat16, device='mps') + torch.rand(5000, 1, 1, dtype=torch.bfloat16, device='mps')).sin())"
```

Fixes https://github.com/pytorch/pytorch/issues/154828
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155183
Approved by: https://github.com/cyyever, https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #155150, #155178, #155184
2025-06-05 18:57:14 +00:00
Roy Hvaara
9a4c08ddfc [MPS] Parametrize test_scaled_dot_product_attention_autocast (#155005)
Also moving comments inside the function scope for some of my previous regression tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155005
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-06-05 13:24:53 +00:00
Nikita Shulga
9cdce682a1 [MPS][BE] Reimplement log1p as Metal shader (#154936)
That should make it faster than MPSGraph implementation, but also
improves accuracy for small inputs, by using the algorithm described in [What Every Computer Scientist Should Know About Floating-Point Arithmetic](https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202), i.e. $log(1+x) = \frac{x * log(1+x)}{(1 + x) - 1}$ if $1 +x \neq 1$ else just $x$

Also tried using first 3 elements of Taylor series in Horner's form which also seems to work fine, i.e. $log(1+x) \approx x * (1 -x (\frac{1}{2} -  \frac{x}{3}))$

Replaced less accurate log1p implementation in `c10/metal/special_math.h` with generic one.

Parametrize and modify regression test to check for accuracy of small values

TODOs:
 - Do proper implementation for complex values as well, perhaps using 0408ba0a76/mlx/backend/metal/kernels/utils.h (L339)
 - May be implement it using Remez-like algorithm documented here 207f3b2b25/lib/msun/src/s_log1pf.c (L37)
 - Or use llvm's implementation from f393986b53/libclc/clc/lib/generic/math/clc_log1p.inc (L22)
 - Benchmark which algorithm is faster and delivers better accuracy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154936
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-06-03 14:10:13 +00:00
Joona Havukainen
981bdb39ca Enable ConvTranspose3D for FP32 and Complex64 (#154696)
Fixes #154615

Enables using ConvTranspose3D since it seems support exists both on MacOS 14 and 15.

For the half dtypes the discrepancy of CPU and GPU implementations is too large to conclude whether there is a bug in the implementation or not without a more rigorous study on what bounds are there to the expected error. So they are left unsupported for now and an assert is added to notify the user if the op is called with fp16 or bf16 inputs.

Tests for ConvTranspose3D were enabled for the supported data types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154696
Approved by: https://github.com/malfet
2025-06-02 16:24:03 +00:00
Isalia20
41092cb86c [MPS] index copy impl (#154326)
Second most requested op according to #154052

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154326
Approved by: https://github.com/malfet
2025-05-29 16:57:43 +00:00
Xuehai Pan
7ae204c3b6 [BE][CI][Easy] Run lintrunner on generated .pyi stub files (#150732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150732
Approved by: https://github.com/malfet, https://github.com/cyyever, https://github.com/aorenste
2025-05-27 14:58:02 +00:00
Nikita Shulga
975bbc63db [MPS][BE] Move fmod/remainder to Metal ops (#154280)
This accomplishes following:
 - Fixes correctness problem with large integer types (though probably makes it slower, but this could not be avoided if one wants to compute accurate answer)
 - Makes op faster for floating point types (as Metal kernel invocation is faster than creating MPSGraph)
 - Eliminates need for several correctness workarounds

Fixes https://github.com/pytorch/pytorch/issues/154171
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154280
Approved by: https://github.com/dcci
ghstack dependencies: #154275, #154290
2025-05-24 01:45:33 +00:00
Nikita Shulga
633ed01145 [MPS] Add support for two more isin variants (#154010)
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op, but needs a small tweak to handle float scalars
Make sure that `out` is resized to an expected value in `isin_Tensor_Tensor_out_mps`

Add unittests to validate that, but skip them on MacOS-13, where MPS op just returns garbage

Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154010
Approved by: https://github.com/Skylion007, https://github.com/dcci, https://github.com/manuelcandales
ghstack dependencies: #153970, #153971, #153997
2025-05-22 17:59:35 +00:00
Nikita Shulga
d5ddc5ab20 [MPS] Fix float64 scalar tensor handling (#153582)
Current implementation causes silent correction problem with torch.compile when someone tries to `torch.compile` function where one of the arguments is say `np.exp(.3)`, which will be represented as torch.float64 scalar tensor

Add regssion test for this behavior
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153582
Approved by: https://github.com/dcci
2025-05-15 05:15:14 +00:00
Nikita Shulga
8749fe8439 [CI][MPS] Speedup test_large_bmm (#153562)
By computing matmuls of only one random non-zero batch on CPU

This reduces test runtime from 11 minutes to 14 sec
```
 % python3 test/test_mps.py -v -k test_large_bmm_
test_large_bmm_bfloat16 (__main__.TestMPS.test_large_bmm_bfloat16) ... ok
test_large_bmm_float16 (__main__.TestMPS.test_large_bmm_float16) ... ok

----------------------------------------------------------------------
Ran 2 tests in 27.495s

```

TODO: Compute it over two slices when https://github.com/pytorch/pytorch/issues/153560 is fixed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153562
Approved by: https://github.com/Skylion007, https://github.com/clee2000
2025-05-14 18:49:42 +00:00
Isalia20
56492bfcb9 [MPS] SDPA specialized kernels (#152781)
Paritally fixes #139668 and #152550

Still work in progress. Following needs to be addressed:
- [x] Some tests are failing and need to check why and bugfix
- [x] Benchmark the new kernels and  add to this PR for varying sequence lengths head dimensions(the ones that get dispatched to kernels)
- [x] Add tests to cover the specialized paths(if applicable)
- [x] Code cleanup

**Tested on Macbook M1 Pro**
### Vector Fast Path (q_len=1, k_len=256)
- Old: 0.378 ms
- New: 0.260 ms
- **31.2% speed improvement**

### Vector 2-pass (q_len=1, k_len=4096)
- Old: 0.627 ms
- New: 0.370 ms
- **41.0% speed improvement**

### Vector Fast Path (q_len=8, k_len=256)
- Old: 0.545 ms
- New: 0.322 ms
- **40.9% speed improvement**

### Vector 2-pass (q_len=8, k_len=4096)
- Old: 1.318 ms
- New: 1.057 ms
- **19.8% speed improvement**

Script to get perf:
```
import torch
import time

def benchmark_sdpa(config, iterations=100):
    device = config.get("device", "cpu")
    batch = config["batch"]
    heads = config["heads"]
    q_len = config["q_len"]
    k_len = config["k_len"]
    head_dim = config["head_dim"]

    q = torch.randn(batch, heads, q_len, head_dim, device=device, dtype=torch.float32)
    k = torch.randn(batch, heads, k_len, head_dim, device=device, dtype=torch.float32)
    v = torch.randn(batch, heads, k_len, head_dim, device=device, dtype=torch.float32)

    for _ in range(5):
        _ = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        if device == "mps":
            torch.mps.synchronize()

    total_time = 0.0
    for i in range(iterations):
        start = time.perf_counter()
        _ = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        if device == "mps":
            torch.mps.synchronize()
        end = time.perf_counter()
        total_time += end - start

    avg_time = total_time / iterations
    print(f"[{config['name']}] Avg time per run: {avg_time * 1000:.3f} ms over {iterations} iterations")
    return avg_time

def main():
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Running benchmarks on device: {device}")

    benchmarks = [
        {
            "name": "Vector Fast - Small q_len & moderate k_len",
            "batch": 1,
            "heads": 8,
            "q_len": 1,      # small query sequence length triggers vector fast path
            "k_len": 256,    # moderate key length
            "head_dim": 64,
            "device": device,
        },
        {
            "name": "Vector 2-pass - Small q_len & long k_len",
            "batch": 1,
            "heads": 8,
            "q_len": 1,      # small query sequence length
            "k_len": 4096,   # long key length triggers the 2-pass variant
            "head_dim": 64,
            "device": device,
        },
        # {
        #     "name": "Full Attention - Moderate q_len/k_len",
        #     "batch": 1,
        #     "heads": 8,
        #     "q_len": 128,    # longer query sequence length
        #     "k_len": 8192,    # matching key length for full attention paths
        #     "head_dim": 64,
        #     "device": device,
        # },
        # {
        #     "name": "Full Attention - Longer q_len/k_len",
        #     "batch": 1,
        #     "heads": 8,
        #     "q_len": 128,    # very long sequence length
        #     "k_len": 8192,
        #     "head_dim": 64,
        #     "device": device,
        # },
    ]

    iterations = 100
    for config in benchmarks:
        benchmark_sdpa(config, iterations=iterations)

if __name__ == "__main__":
    main()

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152781
Approved by: https://github.com/malfet
2025-05-07 00:40:11 +00:00
Nikita Shulga
0ffd31dc8a [MPS] Migrate div roudning modes (#152758)
By implementing `div_floor` and `div_trunc` . Do not mark `div_trunc` as OPMATH, to align following output with CPU(if division is performed in fp32, than result will be truncated to 25
```
import torch
print(torch.tensor([[-7.4688, -3.1289]], dtype=torch.float16,device="cpu").div(torch.tensor([-0.2988, -0.8789], dtype=torch.bfloat16,device="cpu"), rounding_mode="trunc"))
tensor([[24.,  3.]])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152758
Approved by: https://github.com/dcci
ghstack dependencies: #152663, #152515, #152737, #152743
2025-05-05 03:02:29 +00:00
Isalia20
99c42722f6 [MPS] fix memory leak in sdpa float32 (#152371)
Fixes #152344

Leak seems to be on the MPS Graph side, even though there is an identity tensor it seems like it's no longer enough to bypass the SDPA sequence which seems to leak memory.

Even adding 0.0f seems to be optimized to be ignored and still take the sdpa sequence(that's the reason for adding 1e-20)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152371
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-29 04:51:10 +00:00
Isalia20
899eec665c [MPS] col2im kernel implementation (#152282)
Fixes #151820
Also requested in #141287

Mainly based on the cuda kernel implementations

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152282
Approved by: https://github.com/malfet
2025-04-28 03:48:41 +00:00
Nikita Shulga
3ef6d6924a [BE] Switch TestConsistency to MPS device (#147893)
Which will eventually allow move decorators away more `common_mps.py`

Adjust tolerances accordingly. XFAIL a bunch of tests on MacOS-13, which is going to be deprecated anyway

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147893
Approved by: https://github.com/atalman
ghstack dependencies: #152204
2025-04-26 01:19:21 +00:00
Isalia20
5e9bdc9b86 [MPS] layernorm forward kernel (#152010)
Implements layernorm forward pass as a metal kernel instead of MPSGraph ops. Speed ups are indicated on the chart below:
![Figure_1](https://github.com/user-attachments/assets/27a4d2ef-b3e4-4650-9ce3-b939c080321e)

Script for generating times, need to build torch with old/new codebase and then run this with different file name indicated at the end of the script
```python
import csv
import time

import numpy as np

import torch
import torch.nn.functional as F

matrix_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
batch_sizes = [1]
elementwise_affine = [False, True]
num_runs = 50
warmup_runs = 3

def create_input_tensor(n, batch_size):
    torch.manual_seed(42)
    return torch.randn(batch_size, n, dtype=torch.float32)

def run_layer_norm(A, normalized_shape, elementwise_affine):
    torch.mps.synchronize()
    start = time.perf_counter()
    out = F.layer_norm(A, normalized_shape)
    torch.mps.synchronize()
    end = time.perf_counter()
    return out, end - start

results = {"N": [], "elementwise_affine": [], "batch_size": [], "mean_time": [], "std_time": []}

for el_aff in elementwise_affine:
    for n in matrix_sizes:
        for batch_size in batch_sizes:
            print(f"\nBenchmarking LayerNorm for input size N={n}, batch_size={batch_size}, elementwise_affine={el_aff}")

            try:
                A_cpu = create_input_tensor(n, batch_size)
                A_mps = A_cpu.to("mps")

                normalized_shape = (n,)

                for _ in range(warmup_runs):
                    _, _ = run_layer_norm(A_mps, normalized_shape, el_aff)

                times = []
                for _ in range(num_runs):
                    _, t = run_layer_norm(A_mps, normalized_shape, el_aff)
                    times.append(t)

                mean_time = np.mean(times)
                std_time = np.std(times)

                results["N"].append(n)
                results["elementwise_affine"].append(el_aff)
                results["batch_size"].append(batch_size)
                results["mean_time"].append(mean_time)
                results["std_time"].append(std_time)

                print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")

            except RuntimeError as e:
                print(f"Error for N={n}, batch_size={batch_size}: {e}")
                continue

with open("layernorm_benchmark_times_new.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["N", "elementwise_affine", "batch_size", "mean_time", "std_time"])
    for i in range(len(results["N"])):
        writer.writerow(
            [
                results["N"][i],
                results["elementwise_affine"][i],
                results["batch_size"][i],
                results["mean_time"][i],
                results["std_time"][i],
            ]
        )

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152010
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-24 05:07:46 +00:00
Nikita Shulga
3aecf2dc52 [MPS] Extend index_put to half precision floats (#151869)
By reusing `c10/metal/atomic.h`
This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support

TODOs:
 - Get rid of indexing kernel and compute it directly when kernel is run
 - Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch
 - Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-04-22 22:00:08 +00:00
Li-Huai (Allan) Lin
fbd29527d8 [MPS] Move ops modifiers to testing utils so other tests can reuse (#151781)
Test collection check:
```
python -m pytest test/test_mps.py --collect-only
```
Before:
```
6390 tests collected in 8.34s
```

After:
```
6390 tests collected in 7.71s
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151781
Approved by: https://github.com/malfet
2025-04-22 19:19:52 +00:00
Nikita Shulga
f37e138bc4 [MPS] Enable log1p and sigmoid for int64 (#151791)
It works on MacOS-15, but likely will need a skip for MacOS-13

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151791
Approved by: https://github.com/Skylion007
ghstack dependencies: #151790
2025-04-21 18:30:04 +00:00
Davide Italiano
470132c6a1 [MPS] Add support for hermite_polynomial_he (inductor/eager). (#151754)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151754
Approved by: https://github.com/malfet, https://github.com/jansel
2025-04-20 17:44:40 +00:00
Nikita Shulga
14293c2377 [MPS] Allow isin for mixed types (#151600)
To follow pattern set by CPU and CUDA impls: define common_dtype and optionally casts `elements` and `test_elements` to common dtype if needed

- Add regression test, though skip it on MacOS-13, as `isin` seems to produce garbage there even for same dtypes
```
>>> import torch
>>> x=torch.arange(4.0, device='mps')
>>> y=torch.arange(1.0, 3.0, device='mps')
>>> x, y, torch.isin(x, y), torch.isin(y, x)
(tensor([0., 1., 2., 3.], device='mps:0'), tensor([1., 2.], device='mps:0'), tensor([False,  True, False, False], device='mps:0'), tensor([False, False], device='mps:0'))
>>> torch.__version__
'2.6.0'
```
- Cleanup code a bit

Fixes https://github.com/pytorch/pytorch/issues/151443
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151600
Approved by: https://github.com/Skylion007, https://github.com/dcci, https://github.com/kulinseth
2025-04-18 12:30:32 +00:00
Nikita Shulga
1ffaa00ad7 [MPS] Migrate bitwise_not to unary operator (#151460)
That kills to birds with one stone:
 - Makes implementations more standartized (and faster for strided inputs/outputs)
 - Fixes bug strided inplace bitwise_not

I.e. before this change
```python
import torch
x=torch.arange(32, device="mps")
x[::2].bitwise_not_()
print(x)
```
produced
```
tensor([ -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13, -14,
        -15, -16,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31], device='mps:0')
```
after, it generates reasonable output
```
tensor([ -1,   1,  -3,   3,  -5,   5,  -7,   7,  -9,   9, -11,  11, -13,  13,
        -15,  15, -17,  17, -19,  19, -21,  21, -23,  23, -25,  25, -27,  27,
        -29,  29, -31,  31], device='mps:0')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151460
Approved by: https://github.com/dcci, https://github.com/qqaatw, https://github.com/Skylion007
2025-04-16 21:34:45 +00:00
Nikita Shulga
b8a2824755 [MPS] Fix logit output for half/bfloat (#151282)
Which also fixes MPSInductor pointwise test
TODO: (as followup PRs): get rid of special native_function.yaml dispatches and use stub
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151282
Approved by: https://github.com/dcci
ghstack dependencies: #151224, #151246, #151272
2025-04-15 06:25:00 +00:00
Li-Huai (Allan) Lin
ddfc14b3ae [MPS] Fix where (#151176)
Fixes #150967
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151176
Approved by: https://github.com/kulinseth, https://github.com/malfet
2025-04-13 20:44:50 +00:00
Nikita Shulga
bc47d539fc [MPS] Support ArgumentBuffer bindings from C++/Python (#150780)
To workaround limitation of 32-arguments per kernel and being able to eventually compile something like
```python
import torch

def foo(*args):
  rc = torch.empty_like(args[0])
  for arg in args:
      rc += arg
  return rc

tensors = torch.rand(100, 32, device='mps').unbind(0)
print(torch.compile(foo)(*tensors))
```

For now, introduce `at::native:🤘:get_tensor_gpu_address` and use it from both C++ test and compile_shader to convert list of tensors to list of pointers valid on GPU.

Initially this binding were done via `id< MTLArgumentEncoder>`, but according to [Improving CPU Performance by Using Argument Buffers](https://developer.apple.com/documentation/metal/improving-cpu-performance-by-using-argument-buffers?language=objc#Encode-Resources-into-Argument-Buffers) article, this is not necessary when targeting Tier2-only devices (which is true of all devices on MacOS-13 or newer):
> To directly encode the argument buffer resources on these Tier 2 devices, write the [MTLBuffer](https://developer.apple.com/documentation/metal/mtlbuffer?language=objc).[gpuAddress](https://developer.apple.com/documentation/metal/mtlbuffer/gpuaddress?language=objc) property — and for other resource types (samplers, textures, and acceleration structures), the [gpuResourceID](https://developer.apple.com/documentation/metal/mtlcomputepipelinestate/gpuresourceid?language=objc) property — into the corresponding structure member. To encode offsets, treat these property values as uint64 types and add the offset to them.

Add both C++ and PyThon unittests that validate that this works.
Please note, that using either ArgumentEncoder or directly encoding the data does not guarantee buffer will not be freed until shader execution is complete. On the other hand, this should already be guaranteed by MPSCachingAllocator that would only free the memory after all streams completed its execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150780
Approved by: https://github.com/dcci
2025-04-09 04:24:37 +00:00
Isalia20
49f6cce736 [MPS] grad scaler (#150255)
Fixes #142397

Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests

Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.

By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`

This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-06 17:06:55 +00:00
Isalia20
cfea55dbec [MPS] fix inverse bug for N>1024 (#146754)
Fixes #138200

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146754
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-05 21:49:21 +00:00
Nikita Shulga
7ac8186851 [MPSInductor] Speedup sum/prod reductions (#150566)
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions

Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512         | 202.1       | 131.8       |
| 1024x1024   |   780.6    | 176.9       |
| 2048x2048    |   1423.4       | 339.9      |
| 4096x4097    |    2982.2 | 1047.2      |

Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
  out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,  0,  0,  0,  0, 0,  0,  0,  0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
2025-04-05 02:47:27 +00:00
Nikita Shulga
827b730f4e [CI] Skip test_copy_large_tensor on M2-15 runners (#150377)
They have more than 12Gb memory, but may be running this test causes OOM in CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150377
Approved by: https://github.com/atalman
2025-04-01 02:33:43 +00:00
Davide Italiano
b48505a8a1 [MPS] Add support for hermite_polynomial_h. (#150279)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150279
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-03-31 23:30:19 +00:00
Nikita Shulga
7c65911b11 [MPS] Fix dot/mm for conj_tensors (#150157)
- Distinguish between conjugated/non_conjugated inputs by appending conjugation to the operator key
- For matmul or dot, add `conjugateWithTensor:name:` calls before running the op
- Enable testing for conjugated ops by passing `include_conjugated_inputs` to opinfo
- Filter  `include_conjugated_inputs` argument from `sample_inputs_window` (probably should have landed as separate PR)
- Preserve conj property when gathering the views, that fixes `cov` operator

Fixes https://github.com/pytorch/pytorch/issues/148156
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150157
Approved by: https://github.com/dcci
2025-03-28 20:36:44 +00:00
Nikita Shulga
ef1cb6b646 [BE] Suppress user_warnings while running opinfo tests (#150115)
Some of the samples are constructed in a way that are expected to trigger those, but what's the point displaying them
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150115
Approved by: https://github.com/dcci
ghstack dependencies: #150060
2025-03-27 22:36:27 +00:00
Nikita Shulga
6aca002d82 [MPS] Add chebyshev_polynomial_[uvw] (#150060)
For both eager and inductor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150060
Approved by: https://github.com/dcci, https://github.com/jansel
2025-03-26 23:35:05 +00:00
Nikita Shulga
de68ddc68e [MPS] Fix metal ops with different dtypes (#149974)
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes.

Been dreaded to write this PR for a while, as it end up to be pretty bulky:
 - Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10:🤘:ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType`
 - Add `c10:🤘:cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values
 - Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types
 - Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op

Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.)

|                                           |  dense-dense  |  transp-transp  |  dense-transp  |  transp-dense  |  dense-scalar  |  dense-bcast |
|-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- |
|      fmax (torch.float16, torch.float16)  |     160.9     |      159.9      |     270.5      |     270.9      |     236.6      |     293.0
|      fmax (torch.float32, torch.float32)  |     176.9     |      171.0      |     273.7      |     293.5      |     242.6      |     294.2
|      fmax (torch.float32, torch.float16)  |     171.4     |      170.9      |     283.6      |     303.0      |     253.7      |     302.3
|      add (torch.float16, torch.float16)   |     218.0     |      223.6      |     221.0      |     222.0      |     214.9      |     218.3
|      add (torch.float32, torch.float32)   |     227.4     |      233.9      |     228.8      |     231.9      |     218.9      |     221.4
|      add (torch.float32, torch.float16)   |     226.1     |      227.5      |     227.5      |     226.9      |     177.0      |     190.8

TODOS:
 - Include input and output dtype in non-cast kernel name
 - Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR`
- Extend mixed_dytpes testing via OpInfo

Fixes https://github.com/pytorch/pytorch/issues/149951
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149974
Approved by: https://github.com/manuelcandales
2025-03-26 07:03:21 +00:00
Isalia20
ba46643df1 [MPS] tril op not handling infs correctly (#149866)
Fixes #149813

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149866
Approved by: https://github.com/malfet
2025-03-24 23:38:41 +00:00
Davide Italiano
9179178728 [MPS] Add support for chebyshev_polynomial_t in eager. (#149816)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149816
Approved by: https://github.com/malfet
2025-03-24 19:19:55 +00:00
Isalia20
248487f455 [MPS] nanmedian with dims (#149680)
Third most voted op from #77764

Tests were deleted because they are covered by the regular test_output_match tests so those were redundant and were added in the last PR before the nanmedian dim version would be implemented

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149680
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-24 03:49:16 +00:00
Davide Italiano
b9a5e1d038 [MPS] Add support for scaled_modified_bessel_k1 to eager. (#149783)
Another day another op

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149783
Approved by: https://github.com/malfet
2025-03-22 02:13:41 +00:00
Davide Italiano
bdc132d0e1 [MPS] Add support for scaled_modified_bessel_k0 for eager. (#149705)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149705
Approved by: https://github.com/malfet
2025-03-21 16:14:29 +00:00
Davide Italiano
0ed34210b2 [MPS] Add support for modified_bessel_k1 to eager and inductor. (#149687)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149687
Approved by: https://github.com/malfet
2025-03-21 04:59:06 +00:00
Isalia20
95e71765f2 [MPS] nanmedian implementation (#149407)
Implements nanmedian on MPS. This implementation only implements `torch.nanmedian(tensor)` without `keepdim` and `dim`
Will implement nanmedian with dim and keepdim in a followup

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149407
Approved by: https://github.com/malfet
2025-03-20 03:50:26 +00:00
Davide Italiano
88c2fe533f [MPS] Add modified_bessel_k0 support to eager. (#149563)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149563
Approved by: https://github.com/malfet
2025-03-19 23:10:55 +00:00
Nikita Shulga
2e0c98ff05 [MPS] Add bicubic2d_aa (#149378)
Which is currently the most frequently requested op in https://github.com/pytorch/pytorch/issues/141287

Mostly done by refactoring `upsample_bilinear2d_aa` to accept Functor as one of the template arguments, which closely ideas from eec43cfbc0/src/libImaging/Resample.c as well as
bb42e4d137/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu (L472-L478)

Populate unit tests by copying upsample_bilinear_2d_aa and reusing it as upsample_bicubic2d_aa

At that point, only difference between upsample_bilinear2d_aa and upsample_bicubic2d_aa are convolution kernel function and size: for bilinear it's 3x3, for bicubic it's 5x5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149378
Approved by: https://github.com/dcci
2025-03-18 05:35:41 +00:00
Davide Italiano
c43e35d6f7 [MPS] Implement support for modified_bessel_i1 in eager. (#149368)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149368
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-18 03:29:10 +00:00
Davide Italiano
186cc7327c [MPS/BE] Remove decorator that skipped test on macOS 12. (#149365)
macOS 12 is not really supported anymore.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149365
Approved by: https://github.com/malfet
2025-03-18 00:58:08 +00:00
Davide Italiano
9f33c6f0a0 [MPS] Add support for modified_bessel_i0 in eager. (#149264)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149264
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-16 04:45:49 +00:00
Nikita Shulga
96795e9533 [BE] Parametrize TestMPS.test_binops_dtype_precedence (#149234)
No op change, just splits a longer tests into a series of a smaller ones
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149234
Approved by: https://github.com/atalman, https://github.com/dcci
ghstack dependencies: #149216, #149233
2025-03-15 00:37:11 +00:00
Isalia20
dd6e9df3d0 [MPS] fix attention enable_gqa crash on mps (#149147)
Fixes #149132

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149147
Approved by: https://github.com/malfet
2025-03-14 21:25:54 +00:00
Nikita Shulga
f2221b2fce [MPS] Add support for i1e (#149203)
Followup after https://github.com/pytorch/pytorch/pull/149174
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149203
Approved by: https://github.com/dcci
2025-03-14 17:33:52 +00:00
cyy
a9aae05a6b Remove test decorations on MacOS 12 (#148942)
MacOS 12 may reach EOL, as from https://endoflife.date/macos
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148942
Approved by: https://github.com/malfet
2025-03-14 17:22:37 +00:00
Davide Italiano
706c22549c [MPS] Add support for i0e in eager. (#149174)
Add `special.i0e` to XFAIL_GRADLIST for now, as its backward op is not yet implemented
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149174
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-14 14:43:46 +00:00
PyTorch MergeBot
be4e6c1c8e Revert "[MPS] Add support for i0e in eager. (#149174)"
This reverts commit b4745db904.

Reverted https://github.com/pytorch/pytorch/pull/149174 on behalf of https://github.com/malfet due to MPS are red on trunk ([comment](https://github.com/pytorch/pytorch/pull/149174#issuecomment-2723774600))
2025-03-14 06:35:01 +00:00
Nikita Shulga
db6d72213b [MPS] Add torch.special.bessel_[jy][01] implementations (#149123)
By copy-n-pasting functions from
f59064f2b7/aten/src/ATen/native/cuda/Math.cuh (L1463)

With an  ugly workaround for `bessel_y[01]` to avoid internal compiler exception on M1/M2 machines (see FB16863363 /  https://gist.github.com/malfet/e7785e4b572e7740887a83a2386ef769 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149123
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-03-14 05:13:55 +00:00
Davide Italiano
b4745db904 [MPS] Add support for i0e in eager. (#149174)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149174
Approved by: https://github.com/malfet
2025-03-14 02:51:28 +00:00
Nikita Shulga
924a247fbb [MPS] Enable angle and atan2 for torch.long (#149017)
This check was added by https://github.com/pytorch/pytorch/pull/85817, that introduced no unit-tests and its content seems to be totally unrelated to title/subject of that PR. Anyway, right now it seems to be working fine on MacOS-13+

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149017
Approved by: https://github.com/dcci
2025-03-12 04:48:52 +00:00
Nikita Shulga
c18858d633 [MPS] Make torch.mps.compile_shader public (#148972)
It was a private method in 2.6, but nothin changes in its API for 2.7
and it will likely remain the same in 2.8, so time to remove underscore
from its name

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148972
Approved by: https://github.com/Skylion007, https://github.com/atalman, https://github.com/seemethere, https://github.com/albanD, https://github.com/dcci
2025-03-11 20:20:58 +00:00
Nikita Shulga
b95889042c [MPS] Introduce strides unary op (#148468)
By adding following template
```metal
template <typename T, typename F>
kernel void unary_strided(
    device result_of<F, T>* output [[buffer(0)]],
    constant T* input [[buffer(1)]],
    constant long* sizes [[buffer(2)]],
    constant long* input_strides [[buffer(3)]],
    constant long* output_strides [[buffer(4)]],
    constant uint& ndim,
    uint index [[thread_position_in_grid]]) {
  F f;
  int pos[max_ndim];
  pos_from_thread_index(int(index), pos, sizes, ndim);
  const auto input_offs = offset_from_coord(pos, input_strides, ndim);
  const auto output_offs = offset_from_coord(pos, output_strides, ndim);
  output[output_offs] = f(input[input_offs]);
}
```
and instantiating it for all existing unary shaders, which eliminates the need to any intermediate copies.
No extra testing are needed as those cases are already covered by `test_output_grad_match_corrcoef_cpu_float32` as well as `test_unary_ops_storage_offset_strided`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148468
Approved by: https://github.com/dcci
2025-03-09 22:30:51 +00:00
Nikita Shulga
da923afdc7 [MPS][BE] Align bitshift behavior with CPU (#148719)
By casting the argument to output type
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148719
Approved by: https://github.com/Skylion007
ghstack dependencies: #148685, #148686
2025-03-07 18:28:14 +00:00
Nikita Shulga
f84710aef4 [MPS] Fix scalar to tensors bitshifts (#148686)
By introducing a concept of non-commutative binary op and renaming all op templates from `bitwise_foo_tensor` and `bitwise_foo_scalar` to `bitwise_foo_tensor_tensor` and `bitwise_foo_tensor_scalar`

Add regression tests

Please note, that for some undefined values MPS and CPU behaviors are different, for example
```
>>> import torch
>>> 4095 >> torch.arange(12, device="mps", dtype=torch.uint8)
tensor([255, 255, 255, 255, 255, 127,  63,  31,  15,   7,   3,   1],
       device='mps:0', dtype=torch.uint8)
>>> 4095 >> torch.arange(12, device="cpu", dtype=torch.uint8)
tensor([255, 127,  63,  31,  15,   7,   3,   1,   0,   0,   0,   0],
       dtype=torch.uint8)
```
Because on CPU scalar is cast to output dtype before operation is performed, but on MPS this happens after the op is done

Fixes https://github.com/pytorch/pytorch/issues/147889
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148686
Approved by: https://github.com/albanD
ghstack dependencies: #148685
2025-03-07 18:28:14 +00:00
Isalia20
02e1580e39 [MPS] fix crash for mse loss with 0 numel inputs (#148608)
Fixes #148589

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148608
Approved by: https://github.com/malfet
2025-03-06 03:32:34 +00:00
Nikita Shulga
864b75dd50 [MPS] Fix unary_kernel_strided logic (#148512)
Fixes bug introduced by https://github.com/pytorch/pytorch/pull/148350
Before this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[  0.0000,   1.4142,   2.0000,   2.4495],
        [ 80.0000,  82.0000,  84.0000,  86.0000],
        [ 96.0000,  98.0000, 100.0000, 102.0000],
        [112.0000, 114.0000, 116.0000, 118.0000]], device='mps:0')
```
After this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[0.0000, 1.4142, 2.0000, 2.4495],
        [4.0000, 4.2426, 4.4721, 4.6904],
        [5.6569, 5.8310, 6.0000, 6.1644],
        [6.9282, 7.0711, 7.2111, 7.3485]], device='mps:0')
```
One can not avoid copies if both input and output tensors have the same strides, one needs to make sure that they are dense-in-storage (transposed tensor would be dense, but say selecting every odd and even column wouldn't)

Add regression test to prevent those from happening again

Also, no need to check that sizes match, luckily it is checked by the structured op (and `out` for unary ops does not support broadcasting, I just checked)

Revived needs_copy_logic, though it  will become irrelevant after https://github.com/pytorch/pytorch/pull/148468 is landed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148512
Approved by: https://github.com/janeyx99
2025-03-05 15:57:54 +00:00
Isalia20
0c0a4baddd [MPS] unary kernels - avoid copying tensors if they have same stride (#148350)
I was a bit concerned when I saw in #148272 that metal unary kernel was 0.02x of the performance of what we had with MPS Graphs for sqrt(for non contiguous) tensors. This change makes it so that copying is only done if we don't have same strided tensors(for input/output). So if out tensor is not provided then we don't do copy(don't call contiguous) at all and dispatch the kernel as is. After making this change the script that I listed at the end of the above PR has the same execution time as the non-transposed one.

Times for reference(on transposed tensor where matrix is NxN matrix):

| N     | time_old           | time_new           |
|-------|--------------------|--------------------|
| 100   | 0.0002241021       | 0.0001548659       |
| 1000  | 0.0005934822       | 0.0002150342       |
| 10000 | 0.3242016407       | 0.0045755033       |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148350
Approved by: https://github.com/janeyx99
2025-03-04 23:20:26 +00:00
Isalia20
439395c0ae [MPS] add slogdet and logdet implementations to mps (#148287)
Low hanging fruits, all ops for these are implemented so just adding them to native functions adds the functionality on mps. Probably next op I should add should be lu solve seeing as how many ops need it for the grad calculation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148287
Approved by: https://github.com/malfet
2025-03-04 19:49:23 +00:00
Nikita Shulga
84502baaff [MPS] Fix sqrt and other for torch.chalf (#148285)
Those kernels, instead of being instantiated for half2 (which corresponds to ComplexHalf) were instnatiated for short2, which resuled in the following test
```
% python3 -c "import torch; print(torch.rand(6, device='mps', dtype=torch.chalf).sqrt())"
```
Fail with
```
RuntimeError: Failed to create function state object for: sqrt_complex_half_half
```

As sqrt is not implemented for CPU, add explicit test to `test_sqrt`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148285
Approved by: https://github.com/dcci
2025-03-03 16:03:54 +00:00
Isalia20
19de523de6 [MPS] metal unary kernel for sqrt (#148272)
Issue #148219 highlighted the high dispatch times of ops which ran with MPS Graph on smaller tensors. This PR rewrites the sqrt with metal kernel to mitigate that issue

## Speedups:

Matrix size means NxN matrix here.
![speedup_sqrt](https://github.com/user-attachments/assets/db0a705b-1a0e-42b4-bd42-4e7960415c81)

Code to generate the times(needs building the torch with old time and new time):
```python
import torch
import numpy as np
import time
import csv

matrix_sizes = [1, 100, 1000, 10_000]
num_runs = 1000
warmup_runs = 3

def run_sqrt(A):
    torch.mps.synchronize()
    start = time.perf_counter()
    c = torch.sqrt(A)
    torch.mps.synchronize()
    end = time.perf_counter()
    return c, end - start

results = {
    'N': [],
    'mean_time': [],
    'std_time': []
}

for n in matrix_sizes:
    print(f"\nBenchmarking N={n}")

    try:
        A_mps = torch.rand((n, n), dtype=torch.float32, device="mps")

        for _ in range(warmup_runs):
            _, _ = run_sqrt(A_mps)

        times = []
        for _ in range(num_runs):
            _, t = run_sqrt(A_mps)
            times.append(t)

        mean_time = np.mean(times)
        std_time = np.std(times)

        results['N'].append(n)
        results['mean_time'].append(mean_time)
        results['std_time'].append(std_time)

        print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")

    except RuntimeError as e:
        print(f"Error for N={n}: {e}")
        continue

with open('sqrt_benchmark_times_new.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['N', 'mean_time', 'std_time'])
    for i in range(len(results['N'])):
        writer.writerow([
            results['N'][i],
            results['mean_time'][i],
            results['std_time'][i]
        ])

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148272
Approved by: https://github.com/malfet
2025-03-02 00:45:45 +00:00
Nikita Shulga
3a0c9f7f9d [MPS] Fix SDPA crash (#148239)
If operation is invoked with mask twice it will crash, as mask expansion logic was implemented inside cache creation block, which is executed only once for all shapes

Fixes https://github.com/pytorch/pytorch/issues/148194 which is a regression introduced by https://github.com/pytorch/pytorch/pull/147545
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148239
Approved by: https://github.com/dcci
2025-03-01 13:06:51 +00:00
Nikita Shulga
735d7b1af6 [EZ][BE] Increase tolerances for interpolate op (#148224)
Not sure why tolerances were set like that, this logic was added in https://github.com/pytorch/pytorch/pull/104181 without much explanation
But if I'm to make a guess, it's likely due to the inaccuracy of bilinear op, that has since been replaced by shader
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148224
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #148154, #148187, #148211
2025-03-01 13:03:59 +00:00
Isalia20
08434df1f2 [MPS] fix empty place holder error for smooth l1 loss (#148133)
Fixes #123171

And parametrizes the tests for it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148133
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-03-01 02:32:45 +00:00
Nikita Shulga
e5e31050d3 [MPS] Implement linear1d as shader (#148154)
And get rid of MPS call, as for some reason implementation via MPSGraph
API call is 100x+ times slower that Metal shader, at least according to
the following benchmark
```python
import torch
import time
import subprocess

def benchmark(device, dtype):
    # Create example inputs
    x = torch.testing.make_tensor(3, 5, 65536, device=device, dtype=dtype)
    sf = .5

    # Check output
    y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
    z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="linear")
    outputs_match = torch.allclose(y.cpu(), z)
    if not outputs_match:
       atol = (y.cpu() - z).abs().max()
       rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
       print(f"atol={atol} rtol={rtol}")

    # Measure time manually
    start_time = time.time() * 1000
    for _ in range(1000):
        y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="linear")
    torch.mps.synchronize
    end_time = time.time() * 1000
    manual_delta = (end_time - start_time)
    average_time = f"{manual_delta:6.1f}"

    return "True " if outputs_match else "False", average_time

outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
    for dtype in [torch.float32, torch.float16, torch.bfloat16]:
        outputs_match, average_time = benchmark(device, dtype)
        outputs_match_list.append(str(outputs_match))
        average_time_list.append(average_time)

brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device            :                MPS        |               CPU")
print("Dtype             :   FP32  |  FP16  |  BF16  |  FP32  |  FP16  |  BF16  ")
print(f"Outputs Match     :  ", " |  ".join(outputs_match_list))
print(f"Average Time (us) :", "  |".join(average_time_list))
```

Benchmark results after the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device            :                MPS        |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |  FP32  |  FP16  |  BF16
Outputs Match     :   True  |  True  |  True  |  True  |  True  |  True
Average Time (us) :    2.5  |   2.1  |   2.2  | 161.4  | 115.0  | 161.1
```
And before the change
```
Benchmarking Results (collected on Apple M2 Pro):
----------------------------------------
Device            :                MPS        |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |  FP32  |  FP16  |  BF16
Outputs Match     :   True  |  True  |  True  |  True  |  True  |  True
Average Time (us) :  354.0  | 336.0  | 332.4  | 145.5  | 114.7  | 148.3
```

Fixes https://github.com/pytorch/pytorch/issues/144245
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148154
Approved by: https://github.com/dcci
2025-02-28 16:47:42 +00:00
Davide Italiano
683e083e8d [MPS] Add support for entr() in eager. (#147948)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147948
Approved by: https://github.com/malfet
2025-02-26 19:55:02 +00:00
Nikita Shulga
00732c3f7e [MPS] Implemented masked_fill_scalar as shader (#147369)
- Move `pos_from_thread_index and `offset_from_pos` from `UnfoldBackward.metal` into `c10/metal/indexing.h` header
- Initial idea were to implement `StridedTensor` and `ConstStridedTensor` and use them to have masked_fill kernel a something simple as the following loop
```metal
ConstStridedTensor<bool> mask(mask_data, sizes, mask_strides, ndim);
if (mask[thread_index]) {
  StridedTensor<T> input(input_data, sizes, input_strides, ndim);
  input[thread_index] = val;
}
```
But though it looks elegant and works correctly, performance wise it's much slower that the existing MPS shader (see table below), as int64 divisions on M2 GPU are really slow

- Solved performance issue by implementing 3 flavors of the same shader: `dense`, that is used when both input and mask are dense tensors of the same size, `broadcast`, which is used when `mask` is leading dimensions expandable into input tensor and `strided`  which is a general purpose fallback, but still computes position in the tensors only ones. As result, perf is even better than existing MPS shader for dense and broadcast able tensors.

Performance measured on M2Pro thru different iterations of the same shader

| dtype | MPS | int64-idx | int64-inlined | 32-bit strided | 32-bit broadcasted |
| ------|------| -----|   ---- | --- | ---- |
| float32 | 2.8 msec  | 41.6 msec | 26.9 msec | 5 msec | 2.4 msec |
| float16 | 1.86 msec | 38.2 msec| 26.6 msec | 4.6 msec | 1.9 msec |
|bfloat16|1.86 msec |38.3 msec | 26.6 msec | 4.6 msec | 1.9 msec |

And benchmark script
```python
import torch

from timeit import default_timer
from itertools import product
from torch.utils.benchmark import Measurement, Timer

def bench_mask_fill(
    n,
    binary_func,
    dtype=torch.float32,
) -> Measurement:
    t = Timer(
        stmt=f"x.masked_fill(y, -17.0); torch.mps.synchronize()",
        setup=f"x,y = torch.rand(1, 20, {n}, {n}, dtype={dtype}, device='mps'), torch.ones({n}, {n}, device='mps').triu().bool()",
        globals = {'f': binary_func},
        language="python", timer=default_timer
    )
    return t.blocked_autorange()

if __name__ == "__main__":
    n = 1024
    for dtype in [torch.float32, torch.float16, torch.bfloat16]:
        eager_t = bench_mask_fill(n, torch.fmax, dtype)
        use_msec = eager_t.mean > 1e-4
        multiplier = 1e3 if use_msec else 1e6
        uname = "msec" if use_msec else "usec"
        print(f"torch.masked_fill_() {str(dtype):>14} {eager_t.mean*multiplier:>7.2f} {uname}")
```
Fixes https://github.com/pytorch/pytorch/issues/143477
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147369
Approved by: https://github.com/dcci
ghstack dependencies: #147977
2025-02-26 18:39:15 +00:00
Nikita Shulga
9ed40af917 [BE][EZ] Delete MacOS-12.3 xfail list (#147905)
As PyTorch requires at least MacOS-13 (and Metal-3) to work, delete any pre-MacoS13 checks from test script
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147905
Approved by: https://github.com/dcci
ghstack dependencies: #147892
2025-02-26 05:08:09 +00:00
Nikita Shulga
346bbefa63 [BE] Parameterize TestSDPA in test_mps.py (#147856)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147856
Approved by: https://github.com/Skylion007
2025-02-25 16:07:24 +00:00
Isalia20
a695aae89b [MPS] fix attention for >4d tensors (#147545)
Fixes #147443

and adds tests for >4d tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147545
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-25 13:55:28 +00:00
Davide Italiano
4e934ee5a7 [MPS] Add eager support for xlog1py. (#147687)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147687
Approved by: https://github.com/malfet
2025-02-24 01:23:59 +00:00
Nikita Shulga
f03e7f3801 [MPS] Workaround rng bug for 5D tensors (#147667)
For some reason MPSGraph returns repeated values is tensor dimention is
larger than 4, which can be clearly seen by running following
```swift
import Metal
import MetalPerformanceShadersGraph

func randMPS(device: MTLDevice, obuf: MTLBuffer, nelem: Int, ndim: Int = 5) {
  let graph = MPSGraph()
  var dims = Array(repeating: 1, count: ndim)
  dims[0] = nelem
  let shape = dims.map { NSNumber(value: $0) }
  let randNode = graph.randomUniformTensor(withShape: shape, seed: 42, name: nil)
  let mpsOutputBuffer = MPSGraphTensorData(obuf, shape: shape, dataType: .float32)
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  graph.run(with: queue, feeds: [:], targetOperations: nil, resultsDictionary: [randNode: mpsOutputBuffer])
}

func printBuf(_ prefix: String, buf: MTLBuffer, nelem: Int) {
  let buf_data = buf.contents().assumingMemoryBound(to: Float.self)
  print(prefix)
  for i in 0..<nelem {
      print(buf_data[i], terminator: i != nelem - 1 ? " " : "\n")
  }
}

guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
print("Using device \(device.name)")

let nelem = 2
guard let buf = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }

randMPS(device: device, obuf: buf, nelem: nelem, ndim: 4)
printBuf("4D uniform", buf: buf, nelem: nelem)

randMPS(device: device, obuf: buf, nelem: nelem, ndim: 5)
printBuf("5D uniform", buf: buf, nelem: nelem)
```

Workaround by flatting the tensor if it's contiguous

Fixes https://github.com/pytorch/pytorch/issues/147624
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147667
Approved by: https://github.com/dcci
2025-02-23 16:52:01 +00:00
Nikita Shulga
198ffbdf11 [MPS] Implement and test round.decimals (#147266)
If inductor can do it, why not eager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147266
Approved by: https://github.com/Skylion007
ghstack dependencies: #147286
2025-02-16 23:17:13 +00:00
tim
b9a22b3f37 bug fix: ensure 4d input in _scaled_dot_product_attention_math_mps (#146623)
This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.

The issue was found in https://github.com/hiyouga/LLaMA-Factory/issues/6835, in [transformers qwen2_vl](1590c66430/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py (L373C14-L373C93)), 3d q/k/v were passed into sdpa function, which lead to an error.

Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.

---
reproduce code:
```
import torch
import torch.nn.functional as F

head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1

q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)

oo_cpu = F.scaled_dot_product_attention(
    q.to("cpu"),
    k.to("cpu"),
    v.to("cpu"),
    attention_mask.to("cpu"),
    dropout_p=0.0
)

if torch.backends.mps.is_available():
    oo_mps = F.scaled_dot_product_attention(
        q.to("mps"),
        k.to("mps"),
        v.to("mps"),
        attention_mask.to("mps"),
        dropout_p=0.0
    )
    assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)
```

error outputs:
```
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
    oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
```

hardware and envs:
```
torch               2.6.0
apple m3 max
```

---

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146623
Approved by: https://github.com/malfet

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-13 07:00:51 +00:00
Isalia20
17a808557c [MPS] cholesky ex version (#146799)
PR #145701 didn't have experimental version of cholesky. This PR adds that version

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146799
Approved by: https://github.com/malfet
2025-02-13 07:00:21 +00:00
Isalia20
d763093b49 [MPS] fix lu factor for large tensors with bs>1 (#146753)
Try this:
```python
import torch

batch_size = 2
A = torch.eye(256, device="mps")[None, :, :].expand(batch_size, -1, -1) + 0.1 * torch.randn((batch_size, 256, 256), device="mps")
A_cpu = A.cpu()
LU_cpu, pivots_cpu = torch.linalg.lu_factor(A_cpu)
LU, pivots = torch.linalg.lu_factor(A)
torch.testing.assert_close(LU.cpu(), LU_cpu)
```
You'll get huge difference in LU tensors
<img width="706" alt="Screenshot 2025-02-08 at 12 14 39" src="https://github.com/user-attachments/assets/b45f2b3c-e0a5-49c8-aa07-42792150b781" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146753
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-11 00:37:07 +00:00
Davide Italiano
dfe3b64282 [mps] Implement eager support for spherical_bessel_j0 (#146818)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146818
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-10 16:58:05 +00:00
Nikita Shulga
611ca163fd [MPS] Add bilineard2d_aa implementation (#145526)
Interesting quirk of the algorithm, that is not very well documented, is that value of align_corners is ignored in antialias mode, see arguments of
e8304f08fe/aten/src/ATen/native/cpu/UpSampleKernel.cpp (L747-L751)

Error out on  uint8 implementation(as it relies on a very fragile integer integer arithmetic), as it's not implemented on any other Accelerator devices at the moment.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145526
Approved by: https://github.com/dcci
2025-02-10 15:03:14 +00:00
Isalia20
0ab67299c3 [MPS] lu unpack (#146681)
Implements lu unpack function on MPS. Haven't added new tests because they are covered by removing the lu_unpack from UNIMPLEMENTED_XFAILLIST in test_mps with `test_output_match` function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146681
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-08 00:16:17 +00:00
Nikita Shulga
624d94bdb8 [MPS] Extend torch.special.sinc to complex (#146648)
And to integral data types as well

Was too lazy to deduce the formula myself(or write a sympy script), but ChatGPT did a decent job of doing it, though it forgot that input must be multiplied by $$\pi$$:
```math
\text{Re}\left(\text{sinc}(x + i y)\right) = \frac{\sin(x)\cosh(y) x - \cos(x)\sinh(y) y}{x^2 + y^2}
```
```math
\text{Im}\left(\text{sinc}(x + i y)\right) = \frac{\cos(x)\sinh(y) x + \sin(x)\cosh(y) y}{x^2 + y^2}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146648
Approved by: https://github.com/dcci
2025-02-07 01:12:37 +00:00
Davide Italiano
46390e9a37 [mps] Implement support for sinc() operator (inductor and eager). (#146539)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146539
Approved by: https://github.com/malfet, https://github.com/jansel

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-06 16:37:27 +00:00
Isalia20
0dc03134d9 [MPS] linalg solve implementation (#146531)
Fixes #98222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146531
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-06 00:57:49 +00:00
Davide Italiano
8a2000fd42 [MPS] Implement support for zeta (both eager and inductor). (#146465)
A test was failing in inductor (`test_pointwise_zeta`) -- and I realized the operation was missing also from eager.
Implemented for both, leveraging the kernel. Happy to split in two (one PR for eager, one for inductor) if folks prefer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146465
Approved by: https://github.com/malfet
2025-02-05 13:55:50 +00:00
Nikita Shulga
aafaf4016f [MPS] Add error checking when dispatching kernel (#146458)
That thread-group size should not exceed maximum thread group size
Add regression test to validate that
Make failures like https://github.com/pytorch/pytorch/issues/146430 much easier to detect
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146458
Approved by: https://github.com/dcci
2025-02-05 02:56:40 +00:00
Isalia20
e3643e1e0e [MPS] Add linalg det and fix lu factor for non contiguous tensors (#146279)
Requested in #77764

This PR adds support for linalg.det on MPS and fixes lu factor for non contiguous tensors, current implementation crashed on any kind of non-contiguous tensor with an error:
```
-[AGXG13XFamilyCommandBuffer blitCommandEncoderCommon:]:833: failed assertion `A command encoder is already encoding to this command buffer'
zsh: abort      python det.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146279
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-02-03 06:06:43 +00:00
Isalia20
5d55a6585d [MPS] lu factor ex implementation (#144651)
Implements `torch.linalg.lu_factor_ex`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144651
Approved by: https://github.com/malfet
2025-02-02 15:09:49 +00:00
Nikita Shulga
99a0940991 [MPS] Fix regression in con-contig bitwise ops (#146085)
Caused by https://github.com/pytorch/pytorch/pull/128393 that change semantic of `needsGather`, which resulted in silent correctness errors on MacOS-15+ if output tensor is non-contiguous

Fixes https://github.com/pytorch/pytorch/issues/145203

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146085
Approved by: https://github.com/dcci
2025-01-30 22:36:56 +00:00
Nikita Shulga
1fdb4d65c0 [MPS] Extend torch.mm/torch.bmm to integral types (#145809)
By using `naive_mm` kernel, but make sure that accumulation is done over int32 for smaller int types (and float for half and bfloat) as well as adding `navie_bmm` that follows the same pattern.
Remove stale restriction on `torch.dot` (which works fine on MacOS-14/15)
This also enables integer op flavors for:
- `addmv`
- `einsum`
- `inner`
- `linalg.multi_dot`
- `matmul`
- `mv`
- `tensordot`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145809
Approved by: https://github.com/dcci
2025-01-30 19:35:25 +00:00