Graduate tf.data.experimental.dense_to_sparse_batch to tf.data.Dataset.sparse_batch.

PiperOrigin-RevId: 485168714
This commit is contained in:
Matt Callanan 2022-10-31 15:52:25 -07:00 committed by TensorFlower Gardener
parent 4fea1e92bb
commit 29e6955939
22 changed files with 313 additions and 1 deletions

View File

@ -164,6 +164,7 @@ This release contains contributions from many people at Google, as well as:
* `tf.data`:
* Graduated experimental APIs:
* [`tf.data.Dataset.ragged_batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#ragged_batch), which batches elements of `tf.data.Dataset`s into `tf.RaggedTensor`s.
* [`tf.data.Dataset.sparse_batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#sparse_batch), which batches elements of `tf.data.Dataset`s into `tf.sparse.SparseTensor`s.
## Bug Fixes and Other Changes

View File

@ -136,7 +136,7 @@ def dense_to_sparse_batch(batch_size, row_shape):
"""
def _apply_fn(dataset):
return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
return dataset.sparse_batch(batch_size, row_shape)
return _apply_fn

View File

@ -770,6 +770,22 @@ tf_py_test(
],
)
tf_py_test(
name = "sparse_batch_test",
size = "medium",
srcs = ["sparse_batch_test.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/kernel_tests:checkpoint_test_base",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
tf_py_test(
name = "random_test",
size = "small",

View File

@ -0,0 +1,123 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.sparse_batch`."""
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import checkpoint_test_base
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class DenseToSparseBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testBasic(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
lambda x: array_ops.fill([x], x)).sparse_batch(4, [12])
get_next = self.getNext(dataset)
for start in range(0, len(components), 4):
results = self.evaluate(get_next())
self.assertAllEqual([[i, j]
for i, c in enumerate(components[start:start + 4])
for j in range(c)], results.indices)
self.assertAllEqual(
[c for c in components[start:start + 4] for _ in range(c)],
results.values)
self.assertAllEqual([min(4,
len(components) - start), 12],
results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
lambda x: array_ops.fill([x, x], x)).sparse_batch(4, [5, None])
get_next = self.getNext(dataset)
for start in range(0, len(components), 4):
results = self.evaluate(get_next())
self.assertAllEqual([[i, j, z]
for i, c in enumerate(components[start:start + 4])
for j in range(c)
for z in range(c)], results.indices)
self.assertAllEqual([
c for c in components[start:start + 4] for _ in range(c)
for _ in range(c)
], results.values)
self.assertAllEqual([
min(4,
len(components) - start), 5,
np.max(components[start:start + 4])
], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
with self.assertRaisesRegex(ValueError, "Dimension -2 must be >= 0"):
dataset_ops.Dataset.from_tensors(input_tensor).sparse_batch(4, [-2])
@combinations.generate(test_base.default_test_combinations())
def testShapeErrors(self):
def dataset_fn(input_tensor):
return dataset_ops.Dataset.from_tensors(input_tensor).sparse_batch(
4, [12])
# Initialize with an input tensor of incompatible rank.
get_next = self.getNext(dataset_fn([[1]]))
with self.assertRaisesRegex(errors.InvalidArgumentError,
"incompatible with the row shape"):
self.evaluate(get_next())
# Initialize with an input tensor that is larger than `row_shape`.
get_next = self.getNext(dataset_fn(np.int32(range(13))))
with self.assertRaisesRegex(errors.DataLossError,
"larger than the row shape"):
self.evaluate(get_next())
class DenseToSparseBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
parameterized.TestCase):
def _build_dataset(self, components):
return dataset_ops.Dataset.from_tensor_slices(components).map(
lambda x: array_ops.fill([x], x)).sparse_batch(4, [12])
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
checkpoint_test_base.default_test_combinations()))
def test(self, verify_fn):
components = np.random.randint(5, size=(40,)).astype(np.int32)
num_outputs = len(components) // 4
verify_fn(self, lambda: self._build_dataset(components), num_outputs)
if __name__ == "__main__":
test.main()

View File

@ -62,6 +62,7 @@ py_library(
":ragged_batch_op",
":rebatch_op",
":save_ops",
":sparse_batch_op",
":structured_function",
":zip_op",
"//tensorflow/python:constant_op",
@ -162,6 +163,14 @@ py_library(
],
)
py_library(
name = "sparse_batch_op",
srcs = ["sparse_batch_op.py"],
srcs_version = "PY3",
deps = [
],
)
py_library(
name = "save_ops",
srcs = ["save_op.py"],

View File

@ -2180,6 +2180,53 @@ class DatasetV2(
return ragged_batch_op.ragged_batch(self, batch_size, drop_remainder,
row_splits_dtype, name)
def sparse_batch(self, batch_size, row_shape, name=None):
"""Combines consecutive elements into `tf.sparse.SparseTensor`s.
Like `Dataset.padded_batch()`, this transformation combines multiple
consecutive elements of the dataset, which might have different
shapes, into a single element. The resulting element has three
components (`indices`, `values`, and `dense_shape`), which
comprise a `tf.sparse.SparseTensor` that represents the same data. The
`row_shape` represents the dense shape of each row in the
resulting `tf.sparse.SparseTensor`, to which the effective batch size is
prepended. For example:
```python
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
a.apply(tf.data.experimental.dense_to_sparse_batch(
batch_size=2, row_shape=[6])) ==
{
([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices
['a', 'b', 'c', 'a', 'b'], # values
[2, 6]), # dense_shape
([[0, 0], [0, 1], [0, 2], [0, 3]],
['a', 'b', 'c', 'd'],
[1, 6])
}
```
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object
representing the equivalent dense shape of a row in the resulting
`tf.sparse.SparseTensor`. Each element of this dataset must have the
same rank as `row_shape`, and must have size less than or equal to
`row_shape` in each dimension.
name: (Optional.) A string indicating a name for the `tf.data` operation.
Returns:
A new `Dataset` with the transformation applied as described above.
"""
# Loaded lazily due to a circular dependency (dataset_ops ->
# sparse_batch_op -> dataset_ops).
from tensorflow.python.data.ops import sparse_batch_op # pylint: disable=g-import-not-at-top
return sparse_batch_op.sparse_batch(self, batch_size, row_shape, name)
def map(self,
map_func,
num_parallel_calls=None,

View File

@ -0,0 +1,56 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The implementation of `tf.data.Dataset.sparse_batch`."""
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
def sparse_batch(self, batch_size, row_shape, name=None):
return _DenseToSparseBatchDataset(self, batch_size, row_shape, name)
class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape, name=None):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
if not isinstance(
dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
raise TypeError("`dense_to_sparse_batch` requires an input dataset whose "
"elements have a single component, but the given dataset "
"has the following component types: "
f"{dataset_ops.get_legacy_output_types(input_dataset)}.")
self._input_dataset = input_dataset
self._batch_size = batch_size
self._row_shape = row_shape
self._element_spec = sparse_tensor.SparseTensorSpec(
tensor_shape.TensorShape([None]).concatenate(self._row_shape),
dataset_ops.get_legacy_output_types(input_dataset))
self._name = name
variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._batch_size,
row_shape=convert.partial_shape_to_tensor(self._row_shape),
**self._flat_structure)
super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
variant_tensor)
@property
def element_spec(self):
return self._element_spec

View File

@ -199,6 +199,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -201,6 +201,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -201,6 +201,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -201,6 +201,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -201,6 +201,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -201,6 +201,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -201,6 +201,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -166,6 +166,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -168,6 +168,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -167,6 +167,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -168,6 +168,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -168,6 +168,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -169,6 +169,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -168,6 +168,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -169,6 +169,10 @@ tf_class {
name: "snapshot"
argspec: "args=[\'self\', \'path\', \'compression\', \'reader_func\', \'shard_func\', \'name\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sparse_batch"
argspec: "args=[\'self\', \'batch_size\', \'row_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "