Support SymbolicGradient for functions with non-trainable arguments.

The non-trainable arguments end up with None as their incoming out_grad, which is not a valid input to SymbolicGradient (inputs have to be convertible to Tensor, and None isn't).

PiperOrigin-RevId: 173901727
This commit is contained in:
RJ Ryan 2017-10-30 09:11:59 -07:00 committed by TensorFlower Gardener
parent 494672475b
commit 4723f8f6ed
2 changed files with 26 additions and 6 deletions

View File

@ -864,6 +864,24 @@ class FunctionTest(test.TestCase):
[result]) [result])
self.assertEqual(len(f.signature.input_arg), 3) self.assertEqual(len(f.signature.input_arg), 3)
def testGradientWithIntegerFunctionArgument(self):
@function.Defun(dtypes.int32, dtypes.float32)
def Foo(t, x):
return x[t]
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(dtypes.float32)
t = constant_op.constant(0, dtypes.int32)
out = Foo(t, inp)
dinp, = gradients_impl.gradients(out, [inp])
x = np.zeros((2,)).astype(np.float32)
with session.Session(graph=g) as sess:
self.assertAllClose(
np.array([1.0, 0.0]).astype(np.float32),
sess.run(dinp, {inp: x}))
@test_util.with_c_api @test_util.with_c_api
class FunctionsFromProtos(test.TestCase): class FunctionsFromProtos(test.TestCase):

View File

@ -582,8 +582,10 @@ def gradients(ys,
# therefore dC/doutput[i] is 0. # therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads): for i, out_grad in enumerate(out_grads):
if (not isinstance(out_grad, ops.Tensor) and if (not isinstance(out_grad, ops.Tensor) and
not out_grad) and _IsTrainable(op.outputs[i]): not out_grad) and ((not grad_fn and is_func_call) or
# Only floating-point outputs get a zero gradient. Gradient _IsTrainable(op.outputs[i])):
# Only trainable outputs or outputs for a function call that
# will use SymbolicGradient get a zero gradient. Gradient
# functions should ignore the gradient for other outputs. # functions should ignore the gradient for other outputs.
# TODO(apassos) gradients of resource handles might be an # TODO(apassos) gradients of resource handles might be an
# issue here because of zeros. # issue here because of zeros.
@ -670,15 +672,15 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
grad_state.pending_exits_count -= 1 grad_state.pending_exits_count -= 1
if grad_state.pending_exits_count == 0: if grad_state.pending_exits_count == 0:
# We now have all the exits so process them. # We now have all the exits so process them.
has_real_grad = False has_not_none_grad = False
for y in grad_state.deferred_exits: for y in grad_state.deferred_exits:
if _HasAnyNotNoneGrads(grads, y.op): if _HasAnyNotNoneGrads(grads, y.op):
has_real_grad = True has_not_none_grad = True
queue.append(y.op) queue.append(y.op)
else: else:
grad_state.unused_exits.append(y) grad_state.unused_exits.append(y)
if has_real_grad: if has_not_none_grad:
# For an unused exit, if it has floating-point outputs, backprop # For an unused exit, if it has trainable outputs, backprop
# a zero gradient. Otherwise, just ignore it. # a zero gradient. Otherwise, just ignore it.
for y in grad_state.unused_exits: for y in grad_state.unused_exits:
if _IsTrainable(y): if _IsTrainable(y):