mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Adds eager execution compatibility note in Estimators.
Raises a RuntimeError in Estimator base class. PiperOrigin-RevId: 173744765
This commit is contained in:
parent
9f4b12bb55
commit
c16797ec36
|
|
@ -259,6 +259,10 @@ class DNNClassifier(estimator.Estimator):
|
|||
whose `value` is a `Tensor`.
|
||||
|
||||
Loss is calculated by using softmax cross entropy.
|
||||
|
||||
@compatibility(eager)
|
||||
Estimators are not compatible with eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -392,6 +396,10 @@ class DNNRegressor(estimator.Estimator):
|
|||
whose `value` is a `Tensor`.
|
||||
|
||||
Loss is calculated by using mean squared error.
|
||||
|
||||
@compatibility(eager)
|
||||
Estimators are not compatible with eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
|||
|
|
@ -278,6 +278,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
|||
whose `value` is a `Tensor`.
|
||||
|
||||
Loss is calculated by using softmax cross entropy.
|
||||
|
||||
@compatibility(eager)
|
||||
Estimators are not compatible with eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -438,6 +442,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
|
|||
whose `value` is a `Tensor`.
|
||||
|
||||
Loss is calculated by using mean squared error.
|
||||
|
||||
@compatibility(eager)
|
||||
Estimators are not compatible with eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
|||
|
|
@ -184,6 +184,10 @@ class LinearClassifier(estimator.Estimator):
|
|||
whose `value` is a `Tensor`.
|
||||
|
||||
Loss is calculated by using softmax cross entropy.
|
||||
|
||||
@compatibility(eager)
|
||||
Estimators are not compatible with eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -300,6 +304,10 @@ class LinearRegressor(estimator.Estimator):
|
|||
key=column.name, value=a `Tensor`
|
||||
|
||||
Loss is calculated by using mean squared error.
|
||||
|
||||
@compatibility(eager)
|
||||
Estimators are not compatible with eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import six
|
|||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator import run_config
|
||||
from tensorflow.python.estimator import util
|
||||
|
|
@ -87,6 +88,10 @@ class Estimator(object):
|
|||
None of `Estimator`'s methods can be overridden in subclasses (its
|
||||
constructor enforces this). Subclasses should use `model_fn` to configure
|
||||
the base class, and may add methods implementing specialized functionality.
|
||||
|
||||
@compatibility(eager)
|
||||
Estimators are not compatible with eager execution.
|
||||
@end_compatibility
|
||||
"""
|
||||
|
||||
def __init__(self, model_fn, model_dir=None, config=None, params=None):
|
||||
|
|
@ -129,10 +134,15 @@ class Estimator(object):
|
|||
Keys are names of parameters, values are basic python types.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled.
|
||||
ValueError: parameters of `model_fn` don't match `params`.
|
||||
ValueError: if this is called via a subclass and if that class overrides
|
||||
a member of `Estimator`.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
raise RuntimeError(
|
||||
'Estimators are not supported when eager execution is enabled.')
|
||||
|
||||
Estimator._assert_members_are_not_overridden(self)
|
||||
|
||||
if config is None:
|
||||
|
|
@ -1016,4 +1026,3 @@ def _has_dataset_or_queue_runner(maybe_tensor):
|
|||
|
||||
# Now, check queue.
|
||||
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user