mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPSInductor] More is_dtype_supported gating (#144981)
This makes `GPUTest.test_scalar_cpu_tensor_arg_mps` pass Pull Request resolved: https://github.com/pytorch/pytorch/pull/144981 Approved by: https://github.com/dcci ghstack dependencies: #144971
This commit is contained in:
parent
94c0f15302
commit
42c64bd35c
|
|
@ -66,6 +66,7 @@ class MPSBasicTests(TestCase):
|
|||
test_remove_no_ops = CommonTemplate.test_remove_no_ops
|
||||
test_reflection_pad2d = CommonTemplate.test_reflection_pad2d
|
||||
test_rsqrt = CommonTemplate.test_rsqrt
|
||||
test_scalar_cpu_tensor_arg = CommonTemplate.test_scalar_cpu_tensor_arg
|
||||
test_scalar_output = CommonTemplate.test_scalar_output
|
||||
test_setitem_with_int_parameter = CommonTemplate.test_setitem_with_int_parameter
|
||||
test_signbit = CommonTemplate.test_signbit
|
||||
|
|
|
|||
|
|
@ -11778,6 +11778,8 @@ class CommonTemplate:
|
|||
torch.bfloat16,
|
||||
]
|
||||
for cpu_dtype in test_dtypes:
|
||||
if not self.is_dtype_supported(cpu_dtype):
|
||||
continue
|
||||
x = torch.rand([20], device=GPU_TYPE)
|
||||
y = torch.rand([4], device="cpu", dtype=cpu_dtype)
|
||||
self.common(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user