mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Summary-writing support for Evaluators.
PiperOrigin-RevId: 173971621
This commit is contained in:
parent
72be26dc82
commit
73fdaf0b56
|
|
@ -164,6 +164,7 @@ py_test(
|
||||||
deps = [
|
deps = [
|
||||||
":metrics",
|
":metrics",
|
||||||
"//tensorflow/contrib/summary:summary_ops",
|
"//tensorflow/contrib/summary:summary_ops",
|
||||||
|
"//tensorflow/contrib/summary:summary_test_util",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
|
|
@ -185,6 +186,7 @@ py_library(
|
||||||
deps = [
|
deps = [
|
||||||
":datasets",
|
":datasets",
|
||||||
":metrics",
|
":metrics",
|
||||||
|
"//tensorflow/contrib/summary:summary_ops",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
|
@ -201,6 +203,10 @@ py_test(
|
||||||
deps = [
|
deps = [
|
||||||
":evaluator",
|
":evaluator",
|
||||||
":metrics",
|
":metrics",
|
||||||
|
"//tensorflow/contrib/summary:summary_test_util",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import six
|
||||||
|
|
||||||
from tensorflow.contrib.eager.python import datasets
|
from tensorflow.contrib.eager.python import datasets
|
||||||
from tensorflow.contrib.eager.python import metrics
|
from tensorflow.contrib.eager.python import metrics
|
||||||
|
from tensorflow.contrib.summary import summary_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
|
@ -36,7 +37,7 @@ class Evaluator(object):
|
||||||
evaluator = my_model.evaluator() # or MyEvaluator(my_model)
|
evaluator = my_model.evaluator() # or MyEvaluator(my_model)
|
||||||
for example_batch in ...:
|
for example_batch in ...:
|
||||||
evaluator(example_batch)
|
evaluator(example_batch)
|
||||||
results = evaluator.all_metric_results(optional_summary_writer)
|
results = evaluator.all_metric_results(optional_summary_logdir)
|
||||||
|
|
||||||
Or, if you are getting your examples from a tf.data.Dataset, you can use
|
Or, if you are getting your examples from a tf.data.Dataset, you can use
|
||||||
the evaluate_on_dataset() method.
|
the evaluate_on_dataset() method.
|
||||||
|
|
@ -94,8 +95,31 @@ class Evaluator(object):
|
||||||
"eager execution is enabled.")
|
"eager execution is enabled.")
|
||||||
return control_flow_ops.group([m.init_variables() for _, m in self.metrics])
|
return control_flow_ops.group([m.init_variables() for _, m in self.metrics])
|
||||||
|
|
||||||
def all_metric_results(self): # TODO(josh11b): Add optional summary_writer.
|
def all_metric_results(self, summary_logdir=None):
|
||||||
"""Returns dict mapping metric name -> value."""
|
"""Computes results for all contained metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary_logdir: An optional string. If specified, metric results
|
||||||
|
will be written as summaries to this directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `dict` mapping string names to tensors.
|
||||||
|
"""
|
||||||
|
if summary_logdir is None:
|
||||||
|
with summary_ops.never_record_summaries():
|
||||||
|
return self._all_metric_results()
|
||||||
|
else:
|
||||||
|
def f():
|
||||||
|
with summary_ops.create_summary_file_writer(
|
||||||
|
summary_logdir).as_default(), summary_ops.always_record_summaries():
|
||||||
|
return self._all_metric_results()
|
||||||
|
if context.in_eager_mode():
|
||||||
|
return f()
|
||||||
|
else:
|
||||||
|
return function.defun(f)()
|
||||||
|
|
||||||
|
def _all_metric_results(self):
|
||||||
|
"""Implementation of `all_metric_results` in the summary context."""
|
||||||
results = {}
|
results = {}
|
||||||
for name, metric in six.iteritems(self._metrics):
|
for name, metric in six.iteritems(self._metrics):
|
||||||
results[name] = metric.result()
|
results[name] = metric.result()
|
||||||
|
|
@ -110,7 +134,9 @@ class Evaluator(object):
|
||||||
Args:
|
Args:
|
||||||
dataset: Dataset object with the input data to evaluate on.
|
dataset: Dataset object with the input data to evaluate on.
|
||||||
*args:
|
*args:
|
||||||
**kwargs: Optional additional arguments to __call__().
|
**kwargs: Optional additional arguments to __call__(), except
|
||||||
|
`summary_logdir`: if specified, metrics will be written as summaries
|
||||||
|
to this directory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@compatibility(eager)
|
@compatibility(eager)
|
||||||
|
|
@ -131,17 +157,17 @@ class Evaluator(object):
|
||||||
```
|
```
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
# TODO(josh11b): Add optional summary_writer.
|
summary_logdir = kwargs.pop("summary_logdir", None)
|
||||||
if context.in_graph_mode():
|
if context.in_graph_mode():
|
||||||
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(),
|
call_op = self.__call__(dataset.make_one_shot_iterator().get_next(),
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
init_op = self.init_variables()
|
init_op = self.init_variables()
|
||||||
results_op = self.all_metric_results()
|
results_op = self.all_metric_results(summary_logdir)
|
||||||
return (init_op, call_op, results_op)
|
return (init_op, call_op, results_op)
|
||||||
# Eager case
|
# Eager case
|
||||||
for example in datasets.Iterator(dataset):
|
for example in datasets.Iterator(dataset):
|
||||||
self.__call__(example, *args, **kwargs)
|
self.__call__(example, *args, **kwargs)
|
||||||
return self.all_metric_results()
|
return self.all_metric_results(summary_logdir)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_evaluation(init_op, call_op, results_op, sess=None):
|
def run_evaluation(init_op, call_op, results_op, sess=None):
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,18 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from tensorflow.contrib.eager.python import evaluator
|
from tensorflow.contrib.eager.python import evaluator
|
||||||
|
|
||||||
from tensorflow.contrib.eager.python import metrics
|
from tensorflow.contrib.eager.python import metrics
|
||||||
|
from tensorflow.contrib.summary import summary_test_util
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
|
||||||
class IdentityModel(object):
|
class IdentityModel(object):
|
||||||
|
|
@ -71,6 +78,19 @@ class EvaluatorTest(test.TestCase):
|
||||||
self.assertEqual(set(["mean"]), set(results.keys()))
|
self.assertEqual(set(["mean"]), set(results.keys()))
|
||||||
self.assertEqual(6.0, results["mean"].numpy())
|
self.assertEqual(6.0, results["mean"].numpy())
|
||||||
|
|
||||||
|
def testWriteSummaries(self):
|
||||||
|
e = SimpleEvaluator(IdentityModel())
|
||||||
|
e(3.0)
|
||||||
|
e([5.0, 7.0, 9.0])
|
||||||
|
training_util.get_or_create_global_step()
|
||||||
|
logdir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
e.all_metric_results(logdir)
|
||||||
|
|
||||||
|
events = summary_test_util.events_from_file(logdir)
|
||||||
|
self.assertEqual(len(events), 2)
|
||||||
|
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
|
||||||
|
|
||||||
def testComposition(self):
|
def testComposition(self):
|
||||||
e = DelegatingEvaluator(PrefixLModel())
|
e = DelegatingEvaluator(PrefixLModel())
|
||||||
e({"inner": 2.0, "outer": 100.0})
|
e({"inner": 2.0, "outer": 100.0})
|
||||||
|
|
@ -97,7 +117,7 @@ class EvaluatorTest(test.TestCase):
|
||||||
self.assertEqual(6.0, results["mean"].numpy())
|
self.assertEqual(6.0, results["mean"].numpy())
|
||||||
|
|
||||||
def testDatasetGraph(self):
|
def testDatasetGraph(self):
|
||||||
with context.graph_mode(), self.test_session():
|
with context.graph_mode(), ops.Graph().as_default(), self.test_session():
|
||||||
e = SimpleEvaluator(IdentityModel())
|
e = SimpleEvaluator(IdentityModel())
|
||||||
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
|
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
|
||||||
init_op, call_op, results_op = e.evaluate_on_dataset(ds)
|
init_op, call_op, results_op = e.evaluate_on_dataset(ds)
|
||||||
|
|
@ -105,6 +125,21 @@ class EvaluatorTest(test.TestCase):
|
||||||
self.assertEqual(set(["mean"]), set(results.keys()))
|
self.assertEqual(set(["mean"]), set(results.keys()))
|
||||||
self.assertEqual(6.0, results["mean"])
|
self.assertEqual(6.0, results["mean"])
|
||||||
|
|
||||||
|
def testWriteSummariesGraph(self):
|
||||||
|
with context.graph_mode(), ops.Graph().as_default(), self.test_session():
|
||||||
|
e = SimpleEvaluator(IdentityModel())
|
||||||
|
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
|
||||||
|
training_util.get_or_create_global_step()
|
||||||
|
logdir = tempfile.mkdtemp()
|
||||||
|
init_op, call_op, results_op = e.evaluate_on_dataset(
|
||||||
|
ds, summary_logdir=logdir)
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
e.run_evaluation(init_op, call_op, results_op)
|
||||||
|
|
||||||
|
events = summary_test_util.events_from_file(logdir)
|
||||||
|
self.assertEqual(len(events), 2)
|
||||||
|
self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
|
||||||
|
|
||||||
def testModelProperty(self):
|
def testModelProperty(self):
|
||||||
m = IdentityModel()
|
m = IdentityModel()
|
||||||
e = SimpleEvaluator(m)
|
e = SimpleEvaluator(m)
|
||||||
|
|
|
||||||
|
|
@ -18,18 +18,15 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from tensorflow.contrib.eager.python import metrics
|
from tensorflow.contrib.eager.python import metrics
|
||||||
from tensorflow.contrib.summary import summary_ops
|
from tensorflow.contrib.summary import summary_ops
|
||||||
from tensorflow.core.util import event_pb2
|
from tensorflow.contrib.summary import summary_test_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.lib.io import tf_record
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import gfile
|
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,15 +60,9 @@ class MetricsTest(test.TestCase):
|
||||||
name="t0").as_default(), summary_ops.always_record_summaries():
|
name="t0").as_default(), summary_ops.always_record_summaries():
|
||||||
m.result() # As a side-effect will write summaries.
|
m.result() # As a side-effect will write summaries.
|
||||||
|
|
||||||
self.assertTrue(gfile.Exists(logdir))
|
events = summary_test_util.events_from_file(logdir)
|
||||||
files = gfile.ListDirectory(logdir)
|
self.assertEqual(len(events), 2)
|
||||||
self.assertEqual(len(files), 1)
|
self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
|
||||||
records = list(
|
|
||||||
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
|
||||||
self.assertEqual(len(records), 2)
|
|
||||||
event = event_pb2.Event()
|
|
||||||
event.ParseFromString(records[1])
|
|
||||||
self.assertEqual(event.summary.value[0].simple_value, 37.0)
|
|
||||||
|
|
||||||
def testWeightedMean(self):
|
def testWeightedMean(self):
|
||||||
m = metrics.Mean()
|
m = metrics.Mean()
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from tensorflow.contrib.summary.summary_ops import all_summary_ops
|
||||||
from tensorflow.contrib.summary.summary_ops import always_record_summaries
|
from tensorflow.contrib.summary.summary_ops import always_record_summaries
|
||||||
from tensorflow.contrib.summary.summary_ops import audio
|
from tensorflow.contrib.summary.summary_ops import audio
|
||||||
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
|
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
|
||||||
|
from tensorflow.contrib.summary.summary_ops import eval_dir
|
||||||
from tensorflow.contrib.summary.summary_ops import generic
|
from tensorflow.contrib.summary.summary_ops import generic
|
||||||
from tensorflow.contrib.summary.summary_ops import histogram
|
from tensorflow.contrib.summary.summary_ops import histogram
|
||||||
from tensorflow.contrib.summary.summary_ops import image
|
from tensorflow.contrib.summary.summary_ops import image
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from tensorflow.contrib.summary import gen_summary_ops
|
from tensorflow.contrib.summary import gen_summary_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
|
@ -272,3 +274,8 @@ def audio(name, tensor, sample_rate, max_outputs, family=None):
|
||||||
name=scope)
|
name=scope)
|
||||||
|
|
||||||
return summary_writer_function(name, tensor, function, family=family)
|
return summary_writer_function(name, tensor, function, family=family)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_dir(model_dir, name=None):
|
||||||
|
"""Construct a logdir for an eval summary writer."""
|
||||||
|
return os.path.join(model_dir, "eval" if not name else "eval_" + name)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user