[MPS] Enable conditional indexing tests (#97871)

The tests seem to be working now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97871
Approved by: https://github.com/kulinseth
This commit is contained in:
Li-Huai (Allan) Lin 2023-04-01 16:15:08 +00:00 committed by PyTorch MergeBot
parent e8d39606eb
commit db8abde9b6

View File

@ -9237,21 +9237,20 @@ class TestAdvancedIndexing(TestCaseMPS):
# FIXME: use supported_dtypes once uint8 is fixed
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
# FIXME: conditional indexing not working
# def test_boolean_array_indexing_1(self):
# def helper(dtype):
# x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
# x_mps = x_cpu.detach().clone().to("mps")
def test_boolean_array_indexing(self):
def helper(dtype):
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
# res_cpu = x_cpu[x_cpu > 5]
# res_mps = x_mps[x_mps > 5]
# print(res_cpu)
# print(res_mps)
# self.assertEqual(res_cpu, res_mps, str(dtype))
# [helper(dtype) for dtype in self.supported_dtypes]
res_cpu = x_cpu[x_cpu > 5]
res_mps = x_mps[x_mps > 5]
self.assertEqual(res_cpu, res_mps, str(dtype))
for dtype in self.supported_dtypes:
# MPS support binary op with uint8 natively starting from macOS 13.0
if product_version < 13.0 and dtype == torch.uint8:
continue
helper(dtype)
def test_advanced_indexing_3D_get(self):
def helper(x_cpu):
@ -9566,24 +9565,23 @@ class TestAdvancedIndexing(TestCaseMPS):
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
# FIXME: conditional indexing not working
# def test_jit_indexing(self, device="mps"):
# def fn1(x):
# x[x < 50] = 1.0
# return x
def test_jit_indexing(self, device="mps"):
def fn1(x):
x[x < 50] = 1.0
return x
# def fn2(x):
# x[0:50] = 1.0
# return x
def fn2(x):
x[0:50] = 1.0
return x
# scripted_fn1 = torch.jit.script(fn1)
# scripted_fn2 = torch.jit.script(fn2)
# data = torch.arange(100, device=device, dtype=torch.float)
# out = scripted_fn1(data.detach().clone())
# ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
# self.assertEqual(out, ref)
# out = scripted_fn2(data.detach().clone())
# self.assertEqual(out, ref)
scripted_fn1 = torch.jit.script(fn1)
scripted_fn2 = torch.jit.script(fn2)
data = torch.arange(100, device=device, dtype=torch.float)
out = scripted_fn1(data.detach().clone())
ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
self.assertEqual(out, ref)
out = scripted_fn2(data.detach().clone())
self.assertEqual(out, ref)
def test_int_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)