Commit Graph

727 Commits

Author SHA1 Message Date
Kurt Mohler
20149080f2 [MPS] Compute offset2bag/bag_size/max_indices in _embedding_bag (#163281)
Part of #162270

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163281
Approved by: https://github.com/malfet
2025-09-23 22:30:48 +00:00
joshuamarkovic
559e8d1c20 [doc]: Small typos (#162982)
Small typo fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162982
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-09-16 17:42:19 +00:00
Nikita Shulga
d25c35d2b2 [MPS] Fix [nan]median output for empty tensors (#162846)
It should be `NaN` rather than 0

Added respective checks to `test_empty_tensor`

Fixes https://github.com/pytorch/pytorch/issues/162798
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162846
Approved by: https://github.com/dcci
2025-09-12 22:26:29 +00:00
PyTorch MergeBot
468c1f9e9d Revert "[nn] Assert parsed iterable arguments are an appropriate length (#162340)"
This reverts commit b5e6e58050.

Reverted https://github.com/pytorch/pytorch/pull/162340 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break an MPS tests on ExecuTorch ([comment](https://github.com/pytorch/pytorch/pull/162340#issuecomment-3282676242))
2025-09-11 21:22:57 +00:00
Benjamin Glass
b5e6e58050 [nn] Assert parsed iterable arguments are an appropriate length (#162340)
Fixes #162327
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162340
Approved by: https://github.com/Skylion007
2025-09-10 15:15:49 +00:00
Kurt Mohler
583bbf7761 [MPS] Add native_dropout and native_dropout_backward (#162108)
Fixes #162002
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162108
Approved by: https://github.com/malfet
2025-09-09 01:44:06 +00:00
Isalia20
dcf385395d [MPS] Move sparsemps testing from test_mps to test_sparse (#161852)
Moves Sparse MPS testing from test_mps to test_sparse. Lots of skips now but I expect to remove them iteratively once ops are implemented

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161852
Approved by: https://github.com/malfet
2025-09-02 19:04:11 +00:00
Isalia20
f3697b033e [MPS] add bunch of unary funcs for sparse tensors (#161846)
adds bunch of unary functions for sparse tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161846
Approved by: https://github.com/malfet
2025-08-30 21:13:05 +00:00
Irakli Salia
8627a19adf [MPS] sparse add unary funcs + add for sparse tensors (#160839)
Adds several unary functions and add. Enables tests for unary functions in test_sparse but not enabling other tests yet, needs more ops before we fully migrate to testing SparseMPS with `test_sparse.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160839
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-08-30 01:09:00 +00:00
Nikita Shulga
7c30a9d7fc [MPS] Add slow version of kthvalue (#161817)
Which heavily borrows implementation logic from `topk`
As this method is non-deterministic, modified the logic for cpu-ops indices comparison with just an equality statement, as by default random numbers picked for input tensor allow for quite a lot of overlaps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161817
Approved by: https://github.com/dcci
2025-08-30 00:44:29 +00:00
Isalia20
3daf20f8e1 [MPS] fix empty input in posneg functions (#161824)
fix empty posneg function for mps:
```python
import torch

input_tensor = torch.empty(0, device="mps")
out_pos = torch.isposinf(input_tensor)
```

Gives:
```
RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED at "/Users/Irakli_Salia/Desktop/pytorch/aten/src/ATen/native/mps/OperationUtils.mm":551, please report a bug to PyTorch. Placeholder tensor is empty!
```

on main branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161824
Approved by: https://github.com/malfet
2025-08-29 23:12:04 +00:00
PyTorch MergeBot
f6368e934e Revert "[MPS] sparse add unary funcs + add for sparse tensors (#160839)"
This reverts commit 93c5112f46.

Reverted https://github.com/pytorch/pytorch/pull/160839 on behalf of https://github.com/atalman due to test_sparse_csr.py::TestSparseCompressedCPU::test_consistency_SparseCSR_asinh_cpu_complex64 [GH job link](https://github.com/pytorch/pytorch/actions/runs/17329155095/job/49201551217) [HUD commit link](93c5112f46) ([comment](https://github.com/pytorch/pytorch/pull/160839#issuecomment-3238093296))
2025-08-29 19:55:39 +00:00
Irakli Salia
93c5112f46 [MPS] sparse add unary funcs + add for sparse tensors (#160839)
Adds several unary functions and add. Enables tests for unary functions in test_sparse but not enabling other tests yet, needs more ops before we fully migrate to testing SparseMPS with `test_sparse.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160839
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-08-29 16:28:58 +00:00
rebeccajae
ee0ec21191 Ensure that tensors are contiguous before using no-graph MPS impl (#161641)
Fixes #161640

Check if tensors are contiguous before using the no-graph implementation. Using the script in the issue above with this change I get expected results.

```
MPS contiguous result sample: tensor([ 1.3600, -2.9516,  1.3207, -3.5132,  1.7061], device='mps:0')
MPS non-contig result sample: tensor([ 1.3600, -2.9516,  1.3207, -3.5132,  1.7061], device='mps:0')
CPU non-contig result sample: tensor([ 1.3600, -2.9516,  1.3207, -3.5132,  1.7061])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161641
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-08-27 22:31:57 +00:00
Kurt Mohler
121afd6a8f [MPS] Update avg_pool2d to use Metal kernel when ceil_mode=True (#161011)
Fixes #160743

The MPS impl of `avg_pool2d` seems to only give incorrect results when `ceil_mode=True`. I wrote a performance measurement script (0ee6e58643/avg_pool_mps/perf_2d.py) which tests a bunch of different cases and also marks the cases where MPS and CPU results do not match.

I found that if I update `avg_pool2d` to use the new Metal kernel in all cases, that fixes all the mismatches, but it also decreases performance for some of the `ceil_mode=False` cases. So I opted to only run the new Metal kernel when  `ceil_mode=True`, which does not significantly decrease performance in any of the cases tested.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161011
Approved by: https://github.com/malfet
2025-08-23 02:36:22 +00:00
can-gaa-hou
cee72119b2 [Test] Adding a testcase for constant_pad_nd (#161259)
Fixes #161066

This PR adds a simple testcase for constant_pad_nd on MPS as mentioned in https://github.com/pytorch/pytorch/pull/161149#issuecomment-3211701274

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161259
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-08-23 01:00:50 +00:00
Nikita Shulga
b0071c65e2 [MPS] Fix error check for torch.var on scalar (#160889)
Fixes https://github.com/pytorch/pytorch/issues/160738
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160889
Approved by: https://github.com/Skylion007
ghstack dependencies: #160850
2025-08-18 17:36:42 +00:00
Kurt Mohler
6382302990 [MPS] Add grid_sampler_3d for MPS (#160541)
This PR adds support for `grid_sampler_3d` for MPS with "bilinear" interpolation.

NOTE: "nearest" interpolation is not yet supported

Fixes #159882
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160541
Approved by: https://github.com/malfet
2025-08-15 16:19:25 +00:00
Nikita Shulga
7d87e358ac Fix MPS conv3d autocast bias dtype mismatch (#160423)
## Summary
- register conv3d with MPS autocast to ensure bias dtypes match under AMP
- add regression test chaining two Conv3d layers on MPS autocast

Written by Codex, see https://chatgpt.com/codex/tasks/task_e_689b64192df883278648935963d2776d

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160423
Approved by: https://github.com/dcci
2025-08-13 16:23:21 +00:00
Nikita Shulga
7d2ec704e4 Fix MPS autocast for ConvTranspose3d (#160345)
## Summary
- ensure ConvTranspose3d uses fp32 under MPS autocast
- add MPS autocast test for ConvTranspose3d

Generated by Codex, see https://chatgpt.com/codex/tasks/task_e_689a360388288327a2cac6f55bbfc42c

Fixes https://github.com/pytorch/pytorch/issues/160332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160345
Approved by: https://github.com/dcci
2025-08-11 21:01:52 +00:00
Nikita Shulga
d25c4f954d [MPS] Type-promote tensor-iterator common dtype (#160334)
Otherwise, `torch.add(FloatTensor, IntTensor, alpha=2)` and `torch.add(FloatTensor, IntTensor, alpha=2)` were dispatched to different kernels

Fixes https://github.com/pytorch/pytorch/issues/160208
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160334
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-08-11 17:53:56 +00:00
Isalia20
a84b60c0c4 [MPS] Sparse coalesce more dtypes to match cpu (#160254)
More dtypes to match the cpu

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160254
Approved by: https://github.com/malfet
2025-08-10 12:25:18 +00:00
Isalia20
7f4cb4a3e0 [MPS] coalesce for sparse tensors (#159729)
MPS coalesce function for sparse tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159729
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-08-08 13:49:55 +00:00
angelayi
74a754aae9 Add meta kernel for sdpa_math_for_mps (#159695)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159695
Approved by: https://github.com/malfet
ghstack dependencies: #159456
2025-08-05 22:27:06 +00:00
Nikita Shulga
f946b25865 [MPS] Speedup argmax/argmin (#159524)
By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
2025-07-31 16:18:32 +00:00
Nikita Shulga
1293405c8d [MPS] Add simd_[arg][max|min] (#158990)
And add eager tests for those.
Re-implement `threadgroup_[max|min]` using those function as they are significantly faster (though much slower than eager, due to the arg part) than before, which could be verified by running the following script
```python
import itertools
import timeit
import torch
from torch.utils.benchmark import Compare, Measurement, Timer

def bench_unary_op(func, x, label) -> Measurement:
    sync_cmd = "torch.mps.synchronize()" if "mps" in str(x.device) else ""
    t = Timer(
        stmt=f"f(x);{sync_cmd}",
        globals={"f": func, "x": x},
        language="python",
        timer=timeit.default_timer,
        sub_label=f"{func.__name__} ({str(x.dtype)})",
        description=label,
        env=torch.__version__,
    )
    return t.blocked_autorange()

def bench_reduction(
    reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32
) -> list[Measurement]:
    rc = []

    # Bench 2D with reduction over dim=0
    def f(t):
        return reduction_func(t, dim=0)[0]

    f.__name__ = reduction_func.__name__
    f_c = torch.compile(f, dynamic=False, fullgraph=True)

    for size in (512, 1024, 2048, 4096):
        x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
        rc_c, rc_e = f(x), f_c(x)
        rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e)
        rc.append(bench_unary_op(f, x, f"eager-{size}x{size}"))
        rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}"))
    return rc

def main() -> None:
    #dtypes = [torch.float16, torch.float32, torch.bfloat16, torch.int32, torch.int64]
    dtypes = [torch.float32, torch.int32, torch.int64]

    # Profile reduction ops
    rc = []
    for op, dtype in itertools.product([torch.max], dtypes):
        rc.extend(bench_reduction(op, dtype=dtype))
    Compare(rc).print()

if __name__ == "__main__":
    torch._dynamo.config.cache_size_limit = 2**16
    main()
```

Produces the following table before
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      297.3      |       531.6       |       394.1       |        2550.5       |       773.0       |        4904.7       |       3647.2      |        9682.0
      max (torch.int32)    |      297.8      |       359.2       |       387.7       |        1179.4       |       768.2       |        2175.0       |       3677.1      |        4495.9
      max (torch.int64)    |      278.7      |       541.4       |       410.2       |        2873.3       |       858.9       |        5620.4       |       6107.2      |       11176.1

Times are in microseconds (us).
```
And after
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float32)  |      307.9      |       265.3       |       401.0       |        340.8        |       766.5       |        661.9        |       3463.5      |        2829.5
      max (torch.int32)    |      293.5      |       263.1       |       405.0       |        338.8        |       761.4       |        672.5        |       3050.0      |        2688.6
      max (torch.int64)    |      308.2      |       255.7       |       417.4       |        341.4        |       877.0       |        695.0        |       5812.2      |        5762.2

```

`argmax`/`argmin` are much tricker due to the nan-handling logic that need to be added there.

Also fixes `torch.max/min` compilation for half-precision types, added regression types for it.

This PR also introduces a bunch of helper functions, such as `simd_broadcast` that works for int64 and `c10:🤘:pair` template, which are used by `simd_argmax` to return both value and index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158990
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-07-30 21:57:25 +00:00
Kurt Mohler
70d2e9ba45 [MPS] Avoid outputing zeros from exponential_ for MPS (#159386)
Fixes #159103
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159386
Approved by: https://github.com/malfet
2025-07-30 00:20:31 +00:00
Nikita Shulga
15bb81ea4f [2/N][CI] Remove MacOS-13 workarounds from tests (#159304)
Part of https://github.com/pytorch/pytorch/issues/159275

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159304
Approved by: https://github.com/dcci, https://github.com/cyyever
ghstack dependencies: #159277, #159278
2025-07-29 23:12:13 +00:00
Nikita Shulga
d0c00d9a69 [MPS] Do not crash if tensor dim > INT_MAX (#158824)
Looks like all MPS operations will crash if one of tensor dimentions are
greater than `2**31-1`

Change it into a structured exception, by checking tensor size before
attempting to create MPS Tensor

Add regression test for it. Before this change running following will abort with exception
```
% python3 -c "import torch; torch.randint(0, 10, (2**31,), dtype=torch.uint8, device='mps')"
/AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
zsh: abort      python3 -c·
```

Skip the test on MacOS-13, as it crashes somewhere deep in MPSGraph framework with
```
/AppleInternal/Library/BuildRoots/c651a45f-806e-11ed-a221-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:724: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158824
Approved by: https://github.com/dcci
ghstack dependencies: #158690, #158823
2025-07-22 15:12:26 +00:00
Joona Havukainen
194539e9c3 Address NaNs if SDPA is called with all values masked from query (#157727)
Fixes #156707

Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation.

These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157727
Approved by: https://github.com/malfet
2025-07-14 22:09:35 +00:00
Nikita Shulga
beed033b6e [MPS] Fix index_kernel for large tensors (#158064)
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6

Times are in microseconds (us).
```

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes https://github.com/pytorch/pytorch/issues/153560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158064
Approved by: https://github.com/dcci
2025-07-11 22:35:44 +00:00
Daisy Deng
8088958793 port 4 dynamo test files to Intel GPU (#157779)
For https://github.com/pytorch/pytorch/issues/114850, we will port test cases to Intel GPU. Six dynamo test files were ported in PR [#156056](https://github.com/pytorch/pytorch/pull/156056) and [#156575](https://github.com/pytorch/pytorch/pull/156575.) In this PR we will port 4 more dynamo test files.
We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- added XPU support in decorators like @requires_gpu
- enabled XPU for some test path
- added xfailIfXPU to skip xpu test when there is a bug.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157779
Approved by: https://github.com/guangyey, https://github.com/jansel
2025-07-11 10:11:49 +00:00
Xuehai Pan
fc0376e8b1 [BE][2/6] fix typos in test/ (test/test_*.py) (#157636)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636
Approved by: https://github.com/yewentao256, https://github.com/mlazos
ghstack dependencies: #156311, #156609
2025-07-09 11:02:23 +00:00
Nikita Shulga
a5c61eb78d [MPS][BE] Delete as_strided_tensorimpl_mps (#157772)
Because it's just copy-n-paste of `as_strided_tensorimpl` with call to `updateTensorBaseShape`, which is not called/used anywhere else.

Fixes https://github.com/pytorch/pytorch/issues/152701
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157772
Approved by: https://github.com/Skylion007
2025-07-08 17:02:36 +00:00
Kurt Mohler
510c398a4f Add max_pool3d backward pass for MPS (#157498)
Note on backward precision over fp16:

A float16 number has 10 bits of mantissa, 5 bits of exponent, and 1 bit for the sign. If the sign bit is positive, then with a mantissa $m$ and exponent $e$ represented in base 10, the number that the float16 format represents is $(1 + m / 1024)  \exp2(e)$. ([source](https://en.wikipedia.org/wiki/Half-precision_floating-point_format))

Consider adding two numbers $a$ and $b$ which have arbitrary mantissas, and say their exponents are $e_a = 1$ (so $2 \le a \lt 4$) and $e_b=-3$ (so $0.175 \le b \lt 0.25$). Assume that the result has the same exponent as $a$. Since the exponents differ by 4, we'll effectively need to truncate the 4 rightmost bits of $b$'s mantissa, which would introduce a maximum error on the order of $(2^4 / 1024)  \exp2(-3) \approx 0.002$.

The error is nearly the same if $e_b = -2$ (so $0.25 \le b \lt 0.5$), where the 3 rightmost bits are truncated, giving a maximum error on the order of $(2^3 / 1024)  \exp2(-2) \approx 0.002$. Same for $e_b=-1$.

So if we're adding up nine different numbers that all have exponents -3, -2, or -1, and they sum to a number with exponent 1, then we would expect a maximum error of several times greater than 0.002. In my comments above, summing those particular nine numbers in different ways gave results that ranged between 3.1816 and 3.1758, a difference of $0.0058 \approx 2.9  * 0.002$.

That's within the acceptable bounds, and we can safely just increase the error tolerance used in test_output_grad_match for the case of max_pool3d_backward with float16.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157498
Approved by: https://github.com/malfet
2025-07-07 19:46:44 +00:00
Manuel Candales
d56f11a1f2 [MPS] Implement logcumsumexp metal kernel (#156858)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156858
Approved by: https://github.com/malfet
ghstack dependencies: #157512
2025-07-03 18:16:25 +00:00
Nikita Shulga
5e636d664a [BE] @serialTest decorator must be called (#157388)
Otherwise it turns test into a trivial one(that always succeeds), as following example demonstrates
```python
import torch
from torch.testing._internal.common_utils import serialTest, run_tests, TestCase

class MegaTest(TestCase):
    @serialTest
    def test_foo(self):
        if hasattr(self.test_foo, "pytestmark"):
            print("foo has attr and it is", self.test_foo.pytestmark)
        print("foo")

    @serialTest()
    def test_bar(self):
        if hasattr(self.test_bar, "pytestmark"):
            print("bar has attr and it is", self.test_bar.pytestmark)
        print("bar")

if __name__ == "__main__":
    run_tests()
```

That will print
```
test_bar (__main__.MegaTest.test_bar) ... bar has attr and it is [Mark(name='serial', args=(), kwargs={})]
bar
ok
test_foo (__main__.MegaTest.test_foo) ... ok

----------------------------------------------------------------------
Ran 2 tests in 0.013s

```

Added assert that arg is boolean in the decorator to prevent such silent skips in the future

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157388
Approved by: https://github.com/clee2000
2025-07-02 19:15:19 +00:00
Nikita Shulga
019e30e3b8 [BE] Decorate LargeTensorTest with serialTests (#157382)
May be it'll help make M2-15 jobs more stable, as that was the last test run before OOM
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157382
Approved by: https://github.com/clee2000
2025-07-01 20:35:42 +00:00
Isalia20
a1282b1823 [MPS] Add boilerplate sparse code support (#157238)
This PR makes minimal changes to support sparse tensors on MPS. In the followup PRs I'll start adding different operations slowly so we can fix the issue of
https://github.com/pytorch/pytorch/issues/129842
which is highly requested(I assume because of whisper using sparse tensors)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157238
Approved by: https://github.com/malfet
2025-06-30 01:53:45 +00:00
Nikita Shulga
a1e4f1f98a [MPS] Reimplement tri[ul] as Metal shaders (#157179)
And add in-place flavor, as it is currently broken for non-contig tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157179
Approved by: https://github.com/dcci
2025-06-28 01:33:18 +00:00
Isalia20
653c52fe52 [MPS] Fix batch norm incorrect gradient (#156867)
Fixes #156555

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156867
Approved by: https://github.com/malfet
2025-06-25 23:05:49 +00:00
Joona Havukainen
20a74c370b Add error message with assert to topK if ndims() - dim > 4 (#155475)
Addressing #154890

Not really a proper fix but at least it's more informative than the current crash.

For a more long term solution I'm testing if we can use the TopK API released in MacOS14 as it does not have the same MPSScan op issue that the Sort and ArgSort are hitting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155475
Approved by: https://github.com/kulinseth
2025-06-13 21:10:06 +00:00
Nikita Shulga
dd41a3907c [MPS] Fix unary/binary ops for 2**32+ elem tensors (#155183)
By using `TensorIterator::with_32bit_indexing()` primitive

Add `bind_tensors` helper function that correctly sets up MPS tensors originating from TensorIterator

TODO: Add comments to bind_tensors as well asunit test, based on
```
python  -c "import torch;print((torch.rand(1, 1024, 1024, dtype=torch.bfloat16, device='mps') + torch.rand(5000, 1, 1, dtype=torch.bfloat16, device='mps')).sin())"
```

Fixes https://github.com/pytorch/pytorch/issues/154828
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155183
Approved by: https://github.com/cyyever, https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #155150, #155178, #155184
2025-06-05 18:57:14 +00:00
Roy Hvaara
9a4c08ddfc [MPS] Parametrize test_scaled_dot_product_attention_autocast (#155005)
Also moving comments inside the function scope for some of my previous regression tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155005
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-06-05 13:24:53 +00:00
Nikita Shulga
9cdce682a1 [MPS][BE] Reimplement log1p as Metal shader (#154936)
That should make it faster than MPSGraph implementation, but also
improves accuracy for small inputs, by using the algorithm described in [What Every Computer Scientist Should Know About Floating-Point Arithmetic](https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202), i.e. $log(1+x) = \frac{x * log(1+x)}{(1 + x) - 1}$ if $1 +x \neq 1$ else just $x$

Also tried using first 3 elements of Taylor series in Horner's form which also seems to work fine, i.e. $log(1+x) \approx x * (1 -x (\frac{1}{2} -  \frac{x}{3}))$

Replaced less accurate log1p implementation in `c10/metal/special_math.h` with generic one.

Parametrize and modify regression test to check for accuracy of small values

TODOs:
 - Do proper implementation for complex values as well, perhaps using 0408ba0a76/mlx/backend/metal/kernels/utils.h (L339)
 - May be implement it using Remez-like algorithm documented here 207f3b2b25/lib/msun/src/s_log1pf.c (L37)
 - Or use llvm's implementation from f393986b53/libclc/clc/lib/generic/math/clc_log1p.inc (L22)
 - Benchmark which algorithm is faster and delivers better accuracy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154936
Approved by: https://github.com/dcci, https://github.com/Skylion007
2025-06-03 14:10:13 +00:00
Joona Havukainen
981bdb39ca Enable ConvTranspose3D for FP32 and Complex64 (#154696)
Fixes #154615

Enables using ConvTranspose3D since it seems support exists both on MacOS 14 and 15.

For the half dtypes the discrepancy of CPU and GPU implementations is too large to conclude whether there is a bug in the implementation or not without a more rigorous study on what bounds are there to the expected error. So they are left unsupported for now and an assert is added to notify the user if the op is called with fp16 or bf16 inputs.

Tests for ConvTranspose3D were enabled for the supported data types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154696
Approved by: https://github.com/malfet
2025-06-02 16:24:03 +00:00
Isalia20
41092cb86c [MPS] index copy impl (#154326)
Second most requested op according to #154052

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154326
Approved by: https://github.com/malfet
2025-05-29 16:57:43 +00:00
Xuehai Pan
7ae204c3b6 [BE][CI][Easy] Run lintrunner on generated .pyi stub files (#150732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150732
Approved by: https://github.com/malfet, https://github.com/cyyever, https://github.com/aorenste
2025-05-27 14:58:02 +00:00
Nikita Shulga
975bbc63db [MPS][BE] Move fmod/remainder to Metal ops (#154280)
This accomplishes following:
 - Fixes correctness problem with large integer types (though probably makes it slower, but this could not be avoided if one wants to compute accurate answer)
 - Makes op faster for floating point types (as Metal kernel invocation is faster than creating MPSGraph)
 - Eliminates need for several correctness workarounds

Fixes https://github.com/pytorch/pytorch/issues/154171
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154280
Approved by: https://github.com/dcci
ghstack dependencies: #154275, #154290
2025-05-24 01:45:33 +00:00
Nikita Shulga
633ed01145 [MPS] Add support for two more isin variants (#154010)
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op, but needs a small tweak to handle float scalars
Make sure that `out` is resized to an expected value in `isin_Tensor_Tensor_out_mps`

Add unittests to validate that, but skip them on MacOS-13, where MPS op just returns garbage

Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154010
Approved by: https://github.com/Skylion007, https://github.com/dcci, https://github.com/manuelcandales
ghstack dependencies: #153970, #153971, #153997
2025-05-22 17:59:35 +00:00