By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions
Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512 | 202.1 | 131.8 |
| 1024x1024 | 780.6 | 176.9 |
| 2048x2048 | 1423.4 | 339.9 |
| 4096x4097 | 2982.2 | 1047.2 |
Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](86389bf970/mlx/backend/metal/kernels/reduction/ops.h (L15-L18)) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: #150452, #150457
Not sure how it worked in the past, but fence should be before first read from the shared memory, not after it.
This bug was exposed by https://github.com/pytorch/pytorch/pull/148969 which removed unnecessary barrier before calling `threadgroup_reduce` functions
Test plan:
```
% python3 generate.py --checkpoint_path checkpoints/stories15M/model.pth --prompt "Once upon a time" --device mps --compile
```
Before that it produced gibberish, now it works fine
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149437
Approved by: https://github.com/manuelcandales, https://github.com/dcci
threadgroup_argmin used to return input type, which is wrong, it should have returned `int` or `long`
Change signatures of both thredgroup_argmin and threadgroup_argmax to return int, as group size is small, no need to carry over large integeres
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149020
Approved by: https://github.com/jansel
ghstack dependencies: #148969, #148975, #149004
I wish I knew how to extract Metal warnings during JIT compilation but https://developer.apple.com/documentation/metal/mtldevice/makelibrary(source:options:)?changes=_7&language=objc is a lie as `error:` stays `nil` unless shader compilation fails. But when it does following warnings are thrown
```
program_source:666:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (auto idx = 1; idx < size; ++idx) {
~~~ ^ ~~~~
program_source:677:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (auto idx = 1; idx < size; ++idx) {
~~~ ^ ~~~~
program_source:688:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (auto idx = 1; idx < size; ++idx) {
~~~ ^ ~~~~
program_source:699:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (auto idx = 1; idx < size; ++idx) {
~~~ ^ ~~~~
program_source:710:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (auto idx = 1; idx < size; ++idx) {
~~~ ^ ~~~~
program_source:723:26: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (auto idx = 1; idx < size; ++idx) {
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146549
Approved by: https://github.com/dcci
- Add `threadgroup_sum` template to `c10/metal/reduction_utils.h` that so far uses barrier to compute the reductions
TODOs:
- Implement efficient reduction using cooperative functions such as `simd_shuffle_down`
- Figure out how to merge several sum reduction together
- Implement `reduction_store` that will only write results from the first thread
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146380
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #146369, #146370