Commit Graph

51 Commits

Author SHA1 Message Date
Jason Ansel
e90cf4abcf [inductor] Add some typing to common.py (#145691)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145691
Approved by: https://github.com/malfet
ghstack dependencies: #145690
2025-01-27 06:27:13 +00:00
Nikita Shulga
71caac2b30 [MPSInductor] Add rand support (#145705)
Using Philox4 as PRNG

Test plan (other that CI)
Run
```python
mport torch
from torch._inductor.utils import run_and_get_code
from contextlib import nullcontext

def foo(x):
   return x * torch.randn_like(x)

foo_c = torch.compile(foo)

x = torch.ones(100, 100, device="mps")

y = foo_c(x)

print(y.mean().item(), y.std().item())
for i in range(25):
  print(y[i].mean(), y[i].std())
```
And observe that printed values are close to 0 and 1

TODO: Better `randint` algorithm for large ranges

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145705
Approved by: https://github.com/dcci, https://github.com/jansel
2025-01-27 06:07:36 +00:00
Davide Italiano
57591edca1 [mps/inductor] Add support for erfinv. (#145643)
After several rounds of refactoring, this seems to be done now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145643
Approved by: https://github.com/malfet, https://github.com/jansel
2025-01-24 22:55:44 +00:00
Nikita Shulga
70ccbade83 [MPSInductor] Add gamma op (#145341)
By moving `gamma` and `log_gamma` implementation from `Gamma.metal` to `c10/metal/special_math.h`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145341
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #145309
2025-01-22 19:37:45 +00:00
Nikita Shulga
980c75fe6e [MPSInductor] Add TrueDiv and Round[Int|Decimal] (#145160)
That fixes `test_builtins_round_float_ndigits_neg` and `test_builtins_round`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145160
Approved by: https://github.com/jansel, https://github.com/dcci
2025-01-20 04:29:42 +00:00
Davide Italiano
8cc415774f [mps/inductor] Introduce a metal approx for erf() and use it. (#145161)
Probably we can do better, but this is a start.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145161
Approved by: https://github.com/malfet
2025-01-19 02:29:05 +00:00
Nikita Shulga
cede43e06b [MPSInductor][BE] NaN-propagating min/max to header (#145157)
May be to be later reused from eager op as well

Also, didn't know that Metal already have type_traits
And use `metal::isunorderder(a, b)` instead of `metal::isnan(a + b)` is it is defined as function that is equivalent  `a != a || b != b`, but I suspect it might have a best native implementation for the specific architecture

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145157
Approved by: https://github.com/dcci
2025-01-18 22:52:44 +00:00
Nikita Shulga
8a57234033 [MPSInductor] Implement i0 and i1 ops (#145092)
Using shared definitions with eager op

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145092
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #145023, #145087
2025-01-18 15:41:02 +00:00
Nikita Shulga
41ec2e8d3e [MPSInductor] Fix codegen regression (#144924)
Caused by https://github.com/pytorch/pytorch/pull/144649

Do not try to insert anything into the header if wrapper is not ready yet

Fixes `test_sort_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144924
Approved by: https://github.com/dcci
ghstack dependencies: #144827, #144917
2025-01-16 02:12:42 +00:00
Nikita Shulga
05505771a0 [MPSInductor] Properly convert index (#144917)
By calling `self.index_to_str` from `load`,`store` and `check_bounds` in order to properly handle sizevars variables renames

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144917
Approved by: https://github.com/dcci
ghstack dependencies: #144827
2025-01-16 02:12:41 +00:00
Nikita Shulga
904641769e [MPSInductor] Implement pow() (#144827)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144827
Approved by: https://github.com/dcci, https://github.com/jansel
2025-01-15 20:11:34 +00:00
Nikita Shulga
d2ca8163c0 [MPSInductor] Support abs in MetalPrintExpr (#144826)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144826
Approved by: https://github.com/dcci
ghstack dependencies: #144509, #144798, #144795, #144796
2025-01-15 05:01:25 +00:00
Nikita Shulga
e2251fffbb [MPSInductor] Add min/max to MetalExprPrinter (#144798)
After that `GPUTests::test_avg_pool2d8_mps` and `GPUTests::test_avg_pool2d5_mps` passes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144798
Approved by: https://github.com/dcci
ghstack dependencies: #144509
2025-01-15 01:43:42 +00:00
Davide Italiano
35b46a75f1 [mps/inductor] Add support for round() (#144731)
With this change, inductor/test_view_on_aliased passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144731
Approved by: https://github.com/malfet
2025-01-14 05:56:13 +00:00
Davide Italiano
de9d6a25d7 [mps/inductor] Add support for ceil (#144715)
inductor/test_index_dynamic_shapes passes after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144715
Approved by: https://github.com/malfet
2025-01-14 01:16:47 +00:00
Nikita Shulga
c40d917182 [MPSInductor] Fix maximum/minimum for int types (#144665)
`metal::isnan` is only defined for floats, so provide a generic wrapper
that is false for integral types

TODO: Figure out why type propagantion is not working (or should it?)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144665
Approved by: https://github.com/dcci
2025-01-13 15:14:01 +00:00
Davide Italiano
417354d953 [mps/inductor] Add support for truncdiv(). (#144666)
Two other inductor tests pass after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144666
Approved by: https://github.com/malfet
2025-01-13 13:39:38 +00:00
Nikita Shulga
7e2239f1f0 [MPSInductor] Better error when kernel fails to compile (#144649)
Now error message looks as follows:
```
% python ../test/inductor/test_torchinductor.py -v -k test_cat_unbacked_2d_mps
test_cat_unbacked_2d_mps (__main__.GPUTests) ... inline_call []
stats [('calls_captured', 6)]
inductor [('extern_calls', 2), ('fxgraph_cache_miss', 1)]
aot_autograd [('total', 1), ('autograd_cache_bypass', 1), ('not_ok', 1)]
ERROR

======================================================================
ERROR: test_cat_unbacked_2d_mps (__main__.GPUTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3126, in wrapper
    method(*args, **kwargs)
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 12254, in new_test
    return value(self)
  File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 5885, in test_cat_unbacked_2d
    self.common(
  File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 620, in check_model_gpu
    check_model(
  File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 461, in check_model
    actual = run(*example_inputs, **kwargs)
  File "/Users/malfet/git/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 704, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 689, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1149, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1064, in codegen_and_compile
    compiled_fn = graph.compile_to_module().call
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 1977, in compile_to_module
    return self._compile_to_module()
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 2018, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/codecache.py", line 2768, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 40, in <module>
  File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 32, in _compile_mps_shader
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    kernel void generated_kernel(
        device float* out_ptr0,
        constant float* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        long x1 = (xindex) / (3);
        auto tmp0 = x1;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        auto tmp4 = 2;
        auto tmp5 = tmp1 < tmp4;
        long x0 = (xindex) % (3);
        auto tmp6 = in_ptr0[x0 + 3*(x1)];
        auto tmp7 = tmp5 ? tmp6 : 0.0;
        auto tmp8 = tmp1 >= tmp4;
        auto tmp9 = 2 + ks0;
        auto tmp10 = static_cast<long>(tmp9);
        auto tmp11 = tmp1 < tmp10;
        auto tmp12 = 1.0;
        auto tmp13 = tmp8 ? tmp12 : 0.0;
        auto tmp14 = tmp5 ? tmp7 : tmp13;
        long x2 = xindex;
        out_ptr0[x2] = static_cast<float>(tmp14);
    }
 with program_source:18:25: error: use of undeclared identifier 'ks0'
        auto tmp9 = 2 + ks0;
                        ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

To execute this test, run the following from the base repo dir:
    python test/inductor/test_torchinductor.py GPUTests.test_cat_unbacked_2d_mps

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 1 test in 0.472s

FAILED (errors=1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144649
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #144647, #144648
2025-01-13 13:38:03 +00:00
Nikita Shulga
a08bd8154e [MPSInductor] Add support for sizevars (#144662)
Just pass them as kernel arguments

After this change  `pytest test/inductor/test_torchinduct.py -v -k _mps` reports 330 failed, 429 passed  after and 335 failed, 424 passed before

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144662
Approved by: https://github.com/jansel
2025-01-13 06:22:38 +00:00
Nikita Shulga
91a65cbd31 [MPSInductor] Implement check_bounds (#144635)
Although at the moment it returns rather than rasises assert due to https://github.com/pytorch/pytorch/pull/144632

`pytest test/inductor/test_torchinductor.py -v -k _mps` score is `368
failed, 391 passed, 32 skipped`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144635
Approved by: https://github.com/jansel
2025-01-12 21:01:20 +00:00
Nikita Shulga
cec245806e [MPSInductor] Implement bitcasts (#144638)
That will be used to compile something like `torch.rand(32, device='mps').view(dtype=torch.int32)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144638
Approved by: https://github.com/dcci
2025-01-12 06:11:28 +00:00
Nikita Shulga
32a91dedc5 [MPSInductor] Properly generate index expressions (#144632)
Now test_slice_scatter4_mps passes

Before this change test_torchinductor.py reported 422 failed and 337 passed, after this change 412 failed 347 passed.

Fixes https://github.com/pytorch/pytorch/issues/144630

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144632
Approved by: https://github.com/dcci
2025-01-12 06:10:05 +00:00
Davide Italiano
e0f67405a1 [mps/inductor] Add support for exp(). (#144606)
inductor/test_silu now passes after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144606
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-01-12 00:38:11 +00:00
Davide Italiano
5e858254d2 [mps/inductor] Add support for trunc(). (#144629)
inductor/test_div1 passes after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144629
Approved by: https://github.com/malfet, https://github.com/jansel
2025-01-12 00:11:03 +00:00
PyTorch MergeBot
4f406d22a2 Revert "[mps/inductor] Add support for exp(). (#144606)"
This reverts commit 2ccbacfa24.

Reverted https://github.com/pytorch/pytorch/pull/144606 on behalf of https://github.com/malfet due to It now passes MPS-not-supported test ([comment](https://github.com/pytorch/pytorch/pull/144606#issuecomment-2585482477))
2025-01-11 23:51:35 +00:00
Davide Italiano
2ccbacfa24 [mps/inductor] Add support for exp(). (#144606)
inductor/test_silu now passes after this change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144606
Approved by: https://github.com/malfet
2025-01-11 18:09:33 +00:00
Nikita Shulga
c7f12a4a7b [MPSInductor] Speedup maximum/minumum ops (#144581)
By relying on the fact that if either `a` or `b` is NaN (or both), than `a + b` would also be NaN.

I.e. it replaces
```metal
auto tmp2 = metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp0))) | metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp1))) ? static_cast<decltype(tmp0+tmp1)>(NAN) : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1));
```
with
```metal
auto tmp2 = metal::isnan(tmp0 + tmp1) ? tmp0 + tmp1 : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1));
```

which according to MetalProfiler takes fewer instructions:
<img width="520" alt="image" src="https://github.com/user-attachments/assets/54659392-012b-453e-9c02-c3c5f332074a" />
vs
<img width="1031" alt="image" src="https://github.com/user-attachments/assets/55fcfa78-1ea5-4b0a-8154-d79b3e3cc400" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144581
Approved by: https://github.com/dcci, https://github.com/jhavukainen
2025-01-10 22:58:00 +00:00
Nikita Shulga
91cbeb7db9 [MPSInductor] Fix masked/where for inf values (#144500)
Move constant to value logic to `value_to_metal` function (similar to `value_to_cpp`)

Call it from `constant` as well as `where` ops (which is in turn being called from `masked` op

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144500
Approved by: https://github.com/dcci
2025-01-09 23:11:06 +00:00
Davide Italiano
1353f3beb4 [mps/inductor] Add support for fmod(). (#144449)
397 -> 395 tests failing. `static_cast<>` is because there are several overloads of `fmod()` that's otherwise ambiguous. I wonder if we should take in account NaN propagation (maybe it's not tested).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144449
Approved by: https://github.com/malfet
2025-01-09 15:47:41 +00:00
Davide Italiano
6f28e466f3 [mps/inductor] Add support for tanh(). (#144443)
Fixes test_tanh() in the inductor testsuite.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144443
Approved by: https://github.com/malfet
2025-01-09 06:14:03 +00:00
Davide Italiano
8fc0ffe54b [mps/inductor] Add support for rsqrt(). (#144374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144374
Approved by: https://github.com/malfet
2025-01-08 13:58:05 +00:00
Davide Italiano
551f104153 [mps/inductor] Add support for sign(). (#144298)
Drive-by fix of a test name while I was at it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144298
Approved by: https://github.com/malfet
2025-01-07 03:33:26 +00:00
Nikita Shulga
16c1b1048b [MPSInductor] Add nan constant generation (#144281)
If val is not equal to self, it's a nan (which is spelled as `NAN` in Metal)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144281
Approved by: https://github.com/atalman, https://github.com/dcci
2025-01-06 22:13:23 +00:00
Nikita Shulga
7d5249dbc2 [EZ][BE] Fix E226 flake8 violation (#144282)
Not sure why CI did not complain about it, but it my local runs it clearly says
```
Advice (FLAKE8) E226
    missing whitespace around arithmetic operator
    See https://www.flake8rules.com/rules/E226.html

        268  |            with code.indent():
        269  |                if len(idx_var_names) > 1:
        270  |                    for idx, name in enumerate(idx_var_names):
    >>> 271  |                        code.writeline(f"auto {name} = thread_pos.{chr(120+idx)};")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144282
Approved by: https://github.com/Skylion007
2025-01-06 22:12:21 +00:00
Davide Italiano
23e2953cd3 [mps/inductor] Add support for floor(). (#144195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144195
Approved by: https://github.com/jansel
2025-01-06 02:07:17 +00:00
Joona Havukainen
811c714911 Fix nan propagation for minimum() and maximum() in MPS (#144086)
Fixes #143976

- Moves minimum and maximum operations to use the NaN propagating call into MPSGraph instead of the default one.
 - Adds test for the NaN propagating case to `test_mps.py`.
- Adjusts the inductor metal backend implementation for minimum and maximum to also respect the nan propagation.

Additions by @malfet:
 - Introduce MPSGraph+PyTorchFixups interface following [Customizing existing classes](https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ProgrammingWithObjectiveC/CustomizingExistingClasses/CustomizingExistingClasses.html) tutorial and implement `minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:` as `minimumWithNaNPropagationWithPrimaryTensor:` segfaults when called for integral types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144086
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <nshulga@meta.com>
2025-01-04 18:48:24 +00:00
Nikita Shulga
b5b1e9456a [MPSInductor] Add masked implementation (#144084)
More or less borrowed from
22580f160e/torch/_inductor/codegen/halide.py (L549-L563)

`pytest test/inductor/test_torchinductor.py -k _mps` score is 408 failed, 347 passed, 32 skipped

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144084
Approved by: https://github.com/Skylion007, https://github.com/jansel
ghstack dependencies: #144167, #144162, #144083
2025-01-04 04:30:07 +00:00
Davide Italiano
479d6f2199 [mps/inductor] Add support for log(). (#144169)
Tested via:

```
 % pytest test/inductor/test_mps_basic.py
 ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144169
Approved by: https://github.com/jansel, https://github.com/malfet
2025-01-04 03:07:56 +00:00
Nikita Shulga
464b50dbd7 [MPSInductor] Add floor_div and index_expr implementation (#144083)
Simply copy-n-pasted from CPPInductor

`pytest test/inductor/test_torchinductor.py -k _mps` score is 418 failed, 337 passed, 32 skipped

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144083
Approved by: https://github.com/jansel
ghstack dependencies: #144167, #144162
2025-01-04 01:10:01 +00:00
Nikita Shulga
6d25938540 [MPSInductor] Add remainder op (#144162)
For it to return correct result for half precision type it must be
upcast to float

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144162
Approved by: https://github.com/jansel
ghstack dependencies: #144167
2025-01-04 00:47:40 +00:00
Nikita Shulga
f8e1eacf2f [MPSInductor] Extend constant to bool type (#144167)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144167
Approved by: https://github.com/jansel
2025-01-04 00:47:40 +00:00
Nikita Shulga
ad09395674 [MPSInductor] Fix multi rangevar kernel invocation (#144050)
By changing `thread_position_in_grid` type to uint{n} and passing
dimentions during the kernel call

`pytest test/inductor/test_torchinductor.py -k _mps` score is 445 failed, 309 passed, 32 skipped

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144050
Approved by: https://github.com/jansel
ghstack dependencies: #144055, #144051, #144122, #144105, #144156
2025-01-03 19:32:43 +00:00
Nikita Shulga
52e107a7ca [MPSInductor] Add constant, isinf and isnan ops (#144156)
Per Table 6.5 of [Metal Language Specification](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf) infinity is `HUGE_VALF`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144156
Approved by: https://github.com/Skylion007, https://github.com/jansel
ghstack dependencies: #144055, #144051, #144122, #144105
2025-01-03 19:32:43 +00:00
Davide Italiano
56f6289f6a [mps/inductor] Add support for atanh(). (#144121)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144121
Approved by: https://github.com/jansel, https://github.com/malfet
2025-01-03 18:55:05 +00:00
Nikita Shulga
a7b61c5b49 [MPSInductor] Add signbit op support (#144105)
By mapping it to `metal::signbit`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144105
Approved by: https://github.com/jansel, https://github.com/Skylion007
ghstack dependencies: #144055, #144051, #144122
2025-01-03 18:34:46 +00:00
Nikita Shulga
f7644efa79 [MPSInductor][EZ] Fix logical_[or|end] ops (#144122)
For boolean operands it does not really matter whether `&` or `&&` is
used, but if one ever to rely on operator precedence, then bitwise ops
should have higher precendence than logical ones

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144122
Approved by: https://github.com/huydhn
ghstack dependencies: #144055, #144051
2025-01-03 15:28:07 +00:00
Nikita Shulga
b336d72dae [MPSInductor] Preserve dtype during load (#144051)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144051
Approved by: https://github.com/Skylion007
ghstack dependencies: #144055
2025-01-03 15:17:33 +00:00
Nikita Shulga
5ef0de7615 [MPSInductor] Fix multiple kernel generation (#143998)
At the moment by generating multiple MetalLibraries

`pytest test/inductor/test_torchinductor.py -k _mps` score is 434 failed, 317 passed, 32 skipped

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143998
Approved by: https://github.com/jansel, https://github.com/ruidazeng
ghstack dependencies: #143948, #143949, #143973, #143977
2024-12-31 13:51:50 +00:00
Nikita Shulga
f0f09bb3c2 [MPSInductor] Implement minimum and maximum ops (#143977)
By calling `metal::min` and `metal::max` respectively with argument
typecast to a common type to avoid ambiguous calls errors

TODO: Implement NaN propagation for both eager and compile, see https://github.com/pytorch/pytorch/issues/143976

`pytest test/inductor/test_torchinductor.py -k _mps` score is 460 failed, 291 passed, 32 skipped

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143977
Approved by: https://github.com/jansel
ghstack dependencies: #143948, #143949, #143973
2024-12-31 13:51:50 +00:00
Nikita Shulga
11bb94b7ea [MPSInductor] Fix index generation for transpose (#143973)
Alas, PythonPrinter would not work here, not would CppPrinter, so start building MetalPrinter.

`pytest test/inductor/test_torchinductor.py -k _mps` score is 474 failed, 277 passed, 32 skipped
Before this change:
`pytest test/inductor/test_torchinductor.py -k _mps` reported 506 failed, 245 passed, 32 skipped

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143973
Approved by: https://github.com/jansel
ghstack dependencies: #143948, #143949
2024-12-31 02:04:50 +00:00