diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 6028b01ba77..a5ea79116e3 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -128,6 +128,10 @@ for test_name in [ "test_builtins_round_float_ndigits_neg", "test_cat_empty", "test_cat_unbacked_empty_1d", + "test_consecutive_split_cumprod", + "test_consecutive_split_cumsum", + "test_constant_pad_float64", + "test_cumsum_inf", "test_custom_op_2", "test_div1", "test_div3", @@ -141,6 +145,7 @@ for test_name in [ "test_isinf", "test_isinf2", "test_lgamma", + "test_linear_float64", "test_log_fp64", "test_low_memory_max_pool", "test_max_min", diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4f09864f508..3267d28a985 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2043,9 +2043,20 @@ class CommonTemplate: b = b.view(-1) return torch.cumsum(a, 0) + torch.cumsum(b, 0) - a = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=self.device) - b = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float64, device=self.device) - self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False) + dtype_a = torch.float32 + dtype_b = torch.float64 + + ctx = ( + contextlib.nullcontext() + if self.is_dtype_supported(dtype_a) and self.is_dtype_supported(dtype_b) + else self.assertRaises(TypeError) + ) + + with ctx: + a = make_tensor(10, 3, 352, 352, low=0, dtype=dtype_a, device=self.device) + b = make_tensor(10, 3, 352, 352, low=0, dtype=dtype_b, device=self.device) + + self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False) @config.patch(max_autotune_pointwise=True) def test_split_cumsum_index(self): @@ -2097,13 +2108,20 @@ class CommonTemplate: def fn(a, b): return torch.cumprod(a, 0) + torch.cumprod(b, 0) - a = _large_cumprod_input( - (10000,), dim=0, dtype=torch.float32, device=self.device + dtype_a = torch.float32 + dtype_b = torch.float64 + + ctx = ( + contextlib.nullcontext() + if self.is_dtype_supported(dtype_a) and self.is_dtype_supported(dtype_b) + else self.assertRaises(TypeError) ) - b = _large_cumprod_input( - (10000,), dim=0, dtype=torch.float64, device=self.device - ) - self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False) + + with ctx: + a = _large_cumprod_input((10000,), dim=0, dtype=dtype_a, device=self.device) + b = _large_cumprod_input((10000,), dim=0, dtype=dtype_b, device=self.device) + + self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False) @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") @skip_if_halide # scan ops @@ -2492,18 +2510,22 @@ class CommonTemplate: def fn(x): return x.cumsum(-1) + _dtype = torch.float64 + def make_tensor(shape): - return torch.full( - shape, float("inf"), device=self.device, dtype=torch.float64 - ) + return torch.full(shape, float("inf"), device=self.device, dtype=_dtype) - cfn = torch.compile(fn) + ctx = ( + contextlib.nullcontext() + if self.is_dtype_supported(_dtype) + else self.assertRaises(TypeError) + ) + with ctx: + cfn = torch.compile(fn) - for n in [100, 10, 100]: - inp = torch.full( - (2, n), float("inf"), device=self.device, dtype=torch.float64 - ) - self.assertEqual(cfn(inp), fn(inp)) + for n in [100, 10, 100]: + inp = torch.full((2, n), float("inf"), device=self.device, dtype=_dtype) + self.assertEqual(cfn(inp), fn(inp)) @xfail_if_triton_cpu def test_logcumsumexp(self): @@ -3527,9 +3549,16 @@ class CommonTemplate: @skipCUDAIf(True, "cuda failed for float64 linear") @skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN") def test_linear_float64(self): - mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(torch.float64)).eval() - with torch.no_grad(): - self.common(mod, (torch.randn(2, 8).to(torch.float64),)) + _dtype = torch.float64 + ctx = ( + contextlib.nullcontext() + if self.is_dtype_supported(_dtype) + else self.assertRaises(TypeError) + ) + with ctx: + mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(_dtype)).eval() + with torch.no_grad(): + self.common(mod, (torch.randn(2, 8).to(_dtype),)) def test_linear1(self): mod = torch.nn.Sequential( @@ -7083,8 +7112,16 @@ class CommonTemplate: v1 = torch.nn.functional.pad(input, pad=(1, 0)) return torch.gt(v1, input) - x = torch.rand([1, 2, 2, 1], dtype=torch.float64) - self.common(fn, (x,)) + _dtype = torch.float64 + + ctx = ( + contextlib.nullcontext() + if self.is_dtype_supported(_dtype) + else self.assertRaises(TypeError) + ) + x = torch.rand([1, 2, 2, 1], dtype=_dtype) + with ctx: + self.common(fn, (x,)) def test_constant_pad_nd_inplace(self): def fn(a):