[BE] Switch TestConsistency to MPS device (#147893)

Which will eventually allow move decorators away more `common_mps.py`

Adjust tolerances accordingly. XFAIL a bunch of tests on MacOS-13, which is going to be deprecated anyway

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147893
Approved by: https://github.com/atalman
ghstack dependencies: #152204
This commit is contained in:
Nikita Shulga 2025-04-25 17:42:28 -07:00 committed by PyTorch MergeBot
parent 73f11e3365
commit 3ef6d6924a
4 changed files with 115 additions and 33 deletions

View File

@ -11706,23 +11706,24 @@ MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
MPS_GRAD_DTYPES = [torch.float32, torch.float16]
def transform_opinfo_sample_to_mps(sample):
"""Transforms opinfo.core.SampleInput from CPU to MPS"""
def transform_opinfo_sample_to_cpu(sample):
"""Transforms opinfo.core.SampleInput from MPS to CPU"""
def transform_sample(x):
if not isinstance(x, torch.Tensor):
return x
requires_grad = x.requires_grad
conjugated = x.is_conj()
rc = x.detach()
rc = rc.to("mps") if not conjugated else x.conj().to("mps").conj()
rc = rc.cpu() if not conjugated else x.conj().cpu().conj()
return rc.requires_grad_(x.requires_grad)
mps_sample = sample.transform(transform_sample)
cpu_sample = sample.transform(transform_sample)
# Transform kwargs `device="cpu"` to `device="mps"`
if mps_sample.kwargs.get("device", "") == "cpu":
mps_sample.kwargs["device"] = "mps"
return mps_sample
# Transform kwargs `device="mps:0"` to `device="cpu"`
if cpu_sample.kwargs.get("device", "") == "mps:0":
cpu_sample.kwargs["device"] = "cpu"
return cpu_sample
class TestConsistency(TestCaseMPS):
# TODO: This is only used while some ops are being added.
@ -11838,8 +11839,10 @@ class TestConsistency(TestCaseMPS):
@ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES)
def test_output_match(self, device, dtype, op):
self.assertEqual(device, "cpu")
self.assertEqual(device, "mps:0")
include_conjugated_inputs = dtype.is_complex and op.test_conjugated_samples
if op.name.endswith("svd") and MACOS_VERSION < 14.0 and dtype == torch.complex64:
raise unittest.SkipTest("Can't even generate complex samples on MacOS-13")
def get_samples():
return op.sample_inputs(
@ -11847,16 +11850,14 @@ class TestConsistency(TestCaseMPS):
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
include_conjugated_inputs=include_conjugated_inputs,
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
set_seed=True,
)
cpu_samples = get_samples()
for cpu_sample in cpu_samples:
for mps_sample in get_samples():
#
# Forward check
#
mps_sample = transform_opinfo_sample_to_mps(cpu_sample)
cpu_sample = transform_opinfo_sample_to_cpu(mps_sample)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
@ -11896,7 +11897,7 @@ class TestConsistency(TestCaseMPS):
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
def test_output_grad_match(self, device, dtype, op):
self.assertEqual(device, "cpu")
self.assertEqual(device, "mps:0")
def get_samples():
return op.sample_inputs(
@ -11906,13 +11907,12 @@ class TestConsistency(TestCaseMPS):
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
)
cpu_samples = get_samples()
for cpu_sample in cpu_samples:
for mps_sample in get_samples():
#
# Forward check
#
mps_sample = transform_opinfo_sample_to_mps(cpu_sample)
cpu_sample = transform_opinfo_sample_to_cpu(mps_sample)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
@ -12290,7 +12290,7 @@ class TestMetalLibrary(TestCaseMPS):
# This requires mps to be properly registered in the device generic test framework which is not the
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
# to achieve this.
instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
instantiate_device_type_tests(TestConsistency, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")

View File

@ -39,7 +39,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
TEST_WITH_TORCHINDUCTOR
TEST_WITH_TORCHINDUCTOR, MACOS_VERSION
)
from torch.testing._utils import wrapper_set_seed
@ -12172,6 +12172,8 @@ op_db: list[OpInfo] = [
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
"TestConsistency", "test_output_grad_match", device_type="mps"),
)),
OpInfo('addmm',
# When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
@ -12206,6 +12208,10 @@ op_db: list[OpInfo] = [
DecorateInfo(
toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
"TestConsistency", "test_output_match", device_type="mps"),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
"TestConsistency", "test_output_grad_match", device_type="mps"),
],
sample_inputs_func=sample_inputs_addmv),
OpInfo('addbmm',
@ -12232,7 +12238,8 @@ op_db: list[OpInfo] = [
torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
'TestCommon', 'test_numpy_ref_mps'),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5),
torch.bfloat16: tol(atol=2e-1, rtol=6e-1)}),
'TestConsistency',
'test_output_match',
),
@ -12328,7 +12335,15 @@ op_db: list[OpInfo] = [
# NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
"TestCommon", "test_out")
"TestCommon", "test_out"),
# Fast math on MacOS-13?
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}),
'TestConsistency',
'test_output_match',
active_if=lambda _: MACOS_VERSION < 14.0,
device_type='mps',
dtypes=(torch.float32,)),
),
sample_inputs_func=sample_inputs_bmm),
OpInfo('mv',
@ -12507,6 +12522,8 @@ op_db: list[OpInfo] = [
"test_comprehensive",
device_type="cuda"
),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
"TestConsistency", "test_output_grad_match", device_type="mps"),
],
supports_inplace_autograd=False,
supports_forward_ad=True,
@ -13030,6 +13047,8 @@ op_db: list[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}),
"TestInductorOpInfo", "test_comprehensive", device_type="cpu"),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=3e-4, rtol=1e-4)}),
"TestConsistency", "test_output_grad_match", device_type="mps"),
)),
OpInfo('cross',
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
@ -13181,6 +13200,12 @@ op_db: list[OpInfo] = [
'test_fn_grad',
dtypes=(torch.float64,),
device_type='cpu'),
DecorateInfo(unittest.skip("Broken on MacOS13"),
'TestConsistency',
'test_output_match',
device_type='mps',
dtypes=(torch.float16,),
active_if=lambda _: MACOS_VERSION < 14.0),
)),
BinaryUfuncInfo('true_divide',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
@ -13409,6 +13434,12 @@ op_db: list[OpInfo] = [
"test_comprehensive",
device_type="cuda"
),
DecorateInfo(unittest.skip("Broken on MacOS13"),
'TestConsistency',
'test_output_match',
device_type='mps',
dtypes=(torch.float16,),
active_if=lambda _: MACOS_VERSION < 14.0),
)),
UnaryUfuncInfo('frac',
ref=lambda x: np.modf(x)[0],
@ -15048,6 +15079,8 @@ op_db: list[OpInfo] = [
'test_variant_consistency_jit',
dtypes=(torch.float32,)
),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
"TestConsistency", "test_output_match", device_type="mps"),
),
),
UnaryUfuncInfo(
@ -16905,6 +16938,14 @@ op_db: list[OpInfo] = [
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
# Fast math on MacOS-13?
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}),
'TestConsistency',
'test_output_match',
active_if=lambda _: MACOS_VERSION < 14.0,
device_type='mps',
dtypes=(torch.float32,)),
)),
OpInfo('mode',
op=torch.mode,
@ -17334,6 +17375,8 @@ op_db: list[OpInfo] = [
dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
"TestConsistency", "test_output_grad_match", device_type="mps"),
),
decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
UnaryUfuncInfo('sinc',
@ -17704,6 +17747,10 @@ op_db: list[OpInfo] = [
dtypes=(torch.float16,),
device_type="cuda",
),
DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}),
"TestConsistency", "test_output_match", device_type="mps"),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
"TestConsistency", "test_output_grad_match", device_type="mps"),
),
# tan(pi/2 * odd_number) is nan
reference_numerics_filter=NumericsFilter(
@ -17738,6 +17785,8 @@ op_db: list[OpInfo] = [
active_if=(IS_MACOS or IS_WINDOWS)),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}),
"TestConsistency", "test_output_match", device_type="mps"),
),
# tan(j * pi/2 * odd_number) is nan
reference_numerics_filter=NumericsFilter(
@ -17941,6 +17990,8 @@ op_db: list[OpInfo] = [
active_if=IS_MACOS),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
DecorateInfo(toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}),
"TestConsistency", "test_output_match", device_type="mps"),
)),
UnaryUfuncInfo('square',
ref=np.square,
@ -19619,6 +19670,8 @@ op_db: list[OpInfo] = [
# Decomp max diff: 1.8187482915266173e-06
DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive',
device_type='cpu', dtypes=(torch.float16,)),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-4, rtol=3e-6)}),
"TestConsistency", "test_output_match", device_type="mps"),
)),
ShapeFuncInfo('repeat',
op=lambda x, dims: x.repeat(dims),
@ -21066,6 +21119,8 @@ op_db: list[OpInfo] = [
device_type='cuda', dtypes=[torch.float16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
device_type='cuda', dtypes=[torch.complex64]),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-5, rtol=4e-2)}),
"TestConsistency", "test_output_match", device_type="mps"),
),
),
ReductionOpInfo(
@ -21253,6 +21308,8 @@ op_db: list[OpInfo] = [
# possibly bad low precision reference in numpy
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
dtypes=[torch.float16]),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-3, rtol=4e-2)}),
"TestConsistency", "test_output_match", device_type="mps"),
),
),
OpInfo(
@ -21352,6 +21409,10 @@ op_db: list[OpInfo] = [
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}),
"TestConsistency", "test_output_match", device_type="mps"),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}),
"TestConsistency", "test_output_grad_match", device_type="mps"),
),
),
OpInfo(

View File

@ -282,6 +282,8 @@ if torch.backends.mps.is_available():
}
# Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
MACOS_BEFORE_13_3_XFAILLIST = {
# float16 seems horribly wrong on MacOS13
"floor_divide": [torch.float16],
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
"tan": [torch.float32],
"cdist": [torch.float32],
@ -312,6 +314,15 @@ if torch.backends.mps.is_available():
"masked.cumsum": [torch.int64],
"masked.cumprod": [torch.int64],
"linalg.vander": [torch.int64],
# Fail with `Expected 1.0 but got nan.` for empty tensors
# Caused by sample input at index 23: SampleInput(
# input=Tensor[size=(), device="mps:0", dtype=torch.float32],
# args=(0),
# kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'},
# broadcasts_input=False, name='')
"masked.softmin": [torch.float32, torch.float16],
"masked.softmax": [torch.float32, torch.float16],
"masked.log_softmax": [torch.float32, torch.float16],
}
MACOS_AFTER_13_1_XFAILLIST = {
@ -455,6 +466,7 @@ if torch.backends.mps.is_available():
"_segment_reducelengths": None,
"_segment_reduceoffsets": None,
"sparse.mm": None,
"sparse.sampled_addmm": None,
"sparse.mmreduce": None,
"special.airy_ai": None,
"special.erfcx": None,
@ -586,10 +598,6 @@ if torch.backends.mps.is_available():
# round not working properly for float16 and bfloat16
"round": [torch.float16, torch.bfloat16],
"rounddecimals_0": [torch.bfloat16],
# bfloat16 have weird issues with rounding
"divfloor_rounding": [torch.bfloat16],
"floor_divide": [torch.bfloat16],
"remainder": [torch.bfloat16],
# atomic operations not supported
"_unsafe_masked_index_put_accumulate": [
torch.bool,
@ -654,7 +662,6 @@ if torch.backends.mps.is_available():
torch.int64,
torch.uint8,
torch.int8,
torch.bfloat16,
],
# Failures due to random output that they generate using
# Philox engine causing mismatch with CPU results
@ -857,9 +864,6 @@ if torch.backends.mps.is_available():
def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
XFAILLIST_GRAD = {
# precision issues
"special.polygammaspecial_polygamma_n_0": [torch.float16],
"polygammapolygamma_n_0": [torch.float16],
# Unimplemented ops
"_segment_reduce": [torch.float16, torch.float32],
"_chunk_cat": [torch.float16, torch.float32],
@ -911,8 +915,6 @@ if torch.backends.mps.is_available():
"equal": [torch.float16, torch.float32],
# 'float' object is not iterable
"item": [torch.float16, torch.float32],
# "mse_backward_cpu_out" not implemented for 'Half'
"nn.functional.mse_loss": [torch.float16],
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
"nn.functional.smooth_l1_loss": [torch.float16],
# cpu error: grad requires non-empty inputs
@ -934,13 +936,18 @@ if torch.backends.mps.is_available():
"fmod": [torch.float16],
# round not working properly for float16
"round": [torch.float16],
# topk fails with duplicate indices
"topk": [torch.float16],
}
MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
# Failures due to precision issues (may be fast-math). These has been fixed in MacOS 14
"masked.softmin": [torch.float32, torch.float16],
"masked.softmax": [torch.float32, torch.float16],
"masked.log_softmax": [torch.float32, torch.float16],
"atanh": [torch.float16],
"__rmod__": [torch.float16],
"triangular_solve": [torch.float32],
# Unsupported Border padding mode, forward pass success as fallback to cpu
"grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16],
# Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).

View File

@ -1554,6 +1554,14 @@ op_db: list[OpInfo] = [
supports_fwgrad_bwgrad=True,
check_batched_grad=False,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
skips=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=8e-5, rtol=2e-6)}),
"TestConsistency",
"test_output_grad_match",
device_type="mps",
),
),
sample_inputs_func=sample_inputs_linalg_matrix_power,
),
OpInfo(
@ -2284,6 +2292,12 @@ op_db: list[OpInfo] = [
"test_noncontiguous_samples",
device_type="cpu",
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=2e-04, rtol=3e-06)}),
"TestConsistency",
"test_output_match",
device_type="mps",
),
],
skips=(
DecorateInfo(