mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add tests with different delta to huber_loss.
PiperOrigin-RevId: 158191361
This commit is contained in:
parent
a4e7b7add4
commit
51acad09c1
|
|
@ -845,6 +845,25 @@ class HuberLossTest(test.TestCase):
|
|||
expected_loss = (quadratic + linear) / 2.
|
||||
self.assertAllClose(loss.eval(), expected_loss, atol=1e-5)
|
||||
|
||||
def testAllQuadraticDelta(self):
|
||||
with self.test_session():
|
||||
delta = 0.5
|
||||
predictions = constant_op.constant([1.5, -1.4, -0.5, 0.0])
|
||||
labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
|
||||
expected = 0.5 * np.array([0.5**2, 0.4**2, 0.5**2, 0.5**2]).mean()
|
||||
loss = losses.huber_loss(labels, predictions, delta=delta)
|
||||
self.assertAllClose(expected, loss.eval(), atol=1e-5)
|
||||
|
||||
def testAllLinearDelta(self):
|
||||
delta = 0.5
|
||||
predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
|
||||
labels = constant_op.constant([0.0, 1.0, 0.0, 1.5])
|
||||
expected = delta * np.array([1.5, 2.4, 1.0, 1.5]).mean()
|
||||
expected -= 0.5 * delta**2
|
||||
loss = losses.huber_loss(labels, predictions, delta=delta)
|
||||
with self.test_session():
|
||||
self.assertAllClose(expected, loss.eval(), atol=1e-5)
|
||||
|
||||
|
||||
class MeanSquaredErrorTest(test.TestCase):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user