Make learning_rate_decay.piecewise_constant work in Eager mode.

PiperOrigin-RevId: 173967531
This commit is contained in:
A. Unique TensorFlower 2017-10-30 16:25:14 -07:00 committed by TensorFlower Gardener
parent 0e6abfcdaf
commit 293ba20be1
2 changed files with 59 additions and 58 deletions

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
staircase=False, name=None):
"""Applies exponential decay to the learning rate.
@ -164,13 +165,13 @@ def piecewise_constant(x, boundaries, values, name=None):
raise ValueError(
"Values must have elements all with the same dtype (%s vs %s)." % (
values[0].dtype.base_dtype, v.dtype.base_dtype))
pred_fn_pairs = {}
pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]
pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1]
pred_fn_pairs = []
pred_fn_pairs.append((x <= boundaries[0], lambda: values[0]))
pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1]))
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
# Need to bind v here; can do this with lambda v=v: ...
pred = (x > low) & (x <= high)
pred_fn_pairs[pred] = lambda v=v: v
pred_fn_pairs.append((pred, lambda v=v: v))
# The default isn't needed here because our conditions are mutually
# exclusive and exhaustive, but tf.case requires it.

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import math
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_state_ops
@ -43,7 +44,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
def testStaircase(self):
with self.test_session():
step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
name="step", container="", shared_name="")
name="step", container="", shared_name="")
assign_100 = state_ops.assign(step, 100)
assign_1 = state_ops.assign(step, 1)
assign_2 = state_ops.assign(step, 2)
@ -78,65 +79,63 @@ class LRDecayTest(test_util.TensorFlowTestCase):
expected = .1 * 0.96 ** (100 // 3)
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
@test_util.run_in_graph_and_eager_modes()
def testPiecewiseConstant(self):
with self.test_session():
x = variables.Variable(-999)
assign_100 = x.assign(100)
assign_105 = x.assign(105)
assign_110 = x.assign(110)
assign_120 = x.assign(120)
assign_999 = x.assign(999)
pc = learning_rate_decay.piecewise_constant(x, [100, 110, 120],
[1.0, 0.1, 0.01, 0.001])
x = resource_variable_ops.ResourceVariable(-999)
def pc():
return learning_rate_decay.piecewise_constant(x, [100, 110, 120],
[1.0, 0.1, 0.01, 0.001])
variables.global_variables_initializer().run()
self.assertAllClose(pc.eval(), 1.0, 1e-6)
assign_100.op.run()
self.assertAllClose(pc.eval(), 1.0, 1e-6)
assign_105.op.run()
self.assertAllClose(pc.eval(), 0.1, 1e-6)
assign_110.op.run()
self.assertAllClose(pc.eval(), 0.1, 1e-6)
assign_120.op.run()
self.assertAllClose(pc.eval(), 0.01, 1e-6)
assign_999.op.run()
self.assertAllClose(pc.eval(), 0.001, 1e-6)
self.evaluate(variables.global_variables_initializer())
self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6)
self.evaluate(x.assign(100))
self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6)
self.evaluate(x.assign(105))
self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6)
self.evaluate(x.assign(110))
self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6)
self.evaluate(x.assign(120))
self.assertAllClose(self.evaluate(pc()), 0.01, 1e-6)
self.evaluate(x.assign(999))
self.assertAllClose(self.evaluate(pc()), 0.001, 1e-6)
@test_util.run_in_graph_and_eager_modes()
def testPiecewiseConstantEdgeCases(self):
with self.test_session():
x_int = variables.Variable(0, dtype=variables.dtypes.int32)
boundaries, values = [-1.0, 1.0], [1, 2, 3]
with self.assertRaises(ValueError):
learning_rate_decay.piecewise_constant(x_int, boundaries, values)
x = variables.Variable(0.0)
boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
with self.assertRaises(ValueError):
learning_rate_decay.piecewise_constant(x, boundaries, values)
x_int = resource_variable_ops.ResourceVariable(
0, dtype=variables.dtypes.int32)
boundaries, values = [-1.0, 1.0], [1, 2, 3]
with self.assertRaises(ValueError):
learning_rate_decay.piecewise_constant(x_int, boundaries, values)
x = resource_variable_ops.ResourceVariable(0.0)
boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
with self.assertRaises(ValueError):
learning_rate_decay.piecewise_constant(x, boundaries, values)
# Test that ref types are valid.
# Test that ref types are valid.
if context.in_graph_mode():
x = variables.Variable(0.0)
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
boundaries, values = [1.0, 2.0], [1, 2, 3]
learning_rate_decay.piecewise_constant(x_ref, boundaries, values)
# Test casting boundaries from int32 to int64.
x_int64 = variables.Variable(0, dtype=variables.dtypes.int64)
assign_1 = x_int64.assign(1)
assign_2 = x_int64.assign(2)
assign_3 = x_int64.assign(3)
assign_4 = x_int64.assign(4)
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
pc = learning_rate_decay.piecewise_constant(x_int64, boundaries, values)
# Test casting boundaries from int32 to int64.
x_int64 = resource_variable_ops.ResourceVariable(
0, dtype=variables.dtypes.int64)
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
def pc():
return learning_rate_decay.piecewise_constant(x_int64, boundaries, values)
variables.global_variables_initializer().run()
self.assertAllClose(pc.eval(), 0.4, 1e-6)
assign_1.op.run()
self.assertAllClose(pc.eval(), 0.4, 1e-6)
assign_2.op.run()
self.assertAllClose(pc.eval(), 0.5, 1e-6)
assign_3.op.run()
self.assertAllClose(pc.eval(), 0.6, 1e-6)
assign_4.op.run()
self.assertAllClose(pc.eval(), 0.7, 1e-6)
self.evaluate(variables.global_variables_initializer())
self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6)
self.evaluate(x_int64.assign(1))
self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6)
self.evaluate(x_int64.assign(2))
self.assertAllClose(self.evaluate(pc()), 0.5, 1e-6)
self.evaluate(x_int64.assign(3))
self.assertAllClose(self.evaluate(pc()), 0.6, 1e-6)
self.evaluate(x_int64.assign(4))
self.assertAllClose(self.evaluate(pc()), 0.7, 1e-6)
class LinearDecayTest(test_util.TensorFlowTestCase):
@ -245,6 +244,7 @@ class SqrtDecayTest(test_util.TensorFlowTestCase):
expected = (lr - end_lr) * 0.25 ** power + end_lr
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
class PolynomialDecayTest(test_util.TensorFlowTestCase):
def testBeginWithCycle(self):
@ -265,7 +265,7 @@ class ExponentialDecayTest(test_util.TensorFlowTestCase):
k = 10
decay_rate = 0.96
step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
name="step", container="", shared_name="")
name="step", container="", shared_name="")
assign_step = state_ops.assign(step, 0)
increment_step = state_ops.assign_add(step, 1)
decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step,
@ -282,7 +282,7 @@ class ExponentialDecayTest(test_util.TensorFlowTestCase):
k = 10
decay_rate = 0.96
step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
name="step", container="", shared_name="")
name="step", container="", shared_name="")
assign_step = state_ops.assign(step, 0)
increment_step = state_ops.assign_add(step, 1)
decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr,
@ -305,7 +305,7 @@ class InverseDecayTest(test_util.TensorFlowTestCase):
k = 10
decay_rate = 0.96
step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
name="step", container="", shared_name="")
name="step", container="", shared_name="")
assign_step = state_ops.assign(step, 0)
increment_step = state_ops.assign_add(step, 1)
decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr,
@ -324,7 +324,7 @@ class InverseDecayTest(test_util.TensorFlowTestCase):
k = 10
decay_rate = 0.96
step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
name="step", container="", shared_name="")
name="step", container="", shared_name="")
assign_step = state_ops.assign(step, 0)
increment_step = state_ops.assign_add(step, 1)
decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr,