Test asserts should have expected first.

PiperOrigin-RevId: 163409348
This commit is contained in:
A. Unique TensorFlower 2017-07-27 16:54:22 -07:00 committed by TensorFlower Gardener
parent d5cc143e27
commit 905abb1f9d
2 changed files with 28 additions and 28 deletions

View File

@ -260,10 +260,10 @@ class ExponentialMovingAverageTest(test.TestCase):
ema.average_name(tensor2) + "/biased",
ema.average_name(tensor2) + "/local_step"
]
self.assertEqual(sorted(vars_to_restore.keys()), sorted(expected_names))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
self.assertEqual(sorted(expected_names), sorted(vars_to_restore.keys()))
self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
def testAverageVariablesNames(self):
self.averageVariablesNamesHelper(zero_debias=True)
@ -307,11 +307,11 @@ class ExponentialMovingAverageTest(test.TestCase):
sc + ema.average_name(tensor2) + "/local_step"
]
self.assertEqual(sorted(vars_to_restore.keys()), sorted(expected_names))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(sorted(expected_names), sorted(vars_to_restore.keys()))
self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
self.assertEqual(
ema.average_name(tensor2), ema.average(tensor2).op.name)
ema.average(tensor2).op.name, ema.average_name(tensor2))
def testAverageVariablesNamesRespectScope(self):
self.averageVariablesNamesRespectScopeHelper(zero_debias=True)
@ -343,9 +343,9 @@ class ExponentialMovingAverageTest(test.TestCase):
v2.op.name
]))
ema.apply([v0, v1, tensor2])
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
def testAverageVariablesDeviceAssignment(self):
with ops.device("/job:dev_v0"):

View File

@ -36,10 +36,10 @@ class SlotCreatorTest(test.TestCase):
variables.global_variables_initializer().run()
self.assertEqual(slot.op.name, "var/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, dtypes.float32)
self.assertAllEqual(slot.eval(), [1.0, 2.5])
self.assertEqual("var/slot", slot.op.name)
self.assertEqual([2], slot.get_shape().as_list())
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual([1.0, 2.5], slot.eval())
def testCreateSlotFromTensor(self):
with self.test_session():
@ -48,10 +48,10 @@ class SlotCreatorTest(test.TestCase):
variables.global_variables_initializer().run()
self.assertEqual(slot.op.name, "const/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, dtypes.float32)
self.assertAllEqual(slot.eval(), [2.0, 5.0])
self.assertEqual("const/slot", slot.op.name)
self.assertEqual([2], slot.get_shape().as_list())
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual([2.0, 5.0], slot.eval())
def testCreateZerosSlotFromVariable(self):
with self.test_session():
@ -62,10 +62,10 @@ class SlotCreatorTest(test.TestCase):
variables.global_variables_initializer().run()
self.assertEqual(slot.op.name, "var/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, dtypes.float64)
self.assertAllEqual(slot.eval(), [0.0, 0.0])
self.assertEqual("var/slot", slot.op.name)
self.assertEqual([2], slot.get_shape().as_list())
self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateZerosSlotFromTensor(self):
with self.test_session():
@ -75,10 +75,10 @@ class SlotCreatorTest(test.TestCase):
variables.global_variables_initializer().run()
self.assertEqual(slot.op.name, "const/slot")
self.assertEqual(slot.get_shape().as_list(), [2])
self.assertEqual(slot.dtype.base_dtype, dtypes.float32)
self.assertAllEqual(slot.eval(), [0.0, 0.0])
self.assertEqual("const/slot", slot.op.name)
self.assertEqual([2], slot.get_shape().as_list())
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateSlotFromVariableRespectsScope(self):
# See discussion on #2740.
@ -86,7 +86,7 @@ class SlotCreatorTest(test.TestCase):
with variable_scope.variable_scope("scope"):
v = variables.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
self.assertEqual(slot.op.name, "scope/scope/var/slot")
self.assertEqual("scope/scope/var/slot", slot.op.name)
if __name__ == "__main__":