# Owner(s): ["module: unknown"] from functools import partial, wraps from itertools import chain import torch from torch.testing._internal.common_utils import \ (TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, OpDTypes) # TODO: fixme https://github.com/pytorch/pytorch/issues/68972 torch.set_default_dtype(torch.float32) # gradcheck requires double precision _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]) class TestGradients(TestCase): exact_dtype = True # Copies inputs to inplace operations to avoid inplace modifications # to leaves requiring gradient def _get_safe_inplace(self, inplace_variant): @wraps(inplace_variant) def _fn(t, *args, **kwargs): return inplace_variant(t.clone(), *args, **kwargs) return _fn def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True, check_batched_grad=None, check_batched_forward_grad=False): assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad') # NB: check_backward_ad does not affect gradgradcheck (always True) if variant is None: self.skipTest("Skipped! Variant not implemented.") if not op.supports_dtype(dtype, torch.device(device).type): self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") def is_inplace(variant): if hasattr(variant, "__wrapped__"): return variant.__wrapped__ is op.get_inplace() return variant is op.get_inplace() include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs) for sample in samples: if sample.broadcasts_input and is_inplace(variant): continue # Gradcheck expects tensors as its input, but autograd actually supports tensorlists # and tensors passed as kwargs. The following creates a function that accepts just # the tensors that require grad as varargs, and then recomposes them back into the # original input. # Creates gradcheck inputs by identifying tensors requiring grad all_args = None if is_iterable_of_tensors(sample.input): all_args = chain(sample.input, sample.args, sample.kwargs.values()) else: all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) def _input_recomposition_helper(inputs, inp, input_idx): if is_iterable_of_tensors(inp): tensor_list = [] for x in inp: if isinstance(x, torch.Tensor) and x.requires_grad: tensor_list.append(inputs[input_idx]) input_idx = input_idx + 1 else: tensor_list.append(x) return tensor_list, input_idx elif isinstance(inp, torch.Tensor) and inp.requires_grad: return inputs[input_idx], input_idx + 1 else: return inp, input_idx def fn(*inputs): # Puts inputs back into sample properly positional_args = [] input_idx = 0 inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx) positional_args.append(inp) for x in sample.args: inp, input_idx = _input_recomposition_helper(inputs, x, input_idx) positional_args.append(inp) # Recreates kwargs kwargs = {} for k, v in sample.kwargs.items(): inp, input_idx = _input_recomposition_helper(inputs, v, input_idx) kwargs[k] = inp output = op.gradcheck_wrapper(variant, *positional_args, **kwargs) if sample.output_process_fn_grad is not None: return sample.output_process_fn_grad(output) return output if check == 'gradcheck': if check_batched_grad is None: check_batched_grad = op.check_batched_grad self.assertTrue(gradcheck(fn, gradcheck_args, check_batched_grad=check_batched_grad, check_grad_dtypes=True, nondet_tol=op.gradcheck_nondet_tol, fast_mode=op.gradcheck_fast_mode, check_forward_ad=check_forward_ad, check_backward_ad=check_backward_ad, check_undefined_grad=True, check_batched_forward_grad=check_batched_forward_grad)) elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck") for gen_non_contig_grad_outputs in (False, True): kwargs = { "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs, "check_batched_grad": op.check_batched_gradgrad, "check_grad_dtypes": True, "nondet_tol": op.gradcheck_nondet_tol, "fast_mode": op.gradcheck_fast_mode } if check == "fwgrad_bwgrad": kwargs["check_fwd_over_rev"] = True kwargs["check_rev_over_rev"] = False kwargs["check_batched_grad"] = False kwargs["check_undefined_grad"] = False self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs)) else: self.assertTrue(False, msg="Unknown check requested!") def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True, check_batched_grad=None, check_batched_forward_grad=False): return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad, check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad, check_batched_forward_grad=check_batched_forward_grad) def _skip_helper(self, op, device, dtype): if dtype not in op.supported_backward_dtypes(torch.device(device).type): self.skipTest("Skipped! Op doesn't support autograd for this dtype.") if not op.supports_autograd and not op.supports_forward_ad: self.skipTest("Skipped! autograd not supported.") # Tests that gradients are computed correctly @_gradcheck_ops(op_db) def test_fn_grad(self, device, dtype, op): # This is verified by test_dtypes in test_ops.py if dtype not in op.supported_backward_dtypes(torch.device(device).type): self.skipTest("Skipped! Dtype is not in supported backward dtypes!") else: self._grad_test_helper(device, dtype, op, op.get_op()) # Method grad (and gradgrad, see below) tests are disabled since they're # costly and redundant with function grad (and gradgad) tests # @_gradcheck_ops(op_db) # def test_method_grad(self, device, dtype, op): # self._skip_helper(op, device, dtype) # self._grad_test_helper(device, dtype, op, op.get_method()) @_gradcheck_ops(op_db) def test_inplace_grad(self, device, dtype, op): self._skip_helper(op, device, dtype) if not op.inplace_variant: self.skipTest("Op has no inplace variant!") # Verifies an operation doesn't support inplace autograd if it claims not to if not op.supports_inplace_autograd: inplace = self._get_safe_inplace(op.get_inplace()) for sample in op.sample_inputs(device, dtype, requires_grad=True): if sample.broadcasts_input: continue with self.assertRaises(Exception): result = inplace(sample) result.sum().backward() else: self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace())) # Test that gradients of gradients are computed correctly @_gradcheck_ops(op_db) def test_fn_gradgrad(self, device, dtype, op): self._skip_helper(op, device, dtype) if not op.supports_gradgrad: self.skipTest("Op claims it doesn't support gradgrad. This is not verified.") else: self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad') # Test that forward-over-reverse gradgrad is computed correctly @_gradcheck_ops(op_db) def test_fn_fwgrad_bwgrad(self, device, dtype, op): self._skip_helper(op, device, dtype) if op.supports_fwgrad_bwgrad: self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") else: err_msg = r"Trying to use forward AD with .* that does not support it" hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not " "raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.") with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") # Test that gradients of gradients are properly raising @_gradcheck_ops(op_db) def test_fn_fail_gradgrad(self, device, dtype, op): self._skip_helper(op, device, dtype) if op.supports_gradgrad: self.skipTest("Skipped! Operation does support gradgrad") err_msg = r"derivative for .* is not implemented" with self.assertRaisesRegex(RuntimeError, err_msg): self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad') # Method gradgrad (and grad, see above) tests are disabled since they're # costly and redundant with function gradgrad (and grad) tests # @_gradcheck_ops(op_db) # def test_method_gradgrad(self, device, dtype, op): # self._skip_helper(op, device, dtype) # self._gradgrad_test_helper(device, dtype, op, op.get_method()) @_gradcheck_ops(op_db) def test_inplace_gradgrad(self, device, dtype, op): self._skip_helper(op, device, dtype) if not op.inplace_variant or not op.supports_inplace_autograd: self.skipTest("Skipped! Operation does not support inplace autograd.") self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad") def _forward_grad_helper(self, device, dtype, op, variant, is_inplace): # TODO: clean up how attributes are passed to gradcheck from OpInfos def call_grad_test_helper(): check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or (op.check_inplace_batched_forward_grad and is_inplace)) self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False, check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad) if op.supports_forward_ad: call_grad_test_helper() else: err_msg = r"Trying to use forward AD with .* that does not support it" hint_msg = ("Running forward AD for an OP that has does not support it did not " "raise any error. If your op supports forward AD, you should set supports_forward_ad=True") with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): call_grad_test_helper() @_gradcheck_ops(op_db) def test_forward_mode_AD(self, device, dtype, op): self._skip_helper(op, device, dtype) self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False) @_gradcheck_ops(op_db) def test_inplace_forward_mode_AD(self, device, dtype, op): self._skip_helper(op, device, dtype) if not op.inplace_variant or not op.supports_inplace_autograd: self.skipTest("Skipped! Operation does not support inplace autograd.") self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True) instantiate_device_type_tests(TestGradients, globals()) if __name__ == '__main__': run_tests()