Adapt TensorFlowTestCase.setUp() to new reset_default_graph() semantics

Avoid calling reset_default_graph() directly to prevent exceptions in
cases where test methods error out from within nested graph contexts,
which can leave _default_graph_stack non-empty in certain Python
versions.
This commit is contained in:
Shanqing Cai 2017-07-06 23:25:44 -04:00
parent 1e037850f1
commit 01383b946e

View File

@ -266,7 +266,12 @@ class TensorFlowTestCase(googletest.TestCase):
self._ClearCachedSession()
random.seed(random_seed.DEFAULT_GRAPH_SEED)
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
ops.reset_default_graph()
# Note: we avoid calling ops.reset_default_graph() here due to the fact that
# under certain Python versions, test methods that error out from within
# nested graph contexts may leave ops._default_graph_stack non-empty,
# which would cause ops.reset_default_graph() to throw an exception if it
# were used in the following line.
ops._default_graph_stack.reset() # pylint: disable=protected-access
ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED
def tearDown(self):