mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Gradcheck forward AD respects requires grad but run with requires_grad=False (#72309)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72309 Fixes: https://github.com/pytorch/pytorch/issues/72113 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33991570 Pulled By: soulitzer fbshipit-source-id: 610de162e9848d2d3b12e0fb039860fd9dee844f
This commit is contained in:
parent
5fb1eb1b19
commit
a7ecb13610
|
|
@ -792,6 +792,11 @@ void _linalg_check_errors(
|
|||
}
|
||||
}
|
||||
|
||||
bool _requires_fw_or_bw_grad(const Tensor& input) {
|
||||
return ((at::GradMode::is_enabled() && input.requires_grad())
|
||||
|| input._fw_grad(/*level */ 0).defined());
|
||||
}
|
||||
|
||||
// Below of the definitions of the functions operating on a batch that are going to be dispatched
|
||||
// in the main helper functions for the linear algebra operations
|
||||
|
||||
|
|
@ -2382,7 +2387,7 @@ std::tuple<Tensor&, Tensor&> linalg_eigh_out(const Tensor& input, c10::string_vi
|
|||
Tensor linalg_eigvalsh(const Tensor& input, c10::string_view uplo) {
|
||||
// if input requires grad we must compute the eigenvectors to make this function differentiable
|
||||
// the eigenvectors are not exposed to the user
|
||||
if (at::GradMode::is_enabled() && input.requires_grad()) {
|
||||
if (_requires_fw_or_bw_grad(input)) {
|
||||
Tensor values;
|
||||
std::tie(values, std::ignore) = at::linalg_eigh(input, uplo);
|
||||
return values;
|
||||
|
|
@ -2878,7 +2883,7 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
|
|||
Tensor linalg_eigvals(const Tensor& input) {
|
||||
// if input requires grad we must compute the eigenvectors to make this function differentiable
|
||||
// the eigenvectors are not exposed to the user
|
||||
if (at::GradMode::is_enabled() && input.requires_grad()) {
|
||||
if (_requires_fw_or_bw_grad(input)) {
|
||||
return std::get<0>(at::linalg_eig(input));
|
||||
}
|
||||
|
||||
|
|
@ -3063,10 +3068,7 @@ Tensor& linalg_svdvals_out(const Tensor& A, Tensor & S) {
|
|||
}
|
||||
|
||||
Tensor linalg_svdvals(const Tensor& A) {
|
||||
const bool A_requires_grad = (at::GradMode::is_enabled() && A.requires_grad())
|
||||
|| A._fw_grad(/*level */ 0).defined()
|
||||
|| isTensorSubclassLike(A);
|
||||
return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false, /*comptue_uv=*/A_requires_grad));
|
||||
return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false, /*comptue_uv=*/_requires_fw_or_bw_grad(A)));
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, bool compute_uv, Tensor& U, Tensor& S, Tensor& V) {
|
||||
|
|
|
|||
|
|
@ -798,7 +798,7 @@ embedding_bag(const Tensor &weight, const Tensor &indices,
|
|||
padding_idx = maybe_wrap_dim(padding_idx, weight.size(0));
|
||||
}
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> out;
|
||||
if (!weight.requires_grad()) {
|
||||
if (!weight.requires_grad() && !weight._fw_grad(/*level=*/0).defined()) {
|
||||
out = at::_embedding_bag_forward_only(
|
||||
weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq,
|
||||
mode, sparse, per_sample_weights, include_last_offset, padding_idx);
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ Tensor max_pool1d(
|
|||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
if ((self.requires_grad() && at::GradMode::is_enabled()) ||
|
||||
self._fw_grad(/*level */ 0).defined() ||
|
||||
!self.device().is_cpu()) {
|
||||
// Needs indices for grad and with_indices defines CUDA dispatch
|
||||
return std::get<0>(at::max_pool1d_with_indices(
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ static Tensor _mkldnn_pooling(
|
|||
// for inference, don't need the indices, set aprop_kind to prop_kind::forward_inference
|
||||
// can reduce the memory use.
|
||||
if (ideep::algorithm::pooling_max == algo
|
||||
&& !(input.requires_grad() && at::GradMode::is_enabled())) {
|
||||
&& !((input.requires_grad() && at::GradMode::is_enabled()) || input._fw_grad(/*level */ 0).defined())) {
|
||||
aprop_kind = ideep::prop_kind::forward_inference;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4064,6 +4064,82 @@ class TestAutograd(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
|
||||
|
||||
def test_gradcheck_forward_ad_runs_with_no_requires_grad(self):
|
||||
# Currently requires_grad is used as a easy way for gradcheck to know
|
||||
# which inputs of the function are meant to be differentiable
|
||||
# This test checks that when the inputs are passed to the function they should not have
|
||||
# requires_grad=True even though they may have requires_grad=True when passed
|
||||
# to gradcheck
|
||||
class UserFn(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
if fwAD._current_level >= 0:
|
||||
self.assertFalse(x.requires_grad)
|
||||
self.assertFalse(y.requires_grad)
|
||||
return x.clone(), y.clone()
|
||||
|
||||
@staticmethod
|
||||
def jvp(ctx, x_t, y_t):
|
||||
return x_t, y_t
|
||||
|
||||
x = torch.rand(2, dtype=torch.double, requires_grad=True)
|
||||
y = torch.rand(2, dtype=torch.double, requires_grad=True)
|
||||
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=False)
|
||||
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=False)
|
||||
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=True)
|
||||
|
||||
x = torch.rand(2, dtype=torch.double, requires_grad=True)
|
||||
y = torch.rand(2, dtype=torch.double, requires_grad=False)
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=True)
|
||||
|
||||
def test_gradcheck_forward_ad_respects_requires_grad(self):
|
||||
# Currently requires_grad is used as a easy way for gradcheck to know
|
||||
# which inputs of the function are meant to be differentiable
|
||||
jvp_count = [0]
|
||||
|
||||
class UserFn(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
return x.clone(), y.clone()
|
||||
|
||||
@staticmethod
|
||||
def jvp(ctx, x_t, y_t):
|
||||
jvp_count[0] += 1
|
||||
return x_t, y_t
|
||||
|
||||
x = torch.rand(2, dtype=torch.double, requires_grad=True)
|
||||
y = torch.rand(2, dtype=torch.double, requires_grad=True)
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=False, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=False)
|
||||
self.assertEqual(jvp_count[0], 2) # (2) once per input
|
||||
jvp_count = [0]
|
||||
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=False)
|
||||
self.assertEqual(jvp_count[0], 6) # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2)
|
||||
jvp_count = [0]
|
||||
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=True)
|
||||
self.assertEqual(jvp_count[0], 12) # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2)
|
||||
jvp_count = [0]
|
||||
|
||||
# Repeat the previous test except we mark one input with requires_grad=False
|
||||
# NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)!
|
||||
# Otherwise, other counts are halved.
|
||||
x = torch.rand(2, dtype=torch.double, requires_grad=True)
|
||||
y = torch.rand(2, dtype=torch.double, requires_grad=False)
|
||||
gradcheck(UserFn.apply, (x, y), check_forward_ad=True, check_undefined_grad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=True)
|
||||
self.assertEqual(jvp_count[0], 5) # 1 + 1 + 3
|
||||
|
||||
def test_gradcheck_check_forward_or_backward_only(self):
|
||||
"""Depending on settings for check_forward_ad and check_backward_ad, the
|
||||
correct codepaths should be reached (or not reached)
|
||||
|
|
|
|||
|
|
@ -851,6 +851,7 @@ class TestGradCheckOverride(TestCase):
|
|||
'new_zeros',
|
||||
'numel',
|
||||
'requires_grad',
|
||||
'requires_grad_',
|
||||
'retain_grad',
|
||||
'size',
|
||||
'stride',
|
||||
|
|
@ -867,6 +868,7 @@ class TestGradCheckOverride(TestCase):
|
|||
torch.Tensor.numel,
|
||||
torch.Tensor.retain_grad,
|
||||
torch.Tensor.stride,
|
||||
torch.Tensor.requires_grad_,
|
||||
torch.autograd.grad,
|
||||
torch.add,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -329,7 +329,7 @@ def _get_analytical_jacobian_forward_ad(fn, inputs, outputs, *, check_grad_dtype
|
|||
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
|
||||
raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
|
||||
|
||||
inp = fwAD.make_dual(inp, torch.zeros_like(inp))
|
||||
inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
|
||||
# If inp is a differentiable view, the dual might not be the tangent given to
|
||||
# make_dual, so read it explicitly from the dual tensor
|
||||
fw_grads.append(fwAD.unpack_dual(inp)[1])
|
||||
|
|
@ -760,10 +760,14 @@ def _test_batched_grad_forward_ad(func, inputs) -> bool:
|
|||
assert isinstance(inputs, tuple)
|
||||
|
||||
for input_idx, current_input in enumerate(inputs):
|
||||
if not (is_tensor_like(current_input) and current_input.requires_grad):
|
||||
continue
|
||||
|
||||
def jvp(tangent: torch.Tensor):
|
||||
with fwAD.dual_level():
|
||||
dual = fwAD.make_dual(current_input, tangent)
|
||||
inputs_with_dual = tuple(dual if idx == input_idx else inp for idx, inp in enumerate(inputs))
|
||||
dual = fwAD.make_dual(current_input.detach(), tangent)
|
||||
inputs_with_dual = tuple(dual if idx == input_idx else (inp.detach() if is_tensor_like(inp) else inp)
|
||||
for idx, inp in enumerate(inputs))
|
||||
dual_outputs = _as_tuple(func(*inputs_with_dual))
|
||||
ret = []
|
||||
for dual_output in dual_outputs:
|
||||
|
|
@ -888,7 +892,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
|
|||
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
|
||||
raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
|
||||
|
||||
inp = fwAD.make_dual(inp, torch.zeros_like(inp))
|
||||
inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
|
||||
# If inp is a differentiable view, the dual might not be the tangent given to
|
||||
# make_dual, so read it explicitly from the dual tensor
|
||||
fw_grads.append(fwAD.unpack_dual(inp)[1])
|
||||
|
|
@ -904,12 +908,12 @@ def _test_undefined_forward_mode(func, outputs, inputs):
|
|||
dual_inp_obj = dual_inputs[idx]
|
||||
|
||||
# case 1 (Materialized Zero Tensor Tangent)
|
||||
dual_inputs[idx] = fwAD.make_dual(inp, torch.zeros_like(inp))
|
||||
dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
|
||||
raw_outputs = _as_tuple(func(*dual_inputs))
|
||||
dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
|
||||
|
||||
# case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
|
||||
dual_inputs[idx] = inp
|
||||
dual_inputs[idx] = inp.detach()
|
||||
raw_outputs = _as_tuple(func(*dual_inputs))
|
||||
dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
|
||||
|
||||
|
|
@ -1532,13 +1536,18 @@ def gradgradcheck(
|
|||
|
||||
num_outputs = len(tupled_grad_outputs)
|
||||
|
||||
# NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
|
||||
# before running forward mode AD
|
||||
diff_input_args_indices = set(i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad)
|
||||
diff_grad_output_indices = set(i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad)
|
||||
|
||||
def new_func(*args):
|
||||
input_args = args[:-num_outputs]
|
||||
grad_outputs = args[-num_outputs:]
|
||||
# Restore the requires_grad information
|
||||
input_args = tuple(x.requires_grad_() if i in diff_input_args_indices else x for i, x in enumerate(args[:-num_outputs]))
|
||||
outputs = _differentiable_outputs(func(*input_args))
|
||||
input_args = tuple(x for x in input_args
|
||||
if is_tensor_like(x) and x.requires_grad)
|
||||
grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True,
|
||||
grad_outputs = tuple(x.requires_grad_() if i in diff_grad_output_indices else x for i, x in enumerate(args[-num_outputs:]))
|
||||
diff_input_args = tuple(x for i, x in enumerate(input_args) if i in diff_input_args_indices)
|
||||
grad_inputs = torch.autograd.grad(outputs, diff_input_args, grad_outputs, create_graph=True,
|
||||
allow_unused=True)
|
||||
grad_inputs = tuple(g for g in grad_inputs if g is not None)
|
||||
return grad_inputs
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user