mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Report a nicer error message when differentiating a function
that returns None in eager PiperOrigin-RevId: 173914883
This commit is contained in:
parent
85f8d92408
commit
e8ac0b48f4
|
|
@ -332,6 +332,9 @@ def implicit_val_and_grad(f):
|
|||
A function which, when called, returns a tuple pair.
|
||||
Its first element is the value to which the function evaluates.
|
||||
Its second element is list of (gradient, variable) pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: if `f` returns None.
|
||||
"""
|
||||
# TODO(cais): Remove calls to tf.constant() once the gradients functions
|
||||
# accept lists and np.ndarrays.
|
||||
|
|
@ -341,6 +344,10 @@ def implicit_val_and_grad(f):
|
|||
tape.push_new_tape()
|
||||
try:
|
||||
end_node = f(*args)
|
||||
if end_node is None:
|
||||
raise ValueError("Cannot differentiate a function that returns None; "
|
||||
"did you forget to return a value from {}?".format(
|
||||
f.__name__))
|
||||
variables = tape.top_tape_watched_variables()
|
||||
finally:
|
||||
popped_tape = tape.pop_tape()
|
||||
|
|
@ -630,6 +637,8 @@ def make_vjp(f, params=None):
|
|||
# result is 9.0
|
||||
vjp() # the vjp function rturns 6.0
|
||||
|
||||
Raises:
|
||||
ValueError: if `f` returns None.
|
||||
"""
|
||||
|
||||
def decorated(*args, **kwds):
|
||||
|
|
@ -649,6 +658,10 @@ def make_vjp(f, params=None):
|
|||
sources.append(args[i])
|
||||
tape.watch(args[i])
|
||||
result = f(*args)
|
||||
if result is None:
|
||||
raise ValueError("Cannot differentiate a function that returns None; "
|
||||
"did you forget to return a value from {}?".format(
|
||||
f.__name__))
|
||||
flat_result = nest.flatten(result)
|
||||
flat_result = [gen_array_ops.identity(x) for x in flat_result]
|
||||
result = nest.pack_sequence_as(result, flat_result)
|
||||
|
|
|
|||
|
|
@ -588,5 +588,26 @@ class BackpropTest(test.TestCase):
|
|||
|
||||
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
|
||||
|
||||
def testDifferentiatingFunctionThatReturnsNone(self):
|
||||
|
||||
def fn(x, y):
|
||||
result = x*y # pylint: disable=unused-variable
|
||||
|
||||
x = constant_op.constant(1)
|
||||
y = constant_op.constant(2)
|
||||
|
||||
loss_grads_fn = backprop.implicit_val_and_grad(fn)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Cannot differentiate a function that returns None; '
|
||||
'did you forget to return a value from fn?'):
|
||||
loss_grads_fn(x, y)
|
||||
|
||||
val_and_grads_fn = backprop.val_and_grad_function(fn)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Cannot differentiate a function that returns None; '
|
||||
'did you forget to return a value from fn?'):
|
||||
val_and_grads_fn(x, y)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user