mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Introduce TraceType for Iterator
PiperOrigin-RevId: 403194058 Change-Id: Ieda3c876f2e18bc19e8cecc23e1725ec1c146d48
This commit is contained in:
parent
23260248e9
commit
483ac0b5bb
|
|
@ -34,6 +34,7 @@ from tensorflow.python.framework import type_spec
|
|||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.training.saver import BaseSaverBuilder
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.types import trace
|
||||
from tensorflow.python.util import _pywrap_utils
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import lazy_loader
|
||||
|
|
@ -672,7 +673,30 @@ class IteratorBase(collections_abc.Iterator, trackable.Trackable,
|
|||
raise NotImplementedError("Iterator.get_next_as_optional()")
|
||||
|
||||
|
||||
class OwnedIterator(IteratorBase):
|
||||
# TODO(b/202447704): Merge into IteratorSpec.
|
||||
class IteratorType(trace.TraceType):
|
||||
"""Represents Iterators (and specs) for function tracing purposes."""
|
||||
|
||||
def __init__(self, spec, local_id):
|
||||
self._components = (spec, local_id)
|
||||
|
||||
def is_subtype_of(self, other):
|
||||
# TODO(b/202429845): Implement for subtyping.
|
||||
return self == other
|
||||
|
||||
def most_specific_common_supertype(self, others):
|
||||
# TODO(b/202430155) Implement for shape relaxation.
|
||||
return None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._components)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return isinstance(
|
||||
other, IteratorType) and self._components == other._components
|
||||
|
||||
|
||||
class OwnedIterator(IteratorBase, trace.SupportsTracingType):
|
||||
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
|
||||
|
||||
The iterator resource created through `OwnedIterator` is owned by the Python
|
||||
|
|
@ -876,9 +900,14 @@ class OwnedIterator(IteratorBase):
|
|||
|
||||
return {"ITERATOR": _saveable_factory}
|
||||
|
||||
def __tf_tracing_type__(self, tracing_context):
|
||||
return IteratorType(
|
||||
self._type_spec,
|
||||
tracing_context.get_local_id(self._iterator_resource._id)) # pylint:disable=protected-access
|
||||
|
||||
|
||||
@tf_export("data.IteratorSpec", v1=[])
|
||||
class IteratorSpec(type_spec.TypeSpec):
|
||||
class IteratorSpec(type_spec.TypeSpec, trace.SupportsTracingType):
|
||||
"""Type specification for `tf.data.Iterator`.
|
||||
|
||||
For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
|
||||
|
|
@ -931,6 +960,11 @@ class IteratorSpec(type_spec.TypeSpec):
|
|||
def from_value(value):
|
||||
return IteratorSpec(value.element_spec) # pylint: disable=protected-access
|
||||
|
||||
def __tf_tracing_type__(self, tracing_context):
|
||||
# TODO(b/202772221): Validate and enforce this assumption of uniqueness per
|
||||
# spec instance.
|
||||
return IteratorType(self, tracing_context.get_local_id(id(self)))
|
||||
|
||||
|
||||
# TODO(b/71645805): Expose trackable stateful objects from dataset.
|
||||
class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
|
||||
|
|
|
|||
|
|
@ -541,10 +541,13 @@ cuda_py_test(
|
|||
":function_trace_type",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/framework:combinations",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/framework:tensor_spec",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/platform:client_testlib",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,10 +15,14 @@
|
|||
"""Tests for function_trace_type."""
|
||||
|
||||
import timeit
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import function_trace_type
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
|
@ -29,8 +33,24 @@ from tensorflow.python.ops.ragged import ragged_tensor
|
|||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CacheKeyGenerationTest(test.TestCase):
|
||||
class CacheKeyGenerationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testIteratorAliasing(self):
|
||||
it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
|
||||
it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
|
||||
|
||||
self.assertEqual(
|
||||
function_trace_type.get_arg_spec((it1, it1), False, False, True),
|
||||
function_trace_type.get_arg_spec((it2, it2), False, False, True))
|
||||
self.assertEqual(
|
||||
function_trace_type.get_arg_spec((it1, it2), False, False, True),
|
||||
function_trace_type.get_arg_spec((it2, it1), False, False, True))
|
||||
self.assertNotEqual(
|
||||
function_trace_type.get_arg_spec((it1, it1), False, False, True),
|
||||
function_trace_type.get_arg_spec((it1, it2), False, False, True))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testCompositeAndSpec(self):
|
||||
composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[1, 2, 3], row_splits=[0, 2, 3])
|
||||
|
|
@ -40,6 +60,7 @@ class CacheKeyGenerationTest(test.TestCase):
|
|||
function_trace_type.get_arg_spec(composite_tensor, False, False, True),
|
||||
function_trace_type.get_arg_spec(spec, False, False, True))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testVariableAliasing(self):
|
||||
v1 = resource_variable_ops.ResourceVariable([1])
|
||||
v2 = resource_variable_ops.ResourceVariable([1])
|
||||
|
|
@ -59,6 +80,7 @@ class CacheKeyGenerationTest(test.TestCase):
|
|||
self.assertEqual(all_unique, all_unique_again)
|
||||
self.assertEqual(all_same, all_same_again)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testTensorEquality(self):
|
||||
context = function_trace_type.SignatureContext()
|
||||
tensor_a = array_ops.zeros([11, 3, 5],
|
||||
|
|
@ -75,6 +97,7 @@ class CacheKeyGenerationTest(test.TestCase):
|
|||
self.assertNotEqual(tensor_b, tensor_c)
|
||||
self.assertEqual(tensor_a, tensor_d)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testTensorAndSpecEquality(self):
|
||||
context = function_trace_type.SignatureContext()
|
||||
tensor = array_ops.zeros([11, 3, 5],
|
||||
|
|
@ -87,6 +110,7 @@ class CacheKeyGenerationTest(test.TestCase):
|
|||
self.assertEqual(tensor, spec)
|
||||
self.assertNotEqual(tensor, spec_with_name)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testTupleEquality(self):
|
||||
trace_a = function_trace_type.get_arg_spec((1, 2, 3, 4), False, False, True)
|
||||
trace_b = function_trace_type.get_arg_spec((1, 2, 2, 4), False, False, True)
|
||||
|
|
@ -98,6 +122,7 @@ class CacheKeyGenerationTest(test.TestCase):
|
|||
self.assertNotEqual(trace_b, trace_c)
|
||||
self.assertEqual(trace_a, trace_d)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testListEquality(self):
|
||||
trace_a = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)
|
||||
trace_b = function_trace_type.get_arg_spec([1, 2, 2, 4], False, False, True)
|
||||
|
|
@ -109,6 +134,7 @@ class CacheKeyGenerationTest(test.TestCase):
|
|||
self.assertNotEqual(trace_b, trace_c)
|
||||
self.assertEqual(trace_a, trace_d)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testDictEquality(self):
|
||||
trace_a = function_trace_type.get_arg_spec({1: 2, 3: 4}, False, False, True)
|
||||
trace_b = function_trace_type.get_arg_spec({1: 2, 3: 2}, False, False, True)
|
||||
|
|
@ -120,6 +146,7 @@ class CacheKeyGenerationTest(test.TestCase):
|
|||
self.assertNotEqual(trace_b, trace_c)
|
||||
self.assertEqual(trace_a, trace_d)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testComplexStruct(self):
|
||||
struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
|
||||
trace_a = function_trace_type.get_arg_spec(struct, False, False, True)
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ class SupportsTracingType(Protocol):
|
|||
classes according to the behaviour specified by their TraceType.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __tf_tracing_type__(self, context: TracingContext) -> TraceType:
|
||||
pass
|
||||
raise NotImplementedError(
|
||||
"Class inheriting SupportsTracingType must implement __tf_tracing_type__"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ path: "tensorflow.data.IteratorSpec"
|
|||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.IteratorSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python.types.trace.SupportsTracingType\'>"
|
||||
is_instance: "<class \'typing.Protocol\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ tf_module {
|
|||
}
|
||||
member {
|
||||
name: "IteratorSpec"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'typing._ProtocolMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "Options"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user