mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Enable conditional indexing tests (#97871)
The tests seem to be working now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97871 Approved by: https://github.com/kulinseth
This commit is contained in:
parent
e8d39606eb
commit
db8abde9b6
|
|
@ -9237,21 +9237,20 @@ class TestAdvancedIndexing(TestCaseMPS):
|
|||
# FIXME: use supported_dtypes once uint8 is fixed
|
||||
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
|
||||
|
||||
# FIXME: conditional indexing not working
|
||||
# def test_boolean_array_indexing_1(self):
|
||||
# def helper(dtype):
|
||||
# x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
|
||||
# x_mps = x_cpu.detach().clone().to("mps")
|
||||
def test_boolean_array_indexing(self):
|
||||
def helper(dtype):
|
||||
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
|
||||
x_mps = x_cpu.detach().clone().to("mps")
|
||||
|
||||
# res_cpu = x_cpu[x_cpu > 5]
|
||||
# res_mps = x_mps[x_mps > 5]
|
||||
|
||||
# print(res_cpu)
|
||||
# print(res_mps)
|
||||
|
||||
# self.assertEqual(res_cpu, res_mps, str(dtype))
|
||||
# [helper(dtype) for dtype in self.supported_dtypes]
|
||||
res_cpu = x_cpu[x_cpu > 5]
|
||||
res_mps = x_mps[x_mps > 5]
|
||||
|
||||
self.assertEqual(res_cpu, res_mps, str(dtype))
|
||||
for dtype in self.supported_dtypes:
|
||||
# MPS support binary op with uint8 natively starting from macOS 13.0
|
||||
if product_version < 13.0 and dtype == torch.uint8:
|
||||
continue
|
||||
helper(dtype)
|
||||
|
||||
def test_advanced_indexing_3D_get(self):
|
||||
def helper(x_cpu):
|
||||
|
|
@ -9566,24 +9565,23 @@ class TestAdvancedIndexing(TestCaseMPS):
|
|||
r = v[c > 0]
|
||||
self.assertEqual(r.shape, (num_ones, 3))
|
||||
|
||||
# FIXME: conditional indexing not working
|
||||
# def test_jit_indexing(self, device="mps"):
|
||||
# def fn1(x):
|
||||
# x[x < 50] = 1.0
|
||||
# return x
|
||||
def test_jit_indexing(self, device="mps"):
|
||||
def fn1(x):
|
||||
x[x < 50] = 1.0
|
||||
return x
|
||||
|
||||
# def fn2(x):
|
||||
# x[0:50] = 1.0
|
||||
# return x
|
||||
def fn2(x):
|
||||
x[0:50] = 1.0
|
||||
return x
|
||||
|
||||
# scripted_fn1 = torch.jit.script(fn1)
|
||||
# scripted_fn2 = torch.jit.script(fn2)
|
||||
# data = torch.arange(100, device=device, dtype=torch.float)
|
||||
# out = scripted_fn1(data.detach().clone())
|
||||
# ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
|
||||
# self.assertEqual(out, ref)
|
||||
# out = scripted_fn2(data.detach().clone())
|
||||
# self.assertEqual(out, ref)
|
||||
scripted_fn1 = torch.jit.script(fn1)
|
||||
scripted_fn2 = torch.jit.script(fn2)
|
||||
data = torch.arange(100, device=device, dtype=torch.float)
|
||||
out = scripted_fn1(data.detach().clone())
|
||||
ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
|
||||
self.assertEqual(out, ref)
|
||||
out = scripted_fn2(data.detach().clone())
|
||||
self.assertEqual(out, ref)
|
||||
|
||||
def test_int_indices(self, device="mps"):
|
||||
v = torch.randn(5, 7, 3, device=device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user