Report a nicer error message when differentiating a function

that returns None in eager

PiperOrigin-RevId: 173914883
This commit is contained in:
Akshay Agrawal 2017-10-30 10:44:06 -07:00 committed by TensorFlower Gardener
parent 85f8d92408
commit e8ac0b48f4
2 changed files with 34 additions and 0 deletions

View File

@ -332,6 +332,9 @@ def implicit_val_and_grad(f):
A function which, when called, returns a tuple pair.
Its first element is the value to which the function evaluates.
Its second element is list of (gradient, variable) pairs.
Raises:
ValueError: if `f` returns None.
"""
# TODO(cais): Remove calls to tf.constant() once the gradients functions
# accept lists and np.ndarrays.
@ -341,6 +344,10 @@ def implicit_val_and_grad(f):
tape.push_new_tape()
try:
end_node = f(*args)
if end_node is None:
raise ValueError("Cannot differentiate a function that returns None; "
"did you forget to return a value from {}?".format(
f.__name__))
variables = tape.top_tape_watched_variables()
finally:
popped_tape = tape.pop_tape()
@ -630,6 +637,8 @@ def make_vjp(f, params=None):
# result is 9.0
vjp() # the vjp function rturns 6.0
Raises:
ValueError: if `f` returns None.
"""
def decorated(*args, **kwds):
@ -649,6 +658,10 @@ def make_vjp(f, params=None):
sources.append(args[i])
tape.watch(args[i])
result = f(*args)
if result is None:
raise ValueError("Cannot differentiate a function that returns None; "
"did you forget to return a value from {}?".format(
f.__name__))
flat_result = nest.flatten(result)
flat_result = [gen_array_ops.identity(x) for x in flat_result]
result = nest.pack_sequence_as(result, flat_result)

View File

@ -588,5 +588,26 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
def testDifferentiatingFunctionThatReturnsNone(self):
def fn(x, y):
result = x*y # pylint: disable=unused-variable
x = constant_op.constant(1)
y = constant_op.constant(2)
loss_grads_fn = backprop.implicit_val_and_grad(fn)
with self.assertRaisesRegexp(
ValueError, 'Cannot differentiate a function that returns None; '
'did you forget to return a value from fn?'):
loss_grads_fn(x, y)
val_and_grads_fn = backprop.val_and_grad_function(fn)
with self.assertRaisesRegexp(
ValueError, 'Cannot differentiate a function that returns None; '
'did you forget to return a value from fn?'):
val_and_grads_fn(x, y)
if __name__ == '__main__':
test.main()