mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Export TrackableView to the public API.
PiperOrigin-RevId: 463620038
This commit is contained in:
parent
0330ae4096
commit
f1656e7642
|
|
@ -132,6 +132,11 @@
|
|||
difference range from 8 to 100 times depending on the size of k. When
|
||||
running on CPU and GPU, a non-optimized XLA kernel is used.
|
||||
|
||||
* `tf.train`:
|
||||
|
||||
* Added `tf.train.TrackableView` which allows users to inspect the
|
||||
TensorFlow Trackable object (e.g. `tf.Module`, Keras Layers and models).
|
||||
|
||||
* `tf.vectorized_map`:
|
||||
|
||||
* Added an optional parameter: `warn`. This parameter controls whether or
|
||||
|
|
|
|||
|
|
@ -195,10 +195,10 @@ py_library(
|
|||
srcs_version = "PY3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/trackable:base",
|
||||
"//tensorflow/python/trackable:converter",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,38 @@ import weakref
|
|||
from tensorflow.python.trackable import base
|
||||
from tensorflow.python.trackable import converter
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("train.TrackableView", v1=[])
|
||||
class TrackableView(object):
|
||||
"""Gathers and serializes a trackable view."""
|
||||
"""Gathers and serializes a trackable view.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> class SimpleModule(tf.Module):
|
||||
... def __init__(self, name=None):
|
||||
... super().__init__(name=name)
|
||||
... self.a_var = tf.Variable(5.0)
|
||||
... self.b_var = tf.Variable(4.0)
|
||||
... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
|
||||
|
||||
>>> root = SimpleModule(name="root")
|
||||
>>> root.leaf = SimpleModule(name="leaf")
|
||||
>>> trackable_view = tf.train.TrackableView(root)
|
||||
|
||||
Pass root to tf.train.TrackableView.children() to get the dictionary of all
|
||||
children directly linked to root by name.
|
||||
>>> trackable_view_children = trackable_view.children(root)
|
||||
>>> for item in trackable_view_children.items():
|
||||
... print(item)
|
||||
('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
|
||||
('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
|
||||
('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
|
||||
numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
|
||||
('leaf', ...)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, root):
|
||||
"""Configure the trackable view.
|
||||
|
|
@ -38,7 +66,8 @@ class TrackableView(object):
|
|||
self._root_ref = (root if isinstance(root, weakref.ref)
|
||||
else weakref.ref(root))
|
||||
|
||||
def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
|
||||
@classmethod
|
||||
def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs):
|
||||
"""Returns all child trackables attached to obj.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ class TrackableViewTest(test.TestCase):
|
|||
leaf = base.Trackable()
|
||||
root._track_trackable(leaf, name="leaf")
|
||||
(current_name,
|
||||
current_dependency), = trackable_view.TrackableView(object).children(
|
||||
root, object).items()
|
||||
current_dependency), = trackable_view.TrackableView.children(root).items()
|
||||
self.assertIs(leaf, current_dependency)
|
||||
self.assertEqual("leaf", current_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
path: "tensorflow.train.TrackableView"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.checkpoint.trackable_view.TrackableView\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "root"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'root\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "children"
|
||||
argspec: "args=[\'cls\', \'obj\', \'save_type\'], varargs=None, keywords=kwargs, defaults=[\'SaveType.CHECKPOINT\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "descendants"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
@ -72,6 +72,10 @@ tf_module {
|
|||
name: "ServerDef"
|
||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||
}
|
||||
member {
|
||||
name: "TrackableView"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user