mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Exposes recall_at_top_k under tf.metrics.
PiperOrigin-RevId: 174189641
This commit is contained in:
parent
18bf5b2d91
commit
c40d541733
|
|
@ -2143,7 +2143,7 @@ def sparse_recall_at_top_k(labels,
|
|||
default_name = _at_k_name('recall', class_id=class_id)
|
||||
with ops.name_scope(name, default_name,
|
||||
(top_k_predictions, labels, weights)) as name_scope:
|
||||
return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access
|
||||
return metrics_impl.recall_at_top_k(
|
||||
labels=labels,
|
||||
predictions_idx=top_k_predictions,
|
||||
class_id=class_id,
|
||||
|
|
|
|||
|
|
@ -2304,10 +2304,43 @@ def _test_recall_at_k(predictions,
|
|||
test_case.assertEqual(expected, metric.eval())
|
||||
|
||||
|
||||
def _test_recall_at_top_k(
|
||||
predictions_idx,
|
||||
labels,
|
||||
expected,
|
||||
k=None,
|
||||
class_id=None,
|
||||
weights=None,
|
||||
test_case=None):
|
||||
with ops.Graph().as_default() as g, test_case.test_session(g):
|
||||
if weights is not None:
|
||||
weights = constant_op.constant(weights, dtypes_lib.float32)
|
||||
metric, update = metrics.recall_at_top_k(
|
||||
predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32),
|
||||
labels=labels,
|
||||
k=k,
|
||||
class_id=class_id,
|
||||
weights=weights)
|
||||
|
||||
# Fails without initialized vars.
|
||||
test_case.assertRaises(errors_impl.OpError, metric.eval)
|
||||
test_case.assertRaises(errors_impl.OpError, update.eval)
|
||||
variables.variables_initializer(variables.local_variables()).run()
|
||||
|
||||
# Run per-step op and assert expected values.
|
||||
if math.isnan(expected):
|
||||
_assert_nan(test_case, update.eval())
|
||||
_assert_nan(test_case, metric.eval())
|
||||
else:
|
||||
test_case.assertEqual(expected, update.eval())
|
||||
test_case.assertEqual(expected, metric.eval())
|
||||
|
||||
|
||||
class SingleLabelRecallAtKTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4))
|
||||
self._predictions_idx = [[3], [3]]
|
||||
indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0))
|
||||
class_labels = (3, 2)
|
||||
# Sparse vs dense, and 1d vs 2d labels should all be handled the same.
|
||||
|
|
@ -2318,6 +2351,8 @@ class SingleLabelRecallAtKTest(test.TestCase):
|
|||
[[class_id] for class_id in class_labels], dtype=np.int64))
|
||||
self._test_recall_at_k = functools.partial(
|
||||
_test_recall_at_k, test_case=self)
|
||||
self._test_recall_at_top_k = functools.partial(
|
||||
_test_recall_at_top_k, test_case=self)
|
||||
|
||||
def test_at_k1_nan(self):
|
||||
# Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of
|
||||
|
|
@ -2326,120 +2361,100 @@ class SingleLabelRecallAtKTest(test.TestCase):
|
|||
for class_id in (-1, 0, 1, 4):
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=1, expected=NAN, class_id=class_id)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id)
|
||||
|
||||
def test_at_k1_no_predictions(self):
|
||||
for labels in self._labels:
|
||||
# Class 2: 0 predictions.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=1, expected=0.0, class_id=2)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=1, expected=0.0, class_id=2)
|
||||
|
||||
def test_one_label_at_k1(self):
|
||||
for labels in self._labels:
|
||||
# Class 3: 1 label, 2 predictions, 1 correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=1, expected=1.0 / 1, class_id=3)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3)
|
||||
|
||||
# All classes: 2 labels, 2 predictions, 1 correct.
|
||||
self._test_recall_at_k(self._predictions, labels, k=1, expected=1.0 / 2)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=1, expected=1.0 / 2)
|
||||
|
||||
def test_one_label_at_k1_weighted(self):
|
||||
def test_one_label_at_k1_weighted_class_id3(self):
|
||||
predictions = self._predictions
|
||||
predictions_idx = self._predictions_idx
|
||||
for labels in self._labels:
|
||||
# Class 3: 1 label, 2 predictions, 1 correct.
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=NAN, class_id=3,
|
||||
weights=(0.0,))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=1.0 / 1,
|
||||
class_id=3,
|
||||
predictions, labels, k=1, expected=1.0 / 1, class_id=3,
|
||||
weights=(1.0,))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
|
||||
weights=(1.0,))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=1.0 / 1,
|
||||
class_id=3,
|
||||
predictions, labels, k=1, expected=1.0 / 1, class_id=3,
|
||||
weights=(2.0,))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
|
||||
weights=(2.0,))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=NAN,
|
||||
class_id=3,
|
||||
weights=(0.0, 0.0))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=NAN,
|
||||
class_id=3,
|
||||
predictions, labels, k=1, expected=NAN, class_id=3,
|
||||
weights=(0.0, 1.0))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=NAN, class_id=3,
|
||||
weights=(0.0, 1.0))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=1.0 / 1,
|
||||
class_id=3,
|
||||
predictions, labels, k=1, expected=1.0 / 1, class_id=3,
|
||||
weights=(1.0, 0.0))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
|
||||
weights=(1.0, 0.0))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=1.0 / 1,
|
||||
class_id=3,
|
||||
weights=(1.0, 1.0))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=2.0 / 2,
|
||||
class_id=3,
|
||||
predictions, labels, k=1, expected=2.0 / 2, class_id=3,
|
||||
weights=(2.0, 3.0))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=2.0 / 2, class_id=3,
|
||||
weights=(2.0, 3.0))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=3.0 / 3,
|
||||
class_id=3,
|
||||
weights=(3.0, 2.0))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=0.3 / 0.3,
|
||||
class_id=3,
|
||||
weights=(0.3, 0.6))
|
||||
self._test_recall_at_k(
|
||||
predictions,
|
||||
labels,
|
||||
k=1,
|
||||
expected=0.6 / 0.6,
|
||||
class_id=3,
|
||||
weights=(0.6, 0.3))
|
||||
|
||||
def test_one_label_at_k1_weighted(self):
|
||||
predictions = self._predictions
|
||||
predictions_idx = self._predictions_idx
|
||||
for labels in self._labels:
|
||||
# All classes: 2 labels, 2 predictions, 1 correct.
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=NAN, weights=(0.0,))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=NAN, weights=(0.0,))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=1.0 / 2, weights=(1.0,))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=1.0 / 2, weights=(2.0,))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
|
||||
self._test_recall_at_k(
|
||||
predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
|
||||
self._test_recall_at_top_k(
|
||||
predictions_idx, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
|
||||
|
||||
|
||||
class MultiLabel2dRecallAtKTest(test.TestCase):
|
||||
|
|
@ -2447,6 +2462,7 @@ class MultiLabel2dRecallAtKTest(test.TestCase):
|
|||
def setUp(self):
|
||||
self._predictions = ((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9),
|
||||
(0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6))
|
||||
self._predictions_idx = ((9, 4, 6, 2, 0), (5, 7, 2, 9, 6))
|
||||
indicator_labels = ((0, 0, 1, 0, 0, 0, 0, 1, 1, 0),
|
||||
(0, 1, 1, 0, 0, 1, 0, 0, 0, 0))
|
||||
class_labels = ((2, 7, 8), (1, 2, 5))
|
||||
|
|
@ -2456,6 +2472,8 @@ class MultiLabel2dRecallAtKTest(test.TestCase):
|
|||
class_labels, dtype=np.int64))
|
||||
self._test_recall_at_k = functools.partial(
|
||||
_test_recall_at_k, test_case=self)
|
||||
self._test_recall_at_top_k = functools.partial(
|
||||
_test_recall_at_top_k, test_case=self)
|
||||
|
||||
def test_at_k5_nan(self):
|
||||
for labels in self._labels:
|
||||
|
|
@ -2463,29 +2481,41 @@ class MultiLabel2dRecallAtKTest(test.TestCase):
|
|||
for class_id in (0, 3, 4, 6, 9, 10):
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=NAN, class_id=class_id)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=NAN, class_id=class_id)
|
||||
|
||||
def test_at_k5_no_predictions(self):
|
||||
for labels in self._labels:
|
||||
# Class 8: 1 label, no predictions.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=0.0 / 1, class_id=8)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=8)
|
||||
|
||||
def test_at_k5(self):
|
||||
for labels in self._labels:
|
||||
# Class 2: 2 labels, both correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
|
||||
|
||||
# Class 5: 1 label, incorrect.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
|
||||
|
||||
# Class 7: 1 label, incorrect.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
|
||||
|
||||
# All classes: 6 labels, 3 correct.
|
||||
self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 6)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=3.0 / 6)
|
||||
|
||||
def test_at_k5_some_out_of_range(self):
|
||||
"""Tests that labels outside the [0, n_classes) count in denominator."""
|
||||
|
|
@ -2499,17 +2529,25 @@ class MultiLabel2dRecallAtKTest(test.TestCase):
|
|||
# Class 2: 2 labels, both correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
|
||||
|
||||
# Class 5: 1 label, incorrect.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
|
||||
|
||||
# Class 7: 1 label, incorrect.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
|
||||
|
||||
# All classes: 8 labels, 3 correct.
|
||||
self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 8)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, labels, k=5, expected=3.0 / 8)
|
||||
|
||||
|
||||
class MultiLabel3dRecallAtKTest(test.TestCase):
|
||||
|
|
@ -2519,6 +2557,8 @@ class MultiLabel3dRecallAtKTest(test.TestCase):
|
|||
(0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)),
|
||||
((0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6),
|
||||
(0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9)))
|
||||
self._predictions_idx = (((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)),
|
||||
((5, 7, 2, 9, 6), (9, 4, 6, 2, 0)))
|
||||
# Note: We don't test dense labels here, since examples have different
|
||||
# numbers of labels.
|
||||
self._labels = _binary_3d_label_to_sparse_value(((
|
||||
|
|
@ -2526,114 +2566,128 @@ class MultiLabel3dRecallAtKTest(test.TestCase):
|
|||
(0, 1, 1, 0, 0, 1, 0, 1, 0, 0), (0, 0, 1, 0, 0, 0, 0, 0, 1, 0))))
|
||||
self._test_recall_at_k = functools.partial(
|
||||
_test_recall_at_k, test_case=self)
|
||||
self._test_recall_at_top_k = functools.partial(
|
||||
_test_recall_at_top_k, test_case=self)
|
||||
|
||||
def test_3d_nan(self):
|
||||
# Classes 0,3,4,6,9 have 0 labels, class 10 is out of range.
|
||||
for class_id in (0, 3, 4, 6, 9, 10):
|
||||
self._test_recall_at_k(
|
||||
self._predictions, self._labels, k=5, expected=NAN, class_id=class_id)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=NAN,
|
||||
class_id=class_id)
|
||||
|
||||
def test_3d_no_predictions(self):
|
||||
# Classes 1,8 have 0 predictions, >=1 label.
|
||||
for class_id in (1, 8):
|
||||
self._test_recall_at_k(
|
||||
self._predictions, self._labels, k=5, expected=0.0, class_id=class_id)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=0.0,
|
||||
class_id=class_id)
|
||||
|
||||
def test_3d(self):
|
||||
# Class 2: 4 labels, all correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, self._labels, k=5, expected=4.0 / 4, class_id=2)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=4.0 / 4,
|
||||
class_id=2)
|
||||
|
||||
# Class 5: 2 labels, both correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, self._labels, k=5, expected=2.0 / 2, class_id=5)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=2.0 / 2,
|
||||
class_id=5)
|
||||
|
||||
# Class 7: 2 labels, 1 incorrect.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, self._labels, k=5, expected=1.0 / 2, class_id=7)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=1.0 / 2,
|
||||
class_id=7)
|
||||
|
||||
# All classes: 12 labels, 7 correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions, self._labels, k=5, expected=7.0 / 12)
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=7.0 / 12)
|
||||
|
||||
def test_3d_ignore_all(self):
|
||||
for class_id in xrange(10):
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=NAN,
|
||||
class_id=class_id,
|
||||
self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
|
||||
weights=[[0], [0]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=NAN,
|
||||
class_id=class_id, weights=[[0], [0]])
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=NAN,
|
||||
class_id=class_id,
|
||||
self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
|
||||
weights=[[0, 0], [0, 0]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=NAN,
|
||||
class_id=class_id, weights=[[0, 0], [0, 0]])
|
||||
self._test_recall_at_k(
|
||||
self._predictions, self._labels, k=5, expected=NAN, weights=[[0], [0]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=NAN,
|
||||
weights=[[0], [0]])
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=NAN,
|
||||
self._predictions, self._labels, k=5, expected=NAN,
|
||||
weights=[[0, 0], [0, 0]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=NAN,
|
||||
weights=[[0, 0], [0, 0]])
|
||||
|
||||
def test_3d_ignore_some(self):
|
||||
# Class 2: 2 labels, both correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=2.0 / 2.0,
|
||||
class_id=2,
|
||||
self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
|
||||
weights=[[1], [0]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
|
||||
class_id=2, weights=[[1], [0]])
|
||||
|
||||
# Class 2: 2 labels, both correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=2.0 / 2.0,
|
||||
class_id=2,
|
||||
self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
|
||||
weights=[[0], [1]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
|
||||
class_id=2, weights=[[0], [1]])
|
||||
|
||||
# Class 7: 1 label, correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=1.0 / 1.0,
|
||||
class_id=7,
|
||||
self._predictions, self._labels, k=5, expected=1.0 / 1.0, class_id=7,
|
||||
weights=[[0], [1]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=1.0 / 1.0,
|
||||
class_id=7, weights=[[0], [1]])
|
||||
|
||||
# Class 7: 1 label, incorrect.
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=0.0 / 1.0,
|
||||
class_id=7,
|
||||
self._predictions, self._labels, k=5, expected=0.0 / 1.0, class_id=7,
|
||||
weights=[[1], [0]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=0.0 / 1.0,
|
||||
class_id=7, weights=[[1], [0]])
|
||||
|
||||
# Class 7: 2 labels, 1 correct.
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=1.0 / 2.0,
|
||||
class_id=7,
|
||||
self._predictions, self._labels, k=5, expected=1.0 / 2.0, class_id=7,
|
||||
weights=[[1, 0], [1, 0]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=1.0 / 2.0,
|
||||
class_id=7, weights=[[1, 0], [1, 0]])
|
||||
|
||||
# Class 7: No labels.
|
||||
self._test_recall_at_k(
|
||||
self._predictions,
|
||||
self._labels,
|
||||
k=5,
|
||||
expected=NAN,
|
||||
class_id=7,
|
||||
self._predictions, self._labels, k=5, expected=NAN, class_id=7,
|
||||
weights=[[0, 1], [0, 1]])
|
||||
self._test_recall_at_top_k(
|
||||
self._predictions_idx, self._labels, k=5, expected=NAN, class_id=7,
|
||||
weights=[[0, 1], [0, 1]])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@
|
|||
@@precision_at_thresholds
|
||||
@@recall
|
||||
@@recall_at_k
|
||||
@@recall_at_top_k
|
||||
@@recall_at_thresholds
|
||||
@@root_mean_squared_error
|
||||
@@sensitivity_at_specificity
|
||||
|
|
|
|||
|
|
@ -2246,10 +2246,8 @@ def recall_at_k(labels,
|
|||
with ops.name_scope(
|
||||
name, _at_k_name('recall', k, class_id=class_id),
|
||||
(predictions, labels, weights)) as scope:
|
||||
labels = _maybe_expand_labels(labels, predictions)
|
||||
|
||||
_, top_k_idx = nn.top_k(predictions, k)
|
||||
return _sparse_recall_at_top_k(
|
||||
return recall_at_top_k(
|
||||
labels=labels,
|
||||
predictions_idx=top_k_idx,
|
||||
k=k,
|
||||
|
|
@ -2260,14 +2258,14 @@ def recall_at_k(labels,
|
|||
name=scope)
|
||||
|
||||
|
||||
def _sparse_recall_at_top_k(labels,
|
||||
predictions_idx,
|
||||
k=None,
|
||||
class_id=None,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
name=None):
|
||||
def recall_at_top_k(labels,
|
||||
predictions_idx,
|
||||
k=None,
|
||||
class_id=None,
|
||||
weights=None,
|
||||
metrics_collections=None,
|
||||
updates_collections=None,
|
||||
name=None):
|
||||
"""Computes recall@k of top-k predictions with respect to sparse labels.
|
||||
|
||||
Differs from `recall_at_k` in that predictions must be in the form of top `k`
|
||||
|
|
@ -2287,7 +2285,7 @@ def _sparse_recall_at_top_k(labels,
|
|||
Commonly, N=1 and predictions has shape [batch size, k]. The final
|
||||
dimension contains the top `k` predicted class indices. [D1, ... DN] must
|
||||
match `labels`.
|
||||
k: Integer, k for @k metric.
|
||||
k: Integer, k for @k metric. Only used for the default op name.
|
||||
class_id: Integer class ID for which we want binary metrics. This should be
|
||||
in range [0, num_classes), where num_classes is the last dimension of
|
||||
`predictions`. If class_id is outside this range, the method returns NAN.
|
||||
|
|
@ -2316,6 +2314,7 @@ def _sparse_recall_at_top_k(labels,
|
|||
with ops.name_scope(name,
|
||||
_at_k_name('recall', k, class_id=class_id),
|
||||
(predictions_idx, labels, weights)) as scope:
|
||||
labels = _maybe_expand_labels(labels, predictions_idx)
|
||||
top_k_idx = math_ops.to_int64(predictions_idx)
|
||||
tp, tp_update = _streaming_sparse_true_positive_at_k(
|
||||
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
|
||||
|
|
|
|||
|
|
@ -84,6 +84,10 @@ tf_module {
|
|||
name: "recall_at_thresholds"
|
||||
argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "recall_at_top_k"
|
||||
argspec: "args=[\'labels\', \'predictions_idx\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "root_mean_squared_error"
|
||||
argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user