mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Switching contrib.summaries API to be context-manager-centric
PiperOrigin-RevId: 173129793
This commit is contained in:
parent
03b02ffc9e
commit
4ec6f2b07c
|
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.layers import utils
|
from tensorflow.python.layers import utils
|
||||||
from tensorflow.python.ops import summary_op_util
|
from tensorflow.python.ops import summary_op_util
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
|
from tensorflow.python.util import tf_contextlib
|
||||||
|
|
||||||
# Name for a collection which is expected to have at most a single boolean
|
# Name for a collection which is expected to have at most a single boolean
|
||||||
# Tensor. If this tensor is True the summary ops will record summaries.
|
# Tensor. If this tensor is True the summary ops will record summaries.
|
||||||
|
|
@ -46,22 +47,50 @@ def should_record_summaries():
|
||||||
|
|
||||||
|
|
||||||
# TODO(apassos) consider how to handle local step here.
|
# TODO(apassos) consider how to handle local step here.
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
def record_summaries_every_n_global_steps(n):
|
def record_summaries_every_n_global_steps(n):
|
||||||
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
|
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
|
||||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||||
|
old = collection_ref[:]
|
||||||
collection_ref[:] = [training_util.get_global_step() % n == 0]
|
collection_ref[:] = [training_util.get_global_step() % n == 0]
|
||||||
|
yield
|
||||||
|
collection_ref[:] = old
|
||||||
|
|
||||||
|
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
def always_record_summaries():
|
def always_record_summaries():
|
||||||
"""Sets the should_record_summaries Tensor to always true."""
|
"""Sets the should_record_summaries Tensor to always true."""
|
||||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||||
|
old = collection_ref[:]
|
||||||
collection_ref[:] = [True]
|
collection_ref[:] = [True]
|
||||||
|
yield
|
||||||
|
collection_ref[:] = old
|
||||||
|
|
||||||
|
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
def never_record_summaries():
|
def never_record_summaries():
|
||||||
"""Sets the should_record_summaries Tensor to always false."""
|
"""Sets the should_record_summaries Tensor to always false."""
|
||||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||||
|
old = collection_ref[:]
|
||||||
collection_ref[:] = [False]
|
collection_ref[:] = [False]
|
||||||
|
yield
|
||||||
|
collection_ref[:] = old
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryWriter(object):
|
||||||
|
|
||||||
|
def __init__(self, resource):
|
||||||
|
self._resource = resource
|
||||||
|
|
||||||
|
def set_as_default(self):
|
||||||
|
context.context().summary_writer_resource = self._resource
|
||||||
|
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
|
def as_default(self):
|
||||||
|
old = context.context().summary_writer_resource
|
||||||
|
context.context().summary_writer_resource = self._resource
|
||||||
|
yield
|
||||||
|
context.context().summary_writer_resource = old
|
||||||
|
|
||||||
|
|
||||||
def create_summary_file_writer(logdir,
|
def create_summary_file_writer(logdir,
|
||||||
|
|
@ -77,9 +106,11 @@ def create_summary_file_writer(logdir,
|
||||||
if filename_suffix is None:
|
if filename_suffix is None:
|
||||||
filename_suffix = constant_op.constant("")
|
filename_suffix = constant_op.constant("")
|
||||||
resource = gen_summary_ops.summary_writer(shared_name=name)
|
resource = gen_summary_ops.summary_writer(shared_name=name)
|
||||||
|
# TODO(apassos) ensure the initialization op runs when in graph mode; consider
|
||||||
|
# calling session.run here.
|
||||||
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
|
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
|
||||||
flush_secs, filename_suffix)
|
flush_secs, filename_suffix)
|
||||||
context.context().summary_writer_resource = resource
|
return SummaryWriter(resource)
|
||||||
|
|
||||||
|
|
||||||
def _nothing():
|
def _nothing():
|
||||||
|
|
|
||||||
|
|
@ -41,60 +41,65 @@ class TargetTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testShouldRecordSummary(self):
|
def testShouldRecordSummary(self):
|
||||||
self.assertFalse(summary_ops.should_record_summaries())
|
self.assertFalse(summary_ops.should_record_summaries())
|
||||||
summary_ops.always_record_summaries()
|
with summary_ops.always_record_summaries():
|
||||||
self.assertTrue(summary_ops.should_record_summaries())
|
self.assertTrue(summary_ops.should_record_summaries())
|
||||||
|
|
||||||
def testSummaryOps(self):
|
def testSummaryOps(self):
|
||||||
training_util.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
logdir = tempfile.mkdtemp()
|
logdir = tempfile.mkdtemp()
|
||||||
summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
|
with summary_ops.create_summary_file_writer(
|
||||||
summary_ops.always_record_summaries()
|
logdir, max_queue=0,
|
||||||
summary_ops.generic('tensor', 1, '')
|
name='t0').as_default(), summary_ops.always_record_summaries():
|
||||||
summary_ops.scalar('scalar', 2.0)
|
summary_ops.generic('tensor', 1, '')
|
||||||
summary_ops.histogram('histogram', [1.0])
|
summary_ops.scalar('scalar', 2.0)
|
||||||
summary_ops.image('image', [[[[1.0]]]])
|
summary_ops.histogram('histogram', [1.0])
|
||||||
summary_ops.audio('audio', [[1.0]], 1.0, 1)
|
summary_ops.image('image', [[[[1.0]]]])
|
||||||
# The working condition of the ops is tested in the C++ test so we just
|
summary_ops.audio('audio', [[1.0]], 1.0, 1)
|
||||||
# test here that we're calling them correctly.
|
# The working condition of the ops is tested in the C++ test so we just
|
||||||
self.assertTrue(gfile.Exists(logdir))
|
# test here that we're calling them correctly.
|
||||||
|
self.assertTrue(gfile.Exists(logdir))
|
||||||
|
|
||||||
def testDefunSummarys(self):
|
def testDefunSummarys(self):
|
||||||
training_util.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
logdir = tempfile.mkdtemp()
|
logdir = tempfile.mkdtemp()
|
||||||
summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1')
|
with summary_ops.create_summary_file_writer(
|
||||||
summary_ops.always_record_summaries()
|
logdir, max_queue=0,
|
||||||
|
name='t1').as_default(), summary_ops.always_record_summaries():
|
||||||
|
|
||||||
@function.defun
|
@function.defun
|
||||||
def write():
|
def write():
|
||||||
summary_ops.scalar('scalar', 2.0)
|
summary_ops.scalar('scalar', 2.0)
|
||||||
|
|
||||||
write()
|
write()
|
||||||
|
|
||||||
self.assertTrue(gfile.Exists(logdir))
|
self.assertTrue(gfile.Exists(logdir))
|
||||||
files = gfile.ListDirectory(logdir)
|
files = gfile.ListDirectory(logdir)
|
||||||
self.assertEqual(len(files), 1)
|
self.assertEqual(len(files), 1)
|
||||||
records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
records = list(
|
||||||
self.assertEqual(len(records), 2)
|
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||||
event = event_pb2.Event()
|
self.assertEqual(len(records), 2)
|
||||||
event.ParseFromString(records[1])
|
event = event_pb2.Event()
|
||||||
self.assertEqual(event.summary.value[0].simple_value, 2.0)
|
event.ParseFromString(records[1])
|
||||||
|
self.assertEqual(event.summary.value[0].simple_value, 2.0)
|
||||||
|
|
||||||
def testSummaryName(self):
|
def testSummaryName(self):
|
||||||
training_util.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
logdir = tempfile.mkdtemp()
|
logdir = tempfile.mkdtemp()
|
||||||
summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2')
|
with summary_ops.create_summary_file_writer(
|
||||||
summary_ops.always_record_summaries()
|
logdir, max_queue=0,
|
||||||
|
name='t2').as_default(), summary_ops.always_record_summaries():
|
||||||
|
|
||||||
summary_ops.scalar('scalar', 2.0)
|
summary_ops.scalar('scalar', 2.0)
|
||||||
|
|
||||||
self.assertTrue(gfile.Exists(logdir))
|
self.assertTrue(gfile.Exists(logdir))
|
||||||
files = gfile.ListDirectory(logdir)
|
files = gfile.ListDirectory(logdir)
|
||||||
self.assertEqual(len(files), 1)
|
self.assertEqual(len(files), 1)
|
||||||
records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
records = list(
|
||||||
self.assertEqual(len(records), 2)
|
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||||
event = event_pb2.Event()
|
self.assertEqual(len(records), 2)
|
||||||
event.ParseFromString(records[1])
|
event = event_pb2.Event()
|
||||||
self.assertEqual(event.summary.value[0].tag, 'scalar')
|
event.ParseFromString(records[1])
|
||||||
|
self.assertEqual(event.summary.value[0].tag, 'scalar')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ class _EagerContext(threading.local):
|
||||||
self.mode = _default_mode
|
self.mode = _default_mode
|
||||||
self.scope_name = ""
|
self.scope_name = ""
|
||||||
self.recording_summaries = False
|
self.recording_summaries = False
|
||||||
|
self.summary_writer_resource = None
|
||||||
self.scalar_cache = {}
|
self.scalar_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -86,7 +87,6 @@ class Context(object):
|
||||||
self._eager_context = _EagerContext()
|
self._eager_context = _EagerContext()
|
||||||
self._context_handle = None
|
self._context_handle = None
|
||||||
self._context_devices = None
|
self._context_devices = None
|
||||||
self._summary_writer_resource = None
|
|
||||||
self._post_execution_callbacks = []
|
self._post_execution_callbacks = []
|
||||||
self._config = config
|
self._config = config
|
||||||
self._seed = None
|
self._seed = None
|
||||||
|
|
@ -213,12 +213,12 @@ class Context(object):
|
||||||
@property
|
@property
|
||||||
def summary_writer_resource(self):
|
def summary_writer_resource(self):
|
||||||
"""Returns summary writer resource."""
|
"""Returns summary writer resource."""
|
||||||
return self._summary_writer_resource
|
return self._eager_context.summary_writer_resource
|
||||||
|
|
||||||
@summary_writer_resource.setter
|
@summary_writer_resource.setter
|
||||||
def summary_writer_resource(self, resource):
|
def summary_writer_resource(self, resource):
|
||||||
"""Sets summary writer resource."""
|
"""Sets summary writer resource."""
|
||||||
self._summary_writer_resource = resource
|
self._eager_context.summary_writer_resource = resource
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_name(self):
|
def device_name(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user