Commit Graph

680 Commits

Author SHA1 Message Date
Kulin Seth
4615f6aa97 [MPS]: Add fix for squeezed input axes handling in BCE loss (#79676)
Fixes #79527

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79676
Approved by: https://github.com/razarmehr, https://github.com/albanD
2022-06-16 20:21:31 +00:00
Kulin Seth
355a1c8c3f MPS: TopK raise an error if K>16 (#79677)
* Error out in TopK when k>16.
* Add a test case too.

Fixes #78915

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79677
Approved by: https://github.com/albanD
2022-06-16 16:06:45 +00:00
Nikita Shulga
81cd276d61 [MPS] Support stride of stride
Fixes https://github.com/pytorch/pytorch/issues/79181

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79521

Approved by: https://github.com/kulinseth
2022-06-14 18:49:44 +00:00
Alban Desmaison
0a651a231d Add full support for serialization of MPS Tensors (#79465)
Fix https://github.com/pytorch/pytorch/issues/79384
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79465
Approved by: https://github.com/kulinseth, https://github.com/malfet
2022-06-14 17:54:30 +00:00
PyTorch MergeBot
ce6ce74703 Revert "Add full support for serialization of MPS Tensors (#79465)"
This reverts commit 64c2a275c4.

Reverted https://github.com/pytorch/pytorch/pull/79465 on behalf of https://github.com/zengk95 due to this broke X linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge). Not sure why since it passed on pull.
2022-06-14 16:42:36 +00:00
Alban Desmaison
64c2a275c4 Add full support for serialization of MPS Tensors (#79465)
Fix https://github.com/pytorch/pytorch/issues/79384
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79465
Approved by: https://github.com/kulinseth, https://github.com/malfet
2022-06-14 14:20:09 +00:00
Kulin Seth
77b6885a22 MPS: add layer_norm_backward (#79189)
Layernorm backward

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79189
Approved by: https://github.com/razarmehr, https://github.com/albanD
2022-06-10 13:25:41 +00:00
Kulin Seth
83239351c5 MPS: add exponential op (#79188)
Add exponential distribution

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79188
Approved by: https://github.com/razarmehr, https://github.com/albanD
2022-06-10 13:16:21 +00:00
Kulin Seth
50f7b40ad9 MPS: Binary cast fix by proper type promotion and remove spurious copy warning (#79185)
Fixes #78019, #78020
Fixes https://github.com/pytorch/pytorch/pull/79185
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79185
Approved by: https://github.com/albanD, https://github.com/razarmehr
2022-06-09 17:33:06 +00:00
Nikita Shulga
97594a24b4 Print output during MPS test import tests (#79163)
Simplify `test_no_warnings_on_input` to simply capture any output.
Copy its implementation to `test_testing.py` as this is not specific to MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79163
Approved by: https://github.com/janeyx99, https://github.com/kulinseth
2022-06-09 13:07:05 +00:00
Philip Meier
32593ef2dd move MPS compat into common comparison machinery (#77836)
Addresses https://github.com/pytorch/pytorch/issues/77144#issuecomment-1128168082.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77836
Approved by: https://github.com/albanD
2022-06-08 08:09:18 +00:00
Kulin Seth
a6347f5467 MPS: Fixes (#78930)
Cast integer to float in UnaryOps
Add tensor dtype in key generation
Enable FP16 scalars and use placeholder for alpha tensor in add/sum ops

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78930
Approved by: https://github.com/albanD
2022-06-07 18:22:10 +00:00
Nikita Shulga
55cac22cdf [MPS] Add arange_mps_out implementation (#78789)
Mostly by factoring out shader logic from `linspace_out_mps` implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78789
Approved by: https://github.com/albanD, https://github.com/kulinseth
2022-06-03 21:54:41 +00:00
Kulin Seth
4858c56334 MPS: Fix issues with view tensors and linspace. (#78690)
Fixes: #https://github.com/pytorch/pytorch/issues/78642, https://github.com/pytorch/pytorch/issues/78511
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78690
Approved by: https://github.com/razarmehr, https://github.com/DenisVieriu97
2022-06-02 06:17:19 +00:00
Kulin Seth
a3bdafece3 MPS: add linespace op (#78570)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78570
Approved by: https://github.com/malfet
2022-06-01 13:47:14 +00:00
Ramin Azarmehr
aa62b3e003 Add test case for issue: https://github.com/pytorch/pytorch/issues/77851 (#78547)
The test works fine now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78547
Approved by: https://github.com/kulinseth
2022-05-31 19:15:45 +00:00
Rohan Mitchell
f42b42d3eb MPS: Implement aten::count_nonzero.dim_IntList (#78169)
- See: #77764

Implements the `aten::count_nonzero.dim_IntList` operator (as used by [torch.count_nonzero](https://pytorch.org/docs/stable/generated/torch.count_nonzero.html)) for [MPS](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78169
Approved by: https://github.com/malfet, https://github.com/kulinseth, https://github.com/albanD
2022-05-31 18:23:25 +00:00
Kulin Seth
017b0ae943 MPS: Fix crashes in view tensors due to buffer size mismatch (#78496)
Fixes #78247, #77886

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78496
Approved by: https://github.com/albanD, https://github.com/malfet
2022-05-31 02:09:03 +00:00
Alban Desmaison
bde246fcc6 Speed up test_mps from 9min to 25s
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78488

Approved by: https://github.com/kulinseth
2022-05-30 18:16:53 +00:00
Alban Desmaison
02551a0025 Remove prints and add proper asserts
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78454

Approved by: https://github.com/kulinseth
2022-05-30 18:16:53 +00:00
Kulin Seth
d63db52349 MPS: Fixes the as_strided_mps implementation for contiguous view operations (#78440)
Fixes https://github.com/pytorch/pytorch/issues/78107; https://github.com/pytorch/pytorch/issues/77750

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78440
Approved by: https://github.com/malfet
2022-05-28 14:41:56 +00:00
Nikita Shulga
437ecfc461 [MPS] Fix copy_kernel_mps (#78428)
By passing `storage_offset` of source and destination Tensors
This fixes following simple usecase:
```
python3` -c "import torch;x=torch.zeros(3, 3, device='mps'); x[1, 1]=1;print(x)"
```

Add test to validate it would not regress in the future

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78428
Approved by: https://github.com/kulinseth
2022-05-27 20:46:53 +00:00
Kulin Seth
8552acbd74 MPS: Eye op (#78408)
This can be used as a reference PR was to add Op in MPS backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78408
Approved by: https://github.com/albanD
2022-05-27 17:07:02 +00:00
Kulin Seth
2e32d5fcd8 MPS: Add adaptive max pool2d op (#78410)
Adaptive max pool 2d forward and backward with test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78410
Approved by: https://github.com/albanD
2022-05-27 11:59:07 +00:00
Nikita Shulga
705082656a Fix typo in testname (#78258)
`test_linear2D_no_bias_backwarwd` -> `test_linear2D_no_bias_backward`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78258
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
2022-05-25 16:23:10 +00:00
Lukas Hoenig
a52bfe2c5d Convert MPS Tensor data using MPSGraph API (#78092)
Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~.

Before:
```python
In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
       device='mps:0', dtype=torch.int32)

In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True,
         True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True],
       device='mps:0')
```

After:
```python
In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
       device='mps:0', dtype=torch.int32)

In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10.], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True], device='mps:0')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78092
Approved by: https://github.com/kulinseth, https://github.com/malfet
2022-05-24 20:09:45 +00:00
Alban Desmaison
04ac80c73a Fix a few issues on assert/double error/legacy constructor (#77966)
Fixes https://github.com/pytorch/pytorch/issues/77960, https://github.com/pytorch/pytorch/issues/77957, https://github.com/pytorch/pytorch/issues/77781
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77966
Approved by: https://github.com/soulitzer, https://github.com/kulinseth
2022-05-20 20:25:12 +00:00
Kulin Seth
3d83321b44 MPS Fixes: copy operations, addmm and baddmm (#77791)
Fixes for the copy operations and GEMM operations on MPS backend.

Fixes https://github.com/pytorch/pytorch/issues/77819
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77791
Approved by: https://github.com/albanD
2022-05-20 03:18:11 +00:00
Kulin Seth
978304fc9c MPS: fixes (#77462)
- Fix the is_available flag for x86 machines
- Fix the tensor creation for older MacOS platforms
- Addmm fixes for transposition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77462
Approved by: https://github.com/albanD
2022-05-14 13:33:16 +00:00
Kulin Seth
e011a8e18b Enable PyTorch operations on MPS Backend. (#77343)
Add PyTorch operations to MPS backend.

- https://github.com/pytorch/pytorch/issues/77394
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77343
Approved by: https://github.com/albanD
2022-05-13 18:28:53 +00:00