mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Switch TestConsistency to MPS device (#147893)
Which will eventually allow move decorators away more `common_mps.py` Adjust tolerances accordingly. XFAIL a bunch of tests on MacOS-13, which is going to be deprecated anyway Pull Request resolved: https://github.com/pytorch/pytorch/pull/147893 Approved by: https://github.com/atalman ghstack dependencies: #152204
This commit is contained in:
parent
73f11e3365
commit
3ef6d6924a
|
|
@ -11706,23 +11706,24 @@ MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
|
|||
MPS_GRAD_DTYPES = [torch.float32, torch.float16]
|
||||
|
||||
|
||||
def transform_opinfo_sample_to_mps(sample):
|
||||
"""Transforms opinfo.core.SampleInput from CPU to MPS"""
|
||||
def transform_opinfo_sample_to_cpu(sample):
|
||||
"""Transforms opinfo.core.SampleInput from MPS to CPU"""
|
||||
def transform_sample(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
requires_grad = x.requires_grad
|
||||
conjugated = x.is_conj()
|
||||
rc = x.detach()
|
||||
rc = rc.to("mps") if not conjugated else x.conj().to("mps").conj()
|
||||
rc = rc.cpu() if not conjugated else x.conj().cpu().conj()
|
||||
return rc.requires_grad_(x.requires_grad)
|
||||
|
||||
mps_sample = sample.transform(transform_sample)
|
||||
cpu_sample = sample.transform(transform_sample)
|
||||
|
||||
# Transform kwargs `device="cpu"` to `device="mps"`
|
||||
if mps_sample.kwargs.get("device", "") == "cpu":
|
||||
mps_sample.kwargs["device"] = "mps"
|
||||
return mps_sample
|
||||
# Transform kwargs `device="mps:0"` to `device="cpu"`
|
||||
if cpu_sample.kwargs.get("device", "") == "mps:0":
|
||||
cpu_sample.kwargs["device"] = "cpu"
|
||||
|
||||
return cpu_sample
|
||||
|
||||
class TestConsistency(TestCaseMPS):
|
||||
# TODO: This is only used while some ops are being added.
|
||||
|
|
@ -11838,8 +11839,10 @@ class TestConsistency(TestCaseMPS):
|
|||
|
||||
@ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES)
|
||||
def test_output_match(self, device, dtype, op):
|
||||
self.assertEqual(device, "cpu")
|
||||
self.assertEqual(device, "mps:0")
|
||||
include_conjugated_inputs = dtype.is_complex and op.test_conjugated_samples
|
||||
if op.name.endswith("svd") and MACOS_VERSION < 14.0 and dtype == torch.complex64:
|
||||
raise unittest.SkipTest("Can't even generate complex samples on MacOS-13")
|
||||
|
||||
def get_samples():
|
||||
return op.sample_inputs(
|
||||
|
|
@ -11847,16 +11850,14 @@ class TestConsistency(TestCaseMPS):
|
|||
dtype,
|
||||
requires_grad=(dtype.is_floating_point or dtype.is_complex),
|
||||
include_conjugated_inputs=include_conjugated_inputs,
|
||||
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
|
||||
set_seed=False,
|
||||
set_seed=True,
|
||||
)
|
||||
cpu_samples = get_samples()
|
||||
|
||||
for cpu_sample in cpu_samples:
|
||||
for mps_sample in get_samples():
|
||||
#
|
||||
# Forward check
|
||||
#
|
||||
mps_sample = transform_opinfo_sample_to_mps(cpu_sample)
|
||||
cpu_sample = transform_opinfo_sample_to_cpu(mps_sample)
|
||||
|
||||
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
|
||||
cpu_kwargs = cpu_sample.kwargs
|
||||
|
|
@ -11896,7 +11897,7 @@ class TestConsistency(TestCaseMPS):
|
|||
|
||||
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
|
||||
def test_output_grad_match(self, device, dtype, op):
|
||||
self.assertEqual(device, "cpu")
|
||||
self.assertEqual(device, "mps:0")
|
||||
|
||||
def get_samples():
|
||||
return op.sample_inputs(
|
||||
|
|
@ -11906,13 +11907,12 @@ class TestConsistency(TestCaseMPS):
|
|||
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
|
||||
set_seed=False,
|
||||
)
|
||||
cpu_samples = get_samples()
|
||||
|
||||
for cpu_sample in cpu_samples:
|
||||
for mps_sample in get_samples():
|
||||
#
|
||||
# Forward check
|
||||
#
|
||||
mps_sample = transform_opinfo_sample_to_mps(cpu_sample)
|
||||
cpu_sample = transform_opinfo_sample_to_cpu(mps_sample)
|
||||
|
||||
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
|
||||
cpu_kwargs = cpu_sample.kwargs
|
||||
|
|
@ -12290,7 +12290,7 @@ class TestMetalLibrary(TestCaseMPS):
|
|||
# This requires mps to be properly registered in the device generic test framework which is not the
|
||||
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
|
||||
# to achieve this.
|
||||
instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
|
||||
instantiate_device_type_tests(TestConsistency, globals(), allow_mps=True, only_for="mps")
|
||||
instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
|
||||
instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
|
||||
instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY,
|
||||
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
|
||||
GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
|
||||
TEST_WITH_TORCHINDUCTOR
|
||||
TEST_WITH_TORCHINDUCTOR, MACOS_VERSION
|
||||
)
|
||||
from torch.testing._utils import wrapper_set_seed
|
||||
|
||||
|
|
@ -12172,6 +12172,8 @@ op_db: list[OpInfo] = [
|
|||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
|
||||
"TestConsistency", "test_output_grad_match", device_type="mps"),
|
||||
)),
|
||||
OpInfo('addmm',
|
||||
# When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
|
||||
|
|
@ -12206,6 +12208,10 @@ op_db: list[OpInfo] = [
|
|||
DecorateInfo(
|
||||
toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}),
|
||||
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
|
||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
|
||||
"TestConsistency", "test_output_grad_match", device_type="mps"),
|
||||
],
|
||||
sample_inputs_func=sample_inputs_addmv),
|
||||
OpInfo('addbmm',
|
||||
|
|
@ -12232,7 +12238,8 @@ op_db: list[OpInfo] = [
|
|||
torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
|
||||
'TestCommon', 'test_numpy_ref_mps'),
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
|
||||
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5),
|
||||
torch.bfloat16: tol(atol=2e-1, rtol=6e-1)}),
|
||||
'TestConsistency',
|
||||
'test_output_match',
|
||||
),
|
||||
|
|
@ -12328,7 +12335,15 @@ op_db: list[OpInfo] = [
|
|||
# NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
|
||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
|
||||
"TestCommon", "test_out")
|
||||
"TestCommon", "test_out"),
|
||||
# Fast math on MacOS-13?
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}),
|
||||
'TestConsistency',
|
||||
'test_output_match',
|
||||
active_if=lambda _: MACOS_VERSION < 14.0,
|
||||
device_type='mps',
|
||||
dtypes=(torch.float32,)),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_bmm),
|
||||
OpInfo('mv',
|
||||
|
|
@ -12507,6 +12522,8 @@ op_db: list[OpInfo] = [
|
|||
"test_comprehensive",
|
||||
device_type="cuda"
|
||||
),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
|
||||
"TestConsistency", "test_output_grad_match", device_type="mps"),
|
||||
],
|
||||
supports_inplace_autograd=False,
|
||||
supports_forward_ad=True,
|
||||
|
|
@ -13030,6 +13047,8 @@ op_db: list[OpInfo] = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}),
|
||||
"TestInductorOpInfo", "test_comprehensive", device_type="cpu"),
|
||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=3e-4, rtol=1e-4)}),
|
||||
"TestConsistency", "test_output_grad_match", device_type="mps"),
|
||||
)),
|
||||
OpInfo('cross',
|
||||
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
|
|
@ -13181,6 +13200,12 @@ op_db: list[OpInfo] = [
|
|||
'test_fn_grad',
|
||||
dtypes=(torch.float64,),
|
||||
device_type='cpu'),
|
||||
DecorateInfo(unittest.skip("Broken on MacOS13"),
|
||||
'TestConsistency',
|
||||
'test_output_match',
|
||||
device_type='mps',
|
||||
dtypes=(torch.float16,),
|
||||
active_if=lambda _: MACOS_VERSION < 14.0),
|
||||
)),
|
||||
BinaryUfuncInfo('true_divide',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
|
|
@ -13409,6 +13434,12 @@ op_db: list[OpInfo] = [
|
|||
"test_comprehensive",
|
||||
device_type="cuda"
|
||||
),
|
||||
DecorateInfo(unittest.skip("Broken on MacOS13"),
|
||||
'TestConsistency',
|
||||
'test_output_match',
|
||||
device_type='mps',
|
||||
dtypes=(torch.float16,),
|
||||
active_if=lambda _: MACOS_VERSION < 14.0),
|
||||
)),
|
||||
UnaryUfuncInfo('frac',
|
||||
ref=lambda x: np.modf(x)[0],
|
||||
|
|
@ -15048,6 +15079,8 @@ op_db: list[OpInfo] = [
|
|||
'test_variant_consistency_jit',
|
||||
dtypes=(torch.float32,)
|
||||
),
|
||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
),
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
|
|
@ -16905,6 +16938,14 @@ op_db: list[OpInfo] = [
|
|||
'TestSchemaCheckModeOpInfo',
|
||||
'test_schema_correctness',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
# Fast math on MacOS-13?
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}),
|
||||
'TestConsistency',
|
||||
'test_output_match',
|
||||
active_if=lambda _: MACOS_VERSION < 14.0,
|
||||
device_type='mps',
|
||||
dtypes=(torch.float32,)),
|
||||
)),
|
||||
OpInfo('mode',
|
||||
op=torch.mode,
|
||||
|
|
@ -17334,6 +17375,8 @@ op_db: list[OpInfo] = [
|
|||
dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
|
||||
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
|
||||
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
|
||||
"TestConsistency", "test_output_grad_match", device_type="mps"),
|
||||
),
|
||||
decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
|
||||
UnaryUfuncInfo('sinc',
|
||||
|
|
@ -17704,6 +17747,10 @@ op_db: list[OpInfo] = [
|
|||
dtypes=(torch.float16,),
|
||||
device_type="cuda",
|
||||
),
|
||||
DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
|
||||
"TestConsistency", "test_output_grad_match", device_type="mps"),
|
||||
),
|
||||
# tan(pi/2 * odd_number) is nan
|
||||
reference_numerics_filter=NumericsFilter(
|
||||
|
|
@ -17738,6 +17785,8 @@ op_db: list[OpInfo] = [
|
|||
active_if=(IS_MACOS or IS_WINDOWS)),
|
||||
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
|
||||
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
|
||||
DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
),
|
||||
# tan(j * pi/2 * odd_number) is nan
|
||||
reference_numerics_filter=NumericsFilter(
|
||||
|
|
@ -17941,6 +17990,8 @@ op_db: list[OpInfo] = [
|
|||
active_if=IS_MACOS),
|
||||
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
|
||||
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
|
||||
DecorateInfo(toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
)),
|
||||
UnaryUfuncInfo('square',
|
||||
ref=np.square,
|
||||
|
|
@ -19619,6 +19670,8 @@ op_db: list[OpInfo] = [
|
|||
# Decomp max diff: 1.8187482915266173e-06
|
||||
DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive',
|
||||
device_type='cpu', dtypes=(torch.float16,)),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-4, rtol=3e-6)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
)),
|
||||
ShapeFuncInfo('repeat',
|
||||
op=lambda x, dims: x.repeat(dims),
|
||||
|
|
@ -21066,6 +21119,8 @@ op_db: list[OpInfo] = [
|
|||
device_type='cuda', dtypes=[torch.float16]),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
|
||||
device_type='cuda', dtypes=[torch.complex64]),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-5, rtol=4e-2)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
),
|
||||
),
|
||||
ReductionOpInfo(
|
||||
|
|
@ -21253,6 +21308,8 @@ op_db: list[OpInfo] = [
|
|||
# possibly bad low precision reference in numpy
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
|
||||
dtypes=[torch.float16]),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-3, rtol=4e-2)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
|
|
@ -21352,6 +21409,10 @@ op_db: list[OpInfo] = [
|
|||
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
|
||||
# please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}),
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}),
|
||||
"TestConsistency", "test_output_grad_match", device_type="mps"),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
|
|
|
|||
|
|
@ -282,6 +282,8 @@ if torch.backends.mps.is_available():
|
|||
}
|
||||
# Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
|
||||
MACOS_BEFORE_13_3_XFAILLIST = {
|
||||
# float16 seems horribly wrong on MacOS13
|
||||
"floor_divide": [torch.float16],
|
||||
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
|
||||
"tan": [torch.float32],
|
||||
"cdist": [torch.float32],
|
||||
|
|
@ -312,6 +314,15 @@ if torch.backends.mps.is_available():
|
|||
"masked.cumsum": [torch.int64],
|
||||
"masked.cumprod": [torch.int64],
|
||||
"linalg.vander": [torch.int64],
|
||||
# Fail with `Expected 1.0 but got nan.` for empty tensors
|
||||
# Caused by sample input at index 23: SampleInput(
|
||||
# input=Tensor[size=(), device="mps:0", dtype=torch.float32],
|
||||
# args=(0),
|
||||
# kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'},
|
||||
# broadcasts_input=False, name='')
|
||||
"masked.softmin": [torch.float32, torch.float16],
|
||||
"masked.softmax": [torch.float32, torch.float16],
|
||||
"masked.log_softmax": [torch.float32, torch.float16],
|
||||
}
|
||||
|
||||
MACOS_AFTER_13_1_XFAILLIST = {
|
||||
|
|
@ -455,6 +466,7 @@ if torch.backends.mps.is_available():
|
|||
"_segment_reducelengths": None,
|
||||
"_segment_reduceoffsets": None,
|
||||
"sparse.mm": None,
|
||||
"sparse.sampled_addmm": None,
|
||||
"sparse.mmreduce": None,
|
||||
"special.airy_ai": None,
|
||||
"special.erfcx": None,
|
||||
|
|
@ -586,10 +598,6 @@ if torch.backends.mps.is_available():
|
|||
# round not working properly for float16 and bfloat16
|
||||
"round": [torch.float16, torch.bfloat16],
|
||||
"rounddecimals_0": [torch.bfloat16],
|
||||
# bfloat16 have weird issues with rounding
|
||||
"divfloor_rounding": [torch.bfloat16],
|
||||
"floor_divide": [torch.bfloat16],
|
||||
"remainder": [torch.bfloat16],
|
||||
# atomic operations not supported
|
||||
"_unsafe_masked_index_put_accumulate": [
|
||||
torch.bool,
|
||||
|
|
@ -654,7 +662,6 @@ if torch.backends.mps.is_available():
|
|||
torch.int64,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.bfloat16,
|
||||
],
|
||||
# Failures due to random output that they generate using
|
||||
# Philox engine causing mismatch with CPU results
|
||||
|
|
@ -857,9 +864,6 @@ if torch.backends.mps.is_available():
|
|||
|
||||
def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
|
||||
XFAILLIST_GRAD = {
|
||||
# precision issues
|
||||
"special.polygammaspecial_polygamma_n_0": [torch.float16],
|
||||
"polygammapolygamma_n_0": [torch.float16],
|
||||
# Unimplemented ops
|
||||
"_segment_reduce": [torch.float16, torch.float32],
|
||||
"_chunk_cat": [torch.float16, torch.float32],
|
||||
|
|
@ -911,8 +915,6 @@ if torch.backends.mps.is_available():
|
|||
"equal": [torch.float16, torch.float32],
|
||||
# 'float' object is not iterable
|
||||
"item": [torch.float16, torch.float32],
|
||||
# "mse_backward_cpu_out" not implemented for 'Half'
|
||||
"nn.functional.mse_loss": [torch.float16],
|
||||
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
|
||||
"nn.functional.smooth_l1_loss": [torch.float16],
|
||||
# cpu error: grad requires non-empty inputs
|
||||
|
|
@ -934,13 +936,18 @@ if torch.backends.mps.is_available():
|
|||
"fmod": [torch.float16],
|
||||
# round not working properly for float16
|
||||
"round": [torch.float16],
|
||||
# topk fails with duplicate indices
|
||||
"topk": [torch.float16],
|
||||
}
|
||||
|
||||
MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
|
||||
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
|
||||
# Failures due to precision issues (may be fast-math). These has been fixed in MacOS 14
|
||||
"masked.softmin": [torch.float32, torch.float16],
|
||||
"masked.softmax": [torch.float32, torch.float16],
|
||||
"masked.log_softmax": [torch.float32, torch.float16],
|
||||
"atanh": [torch.float16],
|
||||
"__rmod__": [torch.float16],
|
||||
"triangular_solve": [torch.float32],
|
||||
# Unsupported Border padding mode, forward pass success as fallback to cpu
|
||||
"grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16],
|
||||
# Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
|
||||
|
|
|
|||
|
|
@ -1554,6 +1554,14 @@ op_db: list[OpInfo] = [
|
|||
supports_fwgrad_bwgrad=True,
|
||||
check_batched_grad=False,
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=8e-5, rtol=2e-6)}),
|
||||
"TestConsistency",
|
||||
"test_output_grad_match",
|
||||
device_type="mps",
|
||||
),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_linalg_matrix_power,
|
||||
),
|
||||
OpInfo(
|
||||
|
|
@ -2284,6 +2292,12 @@ op_db: list[OpInfo] = [
|
|||
"test_noncontiguous_samples",
|
||||
device_type="cpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=2e-04, rtol=3e-06)}),
|
||||
"TestConsistency",
|
||||
"test_output_match",
|
||||
device_type="mps",
|
||||
),
|
||||
],
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user