Fix _check_no_differentiable_outputs for forward ad (#91391)

This `is_forward_ad` isn't propagated, which leads to this line creating a
slow-gradcheck failure on master:
```
    if not is_forward_ad and any(o.is_complex() for o in outputs):
        raise ValueError("Expected output to be non-complex. get_numerical_jacobian no "
                         "longer supports functions that return complex outputs.")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91391
Approved by: https://github.com/albanD
This commit is contained in:
Peter Bell 2022-12-26 18:07:49 +00:00 committed by PyTorch MergeBot
parent a061f139dc
commit bb24185ff4

View File

@ -725,10 +725,11 @@ def _check_outputs(outputs) -> None:
'Please call to_dense() on the output of fn for gradcheck.')
def _check_no_differentiable_outputs(func, inputs, func_out, eps) -> bool:
def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool:
# When there are no differentiable outputs, numerical gradient for a function is
# expected to be zero.
jacobians_all_inputs_outputs = _get_numerical_jacobian(func, inputs, func_out, eps=eps)
jacobians_all_inputs_outputs = _get_numerical_jacobian(func, inputs, func_out,
eps=eps, is_forward_ad=is_forward_ad)
for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs:
for jacobian in jacobians_all_outputs_and_fixed_input:
if torch.ne(jacobian, 0).sum() > 0:
@ -1143,7 +1144,8 @@ def _slow_gradcheck(func, func_out, tupled_inputs, outputs, eps, rtol, atol, che
nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False):
func_out = _as_tuple(func_out)
if not outputs:
return _check_no_differentiable_outputs(func, tupled_inputs, func_out, eps)
return _check_no_differentiable_outputs(func, tupled_inputs, func_out,
eps=eps, is_forward_ad=use_forward_ad)
numerical = _transpose(_get_numerical_jacobian(func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad))
# Note: [numerical vs analytical output length]