Added get variable utils to tf.estimator.Estimator.

PiperOrigin-RevId: 171052121
This commit is contained in:
Mustafa Ispir 2017-10-04 13:14:04 -07:00 committed by TensorFlower Gardener
parent 3cf41b2edd
commit ad69076ebd
9 changed files with 128 additions and 0 deletions

View File

@ -204,6 +204,34 @@ class Estimator(object):
return public_model_fn
# TODO(ispir): support a list of names
def get_variable_value(self, name):
"""Returns value of the variable given by name.
Args:
name: string or a list of string, name of the tensor.
Returns:
Numpy array - value of the tensor.
Raises:
ValueError: If the Estimator has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
return training.load_variable(self.model_dir, name)
def get_variable_names(self):
"""Returns list of all variable names in this model.
Returns:
List of names.
Raises:
ValueError: If the Estimator has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
return [name for name, _ in training.list_variables(self.model_dir)]
def latest_checkpoint(self):
"""Finds the filename of latest saved checkpoint file in `model_dir`.
@ -818,6 +846,13 @@ class Estimator(object):
return eval_results
def _check_checkpoint_available(model_dir):
latest_path = saver.latest_checkpoint(model_dir)
if not latest_path:
raise ValueError(
'Could not find trained model in model_dir: {}.'.format(model_dir))
def _check_hooks_type(hooks):
"""Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
hooks = list(hooks or [])

View File

@ -862,6 +862,43 @@ class _StepCounterHook(session_run_hook.SessionRunHook):
return self._steps
class EstimatorGetVariablesTest(test.TestCase):
def test_model_should_be_trained(self):
def _model_fn(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='one')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
train_op=state_ops.assign_add(training.get_global_step(), 1))
est = estimator.Estimator(model_fn=_model_fn)
with self.assertRaisesRegexp(ValueError, 'not find trained model'):
est.get_variable_names()
with self.assertRaisesRegexp(ValueError, 'not find trained model'):
est.get_variable_value('one')
def test_get_variable_utils(self):
def _model_fn(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='one')
variables.Variable(3., name='three')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
train_op=state_ops.assign_add(training.get_global_step(), 1))
est = estimator.Estimator(model_fn=_model_fn)
est.train(input_fn=dummy_input_fn, steps=1)
self.assertEqual(
set(['one', 'three', 'global_step']), set(est.get_variable_names()))
self.assertEqual(1., est.get_variable_value('one'))
self.assertEqual(3., est.get_variable_value('three'))
class EstimatorEvaluateTest(test.TestCase):
def test_input_fn_args(self):

View File

@ -31,6 +31,14 @@ tf_class {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -31,6 +31,14 @@ tf_class {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -31,6 +31,14 @@ tf_class {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -31,6 +31,14 @@ tf_class {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -30,6 +30,14 @@ tf_class {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -31,6 +31,14 @@ tf_class {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -31,6 +31,14 @@ tf_class {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "get_variable_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_variable_value"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "latest_checkpoint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"