mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[mps/inductor] Adjust more tests that expect float64 as input. (#146366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146366 Approved by: https://github.com/malfet
This commit is contained in:
parent
2f40f789da
commit
cf6c5b8fa8
|
|
@ -128,6 +128,10 @@ for test_name in [
|
|||
"test_builtins_round_float_ndigits_neg",
|
||||
"test_cat_empty",
|
||||
"test_cat_unbacked_empty_1d",
|
||||
"test_consecutive_split_cumprod",
|
||||
"test_consecutive_split_cumsum",
|
||||
"test_constant_pad_float64",
|
||||
"test_cumsum_inf",
|
||||
"test_custom_op_2",
|
||||
"test_div1",
|
||||
"test_div3",
|
||||
|
|
@ -141,6 +145,7 @@ for test_name in [
|
|||
"test_isinf",
|
||||
"test_isinf2",
|
||||
"test_lgamma",
|
||||
"test_linear_float64",
|
||||
"test_log_fp64",
|
||||
"test_low_memory_max_pool",
|
||||
"test_max_min",
|
||||
|
|
|
|||
|
|
@ -2043,9 +2043,20 @@ class CommonTemplate:
|
|||
b = b.view(-1)
|
||||
return torch.cumsum(a, 0) + torch.cumsum(b, 0)
|
||||
|
||||
a = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=self.device)
|
||||
b = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float64, device=self.device)
|
||||
self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False)
|
||||
dtype_a = torch.float32
|
||||
dtype_b = torch.float64
|
||||
|
||||
ctx = (
|
||||
contextlib.nullcontext()
|
||||
if self.is_dtype_supported(dtype_a) and self.is_dtype_supported(dtype_b)
|
||||
else self.assertRaises(TypeError)
|
||||
)
|
||||
|
||||
with ctx:
|
||||
a = make_tensor(10, 3, 352, 352, low=0, dtype=dtype_a, device=self.device)
|
||||
b = make_tensor(10, 3, 352, 352, low=0, dtype=dtype_b, device=self.device)
|
||||
|
||||
self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False)
|
||||
|
||||
@config.patch(max_autotune_pointwise=True)
|
||||
def test_split_cumsum_index(self):
|
||||
|
|
@ -2097,13 +2108,20 @@ class CommonTemplate:
|
|||
def fn(a, b):
|
||||
return torch.cumprod(a, 0) + torch.cumprod(b, 0)
|
||||
|
||||
a = _large_cumprod_input(
|
||||
(10000,), dim=0, dtype=torch.float32, device=self.device
|
||||
dtype_a = torch.float32
|
||||
dtype_b = torch.float64
|
||||
|
||||
ctx = (
|
||||
contextlib.nullcontext()
|
||||
if self.is_dtype_supported(dtype_a) and self.is_dtype_supported(dtype_b)
|
||||
else self.assertRaises(TypeError)
|
||||
)
|
||||
b = _large_cumprod_input(
|
||||
(10000,), dim=0, dtype=torch.float64, device=self.device
|
||||
)
|
||||
self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False)
|
||||
|
||||
with ctx:
|
||||
a = _large_cumprod_input((10000,), dim=0, dtype=dtype_a, device=self.device)
|
||||
b = _large_cumprod_input((10000,), dim=0, dtype=dtype_b, device=self.device)
|
||||
|
||||
self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False)
|
||||
|
||||
@skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm")
|
||||
@skip_if_halide # scan ops
|
||||
|
|
@ -2492,18 +2510,22 @@ class CommonTemplate:
|
|||
def fn(x):
|
||||
return x.cumsum(-1)
|
||||
|
||||
_dtype = torch.float64
|
||||
|
||||
def make_tensor(shape):
|
||||
return torch.full(
|
||||
shape, float("inf"), device=self.device, dtype=torch.float64
|
||||
)
|
||||
return torch.full(shape, float("inf"), device=self.device, dtype=_dtype)
|
||||
|
||||
cfn = torch.compile(fn)
|
||||
ctx = (
|
||||
contextlib.nullcontext()
|
||||
if self.is_dtype_supported(_dtype)
|
||||
else self.assertRaises(TypeError)
|
||||
)
|
||||
with ctx:
|
||||
cfn = torch.compile(fn)
|
||||
|
||||
for n in [100, 10, 100]:
|
||||
inp = torch.full(
|
||||
(2, n), float("inf"), device=self.device, dtype=torch.float64
|
||||
)
|
||||
self.assertEqual(cfn(inp), fn(inp))
|
||||
for n in [100, 10, 100]:
|
||||
inp = torch.full((2, n), float("inf"), device=self.device, dtype=_dtype)
|
||||
self.assertEqual(cfn(inp), fn(inp))
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
def test_logcumsumexp(self):
|
||||
|
|
@ -3527,9 +3549,16 @@ class CommonTemplate:
|
|||
@skipCUDAIf(True, "cuda failed for float64 linear")
|
||||
@skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN")
|
||||
def test_linear_float64(self):
|
||||
mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(torch.float64)).eval()
|
||||
with torch.no_grad():
|
||||
self.common(mod, (torch.randn(2, 8).to(torch.float64),))
|
||||
_dtype = torch.float64
|
||||
ctx = (
|
||||
contextlib.nullcontext()
|
||||
if self.is_dtype_supported(_dtype)
|
||||
else self.assertRaises(TypeError)
|
||||
)
|
||||
with ctx:
|
||||
mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(_dtype)).eval()
|
||||
with torch.no_grad():
|
||||
self.common(mod, (torch.randn(2, 8).to(_dtype),))
|
||||
|
||||
def test_linear1(self):
|
||||
mod = torch.nn.Sequential(
|
||||
|
|
@ -7083,8 +7112,16 @@ class CommonTemplate:
|
|||
v1 = torch.nn.functional.pad(input, pad=(1, 0))
|
||||
return torch.gt(v1, input)
|
||||
|
||||
x = torch.rand([1, 2, 2, 1], dtype=torch.float64)
|
||||
self.common(fn, (x,))
|
||||
_dtype = torch.float64
|
||||
|
||||
ctx = (
|
||||
contextlib.nullcontext()
|
||||
if self.is_dtype_supported(_dtype)
|
||||
else self.assertRaises(TypeError)
|
||||
)
|
||||
x = torch.rand([1, 2, 2, 1], dtype=_dtype)
|
||||
with ctx:
|
||||
self.common(fn, (x,))
|
||||
|
||||
def test_constant_pad_nd_inplace(self):
|
||||
def fn(a):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user