Export TrackableView to the public API.

PiperOrigin-RevId: 463620038
This commit is contained in:
Yaning Liang 2022-07-27 10:24:18 -07:00 committed by TensorFlower Gardener
parent 0330ae4096
commit f1656e7642
6 changed files with 63 additions and 5 deletions

View File

@ -132,6 +132,11 @@
difference range from 8 to 100 times depending on the size of k. When difference range from 8 to 100 times depending on the size of k. When
running on CPU and GPU, a non-optimized XLA kernel is used. running on CPU and GPU, a non-optimized XLA kernel is used.
* `tf.train`:
* Added `tf.train.TrackableView` which allows users to inspect the
TensorFlow Trackable object (e.g. `tf.Module`, Keras Layers and models).
* `tf.vectorized_map`: * `tf.vectorized_map`:
* Added an optional parameter: `warn`. This parameter controls whether or * Added an optional parameter: `warn`. This parameter controls whether or

View File

@ -195,10 +195,10 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
tags = ["no_pip"], tags = ["no_pip"],
deps = [ deps = [
":util",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/trackable:base", "//tensorflow/python/trackable:base",
"//tensorflow/python/trackable:converter", "//tensorflow/python/trackable:converter",
"//tensorflow/python/util:tf_export",
], ],
) )

View File

@ -19,10 +19,38 @@ import weakref
from tensorflow.python.trackable import base from tensorflow.python.trackable import base
from tensorflow.python.trackable import converter from tensorflow.python.trackable import converter
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
@tf_export("train.TrackableView", v1=[])
class TrackableView(object): class TrackableView(object):
"""Gathers and serializes a trackable view.""" """Gathers and serializes a trackable view.
Example usage:
>>> class SimpleModule(tf.Module):
... def __init__(self, name=None):
... super().__init__(name=name)
... self.a_var = tf.Variable(5.0)
... self.b_var = tf.Variable(4.0)
... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
>>> root = SimpleModule(name="root")
>>> root.leaf = SimpleModule(name="leaf")
>>> trackable_view = tf.train.TrackableView(root)
Pass root to tf.train.TrackableView.children() to get the dictionary of all
children directly linked to root by name.
>>> trackable_view_children = trackable_view.children(root)
>>> for item in trackable_view_children.items():
... print(item)
('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
('leaf', ...)
"""
def __init__(self, root): def __init__(self, root):
"""Configure the trackable view. """Configure the trackable view.
@ -38,7 +66,8 @@ class TrackableView(object):
self._root_ref = (root if isinstance(root, weakref.ref) self._root_ref = (root if isinstance(root, weakref.ref)
else weakref.ref(root)) else weakref.ref(root))
def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): @classmethod
def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
"""Returns all child trackables attached to obj. """Returns all child trackables attached to obj.
Args: Args:

View File

@ -26,8 +26,7 @@ class TrackableViewTest(test.TestCase):
leaf = base.Trackable() leaf = base.Trackable()
root._track_trackable(leaf, name="leaf") root._track_trackable(leaf, name="leaf")
(current_name, (current_name,
current_dependency), = trackable_view.TrackableView(object).children( current_dependency), = trackable_view.TrackableView.children(root).items()
root, object).items()
self.assertIs(leaf, current_dependency) self.assertIs(leaf, current_dependency)
self.assertEqual("leaf", current_name) self.assertEqual("leaf", current_name)

View File

@ -0,0 +1,21 @@
path: "tensorflow.train.TrackableView"
tf_class {
is_instance: "<class \'tensorflow.python.checkpoint.trackable_view.TrackableView\'>"
is_instance: "<type \'object\'>"
member {
name: "root"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'root\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "children"
argspec: "args=[\'cls\', \'obj\', \'save_type\'], varargs=None, keywords=kwargs, defaults=[\'SaveType.CHECKPOINT\'], "
}
member_method {
name: "descendants"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -72,6 +72,10 @@ tf_module {
name: "ServerDef" name: "ServerDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>" mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
} }
member {
name: "TrackableView"
mtype: "<type \'type\'>"
}
member { member {
name: "experimental" name: "experimental"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"