Graduate tf.data.experimental.Counter to tf.data.Dataset.counter.

PiperOrigin-RevId: 470721700
This commit is contained in:
Matt Callanan 2022-08-29 08:11:57 -07:00 committed by TensorFlower Gardener
parent 5d31f9f48e
commit 6a5c94b82e
22 changed files with 193 additions and 6 deletions

View File

@ -90,6 +90,7 @@ This release contains contributions from many people at Google, as well as:
* Added a new field, `inject_prefetch`, to `tf.data.experimental.OptimizationOptions`. If it is set to `True`, tf.data will now automatically add a `prefetch` transformation to datasets that end in synchronous transformations. This enables data generation to be overlapped with data consumption. This may cause a small increase in memory usage due to buffering. To enable this behavior, set `inject_prefetch=True` in `tf.data.experimental.OptimizationOptions`.
* Added a new value to `tf.data.Options.autotune.autotune_algorithm`: STAGE_BASED. If the autotune algorithm is set to STAGE_BASED, then it runs a new algorithm that can get the same performance with lower CPU/memory usage.
* Added [`tf.data.experimental.from_list`](https://www.tensorflow.org/api_docs/python/tf/data/experimental/from_list), a new API for creating `Dataset`s from lists of elements.
* Graduated `tf.data.experimental.Counter` to [`tf.data.Dataset.counter`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#counter).
* `tf.distribute`:

View File

@ -14,13 +14,15 @@
# ==============================================================================
"""The Counter Dataset."""
from tensorflow.python import tf2
from tensorflow.python.data.ops import counter_op
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.Counter", v1=[])
@deprecation.deprecated(None, "Use `tf.data.Dataset.counter(...)` instead.")
def CounterV2(start=0, step=1, dtype=dtypes.int64):
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
@ -54,14 +56,11 @@ def CounterV2(start=0, step=1, dtype=dtypes.int64):
Returns:
A `Dataset` of scalar `dtype` elements.
"""
with ops.name_scope("counter"):
start = ops.convert_to_tensor(start, dtype=dtype, name="start")
step = ops.convert_to_tensor(step, dtype=dtype, name="step")
return dataset_ops.Dataset.from_tensors(0).repeat(None).scan(
start, lambda state, _: (state + step, state))
return counter_op.counter(start, step, dtype)
@tf_export(v1=["data.experimental.Counter"])
@deprecation.deprecated(None, "Use `tf.data.Dataset.counter(...)` instead.")
def CounterV1(start=0, step=1, dtype=dtypes.int64):
return dataset_ops.DatasetV1Adapter(CounterV2(start, step, dtype))

View File

@ -165,6 +165,17 @@ tf_py_test(
],
)
tf_py_test(
name = "counter_test",
size = "small",
srcs = ["counter_test.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python/data/kernel_tests:test_base",
],
)
tf_py_test(
name = "dataset_spec_test",
size = "small",

View File

@ -0,0 +1,44 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset.counter`."""
from absl.testing import parameterized
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
class CounterTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(start=3, step=4, expected_output=[[3, 7, 11]]) +
combinations.combine(start=0, step=-1, expected_output=[[0, -1, -2]]))
)
def testCounter(self, start, step, expected_output):
dataset = dataset_ops.Dataset.counter(start, step)
self.assertEqual(
[], dataset_ops.get_legacy_output_shapes(dataset).as_list())
self.assertEqual(dtypes.int64, dataset_ops.get_legacy_output_types(dataset))
get_next = self.getNext(dataset)
for expected in expected_output:
self.assertEqual(expected, self.evaluate(get_next()))
if __name__ == "__main__":
test.main()

View File

@ -16,11 +16,17 @@ py_library(
],
)
py_library(
name = "counter_op",
srcs = ["counter_op.py"],
)
py_library(
name = "dataset_ops",
srcs = ["dataset_ops.py"],
srcs_version = "PY3",
deps = [
":counter_op",
":from_tensor_slices_op",
":iterator_ops",
":load_ops",

View File

@ -0,0 +1,26 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The implementation of `tf.data.Dataset.counter`."""
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
def counter(start, step, dtype, name=None):
with ops.name_scope("counter"):
start = ops.convert_to_tensor(start, dtype=dtype, name="start")
step = ops.convert_to_tensor(step, dtype=dtype, name="step")
return (dataset_ops.Dataset.from_tensors(0, name=name).repeat(None).scan(
start, lambda state, _: (state + step, state)))

View File

@ -1288,6 +1288,46 @@ class DatasetV2(
"""
return ConcatenateDataset(self, dataset, name=name)
@staticmethod
def counter(start=0, step=1, dtype=dtypes.int64, name=None):
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
Unlike `tf.data.Dataset.range`, which stops at some ending number,
`tf.data.Dataset.counter` produces elements indefinitely.
>>> dataset = tf.data.experimental.Counter().take(5)
>>> list(dataset.as_numpy_iterator())
[0, 1, 2, 3, 4]
>>> dataset.element_spec
TensorSpec(shape=(), dtype=tf.int64, name=None)
>>> dataset = tf.data.experimental.Counter(dtype=tf.int32)
>>> dataset.element_spec
TensorSpec(shape=(), dtype=tf.int32, name=None)
>>> dataset = tf.data.experimental.Counter(start=2).take(5)
>>> list(dataset.as_numpy_iterator())
[2, 3, 4, 5, 6]
>>> dataset = tf.data.experimental.Counter(start=2, step=5).take(5)
>>> list(dataset.as_numpy_iterator())
[2, 7, 12, 17, 22]
>>> dataset = tf.data.experimental.Counter(start=10, step=-1).take(5)
>>> list(dataset.as_numpy_iterator())
[10, 9, 8, 7, 6]
Args:
start: (Optional.) The starting value for the counter. Defaults to 0.
step: (Optional.) The step size for the counter. Defaults to 1.
dtype: (Optional.) The data type for counter elements. Defaults to
`tf.int64`.
name: (Optional.) A name for the tf.data operation.
Returns:
A `Dataset` of scalar `dtype` elements.
"""
# Loaded lazily due to a circular dependency (dataset_ops -> counter_op
# -> dataset_ops).
from tensorflow.python.data.ops import counter_op # pylint: disable=g-import-not-at-top
return counter_op.counter(start, step, dtype, name=name)
def prefetch(self, buffer_size, name=None):
"""Creates a `Dataset` that prefetches elements from this dataset.

View File

@ -55,6 +55,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -57,6 +57,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -57,6 +57,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -57,6 +57,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -57,6 +57,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -57,6 +57,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -57,6 +57,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -42,6 +42,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -44,6 +44,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -43,6 +43,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -44,6 +44,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -44,6 +44,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -45,6 +45,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -44,6 +44,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "

View File

@ -45,6 +45,10 @@ tf_class {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "counter"
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "enumerate"
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "