Adds eager execution compatibility note in Estimators.

Raises a RuntimeError in Estimator base class.

PiperOrigin-RevId: 173744765
This commit is contained in:
A. Unique TensorFlower 2017-10-27 19:26:27 -07:00 committed by TensorFlower Gardener
parent 9f4b12bb55
commit c16797ec36
4 changed files with 34 additions and 1 deletions

View File

@ -259,6 +259,10 @@ class DNNClassifier(estimator.Estimator):
whose `value` is a `Tensor`.
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,
@ -392,6 +396,10 @@ class DNNRegressor(estimator.Estimator):
whose `value` is a `Tensor`.
Loss is calculated by using mean squared error.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,

View File

@ -278,6 +278,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
whose `value` is a `Tensor`.
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,
@ -438,6 +442,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
whose `value` is a `Tensor`.
Loss is calculated by using mean squared error.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,

View File

@ -184,6 +184,10 @@ class LinearClassifier(estimator.Estimator):
whose `value` is a `Tensor`.
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,
@ -300,6 +304,10 @@ class LinearRegressor(estimator.Estimator):
key=column.name, value=a `Tensor`
Loss is calculated by using mean squared error.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self,

View File

@ -29,6 +29,7 @@ import six
from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util
@ -87,6 +88,10 @@ class Estimator(object):
None of `Estimator`'s methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
@compatibility(eager)
Estimators are not compatible with eager execution.
@end_compatibility
"""
def __init__(self, model_fn, model_dir=None, config=None, params=None):
@ -129,10 +134,15 @@ class Estimator(object):
Keys are names of parameters, values are basic python types.
Raises:
RuntimeError: If eager execution is enabled.
ValueError: parameters of `model_fn` don't match `params`.
ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`.
"""
if context.in_eager_mode():
raise RuntimeError(
'Estimators are not supported when eager execution is enabled.')
Estimator._assert_members_are_not_overridden(self)
if config is None:
@ -1016,4 +1026,3 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)