From d3dba3c42a06ab9b50f03e436e15f93a27bbac8d Mon Sep 17 00:00:00 2001 From: Joel Benjamin Schlosser Date: Fri, 8 Jul 2022 11:38:57 -0400 Subject: [PATCH] Fix ModuleInfo skip logic (#80471) Fixes #80247 This PR: * Refactors the skip logic as done for OpInfo in #62713, fixing the logic error * For tests that were wrongly skipped before and now fail: * Fix `TestModule.test_cpu_gpu_parity` to support Lazy modules - this was affecting `LazyConv*` * Adds `@expectedFailure` decorators and a follow-up message to address `Conv*` failures on `TestModule.test_memory_format` Pull Request resolved: https://github.com/pytorch/pytorch/pull/80471 Approved by: https://github.com/mruberry --- test/test_modules.py | 6 +++ torch/testing/_internal/common_modules.py | 62 ++++++++++++++++------- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 3ed5f3be76f..a62bfff8de6 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -533,6 +533,12 @@ class TestModule(TestCase): gpu_module = module_cls(*args, **kwargs).to(dtype).to(device) gpu_module.train(training) + # === Lazy modules need to see an input to initialize params === + if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin): + with torch.no_grad(): + cpu_module(*cpu_forward_args, **cpu_forward_kwargs) + gpu_module(*gpu_forward_args, **gpu_forward_kwargs) + for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): gpu_p.data.copy_(cpu_p) diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index c3138dd6fd3..7a0d60b8024 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -10,7 +10,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_cuda import TEST_CUDNN from torch.testing._internal.common_dtype import floating_types from torch.testing._internal.common_device_type import ( - _TestParametrizer, _update_param_kwargs, skipIf, toleranceOverride, tol, + _TestParametrizer, _update_param_kwargs, toleranceOverride, tol, skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta) from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_nn import nllloss_reference, get_reduction @@ -104,25 +104,13 @@ class modules(_TestParametrizer): _update_param_kwargs(param_kwargs, 'training', training) try: - active_decorators = [set_single_threaded_if_parallel_tbb] - if module_info.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype): - active_decorators.append(skipIf(True, "Skipped!")) - - if module_info.decorators is not None: - for decorator in module_info.decorators: - # Can't use isinstance as it would cause a circular import - if decorator.__class__.__name__ == 'DecorateInfo': - if decorator.is_active(generic_cls.__name__, test.__name__, - device_cls.device_type, dtype): - active_decorators += decorator.decorators - else: - active_decorators.append(decorator) @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) - for decorator in active_decorators: + for decorator in module_info.get_decorators(generic_cls.__name__, test.__name__, + device_cls.device_type, dtype): test_wrapper = decorator(test_wrapper) yield (test_wrapper, test_name, param_kwargs) @@ -187,16 +175,22 @@ class ModuleInfo(object): ): self.module_cls = module_cls self.module_inputs_func = module_inputs_func - self.skips = skips - self.decorators = decorators + self.decorators = (*(decorators if decorators else []), *(skips if skips else [])) self.dtypes = dtypes self.supports_gradgrad = supports_gradgrad self.gradcheck_nondet_tol = gradcheck_nondet_tol self.module_memformat_affects_out = module_memformat_affects_out self.train_and_eval_differ = train_and_eval_differ - def should_skip(self, cls_name, test_name, device_type, dtype): - return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips) + def get_decorators(self, test_class, test_name, device, dtype): + result = [set_single_threaded_if_parallel_tbb] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active(test_class, test_name, device, dtype): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result @property def name(self): @@ -1094,6 +1088,10 @@ module_db: List[ModuleInfo] = [ # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='cuda', dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1108,6 +1106,9 @@ module_db: List[ModuleInfo] = [ # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1136,6 +1137,11 @@ module_db: List[ModuleInfo] = [ # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'), + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1150,6 +1156,9 @@ module_db: List[ModuleInfo] = [ # Failure on ROCM for float32 issue #70125 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1196,6 +1205,10 @@ module_db: List[ModuleInfo] = [ # Lazy modules don't currently play well with ModuleInfo tests on the meta device. # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", + device_type='cuda', dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1213,6 +1226,9 @@ module_db: List[ModuleInfo] = [ # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1247,6 +1263,11 @@ module_db: List[ModuleInfo] = [ # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cpu'), + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', + dtypes=[torch.float64]), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -1264,6 +1285,9 @@ module_db: List[ModuleInfo] = [ # See https://github.com/pytorch/pytorch/issues/70505 for more info. DecorateInfo(skipMeta), DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]), + # This was wrongly being skipped before and needs investigation. + # See https://github.com/pytorch/pytorch/issues/80247 + DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),