mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Added get variable utils to tf.estimator.Estimator.
PiperOrigin-RevId: 171052121
This commit is contained in:
parent
3cf41b2edd
commit
ad69076ebd
|
|
@ -204,6 +204,34 @@ class Estimator(object):
|
|||
|
||||
return public_model_fn
|
||||
|
||||
# TODO(ispir): support a list of names
|
||||
def get_variable_value(self, name):
|
||||
"""Returns value of the variable given by name.
|
||||
|
||||
Args:
|
||||
name: string or a list of string, name of the tensor.
|
||||
|
||||
Returns:
|
||||
Numpy array - value of the tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If the Estimator has not produced a checkpoint yet.
|
||||
"""
|
||||
_check_checkpoint_available(self.model_dir)
|
||||
return training.load_variable(self.model_dir, name)
|
||||
|
||||
def get_variable_names(self):
|
||||
"""Returns list of all variable names in this model.
|
||||
|
||||
Returns:
|
||||
List of names.
|
||||
|
||||
Raises:
|
||||
ValueError: If the Estimator has not produced a checkpoint yet.
|
||||
"""
|
||||
_check_checkpoint_available(self.model_dir)
|
||||
return [name for name, _ in training.list_variables(self.model_dir)]
|
||||
|
||||
def latest_checkpoint(self):
|
||||
"""Finds the filename of latest saved checkpoint file in `model_dir`.
|
||||
|
||||
|
|
@ -818,6 +846,13 @@ class Estimator(object):
|
|||
return eval_results
|
||||
|
||||
|
||||
def _check_checkpoint_available(model_dir):
|
||||
latest_path = saver.latest_checkpoint(model_dir)
|
||||
if not latest_path:
|
||||
raise ValueError(
|
||||
'Could not find trained model in model_dir: {}.'.format(model_dir))
|
||||
|
||||
|
||||
def _check_hooks_type(hooks):
|
||||
"""Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
|
||||
hooks = list(hooks or [])
|
||||
|
|
|
|||
|
|
@ -862,6 +862,43 @@ class _StepCounterHook(session_run_hook.SessionRunHook):
|
|||
return self._steps
|
||||
|
||||
|
||||
class EstimatorGetVariablesTest(test.TestCase):
|
||||
|
||||
def test_model_should_be_trained(self):
|
||||
|
||||
def _model_fn(features, labels, mode):
|
||||
_, _ = features, labels
|
||||
variables.Variable(1., name='one')
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=constant_op.constant(0.),
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1))
|
||||
|
||||
est = estimator.Estimator(model_fn=_model_fn)
|
||||
with self.assertRaisesRegexp(ValueError, 'not find trained model'):
|
||||
est.get_variable_names()
|
||||
with self.assertRaisesRegexp(ValueError, 'not find trained model'):
|
||||
est.get_variable_value('one')
|
||||
|
||||
def test_get_variable_utils(self):
|
||||
|
||||
def _model_fn(features, labels, mode):
|
||||
_, _ = features, labels
|
||||
variables.Variable(1., name='one')
|
||||
variables.Variable(3., name='three')
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=constant_op.constant(0.),
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1))
|
||||
|
||||
est = estimator.Estimator(model_fn=_model_fn)
|
||||
est.train(input_fn=dummy_input_fn, steps=1)
|
||||
self.assertEqual(
|
||||
set(['one', 'three', 'global_step']), set(est.get_variable_names()))
|
||||
self.assertEqual(1., est.get_variable_value('one'))
|
||||
self.assertEqual(3., est.get_variable_value('three'))
|
||||
|
||||
|
||||
class EstimatorEvaluateTest(test.TestCase):
|
||||
|
||||
def test_input_fn_args(self):
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ tf_class {
|
|||
name: "export_savedmodel"
|
||||
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "latest_checkpoint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ tf_class {
|
|||
name: "export_savedmodel"
|
||||
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "latest_checkpoint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ tf_class {
|
|||
name: "export_savedmodel"
|
||||
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "latest_checkpoint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ tf_class {
|
|||
name: "export_savedmodel"
|
||||
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "latest_checkpoint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -30,6 +30,14 @@ tf_class {
|
|||
name: "export_savedmodel"
|
||||
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "latest_checkpoint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ tf_class {
|
|||
name: "export_savedmodel"
|
||||
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "latest_checkpoint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@ tf_class {
|
|||
name: "export_savedmodel"
|
||||
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_variable_value"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "latest_checkpoint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user