import tempfile import torch from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_utils import ( TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from) from unittest.mock import patch class TestModule(TestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True precision = 1e-5 rel_tol = 1e-5 @modules(module_db) def test_forward(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) for module_input in module_inputs: if module_input.forward_input is None: continue with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Do forward pass. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs outputs = m(*args, **kwargs) # === Compare outputs to a reference if one is specified. === # TODO: Handle precision reference_fn = module_input.reference_fn if reference_fn is not None: ref_outputs = reference_fn(m, *args, **kwargs) self.assertEqual(outputs, ref_outputs) # Tests passing factory kwargs (e.g. device / dtype) during module instantiation. # They should be applied to any created parameters and buffers. @modules(module_db) def test_factory_kwargs(self, device, dtype, module_info): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) for module_input in module_inputs: args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs # Check if this module creates parameters or registers buffers. # The mock magic here passes through to the real Parameter / register_buffer # logic and is only used to check call inputs. module_creates_params_or_buffers = False parameter_new = mock_wrapper(torch.nn.Parameter.__new__) with patch.object(torch.nn.Parameter, '__new__', parameter_new): register_buffer = mock_wrapper(torch.nn.Module.register_buffer) with patch.object(torch.nn.Module, 'register_buffer', register_buffer): m = module_cls(*args, **kwargs) # Check if a parameter or buffer was created with a tensor not passed to the constructor. constructor_tensors = get_tensors_from(args, kwargs) for mock in [parameter_new.mock, register_buffer.mock]: for call_args, call_kwargs in mock.call_args_list: call_tensors = get_tensors_from(call_args, call_kwargs) if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors): module_creates_params_or_buffers = True break if not module_creates_params_or_buffers: continue # Instantiate module with the factory kwargs. kwargs.update({ 'device': device, 'dtype': dtype, }) if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers. uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__) with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new): uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__) with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new): m = module_cls(*args, **kwargs) uninit_param_new.mock.assert_has_calls( [mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls]) uninit_buffer_new.mock.assert_has_calls( [mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls]) else: # Check device placement and dtype for created parameters and buffers. # Only verify floating point dtypes since that's what the kwarg applies to. m = module_cls(*args, **kwargs) for name, param in m.named_parameters(): self.assertEqual( str(param.device), device, f'Parameter {name} is on {param.device.type} instead of the expected device {device}') if param.dtype.is_floating_point: self.assertEqual( param.dtype, dtype, f'Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}') for name, buffer in m.named_buffers(): self.assertEqual( str(buffer.device), device, f'Buffer {name} is on {buffer.device.type} instead of the expected device {device}') if buffer.dtype.is_floating_point: self.assertEqual( buffer.dtype, dtype, f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}') @modules(module_db) def test_pickle(self, device, dtype, module_info): # Test that module can be pickled and unpickled. module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, requires_grad=False) for module_input in module_inputs: if module_input.forward_input is None: continue args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs with freeze_rng_state(): # === Instantiate the module. === args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs m = module_cls(*args, **kwargs) m.to(device).to(dtype) # === Do forward pass. === args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs output = m(*args, **kwargs) # === Check unpickled module gives the same output. === with tempfile.TemporaryFile() as f: torch.save(m, f) f.seek(0) m_copy = torch.load(f) output_from_copy = m_copy(*args, **kwargs) self.assertEqual(output, output_from_copy) instantiate_device_type_tests(TestModule, globals()) if __name__ == '__main__': run_tests()