mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add 'streaming_curve_points' metric which returns curve [ROC, PR] approximation at specified number of points.
PiperOrigin-RevId: 157851535
This commit is contained in:
parent
0f2db73916
commit
cd6c02985e
|
|
@ -23,6 +23,7 @@ See the @{$python/contrib.metrics} guide.
|
|||
@@streaming_precision
|
||||
@@streaming_precision_at_thresholds
|
||||
@@streaming_auc
|
||||
@@streaming_curve_points
|
||||
@@streaming_recall_at_k
|
||||
@@streaming_mean_absolute_error
|
||||
@@streaming_mean_iou
|
||||
|
|
@ -76,6 +77,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
|
|||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives
|
||||
|
|
|
|||
|
|
@ -733,6 +733,102 @@ def streaming_true_negatives_at_thresholds(
|
|||
return values['tn'], update_ops['tn']
|
||||
|
||||
|
||||
def streaming_curve_points(labels=None,
|
||||
predictions=None,
|
||||
weights=None,
|
||||
num_thresholds=200,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
curve='ROC',
|
||||
name=None):
|
||||
"""Computes curve (ROC or PR) values for a prespecified number of points.
|
||||
|
||||
The `streaming_curve_points` function creates four local variables,
|
||||
`true_positives`, `true_negatives`, `false_positives` and `false_negatives`
|
||||
that are used to compute the curve values. To discretize the curve, a linearly
|
||||
spaced set of thresholds is used to compute pairs of recall and precision
|
||||
values.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` operation that updates these variables.
|
||||
|
||||
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
|
||||
|
||||
Args:
|
||||
labels: A `Tensor` whose shape matches `predictions`. Will be cast to
|
||||
`bool`.
|
||||
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
||||
are in the range `[0, 1]`.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||
be either `1`, or the same as the corresponding `labels` dimension).
|
||||
num_thresholds: The number of thresholds to use when discretizing the roc
|
||||
curve.
|
||||
metrics_collections: An optional list of collections that `auc` should be
|
||||
added to.
|
||||
updates_collections: An optional list of collections that `update_op` should
|
||||
be added to.
|
||||
curve: Specifies the name of the curve to be computed, 'ROC' [default] or
|
||||
'PR' for the Precision-Recall-curve.
|
||||
name: An optional variable_scope name.
|
||||
|
||||
Returns:
|
||||
points: A `Tensor` with shape [num_thresholds, 2] that contains points of
|
||||
the curve.
|
||||
update_op: An operation that increments the `true_positives`,
|
||||
`true_negatives`, `false_positives` and `false_negatives` variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predictions` and `labels` have mismatched shapes, or if
|
||||
`weights` is not `None` and its shape doesn't match `predictions`, or if
|
||||
either `metrics_collections` or `updates_collections` are not a list or
|
||||
tuple.
|
||||
"""
|
||||
with variable_scope.variable_scope(name, 'curve_points', (labels, predictions,
|
||||
weights)):
|
||||
if curve != 'ROC' and curve != 'PR':
|
||||
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
|
||||
kepsilon = 1e-7 # to account for floating point imprecisions
|
||||
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
|
||||
for i in range(num_thresholds - 2)]
|
||||
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
|
||||
|
||||
values, update_ops = _streaming_confusion_matrix_at_thresholds(
|
||||
labels=labels,
|
||||
predictions=predictions,
|
||||
thresholds=thresholds,
|
||||
weights=weights)
|
||||
|
||||
# Add epsilons to avoid dividing by 0.
|
||||
epsilon = 1.0e-6
|
||||
|
||||
def compute_points(tp, fn, tn, fp):
|
||||
"""Computes the roc-auc or pr-auc based on confusion counts."""
|
||||
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
|
||||
if curve == 'ROC':
|
||||
fp_rate = math_ops.div(fp, fp + tn + epsilon)
|
||||
return fp_rate, rec
|
||||
else: # curve == 'PR'.
|
||||
prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
|
||||
return rec, prec
|
||||
|
||||
xs, ys = compute_points(values['tp'], values['fn'], values['tn'],
|
||||
values['fp'])
|
||||
points = array_ops.stack([xs, ys], axis=1)
|
||||
update_op = control_flow_ops.group(*update_ops.values())
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, points)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
return points, update_op
|
||||
|
||||
|
||||
def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
||||
metrics_collections=None, updates_collections=None,
|
||||
curve='ROC', name=None):
|
||||
|
|
@ -2372,6 +2468,7 @@ __all__ = [
|
|||
'sparse_recall_at_top_k',
|
||||
'streaming_accuracy',
|
||||
'streaming_auc',
|
||||
'streaming_curve_points',
|
||||
'streaming_false_negatives',
|
||||
'streaming_false_negatives_at_thresholds',
|
||||
'streaming_false_positives',
|
||||
|
|
|
|||
|
|
@ -1327,6 +1327,99 @@ class StreamingRecallTest(test.TestCase):
|
|||
self.assertEqual(0, recall.eval())
|
||||
|
||||
|
||||
class StreamingCurvePointsTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
np.random.seed(1)
|
||||
ops.reset_default_graph()
|
||||
|
||||
def testVars(self):
|
||||
metric_ops.streaming_curve_points(
|
||||
predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
|
||||
_assert_local_variables(
|
||||
self,
|
||||
('curve_points/true_positives:0', 'curve_points/false_negatives:0',
|
||||
'curve_points/false_positives:0', 'curve_points/true_negatives:0'))
|
||||
|
||||
def testMetricsCollection(self):
|
||||
my_collection_name = '__metrics__'
|
||||
points, _ = metric_ops.streaming_curve_points(
|
||||
labels=array_ops.ones((10, 1)),
|
||||
predictions=array_ops.ones((10, 1)),
|
||||
metrics_collections=[my_collection_name])
|
||||
self.assertListEqual(ops.get_collection(my_collection_name), [points])
|
||||
|
||||
def testUpdatesCollection(self):
|
||||
my_collection_name = '__updates__'
|
||||
_, update_op = metric_ops.streaming_curve_points(
|
||||
labels=array_ops.ones((10, 1)),
|
||||
predictions=array_ops.ones((10, 1)),
|
||||
updates_collections=[my_collection_name])
|
||||
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
|
||||
|
||||
def _testValueTensorIsIdempotent(self, curve):
|
||||
predictions = constant_op.constant(
|
||||
np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32)
|
||||
labels = constant_op.constant(
|
||||
np.random.uniform(high=2, size=(10, 3)), dtype=dtypes_lib.float32)
|
||||
|
||||
points, update_op = metric_ops.streaming_curve_points(
|
||||
labels, predictions=predictions, curve=curve)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.local_variables_initializer())
|
||||
|
||||
sess.run(update_op)
|
||||
initial_points = points.eval()
|
||||
|
||||
sess.run(update_op)
|
||||
self.assertAllClose(initial_points, points.eval())
|
||||
|
||||
def testValueTensorIsIdempotentROC(self):
|
||||
self._testValueTensorIsIdempotent(curve='ROC')
|
||||
|
||||
def testValueTensorIsIdempotentPR(self):
|
||||
self._testValueTensorIsIdempotent(curve='PR')
|
||||
|
||||
def _testCase(self, labels, predictions, curve, expected_points):
|
||||
with self.test_session() as sess:
|
||||
predictions_tensor = constant_op.constant(
|
||||
predictions, dtype=dtypes_lib.float32)
|
||||
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
|
||||
points, update_op = metric_ops.streaming_curve_points(
|
||||
labels=labels_tensor,
|
||||
predictions=predictions_tensor,
|
||||
num_thresholds=3,
|
||||
curve=curve)
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run(update_op)
|
||||
|
||||
self.assertAllClose(expected_points, points.eval())
|
||||
|
||||
def testEdgeCasesROC(self):
|
||||
self._testCase([[1]], [[1]], 'ROC', [[0, 1], [0, 1], [0, 0]])
|
||||
self._testCase([[0]], [[0]], 'ROC', [[1, 1], [0, 1], [0, 1]])
|
||||
self._testCase([[0]], [[1]], 'ROC', [[1, 1], [1, 1], [0, 1]])
|
||||
self._testCase([[1]], [[0]], 'ROC', [[0, 1], [0, 0], [0, 0]])
|
||||
|
||||
def testManyValuesROC(self):
|
||||
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
|
||||
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'ROC',
|
||||
[[1.0, 1.0], [0.0, 0.75], [0.0, 0.0]])
|
||||
|
||||
def testEdgeCasesPR(self):
|
||||
self._testCase([[1]], [[1]], 'PR', [[1, 1], [1, 1], [0, 1]])
|
||||
self._testCase([[0]], [[0]], 'PR', [[1, 0], [1, 1], [1, 1]])
|
||||
self._testCase([[0]], [[1]], 'PR', [[1, 0], [1, 0], [1, 1]])
|
||||
self._testCase([[1]], [[0]], 'PR', [[1, 1], [0, 1], [0, 1]])
|
||||
|
||||
def testManyValuesPR(self):
|
||||
self._testCase([[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]],
|
||||
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], 'PR',
|
||||
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
|
||||
|
||||
|
||||
class StreamingAUCTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user