By adding following template
```metal
template <typename T, typename F>
kernel void unary_strided(
device result_of<F, T>* output [[buffer(0)]],
constant T* input [[buffer(1)]],
constant long* sizes [[buffer(2)]],
constant long* input_strides [[buffer(3)]],
constant long* output_strides [[buffer(4)]],
constant uint& ndim,
uint index [[thread_position_in_grid]]) {
F f;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim);
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
output[output_offs] = f(input[input_offs]);
}
```
and instantiating it for all existing unary shaders, which eliminates the need to any intermediate copies.
No extra testing are needed as those cases are already covered by `test_output_grad_match_corrcoef_cpu_float32` as well as `test_unary_ops_storage_offset_strided`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148468
Approved by: https://github.com/dcci
By introducing a concept of non-commutative binary op and renaming all op templates from `bitwise_foo_tensor` and `bitwise_foo_scalar` to `bitwise_foo_tensor_tensor` and `bitwise_foo_tensor_scalar`
Add regression tests
Please note, that for some undefined values MPS and CPU behaviors are different, for example
```
>>> import torch
>>> 4095 >> torch.arange(12, device="mps", dtype=torch.uint8)
tensor([255, 255, 255, 255, 255, 127, 63, 31, 15, 7, 3, 1],
device='mps:0', dtype=torch.uint8)
>>> 4095 >> torch.arange(12, device="cpu", dtype=torch.uint8)
tensor([255, 127, 63, 31, 15, 7, 3, 1, 0, 0, 0, 0],
dtype=torch.uint8)
```
Because on CPU scalar is cast to output dtype before operation is performed, but on MPS this happens after the op is done
Fixes https://github.com/pytorch/pytorch/issues/147889
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148686
Approved by: https://github.com/albanD
ghstack dependencies: #148685
Fixes bug introduced by https://github.com/pytorch/pytorch/pull/148350
Before this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[ 0.0000, 1.4142, 2.0000, 2.4495],
[ 80.0000, 82.0000, 84.0000, 86.0000],
[ 96.0000, 98.0000, 100.0000, 102.0000],
[112.0000, 114.0000, 116.0000, 118.0000]], device='mps:0')
```
After this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[0.0000, 1.4142, 2.0000, 2.4495],
[4.0000, 4.2426, 4.4721, 4.6904],
[5.6569, 5.8310, 6.0000, 6.1644],
[6.9282, 7.0711, 7.2111, 7.3485]], device='mps:0')
```
One can not avoid copies if both input and output tensors have the same strides, one needs to make sure that they are dense-in-storage (transposed tensor would be dense, but say selecting every odd and even column wouldn't)
Add regression test to prevent those from happening again
Also, no need to check that sizes match, luckily it is checked by the structured op (and `out` for unary ops does not support broadcasting, I just checked)
Revived needs_copy_logic, though it will become irrelevant after https://github.com/pytorch/pytorch/pull/148468 is landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148512
Approved by: https://github.com/janeyx99
I was a bit concerned when I saw in #148272 that metal unary kernel was 0.02x of the performance of what we had with MPS Graphs for sqrt(for non contiguous) tensors. This change makes it so that copying is only done if we don't have same strided tensors(for input/output). So if out tensor is not provided then we don't do copy(don't call contiguous) at all and dispatch the kernel as is. After making this change the script that I listed at the end of the above PR has the same execution time as the non-transposed one.
Times for reference(on transposed tensor where matrix is NxN matrix):
| N | time_old | time_new |
|-------|--------------------|--------------------|
| 100 | 0.0002241021 | 0.0001548659 |
| 1000 | 0.0005934822 | 0.0002150342 |
| 10000 | 0.3242016407 | 0.0045755033 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148350
Approved by: https://github.com/janeyx99
Low hanging fruits, all ops for these are implemented so just adding them to native functions adds the functionality on mps. Probably next op I should add should be lu solve seeing as how many ops need it for the grad calculation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148287
Approved by: https://github.com/malfet
Those kernels, instead of being instantiated for half2 (which corresponds to ComplexHalf) were instnatiated for short2, which resuled in the following test
```
% python3 -c "import torch; print(torch.rand(6, device='mps', dtype=torch.chalf).sqrt())"
```
Fail with
```
RuntimeError: Failed to create function state object for: sqrt_complex_half_half
```
As sqrt is not implemented for CPU, add explicit test to `test_sqrt`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148285
Approved by: https://github.com/dcci
Issue #148219 highlighted the high dispatch times of ops which ran with MPS Graph on smaller tensors. This PR rewrites the sqrt with metal kernel to mitigate that issue
## Speedups:
Matrix size means NxN matrix here.

Code to generate the times(needs building the torch with old time and new time):
```python
import torch
import numpy as np
import time
import csv
matrix_sizes = [1, 100, 1000, 10_000]
num_runs = 1000
warmup_runs = 3
def run_sqrt(A):
torch.mps.synchronize()
start = time.perf_counter()
c = torch.sqrt(A)
torch.mps.synchronize()
end = time.perf_counter()
return c, end - start
results = {
'N': [],
'mean_time': [],
'std_time': []
}
for n in matrix_sizes:
print(f"\nBenchmarking N={n}")
try:
A_mps = torch.rand((n, n), dtype=torch.float32, device="mps")
for _ in range(warmup_runs):
_, _ = run_sqrt(A_mps)
times = []
for _ in range(num_runs):
_, t = run_sqrt(A_mps)
times.append(t)
mean_time = np.mean(times)
std_time = np.std(times)
results['N'].append(n)
results['mean_time'].append(mean_time)
results['std_time'].append(std_time)
print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")
except RuntimeError as e:
print(f"Error for N={n}: {e}")
continue
with open('sqrt_benchmark_times_new.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['N', 'mean_time', 'std_time'])
for i in range(len(results['N'])):
writer.writerow([
results['N'][i],
results['mean_time'][i],
results['std_time'][i]
])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148272
Approved by: https://github.com/malfet
- Move `pos_from_thread_index and `offset_from_pos` from `UnfoldBackward.metal` into `c10/metal/indexing.h` header
- Initial idea were to implement `StridedTensor` and `ConstStridedTensor` and use them to have masked_fill kernel a something simple as the following loop
```metal
ConstStridedTensor<bool> mask(mask_data, sizes, mask_strides, ndim);
if (mask[thread_index]) {
StridedTensor<T> input(input_data, sizes, input_strides, ndim);
input[thread_index] = val;
}
```
But though it looks elegant and works correctly, performance wise it's much slower that the existing MPS shader (see table below), as int64 divisions on M2 GPU are really slow
- Solved performance issue by implementing 3 flavors of the same shader: `dense`, that is used when both input and mask are dense tensors of the same size, `broadcast`, which is used when `mask` is leading dimensions expandable into input tensor and `strided` which is a general purpose fallback, but still computes position in the tensors only ones. As result, perf is even better than existing MPS shader for dense and broadcast able tensors.
Performance measured on M2Pro thru different iterations of the same shader
| dtype | MPS | int64-idx | int64-inlined | 32-bit strided | 32-bit broadcasted |
| ------|------| -----| ---- | --- | ---- |
| float32 | 2.8 msec | 41.6 msec | 26.9 msec | 5 msec | 2.4 msec |
| float16 | 1.86 msec | 38.2 msec| 26.6 msec | 4.6 msec | 1.9 msec |
|bfloat16|1.86 msec |38.3 msec | 26.6 msec | 4.6 msec | 1.9 msec |
And benchmark script
```python
import torch
from timeit import default_timer
from itertools import product
from torch.utils.benchmark import Measurement, Timer
def bench_mask_fill(
n,
binary_func,
dtype=torch.float32,
) -> Measurement:
t = Timer(
stmt=f"x.masked_fill(y, -17.0); torch.mps.synchronize()",
setup=f"x,y = torch.rand(1, 20, {n}, {n}, dtype={dtype}, device='mps'), torch.ones({n}, {n}, device='mps').triu().bool()",
globals = {'f': binary_func},
language="python", timer=default_timer
)
return t.blocked_autorange()
if __name__ == "__main__":
n = 1024
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
eager_t = bench_mask_fill(n, torch.fmax, dtype)
use_msec = eager_t.mean > 1e-4
multiplier = 1e3 if use_msec else 1e6
uname = "msec" if use_msec else "usec"
print(f"torch.masked_fill_() {str(dtype):>14} {eager_t.mean*multiplier:>7.2f} {uname}")
```
Fixes https://github.com/pytorch/pytorch/issues/143477
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147369
Approved by: https://github.com/dcci
ghstack dependencies: #147977
This pr addresses the issue in the MPS backend for `_scaled_dot_product_attention_math_mps` where a 3d input like (num_heads, seq_len, query_dim) cannot be automatically treated as (1, num_heads, seq_len, query_dim), which can be inferred on cpu or cuda, which can be circumvented by adding a util function to ensure a 4d shape.
The issue was found in https://github.com/hiyouga/LLaMA-Factory/issues/6835, in [transformers qwen2_vl](1590c66430/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py (L373C14-L373C93)), 3d q/k/v were passed into sdpa function, which lead to an error.
Considering consistency, since this pattern might pop up elsewhere in the transformers codebase, I think it makes more sense to maintain the same intuition across all platforms.
---
reproduce code:
```
import torch
import torch.nn.functional as F
head_num, seq_len, embed_dim = 16, 16, 80
bsz = 1
q = torch.randn(head_num, seq_len, embed_dim)
k = torch.randn(head_num, seq_len, embed_dim)
v = torch.randn(head_num, seq_len, embed_dim)
attention_mask = torch.ones(1, seq_len, seq_len)
oo_cpu = F.scaled_dot_product_attention(
q.to("cpu"),
k.to("cpu"),
v.to("cpu"),
attention_mask.to("cpu"),
dropout_p=0.0
)
if torch.backends.mps.is_available():
oo_mps = F.scaled_dot_product_attention(
q.to("mps"),
k.to("mps"),
v.to("mps"),
attention_mask.to("mps"),
dropout_p=0.0
)
assert torch.allclose(oo_cpu, oo_mps.to("cpu"), atol=1e-5)
```
error outputs:
```
Traceback (most recent call last):
File "/opt/homebrew/Caskroom/miniconda/base/envs/torch-dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-5169b8d2c5dd>", line 21, in <module>
oo_mps = F.scaled_dot_product_attention(
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
```
hardware and envs:
```
torch 2.6.0
apple m3 max
```
---
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146623
Approved by: https://github.com/malfet
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Implements lu unpack function on MPS. Haven't added new tests because they are covered by removing the lu_unpack from UNIMPLEMENTED_XFAILLIST in test_mps with `test_output_match` function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146681
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
And to integral data types as well
Was too lazy to deduce the formula myself(or write a sympy script), but ChatGPT did a decent job of doing it, though it forgot that input must be multiplied by $$\pi$$:
```math
\text{Re}\left(\text{sinc}(x + i y)\right) = \frac{\sin(x)\cosh(y) x - \cos(x)\sinh(y) y}{x^2 + y^2}
```
```math
\text{Im}\left(\text{sinc}(x + i y)\right) = \frac{\cos(x)\sinh(y) x + \sin(x)\cosh(y) y}{x^2 + y^2}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146648
Approved by: https://github.com/dcci
A test was failing in inductor (`test_pointwise_zeta`) -- and I realized the operation was missing also from eager.
Implemented for both, leveraging the kernel. Happy to split in two (one PR for eager, one for inductor) if folks prefer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146465
Approved by: https://github.com/malfet
Requested in #77764
This PR adds support for linalg.det on MPS and fixes lu factor for non contiguous tensors, current implementation crashed on any kind of non-contiguous tensor with an error:
```
-[AGXG13XFamilyCommandBuffer blitCommandEncoderCommon:]:833: failed assertion `A command encoder is already encoding to this command buffer'
zsh: abort python det.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146279
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
By using `naive_mm` kernel, but make sure that accumulation is done over int32 for smaller int types (and float for half and bfloat) as well as adding `navie_bmm` that follows the same pattern.
Remove stale restriction on `torch.dot` (which works fine on MacOS-14/15)
This also enables integer op flavors for:
- `addmv`
- `einsum`
- `inner`
- `linalg.multi_dot`
- `matmul`
- `mv`
- `tensordot`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145809
Approved by: https://github.com/dcci