#tf-data Support Python-style zipping in tf.data.Dataset.zip

The current behavior is unchanged to avoid affecting backward compatibility but only added a new behavior to `tf.data.Dataset.zip` Python op for accepting a non-nested structure of datasets without doing the tuple.

PiperOrigin-RevId: 515389129
This commit is contained in:
A. Unique TensorFlower 2023-03-09 11:11:27 -08:00 committed by TensorFlower Gardener
parent 253052b195
commit fe1cc8d2c4
19 changed files with 133 additions and 39 deletions

View File

@ -85,6 +85,11 @@
`tf.nn.safe_embedding_lookup_sparse`, which enables a simplified and `tf.nn.safe_embedding_lookup_sparse`, which enables a simplified and
typically faster lookup procedure. typically faster lookup procedure.
* `tf.data`
* `tf.data.Dataset.zip` now supports Python-style zipping, i.e.
`Dataset.zip(a, b, c)`.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES> * <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>

View File

@ -134,15 +134,53 @@ class ZipTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.zip((x, y), name="zip") dataset = dataset_ops.Dataset.zip((x, y), name="zip")
self.assertDatasetProduces(dataset, [(4, 2)]) self.assertDatasetProduces(dataset, [(4, 2)])
@combinations.generate(
combinations.times(test_base.default_test_combinations())
)
def testZipWithArgsAndDataset(self):
with self.assertRaisesRegex(
TypeError, r"Both `\*args` and `datasets` cannot be set."
):
dataset_ops.Dataset.zip(
dataset_ops.Dataset.range(1, 4),
dataset_ops.Dataset.range(4, 7),
datasets=(
dataset_ops.Dataset.range(1, 4),
dataset_ops.Dataset.range(4, 7),
),
)
class ZipCheckpointTest(checkpoint_test_base.CheckpointTestBase, @combinations.generate(
parameterized.TestCase): combinations.times(test_base.default_test_combinations())
)
def testZipBasicWithNoInput(self):
with self.assertRaisesRegex(
TypeError, r"Must pass at least one dataset to `zip`."
):
dataset_ops.Dataset.zip()
@combinations.generate(
combinations.times(test_base.default_test_combinations())
)
def InvalidZipInputList(self):
with self.assertRaisesRegex(
TypeError,
r"Invalid input to `zip`. Inputs are expected to be (nested)"
r" structures of `tf.data.Dataset` objects. Python `list` is"
r" not supported and you should use `tuple` instead.",
):
dataset_ops.Dataset.zip([1, 2, 3], [4, 5, 6])
class ZipCheckpointTest(
checkpoint_test_base.CheckpointTestBase, parameterized.TestCase
):
def _build_dataset(self, arr, options=None): def _build_dataset(self, arr, options=None):
components = [ components = [
np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22), np.tile(np.array([[12], [13], [14], [15]]), 22),
np.array(arr) np.array(arr),
] ]
datasets = [ datasets = [
dataset_ops.Dataset.from_tensor_slices(component) dataset_ops.Dataset.from_tensor_slices(component)
@ -200,6 +238,19 @@ class ZipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.times(test_base.default_test_combinations())) combinations.times(test_base.default_test_combinations()))
def testZipBasicWithoutTuple(self):
dataset = dataset_ops.Dataset.zip(
dataset_ops.Dataset.range(1, 4), dataset_ops.Dataset.range(4, 7)
)
expected_dataset = [(1, 4), (2, 5), (3, 6)]
for i in range(3):
self.assertEqual(
self.evaluate(random_access.at(dataset, index=i)), expected_dataset[i]
)
@combinations.generate(
combinations.times(test_base.default_test_combinations())
)
def testZipEqual(self): def testZipEqual(self):
components = [ components = [
np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(np.array([[1], [2], [3], [4]]), 20),
@ -246,6 +297,28 @@ class ZipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
self.evaluate(random_access.at(dataset, index=4)) self.evaluate(random_access.at(dataset, index=4))
@combinations.generate(test_base.default_test_combinations())
def testNestedWithoutTuple(self):
components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0]),
]
datasets = [
dataset_ops.Dataset.from_tensor_slices(component)
for component in components
]
dataset = dataset_ops.Dataset.zip(datasets[0], (datasets[1], datasets[2]))
for i in range(4):
result1, (result2, result3) = self.evaluate(
random_access.at(dataset, index=i)
)
self.assertAllEqual(components[0][i], result1)
self.assertAllEqual(components[1][i], result2)
self.assertAllEqual(components[2][i], result3)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(random_access.at(dataset, index=4))
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testNamedTuple(self): def testNamedTuple(self):
Foo = collections.namedtuple("Foo", ["x", "y"]) Foo = collections.namedtuple("Foo", ["x", "y"])

View File

@ -1017,7 +1017,7 @@ class DatasetV2(
# pylint: enable=g-import-not-at-top,protected-access # pylint: enable=g-import-not-at-top,protected-access
@staticmethod @staticmethod
def zip(datasets, name=None): def zip(*args, datasets=None, name=None):
"""Creates a `Dataset` by zipping together the given datasets. """Creates a `Dataset` by zipping together the given datasets.
This method has similar semantics to the built-in `zip()` function This method has similar semantics to the built-in `zip()` function
@ -1026,14 +1026,14 @@ class DatasetV2(
nesting mechanisms are documented nesting mechanisms are documented
[here] (https://www.tensorflow.org/guide/data#dataset_structure). [here] (https://www.tensorflow.org/guide/data#dataset_structure).
>>> # The nested structure of the `datasets` argument determines the >>> # The datasets or nested structure of datasets `*args` argument
>>> # structure of elements in the resulting dataset. >>> # determines the structure of elements in the resulting dataset.
>>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ] >>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
>>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ] >>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
>>> ds = tf.data.Dataset.zip((a, b)) >>> ds = tf.data.Dataset.zip(a, b)
>>> list(ds.as_numpy_iterator()) >>> list(ds.as_numpy_iterator())
[(1, 4), (2, 5), (3, 6)] [(1, 4), (2, 5), (3, 6)]
>>> ds = tf.data.Dataset.zip((b, a)) >>> ds = tf.data.Dataset.zip(b, a)
>>> list(ds.as_numpy_iterator()) >>> list(ds.as_numpy_iterator())
[(4, 1), (5, 2), (6, 3)] [(4, 1), (5, 2), (6, 3)]
>>> >>>
@ -1041,7 +1041,7 @@ class DatasetV2(
>>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8], >>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8],
... # [9, 10], ... # [9, 10],
... # [11, 12] ] ... # [11, 12] ]
>>> ds = tf.data.Dataset.zip((a, b, c)) >>> ds = tf.data.Dataset.zip(a, b, c)
>>> for element in ds.as_numpy_iterator(): >>> for element in ds.as_numpy_iterator():
... print(element) ... print(element)
(1, 4, array([7, 8])) (1, 4, array([7, 8]))
@ -1051,12 +1051,16 @@ class DatasetV2(
>>> # The number of elements in the resulting dataset is the same as >>> # The number of elements in the resulting dataset is the same as
>>> # the size of the smallest dataset in `datasets`. >>> # the size of the smallest dataset in `datasets`.
>>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ] >>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ]
>>> ds = tf.data.Dataset.zip((a, d)) >>> ds = tf.data.Dataset.zip(a, d)
>>> list(ds.as_numpy_iterator()) >>> list(ds.as_numpy_iterator())
[(1, 13), (2, 14)] [(1, 13), (2, 14)]
Args: Args:
datasets: A (nested) structure of datasets. *args: Datasets or nested structures of datasets to zip together. This
can't be set if `datasets` is set.
datasets: A (nested) structure of datasets. This can't be set if `*args`
is set. Note that this exists only for backwards compatibility and it is
preferred to use *args.
name: (Optional.) A name for the tf.data operation. name: (Optional.) A name for the tf.data operation.
Returns: Returns:
@ -1066,8 +1070,16 @@ class DatasetV2(
# dataset_ops). # dataset_ops).
# pylint: disable=g-import-not-at-top,protected-access # pylint: disable=g-import-not-at-top,protected-access
from tensorflow.python.data.ops import zip_op from tensorflow.python.data.ops import zip_op
if not args and datasets is None:
raise TypeError("Must pass at least one dataset to `zip`.")
if args and datasets is not None:
raise TypeError("Both `*args` and `datasets` cannot be set.")
if len(args) == 1:
datasets = args[0]
elif len(args) > 1:
datasets = args
return zip_op._zip(datasets, name) return zip_op._zip(datasets, name)
# pylint: enable=g-import-not-at-top,protected-access
def concatenate(self, dataset, name=None): def concatenate(self, dataset, name=None):
"""Creates a `Dataset` by concatenating the given dataset with this dataset. """Creates a `Dataset` by concatenating the given dataset with this dataset.
@ -3937,8 +3949,8 @@ class DatasetV1(DatasetV2, data_types.DatasetV1):
@staticmethod @staticmethod
@functools.wraps(DatasetV2.zip) @functools.wraps(DatasetV2.zip)
def zip(datasets, name=None): def zip(*args, datasets=None, name=None):
return DatasetV1Adapter(DatasetV2.zip(datasets, name=name)) return DatasetV1Adapter(DatasetV2.zip(*args, datasets=datasets, name=name))
@functools.wraps(DatasetV2.concatenate) @functools.wraps(DatasetV2.concatenate)
def concatenate(self, dataset, name=None): def concatenate(self, dataset, name=None):

View File

@ -32,22 +32,26 @@ class _ZipDataset(dataset_ops.DatasetV2):
for ds in nest.flatten(datasets): for ds in nest.flatten(datasets):
if not isinstance(ds, data_types.DatasetV2): if not isinstance(ds, data_types.DatasetV2):
if isinstance(ds, list): if isinstance(ds, list):
raise TypeError("Invalid `datasets`. `datasets` is expected to be a " raise TypeError(
"(nested) structure of `tf.data.Dataset` objects. " "Invalid input to `zip`. Inputs are expected to be (nested)"
"Python `list` is not supported and you should use " " structures of `tf.data.Dataset` objects. Python `list` is"
"`tuple` instead.") " not supported and you should use `tuple` instead."
)
else: else:
raise TypeError(f"Invalid `datasets`. `datasets` is expected to be a " raise TypeError(
f"(nested) structure of `tf.data.Dataset` objects " "Invalid input to `zip`. Inputs are expected to be (nested)"
f"but encountered object of type {type(ds)}.") " structures of `tf.data.Dataset` objects but"
f" encountered object of type {type(ds)}."
)
self._datasets = datasets self._datasets = datasets
self._structure = nest.pack_sequence_as( self._structure = nest.pack_sequence_as(
self._datasets, self._datasets, [ds.element_spec for ds in nest.flatten(self._datasets)]
[ds.element_spec for ds in nest.flatten(self._datasets)]) )
self._name = name self._name = name
variant_tensor = gen_dataset_ops.zip_dataset( variant_tensor = gen_dataset_ops.zip_dataset(
[ds._variant_tensor for ds in nest.flatten(self._datasets)], [ds._variant_tensor for ds in nest.flatten(self._datasets)],
**self._common_args) **self._common_args,
)
super().__init__(variant_tensor) super().__init__(variant_tensor)
def _inputs(self): def _inputs(self):

View File

@ -229,6 +229,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -231,6 +231,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -231,6 +231,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -231,6 +231,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -231,6 +231,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -231,6 +231,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -231,6 +231,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -196,6 +196,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -198,6 +198,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -197,6 +197,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -198,6 +198,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -198,6 +198,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -199,6 +199,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -198,6 +198,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }

View File

@ -199,6 +199,6 @@ tf_class {
} }
member_method { member_method {
name: "zip" name: "zip"
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[], varargs=args, keywords=None, defaults=None"
} }
} }