mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Support SymbolicGradient for functions with non-trainable arguments.
The non-trainable arguments end up with None as their incoming out_grad, which is not a valid input to SymbolicGradient (inputs have to be convertible to Tensor, and None isn't). PiperOrigin-RevId: 173901727
This commit is contained in:
parent
494672475b
commit
4723f8f6ed
|
|
@ -864,6 +864,24 @@ class FunctionTest(test.TestCase):
|
||||||
[result])
|
[result])
|
||||||
self.assertEqual(len(f.signature.input_arg), 3)
|
self.assertEqual(len(f.signature.input_arg), 3)
|
||||||
|
|
||||||
|
def testGradientWithIntegerFunctionArgument(self):
|
||||||
|
@function.Defun(dtypes.int32, dtypes.float32)
|
||||||
|
def Foo(t, x):
|
||||||
|
return x[t]
|
||||||
|
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
inp = array_ops.placeholder(dtypes.float32)
|
||||||
|
t = constant_op.constant(0, dtypes.int32)
|
||||||
|
out = Foo(t, inp)
|
||||||
|
dinp, = gradients_impl.gradients(out, [inp])
|
||||||
|
|
||||||
|
x = np.zeros((2,)).astype(np.float32)
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
|
self.assertAllClose(
|
||||||
|
np.array([1.0, 0.0]).astype(np.float32),
|
||||||
|
sess.run(dinp, {inp: x}))
|
||||||
|
|
||||||
|
|
||||||
@test_util.with_c_api
|
@test_util.with_c_api
|
||||||
class FunctionsFromProtos(test.TestCase):
|
class FunctionsFromProtos(test.TestCase):
|
||||||
|
|
|
||||||
|
|
@ -582,8 +582,10 @@ def gradients(ys,
|
||||||
# therefore dC/doutput[i] is 0.
|
# therefore dC/doutput[i] is 0.
|
||||||
for i, out_grad in enumerate(out_grads):
|
for i, out_grad in enumerate(out_grads):
|
||||||
if (not isinstance(out_grad, ops.Tensor) and
|
if (not isinstance(out_grad, ops.Tensor) and
|
||||||
not out_grad) and _IsTrainable(op.outputs[i]):
|
not out_grad) and ((not grad_fn and is_func_call) or
|
||||||
# Only floating-point outputs get a zero gradient. Gradient
|
_IsTrainable(op.outputs[i])):
|
||||||
|
# Only trainable outputs or outputs for a function call that
|
||||||
|
# will use SymbolicGradient get a zero gradient. Gradient
|
||||||
# functions should ignore the gradient for other outputs.
|
# functions should ignore the gradient for other outputs.
|
||||||
# TODO(apassos) gradients of resource handles might be an
|
# TODO(apassos) gradients of resource handles might be an
|
||||||
# issue here because of zeros.
|
# issue here because of zeros.
|
||||||
|
|
@ -670,15 +672,15 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
|
||||||
grad_state.pending_exits_count -= 1
|
grad_state.pending_exits_count -= 1
|
||||||
if grad_state.pending_exits_count == 0:
|
if grad_state.pending_exits_count == 0:
|
||||||
# We now have all the exits so process them.
|
# We now have all the exits so process them.
|
||||||
has_real_grad = False
|
has_not_none_grad = False
|
||||||
for y in grad_state.deferred_exits:
|
for y in grad_state.deferred_exits:
|
||||||
if _HasAnyNotNoneGrads(grads, y.op):
|
if _HasAnyNotNoneGrads(grads, y.op):
|
||||||
has_real_grad = True
|
has_not_none_grad = True
|
||||||
queue.append(y.op)
|
queue.append(y.op)
|
||||||
else:
|
else:
|
||||||
grad_state.unused_exits.append(y)
|
grad_state.unused_exits.append(y)
|
||||||
if has_real_grad:
|
if has_not_none_grad:
|
||||||
# For an unused exit, if it has floating-point outputs, backprop
|
# For an unused exit, if it has trainable outputs, backprop
|
||||||
# a zero gradient. Otherwise, just ignore it.
|
# a zero gradient. Otherwise, just ignore it.
|
||||||
for y in grad_state.unused_exits:
|
for y in grad_state.unused_exits:
|
||||||
if _IsTrainable(y):
|
if _IsTrainable(y):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user