Commit Graph

517 Commits

Author SHA1 Message Date
Joona Havukainen
5b96a552df Add a check and error message for no support on MPS for conv with output_channels > 2^16 (#129484)
Fixes the silent correctness issue in #129207 by preventing the user from calling the convolution op on MPS device with an unsupported value.

The fix for the missing support is coming in later as that requires work on the kernel side so it'll take some more time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129484
Approved by: https://github.com/kulinseth
2024-06-28 20:57:40 +00:00
Manuel Candales
eabe6574c0 [metal] Parameterize group_size in int4_mm test, fix int4mm shader for group_size > 128 (#129628)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129628
Approved by: https://github.com/kimishpatel
2024-06-28 15:01:30 +00:00
Nikita Shulga
bc68907caa [EZ][BE] Replace assertTrue with more appropriate checks (#129569)
Based on this https://github.com/pytorch/pytorch/pull/129340#issuecomment-2191228046 I.e.
- `assertTrue(x == y)` -> `assertEqual(x, y)
- `assertTrue(not x)` -> assertFalse(x)`
- `assertTrue(x > y)` -> assertGreater(x, y)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129569
Approved by: https://github.com/jeanschmidt, https://github.com/Skylion007
2024-06-26 16:29:59 +00:00
PyTorch MergeBot
b045878f81 Revert "Remove test_mps_allocator_module XFAIL (#129340)"
This reverts commit c888ee3632.

Reverted https://github.com/pytorch/pytorch/pull/129340 on behalf of https://github.com/huydhn due to The test is now failing again in trunk after a day or so of staying green, we need to continue the investigation ([comment](https://github.com/pytorch/pytorch/pull/129340#issuecomment-2189701706))
2024-06-25 18:37:54 +00:00
Isuru Fernando
e6bfa2958b Add aten._unsafe_masked_index (#116491)
To generate masked indexing operations that would generate
masked loads in triton code

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116491
Approved by: https://github.com/lezcano, https://github.com/peterbell10
2024-06-25 02:45:02 +00:00
Isuru Fernando
5f912f480c Fix max_pool2d decomposition for empty list and integer limits (#129106)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129106
Approved by: https://github.com/peterbell10, https://github.com/lezcano, https://github.com/malfet
ghstack dependencies: #129096, #129097
2024-06-24 22:19:42 +00:00
Huy Do
c888ee3632 Remove test_mps_allocator_module XFAIL (#129340)
Not sure why this test starts to fail (maybe runner update) 8a2fed7e6a/1 or why it was XFAIL in this old PR https://github.com/pytorch/pytorch/pull/97151, but the test is passing locally for me now

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129340
Approved by: https://github.com/kit1980
2024-06-24 16:26:38 +00:00
Manuel Candales
749c03406c [metal] Add int4mm weight packing mps kernel, and improved int4mm shader (#128965)
Adds _convert_weight_to_int4pack MPS kernel
Replaces previous int4mm Metal shader, with shader authored by @kimishpatel  which improves perf by ~40%

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128965
Approved by: https://github.com/malfet
2024-06-23 02:10:46 +00:00
Li-Huai (Allan) Lin
799acd31b4 [MPS] Add lu_factor (#99269)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at d75cde1</samp>

Added MPS support and autograd formulas for LU factorization of tensors. Implemented the `linalg_lu_factor` and `linalg_lu_factor.out` functions for the MPS backend in `LinearAlgebra.mm` and added tests in `test_mps.py`. Added the corresponding dispatch entries in `native_functions.yaml` and the backward and forward formulas in `derivatives.yaml`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99269
Approved by: https://github.com/kulinseth, https://github.com/lezcano
2024-06-20 07:35:29 +00:00
Li-Huai (Allan) Lin
9a7e2519d3 [MPS] Fused Adam & AdamW (#127242)
Summary:

This PR adds fused Adam and AdamW implementations.

Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        89
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        90
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        83
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       12      |        94
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       11      |        88
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       12      |        90
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |       100
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       27      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       23      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       27      |       100
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       23      |        98
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       82      |       480
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       72      |       450
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       82      |       450
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       73      |       420
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       91      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       83      |       400
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |       94      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       78      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      170      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      140      |       600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      170      |       600
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      140      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      250      |       890
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      220      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      250      |       830
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      220      |       770
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      270      |       870
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      230      |       840
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      270      |       810
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      240      |       800
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      400      |      1000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      360      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      430      |      2000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      360      |      1300

Times are in milliseconds (ms).
```

**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        79
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       11      |        93
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       10      |        90
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       11      |        91
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |        81
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       34      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       31      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       34      |        95
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       31      |       100
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       94      |       500
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       82      |       430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       92      |       430
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       81      |       390
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       98      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       88      |       430
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |      100      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       88      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      210      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      190      |       610
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      210      |       510
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      190      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      300      |       900
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      260      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      295      |       900
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      260      |       800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      320      |       910
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      280      |       900
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      320      |       900
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      300      |       900
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      500      |      2000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      480      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      540      |      1500
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      480      |      1200

Times are in milliseconds (ms).
```

```python
def profile_fused_adam():
    from torch.optim import adam, adamw
    import torch.utils.benchmark as benchmark

    import itertools

    def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
        fn(
            params,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
            foreach=False,
            capturable=False,
            fused=fused,
            amsgrad=amsgrad,
            beta1=0.9,
            beta2=0.99,
            lr=1e-3,
            weight_decay=.0,
            eps=1e-5,
            maximize=False,
            grad_scale=None,
            found_inf=None,
        )
        torch.mps.synchronize()

    device = "mps"

    results = []

    for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
        print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
        params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
        max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
        state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
        if adamWflag:
            fn = adamw.adamw
        else:
            fn = adam.adam

        for fused in [True, False]:

            t = benchmark.Timer(
                    stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                    label='Fused Adam',
                    sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                    globals=locals(),
                    description= f"Fused: {fused}",
                ).blocked_autorange(min_run_time=5)
            results.append(t)

    compare = benchmark.Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)
    compare.print()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
2024-06-18 19:59:50 +00:00
Joona Havukainen
d9eaa224f2 Fixes #128429: NaN in triu op on MPS (#128575)
Fixes triu op when k > 0 and the lower triangle of the input tensor contains inf leading to NaNs in the computation through complement. Fixed by using select API instead.

Fixes #128429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128575
Approved by: https://github.com/kulinseth
2024-06-18 03:44:42 +00:00
Nikita Shulga
9035fff2de [BE] Do not test deprecated torch.nn.utils.weight_norm (#128727)
Test `torch.nn.utils.parametrizations.weight_norm` instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128727
Approved by: https://github.com/kit1980
ghstack dependencies: #128726
2024-06-14 19:14:44 +00:00
Nikita Shulga
27458cc097 [BE] Refactor repeated code in test_weight_norm (#128726)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128726
Approved by: https://github.com/kit1980
2024-06-14 19:14:44 +00:00
Tom Ritchford
edb45dce85 Add OpInfo entry for as_strided_copy (#127231)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127231
Approved by: https://github.com/lezcano
2024-06-13 13:58:47 +00:00
Nikita Shulga
0678742924 [MPS] Add Metal implementation of exp op (#128421)
To improve accuracy, use `precise::exp()` (and `precise::sin()`/`precise::cos()` for complex flavor)
Reuse `test_exp1` to check that accuracy of `exp` ops is sometimes closer to CPU

Fix bug in non-contiguous tensors handling

Fixes https://github.com/pytorch/pytorch/issues/84936
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128421
Approved by: https://github.com/kulinseth
ghstack dependencies: #128373, #128375
2024-06-13 06:53:17 +00:00
Kulin Seth
8df56afc20 Add support in Python API for the recommended max working set size. (#128289)
Adds ways for users to request recommended max size for Metal on Mac. It plumbs through
https://developer.apple.com/documentation/metal/mtldevice/2369280-recommendedmaxworkingsetsize?language=objc

Can be used like
```
        max_memory = torch.mps.recommended_max_memory()
        print ("Recommended Max Memory : ", (max_memory/(1024*1024*1024)), "GB")
```

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128289
Approved by: https://github.com/malfet
2024-06-12 16:03:57 +00:00
Tom Ritchford
2386045e4f Add OpInfo entry for alias_copy (#127232) (#128142)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128142
Approved by: https://github.com/lezcano
2024-06-12 09:39:58 +00:00
Joona Havukainen
a5ba9b2858 Fix for addcdiv contiguous problem (#124442)
Fixes issue number #118115
Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124442
Approved by: https://github.com/kulinseth
2024-06-06 16:09:18 +00:00
Huy Do
8992141dba Restore MPS testing on MacOS 13 and m2 metal (#127853)
The runners are ready now https://github.com/organizations/pytorch/settings/actions/runners?qr=label%3Amacos-m1-13, we want to keep some MacOS 13 runner for mps coverage until MacOS 15 is out.

This also fixes the `macos-m2-14` mistake from https://github.com/pytorch/pytorch/pull/127582.

The current `macos-m2-14` runner is on 14.2 while our `macos-m1-14` has 14.4.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127853
Approved by: https://github.com/malfet
2024-06-05 14:44:00 +00:00
PyTorch MergeBot
d1fad416a8 Revert "Add aten._unsafe_masked_index (#116491)"
This reverts commit f03f8bc901.

Reverted https://github.com/pytorch/pytorch/pull/116491 on behalf of https://github.com/PaliC due to breaking onnx tests ([comment](https://github.com/pytorch/pytorch/pull/116491#issuecomment-2145557724))
2024-06-03 15:51:50 +00:00
Isuru Fernando
f03f8bc901 Add aten._unsafe_masked_index (#116491)
To generate masked indexing operations that would generate
masked loads in triton code

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116491
Approved by: https://github.com/lezcano, https://github.com/peterbell10
2024-06-03 14:44:03 +00:00
Nikita Shulga
045309aa35 [MPS] Enable toch.mm and friends for complex dtypes (#127241)
- Add `supportedFloatingOrComplexType`
- Change dtype check to those
- Extend low-precision fp32 list to complex types
- Mark conv2d as supported now, as it was failing due to the tighter accuracy constrains than the same op for float32 dtype

Fixes https://github.com/pytorch/pytorch/issues/127178

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127241
Approved by: https://github.com/janeyx99
2024-05-28 17:56:13 +00:00
Nikita Shulga
4ff9113e3d [MPS] Add _weight_int8pack_mm tests (#127041)
As well as extend the test to cover MV cases (where A matrix is 1xM) Limit int8 op testing to 32x32 matrix sizes for now

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127041
Approved by: https://github.com/larryliu0820, https://github.com/manuelcandales
2024-05-24 16:08:06 +00:00
jhavukainen
6a539e80dd Update descriptor fields to resolve fft precision issue (#125328)
Fixes #124096
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125328
Approved by: https://github.com/kulinseth, https://github.com/malfet
2024-05-22 21:48:49 +00:00
jhavukainen
d28868c7e8 Change skipIfs to xfails in test_mps.py for test_isin (#125412)
Follow-up to #124896 to move the added test to use expectedFailure instead of skip.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125412
Approved by: https://github.com/kulinseth
2024-05-20 20:23:53 +00:00
Nikita Shulga
b8a706a321 [EZ][BE] Use untyped_storage in tests (#125838)
Get's rid of the following warning:
```
/Users/shenke/workspace/pytorch/test/test_mps.py:9229: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if base.storage().data_ptr() != other.storage().data_ptr():
```

(noticed while looking at https://github.com/pytorch/pytorch/issues/96153#issuecomment-2101876484 )

Respective change to view ops was landed back in 2022, see https://github.com/pytorch/pytorch/pull/91414

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125838
Approved by: https://github.com/albanD
2024-05-09 14:04:21 +00:00
Nikita Shulga
4e29e80bf0 Run MPS tests on MacOS Sonoma (#125801)
Those ones are running 14.4.1, so I wonder if they actually pass CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125801
Approved by: https://github.com/kit1980, https://github.com/huydhn
2024-05-09 13:43:12 +00:00
Denis Vieriu
58e045d03c [MPS] Fix strided ELU op (#125692)
Fixes https://github.com/pytorch/pytorch/issues/124834

Summary of changes:

In case of non-contiguous input, the output would be non-contiguous too. At the moment it's not supported to save the result to a non-contiguous buffer, thus we need two steps, one to allocate a contiguous buffer and the second one to scatter the result back to the original ouput.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125692
Approved by: https://github.com/kulinseth
2024-05-08 01:34:40 +00:00
Denis Vieriu
ba27548679 [MPS] Remove in place views (causes too many crashes) (#124895)
Fixes https://github.com/pytorch/pytorch/issues/96153

Remove in place views as they are a general cause for many crashes.
Proper fix to handle views without copies will come in a different PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124895
Approved by: https://github.com/kulinseth
2024-05-08 01:00:37 +00:00
Denis Vieriu
3fb53bb6a7 [MPS] Fix strided mse_loss (#125696)
Fixes https://github.com/pytorch/pytorch/issues/124621

Summary of changes:
- In case of non-contiguous input, the output would be non-contiguous too. At the moment it's not supported to save the result to a non-contiguous buffer, thus we need two steps, one to allocate a contiguous buffer and the second one to scatter the result back to the original ouput.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125696
Approved by: https://github.com/kulinseth
2024-05-08 00:52:26 +00:00
Nikita Shulga
0fd1fc17c3 [MPS] Fix abs for complex types (#125662)
By calling `realPartOfTensor:` if input type is complex on Sonoma and fall back to `at::view_as_real` trick on Ventura.

Split `unary_op` template into `unary_op` and `unary_op_noresize`, which skips resize and empty checks

Marked `abs`, `isclose` and `nn.functional.softsign` OpInfo tests as supported by complex types

Fixes https://github.com/pytorch/pytorch/issues/125135

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125662
Approved by: https://github.com/kulinseth
2024-05-07 22:15:20 +00:00
Nikita Shulga
30610251ec [MPS] And naive quantized intmm and .gputrace capture hooks (#125163)
- Implement a very straightforward Metal copy of CPU int4mm kernel
- Implement int8mm kernel by constructing a graph consisting of upcast, transpose and mm
- Add `isCapturing`, `isCaptureEnabled`, `startCapture` and `stopCapture` methods to `MPSProfile` which can be used to help one debug/profile Metal kernels by wrapping the calls with the following
  ```cpp
   if (getMPSProfiler().profiler.isCaptureEnabled()) {
     getMPSProfiler().startCapture(__func__, mpsStream);
   }
   ...
   if (getMPSProfiler().isCapturing()) {
     getMPSProfiler().stopCapture(mpsStream);
   }
  ```
  that, if invoked with `MTL_CAPTURE_ENABLED` environment variable set to one, will produce .gputrace files, in the current working directory, which can later be loaded and used to debug or profiler the kernel
<img width="1093" alt="image" src="https://github.com/pytorch/pytorch/assets/2453524/a2bf27e8-df8a-442c-a525-1df67b8a376a">

- Added `test_int4mm` to TestLinalgMPS, which is mostly copy-n-paste of the test from `test_linalg`

TODOs:
 - Add weight pack
 - Perf-tune both kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125163
Approved by: https://github.com/mikekgfb
2024-05-03 15:20:39 +00:00
Denis Vieriu
a40d6df448 [MPS] Native nonzero implementation (#125355)
Fixes https://github.com/pytorch/pytorch/issues/124850

Replace previous MPSGraph nonzero construction with native nonzero op. For older OSes, fallback to CPU (previous implementation was not reliable and was comparable to CPU in speed).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125355
Approved by: https://github.com/kulinseth
2024-05-03 03:50:58 +00:00
Roy Hvaara
e15da7856c [MPS] Fix overflow in cumsum when dtype is bool (#125318)
`cumsum` and `cumprod` was (is?) buggy for MPS: c8d2a55273/aten/src/ATen/native/mps/operations/UnaryOps.mm (L435-L436)

A workaround casts the input to int32 prior to performing the op to prevent overflow for certain numeric types.

It turns out this issue also affects boolean types:

```python
import torch
print(torch.ones(128, dtype=torch.bool, device="mps").cumsum(0)[-1])
# tensor(-128, device='mps:0')
```

In this PR I'm adding logic to also cast bool dtypes to int32 prior to `cumsum` and `cumprod`, although output is guaranteed not to overflow for the latter with bools. I'm also adding a test to prevent regressions.

Fixes #96614 #106112 #109166

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125318
Approved by: https://github.com/malfet
2024-05-03 01:19:24 +00:00
Joona Havukainen
c451d108da Implemented isin_Tensor_Tensor_out for MPS backend (#124896)
Addresses issue #124518, adds isin_Tensor_Tensor_out.

Tests added to test_mps.py.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124896
Approved by: https://github.com/malfet, https://github.com/kulinseth
2024-05-01 23:14:05 +00:00
Nikita Shulga
5944a53555 [MPS] Fix nextafter for negative values (#125029)
By changing the logic to on older MacOS:
```cpp
bits += ((input > 0) ^ (input > other)) ? 1 : -1;
```
And use native `nextafter` on MacOS Sonoma (i.e. if Metal 3.1 is available)

TODO:
  - Add tests for infs and denorms

Fixes https://github.com/pytorch/pytorch/issues/124985

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125029
Approved by: https://github.com/Skylion007
2024-04-27 02:58:05 +00:00
Nikita Shulga
db3a2d751c [MPS][BE] Error-check linear (#124952)
Validate that all arguments are on MPS devices and dtypes are expected

Fixes cryptic messages like
```
% python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32), torch.rand((32, 32), device='mps')))"
RuntimeError: Placeholder storage has not been allocated on MPS device!
```
And hard crashes like
```
% python3 -c "import torch;print(torch.nn.functional.linear(torch.rand(32, 32, device='mps'), torch.randint(-10, 10, (32, 32), dtype=torch.int8, device='mps')))"
```

Fixes https://github.com/pytorch/pytorch/issues/123995

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124952
Approved by: https://github.com/Skylion007
2024-04-25 23:25:20 +00:00
Nikita Shulga
abf3f90781 [MPS] Fix large copy (#124635)
By slicing `copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:` into 2Gb chunks

Add regression test, but limit it to machines with 12Gb of RAM or more, and MacOS 14+, as on MacOS 13 attempt to alloc 4Gb tensor fails with:
```
/AppleInternal/Library/BuildRoots/c651a45f-806e-11ed-a221-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:724: failed assertion `[MPSNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'
```

Fixes https://github.com/pytorch/pytorch/issues/124335

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124635
Approved by: https://github.com/kulinseth
2024-04-22 23:43:11 +00:00
Aaron Gokaslan
5a1216bb2e [BE]: Update ruff to 0.4.1 (#124549)
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.

Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0

| Repository                                         | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7         | 251.8         | 351.1            | 274.9            |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
2024-04-21 14:06:23 +00:00
Joël Tang
a6a3f2e06b [MPS] Fixes GELU, LeakyRELU and MISH on non-contiguous tensors (#123049)
Fixes GELU, LeakyRELU and MISH activation functions on non-contiguous tensors (for instance, when a transpose operation was applied on the tensors prior to the MPS operator), forward and backward passes.

I also extended tests on the 3 activation functions to check: full-precision and half-precision, contiguous and non-contiguous, and several dims of tensors: scalars, 1D, empty, 2D, > 3D.

I had issues with Mish and GELU activations when asserting the gradients vs. CPU with sum() on some cases, so I reverted to the previous setup by setting a gradient parameter on .backwards().
This PR also fixes an issue with LeakyRELU on empty tensors.

Fixes #98212 huggingface/transformers#22468 huggingface/transformers#19353
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123049
Approved by: https://github.com/kulinseth
2024-04-21 00:12:32 +00:00
Nikita Shulga
5677128cb8 [MPS] Fix crash with binary_cross_entropy is invoked for half dtypes (#124258)
By creating constants using input tensors dtype

One line reproducer:
```
python -c "import torch; x=torch.arange(3, dtype=torch.float16,device='mps');print(torch.nn.functional.binary_cross_entropy(x, x))"
```

Before the change
```
loc("mps_subtract"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":233:0)): error: input types 'tensor<f32>' and 'tensor<3xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
```
After
```
tensor(-33.7812, device='mps:0', dtype=torch.float16)
```

Fixes https://github.com/pytorch/pytorch/issues/124252

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124258
Approved by: https://github.com/kulinseth
2024-04-18 15:21:01 +00:00
xinan.lin
6fcbeb3489 [ATen] Add CPU fp16 support for nll_loss and cross_entropy_loss (#123256)
Add CPU FP16 support for nll_loss and cross_entropy_loss.
Resolve issue #123328.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123256
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/malfet
2024-04-18 11:44:38 +00:00
Pearu Peterson
d2b0c0a34e Fix index_reduce sampler filter when op_info.variant_test_name is specified (#123375)
As in the title: `index_reduce` sample must correspond to reduction type specified by `variant_test_name`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123375
Approved by: https://github.com/zou3519, https://github.com/peterbell10
2024-04-17 15:31:28 +00:00
FFFrog
acc466751b Add bfloat16 support to binary_cross_entropy for CPU (#123823)
Fixes #123715

As the title stated.

But, maybe we should pay attention to this https://github.com/pytorch/pytorch/pull/33206, which removed the half support for cpu about 4 years ago.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123823
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-04-17 09:44:07 +00:00
Joona Havukainen
05289a278c Fix for MPS regression in #122016 and #123178 (#123234)
Fixes #122016 and #123178. This regression is related to an OS side change that requires a slight adjustment from us on PyTorch side to restore the previous behavior. Additionally we cleared out pre-MacOS13 related workarounds.

Before the fix on MacOS 14.4:

```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 3., 3.], device='mps:0')
```

After the fix:
```
python -c "import torch;x=torch.zeros(3, device='mps');x[1] = 1; x[2] = 3; print(x)"
tensor([0., 1., 3.], device='mps:0')
```

This also fixes complex number initialization and as such makes `nn.functional.rms_norm` pass on MacOS-14+

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123234
Approved by: https://github.com/malfet, https://github.com/kulinseth
2024-04-03 23:00:57 +00:00
PyTorch MergeBot
feabb645a7 Revert "Handle transposes in second batch of matrices in bmm (#122194)"
This reverts commit 251ad1232b.

Reverted https://github.com/pytorch/pytorch/pull/122194 on behalf of https://github.com/malfet due to Broke lint ([comment](https://github.com/pytorch/pytorch/pull/122194#issuecomment-2032806360))
2024-04-02 18:49:28 +00:00
Kulin Seth
251ad1232b Handle transposes in second batch of matrices in bmm (#122194)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122194
Approved by: https://github.com/DenisVieriu97
2024-04-02 17:48:35 +00:00
Nikita Shulga
4c70ab26ef [MPS] Enable index_select for complex types (#122590)
Surprisingly, as of MacOS-14.14 MPS `gatherWithUpdatesTensor:indicesTensor:axis:batchDimensions:name:` still does not support complex types, so emulate them by using `at::view_as_real` trick

Fixes https://github.com/pytorch/pytorch/issues/122427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122590
Approved by: https://github.com/Skylion007
2024-03-25 16:57:35 +00:00
andrewor14
773ae817f7 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-18 21:01:30 +00:00
Roger Lam
40acc84aaf Fix torch.clamp in MPS to handle NaN correctly (#121381)
Fixes #120899

So this is interesting. There are methods that specifically propagate NaN instead of clamping to real numbers.
https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/3857573-maximumwithnanpropagationwithpri

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121381
Approved by: https://github.com/malfet
2024-03-18 19:38:15 +00:00
PyTorch MergeBot
0cc60a05da Revert "Fix torch.clamp in MPS to handle NaN correctly (#121381)"
This reverts commit ca80d07ac7.

Reverted https://github.com/pytorch/pytorch/pull/121381 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think its test is failing in trunk https://github.com/pytorch/pytorch/actions/runs/8302739752/job/22725865151#step:7:644, we should have ciflow/mps to run the test on PR.  Please take a look a reland the change ([comment](https://github.com/pytorch/pytorch/pull/121381#issuecomment-2000685856))
2024-03-15 23:53:05 +00:00
Roger Lam
ca80d07ac7 Fix torch.clamp in MPS to handle NaN correctly (#121381)
Fixes #120899

So this is interesting. There are methods that specifically propagate NaN instead of clamping to real numbers.
https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/3857573-maximumwithnanpropagationwithpri

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121381
Approved by: https://github.com/malfet
2024-03-15 21:54:50 +00:00
Nikita Shulga
5498804ec2 [MPS] Fix naive matmul for BFloat16 (#121731)
Will only work on MacOS14 or newer, so compile the shader with `MTLLanguageVersion_3_1` when appropriate

Fixes https://github.com/pytorch/pytorch/issues/121583
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121731
Approved by: https://github.com/albanD
2024-03-13 14:34:03 +00:00
Nikita Shulga
07330ff7b6 [MPS][BE] Define _compute_tolerances (#121754)
Right now logic is mostly duplicated between `test_output_match` and `test_output_gradient_match`
So move tolerance definition logic into a shared `_compute_tolerances` function and
only keep differences (for example, grad checks are completely skipped for `torch.unique`) in the respective test functions.

Also, increase tolerance for `pow` and `__rpow__` only on MacOS-13.3 or older and remove GRAD xfaillist for those

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121754
Approved by: https://github.com/albanD
2024-03-13 04:08:06 +00:00
PyTorch MergeBot
fd0dbcd891 Revert "Batch Norm Consolidation (#116092)"
This reverts commit 7b4f70eda5.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965))
2024-03-11 22:22:41 +00:00
Boyuan Feng
35d3adb4b0 Add ATen Op _chunk_cat and _chunk_cat.out (#121081)
# Motivation

In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.

### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):

Input tensors:
```
AAAA   BBB   CC
AAAA   BBB
       BBB
```

Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```

### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):

Input tensors:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Reduce-scatter-copy-in first pad:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```

The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.

# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:

```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```

This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.

## Requirements on input

1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
2024-03-08 21:48:12 +00:00
Nikita Shulga
9b03a06288 [BE] [MPS] Fix out resize logic in torch.where (#121476)
By deleting `where_mps`  and registering MPS dispatch for `where_kernel`.
As result of this change resizing and type-checking logic is shared between MPS, CPU and  CUDA backends.

Add test_case to `TestMPS.test_where` (that should eventually be removed, when `out` OpInfo testing is enabled for MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121476
Approved by: https://github.com/albanD, https://github.com/Skylion007
ghstack dependencies: #121473, #121494
2024-03-08 18:59:37 +00:00
andrewor14
7b4f70eda5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-08 15:07:15 +00:00
PyTorch MergeBot
b529c19bdf Revert "Batch Norm Consolidation (#116092)"
This reverts commit 5680f565d5.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
2024-03-06 17:10:01 +00:00
Tugsbayasgalan Manlaibaatar
5680f565d5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-06 04:50:46 +00:00
Kai
c59b14163b Implement aten::upsample_linear1d on mps (#115031)
Related to #77764

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115031
Approved by: https://github.com/malfet
2024-02-26 23:04:52 +00:00
Nikita Shulga
53bfae2c06 [MPS] Add torch.fft. support (#119670)
Increase tolerance for `ftt` ops, that warrants a further investigation as it grows larger with larger matrix dimensions (see https://github.com/pytorch/pytorch/issues/120237 )

When compiling on MacOS13, implement `+[FakeMPSGraphFFTDescriptor descriptor]` as a redispatch to a real thing.

Fixes https://github.com/pytorch/pytorch/issues/78044
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119670
Approved by: https://github.com/kulinseth, https://github.com/albanD
2024-02-20 18:23:06 +00:00
Nikita Shulga
eb9a3383c2 [MPS] Add naive std_mean implementation (#119777)
By just calling `std_mps` and `mean` in sequence

Move `var_mean` decomp to `ReduceOps.mm`, as it should be faster to skip dispatching to a Python, which one can validate by running the following script:
```python
from timeit import default_timer

import torch
from torch.utils.benchmark import Measurement, Timer

def bench_var_mean(
    m, n, k,
    dtype = torch.float32,
    device:str = "cpu",
) -> Measurement:
    setup = f"""
     x = torch.rand({m}, {n}, {k}, dtype={dtype}, device="{device}")
    """

    t = Timer(
        stmt="torch.var_mean(x, dim=1)", setup=setup, language="python", timer=default_timer
    )
    return t.blocked_autorange()

for x in [100, 1000]:
    rc = bench_var_mean(1000, x, 100, device="mps")
    print(f"{x:5} : {rc.mean*1e6:.2f} usec")
```
which before the change reports 681 and 1268 usec and after 668 and 684 (which probably means that GPU is not saturated, but overhead from switching between native and interpretable runtimes are shorter.

Fixes https://github.com/pytorch/pytorch/issues/119663

TODOs:
 - Refactor the codebase and implement proper composite function (that must be faster)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119777
Approved by: https://github.com/albanD
2024-02-13 21:51:29 +00:00
Nikita Shulga
15ef52a015 [MPS] Enable conj and conj_physical (#119669)
Former is only on MacOS 14+, but at least on older MacOSes it would raise an exception rather than returning non-conjugated tensor

Preliminary step for enabling FFT ops (without it `ifft` would never work)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119669
Approved by: https://github.com/albanD
ghstack dependencies: #119681
2024-02-13 02:27:51 +00:00
Nikita Shulga
8d8fb9783c [MPS][EZ] Fix cfloat->chalf conversion on MacOS13 (#119681)
By using `view_as_real` when type casting between two complex types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119681
Approved by: https://github.com/Skylion007, https://github.com/albanD
2024-02-12 19:09:10 +00:00
Pearu Peterson
2c91e13afc Add lowerings to special functions (#119187)
As in the title.

In addition, the PR introduces infrastructure for lowerings of pointwise functions that have both cpp and triton implementations available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119187
Approved by: https://github.com/peterbell10
2024-02-11 16:35:40 +00:00
Nikita Shulga
4ee8aac432 [MPS] Enable bfloat16 support on MacOS 14 (#119641)
Per [MPSDataType](https://developer.apple.com/documentation/metalperformanceshaders/mpsdatatype/mpsdatatypebfloat16?changes=_11&language=objc) documentation bfloat16 are supported in MacOS Sonoma or later

Added missing `MPSDataTypeBFloat16` and `MTLLanguageVersion3_1` enums to `MPSGraphSonomaOps.h`

TODO: Enable more testing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119641
Approved by: https://github.com/Skylion007
2024-02-11 16:25:29 +00:00
Nikita Shulga
1d61011c11 [MPS] Add support for complex scalars (#119318)
- Switch to native complex support if running on MacOS Monterey or newer for binary ops.
- Python complex scalars are always represented in PyTorch as ComplexDouble, but MPS yet to support double precision types, so downcast them to floats
- Also add `cf`(for complex float)  and `ch`(for complex half) to MPSScalar value union
- Fix complex scalars to view promotion, by introducing `legacy_complex_as_view` helper function, that non-float types to complex and promotes CPU complex scalars to MPS before turning them into a view.
- Add `test_tensor_scalar_binops`

Fixes https://github.com/pytorch/pytorch/issues/119088

Test plan: CI (have quite a lot of tests, see new unexpected successes) +  `python -c "import torch;x,y=torch.rand(2, 2, dtype=torch.cfloat, device='mps'),torch.tensor(2+3j,dtype=torch.chalf);print(y+x)"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119318
Approved by: https://github.com/albanD
2024-02-08 18:10:59 +00:00
watarungurunnn
d444a3b443 [MPS] fix float32 error on mps, in linalg.matrix_rank and linalg.pinv (#114771)
Fixes #114285

(However, still have NotImplementedError
```NotImplementedError: The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114771
Approved by: https://github.com/lezcano
2024-02-05 15:36:55 +00:00
lancerts
26a2743162 Fix placeholder tensor is empty for relu in mps (#118965)
Fixes #118845
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118965
Approved by: https://github.com/malfet
2024-02-03 23:50:35 +00:00
Nikita Shulga
24dd9f42ce [MPS] Fix use_metal_mm condition (#118830)
One should not only look at stride size, but on dimensions as well, as strides of `torch.rand(65536, 1)` are `(1, 1)`

Extend test to account for this situation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118830
Approved by: https://github.com/huydhn
2024-02-01 17:53:42 +00:00
Yifu Wang
a1280f0cc6 Add an OpInfo test for split_with_sizes_copy (#118512)
Adding an `OpInfo` test for `split_with_sizes_copy` so we can use it to test [CUDA fast path for split_with_sizes_copy.out](https://github.com/pytorch/pytorch/pull/117203). Since the `OpInfo` test doesn't exist yet and introducing it requires modifications to the `CompositeExplicitAutograd` impl, adding the `OpInfo` test in a separate PR to establish a healthy baseline.

Changes made:
- Registered a batching rule for `split_with_sizes_copy`.
- Registered a decomposition for `split_with_sizes_copy`.
- Registered a DTensor prop rule for `split_with_sizes_copy`.
- Added required dtype and device checks to the composite impl.
- Added output resize to the composite impl.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118512
Approved by: https://github.com/albanD
2024-02-01 07:09:27 +00:00
Sun, Jiayi
2dd4a254a0 add Half support for interpolate operators on CPU (#105648)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105648
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch
2024-01-18 09:07:16 +00:00
Nikita Shulga
1872834247 [MPS] Fix torch.mm correctness for large matrices (#117549)
Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K
Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows:
```objc
  NSMutableArray<MPSGraphTensor*>* rows = [NSMutableArray new];
  for (int64_t i = 0; i < M; i += tile_size) {
    const auto i_end = std::min(i + tile_size, M);
    NSMutableArray<MPSGraphTensor*>* row_chunks = [NSMutableArray new];
    for (int64_t j = 0; j < K; j += tile_size) {
      const auto j_end = std::min(j + tile_size, K);
      MPSGraphTensor* tile = nil;
      for (int64_t k = 0; k < N; k += tile_size) {
        const auto k_end = std::min(k + tile_size, N);
        auto selfChunk = [graph sliceTensor:selfTensor
                                     starts:@[ @(i), @(k) ]
                                       ends:@[ @(i_end), @(k_end) ]
                                    strides:@[ @(1), @(1) ]
                                       name:nil];
        auto otherChunk = [graph sliceTensor:otherTensor
                                      starts:@[ @(k), @(j) ]
                                        ends:@[ @(k_end), @(j_end) ]
                                     strides:@[ @(1), @(1) ]
                                        name:nil];
        auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil];

        tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM;
      }
      [row_chunks addObject:tile];
    }
    auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject;
    [rows addObject:row];
  }
  return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject;
```

One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable
Fixes https://github.com/pytorch/pytorch/issues/116769
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117549
Approved by: https://github.com/kulinseth
2024-01-17 01:33:08 +00:00
Nikita Shulga
6784030df4 [MPS] Add support for 64-bit index operations (#116942)
But enable it only if `iter.can_use_32bit_indexing()` is False. add test for index_select, but enable it only on Sonoma, as all attempts to create 4Gb+ tensor on Ventura and older fail
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116942
Approved by: https://github.com/Skylion007, https://github.com/kulinseth
ghstack dependencies: #116903, #116904, #116915, #116940
2024-01-09 16:56:49 +00:00
Nikita Shulga
ff0f79d3c7 [MPS] Mark torch.[all|any] as working with complex on MacOS14 (#116907)
It was enabled by https://github.com/pytorch/pytorch/pulls/116457 but at the time PR was landed Sonoma testing was still not enabled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116907
Approved by: https://github.com/osalpekar, https://github.com/kit1980
2024-01-06 01:10:11 +00:00
Nikita Shulga
b0393ebe9b [MPS] Make test_mps.py passable on Sonoma (#116764)
- Enable Sonoma testing on M2 machines
- Add 70+ ops to the list of supported ones on MacOS Sonoma
- Enable nn.functional.
- Add explicit `TORCH_CHECK` to mark scatter/gather, index_select and linalg ops as yet not supporting Complex, as attempt to call those will crash with various MPS asserts such as:
```
(mpsFileLoc): /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:96:0: error: 'mps.reduction_min' op operand #0 must be tensor of MPS type values or memref of MPS type values, but got 'tensor<5x5xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:96:0: note: see current operation: %3 = "mps.reduction_min"(%1, %2) <{keep_dims}> : (tensor<5x5xcomplex<f32>>, tensor<2xsi32>) -> tensor<1x1xcomplex<f32>>
```
- Treat bools as int8 to fix regression re-surfaced in `index_fill` (used to be broken in Monterey, then fixed in Ventura and broken in Sonoma again)
- `nn.functional.max_pool2d` results now match CPU output for uint8 dtype in Sonoma

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116764
Approved by: https://github.com/kulinseth, https://github.com/seemethere
2024-01-05 00:25:47 +00:00
Gao Tianlin
6793b99107 [BugFix] Fix SegFault when torch.all/any dispatched to mps or other backends (#116457)
The old implementation will result in an infinite recursive loop, leading to a stack overflow and segfault.

If TORCH_SHOW_DISPATCH_TRACE is on, with a debug version pytorch, we can see the following endless output in terminal:
```
[call] op=[aten::quantize_per_tensor], key=[AutogradCPU]
  [redispatch] op=[aten::quantize_per_tensor], key=[CPU]
 [call] op=[aten::any.dims], key=[AutogradCPU]
  [redispatch] op=[aten::any.dims], key=[QuantizedCPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::any.dims_out], key=[QuantizedCPU]
    [call] op=[aten::any.dims], key=[QuantizedCPU]
     [call] op=[aten::empty.memory_format], key=[BackendSelect]
      [redispatch] op=[aten::empty.memory_format], key=[CPU]
     [call] op=[aten::any.dims_out], key=[QuantizedCPU]
      [call] op=[aten::any.dims], key=[QuantizedCPU]
       [call] op=[aten::empty.memory_format], key=[BackendSelect]
        [redispatch] op=[aten::empty.memory_format], key=[CPU]
       [call] op=[aten::any.dims_out], key=[QuantizedCPU]
        [call] op=[aten::any.dims], key=[QuantizedCPU]
         [call] op=[aten::empty.memory_format], key=[BackendSelect]
          [redispatch] op=[aten::empty.memory_format], key=[CPU]
         [call] op=[aten::any.dims_out], key=[QuantizedCPU]
          [call] op=[aten::any.dims], key=[QuantizedCPU]
           [call] op=[aten::empty.memory_format], key=[BackendSelect]
            [redispatch] op=[aten::empty.memory_format], key=[CPU]
           [call] op=[aten::any.dims_out], key=[QuantizedCPU]
            [call] op=[aten::any.dims], key=[QuantizedCPU]
             [call] op=[aten::empty.memory_format], key=[BackendSelect]
              [redispatch] op=[aten::empty.memory_format], key=[CPU]
             [call] op=[aten::any.dims_out], key=[QuantizedCPU]
              [call] op=[aten::any.dims], key=[QuantizedCPU]
               [call] op=[aten::empty.memory_format], key=[BackendSelect]
                [redispatch] op=[aten::empty.memory_format], key=[CPU]
               [call] op=[aten::any.dims_out], key=[QuantizedCPU]
                [call] op=[aten::any.dims], key=[QuantizedCPU]
                 [call] op=[aten::empty.memory_format], key=[BackendSelect]
                  [redispatch] op=[aten::empty.memory_format], key=[CPU]
                 [call] op=[aten::any.dims_out], key=[QuantizedCPU]
                  [call] op=[aten::any.dims], key=[QuantizedCPU]
.....
.....
.....
```

Fixes #116452
Fixes #116451

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116457
Approved by: https://github.com/malfet
2024-01-04 17:37:17 +00:00
Aaron Gokaslan
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
Nikita Shulga
09ee96b69d [MPS] Fix CrossEntropyLoss for float16 (#116597)
Looks like neither [`divisionNoNaNWithPrimaryTensor:`](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/3675593-divisionnonanwithprimarytensor) nor `oneHotWithIndicesTensor:` works for `MPSDataTypeFloat16`, so provide an explicit cast for one-hot tensor and alternative implementation using the formula from the official doc, i.e.
> `resultTensor = select(secondaryTensor, primaryTensor / secondaryTensor, 0)`

Alas, at the moment  it can not be tested via `test_modules.py` as it runs only `torch.float32` and `torch.float64` tests (and `torch.half` implementation is not available for CPU)

Fixes https://github.com/pytorch/pytorch/issues/116095

TODO: Enable testing via TestModules, but will do in separate PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116597
Approved by: https://github.com/kulinseth
2024-01-03 05:58:26 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
PyTorch MergeBot
0978482afa Revert "Implement aten::upsample_linear1d on mps (#115031)"
This reverts commit c6969cb8a9.

Reverted https://github.com/pytorch/pytorch/pull/115031 on behalf of https://github.com/malfet due to Broke lint, will fwd fix and re-land ([comment](https://github.com/pytorch/pytorch/pull/115031#issuecomment-1869693081))
2023-12-26 18:01:49 +00:00
Kai
c6969cb8a9 Implement aten::upsample_linear1d on mps (#115031)
Related to #77764

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115031
Approved by: https://github.com/malfet
2023-12-26 15:44:21 +00:00
Aaron Gokaslan
6de28e92d2 [BE]: Apply FURB118 (prev): replaces unnecessary lambdas with operator. (#116027)
This replaces a bunch of unnecessary lambdas with the operator package. This is semantically equivalent, but the operator package is faster, and arguably more readable. When the FURB rules are taken out of preview, I will enable it as a ruff check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116027
Approved by: https://github.com/malfet
2023-12-20 19:35:08 +00:00
Sun, Jiayi
c173a9d9b3 add Half support for layer_norm on CPU (#99590)
### Testing
Single socket (icx, 32cores):
| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.051 | 0.051 | 0.050 |
| (8 ,8, 16) | 0.013 | 0.013 | 0.013 | 0.054 | 0.053 | 0.051 |
| (32, 8, 16) | 0.015 | 0.014 | 0.014 | 0.059 | 0.054 | 0.052 |
| (64, 128, 56, 56) | 1.875 | 0.790 | 1.016 | 12.845 | 7.151 | 6.985 |
| (64, 128, 256, 256) | 50.226 | 25.462 | 35.736 | 328.957 | 179.615 | 175.618 |

Single core (icx):

| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.040 | 0.041 | 0.041 |
| (8 ,8, 16) | 0.012 | 0.012 | 0.012 | 0.042 | 0.042 | 0.042 |
| (32, 8, 16) | 0.027 | 0.014 | 0.014 | 0.048 | 0.048 | 0.046 |
| (64, 128, 56, 56) | 58.054 | 11.034 | 17.928 | 108.603 | 48.816 | 50.244 |
| (64, 128, 256, 256) | 1327.758 | 352.394 | 496.994 | 2846.182 | 1224.247 | 1218.422 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99590
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/cpuhrsch
2023-12-20 01:11:15 +00:00
Nikita Shulga
9dda4b20a0 [MPS] Enable select/[broad]cast ops for complex dtypes (#115727)
By representing `torch.cfloat`/`torch.chalf` as `float2`/`half2` metal types and modifying `SCATTER_OPS_TEMPLATE`/`GATHER_OPS_TEMPLATE` to accept third argument which is fully specialized `cast` function, which is no-op for regular type, but special cased for float->complex and complex->float

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115727
Approved by: https://github.com/kulinseth
2023-12-19 02:25:28 +00:00
Peter Pham
74dfdc567b [MPS] aten::erfinv bug fix: add storage offset buffers to handle slicing (#105801)
A bug fix of a recently merged PR per comment: https://github.com/pytorch/pytorch/pull/101507#discussion_r1271393706

The follow test would fail without this bug fix:

```
import torch
def test_erfinv():
    for device in ['cpu', 'mps']:
        x = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], device=device)
        y = x[2:].erfinv()

        x2 = torch.tensor([0.3, 0.4, 0.5], device=device)
        y2 = x2.erfinv()

        print(y)
        print(y2)

        torch.testing.assert_close(y, y2)
        print(f"{device} passes.")

test_erfinv()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105801
Approved by: https://github.com/malfet
2023-12-15 23:14:03 +00:00
Lucas Steuernagel
2e517b20d9 [MPS] Add Conv3D support for MPS (#114183)
Fixes #77818

I saw that PR #99246 was approved, but no one fixed the rebase conflicts, so I am bringing this up again to be merged.
I am leveraging @mattiaspaul work. Quoting the description here:

> * this pull request enables 3D convolutions (forward/backward) for MPS (Apple Silicon) within the same Convolution.mm file as conv2d.
> * does not support channel_last (since pytorch doesn't implement channel_last for 3D tensors)
> * does not support conv3d_transpose and treats depth-separable convolutions not as normal case (there are no MPS kernels available for either of those so far)
> * requires MacOS >=13.2 (Ventura)

Please, let me know if there are any other changes needed and I'll be happy to implement them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114183
Approved by: https://github.com/malfet
2023-12-15 23:05:01 +00:00
mingfeima
a8acd6c410 Add Half support for AvgPool2d on CPU (#109578)
Add Half support for AvgPool2d (both channels last and channels first) on CPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109578
Approved by: https://github.com/mingfeima, https://github.com/albanD
2023-12-12 12:59:47 +00:00
igm503
f017a1af3f [MPS] add complex_out to MPS backend (#110851)
Adds support for at::complex_out to the MPS backend

Implemented in a binary kernel using the view_as_real pattern for handling complex dtypes in the mps backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110851
Approved by: https://github.com/kulinseth
2023-12-11 13:37:55 +00:00
Li-Huai (Allan) Lin
38e1440bae [MPS] Remove redundant topk test and move all pad tests inside a class (#113313)
Summary:
1. The removed `topk` test is essentially very similar to the following test, so I remove it:
```python
def test_topk(self):
        def helper(shape):
            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
            x = cpu_x.detach().clone().to('mps')
            for largest_val in [True, False]:
                if (type(shape) == tuple):
                    for curr_dim in range(0, len(shape)):
                        dim_size = shape[curr_dim]
                        for k in range(1, dim_size + 1):
                            topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest_val)
                            topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest_val)
                            self.assertEqual(topk_values, topk_values_cpu)
                            self.assertEqual(topk_indices, topk_indices_cpu)
                else:
                    for k in range(1, shape):
                        topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest_val)
                        topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest_val)
                        self.assertEqual(topk_values, topk_values_cpu)
                        self.assertEqual(topk_indices, topk_indices_cpu)

        helper(2)
        helper((5, 1))
        helper((1, 5))
        helper((5, 9, 7, 4))
        helper((50, 20, 7, 4))
```
297c26bb8e/test/test_mps.py (L8054-L8091)

2. Move all pad tests to one standalone class.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113313
Approved by: https://github.com/kulinseth
ghstack dependencies: #113312
2023-12-01 06:52:07 +00:00
Li-Huai (Allan) Lin
88a659e752 [MPS] Move non-nll loss tests outside TestNLLLoss (#113312)
The diff looks messy but this PR essentially does one thing: Move non-nll loss tests in `TestNLLLoss` class to `TestMPS` class. After doing so, it ends up having two stack tests the same name `test_stack` ; therefore, I rename one of them to `test_stack_storage_offset`, which is what the test actually does.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113312
Approved by: https://github.com/kulinseth
2023-12-01 06:52:07 +00:00
Nikita Shulga
1b27eae65e [MPS] Fix out-of-bounds fill to sliced tensor (#114838)
This fixes regression introduced by https://github.com/pytorch/pytorch/pull/81951 that caused out-of-bounds access when sliced tensor is filled with zeros

Remove bogus `TORCH_INTERNAL_ASSERT(length >= offset)` as [NSMakeRange](https://developer.apple.com/documentation/foundation/1417188-nsmakerange?language=objc) arguments are location and length rather than start and end offset.

In `fill_mps_tensor_`:
- Pass `value` argument to `MPSStream::fill`
- Pass `self.nbytes()` rather than `self.storage().nbytes()` as length of of buffer to fill as later will always results in out-of-bounds write if offset within the store is non-zero

Add regression test

Fixes https://github.com/pytorch/pytorch/issues/114692

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114838
Approved by: https://github.com/atalman, https://github.com/kulinseth
2023-12-01 06:24:42 +00:00
Khushi Agrawal
cff84871ce [reland][opinfo][fix] conv3d & fix conv{1, 2}d for neg dilation|groups & add ErrorInputs for conv ops (#114589)
Previous PR: #113885

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114589
Approved by: https://github.com/lezcano
2023-11-27 14:45:44 +00:00
PyTorch MergeBot
150aaf46ca Revert "[opinfo][fix] conv3d & fix conv{1, 2}d for neg dilation|groups & add ErrorInputs for conv ops (#113885)"
This reverts commit 4fa1ff8404.

Reverted https://github.com/pytorch/pytorch/pull/113885 on behalf of https://github.com/huydhn due to Sorry for reverting you change but its TestCommonCUDA::test_compare_cpu_nn_functional_conv3d test failing in trunk 4fa1ff8404 ([comment](https://github.com/pytorch/pytorch/pull/113885#issuecomment-1827268473))
2023-11-27 07:33:00 +00:00
Khushi Agrawal
4fa1ff8404 [opinfo][fix] conv3d & fix conv{1, 2}d for neg dilation|groups & add ErrorInputs for conv ops (#113885)
Previous PR: https://github.com/pytorch/pytorch/pull/85202

Also, cc'ing @lezcano @kshitij12345 @zou3519, who reviewed my previous PR. Thanks!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113885
Approved by: https://github.com/lezcano
2023-11-26 13:44:30 +00:00
Nikita Shulga
324cde59b2 [MPS] Fix test_copy_cast_no_leak (#114313)
When running on MacOS-13.2 test always fails on first run, but succeeds on the second as presumably it reserves some memory to cache f32->f16 graph. Make it resilient against such failures by adding a warmup step when one conversion is performed before recording driver memory utilization.

Fixes https://github.com/pytorch/pytorch/issues/114305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114313
Approved by: https://github.com/huydhn
2023-11-22 14:48:24 +00:00
Nikita Shulga
b5dd37f23e [MPS] Fix memory leak in copy_from_mps_ (#114197)
By always calling `[destBuffer release]` before leaving the scope in which it was allocated.
Leak was introduced by https://github.com/pytorch/pytorch/pull/84928
Add regression test.
Before the change:
```
% python ../test/test_mps.py -v -k test_copy_cast_no_leak --repeat 10
test_copy_cast_no_leak (__main__.TestMemoryLeak) ... FAIL

======================================================================
FAIL: test_copy_cast_no_leak (__main__.TestMemoryLeak)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2554, in wrapper
    method(*args, **kwargs)
  File "/Users/nshulga/git/pytorch/pytorch/build/../test/test_mps.py", line 1064, in test_copy_cast_no_leak
    self.assertTrue(driver_before == driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
AssertionError: False is not true : Detected 65536 bytes leak of GPU memory

To execute this test, run the following from the base repo dir:
     python test/test_mps.py -k test_copy_cast_no_leak

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 1 test in 1.102s

FAILED (failures=1)
```
After:
```
% python ../test/test_mps.py -k test_copy_cast_no_leak --repeat 10
.
----------------------------------------------------------------------
Ran 1 test in 0.819s

OK
.
----------------------------------------------------------------------
Ran 1 test in 0.001s

OK
.
----------------------------------------------------------------------
Ran 1 test in 0.002s

OK
...
```

Fixes https://github.com/pytorch/pytorch/issues/114096

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114197
Approved by: https://github.com/kit1980
2023-11-21 14:52:55 +00:00
Li-Huai (Allan) Lin
538114db65 [MPS] Fix and refactor unary/binary ops with non-zero offset or non-contiguous output (#97085)
Fixes #100764

This PR fixes the unary ops implementation and refactors the binary ops implementation a bit.

For unary ops:
Previously we didn't take into account unary ops that have a non-contiguous/storage-offset output, causing an incorrect result (because the MPS graph kernel always writes the buffer contiguously). Therefore, this PR creates a temporary output tensor for the graph first and then copy the result back to the original output tensor. We currently do not have a better fix other than this I think.

For binary ops, see https://github.com/pytorch/pytorch/pull/97085#discussion_r1140999125

See the added test for repro.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97085
Approved by: https://github.com/malfet
2023-11-14 22:03:21 +00:00
Nikita Shulga
265d6aac0b [MPS] Fix crashes during Conv backward pass (#113398)
By adding weights tensor to the MPSGraph cache key.
Add regression test to validate that collision no longer happens

Fixes https://github.com/pytorch/pytorch/issues/112998

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113398
Approved by: https://github.com/kulinseth
2023-11-10 04:29:33 +00:00