Fix setUpClass() / tearDownClass() for device-specific tests (#151129)

Finishes up the work started in #121686 + adds test

Update: this was not as straightforward as I originally imagined. Context below.

**TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not.

**Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`.

Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`.

(1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:

070f389745/test/test_ops.py (L218-L221)

(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:

070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)

070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)

To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc).

**Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:
4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)

To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384))). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129
Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
This commit is contained in:
Joel Schlosser 2025-04-15 18:14:41 -04:00 committed by PyTorch MergeBot
parent 067a7b1d4a
commit ae53510b9e
7 changed files with 77 additions and 74 deletions

View File

@ -327,7 +327,7 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
class TestFlexAttention(InductorTestCase):
def setUp(self):
super(self.__class__, self).setUp()
super().setUp()
self.test_inference_only = False
if test_device[0] == "cpu":
if LONG_COMPILATION_ON_CPU:
@ -3933,7 +3933,7 @@ class GraphModule(torch.nn.Module):
class TestBlockMask(InductorTestCase):
def setUp(self):
super(self.__class__, self).setUp()
super().setUp()
if test_device[0] == "cpu":
self.skipTest(
"skip UT for CPUs as 'BlockMask' is common and covered on CUDA"
@ -4496,7 +4496,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
class TestPagedAttention(InductorTestCase):
def setUp(self):
super(self.__class__, self).setUp()
super().setUp()
if test_device[0] == "cpu":
if LONG_COMPILATION_ON_CPU:
self.skipTest(

View File

@ -277,7 +277,7 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
class TestFlexDecoding(InductorTestCase):
def setUp(self):
super(self.__class__, self).setUp()
super().setUp()
self.test_inference_only = False
if test_device[0] == "cpu":
if LONG_COMPILATION_ON_CPU:

View File

@ -125,12 +125,12 @@ class TestLinalg(TestCase):
del os.environ["HIPBLASLT_ALLOW_TF32"]
def setUp(self):
super(self.__class__, self).setUp()
super().setUp()
torch.backends.cuda.matmul.allow_tf32 = False
def tearDown(self):
torch.backends.cuda.matmul.allow_tf32 = True
super(self.__class__, self).tearDown()
super().tearDown()
@contextlib.contextmanager
def _tunableop_ctx(self):

View File

@ -64,12 +64,12 @@ assert torch.get_default_dtype() is torch.float32
@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
class TestMatmulCuda(TestCase):
def setUp(self):
super(self.__class__, self).setUp()
super().setUp()
torch.backends.cuda.matmul.allow_tf32 = False
def tearDown(self):
torch.backends.cuda.matmul.allow_tf32 = True
super(self.__class__, self).tearDown()
super().tearDown()
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
#

View File

@ -859,7 +859,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
def tearDown(self):
SparseSemiStructuredTensor._FORCE_CUTLASS = False
super(self.__class__, self).tearDown()
super().tearDown()
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
@inference_dtypes

View File

@ -440,6 +440,40 @@ if __name__ == '__main__':
op.supported_dtypes(torch.device("cuda", index=1)),
)
def test_setup_and_teardown_run_for_device_specific_tests(self, device):
# TODO: Move this (and other similar text blocks) to some fixtures/ subdir
stderr = TestCase.runWithPytorchAPIUsageStderr(f"""\
#!/usr/bin/env python3
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import TestCase, run_tests
class TestFoo(TestCase):
@classmethod
def setUpClass(cls):
# store something on the test class to query during teardown
cls.stored_thing = "called with " + cls.__name__
@classmethod
def tearDownClass(cls):
# throw here so we know teardown was run
raise RuntimeError(cls.stored_thing)
def test_bar(self, device):
# make sure the test can access the stored thing
print(self.stored_thing)
instantiate_device_type_tests(TestFoo, globals(), only_for='{self.device_type}')
if __name__ == '__main__':
run_tests()
""")
expected_device_class_name = f"TestFoo{self.device_type.upper()}"
expected_error_text = f"RuntimeError: called with {expected_device_class_name}"
self.assertIn(expected_error_text, stderr)
instantiate_device_type_tests(TestTesting, globals())

View File

@ -281,35 +281,6 @@ except ModuleNotFoundError:
# they are run. This makes it useful for initializing devices and dependencies.
# Note [Overriding methods in generic tests]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Device generic tests look a lot like normal test classes, but they differ
# from ordinary classes in some important ways. In particular, overriding
# methods in generic tests doesn't work quite the way you expect.
#
# class TestFooDeviceType(TestCase):
# # Intention is to override
# def assertEqual(self, x, y):
# # This DOESN'T WORK!
# super().assertEqual(x, y)
#
# If you try to run this code, you'll get an error saying that TestFooDeviceType
# is not in scope. This is because after instantiating our classes, we delete
# it from the parent scope. Instead, you need to hardcode a direct invocation
# of the desired subclass call, e.g.,
#
# class TestFooDeviceType(TestCase):
# # Intention is to override
# def assertEqual(self, x, y):
# TestCase.assertEqual(x, y)
#
# However, a less error-prone way of customizing the behavior of TestCase
# is to either (1) add your functionality to TestCase and make it toggled
# by a class attribute, or (2) create your own subclass of TestCase, and
# then inherit from it for your generic test.
def _dtype_test_suffix(dtypes):
"""Returns the test suffix for a dtype, sequence of dtypes, or None."""
if isinstance(dtypes, (list, tuple)):
@ -893,20 +864,7 @@ def instantiate_device_type_tests(
# are not discoverable.
del scope[generic_test_class.__name__]
# Creates an 'empty' version of the generic_test_class
# Note: we don't inherit from the generic_test_class directly because
# that would add its tests to our test classes and they would be
# discovered (despite not being runnable). Inherited methods also
# can't be removed later, and we can't rely on load_tests because
# pytest doesn't support it (as of this writing).
empty_name = generic_test_class.__name__ + "_base"
empty_class = type(empty_name, generic_test_class.__bases__, {})
# Acquires members names
# See Note [Overriding methods in generic tests]
generic_members = set(generic_test_class.__dict__.keys()) - set(
empty_class.__dict__.keys()
)
generic_members = set(generic_test_class.__dict__.keys())
generic_tests = [x for x in generic_members if x.startswith("test")]
# Creates device-specific test cases
@ -917,7 +875,30 @@ def instantiate_device_type_tests(
# type set to Any and suppressed due to unsupport runtime class:
# https://github.com/python/mypy/wiki/Unsupported-Python-Features
device_type_test_class: Any = type(class_name, (base, empty_class), {})
device_type_test_class: Any = type(class_name, (base, generic_test_class), {})
# Arrange for setUpClass and tearDownClass methods defined both in the test template
# class and in the generic base to be called. This allows device-parameterized test
# classes to support setup and teardown.
# NB: This should be done before instantiate_test() is called as that invokes setup.
@classmethod
def _setUpClass(cls):
# This should always be called, whether or not the test class invokes
# super().setUpClass(), to set the primary device.
base.setUpClass()
# We want to call the @classmethod defined in the generic base, but pass
# it the device-specific class object (cls), hence the __func__ call.
generic_test_class.setUpClass.__func__(cls)
@classmethod
def _tearDownClass(cls):
# We want to call the @classmethod defined in the generic base, but pass
# it the device-specific class object (cls), hence the __func__ call.
generic_test_class.tearDownClass.__func__(cls)
base.tearDownClass()
device_type_test_class.setUpClass = _setUpClass
device_type_test_class.tearDownClass = _tearDownClass
for name in generic_members:
if name in generic_tests: # Instantiates test member
@ -931,30 +912,11 @@ def instantiate_device_type_tests(
)
else:
device_type_test_class.instantiate_test(name, copy.deepcopy(test))
else: # Ports non-test member
assert (
name not in device_type_test_class.__dict__
), f"Redefinition of directly defined member {name}"
# Ports non-test member. Setup / teardown have already been handled above
elif name not in device_type_test_class.__dict__:
nontest = getattr(generic_test_class, name)
setattr(device_type_test_class, name, nontest)
# The dynamically-created test class derives from the test template class
# and the empty class. Arrange for both setUpClass and tearDownClass methods
# to be called. This allows the parameterized test classes to support setup
# and teardown.
@classmethod
def _setUpClass(cls):
base.setUpClass()
empty_class.setUpClass()
@classmethod
def _tearDownClass(cls):
empty_class.tearDownClass()
base.tearDownClass()
device_type_test_class.setUpClass = _setUpClass
device_type_test_class.tearDownClass = _tearDownClass
# Mimics defining the instantiated class in the caller's file
# by setting its module to the given class's and adding
# the module to the given scope.
@ -962,6 +924,13 @@ def instantiate_device_type_tests(
device_type_test_class.__module__ = generic_test_class.__module__
scope[class_name] = device_type_test_class
# Delete the generic form of the test functions (e.g. TestFoo.test_bar()) so they're
# not discoverable. This mutates the original class (TestFoo), which was removed from
# scope above. At this point, device-specific tests (e.g. TestFooCUDA.test_bar_cuda)
# have already been created and the generic forms are no longer needed.
for name in generic_tests:
delattr(generic_test_class, name)
# Category of dtypes to run an OpInfo-based test for
# Example use: @ops(dtype=OpDTypes.supported)