mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
#tf-data Support Python-style zipping in tf.data.Dataset.zip
The current behavior is unchanged to avoid affecting backward compatibility but only added a new behavior to `tf.data.Dataset.zip` Python op for accepting a non-nested structure of datasets without doing the tuple. PiperOrigin-RevId: 515389129
This commit is contained in:
parent
253052b195
commit
fe1cc8d2c4
|
|
@ -85,6 +85,11 @@
|
|||
`tf.nn.safe_embedding_lookup_sparse`, which enables a simplified and
|
||||
typically faster lookup procedure.
|
||||
|
||||
* `tf.data`
|
||||
|
||||
* `tf.data.Dataset.zip` now supports Python-style zipping, i.e.
|
||||
`Dataset.zip(a, b, c)`.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
|
|
|
|||
|
|
@ -134,15 +134,53 @@ class ZipTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
dataset = dataset_ops.Dataset.zip((x, y), name="zip")
|
||||
self.assertDatasetProduces(dataset, [(4, 2)])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations())
|
||||
)
|
||||
def testZipWithArgsAndDataset(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r"Both `\*args` and `datasets` cannot be set."
|
||||
):
|
||||
dataset_ops.Dataset.zip(
|
||||
dataset_ops.Dataset.range(1, 4),
|
||||
dataset_ops.Dataset.range(4, 7),
|
||||
datasets=(
|
||||
dataset_ops.Dataset.range(1, 4),
|
||||
dataset_ops.Dataset.range(4, 7),
|
||||
),
|
||||
)
|
||||
|
||||
class ZipCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations())
|
||||
)
|
||||
def testZipBasicWithNoInput(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r"Must pass at least one dataset to `zip`."
|
||||
):
|
||||
dataset_ops.Dataset.zip()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations())
|
||||
)
|
||||
def InvalidZipInputList(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Invalid input to `zip`. Inputs are expected to be (nested)"
|
||||
r" structures of `tf.data.Dataset` objects. Python `list` is"
|
||||
r" not supported and you should use `tuple` instead.",
|
||||
):
|
||||
dataset_ops.Dataset.zip([1, 2, 3], [4, 5, 6])
|
||||
|
||||
|
||||
class ZipCheckpointTest(
|
||||
checkpoint_test_base.CheckpointTestBase, parameterized.TestCase
|
||||
):
|
||||
|
||||
def _build_dataset(self, arr, options=None):
|
||||
components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array(arr)
|
||||
np.array(arr),
|
||||
]
|
||||
datasets = [
|
||||
dataset_ops.Dataset.from_tensor_slices(component)
|
||||
|
|
@ -200,6 +238,19 @@ class ZipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations()))
|
||||
def testZipBasicWithoutTuple(self):
|
||||
dataset = dataset_ops.Dataset.zip(
|
||||
dataset_ops.Dataset.range(1, 4), dataset_ops.Dataset.range(4, 7)
|
||||
)
|
||||
expected_dataset = [(1, 4), (2, 5), (3, 6)]
|
||||
for i in range(3):
|
||||
self.assertEqual(
|
||||
self.evaluate(random_access.at(dataset, index=i)), expected_dataset[i]
|
||||
)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations())
|
||||
)
|
||||
def testZipEqual(self):
|
||||
components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
|
|
@ -246,6 +297,28 @@ class ZipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(random_access.at(dataset, index=4))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNestedWithoutTuple(self):
|
||||
components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37.0, 38.0, 39.0, 40.0]),
|
||||
]
|
||||
datasets = [
|
||||
dataset_ops.Dataset.from_tensor_slices(component)
|
||||
for component in components
|
||||
]
|
||||
dataset = dataset_ops.Dataset.zip(datasets[0], (datasets[1], datasets[2]))
|
||||
for i in range(4):
|
||||
result1, (result2, result3) = self.evaluate(
|
||||
random_access.at(dataset, index=i)
|
||||
)
|
||||
self.assertAllEqual(components[0][i], result1)
|
||||
self.assertAllEqual(components[1][i], result2)
|
||||
self.assertAllEqual(components[2][i], result3)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(random_access.at(dataset, index=4))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNamedTuple(self):
|
||||
Foo = collections.namedtuple("Foo", ["x", "y"])
|
||||
|
|
|
|||
|
|
@ -1017,7 +1017,7 @@ class DatasetV2(
|
|||
# pylint: enable=g-import-not-at-top,protected-access
|
||||
|
||||
@staticmethod
|
||||
def zip(datasets, name=None):
|
||||
def zip(*args, datasets=None, name=None):
|
||||
"""Creates a `Dataset` by zipping together the given datasets.
|
||||
|
||||
This method has similar semantics to the built-in `zip()` function
|
||||
|
|
@ -1026,14 +1026,14 @@ class DatasetV2(
|
|||
nesting mechanisms are documented
|
||||
[here] (https://www.tensorflow.org/guide/data#dataset_structure).
|
||||
|
||||
>>> # The nested structure of the `datasets` argument determines the
|
||||
>>> # structure of elements in the resulting dataset.
|
||||
>>> # The datasets or nested structure of datasets `*args` argument
|
||||
>>> # determines the structure of elements in the resulting dataset.
|
||||
>>> a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
|
||||
>>> b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
|
||||
>>> ds = tf.data.Dataset.zip((a, b))
|
||||
>>> ds = tf.data.Dataset.zip(a, b)
|
||||
>>> list(ds.as_numpy_iterator())
|
||||
[(1, 4), (2, 5), (3, 6)]
|
||||
>>> ds = tf.data.Dataset.zip((b, a))
|
||||
>>> ds = tf.data.Dataset.zip(b, a)
|
||||
>>> list(ds.as_numpy_iterator())
|
||||
[(4, 1), (5, 2), (6, 3)]
|
||||
>>>
|
||||
|
|
@ -1041,7 +1041,7 @@ class DatasetV2(
|
|||
>>> c = tf.data.Dataset.range(7, 13).batch(2) # ==> [ [7, 8],
|
||||
... # [9, 10],
|
||||
... # [11, 12] ]
|
||||
>>> ds = tf.data.Dataset.zip((a, b, c))
|
||||
>>> ds = tf.data.Dataset.zip(a, b, c)
|
||||
>>> for element in ds.as_numpy_iterator():
|
||||
... print(element)
|
||||
(1, 4, array([7, 8]))
|
||||
|
|
@ -1051,12 +1051,16 @@ class DatasetV2(
|
|||
>>> # The number of elements in the resulting dataset is the same as
|
||||
>>> # the size of the smallest dataset in `datasets`.
|
||||
>>> d = tf.data.Dataset.range(13, 15) # ==> [ 13, 14 ]
|
||||
>>> ds = tf.data.Dataset.zip((a, d))
|
||||
>>> ds = tf.data.Dataset.zip(a, d)
|
||||
>>> list(ds.as_numpy_iterator())
|
||||
[(1, 13), (2, 14)]
|
||||
|
||||
Args:
|
||||
datasets: A (nested) structure of datasets.
|
||||
*args: Datasets or nested structures of datasets to zip together. This
|
||||
can't be set if `datasets` is set.
|
||||
datasets: A (nested) structure of datasets. This can't be set if `*args`
|
||||
is set. Note that this exists only for backwards compatibility and it is
|
||||
preferred to use *args.
|
||||
name: (Optional.) A name for the tf.data operation.
|
||||
|
||||
Returns:
|
||||
|
|
@ -1066,8 +1070,16 @@ class DatasetV2(
|
|||
# dataset_ops).
|
||||
# pylint: disable=g-import-not-at-top,protected-access
|
||||
from tensorflow.python.data.ops import zip_op
|
||||
|
||||
if not args and datasets is None:
|
||||
raise TypeError("Must pass at least one dataset to `zip`.")
|
||||
if args and datasets is not None:
|
||||
raise TypeError("Both `*args` and `datasets` cannot be set.")
|
||||
if len(args) == 1:
|
||||
datasets = args[0]
|
||||
elif len(args) > 1:
|
||||
datasets = args
|
||||
return zip_op._zip(datasets, name)
|
||||
# pylint: enable=g-import-not-at-top,protected-access
|
||||
|
||||
def concatenate(self, dataset, name=None):
|
||||
"""Creates a `Dataset` by concatenating the given dataset with this dataset.
|
||||
|
|
@ -3937,8 +3949,8 @@ class DatasetV1(DatasetV2, data_types.DatasetV1):
|
|||
|
||||
@staticmethod
|
||||
@functools.wraps(DatasetV2.zip)
|
||||
def zip(datasets, name=None):
|
||||
return DatasetV1Adapter(DatasetV2.zip(datasets, name=name))
|
||||
def zip(*args, datasets=None, name=None):
|
||||
return DatasetV1Adapter(DatasetV2.zip(*args, datasets=datasets, name=name))
|
||||
|
||||
@functools.wraps(DatasetV2.concatenate)
|
||||
def concatenate(self, dataset, name=None):
|
||||
|
|
|
|||
|
|
@ -32,22 +32,26 @@ class _ZipDataset(dataset_ops.DatasetV2):
|
|||
for ds in nest.flatten(datasets):
|
||||
if not isinstance(ds, data_types.DatasetV2):
|
||||
if isinstance(ds, list):
|
||||
raise TypeError("Invalid `datasets`. `datasets` is expected to be a "
|
||||
"(nested) structure of `tf.data.Dataset` objects. "
|
||||
"Python `list` is not supported and you should use "
|
||||
"`tuple` instead.")
|
||||
raise TypeError(
|
||||
"Invalid input to `zip`. Inputs are expected to be (nested)"
|
||||
" structures of `tf.data.Dataset` objects. Python `list` is"
|
||||
" not supported and you should use `tuple` instead."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid `datasets`. `datasets` is expected to be a "
|
||||
f"(nested) structure of `tf.data.Dataset` objects "
|
||||
f"but encountered object of type {type(ds)}.")
|
||||
raise TypeError(
|
||||
"Invalid input to `zip`. Inputs are expected to be (nested)"
|
||||
" structures of `tf.data.Dataset` objects but"
|
||||
f" encountered object of type {type(ds)}."
|
||||
)
|
||||
self._datasets = datasets
|
||||
self._structure = nest.pack_sequence_as(
|
||||
self._datasets,
|
||||
[ds.element_spec for ds in nest.flatten(self._datasets)])
|
||||
self._datasets, [ds.element_spec for ds in nest.flatten(self._datasets)]
|
||||
)
|
||||
self._name = name
|
||||
variant_tensor = gen_dataset_ops.zip_dataset(
|
||||
[ds._variant_tensor for ds in nest.flatten(self._datasets)],
|
||||
**self._common_args)
|
||||
**self._common_args,
|
||||
)
|
||||
super().__init__(variant_tensor)
|
||||
|
||||
def _inputs(self):
|
||||
|
|
|
|||
|
|
@ -229,6 +229,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -196,6 +196,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,6 +198,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -197,6 +197,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,6 +198,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,6 +198,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -199,6 +199,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,6 +198,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -199,6 +199,6 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "zip"
|
||||
argspec: "args=[\'datasets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user