from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_utils import TestCase, run_tests, freeze_rng_state class TestModule(TestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True precision = 1e-5 rel_tol = 1e-5 @modules(module_db) def test_forward(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) for module_input in module_inputs: if module_input.forward_input is None: continue with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Do forward pass. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs outputs = m(*args, **kwargs) # === Compare outputs to a reference if one is specified. === # TODO: Handle precision reference_fn = module_input.reference_fn if reference_fn is not None: ref_outputs = reference_fn(m, *args, **kwargs) self.assertEqual(outputs, ref_outputs) instantiate_device_type_tests(TestModule, globals()) if __name__ == '__main__': run_tests()