[MPSInductor] More is_dtype_supported gating (#144981)

This makes `GPUTest.test_scalar_cpu_tensor_arg_mps` pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144981
Approved by: https://github.com/dcci
ghstack dependencies: #144971
This commit is contained in:
Nikita Shulga 2025-01-16 13:37:04 -08:00 committed by PyTorch MergeBot
parent 94c0f15302
commit 42c64bd35c
2 changed files with 3 additions and 0 deletions

View File

@ -66,6 +66,7 @@ class MPSBasicTests(TestCase):
test_remove_no_ops = CommonTemplate.test_remove_no_ops
test_reflection_pad2d = CommonTemplate.test_reflection_pad2d
test_rsqrt = CommonTemplate.test_rsqrt
test_scalar_cpu_tensor_arg = CommonTemplate.test_scalar_cpu_tensor_arg
test_scalar_output = CommonTemplate.test_scalar_output
test_setitem_with_int_parameter = CommonTemplate.test_setitem_with_int_parameter
test_signbit = CommonTemplate.test_signbit

View File

@ -11778,6 +11778,8 @@ class CommonTemplate:
torch.bfloat16,
]
for cpu_dtype in test_dtypes:
if not self.is_dtype_supported(cpu_dtype):
continue
x = torch.rand([20], device=GPU_TYPE)
y = torch.rand([4], device="cpu", dtype=cpu_dtype)
self.common(