Swap the order of NanTensorHook and custom hooks

to ensure that when the training encounteres NaN's in the loss function, user-supplied hooks such as tf_debug.LocalCLIDebugHook can still be used to debug the root cause of the numeric issues.

PiperOrigin-RevId: 158310249
This commit is contained in:
Shanqing Cai 2017-06-07 13:37:02 -07:00 committed by TensorFlower Gardener
parent 599727c654
commit 38249d6be2
3 changed files with 44 additions and 2 deletions

View File

@ -957,6 +957,7 @@ class BaseEstimator(
self._check_inputs(features, labels)
model_fn_ops = self._get_train_ops(features, labels)
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
all_hooks.extend(hooks)
all_hooks.extend([
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
basic_session_run_hooks.LoggingTensorHook(
@ -966,7 +967,6 @@ class BaseEstimator(
},
every_n_iter=100)
])
all_hooks.extend(hooks)
scaffold = model_fn_ops.scaffold or monitored_session.Scaffold()
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):

View File

@ -589,6 +589,7 @@ class Estimator(object):
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
all_hooks.extend(hooks)
all_hooks.extend([
training.NanTensorHook(estimator_spec.loss),
training.LoggingTensorHook(
@ -598,7 +599,6 @@ class Estimator(object):
},
every_n_iter=100)
])
all_hooks.extend(hooks)
all_hooks.extend(estimator_spec.training_hooks)
if not (estimator_spec.scaffold.saver or

View File

@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
@ -1520,6 +1521,47 @@ class EstimatorExportTest(test.TestCase):
est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
class EstimatorHookOrderingTest(test.TestCase):
def testCustomHooksAreCalledBeforeNanTensorHook(self):
def nan_making_model_fn(mode, features, labels):
"""A graph that generates NaN's for testing."""
del features, labels
global_step = variables.Variable(
0, dtype=dtypes.int64, name='global_step')
inc_global_step = state_ops.assign_add(global_step, 1)
nan_const = constant_op.constant(np.nan, dtype=dtypes.float32)
loss = control_flow_ops.cond(
inc_global_step > 1, lambda: nan_const, lambda: 1.0)
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=global_step.read_value(),
loss=loss,
train_op=inc_global_step)
def empty_input_fn():
return dict(), None
class AfterRunCountingHook(session_run_hook.SessionRunHook):
"""Hooks that counts the number of times after_run() is called."""
def __init__(self):
self.after_run_count = 0
def after_run(self, run_context, run_values):
del run_context, run_values
self.after_run_count += 1
test_hook = AfterRunCountingHook()
est = estimator.Estimator(model_fn=nan_making_model_fn)
with self.assertRaises(basic_session_run_hooks.NanLossDuringTrainingError):
est.train(input_fn=empty_input_fn, steps=2, hooks=[test_hook])
self.assertEqual(2, test_hook.after_run_count)
class EstimatorIntegrationTest(test.TestCase):
def test_complete_flow_with_a_simple_linear_model(self):