mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix _check_no_differentiable_outputs for forward ad (#91391)
This `is_forward_ad` isn't propagated, which leads to this line creating a
slow-gradcheck failure on master:
```
if not is_forward_ad and any(o.is_complex() for o in outputs):
raise ValueError("Expected output to be non-complex. get_numerical_jacobian no "
"longer supports functions that return complex outputs.")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91391
Approved by: https://github.com/albanD
This commit is contained in:
parent
a061f139dc
commit
bb24185ff4
|
|
@ -725,10 +725,11 @@ def _check_outputs(outputs) -> None:
|
|||
'Please call to_dense() on the output of fn for gradcheck.')
|
||||
|
||||
|
||||
def _check_no_differentiable_outputs(func, inputs, func_out, eps) -> bool:
|
||||
def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool:
|
||||
# When there are no differentiable outputs, numerical gradient for a function is
|
||||
# expected to be zero.
|
||||
jacobians_all_inputs_outputs = _get_numerical_jacobian(func, inputs, func_out, eps=eps)
|
||||
jacobians_all_inputs_outputs = _get_numerical_jacobian(func, inputs, func_out,
|
||||
eps=eps, is_forward_ad=is_forward_ad)
|
||||
for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs:
|
||||
for jacobian in jacobians_all_outputs_and_fixed_input:
|
||||
if torch.ne(jacobian, 0).sum() > 0:
|
||||
|
|
@ -1143,7 +1144,8 @@ def _slow_gradcheck(func, func_out, tupled_inputs, outputs, eps, rtol, atol, che
|
|||
nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False):
|
||||
func_out = _as_tuple(func_out)
|
||||
if not outputs:
|
||||
return _check_no_differentiable_outputs(func, tupled_inputs, func_out, eps)
|
||||
return _check_no_differentiable_outputs(func, tupled_inputs, func_out,
|
||||
eps=eps, is_forward_ad=use_forward_ad)
|
||||
|
||||
numerical = _transpose(_get_numerical_jacobian(func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad))
|
||||
# Note: [numerical vs analytical output length]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user