Introduce TraceType for Iterator

PiperOrigin-RevId: 403194058
Change-Id: Ieda3c876f2e18bc19e8cecc23e1725ec1c146d48
This commit is contained in:
Faizan Muhammad 2021-10-14 14:55:48 -07:00 committed by TensorFlower Gardener
parent 23260248e9
commit 483ac0b5bb
6 changed files with 74 additions and 8 deletions

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import type_spec
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.types import trace
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
@ -672,7 +673,30 @@ class IteratorBase(collections_abc.Iterator, trackable.Trackable,
raise NotImplementedError("Iterator.get_next_as_optional()")
class OwnedIterator(IteratorBase):
# TODO(b/202447704): Merge into IteratorSpec.
class IteratorType(trace.TraceType):
"""Represents Iterators (and specs) for function tracing purposes."""
def __init__(self, spec, local_id):
self._components = (spec, local_id)
def is_subtype_of(self, other):
# TODO(b/202429845): Implement for subtyping.
return self == other
def most_specific_common_supertype(self, others):
# TODO(b/202430155) Implement for shape relaxation.
return None
def __hash__(self) -> int:
return hash(self._components)
def __eq__(self, other) -> bool:
return isinstance(
other, IteratorType) and self._components == other._components
class OwnedIterator(IteratorBase, trace.SupportsTracingType):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
The iterator resource created through `OwnedIterator` is owned by the Python
@ -876,9 +900,14 @@ class OwnedIterator(IteratorBase):
return {"ITERATOR": _saveable_factory}
def __tf_tracing_type__(self, tracing_context):
return IteratorType(
self._type_spec,
tracing_context.get_local_id(self._iterator_resource._id)) # pylint:disable=protected-access
@tf_export("data.IteratorSpec", v1=[])
class IteratorSpec(type_spec.TypeSpec):
class IteratorSpec(type_spec.TypeSpec, trace.SupportsTracingType):
"""Type specification for `tf.data.Iterator`.
For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
@ -931,6 +960,11 @@ class IteratorSpec(type_spec.TypeSpec):
def from_value(value):
return IteratorSpec(value.element_spec) # pylint: disable=protected-access
def __tf_tracing_type__(self, tracing_context):
# TODO(b/202772221): Validate and enforce this assumption of uniqueness per
# spec instance.
return IteratorType(self, tracing_context.get_local_id(id(self)))
# TODO(b/71645805): Expose trackable stateful objects from dataset.
class _IteratorSaveable(BaseSaverBuilder.SaveableObject):

View File

@ -541,10 +541,13 @@ cuda_py_test(
":function_trace_type",
"//tensorflow/python:array_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/framework:combinations",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/platform:client_testlib",
],
)

View File

@ -15,10 +15,14 @@
"""Tests for function_trace_type."""
import timeit
from absl.testing import parameterized
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.eager import function_trace_type
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
@ -29,8 +33,24 @@ from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
class CacheKeyGenerationTest(test.TestCase):
class CacheKeyGenerationTest(test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['eager']))
def testIteratorAliasing(self):
it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
self.assertEqual(
function_trace_type.get_arg_spec((it1, it1), False, False, True),
function_trace_type.get_arg_spec((it2, it2), False, False, True))
self.assertEqual(
function_trace_type.get_arg_spec((it1, it2), False, False, True),
function_trace_type.get_arg_spec((it2, it1), False, False, True))
self.assertNotEqual(
function_trace_type.get_arg_spec((it1, it1), False, False, True),
function_trace_type.get_arg_spec((it1, it2), False, False, True))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testCompositeAndSpec(self):
composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3], row_splits=[0, 2, 3])
@ -40,6 +60,7 @@ class CacheKeyGenerationTest(test.TestCase):
function_trace_type.get_arg_spec(composite_tensor, False, False, True),
function_trace_type.get_arg_spec(spec, False, False, True))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testVariableAliasing(self):
v1 = resource_variable_ops.ResourceVariable([1])
v2 = resource_variable_ops.ResourceVariable([1])
@ -59,6 +80,7 @@ class CacheKeyGenerationTest(test.TestCase):
self.assertEqual(all_unique, all_unique_again)
self.assertEqual(all_same, all_same_again)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testTensorEquality(self):
context = function_trace_type.SignatureContext()
tensor_a = array_ops.zeros([11, 3, 5],
@ -75,6 +97,7 @@ class CacheKeyGenerationTest(test.TestCase):
self.assertNotEqual(tensor_b, tensor_c)
self.assertEqual(tensor_a, tensor_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testTensorAndSpecEquality(self):
context = function_trace_type.SignatureContext()
tensor = array_ops.zeros([11, 3, 5],
@ -87,6 +110,7 @@ class CacheKeyGenerationTest(test.TestCase):
self.assertEqual(tensor, spec)
self.assertNotEqual(tensor, spec_with_name)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testTupleEquality(self):
trace_a = function_trace_type.get_arg_spec((1, 2, 3, 4), False, False, True)
trace_b = function_trace_type.get_arg_spec((1, 2, 2, 4), False, False, True)
@ -98,6 +122,7 @@ class CacheKeyGenerationTest(test.TestCase):
self.assertNotEqual(trace_b, trace_c)
self.assertEqual(trace_a, trace_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testListEquality(self):
trace_a = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)
trace_b = function_trace_type.get_arg_spec([1, 2, 2, 4], False, False, True)
@ -109,6 +134,7 @@ class CacheKeyGenerationTest(test.TestCase):
self.assertNotEqual(trace_b, trace_c)
self.assertEqual(trace_a, trace_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testDictEquality(self):
trace_a = function_trace_type.get_arg_spec({1: 2, 3: 4}, False, False, True)
trace_b = function_trace_type.get_arg_spec({1: 2, 3: 2}, False, False, True)
@ -120,6 +146,7 @@ class CacheKeyGenerationTest(test.TestCase):
self.assertNotEqual(trace_b, trace_c)
self.assertEqual(trace_a, trace_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testComplexStruct(self):
struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
trace_a = function_trace_type.get_arg_spec(struct, False, False, True)

View File

@ -77,6 +77,7 @@ class SupportsTracingType(Protocol):
classes according to the behaviour specified by their TraceType.
"""
@abc.abstractmethod
def __tf_tracing_type__(self, context: TracingContext) -> TraceType:
pass
raise NotImplementedError(
"Class inheriting SupportsTracingType must implement __tf_tracing_type__"
)

View File

@ -2,7 +2,8 @@ path: "tensorflow.data.IteratorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.IteratorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python.types.trace.SupportsTracingType\'>"
is_instance: "<class \'typing.Protocol\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"

View File

@ -26,7 +26,7 @@ tf_module {
}
member {
name: "IteratorSpec"
mtype: "<type \'type\'>"
mtype: "<class \'typing._ProtocolMeta\'>"
}
member {
name: "Options"