mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Fix the expand_dim for label and weight for classifier heads.
PiperOrigin-RevId: 157524909
This commit is contained in:
parent
346021ab4a
commit
bbeaa1307c
|
|
@ -417,7 +417,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
|
|||
})
|
||||
|
||||
# Eval.
|
||||
label_ids = self._label_ids(_check_labels(labels, 1))
|
||||
label_ids = self._label_ids(_check_labels(_maybe_expand_dim(labels), 1))
|
||||
|
||||
unweighted_loss = losses.sparse_softmax_cross_entropy(
|
||||
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
|
||||
|
|
@ -426,7 +426,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
|
|||
weights = (
|
||||
1. if (self._weight_feature_key is None) else
|
||||
features[self._weight_feature_key])
|
||||
weights = math_ops.to_float(weights, name='weights')
|
||||
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
if mode == model_fn.ModeKeys.EVAL:
|
||||
|
|
@ -577,13 +577,14 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
|
|||
classes=string_ops.as_string(classes, name='str_classes'))})
|
||||
|
||||
# Eval.
|
||||
labels = _check_labels(math_ops.to_float(labels), self.logits_dimension)
|
||||
labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)),
|
||||
self.logits_dimension)
|
||||
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits, name='loss')
|
||||
weights = (
|
||||
1. if (self._weight_feature_key is None) else
|
||||
features[self._weight_feature_key])
|
||||
weights = math_ops.to_float(weights, name='weights')
|
||||
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
if mode == model_fn.ModeKeys.EVAL:
|
||||
|
|
|
|||
|
|
@ -531,6 +531,61 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
|
|||
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
|
||||
}, summary_str, tol)
|
||||
|
||||
def test_train_with_one_dim_label_and_weights(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes, weight_feature_key='label_weights')
|
||||
|
||||
logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)
|
||||
labels_rank_1 = np.array((1, 2, 2,), dtype=np.int64)
|
||||
weights_rank_1 = np.array((1., 2., 3.,), dtype=np.float64)
|
||||
|
||||
self.assertEqual((3,), labels_rank_1.shape)
|
||||
self.assertEqual((3,), weights_rank_1.shape)
|
||||
|
||||
expected_train_result = 'my_train_op'
|
||||
# loss = sum(cross_entropy(labels, logits) * [1, 2, 3])
|
||||
# = sum([10, 10, 0] * [1, 2, 3]) = 30
|
||||
expected_loss = 30.
|
||||
|
||||
def _train_op_fn(loss):
|
||||
return string_ops.string_join(
|
||||
[constant_op.constant(expected_train_result),
|
||||
string_ops.as_string(loss, precision=2)])
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={
|
||||
'x': np.array(((42,),), dtype=np.float32),
|
||||
'label_weights': weights_rank_1,
|
||||
},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
logits=logits,
|
||||
labels=labels_rank_1,
|
||||
train_op_fn=_train_op_fn)
|
||||
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertEqual({}, spec.eval_metric_ops)
|
||||
self.assertIsNotNone(spec.train_op)
|
||||
self.assertIsNone(spec.export_outputs)
|
||||
_assert_no_hooks(self, spec)
|
||||
|
||||
# Assert predictions, loss, train_op, and summaries.
|
||||
tol = 1e-2
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNotNone(spec.scaffold.summary_op)
|
||||
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
|
||||
spec.scaffold.summary_op))
|
||||
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
|
||||
self.assertEqual(
|
||||
six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),
|
||||
train_result)
|
||||
_assert_simple_summaries(self, {
|
||||
metric_keys.MetricKeys.LOSS: expected_loss,
|
||||
metric_keys.MetricKeys.LOSS_MEAN: (
|
||||
expected_loss / np.sum(weights_rank_1)),
|
||||
}, summary_str, tol)
|
||||
|
||||
def test_train_with_vocabulary(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
|
|
@ -989,6 +1044,57 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
|||
self.assertAllClose(
|
||||
expected_metrics, {k: value_ops[k].eval() for k in value_ops})
|
||||
|
||||
def test_train_with_one_dim_labels_and_weights(self):
|
||||
"""3 examples, 1 batch."""
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
weight_feature_key='label_weights')
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)
|
||||
labels_rank_1 = np.array((1., 1., 0.,))
|
||||
weights_rank_1 = np.array(((1., .1, 1.5,)), dtype=np.float64)
|
||||
self.assertEqual((3,), labels_rank_1.shape)
|
||||
self.assertEqual((3,), weights_rank_1.shape)
|
||||
|
||||
expected_train_result = b'my_train_op'
|
||||
# losses = label_weights*cross_entropy(labels, logits)
|
||||
# = (1*0 + .1*41 + 1.5*44) = (1, 4.1, 66)
|
||||
# loss = sum(losses) = 1 + 4.1 + 66 = 70.1
|
||||
expected_loss = 70.1
|
||||
def _train_op_fn(loss):
|
||||
with ops.control_dependencies((check_ops.assert_equal(
|
||||
math_ops.to_float(expected_loss), math_ops.to_float(loss),
|
||||
name='assert_loss'),)):
|
||||
return constant_op.constant(expected_train_result)
|
||||
spec = head.create_estimator_spec(
|
||||
features={
|
||||
'x': np.array(((42.,), (43.,), (44.,)), dtype=np.float32),
|
||||
'label_weights': weights_rank_1,
|
||||
},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
logits=logits,
|
||||
labels=labels_rank_1,
|
||||
train_op_fn=_train_op_fn)
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertIsNotNone(spec.train_op)
|
||||
|
||||
# Assert predictions, loss, and metrics.
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNotNone(spec.scaffold.summary_op)
|
||||
loss, train_result, summary_str = sess.run((
|
||||
spec.loss, spec.train_op, spec.scaffold.summary_op))
|
||||
self.assertAllClose(expected_loss, loss)
|
||||
self.assertEqual(expected_train_result, train_result)
|
||||
_assert_simple_summaries(self, {
|
||||
metric_keys.MetricKeys.LOSS: expected_loss,
|
||||
# loss_mean = loss/sum(label_weights) = 70.1/(1 + .1 + 1.5)
|
||||
# = 70.1/2.6 = 26.9615384615
|
||||
metric_keys.MetricKeys.LOSS_MEAN: 26.9615384615,
|
||||
}, summary_str)
|
||||
|
||||
def test_weighted_multi_example_train(self):
|
||||
"""3 examples, 1 batch."""
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user