Fixes#142397
Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests
Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.
By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`
This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions
Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512 | 202.1 | 131.8 |
| 1024x1024 | 780.6 | 176.9 |
| 2048x2048 | 1423.4 | 339.9 |
| 4096x4097 | 2982.2 | 1047.2 |
Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
- Distinguish between conjugated/non_conjugated inputs by appending conjugation to the operator key
- For matmul or dot, add `conjugateWithTensor:name:` calls before running the op
- Enable testing for conjugated ops by passing `include_conjugated_inputs` to opinfo
- Filter `include_conjugated_inputs` argument from `sample_inputs_window` (probably should have landed as separate PR)
- Preserve conj property when gathering the views, that fixes `cov` operator
Fixes https://github.com/pytorch/pytorch/issues/148156
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150157
Approved by: https://github.com/dcci
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes.
Been dreaded to write this PR for a while, as it end up to be pretty bulky:
- Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10:🤘:ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType`
- Add `c10:🤘:cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values
- Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types
- Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op
Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.)
| | dense-dense | transp-transp | dense-transp | transp-dense | dense-scalar | dense-bcast |
|-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- |
| fmax (torch.float16, torch.float16) | 160.9 | 159.9 | 270.5 | 270.9 | 236.6 | 293.0
| fmax (torch.float32, torch.float32) | 176.9 | 171.0 | 273.7 | 293.5 | 242.6 | 294.2
| fmax (torch.float32, torch.float16) | 171.4 | 170.9 | 283.6 | 303.0 | 253.7 | 302.3
| add (torch.float16, torch.float16) | 218.0 | 223.6 | 221.0 | 222.0 | 214.9 | 218.3
| add (torch.float32, torch.float32) | 227.4 | 233.9 | 228.8 | 231.9 | 218.9 | 221.4
| add (torch.float32, torch.float16) | 226.1 | 227.5 | 227.5 | 226.9 | 177.0 | 190.8
TODOS:
- Include input and output dtype in non-cast kernel name
- Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR`
- Extend mixed_dytpes testing via OpInfo
Fixes https://github.com/pytorch/pytorch/issues/149951
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149974
Approved by: https://github.com/manuelcandales
Third most voted op from #77764
Tests were deleted because they are covered by the regular test_output_match tests so those were redundant and were added in the last PR before the nanmedian dim version would be implemented
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149680
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Implements nanmedian on MPS. This implementation only implements `torch.nanmedian(tensor)` without `keepdim` and `dim`
Will implement nanmedian with dim and keepdim in a followup
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149407
Approved by: https://github.com/malfet
By adding following template
```metal
template <typename T, typename F>
kernel void unary_strided(
device result_of<F, T>* output [[buffer(0)]],
constant T* input [[buffer(1)]],
constant long* sizes [[buffer(2)]],
constant long* input_strides [[buffer(3)]],
constant long* output_strides [[buffer(4)]],
constant uint& ndim,
uint index [[thread_position_in_grid]]) {
F f;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim);
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
output[output_offs] = f(input[input_offs]);
}
```
and instantiating it for all existing unary shaders, which eliminates the need to any intermediate copies.
No extra testing are needed as those cases are already covered by `test_output_grad_match_corrcoef_cpu_float32` as well as `test_unary_ops_storage_offset_strided`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148468
Approved by: https://github.com/dcci
By introducing a concept of non-commutative binary op and renaming all op templates from `bitwise_foo_tensor` and `bitwise_foo_scalar` to `bitwise_foo_tensor_tensor` and `bitwise_foo_tensor_scalar`
Add regression tests
Please note, that for some undefined values MPS and CPU behaviors are different, for example
```
>>> import torch
>>> 4095 >> torch.arange(12, device="mps", dtype=torch.uint8)
tensor([255, 255, 255, 255, 255, 127, 63, 31, 15, 7, 3, 1],
device='mps:0', dtype=torch.uint8)
>>> 4095 >> torch.arange(12, device="cpu", dtype=torch.uint8)
tensor([255, 127, 63, 31, 15, 7, 3, 1, 0, 0, 0, 0],
dtype=torch.uint8)
```
Because on CPU scalar is cast to output dtype before operation is performed, but on MPS this happens after the op is done
Fixes https://github.com/pytorch/pytorch/issues/147889
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148686
Approved by: https://github.com/albanD
ghstack dependencies: #148685
Fixes bug introduced by https://github.com/pytorch/pytorch/pull/148350
Before this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[ 0.0000, 1.4142, 2.0000, 2.4495],
[ 80.0000, 82.0000, 84.0000, 86.0000],
[ 96.0000, 98.0000, 100.0000, 102.0000],
[112.0000, 114.0000, 116.0000, 118.0000]], device='mps:0')
```
After this change
```
% python3 -c "import torch; x, y = torch.arange(128.0, device='mps').reshape(2, 8, 8).unbind(0); print(torch.sqrt(x[::2, ::2], out=y[::2, ::2]))"
tensor([[0.0000, 1.4142, 2.0000, 2.4495],
[4.0000, 4.2426, 4.4721, 4.6904],
[5.6569, 5.8310, 6.0000, 6.1644],
[6.9282, 7.0711, 7.2111, 7.3485]], device='mps:0')
```
One can not avoid copies if both input and output tensors have the same strides, one needs to make sure that they are dense-in-storage (transposed tensor would be dense, but say selecting every odd and even column wouldn't)
Add regression test to prevent those from happening again
Also, no need to check that sizes match, luckily it is checked by the structured op (and `out` for unary ops does not support broadcasting, I just checked)
Revived needs_copy_logic, though it will become irrelevant after https://github.com/pytorch/pytorch/pull/148468 is landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148512
Approved by: https://github.com/janeyx99
I was a bit concerned when I saw in #148272 that metal unary kernel was 0.02x of the performance of what we had with MPS Graphs for sqrt(for non contiguous) tensors. This change makes it so that copying is only done if we don't have same strided tensors(for input/output). So if out tensor is not provided then we don't do copy(don't call contiguous) at all and dispatch the kernel as is. After making this change the script that I listed at the end of the above PR has the same execution time as the non-transposed one.
Times for reference(on transposed tensor where matrix is NxN matrix):
| N | time_old | time_new |
|-------|--------------------|--------------------|
| 100 | 0.0002241021 | 0.0001548659 |
| 1000 | 0.0005934822 | 0.0002150342 |
| 10000 | 0.3242016407 | 0.0045755033 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148350
Approved by: https://github.com/janeyx99
Low hanging fruits, all ops for these are implemented so just adding them to native functions adds the functionality on mps. Probably next op I should add should be lu solve seeing as how many ops need it for the grad calculation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148287
Approved by: https://github.com/malfet
Those kernels, instead of being instantiated for half2 (which corresponds to ComplexHalf) were instnatiated for short2, which resuled in the following test
```
% python3 -c "import torch; print(torch.rand(6, device='mps', dtype=torch.chalf).sqrt())"
```
Fail with
```
RuntimeError: Failed to create function state object for: sqrt_complex_half_half
```
As sqrt is not implemented for CPU, add explicit test to `test_sqrt`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148285
Approved by: https://github.com/dcci
Issue #148219 highlighted the high dispatch times of ops which ran with MPS Graph on smaller tensors. This PR rewrites the sqrt with metal kernel to mitigate that issue
## Speedups:
Matrix size means NxN matrix here.

Code to generate the times(needs building the torch with old time and new time):
```python
import torch
import numpy as np
import time
import csv
matrix_sizes = [1, 100, 1000, 10_000]
num_runs = 1000
warmup_runs = 3
def run_sqrt(A):
torch.mps.synchronize()
start = time.perf_counter()
c = torch.sqrt(A)
torch.mps.synchronize()
end = time.perf_counter()
return c, end - start
results = {
'N': [],
'mean_time': [],
'std_time': []
}
for n in matrix_sizes:
print(f"\nBenchmarking N={n}")
try:
A_mps = torch.rand((n, n), dtype=torch.float32, device="mps")
for _ in range(warmup_runs):
_, _ = run_sqrt(A_mps)
times = []
for _ in range(num_runs):
_, t = run_sqrt(A_mps)
times.append(t)
mean_time = np.mean(times)
std_time = np.std(times)
results['N'].append(n)
results['mean_time'].append(mean_time)
results['std_time'].append(std_time)
print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")
except RuntimeError as e:
print(f"Error for N={n}: {e}")
continue
with open('sqrt_benchmark_times_new.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['N', 'mean_time', 'std_time'])
for i in range(len(results['N'])):
writer.writerow([
results['N'][i],
results['mean_time'][i],
results['std_time'][i]
])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148272
Approved by: https://github.com/malfet
- Move `pos_from_thread_index and `offset_from_pos` from `UnfoldBackward.metal` into `c10/metal/indexing.h` header
- Initial idea were to implement `StridedTensor` and `ConstStridedTensor` and use them to have masked_fill kernel a something simple as the following loop
```metal
ConstStridedTensor<bool> mask(mask_data, sizes, mask_strides, ndim);
if (mask[thread_index]) {
StridedTensor<T> input(input_data, sizes, input_strides, ndim);
input[thread_index] = val;
}
```
But though it looks elegant and works correctly, performance wise it's much slower that the existing MPS shader (see table below), as int64 divisions on M2 GPU are really slow
- Solved performance issue by implementing 3 flavors of the same shader: `dense`, that is used when both input and mask are dense tensors of the same size, `broadcast`, which is used when `mask` is leading dimensions expandable into input tensor and `strided` which is a general purpose fallback, but still computes position in the tensors only ones. As result, perf is even better than existing MPS shader for dense and broadcast able tensors.
Performance measured on M2Pro thru different iterations of the same shader
| dtype | MPS | int64-idx | int64-inlined | 32-bit strided | 32-bit broadcasted |
| ------|------| -----| ---- | --- | ---- |
| float32 | 2.8 msec | 41.6 msec | 26.9 msec | 5 msec | 2.4 msec |
| float16 | 1.86 msec | 38.2 msec| 26.6 msec | 4.6 msec | 1.9 msec |
|bfloat16|1.86 msec |38.3 msec | 26.6 msec | 4.6 msec | 1.9 msec |
And benchmark script
```python
import torch
from timeit import default_timer
from itertools import product
from torch.utils.benchmark import Measurement, Timer
def bench_mask_fill(
n,
binary_func,
dtype=torch.float32,
) -> Measurement:
t = Timer(
stmt=f"x.masked_fill(y, -17.0); torch.mps.synchronize()",
setup=f"x,y = torch.rand(1, 20, {n}, {n}, dtype={dtype}, device='mps'), torch.ones({n}, {n}, device='mps').triu().bool()",
globals = {'f': binary_func},
language="python", timer=default_timer
)
return t.blocked_autorange()
if __name__ == "__main__":
n = 1024
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
eager_t = bench_mask_fill(n, torch.fmax, dtype)
use_msec = eager_t.mean > 1e-4
multiplier = 1e3 if use_msec else 1e6
uname = "msec" if use_msec else "usec"
print(f"torch.masked_fill_() {str(dtype):>14} {eager_t.mean*multiplier:>7.2f} {uname}")
```
Fixes https://github.com/pytorch/pytorch/issues/143477
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147369
Approved by: https://github.com/dcci
ghstack dependencies: #147977