diff --git a/test/test_mps.py b/test/test_mps.py index 67f3c784b9e..9f6a00a3fb7 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9237,21 +9237,20 @@ class TestAdvancedIndexing(TestCaseMPS): # FIXME: use supported_dtypes once uint8 is fixed [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]] - # FIXME: conditional indexing not working - # def test_boolean_array_indexing_1(self): - # def helper(dtype): - # x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) - # x_mps = x_cpu.detach().clone().to("mps") + def test_boolean_array_indexing(self): + def helper(dtype): + x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) + x_mps = x_cpu.detach().clone().to("mps") - # res_cpu = x_cpu[x_cpu > 5] - # res_mps = x_mps[x_mps > 5] - - # print(res_cpu) - # print(res_mps) - - # self.assertEqual(res_cpu, res_mps, str(dtype)) - # [helper(dtype) for dtype in self.supported_dtypes] + res_cpu = x_cpu[x_cpu > 5] + res_mps = x_mps[x_mps > 5] + self.assertEqual(res_cpu, res_mps, str(dtype)) + for dtype in self.supported_dtypes: + # MPS support binary op with uint8 natively starting from macOS 13.0 + if product_version < 13.0 and dtype == torch.uint8: + continue + helper(dtype) def test_advanced_indexing_3D_get(self): def helper(x_cpu): @@ -9566,24 +9565,23 @@ class TestAdvancedIndexing(TestCaseMPS): r = v[c > 0] self.assertEqual(r.shape, (num_ones, 3)) - # FIXME: conditional indexing not working - # def test_jit_indexing(self, device="mps"): - # def fn1(x): - # x[x < 50] = 1.0 - # return x + def test_jit_indexing(self, device="mps"): + def fn1(x): + x[x < 50] = 1.0 + return x - # def fn2(x): - # x[0:50] = 1.0 - # return x + def fn2(x): + x[0:50] = 1.0 + return x - # scripted_fn1 = torch.jit.script(fn1) - # scripted_fn2 = torch.jit.script(fn2) - # data = torch.arange(100, device=device, dtype=torch.float) - # out = scripted_fn1(data.detach().clone()) - # ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float) - # self.assertEqual(out, ref) - # out = scripted_fn2(data.detach().clone()) - # self.assertEqual(out, ref) + scripted_fn1 = torch.jit.script(fn1) + scripted_fn2 = torch.jit.script(fn2) + data = torch.arange(100, device=device, dtype=torch.float) + out = scripted_fn1(data.detach().clone()) + ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float) + self.assertEqual(out, ref) + out = scripted_fn2(data.detach().clone()) + self.assertEqual(out, ref) def test_int_indices(self, device="mps"): v = torch.randn(5, 7, 3, device=device)