[tf.data] Convert dataset arguments to tensors as early as possible.

This change raises a `TypeError` earlier if (for example) the `batch_size`
argument to `Dataset.batch()` has the incorrect type.

PiperOrigin-RevId: 173126678
This commit is contained in:
Derek Murray 2017-10-23 09:34:30 -07:00 committed by TensorFlower Gardener
parent 4f7503a876
commit fc56349b7f

View File

@ -1057,21 +1057,21 @@ class RangeDataset(Dataset):
def _parse_args(self, *args): def _parse_args(self, *args):
if len(args) == 1: if len(args) == 1:
self._start = self._build_tensor(0, "start") self._start = self._build_tensor(0, "start")
self._stop = args[0] self._stop = self._build_tensor(args[0], "stop")
self._step = self._build_tensor(1, "step") self._step = self._build_tensor(1, "step")
elif len(args) == 2: elif len(args) == 2:
self._start = args[0] self._start = self._build_tensor(args[0], "start")
self._stop = args[1] self._stop = self._build_tensor(args[1], "stop")
self._step = self._build_tensor(1, "step") self._step = self._build_tensor(1, "step")
elif len(args) == 3: elif len(args) == 3:
self._start = args[0] self._start = self._build_tensor(args[0], "start")
self._stop = args[1] self._stop = self._build_tensor(args[1], "stop")
self._step = args[2] self._step = self._build_tensor(args[2], "step")
else: else:
raise ValueError("Invalid arguments to RangeDataset: %s" % str(args)) raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
def _build_tensor(self, int64_value, name): def _build_tensor(self, int64_value, name):
return constant_op.constant(int64_value, dtype=dtypes.int64, name=name) return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
def _as_variant_tensor(self): def _as_variant_tensor(self):
return gen_dataset_ops.range_dataset( return gen_dataset_ops.range_dataset(
@ -1217,7 +1217,8 @@ class BatchDataset(Dataset):
"""See `Dataset.batch()` for details.""" """See `Dataset.batch()` for details."""
super(BatchDataset, self).__init__() super(BatchDataset, self).__init__()
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._batch_size = batch_size self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64,
name="batch_size")
def _as_variant_tensor(self): def _as_variant_tensor(self):
return gen_dataset_ops.batch_dataset( return gen_dataset_ops.batch_dataset(
@ -1285,7 +1286,8 @@ class PaddedBatchDataset(Dataset):
"""See `Dataset.batch()` for details.""" """See `Dataset.batch()` for details."""
super(PaddedBatchDataset, self).__init__() super(PaddedBatchDataset, self).__init__()
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._batch_size = batch_size self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64,
name="batch_size")
padding_values = (padding_values if padding_values is not None else padding_values = (padding_values if padding_values is not None else
self._default_padding(input_dataset)) self._default_padding(input_dataset))
self._padded_shapes = nest.map_structure_up_to( self._padded_shapes = nest.map_structure_up_to(
@ -1509,8 +1511,10 @@ class InterleaveDataset(Dataset):
self._map_func = tf_map_func self._map_func = tf_map_func
self._map_func.add_to_graph(ops.get_default_graph()) self._map_func.add_to_graph(ops.get_default_graph())
self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64) self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64,
self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64) name="cycle_length")
self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64,
name="block_length")
def _as_variant_tensor(self): def _as_variant_tensor(self):
return gen_dataset_ops.interleave_dataset( return gen_dataset_ops.interleave_dataset(
@ -1587,7 +1591,8 @@ class PrefetchDataset(Dataset):
"""See `Dataset.prefetch()` for details.""" """See `Dataset.prefetch()` for details."""
super(PrefetchDataset, self).__init__() super(PrefetchDataset, self).__init__()
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(buffer_size, dtype=dtypes.int64) self._buffer_size = ops.convert_to_tensor(buffer_size, dtype=dtypes.int64,
name="buffer_size")
def _as_variant_tensor(self): def _as_variant_tensor(self):
return gen_dataset_ops.prefetch_dataset( return gen_dataset_ops.prefetch_dataset(