mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Compare base_dtype instead of dtype in piecewise_constant (#10280)
* Compare base_dtype instead of dtype in piecewise_constant Compare base_dtype instead of dtype in piecewise_constant. Fix #10086 * add unit test * Small lint fix and comment
This commit is contained in:
parent
7c46214abb
commit
aff4d124b2
|
|
@ -138,17 +138,17 @@ def piecewise_constant(x, boundaries, values, name=None):
|
|||
# comparisons, for example if floats are converted to integers.
|
||||
boundaries = ops.convert_n_to_tensor(boundaries)
|
||||
for b in boundaries:
|
||||
if b.dtype != x.dtype:
|
||||
if b.dtype.base_dtype != x.dtype.base_dtype:
|
||||
raise ValueError(
|
||||
"Boundaries (%s) must have the same dtype as x (%s)." % (
|
||||
b.dtype, x.dtype))
|
||||
b.dtype.base_dtype, x.dtype.base_dtype))
|
||||
# TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
|
||||
values = ops.convert_n_to_tensor(values)
|
||||
for v in values[1:]:
|
||||
if v.dtype != values[0].dtype:
|
||||
if v.dtype.base_dtype != values[0].dtype.base_dtype:
|
||||
raise ValueError(
|
||||
"Values must have elements all with the same dtype (%s vs %s)." % (
|
||||
values[0].dtype, v.dtype))
|
||||
values[0].dtype.base_dtype, v.dtype.base_dtype))
|
||||
|
||||
pred_fn_pairs = {}
|
||||
pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]
|
||||
|
|
|
|||
|
|
@ -113,6 +113,11 @@ class LRDecayTest(test_util.TensorFlowTestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
learning_rate_decay.piecewise_constant(x, boundaries, values)
|
||||
|
||||
# Test that ref types are valid.
|
||||
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
|
||||
boundaries, values = [1.0, 2.0], [1, 2, 3]
|
||||
learning_rate_decay.piecewise_constant(x_ref, boundaries, values)
|
||||
|
||||
|
||||
class LinearDecayTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user