diff --git a/RELEASE.md b/RELEASE.md index c716a7f36de..876ef314f1a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -141,6 +141,10 @@ This release contains contributions from many people at Google, as well as: file is a protobuf containing the "fingerprint" of the SavedModel. See the [RFC](https://github.com/tensorflow/community/pull/415) for more details regarding its design and properties. + +* `tf.data`: + * Graduated experimental APIs: + * [`tf.data.Dataset.ragged_batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#ragged_batch), which batches elements of `tf.data.Dataset`s into `tf.RaggedTensor`s. ## Bug Fixes and Other Changes diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index a254ae26f5f..a86a682ece1 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -147,22 +147,6 @@ tf_py_test( ], ) -tf_py_test( - name = "dense_to_ragged_batch_test", - size = "small", - srcs = ["dense_to_ragged_batch_test.py"], - shard_count = 4, - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python/data/kernel_tests:test_base", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - tf_py_test( name = "dense_to_sparse_batch_test", size = "medium", diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py index da58ff1f8dc..1bc965d88fb 100644 --- a/tensorflow/python/data/experimental/ops/batching.py +++ b/tensorflow/python/data/experimental/ops/batching.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Batching dataset transformations.""" -from tensorflow.python.data.ops import batch_op from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import structured_function from tensorflow.python.data.util import convert @@ -22,15 +21,14 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops -from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @tf_export("data.experimental.dense_to_ragged_batch") +@deprecation.deprecated(None, "Use `tf.data.Dataset.ragged_batch` instead.") def dense_to_ragged_batch(batch_size, drop_remainder=False, row_splits_dtype=dtypes.int64): @@ -87,11 +85,8 @@ def dense_to_ragged_batch(batch_size, Returns: Dataset: A `Dataset`. """ - def _apply_fn(dataset): - ragged_dataset = _DenseToRaggedDataset(dataset, row_splits_dtype) - return batch_op.BatchDataset( - ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder) + return dataset.ragged_batch(batch_size, drop_remainder, row_splits_dtype) return _apply_fn @@ -381,78 +376,3 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset): @property def element_spec(self): return self._element_spec - - -class _DenseToRaggedDataset(dataset_ops.UnaryDataset): - """A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0). - - In particular: - - * Any tf.Tensor elements with rank>0 are encoded as ragged tensors with - ragged_rank=0. This allows tensors with varying shape to be batched - together. - * Any other elements are left as-is. - """ - - def __init__(self, input_dataset, row_splits_dtype): - """Constructs a new _DenseToRaggedDataset. - - Args: - input_dataset: The dataset whose tf.Tensor elements should be made ragged. - row_splits_dtype: The dtype that should be used for the `row_splits` of - any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* - have their row_splits dtype changed. - """ - # Replace each TensorSpec in the input dataset's structure with a - # corresponding RaggedTensorSpec. - def to_ragged_spec(spec): - """Returns the new spec based on RaggedTensors.""" - if (not isinstance(spec, tensor_spec.TensorSpec) or - spec.shape.rank is None or - spec.shape.is_fully_defined()): - return spec - else: - ragged_rank = max([ - axis for (axis, size) in enumerate(spec.shape.as_list()) - if size is None - ]) - return ragged_tensor.RaggedTensorSpec( - shape=spec.shape, - dtype=spec.dtype, - ragged_rank=ragged_rank, - row_splits_dtype=row_splits_dtype) - - self._structure = nest.map_structure(to_ragged_spec, - input_dataset.element_spec) - - # Replace each tf.Tensor value in the input dataset with a variant-encoded - # RaggedTensor. Since we're updating the corresponding structure to be - # a RaggedTensorSpec, this variant-encoded tensor will be decoded with - # RaggedTensorSpec._from_tensor_list. - def to_ragged_variant(value): - """Re-encode Tensors as RaggedTensors.""" - if (not isinstance(value, ops.Tensor) or - value.shape.rank is None or - value.shape.is_fully_defined()): - return value - else: - spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) - if spec._ragged_rank > 0: # pylint: disable=protected-access - value = ragged_tensor.RaggedTensor.from_tensor( - value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access - return spec._to_tensor_list(value)[0] # pylint: disable=protected-access - - # Tuples are automatically unpacked by `dataset.map` so we repack them. - if structured_function._should_unpack(input_dataset.element_spec): # pylint: disable=protected-access - map_fn = lambda *value: nest.map_structure(to_ragged_variant, value) - else: - map_fn = lambda value: nest.map_structure(to_ragged_variant, value) - - self._mapped_dataset = input_dataset.map(map_fn) - - variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access - super(_DenseToRaggedDataset, self).__init__(input_dataset, variant) - - @property - def element_spec(self): - return self._structure diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 607bcd5680c..33dac535c12 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -754,6 +754,22 @@ tf_py_test( ], ) +tf_py_test( + name = "ragged_batch_test", + size = "small", + srcs = ["ragged_batch_test.py"], + shard_count = 4, + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "random_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py b/tensorflow/python/data/kernel_tests/ragged_batch_test.py similarity index 95% rename from tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py rename to tensorflow/python/data/kernel_tests/ragged_batch_test.py index d2ada96b086..7791cc0da55 100644 --- a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py +++ b/tensorflow/python/data/kernel_tests/ragged_batch_test.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for `tf.data.experimental.dense_to_ragged_batch`.""" +"""Tests for `tf.data.Dataset.ragged_batch`.""" from absl.testing import parameterized import numpy as np -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest @@ -118,8 +117,7 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): for _ in range(nrows)] # Batch the dataset, and check that batches match slices from `rows`. - batched_dataset = dataset.apply( - batching.dense_to_ragged_batch(batch_size, drop_remainder)) + batched_dataset = dataset.ragged_batch(batch_size, drop_remainder) get_next = self.getNext(batched_dataset) for start_row in range(0, nrows, batch_size): end_row = start_row + batch_size @@ -155,7 +153,7 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices(np.arange(nrows)) dataset = dataset.map(make_structure) - dataset = dataset.apply(batching.dense_to_ragged_batch(batch_size)) + dataset = dataset.ragged_batch(batch_size) get_next = self.getNext(dataset) for i in range(0, nrows, batch_size): diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index d135909cc27..c3980e0b727 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -53,6 +53,7 @@ py_library( ":iterator_ops", ":load_ops", ":options", + ":ragged_batch_op", ":rebatch_op", ":save_ops", ":structured_function", @@ -137,6 +138,23 @@ py_library( ], ) +py_library( + name = "ragged_batch_op", + srcs = ["ragged_batch_op.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "//tensorflow/python/data/util:convert", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:structure", + ], +) + py_library( name = "save_ops", srcs = ["save_op.py"], diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index d81d0027871..fcc742010c7 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2126,6 +2126,67 @@ class DatasetV2( drop_remainder, name=name) + def ragged_batch(self, + batch_size, + drop_remainder=False, + row_splits_dtype=dtypes.int64, + name=None): + """Combines consecutive elements of this dataset into `tf.RaggedTensor`s. + + Like `tf.data.Dataset.batch`, the components of the resulting element will + have an additional outer dimension, which will be `batch_size` (or + `N % batch_size` for the last element if `batch_size` does not divide the + number of input elements `N` evenly and `drop_remainder` is `False`). If + your program depends on the batches having the same outer dimension, you + should set the `drop_remainder` argument to `True` to prevent the smaller + batch from being produced. + + Unlike `tf.data.Dataset.batch`, the input elements to be batched may have + different shapes: + + * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is + fully defined, then it is batched as normal. + * If an input element is a `tf.Tensor` whose static `tf.TensorShape` + contains one or more axes with unknown size (i.e., `shape[i]=None`), then + the output will contain a `tf.RaggedTensor` that is ragged up to any of such + dimensions. + * If an input element is a `tf.RaggedTensor` or any other type, then it is + batched as normal. + + Example: + + >>> dataset = tf.data.Dataset.range(6) + >>> dataset = dataset.map(lambda x: tf.range(x)) + >>> dataset.element_spec.shape + TensorShape([None]) + >>> dataset = dataset.ragged_batch(2) + >>> for batch in dataset: + ... print(batch) + + + + + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in the case it has fewer than + `batch_size` elements; the default behavior is not to drop the smaller + batch. + row_splits_dtype: The dtype that should be used for the `row_splits` of + any new ragged tensors. Existing `tf.RaggedTensor` elements do not have + their row_splits dtype changed. + name: (Optional.) A string indicating a name for the `tf.data` operation. + + Returns: + A new `Dataset` with the transformation applied as described above. + """ + # Loaded lazily due to a circular dependency (dataset_ops -> + # ragged_batch_op -> dataset_ops). + from tensorflow.python.data.ops import ragged_batch_op # pylint: disable=g-import-not-at-top + return ragged_batch_op.ragged_batch(self, batch_size, drop_remainder, + row_splits_dtype, name) + def map(self, map_func, num_parallel_calls=None, diff --git a/tensorflow/python/data/ops/ragged_batch_op.py b/tensorflow/python/data/ops/ragged_batch_op.py new file mode 100644 index 00000000000..1868925873d --- /dev/null +++ b/tensorflow/python/data/ops/ragged_batch_op.py @@ -0,0 +1,109 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The implementation of `tf.data.Dataset.ragged_batch`.""" +from tensorflow.python.data.ops import batch_op +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import structured_function +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops.ragged import ragged_tensor + + +def ragged_batch(self, + batch_size, + drop_remainder=False, + row_splits_dtype=dtypes.int64, + name=None): + ragged_dataset = _DenseToRaggedDataset(self, row_splits_dtype, name) + return batch_op.BatchDataset( + ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder) + + +class _DenseToRaggedDataset(dataset_ops.UnaryDataset): + """A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0). + + In particular: + + * Any tf.Tensor elements with rank>0 are encoded as ragged tensors with + ragged_rank=0. This allows tensors with varying shape to be batched + together. + * Any other elements are left as-is. + """ + + def __init__(self, input_dataset, row_splits_dtype, name=None): + """Constructs a new _DenseToRaggedDataset. + + Args: + input_dataset: The dataset whose tf.Tensor elements should be made ragged. + row_splits_dtype: The dtype that should be used for the `row_splits` of + any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* + have their row_splits dtype changed. + name: (Optional.) A string indicating a name for the `tf.data` operation. + """ + # Replace each TensorSpec in the input dataset's structure with a + # corresponding RaggedTensorSpec. + def to_ragged_spec(spec): + """Returns the new spec based on RaggedTensors.""" + if (not isinstance(spec, tensor_spec.TensorSpec) or + spec.shape.rank is None or + spec.shape.is_fully_defined()): + return spec + else: + ragged_rank = max([ + axis for (axis, size) in enumerate(spec.shape.as_list()) + if size is None + ]) + return ragged_tensor.RaggedTensorSpec( + shape=spec.shape, + dtype=spec.dtype, + ragged_rank=ragged_rank, + row_splits_dtype=row_splits_dtype) + + self._structure = nest.map_structure(to_ragged_spec, + input_dataset.element_spec) + + # Replace each tf.Tensor value in the input dataset with a variant-encoded + # RaggedTensor. Since we're updating the corresponding structure to be + # a RaggedTensorSpec, this variant-encoded tensor will be decoded with + # RaggedTensorSpec._from_tensor_list. + def to_ragged_variant(value): + """Re-encode Tensors as RaggedTensors.""" + if (not isinstance(value, ops.Tensor) or + value.shape.rank is None or + value.shape.is_fully_defined()): + return value + else: + spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) + if spec._ragged_rank > 0: # pylint: disable=protected-access + value = ragged_tensor.RaggedTensor.from_tensor( + value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access + return spec._to_tensor_list(value)[0] # pylint: disable=protected-access + + # Tuples are automatically unpacked by `dataset.map` so we repack them. + if structured_function._should_unpack(input_dataset.element_spec): # pylint: disable=protected-access + map_fn = lambda *value: nest.map_structure(to_ragged_variant, value) + else: + map_fn = lambda value: nest.map_structure(to_ragged_variant, value) + + self._mapped_dataset = input_dataset.map(map_fn) + self._name = name + variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access + super(_DenseToRaggedDataset, self).__init__(input_dataset, variant) + + @property + def element_spec(self): + return self._structure diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 09793fd81ce..cd18d209660 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -143,6 +143,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 50c35bcbcdd..51740951cec 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -145,6 +145,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index 7c4b9fb2562..319ffa68ffb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -145,6 +145,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index a5966884153..d303899d9f7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -145,6 +145,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index a5817d5d09d..b4adfd7c94c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -145,6 +145,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index bd4abcbb0a1..b347b1bcdda 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -145,6 +145,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index e6883494d8b..318f7746579 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -145,6 +145,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 6f93210ff22..109736eeca2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -110,6 +110,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 5bd5aa8647e..4bf9ff2ea16 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -112,6 +112,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index 4822e276b94..f29a2a38568 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -111,6 +111,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 1f2b6f44764..02ab7ca30df 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -112,6 +112,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index 397beba8f08..635e6eb27ab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -112,6 +112,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index 52911a8e1b7..c05ff5c583b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -113,6 +113,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index 5724e284469..fd853758921 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -112,6 +112,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt index cffcddb10b6..41af206fd97 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt @@ -113,6 +113,10 @@ tf_class { name: "prefetch" argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "ragged_batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"\", \'None\'], " + } member_method { name: "random" argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "