[mps/inductor] Adjust more tests that expect float64 as input. (#146366)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146366
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-02-04 00:48:02 +00:00 committed by PyTorch MergeBot
parent 2f40f789da
commit cf6c5b8fa8
2 changed files with 65 additions and 23 deletions

View File

@ -128,6 +128,10 @@ for test_name in [
"test_builtins_round_float_ndigits_neg",
"test_cat_empty",
"test_cat_unbacked_empty_1d",
"test_consecutive_split_cumprod",
"test_consecutive_split_cumsum",
"test_constant_pad_float64",
"test_cumsum_inf",
"test_custom_op_2",
"test_div1",
"test_div3",
@ -141,6 +145,7 @@ for test_name in [
"test_isinf",
"test_isinf2",
"test_lgamma",
"test_linear_float64",
"test_log_fp64",
"test_low_memory_max_pool",
"test_max_min",

View File

@ -2043,9 +2043,20 @@ class CommonTemplate:
b = b.view(-1)
return torch.cumsum(a, 0) + torch.cumsum(b, 0)
a = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=self.device)
b = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float64, device=self.device)
self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False)
dtype_a = torch.float32
dtype_b = torch.float64
ctx = (
contextlib.nullcontext()
if self.is_dtype_supported(dtype_a) and self.is_dtype_supported(dtype_b)
else self.assertRaises(TypeError)
)
with ctx:
a = make_tensor(10, 3, 352, 352, low=0, dtype=dtype_a, device=self.device)
b = make_tensor(10, 3, 352, 352, low=0, dtype=dtype_b, device=self.device)
self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False)
@config.patch(max_autotune_pointwise=True)
def test_split_cumsum_index(self):
@ -2097,13 +2108,20 @@ class CommonTemplate:
def fn(a, b):
return torch.cumprod(a, 0) + torch.cumprod(b, 0)
a = _large_cumprod_input(
(10000,), dim=0, dtype=torch.float32, device=self.device
dtype_a = torch.float32
dtype_b = torch.float64
ctx = (
contextlib.nullcontext()
if self.is_dtype_supported(dtype_a) and self.is_dtype_supported(dtype_b)
else self.assertRaises(TypeError)
)
b = _large_cumprod_input(
(10000,), dim=0, dtype=torch.float64, device=self.device
)
self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False)
with ctx:
a = _large_cumprod_input((10000,), dim=0, dtype=dtype_a, device=self.device)
b = _large_cumprod_input((10000,), dim=0, dtype=dtype_b, device=self.device)
self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False)
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
@skip_if_halide # scan ops
@ -2492,18 +2510,22 @@ class CommonTemplate:
def fn(x):
return x.cumsum(-1)
_dtype = torch.float64
def make_tensor(shape):
return torch.full(
shape, float("inf"), device=self.device, dtype=torch.float64
)
return torch.full(shape, float("inf"), device=self.device, dtype=_dtype)
cfn = torch.compile(fn)
ctx = (
contextlib.nullcontext()
if self.is_dtype_supported(_dtype)
else self.assertRaises(TypeError)
)
with ctx:
cfn = torch.compile(fn)
for n in [100, 10, 100]:
inp = torch.full(
(2, n), float("inf"), device=self.device, dtype=torch.float64
)
self.assertEqual(cfn(inp), fn(inp))
for n in [100, 10, 100]:
inp = torch.full((2, n), float("inf"), device=self.device, dtype=_dtype)
self.assertEqual(cfn(inp), fn(inp))
@xfail_if_triton_cpu
def test_logcumsumexp(self):
@ -3527,9 +3549,16 @@ class CommonTemplate:
@skipCUDAIf(True, "cuda failed for float64 linear")
@skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN")
def test_linear_float64(self):
mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(torch.float64)).eval()
with torch.no_grad():
self.common(mod, (torch.randn(2, 8).to(torch.float64),))
_dtype = torch.float64
ctx = (
contextlib.nullcontext()
if self.is_dtype_supported(_dtype)
else self.assertRaises(TypeError)
)
with ctx:
mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(_dtype)).eval()
with torch.no_grad():
self.common(mod, (torch.randn(2, 8).to(_dtype),))
def test_linear1(self):
mod = torch.nn.Sequential(
@ -7083,8 +7112,16 @@ class CommonTemplate:
v1 = torch.nn.functional.pad(input, pad=(1, 0))
return torch.gt(v1, input)
x = torch.rand([1, 2, 2, 1], dtype=torch.float64)
self.common(fn, (x,))
_dtype = torch.float64
ctx = (
contextlib.nullcontext()
if self.is_dtype_supported(_dtype)
else self.assertRaises(TypeError)
)
x = torch.rand([1, 2, 2, 1], dtype=_dtype)
with ctx:
self.common(fn, (x,))
def test_constant_pad_nd_inplace(self):
def fn(a):