mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Action following https://github.com/pytorch/pytorch/issues/66232 This change does require some context: there were several suggestions regarding what to do about this group of tests: tests that are core and crucial to all of PyTorch and are too broad to be owned by one team. 1. Let's add a "module: core" and put people behind it! This idea sounds appealing unless you are one of the people backing the label. From talking to albanD among others, this idea of putting all these core tests on the shoulder of a few people or one team isn't super fair and I have not yet found anyone willing to take on this job. 2. Taking advantage of the fact that we already have a triaging oncall that takes turns triaging issues, we can leave these tests essentially unlabeled and allow the oncall to triage these tests. Since these tests are crucial to PyTorch, we'll add the "high priority" label to mark them different from other unowned tests (see https://github.com/pytorch/pytorch/issues/67552). 3. I _could_ still create an unbacked label "module: core" and attribute these tests there, but I don't like the idea of creating a facade that the tests are "triaged" to a label when no one is actually taking a look. Now we could potentially break these tests down into smaller files so that each piece _could_ be owned by a team, but 1. I don't know if this is currently feasible and 2. This approach does not prevent that from happening in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67553 Reviewed By: albanD Differential Revision: D32025004 Pulled By: janeyx99 fbshipit-source-id: 1fb1aa4c27e305695ab6e80ae3d02f90519939c0
377 lines
18 KiB
Python
377 lines
18 KiB
Python
# Owner(s): ["module: nn"]
|
|
|
|
from itertools import product
|
|
from inspect import signature, isgenerator
|
|
from copy import deepcopy
|
|
import tempfile
|
|
|
|
import torch
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_modules import module_db, modules
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
|
|
from unittest.mock import patch
|
|
|
|
|
|
class TestModule(TestCase):
|
|
_do_cuda_memory_leak_check = True
|
|
_do_cuda_non_default_stream = True
|
|
precision = 1e-5
|
|
rel_tol = 1e-5
|
|
|
|
@modules(module_db)
|
|
def test_forward(self, device, dtype, module_info):
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=False)
|
|
for module_input in module_inputs:
|
|
if module_input.forward_input is None:
|
|
continue
|
|
|
|
with freeze_rng_state():
|
|
# === Instantiate the module. ===
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
m = module_cls(*args, **kwargs)
|
|
m.to(device).to(dtype)
|
|
|
|
# === Do forward pass. ===
|
|
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
|
outputs = m(*args, **kwargs)
|
|
|
|
# === Compare outputs to a reference if one is specified. ===
|
|
# TODO: Handle precision
|
|
reference_fn = module_input.reference_fn
|
|
if reference_fn is not None:
|
|
ref_outputs = reference_fn(m, *args, **kwargs)
|
|
self.assertEqual(outputs, ref_outputs)
|
|
|
|
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
|
|
# They should be applied to any created parameters and buffers.
|
|
@modules(module_db)
|
|
def test_factory_kwargs(self, device, dtype, module_info):
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=False)
|
|
for module_input in module_inputs:
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
|
|
# Check if this module creates parameters or registers buffers.
|
|
# The mock magic here passes through to the real Parameter / register_buffer
|
|
# logic and is only used to check call inputs.
|
|
module_creates_params_or_buffers = False
|
|
parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
|
|
with patch.object(torch.nn.Parameter, '__new__', parameter_new):
|
|
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
|
|
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
|
|
m = module_cls(*args, **kwargs)
|
|
|
|
# Check if a parameter or buffer was created with a tensor not passed to the constructor.
|
|
constructor_tensors = get_tensors_from(args, kwargs)
|
|
for mock in [parameter_new.mock, register_buffer.mock]:
|
|
for call_args, call_kwargs in mock.call_args_list:
|
|
call_tensors = get_tensors_from(call_args, call_kwargs)
|
|
if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors):
|
|
module_creates_params_or_buffers = True
|
|
break
|
|
|
|
if not module_creates_params_or_buffers:
|
|
continue
|
|
|
|
# Instantiate module with the factory kwargs.
|
|
kwargs.update({
|
|
'device': device,
|
|
'dtype': dtype,
|
|
})
|
|
|
|
if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
|
|
# Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
|
|
uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
|
|
with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
|
|
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
|
|
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
|
|
m = module_cls(*args, **kwargs)
|
|
uninit_param_new.mock.assert_has_calls(
|
|
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
|
|
uninit_buffer_new.mock.assert_has_calls(
|
|
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
|
|
else:
|
|
# Check device placement and dtype for created parameters and buffers.
|
|
# Only verify floating point dtypes since that's what the kwarg applies to.
|
|
m = module_cls(*args, **kwargs)
|
|
for name, param in m.named_parameters():
|
|
self.assertEqual(
|
|
str(param.device), device,
|
|
f'Parameter {name} is on {param.device.type} instead of the expected device {device}')
|
|
if param.dtype.is_floating_point:
|
|
self.assertEqual(
|
|
param.dtype, dtype,
|
|
f'Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}')
|
|
for name, buffer in m.named_buffers():
|
|
self.assertEqual(
|
|
str(buffer.device), device,
|
|
f'Buffer {name} is on {buffer.device.type} instead of the expected device {device}')
|
|
if buffer.dtype.is_floating_point:
|
|
self.assertEqual(
|
|
buffer.dtype, dtype,
|
|
f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}')
|
|
|
|
@modules(module_db)
|
|
def test_repr(self, device, dtype, module_info):
|
|
# Test module can be represented with repr and str without errors.
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=False)
|
|
for module_input in module_inputs:
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
m = module_cls(*args, **kwargs)
|
|
|
|
# Check that these methods do not raise errors
|
|
m.__repr__()
|
|
str(m)
|
|
|
|
@modules(module_db)
|
|
def test_pickle(self, device, dtype, module_info):
|
|
# Test that module can be pickled and unpickled.
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=False)
|
|
for module_input in module_inputs:
|
|
if module_input.forward_input is None:
|
|
continue
|
|
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
|
|
with freeze_rng_state():
|
|
# === Instantiate the module. ===
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
m = module_cls(*args, **kwargs)
|
|
m.to(device).to(dtype)
|
|
|
|
# === Do forward pass. ===
|
|
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
|
output = m(*args, **kwargs)
|
|
|
|
# === Check unpickled module gives the same output. ===
|
|
with tempfile.TemporaryFile() as f:
|
|
torch.save(m, f)
|
|
f.seek(0)
|
|
m_copy = torch.load(f)
|
|
output_from_copy = m_copy(*args, **kwargs)
|
|
self.assertEqual(output, output_from_copy)
|
|
|
|
@modules([module_info for module_info in module_db
|
|
if 'inplace' in signature(module_info.module_cls).parameters])
|
|
def test_check_inplace(self, device, dtype, module_info):
|
|
# Check if the inplace variant of the module gives the same result as the out of place
|
|
# variant.
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=True)
|
|
for module_input in module_inputs:
|
|
if module_input.forward_input is None:
|
|
continue
|
|
|
|
# === Instantiate the module. ===
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
m_op = module_cls(*args, **kwargs, inplace=False)
|
|
m_op.to(device).to(dtype)
|
|
m_inplace = module_cls(*args, **kwargs, inplace=True)
|
|
m_inplace.to(device).to(dtype)
|
|
|
|
# === Inplace modules only supports inplace operations on the first argument ===
|
|
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
|
|
|
# === Do not allow the first input to be in input_kwargs ===
|
|
forward_sig = signature(m_op).parameters
|
|
self.assertGreaterEqual(len(forward_sig), 1)
|
|
first_param_name = next(iter(forward_sig.items()))
|
|
self.assertNotIn(first_param_name, input_kwargs)
|
|
|
|
# === Out of place operation does not write to original tensor ===
|
|
self.assertGreaterEqual(len(input_args), 1)
|
|
input_version = input_args[0]._version
|
|
with freeze_rng_state():
|
|
output_op = m_op(*input_args, **input_kwargs)
|
|
self.assertEqual(input_args[0]._version, input_version)
|
|
|
|
# === Check that the inplace operation gives the same result ===
|
|
input_arg_copy = deepcopy(input_args)
|
|
input_arg_clone = tuple(i.clone() for i in input_arg_copy)
|
|
with freeze_rng_state():
|
|
output_ip = m_inplace(*input_arg_clone, **input_kwargs)
|
|
self.assertNotEqual(input_arg_clone[0]._version, input_version)
|
|
self.assertEqual(output_op, output_ip)
|
|
|
|
# === Check that the gradients are the same ===
|
|
grad = output_op.data.clone().normal_()
|
|
output_op.backward(grad)
|
|
output_ip.backward(grad)
|
|
self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
|
|
|
|
def _traverse_obj(self, obj, func):
|
|
if isinstance(obj, (tuple, list)):
|
|
return type(obj)(self._traverse_obj(o, func) for o in obj)
|
|
elif isgenerator(obj):
|
|
return tuple(self._traverse_obj(o, func) for o in obj)
|
|
elif isinstance(obj, dict):
|
|
return {name: self._traverse_obj(o, func) for name, o in obj.items()}
|
|
elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
|
|
return func(obj)
|
|
|
|
def _retain_grad(self, obj):
|
|
# gradients needs to be retained to check for grad. This is useful when
|
|
# non-leafs are present in the graph.
|
|
def inner_retain_grad(obj):
|
|
if obj.requires_grad:
|
|
obj.retain_grad()
|
|
self._traverse_obj(obj, inner_retain_grad)
|
|
|
|
def _get_grads(self, obj):
|
|
def inner_get_grad(obj):
|
|
if obj.requires_grad:
|
|
return obj.grad
|
|
return self._traverse_obj(obj, inner_get_grad)
|
|
|
|
def _zero_grad(self, obj):
|
|
def inner_zero_grad(obj):
|
|
if obj.grad is not None:
|
|
obj.grad = None
|
|
self._traverse_obj(obj, inner_zero_grad)
|
|
|
|
@modules(module_db)
|
|
def test_non_contiguous_tensors(self, device, dtype, module_info):
|
|
# Check modules work with non-contiguous tensors
|
|
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=True)
|
|
|
|
def _make_non_contiguous(obj):
|
|
def inner_make_non_contiguous(obj):
|
|
# Scalar tensors can not be made non-contiguous
|
|
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
|
|
return obj
|
|
|
|
out = torch.repeat_interleave(obj, 2, dim=-1)
|
|
out = out[..., ::2].detach()
|
|
out.requires_grad = obj.requires_grad
|
|
return out
|
|
return self._traverse_obj(obj, inner_make_non_contiguous)
|
|
|
|
def _can_be_noncontiguous(obj):
|
|
if isinstance(obj, (tuple, list)):
|
|
return any(_can_be_noncontiguous(o) for o in obj)
|
|
elif isinstance(obj, dict):
|
|
return any(_can_be_noncontiguous(o) for o in obj.values())
|
|
# scalar tensors can not be non-contiguous
|
|
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
for module_input in module_inputs:
|
|
if module_input.forward_input is None:
|
|
continue
|
|
|
|
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
|
if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)):
|
|
continue
|
|
|
|
# === Instantiate the module. ===
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
m = module_cls(*args, **kwargs)
|
|
m.to(device).to(dtype)
|
|
|
|
self._retain_grad((input_args, input_kwargs))
|
|
|
|
# === Forward with default input
|
|
with freeze_rng_state():
|
|
default_output = m(*input_args, **input_kwargs)
|
|
grad_output = default_output.clone().detach_().normal_()
|
|
default_output.backward(grad_output, retain_graph=True)
|
|
|
|
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
|
|
default_param_grad = deepcopy([p.grad for p in m.parameters()])
|
|
|
|
# === Construct non-contiguous tensors ===
|
|
nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs))
|
|
nc_grad_output = _make_non_contiguous(grad_output)
|
|
|
|
# === Compare results with non-contiguous and contiguous tensors ===
|
|
inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)]
|
|
grads = [grad_output, nc_grad_output]
|
|
|
|
for (in_args, in_kwargs), g_out in product(inputs, grads):
|
|
g_out_copy = deepcopy(g_out)
|
|
self._zero_grad((in_args, in_kwargs))
|
|
self._zero_grad(m.parameters())
|
|
|
|
with freeze_rng_state():
|
|
out = m(*in_args, **in_kwargs)
|
|
out.backward(g_out_copy, retain_graph=True)
|
|
|
|
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
|
|
self.assertEqual(out, default_output)
|
|
self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0)
|
|
self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0)
|
|
|
|
param_grad = [p.grad for p in m.parameters()]
|
|
self.assertEqual(param_grad, default_param_grad)
|
|
|
|
|
|
def _test_gradients_helper(self, device, dtype, module_info, check):
|
|
# Check gradients
|
|
module_cls = module_info.module_cls
|
|
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
|
requires_grad=True)
|
|
|
|
for module_input in module_inputs:
|
|
if module_input.forward_input is None:
|
|
continue
|
|
|
|
# === Instantiate the module. ===
|
|
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
|
m = module_cls(*args, **kwargs)
|
|
m.to(device).to(dtype)
|
|
|
|
params = tuple(m.parameters())
|
|
|
|
# === Perform gradient check on the input_args ===
|
|
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
|
|
|
other_kwargs = {}
|
|
kwarg_tensors = []
|
|
for name, obj in input_kwargs.items():
|
|
if isinstance(obj, torch.Tensor):
|
|
kwarg_tensors.append((name, obj))
|
|
else:
|
|
other_kwargs[name] = obj
|
|
|
|
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
|
|
|
|
def fn_to_gradcheck(*input_and_params):
|
|
new_input_args = input_and_params[:len(input_args)]
|
|
kwarg_args = input_and_params[-len(kwarg_tensors):]
|
|
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
|
|
|
|
with freeze_rng_state():
|
|
return m(*new_input_args, **new_kwargs, **other_kwargs)
|
|
|
|
self.assertTrue(check(fn_to_gradcheck, grad_input))
|
|
|
|
|
|
@modules(module_db, allowed_dtypes=[torch.double])
|
|
def test_grad(self, device, dtype, module_info):
|
|
self._test_gradients_helper(device, dtype, module_info, gradcheck)
|
|
|
|
@modules([m for m in module_db if m.supports_gradgrad],
|
|
allowed_dtypes=[torch.double])
|
|
def test_gradgrad(self, device, dtype, module_info):
|
|
self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
|
|
|
|
|
|
instantiate_device_type_tests(TestModule, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|