mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Summary-writing support for Evaluators.
PiperOrigin-RevId: 173971621
This commit is contained in:
parent
72be26dc82
commit
73fdaf0b56
|
|
@ -164,6 +164,7 @@ py_test(
|
|||
deps = [
|
||||
":metrics",
|
||||
"//tensorflow/contrib/summary:summary_ops",
|
||||
"//tensorflow/contrib/summary:summary_test_util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
|
|
@ -185,6 +186,7 @@ py_library(
|
|||
deps = [
|
||||
":datasets",
|
||||
":metrics",
|
||||
"//tensorflow/contrib/summary:summary_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
|
@ -201,6 +203,10 @@ py_test(
|
|||
deps = [
|
||||
":evaluator",
|
||||
":metrics",
|
||||
"//tensorflow/contrib/summary:summary_test_util",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import six
|
|||
|
||||
from tensorflow.contrib.eager.python import datasets
|
||||
from tensorflow.contrib.eager.python import metrics
|
||||
from tensorflow.contrib.summary import summary_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import errors_impl
|
||||
|
|
@ -36,7 +37,7 @@ class Evaluator(object):
|
|||
evaluator = my_model.evaluator() # or MyEvaluator(my_model)
|
||||
for example_batch in ...:
|
||||
evaluator(example_batch)
|
||||
results = evaluator.all_metric_results(optional_summary_writer)
|
||||
results = evaluator.all_metric_results(optional_summary_logdir)
|
||||
|
||||
Or, if you are getting your examples from a tf.data.Dataset, you can use
|
||||
the evaluate_on_dataset() method.
|
||||
|
|
@ -94,8 +95,31 @@ class Evaluator(object):
|
|||
"eager execution is enabled.")
|
||||
return control_flow_ops.group([m.init_variables() for _, m in self.metrics])
|
||||
|
||||
def all_metric_results(self): # TODO(josh11b): Add optional summary_writer.
|
||||
"""Returns dict mapping metric name -> value."""
|
||||
def all_metric_results(self, summary_logdir=None):
|
||||
"""Computes results for all contained metrics.
|
||||
|
||||
Args:
|
||||
summary_logdir: An optional string. If specified, metric results
|
||||
will be written as summaries to this directory.
|
||||
|
||||
Returns:
|
||||
A `dict` mapping string names to tensors.
|
||||
"""
|
||||
if summary_logdir is None:
|
||||
with summary_ops.never_record_summaries():
|
||||
return self._all_metric_results()
|
||||
else:
|
||||
def f():
|
||||
with summary_ops.create_summary_file_writer(
|
||||
summary_logdir).as_default(), summary_ops.always_record_summaries():
|
||||
return self._all_metric_results()
|
||||
if context.in_eager_mode():
|
||||
return f()
|
||||
else:
|
||||
return function.defun(f)()
|
||||
|
||||
def _all_metric_results(self):
|
||||
"""Implementation of `all_metric_results` in the summary context."""
|
||||
results = {}
|
||||
for name, metric in six.iteritems(self._metrics):
|
||||
results[name] = metric.result()
|
||||
|
|
@ -110,7 +134,9 @@ class Evaluator(object):
|
|||
Args:
|
||||
dataset: Dataset object with the input data to evaluate on.
|
||||
*args:
|
||||
**kwargs: Optional additional arguments to __call__().
|
||||
**kwargs: Optional additional arguments to __call__(), except
|
||||
`summary_logdir`: if specified, metrics will be written as summaries
|
||||
to this directory.
|
||||
|
||||
Returns:
|
||||
@compatibility(eager)
|
||||
|
|
@ -131,17 +157,17 @@ class Evaluator(object):
|
|||
```
|
||||
@end_compatibility
|
||||
"""
|
||||
# TODO(josh11b): Add optional summary_writer.
|
||||
summary_logdir = kwargs.pop("summary_logdir", None)
|
||||
if context.in_graph_mode():
|
||||
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(),
|
||||
*args, **kwargs)
|
||||
init_op = self.init_variables()
|
||||
results_op = self.all_metric_results()
|
||||
results_op = self.all_metric_results(summary_logdir)
|
||||
return (init_op, call_op, results_op)
|
||||
# Eager case
|
||||
for example in datasets.Iterator(dataset):
|
||||
self.__call__(example, *args, **kwargs)
|
||||
return self.all_metric_results()
|
||||
return self.all_metric_results(summary_logdir)
|
||||
|
||||
@staticmethod
|
||||
def run_evaluation(init_op, call_op, results_op, sess=None):
|
||||
|
|
|
|||
|
|
@ -18,11 +18,18 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from tensorflow.contrib.eager.python import evaluator
|
||||
|
||||
from tensorflow.contrib.eager.python import metrics
|
||||
from tensorflow.contrib.summary import summary_test_util
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
class IdentityModel(object):
|
||||
|
|
@ -71,6 +78,19 @@ class EvaluatorTest(test.TestCase):
|
|||
self.assertEqual(set(["mean"]), set(results.keys()))
|
||||
self.assertEqual(6.0, results["mean"].numpy())
|
||||
|
||||
def testWriteSummaries(self):
|
||||
e = SimpleEvaluator(IdentityModel())
|
||||
e(3.0)
|
||||
e([5.0, 7.0, 9.0])
|
||||
training_util.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
|
||||
e.all_metric_results(logdir)
|
||||
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
|
||||
|
||||
def testComposition(self):
|
||||
e = DelegatingEvaluator(PrefixLModel())
|
||||
e({"inner": 2.0, "outer": 100.0})
|
||||
|
|
@ -97,7 +117,7 @@ class EvaluatorTest(test.TestCase):
|
|||
self.assertEqual(6.0, results["mean"].numpy())
|
||||
|
||||
def testDatasetGraph(self):
|
||||
with context.graph_mode(), self.test_session():
|
||||
with context.graph_mode(), ops.Graph().as_default(), self.test_session():
|
||||
e = SimpleEvaluator(IdentityModel())
|
||||
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
|
||||
init_op, call_op, results_op = e.evaluate_on_dataset(ds)
|
||||
|
|
@ -105,6 +125,21 @@ class EvaluatorTest(test.TestCase):
|
|||
self.assertEqual(set(["mean"]), set(results.keys()))
|
||||
self.assertEqual(6.0, results["mean"])
|
||||
|
||||
def testWriteSummariesGraph(self):
|
||||
with context.graph_mode(), ops.Graph().as_default(), self.test_session():
|
||||
e = SimpleEvaluator(IdentityModel())
|
||||
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
|
||||
training_util.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
init_op, call_op, results_op = e.evaluate_on_dataset(
|
||||
ds, summary_logdir=logdir)
|
||||
variables.global_variables_initializer().run()
|
||||
e.run_evaluation(init_op, call_op, results_op)
|
||||
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
|
||||
|
||||
def testModelProperty(self):
|
||||
m = IdentityModel()
|
||||
e = SimpleEvaluator(m)
|
||||
|
|
|
|||
|
|
@ -18,18 +18,15 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from tensorflow.contrib.eager.python import metrics
|
||||
from tensorflow.contrib.summary import summary_ops
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.contrib.summary import summary_test_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
|
|
@ -63,15 +60,9 @@ class MetricsTest(test.TestCase):
|
|||
name="t0").as_default(), summary_ops.always_record_summaries():
|
||||
m.result() # As a side-effect will write summaries.
|
||||
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
files = gfile.ListDirectory(logdir)
|
||||
self.assertEqual(len(files), 1)
|
||||
records = list(
|
||||
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||
self.assertEqual(len(records), 2)
|
||||
event = event_pb2.Event()
|
||||
event.ParseFromString(records[1])
|
||||
self.assertEqual(event.summary.value[0].simple_value, 37.0)
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
|
||||
|
||||
def testWeightedMean(self):
|
||||
m = metrics.Mean()
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from tensorflow.contrib.summary.summary_ops import all_summary_ops
|
|||
from tensorflow.contrib.summary.summary_ops import always_record_summaries
|
||||
from tensorflow.contrib.summary.summary_ops import audio
|
||||
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
|
||||
from tensorflow.contrib.summary.summary_ops import eval_dir
|
||||
from tensorflow.contrib.summary.summary_ops import generic
|
||||
from tensorflow.contrib.summary.summary_ops import histogram
|
||||
from tensorflow.contrib.summary.summary_ops import image
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.summary import gen_summary_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
|
@ -272,3 +274,8 @@ def audio(name, tensor, sample_rate, max_outputs, family=None):
|
|||
name=scope)
|
||||
|
||||
return summary_writer_function(name, tensor, function, family=family)
|
||||
|
||||
|
||||
def eval_dir(model_dir, name=None):
|
||||
"""Construct a logdir for an eval summary writer."""
|
||||
return os.path.join(model_dir, "eval" if not name else "eval_" + name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user