mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
#tf-data Support Python-style zipping in tf.data.Dataset.zip
The current behavior is unchanged to avoid affecting backward compatibility but only added a new behavior to `tf.data.Dataset.zip` Python op for accepting a non-nested structure of datasets without doing the tuple. PiperOrigin-RevId: 515389129
This commit is contained in:
parent
253052b195
commit
fe1cc8d2c4
|
|
@ -85,6 +85,11 @@
|
||||||
`tf.nn.safe_embedding_lookup_sparse`, which enables a simplified and
|
`tf.nn.safe_embedding_lookup_sparse`, which enables a simplified and
|
||||||
typically faster lookup procedure.
|
typically faster lookup procedure.
|
||||||
|
|
||||||
|
* `tf.data`
|
||||||
|
|
||||||
|
* `tf.data.Dataset.zip` now supports Python-style zipping, i.e.
|
||||||
|
`Dataset.zip(a, b, c)`.
|
||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||||
|
|
|
||||||
|
|
@ -134,15 +134,53 @@ class ZipTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
dataset = dataset_ops.Dataset.zip((x, y), name="zip")
|
dataset = dataset_ops.Dataset.zip((x, y), name="zip")
|
||||||
self.assertDatasetProduces(dataset, [(4, 2)])
|
self.assertDatasetProduces(dataset, [(4, 2)])
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(test_base.default_test_combinations())
|
||||||
|
)
|
||||||
|
def testZipWithArgsAndDataset(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, r"Both `\*args` and `datasets` cannot be set."
|
||||||
|
):
|
||||||
|
dataset_ops.Dataset.zip(
|
||||||
|
dataset_ops.Dataset.range(1, 4),
|
||||||
|
dataset_ops.Dataset.range(4, 7),
|
||||||
|
datasets=(
|
||||||
|
dataset_ops.Dataset.range(1, 4),
|
||||||
|
dataset_ops.Dataset.range(4, 7),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
class ZipCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
@combinations.generate(
|
||||||
parameterized.TestCase):
|
combinations.times(test_base.default_test_combinations())
|
||||||
|
)
|
||||||
|
def testZipBasicWithNoInput(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, r"Must pass at least one dataset to `zip`."
|
||||||
|
):
|
||||||
|
dataset_ops.Dataset.zip()
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(test_base.default_test_combinations())
|
||||||
|
)
|
||||||
|
def InvalidZipInputList(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError,
|
||||||
|
r"Invalid input to `zip`. Inputs are expected to be (nested)"
|
||||||
|
r" structures of `tf.data.Dataset` objects. Python `list` is"
|
||||||
|
r" not supported and you should use `tuple` instead.",
|
||||||
|
):
|
||||||
|
dataset_ops.Dataset.zip([1, 2, 3], [4, 5, 6])
|
||||||
|
|
||||||
|
|
||||||
|
class ZipCheckpointTest(
|
||||||
|
checkpoint_test_base.CheckpointTestBase, parameterized.TestCase
|
||||||
|
):
|
||||||
|
|
||||||
def _build_dataset(self, arr, options=None):
|
def _build_dataset(self, arr, options=None):
|
||||||
components = [
|
components = [
|
||||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||||
np.array(arr)
|
np.array(arr),
|
||||||
]
|
]
|
||||||
datasets = [
|
datasets = [
|
||||||
dataset_ops.Dataset.from_tensor_slices(component)
|
dataset_ops.Dataset.from_tensor_slices(component)
|
||||||
|
|
@ -200,6 +238,19 @@ class ZipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(test_base.default_test_combinations()))
|
combinations.times(test_base.default_test_combinations()))
|
||||||
|
def testZipBasicWithoutTuple(self):
|
||||||
|
dataset = dataset_ops.Dataset.zip(
|
||||||
|
dataset_ops.Dataset.range(1, 4), dataset_ops.Dataset.range(4, 7)
|
||||||
|
)
|
||||||
|
expected_dataset = [(1, 4), (2, 5), (3, 6)]
|
||||||
|
for i in range(3):
|
||||||
|
self.assertEqual(
|
||||||
|
self.evaluate(random_access.at(dataset, index=i)), expected_dataset[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(test_base.default_test_combinations())
|
||||||
|
)
|
||||||
def testZipEqual(self):
|
def testZipEqual(self):
|
||||||
components = [
|
components = [
|
||||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||||
|
|
@ -246,6 +297,28 @@ class ZipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self.evaluate(random_access.at(dataset, index=4))
|
self.evaluate(random_access.at(dataset, index=4))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testNestedWithoutTuple(self):
|
||||||
|
components = [
|
||||||
|
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||||
|
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||||
|
np.array([37.0, 38.0, 39.0, 40.0]),
|
||||||
|
]
|
||||||
|
datasets = [
|
||||||
|
dataset_ops.Dataset.from_tensor_slices(component)
|
||||||
|
for component in components
|
||||||
|
]
|
||||||
|
dataset = dataset_ops.Dataset.zip(datasets[0], (datasets[1], datasets[2]))
|
||||||
|
for i in range(4):
|
||||||
|
result1, (result2, result3) = self.evaluate(
|
||||||
|
random_access.at(dataset, index=i)
|
||||||
|
)
|
||||||
|
self.assertAllEqual(components[0][i], result1)
|
||||||
|
self.assertAllEqual(components[1][i], result2)
|
||||||
|
self.assertAllEqual(components[2][i], result3)
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
self.evaluate(random_access.at(dataset, index=4))
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testNamedTuple(self):
|
def testNamedTuple(self):
|
||||||
Foo = collections.namedtuple("Foo", ["x", "y"])
|
Foo = collections.namedtuple("Foo", ["x", "y"])
|
||||||
|
|
|
||||||
|
|
@ -1017,7 +1017,7 @@ class DatasetV2(
|
||||||
# pylint: enable=g-import-not-at-top,protected-access
|
# pylint: enable=g-import-not-at-top,protected-access
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def zip(datasets, name=None):
|
def zip(*args, datasets=None, name=None):
|
||||||
"""Creates a `Dataset` by zipping together the given datasets.
|
"""Creates a `Dataset` by zipping together the given datasets.
|
||||||
|
|
||||||
This method has similar semantics to the built-in `zip()` function
|
This method has similar semantics to the built-in `zip()` function
|
||||||
|
|
@ -1026,14 +1026,14 @@ class DatasetV2(
|
||||||
nesting mechanisms are documented
|
nesting mechanisms are documented
|
||||||
[here] (https://www.tensorflow.org/guide/data#dataset_structure).
|
[here] (https://www.tensorflow.org/guide/data#dataset_structure).
|
||||||
|
|
||||||
>>> # The nested structure of the `datasets` argument determines the
|
>>> # The datasets or nested structure of datasets `*args` argument
|
||||||
>>> # structure of elements in the resulting dataset.
|
>>> # determines the structure of elements in the resulting dataset.
|
||||||
>>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
|
>>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
|
||||||
>>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
|
>>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
|
||||||
>>> ds = tf.data.Dataset.zip((a, b))
|
>>> ds = tf.data.Dataset.zip(a, b)
|
||||||
>>> list(ds.as_numpy_iterator())
|
>>> list(ds.as_numpy_iterator())
|
||||||
[(1, 4), (2, 5), (3, 6)]
|
[(1, 4), (2, 5), (3, 6)]
|
||||||
>>> ds = tf.data.Dataset.zip((b, a))
|
>>> ds = tf.data.Dataset.zip(b, a)
|
||||||
>>> list(ds.as_numpy_iterator())
|
>>> list(ds.as_numpy_iterator())
|
||||||
[(4, 1), (5, 2), (6, 3)]
|
[(4, 1), (5, 2), (6, 3)]
|
||||||
>>>
|
>>>
|
||||||
|
|
@ -1041,7 +1041,7 @@ class DatasetV2(
|
||||||
>>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8],
|
>>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8],
|
||||||
... # [9, 10],
|
... # [9, 10],
|
||||||
... # [11, 12] ]
|
... # [11, 12] ]
|
||||||
>>> ds = tf.data.Dataset.zip((a, b, c))
|
>>> ds = tf.data.Dataset.zip(a, b, c)
|
||||||
>>> for element in ds.as_numpy_iterator():
|
>>> for element in ds.as_numpy_iterator():
|
||||||
... print(element)
|
... print(element)
|
||||||
(1, 4, array([7, 8]))
|
(1, 4, array([7, 8]))
|
||||||
|
|
@ -1051,12 +1051,16 @@ class DatasetV2(
|
||||||
>>> # The number of elements in the resulting dataset is the same as
|
>>> # The number of elements in the resulting dataset is the same as
|
||||||
>>> # the size of the smallest dataset in `datasets`.
|
>>> # the size of the smallest dataset in `datasets`.
|
||||||
>>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ]
|
>>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ]
|
||||||
>>> ds = tf.data.Dataset.zip((a, d))
|
>>> ds = tf.data.Dataset.zip(a, d)
|
||||||
>>> list(ds.as_numpy_iterator())
|
>>> list(ds.as_numpy_iterator())
|
||||||
[(1, 13), (2, 14)]
|
[(1, 13), (2, 14)]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets: A (nested) structure of datasets.
|
*args: Datasets or nested structures of datasets to zip together. This
|
||||||
|
can't be set if `datasets` is set.
|
||||||
|
datasets: A (nested) structure of datasets. This can't be set if `*args`
|
||||||
|
is set. Note that this exists only for backwards compatibility and it is
|
||||||
|
preferred to use *args.
|
||||||
name: (Optional.) A name for the tf.data operation.
|
name: (Optional.) A name for the tf.data operation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -1066,8 +1070,16 @@ class DatasetV2(
|
||||||
# dataset_ops).
|
# dataset_ops).
|
||||||
# pylint: disable=g-import-not-at-top,protected-access
|
# pylint: disable=g-import-not-at-top,protected-access
|
||||||
from tensorflow.python.data.ops import zip_op
|
from tensorflow.python.data.ops import zip_op
|
||||||
|
|
||||||
|
if not args and datasets is None:
|
||||||
|
raise TypeError("Must pass at least one dataset to `zip`.")
|
||||||
|
if args and datasets is not None:
|
||||||
|
raise TypeError("Both `*args` and `datasets` cannot be set.")
|
||||||
|
if len(args) == 1:
|
||||||
|
datasets = args[0]
|
||||||
|
elif len(args) > 1:
|
||||||
|
datasets = args
|
||||||
return zip_op._zip(datasets, name)
|
return zip_op._zip(datasets, name)
|
||||||
# pylint: enable=g-import-not-at-top,protected-access
|
|
||||||
|
|
||||||
def concatenate(self, dataset, name=None):
|
def concatenate(self, dataset, name=None):
|
||||||
"""Creates a `Dataset` by concatenating the given dataset with this dataset.
|
"""Creates a `Dataset` by concatenating the given dataset with this dataset.
|
||||||
|
|
@ -3937,8 +3949,8 @@ class DatasetV1(DatasetV2, data_types.DatasetV1):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@functools.wraps(DatasetV2.zip)
|
@functools.wraps(DatasetV2.zip)
|
||||||
def zip(datasets, name=None):
|
def zip(*args, datasets=None, name=None):
|
||||||
return DatasetV1Adapter(DatasetV2.zip(datasets, name=name))
|
return DatasetV1Adapter(DatasetV2.zip(*args, datasets=datasets, name=name))
|
||||||
|
|
||||||
@functools.wraps(DatasetV2.concatenate)
|
@functools.wraps(DatasetV2.concatenate)
|
||||||
def concatenate(self, dataset, name=None):
|
def concatenate(self, dataset, name=None):
|
||||||
|
|
|
||||||
|
|
@ -32,22 +32,26 @@ class _ZipDataset(dataset_ops.DatasetV2):
|
||||||
for ds in nest.flatten(datasets):
|
for ds in nest.flatten(datasets):
|
||||||
if not isinstance(ds, data_types.DatasetV2):
|
if not isinstance(ds, data_types.DatasetV2):
|
||||||
if isinstance(ds, list):
|
if isinstance(ds, list):
|
||||||
raise TypeError("Invalid `datasets`. `datasets` is expected to be a "
|
raise TypeError(
|
||||||
"(nested) structure of `tf.data.Dataset` objects. "
|
"Invalid input to `zip`. Inputs are expected to be (nested)"
|
||||||
"Python `list` is not supported and you should use "
|
" structures of `tf.data.Dataset` objects. Python `list` is"
|
||||||
"`tuple` instead.")
|
" not supported and you should use `tuple` instead."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Invalid `datasets`. `datasets` is expected to be a "
|
raise TypeError(
|
||||||
f"(nested) structure of `tf.data.Dataset` objects "
|
"Invalid input to `zip`. Inputs are expected to be (nested)"
|
||||||
f"but encountered object of type {type(ds)}.")
|
" structures of `tf.data.Dataset` objects but"
|
||||||
|
f" encountered object of type {type(ds)}."
|
||||||
|
)
|
||||||
self._datasets = datasets
|
self._datasets = datasets
|
||||||
self._structure = nest.pack_sequence_as(
|
self._structure = nest.pack_sequence_as(
|
||||||
self._datasets,
|
self._datasets, [ds.element_spec for ds in nest.flatten(self._datasets)]
|
||||||
[ds.element_spec for ds in nest.flatten(self._datasets)])
|
)
|
||||||
self._name = name
|
self._name = name
|
||||||
variant_tensor = gen_dataset_ops.zip_dataset(
|
variant_tensor = gen_dataset_ops.zip_dataset(
|
||||||
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
|
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
|
||||||
**self._common_args)
|
**self._common_args,
|
||||||
|
)
|
||||||
super().__init__(variant_tensor)
|
super().__init__(variant_tensor)
|
||||||
|
|
||||||
def _inputs(self):
|
def _inputs(self):
|
||||||
|
|
|
||||||
|
|
@ -229,6 +229,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -196,6 +196,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -197,6 +197,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -199,6 +199,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -199,6 +199,6 @@ tf_class {
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "zip"
|
name: "zip"
|
||||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user