mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Support SymbolicGradient for functions with non-trainable arguments.
The non-trainable arguments end up with None as their incoming out_grad, which is not a valid input to SymbolicGradient (inputs have to be convertible to Tensor, and None isn't). PiperOrigin-RevId: 173901727
This commit is contained in:
parent
494672475b
commit
4723f8f6ed
|
|
@ -864,6 +864,24 @@ class FunctionTest(test.TestCase):
|
|||
[result])
|
||||
self.assertEqual(len(f.signature.input_arg), 3)
|
||||
|
||||
def testGradientWithIntegerFunctionArgument(self):
|
||||
@function.Defun(dtypes.int32, dtypes.float32)
|
||||
def Foo(t, x):
|
||||
return x[t]
|
||||
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
inp = array_ops.placeholder(dtypes.float32)
|
||||
t = constant_op.constant(0, dtypes.int32)
|
||||
out = Foo(t, inp)
|
||||
dinp, = gradients_impl.gradients(out, [inp])
|
||||
|
||||
x = np.zeros((2,)).astype(np.float32)
|
||||
with session.Session(graph=g) as sess:
|
||||
self.assertAllClose(
|
||||
np.array([1.0, 0.0]).astype(np.float32),
|
||||
sess.run(dinp, {inp: x}))
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class FunctionsFromProtos(test.TestCase):
|
||||
|
|
|
|||
|
|
@ -582,8 +582,10 @@ def gradients(ys,
|
|||
# therefore dC/doutput[i] is 0.
|
||||
for i, out_grad in enumerate(out_grads):
|
||||
if (not isinstance(out_grad, ops.Tensor) and
|
||||
not out_grad) and _IsTrainable(op.outputs[i]):
|
||||
# Only floating-point outputs get a zero gradient. Gradient
|
||||
not out_grad) and ((not grad_fn and is_func_call) or
|
||||
_IsTrainable(op.outputs[i])):
|
||||
# Only trainable outputs or outputs for a function call that
|
||||
# will use SymbolicGradient get a zero gradient. Gradient
|
||||
# functions should ignore the gradient for other outputs.
|
||||
# TODO(apassos) gradients of resource handles might be an
|
||||
# issue here because of zeros.
|
||||
|
|
@ -670,15 +672,15 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
|
|||
grad_state.pending_exits_count -= 1
|
||||
if grad_state.pending_exits_count == 0:
|
||||
# We now have all the exits so process them.
|
||||
has_real_grad = False
|
||||
has_not_none_grad = False
|
||||
for y in grad_state.deferred_exits:
|
||||
if _HasAnyNotNoneGrads(grads, y.op):
|
||||
has_real_grad = True
|
||||
has_not_none_grad = True
|
||||
queue.append(y.op)
|
||||
else:
|
||||
grad_state.unused_exits.append(y)
|
||||
if has_real_grad:
|
||||
# For an unused exit, if it has floating-point outputs, backprop
|
||||
if has_not_none_grad:
|
||||
# For an unused exit, if it has trainable outputs, backprop
|
||||
# a zero gradient. Otherwise, just ignore it.
|
||||
for y in grad_state.unused_exits:
|
||||
if _IsTrainable(y):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user