Compare base_dtype instead of dtype in piecewise_constant (#10280)

* Compare base_dtype instead of dtype in piecewise_constant

Compare base_dtype instead of dtype in piecewise_constant. Fix #10086

* add unit test

* Small lint fix and comment
This commit is contained in:
Yuxin Wu 2017-06-07 19:07:44 -07:00 committed by Jonathan Hseu
parent 7c46214abb
commit aff4d124b2
2 changed files with 9 additions and 4 deletions

View File

@ -138,17 +138,17 @@ def piecewise_constant(x, boundaries, values, name=None):
# comparisons, for example if floats are converted to integers.
boundaries = ops.convert_n_to_tensor(boundaries)
for b in boundaries:
if b.dtype != x.dtype:
if b.dtype.base_dtype != x.dtype.base_dtype:
raise ValueError(
"Boundaries (%s) must have the same dtype as x (%s)." % (
b.dtype, x.dtype))
b.dtype.base_dtype, x.dtype.base_dtype))
# TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
values = ops.convert_n_to_tensor(values)
for v in values[1:]:
if v.dtype != values[0].dtype:
if v.dtype.base_dtype != values[0].dtype.base_dtype:
raise ValueError(
"Values must have elements all with the same dtype (%s vs %s)." % (
values[0].dtype, v.dtype))
values[0].dtype.base_dtype, v.dtype.base_dtype))
pred_fn_pairs = {}
pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]

View File

@ -113,6 +113,11 @@ class LRDecayTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
learning_rate_decay.piecewise_constant(x, boundaries, values)
# Test that ref types are valid.
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
boundaries, values = [1.0, 2.0], [1, 2, 3]
learning_rate_decay.piecewise_constant(x_ref, boundaries, values)
class LinearDecayTest(test_util.TensorFlowTestCase):