Report a nicer error message when differentiating a function

that returns None in eager

PiperOrigin-RevId: 173914883
This commit is contained in:
Akshay Agrawal 2017-10-30 10:44:06 -07:00 committed by TensorFlower Gardener
parent 85f8d92408
commit e8ac0b48f4
2 changed files with 34 additions and 0 deletions

View File

@ -332,6 +332,9 @@ def implicit_val_and_grad(f):
A function which, when called, returns a tuple pair. A function which, when called, returns a tuple pair.
Its first element is the value to which the function evaluates. Its first element is the value to which the function evaluates.
Its second element is list of (gradient, variable) pairs. Its second element is list of (gradient, variable) pairs.
Raises:
ValueError: if `f` returns None.
""" """
# TODO(cais): Remove calls to tf.constant() once the gradients functions # TODO(cais): Remove calls to tf.constant() once the gradients functions
# accept lists and np.ndarrays. # accept lists and np.ndarrays.
@ -341,6 +344,10 @@ def implicit_val_and_grad(f):
tape.push_new_tape() tape.push_new_tape()
try: try:
end_node = f(*args) end_node = f(*args)
if end_node is None:
raise ValueError("Cannot differentiate a function that returns None; "
"did you forget to return a value from {}?".format(
f.__name__))
variables = tape.top_tape_watched_variables() variables = tape.top_tape_watched_variables()
finally: finally:
popped_tape = tape.pop_tape() popped_tape = tape.pop_tape()
@ -630,6 +637,8 @@ def make_vjp(f, params=None):
# result is 9.0 # result is 9.0
vjp() # the vjp function rturns 6.0 vjp() # the vjp function rturns 6.0
Raises:
ValueError: if `f` returns None.
""" """
def decorated(*args, **kwds): def decorated(*args, **kwds):
@ -649,6 +658,10 @@ def make_vjp(f, params=None):
sources.append(args[i]) sources.append(args[i])
tape.watch(args[i]) tape.watch(args[i])
result = f(*args) result = f(*args)
if result is None:
raise ValueError("Cannot differentiate a function that returns None; "
"did you forget to return a value from {}?".format(
f.__name__))
flat_result = nest.flatten(result) flat_result = nest.flatten(result)
flat_result = [gen_array_ops.identity(x) for x in flat_result] flat_result = [gen_array_ops.identity(x) for x in flat_result]
result = nest.pack_sequence_as(result, flat_result) result = nest.pack_sequence_as(result, flat_result)

View File

@ -588,5 +588,26 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0) self.assertAllEqual(backprop.gradients_function(my_identity)(1.0)[0], 2.0)
def testDifferentiatingFunctionThatReturnsNone(self):
def fn(x, y):
result = x*y # pylint: disable=unused-variable
x = constant_op.constant(1)
y = constant_op.constant(2)
loss_grads_fn = backprop.implicit_val_and_grad(fn)
with self.assertRaisesRegexp(
ValueError, 'Cannot differentiate a function that returns None; '
'did you forget to return a value from fn?'):
loss_grads_fn(x, y)
val_and_grads_fn = backprop.val_and_grad_function(fn)
with self.assertRaisesRegexp(
ValueError, 'Cannot differentiate a function that returns None; '
'did you forget to return a value from fn?'):
val_and_grads_fn(x, y)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()