Commit Graph

345 Commits

Author SHA1 Message Date
Kulin Seth
76cff18242 [MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79532
Approved by: https://github.com/albanD, https://github.com/malfet
2022-07-04 06:41:39 +00:00
Ramin Azarmehr
0e3953fc52 MPS: Fix handling of 1D tensors in linear backward (#80759)
Fixes #https://github.com/pytorch/pytorch/issues/79784

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80759
Approved by: https://github.com/ezyang
2022-07-04 02:06:14 +00:00
Kulin Seth
b744e1c8ef Add scatter support for view operations (#79939)
* Add scatter support for view operations; #78074, #78886, #79672
* Update test_slicing_replace_column to properly test different sizes
* Handle in-place changes for binary ops; add new testcase
* Add new view ops testing scatter; add MPSDebugConfig.h config file for debugging purposes
* Merge gatherViewTensor and scatterViewTensor into a generic function
* Add scatter on demand in scatterViewOperation instead of caching it into a generic graph
* Create separate graphs for scatter and gather;
* Create scatter graph at scatter time

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79939
Approved by: https://github.com/razarmehr
2022-07-01 15:10:56 +00:00
PyTorch MergeBot
b1943e01e2 Revert "[MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532)"
This reverts commit c71886e048.

Reverted https://github.com/pytorch/pytorch/pull/79532 on behalf of https://github.com/malfet due to Unintended submodules updates
2022-06-30 16:37:11 +00:00
qqaatw
ae6f07e7d5 [MPS] Fix std/var cache issue (#80502)
Use `getTensorsStringKey` which has tensor shape info added as part of the key to prevent cache lookup issue when the shape of input tensor is changed.

Fixes #80499

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80502
Approved by: https://github.com/malfet, https://github.com/kulinseth
2022-06-30 12:56:56 +00:00
qqaatw
c980fc3d3c [MPS] Add glu (#79866)
Adds mps op for `aten::glu.out`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79866
Approved by: https://github.com/kulinseth, https://github.com/albanD
2022-06-30 08:58:42 +00:00
Kulin Seth
c71886e048 [MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79532
Approved by: https://github.com/albanD
2022-06-30 01:50:17 +00:00
qqaatw
5943aaa0c4 [MPS] Add logical ops (#80216)
This PR adds `logical_not`, `logical_and`, `logical_or`, `logical_xor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80216
Approved by: https://github.com/albanD, https://github.com/kulinseth
2022-06-29 02:44:35 +00:00
qqaatw
c4da23ed1b [MPS] Add flip (#80214)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80214
Approved by: https://github.com/DenisVieriu97, https://github.com/albanD
2022-06-28 19:51:45 +00:00
qqaatw
e1b15b7a04 [MPS] add aten::normal.Tensor_float aten::normal.float_Tensor aten::normal.Tensor_Tensor (#80297)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80297
Approved by: https://github.com/albanD, https://github.com/kulinseth
2022-06-28 15:19:39 +00:00
Nikita Shulga
f11cce309b [MPS] Add equal operator (#80195)
Which is, in essence is composite of `eq`->`all`->`item`
`native/mps/operators/Equal.cpp` is an almost verbatim copy of `native/cuda/Equal.cpp`

Fix codegen by generating MPSFunctions headers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80195
Approved by: https://github.com/albanD
2022-06-25 12:40:52 +00:00
Nikita Shulga
06f874e276 [MPS] Fix binary ops between int32 tensor with int64 scalar (#80220)
For some reason, tensor *op* scalar does not follow the normal binary promotion rules
So cast output tensor to expected type if needed
It seems that one should have casted input tensors to expected output tensor type, but it does not really work for boolean binary ops, so...
Add output tensor type/shape to cached graph key
Extend `TestMPS. test_add_scalars` to test for this regression

Fixes #79835

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80220
Approved by: https://github.com/albanD
2022-06-25 02:21:34 +00:00
qqaatw
ff44bfa1ea [MPS] Add L1 loss test (#80010)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80010
Approved by: https://github.com/albanD
2022-06-24 17:18:31 +00:00
Nikita Shulga
4390546f86 [MPS] Fix torch.uint8 support (#80049)
`ScalarType.Byte` should be cast to `MPSDataTypeUInt8`
And support for `torch.int8` as well as test those conversions in `TestMPS.test_to`

Fixes #80006

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80049
Approved by: https://github.com/albanD
2022-06-22 18:41:21 +00:00
Abhishek Pathak
074dc7465e MPS: Add amax and amin Ops with tests (#79682)
* Add amax and amin with tests

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79682
Approved by: https://github.com/albanD
2022-06-18 00:14:05 +00:00
Kulin Seth
4615f6aa97 [MPS]: Add fix for squeezed input axes handling in BCE loss (#79676)
Fixes #79527

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79676
Approved by: https://github.com/razarmehr, https://github.com/albanD
2022-06-16 20:21:31 +00:00
Kulin Seth
355a1c8c3f MPS: TopK raise an error if K>16 (#79677)
* Error out in TopK when k>16.
* Add a test case too.

Fixes #78915

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79677
Approved by: https://github.com/albanD
2022-06-16 16:06:45 +00:00
Nikita Shulga
81cd276d61 [MPS] Support stride of stride
Fixes https://github.com/pytorch/pytorch/issues/79181

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79521

Approved by: https://github.com/kulinseth
2022-06-14 18:49:44 +00:00
Alban Desmaison
0a651a231d Add full support for serialization of MPS Tensors (#79465)
Fix https://github.com/pytorch/pytorch/issues/79384
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79465
Approved by: https://github.com/kulinseth, https://github.com/malfet
2022-06-14 17:54:30 +00:00
PyTorch MergeBot
ce6ce74703 Revert "Add full support for serialization of MPS Tensors (#79465)"
This reverts commit 64c2a275c4.

Reverted https://github.com/pytorch/pytorch/pull/79465 on behalf of https://github.com/zengk95 due to this broke X linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge). Not sure why since it passed on pull.
2022-06-14 16:42:36 +00:00
Alban Desmaison
64c2a275c4 Add full support for serialization of MPS Tensors (#79465)
Fix https://github.com/pytorch/pytorch/issues/79384
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79465
Approved by: https://github.com/kulinseth, https://github.com/malfet
2022-06-14 14:20:09 +00:00
Kulin Seth
77b6885a22 MPS: add layer_norm_backward (#79189)
Layernorm backward

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79189
Approved by: https://github.com/razarmehr, https://github.com/albanD
2022-06-10 13:25:41 +00:00
Kulin Seth
83239351c5 MPS: add exponential op (#79188)
Add exponential distribution

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79188
Approved by: https://github.com/razarmehr, https://github.com/albanD
2022-06-10 13:16:21 +00:00
Kulin Seth
50f7b40ad9 MPS: Binary cast fix by proper type promotion and remove spurious copy warning (#79185)
Fixes #78019, #78020
Fixes https://github.com/pytorch/pytorch/pull/79185
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79185
Approved by: https://github.com/albanD, https://github.com/razarmehr
2022-06-09 17:33:06 +00:00
Nikita Shulga
97594a24b4 Print output during MPS test import tests (#79163)
Simplify `test_no_warnings_on_input` to simply capture any output.
Copy its implementation to `test_testing.py` as this is not specific to MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79163
Approved by: https://github.com/janeyx99, https://github.com/kulinseth
2022-06-09 13:07:05 +00:00
Philip Meier
32593ef2dd move MPS compat into common comparison machinery (#77836)
Addresses https://github.com/pytorch/pytorch/issues/77144#issuecomment-1128168082.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77836
Approved by: https://github.com/albanD
2022-06-08 08:09:18 +00:00
Kulin Seth
a6347f5467 MPS: Fixes (#78930)
Cast integer to float in UnaryOps
Add tensor dtype in key generation
Enable FP16 scalars and use placeholder for alpha tensor in add/sum ops

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78930
Approved by: https://github.com/albanD
2022-06-07 18:22:10 +00:00
Nikita Shulga
55cac22cdf [MPS] Add arange_mps_out implementation (#78789)
Mostly by factoring out shader logic from `linspace_out_mps` implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78789
Approved by: https://github.com/albanD, https://github.com/kulinseth
2022-06-03 21:54:41 +00:00
Kulin Seth
4858c56334 MPS: Fix issues with view tensors and linspace. (#78690)
Fixes: #https://github.com/pytorch/pytorch/issues/78642, https://github.com/pytorch/pytorch/issues/78511
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78690
Approved by: https://github.com/razarmehr, https://github.com/DenisVieriu97
2022-06-02 06:17:19 +00:00
Kulin Seth
a3bdafece3 MPS: add linespace op (#78570)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78570
Approved by: https://github.com/malfet
2022-06-01 13:47:14 +00:00
Ramin Azarmehr
aa62b3e003 Add test case for issue: https://github.com/pytorch/pytorch/issues/77851 (#78547)
The test works fine now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78547
Approved by: https://github.com/kulinseth
2022-05-31 19:15:45 +00:00
Rohan Mitchell
f42b42d3eb MPS: Implement aten::count_nonzero.dim_IntList (#78169)
- See: #77764

Implements the `aten::count_nonzero.dim_IntList` operator (as used by [torch.count_nonzero](https://pytorch.org/docs/stable/generated/torch.count_nonzero.html)) for [MPS](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78169
Approved by: https://github.com/malfet, https://github.com/kulinseth, https://github.com/albanD
2022-05-31 18:23:25 +00:00
Kulin Seth
017b0ae943 MPS: Fix crashes in view tensors due to buffer size mismatch (#78496)
Fixes #78247, #77886

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78496
Approved by: https://github.com/albanD, https://github.com/malfet
2022-05-31 02:09:03 +00:00
Alban Desmaison
bde246fcc6 Speed up test_mps from 9min to 25s
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78488

Approved by: https://github.com/kulinseth
2022-05-30 18:16:53 +00:00
Alban Desmaison
02551a0025 Remove prints and add proper asserts
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78454

Approved by: https://github.com/kulinseth
2022-05-30 18:16:53 +00:00
Kulin Seth
d63db52349 MPS: Fixes the as_strided_mps implementation for contiguous view operations (#78440)
Fixes https://github.com/pytorch/pytorch/issues/78107; https://github.com/pytorch/pytorch/issues/77750

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78440
Approved by: https://github.com/malfet
2022-05-28 14:41:56 +00:00
Nikita Shulga
437ecfc461 [MPS] Fix copy_kernel_mps (#78428)
By passing `storage_offset` of source and destination Tensors
This fixes following simple usecase:
```
python3` -c "import torch;x=torch.zeros(3, 3, device='mps'); x[1, 1]=1;print(x)"
```

Add test to validate it would not regress in the future

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78428
Approved by: https://github.com/kulinseth
2022-05-27 20:46:53 +00:00
Kulin Seth
8552acbd74 MPS: Eye op (#78408)
This can be used as a reference PR was to add Op in MPS backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78408
Approved by: https://github.com/albanD
2022-05-27 17:07:02 +00:00
Kulin Seth
2e32d5fcd8 MPS: Add adaptive max pool2d op (#78410)
Adaptive max pool 2d forward and backward with test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78410
Approved by: https://github.com/albanD
2022-05-27 11:59:07 +00:00
Nikita Shulga
705082656a Fix typo in testname (#78258)
`test_linear2D_no_bias_backwarwd` -> `test_linear2D_no_bias_backward`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78258
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
2022-05-25 16:23:10 +00:00
Lukas Hoenig
a52bfe2c5d Convert MPS Tensor data using MPSGraph API (#78092)
Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~.

Before:
```python
In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
        -1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
       device='mps:0', dtype=torch.int32)

In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True,
         True, False, False,  True,  True, False, False,  True,  True, False,
        False,  True,  True, False, False,  True,  True, False, False,  True],
       device='mps:0')
```

After:
```python
In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
        -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')

In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
        -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
       device='mps:0', dtype=torch.int32)

In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
        -10., -10., -10., -10.], device='mps:0')

In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True], device='mps:0')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78092
Approved by: https://github.com/kulinseth, https://github.com/malfet
2022-05-24 20:09:45 +00:00
Alban Desmaison
04ac80c73a Fix a few issues on assert/double error/legacy constructor (#77966)
Fixes https://github.com/pytorch/pytorch/issues/77960, https://github.com/pytorch/pytorch/issues/77957, https://github.com/pytorch/pytorch/issues/77781
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77966
Approved by: https://github.com/soulitzer, https://github.com/kulinseth
2022-05-20 20:25:12 +00:00
Kulin Seth
3d83321b44 MPS Fixes: copy operations, addmm and baddmm (#77791)
Fixes for the copy operations and GEMM operations on MPS backend.

Fixes https://github.com/pytorch/pytorch/issues/77819
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77791
Approved by: https://github.com/albanD
2022-05-20 03:18:11 +00:00
Kulin Seth
978304fc9c MPS: fixes (#77462)
- Fix the is_available flag for x86 machines
- Fix the tensor creation for older MacOS platforms
- Addmm fixes for transposition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77462
Approved by: https://github.com/albanD
2022-05-14 13:33:16 +00:00
Kulin Seth
e011a8e18b Enable PyTorch operations on MPS Backend. (#77343)
Add PyTorch operations to MPS backend.

- https://github.com/pytorch/pytorch/issues/77394
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77343
Approved by: https://github.com/albanD
2022-05-13 18:28:53 +00:00