Add 'streaming_curve_points' metric which returns curve [ROC, PR] approximation at specified number of points.

PiperOrigin-RevId: 157851535
This commit is contained in:
A. Unique TensorFlower 2017-06-02 11:08:48 -07:00 committed by TensorFlower Gardener
parent 0f2db73916
commit cd6c02985e
3 changed files with 192 additions and 0 deletions

View File

@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide.
@@streaming_precision
@@streaming_precision_at_thresholds
@@streaming_auc
@@streaming_curve_points
@@streaming_recall_at_k
@@streaming_mean_absolute_error
@@streaming_mean_iou
@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives

View File

@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds(
return values['tn'], update_ops['tn']
def streaming_curve_points(labels=None,
predictions=None,
weights=None,
num_thresholds=200,
metrics_collections=None,
updates_collections=None,
curve='ROC',
name=None):
"""Computes curve (ROC or PR) values for a prespecified number of points.
The `streaming_curve_points` function creates four local variables,
`true_positives`, `true_negatives`, `false_positives` and `false_negatives`
that are used to compute the curve values. To discretize the curve, a linearly
spaced set of thresholds is used to compute pairs of recall and precision
values.
For best results, `predictions` should be distributed approximately uniformly
in the range [0, 1] and not peaked around 0 or 1.
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
`bool`.
predictions: A floating point `Tensor` of arbitrary shape and whose values
are in the range `[0, 1]`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `labels` dimension).
num_thresholds: The number of thresholds to use when discretizing the roc
curve.
metrics_collections: An optional list of collections that `auc` should be
added to.
updates_collections: An optional list of collections that `update_op` should
be added to.
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
'PR' for the Precision-Recall-curve.
name: An optional variable_scope name.
Returns:
points: A `Tensor` with shape [num_thresholds, 2] that contains points of
the curve.
update_op: An operation that increments the `true_positives`,
`true_negatives`, `false_positives` and `false_negatives` variables.
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
`weights` is not `None` and its shape doesn't match `predictions`, or if
either `metrics_collections` or `updates_collections` are not a list or
tuple.
"""
with variable_scope.variable_scope(name, 'curve_points', (labels, predictions,
weights)):
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
labels=labels,
predictions=predictions,
thresholds=thresholds,
weights=weights)
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
def compute_points(tp, fn, tn, fp):
"""Computes the roc-auc or pr-auc based on confusion counts."""
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
if curve == 'ROC':
fp_rate = math_ops.div(fp, fp + tn + epsilon)
return fp_rate, rec
else: # curve == 'PR'.
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
return rec, prec
xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
values['fp'])
points = array_ops.stack([xs, ys], axis=1)
update_op = control_flow_ops.group(*update_ops.values())
if metrics_collections:
ops.add_to_collections(metrics_collections, points)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return points, update_op
def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
metrics_collections=None, updates_collections=None,
curve='ROC', name=None):
@ -2372,6 +2468,7 @@ __all__ = [
'sparse_recall_at_top_k',
'streaming_accuracy',
'streaming_auc',
'streaming_curve_points',
'streaming_false_negatives',
'streaming_false_negatives_at_thresholds',
'streaming_false_positives',

View File

@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase):
self.assertEqual(0, recall.eval())
class StreamingCurvePointsTest(test.TestCase):
def setUp(self):
np.random.seed(1)
ops.reset_default_graph()
def testVars(self):
metric_ops.streaming_curve_points(
predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
_assert_local_variables(
self,
('curve_points/true_positives:0', 'curve_points/false_negatives:0',
'curve_points/false_positives:0', 'curve_points/true_negatives:0'))
def testMetricsCollection(self):
my_collection_name = '__metrics__'
points, _ = metric_ops.streaming_curve_points(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [points])
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metric_ops.streaming_curve_points(
labels=array_ops.ones((10, 1)),
predictions=array_ops.ones((10, 1)),
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def _testValueTensorIsIdempotent(self, curve):
predictions = constant_op.constant(
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
labels = constant_op.constant(
np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32)
points, update_op = metric_ops.streaming_curve_points(
labels, predictions=predictions, curve=curve)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
initial_points = points.eval()
sess.run(update_op)
self.assertAllClose(initial_points, points.eval())
def testValueTensorIsIdempotentROC(self):
self._testValueTensorIsIdempotent(curve='ROC')
def testValueTensorIsIdempotentPR(self):
self._testValueTensorIsIdempotent(curve='PR')
def _testCase(self, labels, predictions, curve, expected_points):
with self.test_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
points, update_op = metric_ops.streaming_curve_points(
labels=labels_tensor,
predictions=predictions_tensor,
num_thresholds=3,
curve=curve)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAllClose(expected_points, points.eval())
def testEdgeCasesROC(self):
self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]])
self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]])
self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]])
self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]])
def testManyValuesROC(self):
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC',
[[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]])
def testEdgeCasesPR(self):
self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]])
self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]])
self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]])
self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]])
def testManyValuesPR(self):
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR',
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
class StreamingAUCTest(test.TestCase):
def setUp(self):