mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Swap the order of NanTensorHook and custom hooks
to ensure that when the training encounteres NaN's in the loss function, user-supplied hooks such as tf_debug.LocalCLIDebugHook can still be used to debug the root cause of the numeric issues. PiperOrigin-RevId: 158310249
This commit is contained in:
parent
599727c654
commit
38249d6be2
|
|
@ -957,6 +957,7 @@ class BaseEstimator(
|
|||
self._check_inputs(features, labels)
|
||||
model_fn_ops = self._get_train_ops(features, labels)
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
|
||||
all_hooks.extend(hooks)
|
||||
all_hooks.extend([
|
||||
basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
|
||||
basic_session_run_hooks.LoggingTensorHook(
|
||||
|
|
@ -966,7 +967,6 @@ class BaseEstimator(
|
|||
},
|
||||
every_n_iter=100)
|
||||
])
|
||||
all_hooks.extend(hooks)
|
||||
|
||||
scaffold = model_fn_ops.scaffold or monitored_session.Scaffold()
|
||||
if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
|
||||
|
|
|
|||
|
|
@ -589,6 +589,7 @@ class Estimator(object):
|
|||
estimator_spec = self._call_model_fn(features, labels,
|
||||
model_fn_lib.ModeKeys.TRAIN)
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
|
||||
all_hooks.extend(hooks)
|
||||
all_hooks.extend([
|
||||
training.NanTensorHook(estimator_spec.loss),
|
||||
training.LoggingTensorHook(
|
||||
|
|
@ -598,7 +599,6 @@ class Estimator(object):
|
|||
},
|
||||
every_n_iter=100)
|
||||
])
|
||||
all_hooks.extend(hooks)
|
||||
all_hooks.extend(estimator_spec.training_hooks)
|
||||
|
||||
if not (estimator_spec.scaffold.saver or
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
|
|||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_state_pb2
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import saver_test_utils
|
||||
|
|
@ -1520,6 +1521,47 @@ class EstimatorExportTest(test.TestCase):
|
|||
est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
|
||||
|
||||
|
||||
class EstimatorHookOrderingTest(test.TestCase):
|
||||
|
||||
def testCustomHooksAreCalledBeforeNanTensorHook(self):
|
||||
|
||||
def nan_making_model_fn(mode, features, labels):
|
||||
"""A graph that generates NaN's for testing."""
|
||||
del features, labels
|
||||
|
||||
global_step = variables.Variable(
|
||||
0, dtype=dtypes.int64, name='global_step')
|
||||
inc_global_step = state_ops.assign_add(global_step, 1)
|
||||
nan_const = constant_op.constant(np.nan, dtype=dtypes.float32)
|
||||
loss = control_flow_ops.cond(
|
||||
inc_global_step > 1, lambda: nan_const, lambda: 1.0)
|
||||
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
predictions=global_step.read_value(),
|
||||
loss=loss,
|
||||
train_op=inc_global_step)
|
||||
|
||||
def empty_input_fn():
|
||||
return dict(), None
|
||||
|
||||
class AfterRunCountingHook(session_run_hook.SessionRunHook):
|
||||
"""Hooks that counts the number of times after_run() is called."""
|
||||
|
||||
def __init__(self):
|
||||
self.after_run_count = 0
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
del run_context, run_values
|
||||
self.after_run_count += 1
|
||||
|
||||
test_hook = AfterRunCountingHook()
|
||||
est = estimator.Estimator(model_fn=nan_making_model_fn)
|
||||
with self.assertRaises(basic_session_run_hooks.NanLossDuringTrainingError):
|
||||
est.train(input_fn=empty_input_fn, steps=2, hooks=[test_hook])
|
||||
self.assertEqual(2, test_hook.after_run_count)
|
||||
|
||||
|
||||
class EstimatorIntegrationTest(test.TestCase):
|
||||
|
||||
def test_complete_flow_with_a_simple_linear_model(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user