Remove incorrect usages of skipIfTorchDynamo (#117114)

Using `@skipifTorchDynamo` is wrong, the correct usage is
`@skipIfTorchDynamo()` or `@skipIfTorchDynamo("msg")`. This would cause
tests to stop existing.
Added an assertion for this and fixed the incorrect callsites.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117114
Approved by: https://github.com/voznesenskym
This commit is contained in:
rzou 2024-01-10 13:03:23 -08:00 committed by PyTorch MergeBot
parent d6540038c0
commit 79e6d2ae9d
13 changed files with 32 additions and 25 deletions

View File

@ -3152,7 +3152,7 @@ class TestComposability(TestCase):
torch.vmap(torch.sin)
# Some of these pass, some of these don't
@skipIfTorchDynamo
@skipIfTorchDynamo()
@parametrize('transform', [
'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize'
])
@ -3398,7 +3398,7 @@ class TestComposability(TestCase):
transform(MySin.apply)(x)
# Some of these pass, some of these don't
@skipIfTorchDynamo
@skipIfTorchDynamo()
@parametrize('transform', [
'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize'
])

View File

@ -2558,7 +2558,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_slogdet(self):
test = functools.partial(self._vmap_test, check_propagates_grad=False)
B0 = 7
@ -3869,7 +3869,7 @@ class TestVmapOperatorsOpInfo(TestCase):
self.opinfo_vmap_test(device, torch.float, op, check_has_batch_rule=True,
postprocess_fn=compute_A)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_slogdet(self, device):
# There's no OpInfo for this
def test():
@ -5101,7 +5101,7 @@ class TestRandomness(TestCase):
@markDynamoStrictTest
class TestTransformFailure(TestCase):
@skipIfTorchDynamo
@skipIfTorchDynamo()
@parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd'])
def test_fails_with_autograd_function(self, device, transform):
failed_build_envs = ('linux-focal-py3.8-clang10', 'linux-focal-py3.11-clang10')

View File

@ -100,26 +100,26 @@ class TestAutocastCPU(TestCase):
else:
return op_with_args[0], op_with_args[1], op_with_args[2]
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_autocast_torch_expect_builtin_promote(self):
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_autocast_methods_expect_builtin_promote(self):
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_autocast_torch_16(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_autocast_nn_16(self):
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
@ -135,14 +135,14 @@ class TestAutocastCPU(TestCase):
amp_dtype=torch.float16,
)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
@ -158,7 +158,7 @@ class TestAutocastCPU(TestCase):
amp_dtype=torch.float16,
)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_autocast_torch_need_autocast_promote(self):
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args1, torch.float32)

View File

@ -3487,7 +3487,7 @@ class TestBinaryUfuncs(TestCase):
)
_test_helper(a, b)
@skipIfTorchDynamo # complex infs/nans differ under Dynamo/Inductor
@skipIfTorchDynamo() # complex infs/nans differ under Dynamo/Inductor
@dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16)
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128)
def test_logaddexp(self, device, dtype):

View File

@ -229,7 +229,7 @@ class TestStreamWrapper(TestCase):
for api in ['open', 'read', 'close']:
self.assertTrue(api in s)
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_api(self):
fd = TestStreamWrapper._FakeFD("")
wrap_fd = StreamWrapper(fd)
@ -1408,7 +1408,7 @@ class TestFunctionalIterDataPipe(TestCase):
_helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0)
@suppress_warnings # Suppress warning for lambda fn
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_map_dict_with_col_iterdatapipe(self):
def fn_11(d):
return -d

View File

@ -4051,7 +4051,7 @@ class TestAsArray(TestCase):
t = torch.asarray(e)
self.assertEqual(t, original)
@skipIfTorchDynamo
@skipIfTorchDynamo()
@onlyCPU
def test_numpy_scalars(self, device):
scalar = np.float64(0.5)

View File

@ -267,7 +267,7 @@ class TestUnaryUfuncs(TestCase):
# 1D tensors and a large 2D tensor with interesting and extremal values
# and noncontiguities.
@suppress_warnings
@skipIfTorchDynamo # really flaky
@skipIfTorchDynamo() # really flaky
@ops(reference_filtered_ops)
def test_reference_numerics_normal(self, device, dtype, op):
tensors = generate_elementwise_unary_tensors(
@ -276,7 +276,7 @@ class TestUnaryUfuncs(TestCase):
self._test_reference_numerics(dtype, op, tensors)
@suppress_warnings
@skipIfTorchDynamo # really flaky
@skipIfTorchDynamo() # really flaky
@ops(reference_filtered_ops)
def test_reference_numerics_small(self, device, dtype, op):
if dtype in (torch.bool,):
@ -288,7 +288,7 @@ class TestUnaryUfuncs(TestCase):
self._test_reference_numerics(dtype, op, tensors)
@suppress_warnings
@skipIfTorchDynamo # really flaky
@skipIfTorchDynamo() # really flaky
@ops(reference_filtered_ops)
def test_reference_numerics_large(self, device, dtype, op):
if dtype in (torch.bool, torch.uint8, torch.int8):

View File

@ -395,7 +395,7 @@ class TestIndexing(TestCase):
a = a.reshape(-1, 1)
assert_(a[b, 0].flags.f_contiguous)
@skipIfTorchDynamo # XXX: flaky, depends on implementation details
@skipIfTorchDynamo() # XXX: flaky, depends on implementation details
def test_small_regressions(self):
# Reference count of intp for index checks
a = np.array([0])

View File

@ -162,7 +162,7 @@ class TestScalarTypeNames(TestCase):
"""Test that names correspond to where the type is under ``np.``"""
assert getattr(np, t.__name__) is t
@skipIfTorchDynamo # XXX: weird, some names are not OK
@skipIfTorchDynamo() # XXX: weird, some names are not OK
@parametrize("t", numeric_types)
def test_names_are_undersood_by_dtype(self, t):
"""Test the dtype constructor maps names back to the type"""

View File

@ -130,7 +130,7 @@ class TestTypes(TestCase):
b = atype([1, 2, 3])
assert_equal(a, b)
@skipIfTorchDynamo # freezes under torch.Dynamo (loop unrolling, huh)
@skipIfTorchDynamo() # freezes under torch.Dynamo (loop unrolling, huh)
def test_leak(self):
# test leak of scalar objects
# a leak would show up in valgrind as still-reachable of ~2.6MB

View File

@ -2577,7 +2577,7 @@ class TestBincount(TestCase):
lambda: np.bincount(x, minlength=-1),
)
@skipIfTorchDynamo # flaky test
@skipIfTorchDynamo() # flaky test
@skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
def test_dtype_reference_leaks(self):
# gh-6805

View File

@ -645,7 +645,7 @@ class TestNoExtraMethods(TestCase):
class TestIter(TestCase):
@skipIfTorchDynamo
@skipIfTorchDynamo()
def test_iter_1d(self):
# numpy generates array scalars, we do 0D arrays
a = np.arange(5)

View File

@ -1342,6 +1342,14 @@ def xfailIfTorchDynamo(func):
def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
"""
Usage:
@skipIfTorchDynamo(msg)
def test_blah(self):
...
"""
assert isinstance(msg, str), "Are you using skipIfTorchDynamo correctly?"
def decorator(fn):
if not isinstance(fn, type):
@wraps(fn)
@ -1359,7 +1367,6 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
return fn
return decorator
def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",