mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Report a nicer error message when differentiating a function
that returns None in eager PiperOrigin-RevId: 173914883
This commit is contained in:
parent
85f8d92408
commit
e8ac0b48f4
|
|
@ -332,6 +332,9 @@ def implicit_val_and_grad(f):
|
||||||
A function which, when called, returns a tuple pair.
|
A function which, when called, returns a tuple pair.
|
||||||
Its first element is the value to which the function evaluates.
|
Its first element is the value to which the function evaluates.
|
||||||
Its second element is list of (gradient, variable) pairs.
|
Its second element is list of (gradient, variable) pairs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `f` returns None.
|
||||||
"""
|
"""
|
||||||
# TODO(cais): Remove calls to tf.constant() once the gradients functions
|
# TODO(cais): Remove calls to tf.constant() once the gradients functions
|
||||||
# accept lists and np.ndarrays.
|
# accept lists and np.ndarrays.
|
||||||
|
|
@ -341,6 +344,10 @@ def implicit_val_and_grad(f):
|
||||||
tape.push_new_tape()
|
tape.push_new_tape()
|
||||||
try:
|
try:
|
||||||
end_node = f(*args)
|
end_node = f(*args)
|
||||||
|
if end_node is None:
|
||||||
|
raise ValueError("Cannot differentiate a function that returns None; "
|
||||||
|
"did you forget to return a value from {}?".format(
|
||||||
|
f.__name__))
|
||||||
variables = tape.top_tape_watched_variables()
|
variables = tape.top_tape_watched_variables()
|
||||||
finally:
|
finally:
|
||||||
popped_tape = tape.pop_tape()
|
popped_tape = tape.pop_tape()
|
||||||
|
|
@ -630,6 +637,8 @@ def make_vjp(f, params=None):
|
||||||
# result is 9.0
|
# result is 9.0
|
||||||
vjp() # the vjp function rturns 6.0
|
vjp() # the vjp function rturns 6.0
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `f` returns None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorated(*args, **kwds):
|
def decorated(*args, **kwds):
|
||||||
|
|
@ -649,6 +658,10 @@ def make_vjp(f, params=None):
|
||||||
sources.append(args[i])
|
sources.append(args[i])
|
||||||
tape.watch(args[i])
|
tape.watch(args[i])
|
||||||
result = f(*args)
|
result = f(*args)
|
||||||
|
if result is None:
|
||||||
|
raise ValueError("Cannot differentiate a function that returns None; "
|
||||||
|
"did you forget to return a value from {}?".format(
|
||||||
|
f.__name__))
|
||||||
flat_result = nest.flatten(result)
|
flat_result = nest.flatten(result)
|
||||||
flat_result = [gen_array_ops.identity(x) for x in flat_result]
|
flat_result = [gen_array_ops.identity(x) for x in flat_result]
|
||||||
result = nest.pack_sequence_as(result, flat_result)
|
result = nest.pack_sequence_as(result, flat_result)
|
||||||
|
|
|
||||||
|
|
@ -588,5 +588,26 @@ class BackpropTest(test.TestCase):
|
||||||
|
|
||||||
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
|
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
|
||||||
|
|
||||||
|
def testDifferentiatingFunctionThatReturnsNone(self):
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
result = x*y # pylint: disable=unused-variable
|
||||||
|
|
||||||
|
x = constant_op.constant(1)
|
||||||
|
y = constant_op.constant(2)
|
||||||
|
|
||||||
|
loss_grads_fn = backprop.implicit_val_and_grad(fn)
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'Cannot differentiate a function that returns None; '
|
||||||
|
'did you forget to return a value from fn?'):
|
||||||
|
loss_grads_fn(x, y)
|
||||||
|
|
||||||
|
val_and_grads_fn = backprop.val_and_grad_function(fn)
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'Cannot differentiate a function that returns None; '
|
||||||
|
'did you forget to return a value from fn?'):
|
||||||
|
val_and_grads_fn(x, y)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user