mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Graduate tf.data.experimental.Counter to tf.data.Dataset.counter.
PiperOrigin-RevId: 470721700
This commit is contained in:
parent
5d31f9f48e
commit
6a5c94b82e
|
|
@ -90,6 +90,7 @@ This release contains contributions from many people at Google, as well as:
|
|||
* Added a new field, `inject_prefetch`, to `tf.data.experimental.OptimizationOptions`. If it is set to `True`, tf.data will now automatically add a `prefetch` transformation to datasets that end in synchronous transformations. This enables data generation to be overlapped with data consumption. This may cause a small increase in memory usage due to buffering. To enable this behavior, set `inject_prefetch=True` in `tf.data.experimental.OptimizationOptions`.
|
||||
* Added a new value to `tf.data.Options.autotune.autotune_algorithm`: STAGE_BASED. If the autotune algorithm is set to STAGE_BASED, then it runs a new algorithm that can get the same performance with lower CPU/memory usage.
|
||||
* Added [`tf.data.experimental.from_list`](https://www.tensorflow.org/api_docs/python/tf/data/experimental/from_list), a new API for creating `Dataset`s from lists of elements.
|
||||
* Graduated `tf.data.experimental.Counter` to [`tf.data.Dataset.counter`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#counter).
|
||||
|
||||
* `tf.distribute`:
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,15 @@
|
|||
# ==============================================================================
|
||||
"""The Counter Dataset."""
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import counter_op
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("data.experimental.Counter", v1=[])
|
||||
@deprecation.deprecated(None, "Use `tf.data.Dataset.counter(...)` instead.")
|
||||
def CounterV2(start=0, step=1, dtype=dtypes.int64):
|
||||
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
|
||||
|
||||
|
|
@ -54,14 +56,11 @@ def CounterV2(start=0, step=1, dtype=dtypes.int64):
|
|||
Returns:
|
||||
A `Dataset` of scalar `dtype` elements.
|
||||
"""
|
||||
with ops.name_scope("counter"):
|
||||
start = ops.convert_to_tensor(start, dtype=dtype, name="start")
|
||||
step = ops.convert_to_tensor(step, dtype=dtype, name="step")
|
||||
return dataset_ops.Dataset.from_tensors(0).repeat(None).scan(
|
||||
start, lambda state, _: (state + step, state))
|
||||
return counter_op.counter(start, step, dtype)
|
||||
|
||||
|
||||
@tf_export(v1=["data.experimental.Counter"])
|
||||
@deprecation.deprecated(None, "Use `tf.data.Dataset.counter(...)` instead.")
|
||||
def CounterV1(start=0, step=1, dtype=dtypes.int64):
|
||||
return dataset_ops.DatasetV1Adapter(CounterV2(start, step, dtype))
|
||||
|
||||
|
|
|
|||
|
|
@ -165,6 +165,17 @@ tf_py_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "counter_test",
|
||||
size = "small",
|
||||
srcs = ["counter_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "dataset_spec_test",
|
||||
size = "small",
|
||||
|
|
|
|||
44
tensorflow/python/data/kernel_tests/counter_test.py
Normal file
44
tensorflow/python/data/kernel_tests/counter_test.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.Dataset.counter`."""
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CounterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(start=3, step=4, expected_output=[[3, 7, 11]]) +
|
||||
combinations.combine(start=0, step=-1, expected_output=[[0, -1, -2]]))
|
||||
)
|
||||
def testCounter(self, start, step, expected_output):
|
||||
dataset = dataset_ops.Dataset.counter(start, step)
|
||||
self.assertEqual(
|
||||
[], dataset_ops.get_legacy_output_shapes(dataset).as_list())
|
||||
self.assertEqual(dtypes.int64, dataset_ops.get_legacy_output_types(dataset))
|
||||
get_next = self.getNext(dataset)
|
||||
for expected in expected_output:
|
||||
self.assertEqual(expected, self.evaluate(get_next()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
@ -16,11 +16,17 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "counter_op",
|
||||
srcs = ["counter_op.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset_ops",
|
||||
srcs = ["dataset_ops.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":counter_op",
|
||||
":from_tensor_slices_op",
|
||||
":iterator_ops",
|
||||
":load_ops",
|
||||
|
|
|
|||
26
tensorflow/python/data/ops/counter_op.py
Normal file
26
tensorflow/python/data/ops/counter_op.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The implementation of `tf.data.Dataset.counter`."""
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
def counter(start, step, dtype, name=None):
|
||||
with ops.name_scope("counter"):
|
||||
start = ops.convert_to_tensor(start, dtype=dtype, name="start")
|
||||
step = ops.convert_to_tensor(step, dtype=dtype, name="step")
|
||||
return (dataset_ops.Dataset.from_tensors(0, name=name).repeat(None).scan(
|
||||
start, lambda state, _: (state + step, state)))
|
||||
|
|
@ -1288,6 +1288,46 @@ class DatasetV2(
|
|||
"""
|
||||
return ConcatenateDataset(self, dataset, name=name)
|
||||
|
||||
@staticmethod
|
||||
def counter(start=0, step=1, dtype=dtypes.int64, name=None):
|
||||
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
|
||||
|
||||
Unlike `tf.data.Dataset.range`, which stops at some ending number,
|
||||
`tf.data.Dataset.counter` produces elements indefinitely.
|
||||
|
||||
>>> dataset = tf.data.experimental.Counter().take(5)
|
||||
>>> list(dataset.as_numpy_iterator())
|
||||
[0, 1, 2, 3, 4]
|
||||
>>> dataset.element_spec
|
||||
TensorSpec(shape=(), dtype=tf.int64, name=None)
|
||||
>>> dataset = tf.data.experimental.Counter(dtype=tf.int32)
|
||||
>>> dataset.element_spec
|
||||
TensorSpec(shape=(), dtype=tf.int32, name=None)
|
||||
>>> dataset = tf.data.experimental.Counter(start=2).take(5)
|
||||
>>> list(dataset.as_numpy_iterator())
|
||||
[2, 3, 4, 5, 6]
|
||||
>>> dataset = tf.data.experimental.Counter(start=2, step=5).take(5)
|
||||
>>> list(dataset.as_numpy_iterator())
|
||||
[2, 7, 12, 17, 22]
|
||||
>>> dataset = tf.data.experimental.Counter(start=10, step=-1).take(5)
|
||||
>>> list(dataset.as_numpy_iterator())
|
||||
[10, 9, 8, 7, 6]
|
||||
|
||||
Args:
|
||||
start: (Optional.) The starting value for the counter. Defaults to 0.
|
||||
step: (Optional.) The step size for the counter. Defaults to 1.
|
||||
dtype: (Optional.) The data type for counter elements. Defaults to
|
||||
`tf.int64`.
|
||||
name: (Optional.) A name for the tf.data operation.
|
||||
|
||||
Returns:
|
||||
A `Dataset` of scalar `dtype` elements.
|
||||
"""
|
||||
# Loaded lazily due to a circular dependency (dataset_ops -> counter_op
|
||||
# -> dataset_ops).
|
||||
from tensorflow.python.data.ops import counter_op # pylint: disable=g-import-not-at-top
|
||||
return counter_op.counter(start, step, dtype, name=name)
|
||||
|
||||
def prefetch(self, buffer_size, name=None):
|
||||
"""Creates a `Dataset` that prefetches elements from this dataset.
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ tf_class {
|
|||
name: "concatenate"
|
||||
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "enumerate"
|
||||
argspec: "args=[\'self\', \'start\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user