mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Test asserts should have expected first.
PiperOrigin-RevId: 163409348
This commit is contained in:
parent
d5cc143e27
commit
905abb1f9d
|
|
@ -260,10 +260,10 @@ class ExponentialMovingAverageTest(test.TestCase):
|
|||
ema.average_name(tensor2) + "/biased",
|
||||
ema.average_name(tensor2) + "/local_step"
|
||||
]
|
||||
self.assertEqual(sorted(vars_to_restore.keys()), sorted(expected_names))
|
||||
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
|
||||
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
|
||||
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
|
||||
self.assertEqual(sorted(expected_names), sorted(vars_to_restore.keys()))
|
||||
self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
|
||||
self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
|
||||
self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
|
||||
|
||||
def testAverageVariablesNames(self):
|
||||
self.averageVariablesNamesHelper(zero_debias=True)
|
||||
|
|
@ -307,11 +307,11 @@ class ExponentialMovingAverageTest(test.TestCase):
|
|||
sc + ema.average_name(tensor2) + "/local_step"
|
||||
]
|
||||
|
||||
self.assertEqual(sorted(vars_to_restore.keys()), sorted(expected_names))
|
||||
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
|
||||
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
|
||||
self.assertEqual(sorted(expected_names), sorted(vars_to_restore.keys()))
|
||||
self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
|
||||
self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
|
||||
self.assertEqual(
|
||||
ema.average_name(tensor2), ema.average(tensor2).op.name)
|
||||
ema.average(tensor2).op.name, ema.average_name(tensor2))
|
||||
|
||||
def testAverageVariablesNamesRespectScope(self):
|
||||
self.averageVariablesNamesRespectScopeHelper(zero_debias=True)
|
||||
|
|
@ -343,9 +343,9 @@ class ExponentialMovingAverageTest(test.TestCase):
|
|||
v2.op.name
|
||||
]))
|
||||
ema.apply([v0, v1, tensor2])
|
||||
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
|
||||
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
|
||||
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
|
||||
self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
|
||||
self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
|
||||
self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
|
||||
|
||||
def testAverageVariablesDeviceAssignment(self):
|
||||
with ops.device("/job:dev_v0"):
|
||||
|
|
|
|||
|
|
@ -36,10 +36,10 @@ class SlotCreatorTest(test.TestCase):
|
|||
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertEqual(slot.op.name, "var/slot")
|
||||
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||
self.assertEqual(slot.dtype.base_dtype, dtypes.float32)
|
||||
self.assertAllEqual(slot.eval(), [1.0, 2.5])
|
||||
self.assertEqual("var/slot", slot.op.name)
|
||||
self.assertEqual([2], slot.get_shape().as_list())
|
||||
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
||||
self.assertAllEqual([1.0, 2.5], slot.eval())
|
||||
|
||||
def testCreateSlotFromTensor(self):
|
||||
with self.test_session():
|
||||
|
|
@ -48,10 +48,10 @@ class SlotCreatorTest(test.TestCase):
|
|||
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertEqual(slot.op.name, "const/slot")
|
||||
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||
self.assertEqual(slot.dtype.base_dtype, dtypes.float32)
|
||||
self.assertAllEqual(slot.eval(), [2.0, 5.0])
|
||||
self.assertEqual("const/slot", slot.op.name)
|
||||
self.assertEqual([2], slot.get_shape().as_list())
|
||||
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
||||
self.assertAllEqual([2.0, 5.0], slot.eval())
|
||||
|
||||
def testCreateZerosSlotFromVariable(self):
|
||||
with self.test_session():
|
||||
|
|
@ -62,10 +62,10 @@ class SlotCreatorTest(test.TestCase):
|
|||
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertEqual(slot.op.name, "var/slot")
|
||||
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||
self.assertEqual(slot.dtype.base_dtype, dtypes.float64)
|
||||
self.assertAllEqual(slot.eval(), [0.0, 0.0])
|
||||
self.assertEqual("var/slot", slot.op.name)
|
||||
self.assertEqual([2], slot.get_shape().as_list())
|
||||
self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
|
||||
self.assertAllEqual([0.0, 0.0], slot.eval())
|
||||
|
||||
def testCreateZerosSlotFromTensor(self):
|
||||
with self.test_session():
|
||||
|
|
@ -75,10 +75,10 @@ class SlotCreatorTest(test.TestCase):
|
|||
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertEqual(slot.op.name, "const/slot")
|
||||
self.assertEqual(slot.get_shape().as_list(), [2])
|
||||
self.assertEqual(slot.dtype.base_dtype, dtypes.float32)
|
||||
self.assertAllEqual(slot.eval(), [0.0, 0.0])
|
||||
self.assertEqual("const/slot", slot.op.name)
|
||||
self.assertEqual([2], slot.get_shape().as_list())
|
||||
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
||||
self.assertAllEqual([0.0, 0.0], slot.eval())
|
||||
|
||||
def testCreateSlotFromVariableRespectsScope(self):
|
||||
# See discussion on #2740.
|
||||
|
|
@ -86,7 +86,7 @@ class SlotCreatorTest(test.TestCase):
|
|||
with variable_scope.variable_scope("scope"):
|
||||
v = variables.Variable([1.0, 2.5], name="var")
|
||||
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
|
||||
self.assertEqual(slot.op.name, "scope/scope/var/slot")
|
||||
self.assertEqual("scope/scope/var/slot", slot.op.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user