mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Remove Estimator from Tensorflow.
The change was made public in TensorFlow 2.14 and 2.15 release notes: https://github.com/tensorflow/tensorflow/releases PiperOrigin-RevId: 600623585
This commit is contained in:
parent
29362bf8e6
commit
aa35dc2761
|
|
@ -71,9 +71,3 @@ teardown_file() {
|
|||
source /tf/venv/bin/activate
|
||||
python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)'
|
||||
}
|
||||
|
||||
# Is this still useful?
|
||||
@test "TensorFlow has Estimator" {
|
||||
source /tf/venv/bin/activate
|
||||
python3 -c 'import sys; import tensorflow as tf; sys.exit(0 if "_v2.estimator" in tf.estimator.__name__ else 1)'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ compile_pip_requirements_3_9(
|
|||
"--allow-unsafe",
|
||||
"-P keras-nightly",
|
||||
"-P tb-nightly",
|
||||
"-P tf-estimator-nightly",
|
||||
],
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = "requirements_lock_3_9.txt",
|
||||
|
|
@ -65,7 +64,6 @@ compile_pip_requirements_3_10(
|
|||
"--allow-unsafe",
|
||||
"-P keras-nightly",
|
||||
"-P tb-nightly",
|
||||
"-P tf-estimator-nightly",
|
||||
],
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = "requirements_lock_3_10.txt",
|
||||
|
|
@ -77,7 +75,6 @@ compile_pip_requirements_3_11(
|
|||
"--allow-unsafe",
|
||||
"-P keras-nightly",
|
||||
"-P tb-nightly",
|
||||
"-P tf-estimator-nightly",
|
||||
],
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = "requirements_lock_3_11.txt",
|
||||
|
|
@ -89,7 +86,6 @@ compile_pip_requirements_3_12(
|
|||
"--allow-unsafe",
|
||||
"-P keras-nightly",
|
||||
"-P tb-nightly",
|
||||
"-P tf-estimator-nightly",
|
||||
],
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = "requirements_lock_3_12.txt",
|
||||
|
|
|
|||
|
|
@ -108,7 +108,6 @@ unless indicated otherwise.
|
|||
"--allow-unsafe",
|
||||
"-P keras-nightly",
|
||||
"-P tb-nightly",
|
||||
"-P tf-estimator-nightly",
|
||||
],
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = "requirements_lock_3_11.txt",
|
||||
|
|
|
|||
|
|
@ -14,13 +14,12 @@ termcolor == 2.3.0
|
|||
wrapt == 1.14.1
|
||||
tblib == 2.0.0
|
||||
|
||||
# Install tensorboard and estimator and keras
|
||||
# Install tensorboard, and keras
|
||||
# Note that here we want the latest version that matches TF major.minor version
|
||||
# Note that we must use nightly here as these are used in nightly jobs
|
||||
# For release jobs, we will pin these on the release branch
|
||||
keras-nightly ~= 3.0.0.dev
|
||||
tb-nightly ~= 2.15.0.a
|
||||
tf-estimator-nightly ~= 2.15.0.dev
|
||||
|
||||
# Test dependencies
|
||||
grpcio >= 1.24.3, < 2.0
|
||||
|
|
|
|||
|
|
@ -542,9 +542,6 @@ termcolor==2.3.0 \
|
|||
--hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
|
||||
--hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a
|
||||
# via -r requirements.in
|
||||
tf-estimator-nightly==2.15.0.dev2023101608 \
|
||||
--hash=sha256:fc045b32fb1a607da93799b3da0642527195a716cac424367f3c5f4edc2ec21e
|
||||
# via -r requirements.in
|
||||
typing-extensions==4.8.0 \
|
||||
--hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \
|
||||
--hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef
|
||||
|
|
|
|||
|
|
@ -542,9 +542,6 @@ termcolor==2.3.0 \
|
|||
--hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
|
||||
--hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a
|
||||
# via -r requirements.in
|
||||
tf-estimator-nightly==2.15.0.dev2023101608 \
|
||||
--hash=sha256:fc045b32fb1a607da93799b3da0642527195a716cac424367f3c5f4edc2ec21e
|
||||
# via -r requirements.in
|
||||
typing-extensions==4.8.0 \
|
||||
--hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \
|
||||
--hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef
|
||||
|
|
|
|||
|
|
@ -546,9 +546,6 @@ termcolor==2.3.0 \
|
|||
--hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
|
||||
--hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a
|
||||
# via -r requirements.in
|
||||
tf-estimator-nightly==2.15.0.dev2023101608 \
|
||||
--hash=sha256:fc045b32fb1a607da93799b3da0642527195a716cac424367f3c5f4edc2ec21e
|
||||
# via -r requirements.in
|
||||
typing-extensions==4.8.0 \
|
||||
--hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \
|
||||
--hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef
|
||||
|
|
|
|||
|
|
@ -546,9 +546,6 @@ termcolor==2.3.0 \
|
|||
--hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \
|
||||
--hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a
|
||||
# via -r requirements.in
|
||||
tf-estimator-nightly==2.15.0.dev2023101608 \
|
||||
--hash=sha256:fc045b32fb1a607da93799b3da0642527195a716cac424367f3c5f4edc2ec21e
|
||||
# via -r requirements.in
|
||||
typing-extensions==4.8.0 \
|
||||
--hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \
|
||||
--hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# TensorFlow is a computational framework, primarily for use in machine
|
||||
# learning applications.
|
||||
|
||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
|
@ -39,6 +38,7 @@ load(
|
|||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl_ml",
|
||||
)
|
||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||
|
||||
# copybara:uncomment_begin
|
||||
# # buildifier: disable=out-of-order-load
|
||||
|
|
@ -1710,7 +1710,6 @@ py_library(
|
|||
"//tensorflow/lite/python:analyzer",
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/lite/python/authoring",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,10 +33,8 @@ import inspect as _inspect
|
|||
import os as _os
|
||||
import site as _site
|
||||
import sys as _sys
|
||||
import typing as _typing
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoader
|
||||
|
||||
# Make sure code inside the TensorFlow codebase can use tf2.enabled() at import.
|
||||
|
|
@ -67,14 +65,6 @@ if (_os.getenv("TF_USE_MODULAR_FILESYSTEM", "0") == "true" or
|
|||
_os.getenv("TF_USE_MODULAR_FILESYSTEM", "0") == "1"):
|
||||
import tensorflow_io_gcs_filesystem as _tensorflow_io_gcs_filesystem
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
# Lazy-load Keras v2/3.
|
||||
_tf_uses_legacy_keras = (
|
||||
_os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"))
|
||||
|
|
@ -169,18 +159,11 @@ except (ImportError, AttributeError):
|
|||
|
||||
del importlib
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator as estimator
|
||||
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
# Don't fail if these modules are not available.
|
||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
||||
# does not have "python", "core" directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
except NameError:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import inspect as _inspect
|
|||
import os as _os
|
||||
import site as _site
|
||||
import sys as _sys
|
||||
import typing as _typing
|
||||
|
||||
# pylint: disable=g-bad-import-order,protected-access,g-import-not-at-top
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
|
|
@ -70,14 +69,6 @@ if (_os.getenv("TF_USE_MODULAR_FILESYSTEM", "0") == "true" or
|
|||
_os.getenv("TF_USE_MODULAR_FILESYSTEM", "0") == "1"):
|
||||
import tensorflow_io_gcs_filesystem as _tensorflow_io_gcs_filesystem
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
# Lazy-load Keras v1.
|
||||
_tf_uses_legacy_keras = (
|
||||
_os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"))
|
||||
|
|
@ -190,17 +181,12 @@ if _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH", ""):
|
|||
_os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH")
|
||||
)
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator as estimator
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
# Don't fail if these modules are not available.
|
||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
||||
# does not have "python", "core" directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
except NameError:
|
||||
|
|
|
|||
|
|
@ -18,10 +18,8 @@
|
|||
|
||||
import os as _os
|
||||
import sys as _sys
|
||||
import typing as _typing
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoader
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
|
@ -31,14 +29,6 @@ from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoad
|
|||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
# Lazy load Keras v2
|
||||
_tf_uses_legacy_keras = (
|
||||
_os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"))
|
||||
|
|
@ -72,9 +62,3 @@ setattr(_current_module, "losses", _losses)
|
|||
setattr(_current_module, "metrics", _metrics)
|
||||
setattr(_current_module, "optimizers", _optimizers)
|
||||
setattr(_current_module, "initializers", _initializers)
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator as estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
|
|
|||
|
|
@ -18,10 +18,8 @@
|
|||
|
||||
import os as _os
|
||||
import sys as _sys
|
||||
import typing as _typing
|
||||
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
|
||||
from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoader
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
|
@ -31,14 +29,6 @@ from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoad
|
|||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
|
||||
# Lazy-load estimator.
|
||||
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
|
||||
estimator = _LazyLoader("estimator", globals(), _estimator_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
# Lazy load Keras v1
|
||||
_tf_uses_legacy_keras = (
|
||||
_os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"))
|
||||
|
|
@ -81,9 +71,3 @@ else:
|
|||
_module_dir = _module_util.get_parent_dir_for_name(
|
||||
"keras.api._v1.keras.__internal__.legacy.rnn_cell")
|
||||
_current_module.nn.__path__ = [_module_dir] + _current_module.nn.__path__
|
||||
|
||||
# Explicitly import lazy-loaded modules to support autocompletion.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if _typing.TYPE_CHECKING:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator as estimator
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ FlatBuffers is an order of magnitude smaller than protocol buffers.
|
|||
|
||||
The converter supports the following input formats:
|
||||
|
||||
* [SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators)
|
||||
* [SavedModels](https://www.tensorflow.org/guide/saved_model)
|
||||
* `tf.keras` H5 models.
|
||||
* Frozen `GraphDef` models generated using
|
||||
[freeze_graph.py](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py).
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ package(
|
|||
default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
"//tensorflow_estimator:__subpackages__",
|
||||
"//third_party/py/tensorflow_federated:__subpackages__",
|
||||
"//third_party/tflite_micro:__subpackages__",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -2082,8 +2082,6 @@ class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
|
|||
saved_model_dir, signature_key=signature_key, tag_set=tags
|
||||
)
|
||||
|
||||
# Ensures any graphs created in Eager mode are able to run. This is required
|
||||
# in order to create a tf.estimator.Exporter that exports a TFLite model.
|
||||
if tags is None:
|
||||
tags = set([_tag_constants.SERVING])
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ Usage information is given in these documents:
|
|||
Once an application developer has a trained TensorFlow model, the TensorFlow
|
||||
Lite Converter will accept
|
||||
that model and generate a TensorFlow Lite
|
||||
[FlatBuffer](https://google.github.io/flatbuffers/) file. The converter currently supports
|
||||
[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators),
|
||||
[FlatBuffer](https://google.github.io/flatbuffers/) file. The converter
|
||||
currently supports [SavedModels](https://www.tensorflow.org/guide/saved_model),
|
||||
frozen graphs (models generated via
|
||||
[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)),
|
||||
and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
# Example Estimator model
|
||||
|
||||
load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library")
|
||||
|
||||
package(
|
||||
|
|
|
|||
|
|
@ -79,7 +79,6 @@ py_strict_library(
|
|||
],
|
||||
deps = [
|
||||
":no_contrib",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/ops:gradient_checker_v2",
|
||||
"//tensorflow/python/ops:stateful_random_ops",
|
||||
"//tensorflow/python/ops/structured:structured_ops",
|
||||
|
|
@ -98,13 +97,11 @@ py_strict_library(
|
|||
visibility = [
|
||||
"//tensorflow:__pkg__",
|
||||
"//tensorflow:internal",
|
||||
"//tensorflow/python/estimator:__subpackages__",
|
||||
"//tensorflow/python/keras:__subpackages__",
|
||||
"//tensorflow/python/tools:__pkg__",
|
||||
"//tensorflow/python/tools/api/generator:__pkg__",
|
||||
"//tensorflow/tools/api/tests:__pkg__",
|
||||
"//tensorflow/tools/compatibility/update:__pkg__",
|
||||
"//tensorflow_estimator:__subpackages__",
|
||||
"//third_party/py/tensorflow_privacy:__subpackages__", # TODO(b/163395075): remove when fixed
|
||||
],
|
||||
deps = [
|
||||
|
|
@ -121,7 +118,6 @@ py_strict_library(
|
|||
srcs_version = "PY3",
|
||||
visibility = [
|
||||
"//tensorflow:__pkg__",
|
||||
"//tensorflow/python/estimator:__subpackages__",
|
||||
"//tensorflow/python/keras:__subpackages__",
|
||||
"//tensorflow/python/tools:__pkg__",
|
||||
"//tensorflow/python/tools/api/generator:__pkg__",
|
||||
|
|
@ -157,7 +153,6 @@ py_strict_library(
|
|||
"//tensorflow/python/distribute",
|
||||
"//tensorflow/python/distribute:combinations", # For tf.__internal__ API.
|
||||
"//tensorflow/python/distribute:distribute_config",
|
||||
"//tensorflow/python/distribute:estimator_training",
|
||||
"//tensorflow/python/distribute:strategy_combinations", # For tf.__internal__,
|
||||
"//tensorflow/python/distribute/experimental/rpc:rpc_ops",
|
||||
"//tensorflow/python/dlpack",
|
||||
|
|
|
|||
|
|
@ -171,62 +171,3 @@ cuda_py_strict_test(
|
|||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_strict_test(
|
||||
name = "quantization_mnist_test",
|
||||
srcs = ["//tensorflow/python/compiler/tensorrt/test:quantization_mnist_test_srcs"],
|
||||
data = [
|
||||
"//tensorflow/python/compiler/tensorrt/test:quantization_mnist_test_data",
|
||||
],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_oss", # TODO(b/125290478): allow running in at least some OSS configurations.
|
||||
"no_pip",
|
||||
"no_rocm",
|
||||
"no_windows",
|
||||
"nomac",
|
||||
"notap", #TODO(b/290051231)
|
||||
"requires-net:external",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
deps = [
|
||||
":tf_trt_integration_test_base",
|
||||
":trt_convert_py",
|
||||
"//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/client:session",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/estimator",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/estimator:run_config",
|
||||
"//tensorflow/python/framework:convert_to_constants",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:importer",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/framework:test_lib",
|
||||
"//tensorflow/python/keras:metrics",
|
||||
"//tensorflow/python/layers",
|
||||
"//tensorflow/python/ops:array_ops",
|
||||
"//tensorflow/python/ops:array_ops_gen",
|
||||
"//tensorflow/python/ops:init_ops",
|
||||
"//tensorflow/python/ops:math_ops",
|
||||
"//tensorflow/python/ops:metrics",
|
||||
"//tensorflow/python/ops:nn",
|
||||
"//tensorflow/python/ops:variable_scope",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/platform:client_testlib",
|
||||
"//tensorflow/python/platform:tf_logging",
|
||||
"//tensorflow/python/saved_model:builder",
|
||||
"//tensorflow/python/saved_model:load",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
"//tensorflow/python/saved_model:signature_def_utils",
|
||||
"//tensorflow/python/saved_model:tag_constants",
|
||||
"//tensorflow/python/saved_model:utils",
|
||||
"//tensorflow/python/summary:summary_py",
|
||||
"//tensorflow/python/training:adam",
|
||||
"//tensorflow/python/training:checkpoint_management",
|
||||
"//tensorflow/python/training:saver",
|
||||
"//tensorflow/python/training:training_util",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,22 +73,6 @@ filegroup(
|
|||
visibility = ["//tensorflow/python/compiler/tensorrt:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "quantization_mnist_test_srcs",
|
||||
srcs = ["quantization_mnist_test.py"],
|
||||
visibility = ["//tensorflow/python/compiler/tensorrt:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "quantization_mnist_test_data",
|
||||
srcs = [
|
||||
"testdata/mnist/checkpoint",
|
||||
"testdata/mnist/model.ckpt-46900.data-00000-of-00001",
|
||||
"testdata/mnist/model.ckpt-46900.index",
|
||||
],
|
||||
visibility = ["//tensorflow/python/compiler/tensorrt:__pkg__"],
|
||||
)
|
||||
|
||||
base_tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_rocm",
|
||||
|
|
|
|||
|
|
@ -1,411 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Script to test TF-TRT INT8 conversion without calibration on Mnist model."""
|
||||
|
||||
import os.path
|
||||
import tempfile
|
||||
import tensorflow_datasets as tfds
|
||||
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.estimator.estimator import Estimator
|
||||
from tensorflow.python.estimator.model_fn import EstimatorSpec
|
||||
from tensorflow.python.estimator.model_fn import ModeKeys
|
||||
from tensorflow.python.estimator.run_config import RunConfig
|
||||
from tensorflow.python.framework import convert_to_constants
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.metrics import Accuracy
|
||||
from tensorflow.python.layers import layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import builder
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.saved_model import utils as saved_model_utils
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model.load import load as saved_model_load
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.adam import AdamOptimizer
|
||||
from tensorflow.python.training.checkpoint_management import latest_checkpoint
|
||||
from tensorflow.python.training.training_util import get_global_step
|
||||
|
||||
INPUT_NODE_NAME = 'input'
|
||||
OUTPUT_NODE_NAME = 'output'
|
||||
MNIST_TEST_DIR_PATH = 'python/compiler/tensorrt/test/testdata/mnist'
|
||||
|
||||
|
||||
def _PreprocessFn(entry):
|
||||
"""Normalizes the pixel values to lay within the [-1, 1] range.
|
||||
|
||||
The same normalization shall be used during training and inference.
|
||||
"""
|
||||
x, y = entry['image'], entry['label']
|
||||
x = math_ops.cast(x, dtypes.float32)
|
||||
x = 2.0 * (x / 255.0) - 1.0
|
||||
y = math_ops.cast(y, dtypes.int32)
|
||||
return x, y
|
||||
|
||||
|
||||
def _GetDataSet(batch_size):
|
||||
dataset = tfds.load('mnist', split='test')
|
||||
dataset = dataset.map(
|
||||
map_func=_PreprocessFn, num_parallel_calls=8).batch(batch_size=batch_size)
|
||||
dataset = dataset.repeat(count=1)
|
||||
return dataset
|
||||
|
||||
|
||||
class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
|
||||
"""Testing usage of quantization ranges inserted in graph."""
|
||||
|
||||
def _BuildGraph(self, x):
|
||||
|
||||
def _Quantize(x, r):
|
||||
x = gen_array_ops.quantize_and_dequantize_v2(x, -r, r)
|
||||
return x
|
||||
|
||||
def _DenseLayer(x, num_inputs, num_outputs, quantization_range, name):
|
||||
"""Defines a dense layer with quantized outputs.
|
||||
|
||||
Args:
|
||||
x: input to the dense layer
|
||||
num_inputs: number of input columns of x
|
||||
num_outputs: number of output columns
|
||||
quantization_range: the min/max range for quantization
|
||||
name: name of the variable scope
|
||||
|
||||
Returns:
|
||||
The output of the layer.
|
||||
"""
|
||||
with variable_scope.variable_scope(name):
|
||||
kernel = variable_scope.get_variable(
|
||||
'kernel',
|
||||
shape=[num_inputs, num_outputs],
|
||||
dtype=dtypes.float32,
|
||||
initializer=init_ops.GlorotUniform())
|
||||
bias = variable_scope.get_variable(
|
||||
'bias',
|
||||
shape=[num_outputs],
|
||||
dtype=dtypes.float32,
|
||||
initializer=init_ops.Zeros())
|
||||
x = math_ops.matmul(x, kernel)
|
||||
x = _Quantize(x, quantization_range)
|
||||
x = nn.bias_add(x, bias)
|
||||
x = _Quantize(x, quantization_range)
|
||||
return x
|
||||
|
||||
x = _Quantize(x, 1)
|
||||
# Conv + Bias + Relu6
|
||||
x = layers.conv2d(x, filters=32, kernel_size=3, use_bias=True)
|
||||
x = nn.relu6(x)
|
||||
# Conv + Bias + Relu6
|
||||
x = layers.conv2d(x, filters=64, kernel_size=3, use_bias=True)
|
||||
x = nn.relu6(x)
|
||||
# Reduce
|
||||
x = math_ops.reduce_mean(x, [1, 2])
|
||||
x = _Quantize(x, 6)
|
||||
# FC1
|
||||
x = _DenseLayer(x, 64, 512, 6, name='dense')
|
||||
x = nn.relu6(x)
|
||||
# FC2
|
||||
x = _DenseLayer(x, 512, 10, 25, name='dense_1')
|
||||
x = array_ops.identity(x, name=OUTPUT_NODE_NAME)
|
||||
return x
|
||||
|
||||
def _LoadWeights(self, model_dir, sess):
|
||||
mnist_saver = saver.Saver()
|
||||
checkpoint_file = latest_checkpoint(model_dir)
|
||||
if checkpoint_file is None:
|
||||
raise ValueError('latest_checkpoint returned None. check if' +
|
||||
'model_dir={} is the right directory'.format(model_dir))
|
||||
mnist_saver.restore(sess, checkpoint_file)
|
||||
|
||||
def _GetGraphDef(self, use_trt, max_batch_size, model_dir):
|
||||
"""Gets the frozen mnist GraphDef.
|
||||
|
||||
Args:
|
||||
use_trt: whether use TF-TRT to convert the graph.
|
||||
max_batch_size: the max batch size to apply during TF-TRT conversion.
|
||||
model_dir: the model directory to load the checkpoints.
|
||||
|
||||
Returns:
|
||||
The frozen mnist GraphDef.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
with self.session(graph=graph) as sess:
|
||||
with graph.device('/GPU:0'):
|
||||
x = array_ops.placeholder(
|
||||
shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME)
|
||||
self._BuildGraph(x)
|
||||
self._LoadWeights(model_dir, sess)
|
||||
# Freeze
|
||||
graph_def = convert_to_constants.convert_variables_to_constants(
|
||||
sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME])
|
||||
# Convert with TF-TRT
|
||||
if use_trt:
|
||||
logging.info('Number of nodes before TF-TRT conversion: %d',
|
||||
len(graph_def.node))
|
||||
converter = trt_convert.TrtGraphConverter(
|
||||
input_graph_def=graph_def,
|
||||
nodes_denylist=[OUTPUT_NODE_NAME],
|
||||
max_batch_size=max_batch_size,
|
||||
precision_mode='INT8',
|
||||
max_workspace_size_bytes=(
|
||||
trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES),
|
||||
minimum_segment_size=2,
|
||||
use_calibration=False)
|
||||
graph_def = converter.convert()
|
||||
logging.info('Number of nodes after TF-TRT conversion: %d',
|
||||
len(graph_def.node))
|
||||
num_engines = len(
|
||||
[1 for n in graph_def.node if str(n.op) == 'TRTEngineOp'])
|
||||
self.assertEqual(1, num_engines)
|
||||
return graph_def
|
||||
|
||||
def _Run(self, is_training, use_trt, batch_size, num_epochs, model_dir):
|
||||
"""Trains or evaluates the model.
|
||||
|
||||
Args:
|
||||
is_training: whether to train or evaluate the model. In training mode,
|
||||
quantization will be simulated where the quantize_and_dequantize_v2 are
|
||||
placed.
|
||||
use_trt: if true, use TRT INT8 mode for evaluation, which will perform
|
||||
real quantization. Otherwise use native TensorFlow which will perform
|
||||
simulated quantization. Ignored if is_training is True.
|
||||
batch_size: batch size.
|
||||
num_epochs: how many epochs to train. Ignored if is_training is False.
|
||||
model_dir: where to save or load checkpoint.
|
||||
|
||||
Returns:
|
||||
The Estimator evaluation result.
|
||||
"""
|
||||
|
||||
def _EvalInputFn():
|
||||
dataset = _GetDataSet(batch_size)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
features, labels = iterator.get_next()
|
||||
return features, labels
|
||||
|
||||
def _TrainInputFn():
|
||||
dataset = tfds.load('mnist', split='train')
|
||||
dataset = dataset.shuffle(60000)
|
||||
dataset = dataset.map(
|
||||
map_func=_PreprocessFn,
|
||||
num_parallel_calls=8).batch(batch_size=batch_size)
|
||||
dataset = dataset.repeat(count=num_epochs)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
features, labels = iterator.get_next()
|
||||
return features, labels
|
||||
|
||||
def _ModelFn(features, labels, mode):
|
||||
if is_training:
|
||||
logits_out = self._BuildGraph(features)
|
||||
else:
|
||||
graph_def = self._GetGraphDef(use_trt, batch_size, model_dir)
|
||||
logits_out = importer.import_graph_def(
|
||||
graph_def,
|
||||
input_map={INPUT_NODE_NAME: features},
|
||||
return_elements=[OUTPUT_NODE_NAME + ':0'],
|
||||
name='')[0]
|
||||
|
||||
loss = losses.sparse_softmax_cross_entropy(
|
||||
labels=labels, logits=logits_out)
|
||||
summary.scalar('loss', loss)
|
||||
|
||||
classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out')
|
||||
accuracy = metrics.accuracy(
|
||||
labels=labels, predictions=classes_out, name='acc_op')
|
||||
summary.scalar('accuracy', accuracy[1])
|
||||
|
||||
if mode == ModeKeys.EVAL:
|
||||
return EstimatorSpec(
|
||||
mode, loss=loss, eval_metric_ops={'accuracy': accuracy})
|
||||
if mode == ModeKeys.TRAIN:
|
||||
optimizer = AdamOptimizer(learning_rate=1e-2)
|
||||
train_op = optimizer.minimize(loss, global_step=get_global_step())
|
||||
return EstimatorSpec(mode, loss=loss, train_op=train_op)
|
||||
|
||||
config_proto = config_pb2.ConfigProto()
|
||||
config_proto.gpu_options.allow_growth = True
|
||||
estimator = Estimator(
|
||||
model_fn=_ModelFn,
|
||||
model_dir=model_dir if is_training else None,
|
||||
config=RunConfig(session_config=config_proto))
|
||||
|
||||
if is_training:
|
||||
estimator.train(_TrainInputFn)
|
||||
results = estimator.evaluate(_EvalInputFn)
|
||||
logging.info('accuracy: %s', str(results['accuracy']))
|
||||
return results
|
||||
|
||||
# To generate the checkpoint, set a different model_dir and call self._Run()
|
||||
# by setting is_training=True and num_epochs=1000, e.g.:
|
||||
# model_dir = '/tmp/quantization_mnist'
|
||||
# self._Run(
|
||||
# is_training=True,
|
||||
# use_trt=False,
|
||||
# batch_size=128,
|
||||
# num_epochs=100,
|
||||
# model_dir=model_dir)
|
||||
def testEval(self):
|
||||
|
||||
model_dir = test.test_src_dir_path(MNIST_TEST_DIR_PATH)
|
||||
|
||||
accuracy_tf_native = self._Run(
|
||||
is_training=False,
|
||||
use_trt=False,
|
||||
batch_size=128,
|
||||
num_epochs=None,
|
||||
model_dir=model_dir)['accuracy']
|
||||
logging.info('accuracy_tf_native: %f', accuracy_tf_native)
|
||||
self.assertAllClose(0.9662, accuracy_tf_native, rtol=3e-3, atol=3e-3)
|
||||
|
||||
accuracy_tf_trt = self._Run(
|
||||
is_training=False,
|
||||
use_trt=True,
|
||||
batch_size=128,
|
||||
num_epochs=None,
|
||||
model_dir=model_dir)['accuracy']
|
||||
logging.info('accuracy_tf_trt: %f', accuracy_tf_trt)
|
||||
self.assertAllClose(0.9675, accuracy_tf_trt, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
class MNISTTestV2(QuantizationAwareTrainingMNISTTest):
|
||||
|
||||
def _SaveModel(self, model_dir, output_dir):
|
||||
saved_model_builder = builder.SavedModelBuilder(output_dir)
|
||||
graph = ops.Graph()
|
||||
with session.Session(graph=graph) as sess:
|
||||
with graph.device('/GPU:0'):
|
||||
x = array_ops.placeholder(
|
||||
shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME)
|
||||
self._BuildGraph(x)
|
||||
self._LoadWeights(model_dir, sess)
|
||||
input_tensor = graph.get_tensor_by_name(INPUT_NODE_NAME + ':0')
|
||||
output = graph.get_tensor_by_name(OUTPUT_NODE_NAME + ':0')
|
||||
signature_def = signature_def_utils.build_signature_def(
|
||||
inputs={'input': saved_model_utils.build_tensor_info(input_tensor)},
|
||||
outputs={'output': saved_model_utils.build_tensor_info(output)},
|
||||
method_name=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||
saved_model_builder.add_meta_graph_and_variables(
|
||||
sess, [tag_constants.SERVING],
|
||||
signature_def_map={
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||
signature_def
|
||||
})
|
||||
saved_model_builder.save()
|
||||
|
||||
def _GetFunc(self, use_trt, model_dir, use_dynamic_shape):
|
||||
"""Gets the mnist function.
|
||||
|
||||
Args:
|
||||
use_trt: whether use TF-TRT to convert the graph.
|
||||
model_dir: the model directory to load the checkpoints.
|
||||
use_dynamic_shape: whether to run the TF-TRT conversion in dynamic shape
|
||||
mode.
|
||||
|
||||
Returns:
|
||||
The mnist model function.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
saved_model_dir = os.path.join(tmpdir, 'mnist')
|
||||
self._SaveModel(model_dir, saved_model_dir)
|
||||
|
||||
if use_trt:
|
||||
conv_params = trt_convert.TrtConversionParams(
|
||||
precision_mode='FP16',
|
||||
minimum_segment_size=2,
|
||||
max_workspace_size_bytes=(
|
||||
trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES),
|
||||
maximum_cached_engines=1)
|
||||
converter = trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=saved_model_dir,
|
||||
use_dynamic_shape=use_dynamic_shape,
|
||||
dynamic_shape_profile_strategy='ImplicitBatchModeCompatible',
|
||||
**conv_params._asdict())
|
||||
converter.convert()
|
||||
try:
|
||||
line_length = max(160, os.get_terminal_size().columns)
|
||||
except OSError:
|
||||
line_length = 160
|
||||
converter.summary(line_length=line_length, detailed=True)
|
||||
func = converter._converted_func
|
||||
else:
|
||||
saved_model_loaded = saved_model_load(
|
||||
saved_model_dir, tags=[tag_constants.SERVING])
|
||||
func = saved_model_loaded.signatures[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
return func
|
||||
|
||||
def _Run(self, use_trt, batch_size, model_dir, use_dynamic_shape=False):
|
||||
"""Evaluates the model.
|
||||
|
||||
Args:
|
||||
use_trt: if true, use TRT INT8 mode for evaluation, which will perform
|
||||
real quantization. Otherwise use native TensorFlow which will perform
|
||||
simulated quantization. Ignored if is_training is True.
|
||||
batch_size: batch size.
|
||||
model_dir: where to save or load checkpoint.
|
||||
use_dynamic_shape: if true, then TF-TRT dynamic shape mode is enabled,
|
||||
otherwise disabled. Ignored if use_trt is false.
|
||||
|
||||
Returns:
|
||||
The Estimator evaluation result.
|
||||
"""
|
||||
func = self._GetFunc(use_trt, model_dir, use_dynamic_shape)
|
||||
ds = _GetDataSet(batch_size)
|
||||
|
||||
m = Accuracy()
|
||||
for example in ds:
|
||||
image, label = example[0], example[1]
|
||||
pred = func(image)
|
||||
m.update_state(math_ops.argmax(pred['output'], axis=1), label)
|
||||
|
||||
return m.result().numpy()
|
||||
|
||||
def testEval(self):
|
||||
model_dir = test.test_src_dir_path(MNIST_TEST_DIR_PATH)
|
||||
|
||||
accuracy_tf_trt = self._Run(
|
||||
use_trt=True,
|
||||
batch_size=128,
|
||||
use_dynamic_shape=False,
|
||||
model_dir=model_dir)
|
||||
logging.info('accuracy_tf_trt: %f', accuracy_tf_trt)
|
||||
self.assertAllClose(0.9675, accuracy_tf_trt, rtol=1e-3, atol=1e-3)
|
||||
|
||||
accuracy_tf_trt = self._Run(
|
||||
use_trt=True,
|
||||
batch_size=128,
|
||||
use_dynamic_shape=True,
|
||||
model_dir=model_dir)
|
||||
logging.info('accuracy_tf_trt: %f', accuracy_tf_trt)
|
||||
self.assertAllClose(0.9675, accuracy_tf_trt, rtol=1e-3, atol=1e-3)
|
||||
|
||||
if __name__ == '__main__' and is_tensorrt_enabled():
|
||||
test.main()
|
||||
|
|
@ -87,7 +87,6 @@ cuda_py_strict_test(
|
|||
deps = [
|
||||
":xla",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/framework:constant_op",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/framework:test_lib",
|
||||
|
|
@ -99,7 +98,7 @@ cuda_py_strict_test(
|
|||
"//tensorflow/python/ops:variable_scope",
|
||||
"//tensorflow/python/ops:while_loop",
|
||||
"//tensorflow/python/platform:client_testlib",
|
||||
"//tensorflow/python/summary:__init__",
|
||||
"//tensorflow/python/summary:summary_py",
|
||||
"//tensorflow/python/tpu:tpu_feed",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -16,10 +16,8 @@
|
|||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
|
|
@ -31,11 +29,10 @@ from tensorflow.python.ops import state_ops
|
|||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import while_loop
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.tpu import tpu_feed
|
||||
|
||||
|
||||
_TRAIN = model_fn_lib.ModeKeys.TRAIN
|
||||
_EVAL = model_fn_lib.ModeKeys.EVAL
|
||||
_EXPECTED_LOSS = 1
|
||||
_EXPECTED_FEATURE = 2
|
||||
_EXPECTED_LABEL = 3
|
||||
|
|
|
|||
|
|
@ -572,21 +572,6 @@ class _CapturedObject(object):
|
|||
return self._object
|
||||
|
||||
|
||||
def _get_scaffold(captured_scaffold_fn):
|
||||
"""Retrieves the Scaffold from `captured_scaffold_fn`."""
|
||||
scaffold_fn = captured_scaffold_fn.get()
|
||||
|
||||
if not scaffold_fn:
|
||||
return None
|
||||
|
||||
scaffold = scaffold_fn()
|
||||
if scaffold is None:
|
||||
raise ValueError(
|
||||
'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
|
||||
|
||||
return scaffold
|
||||
|
||||
|
||||
def check_function_argument_count(func, input_arity, infeed_queue):
|
||||
"""Validate the number of input arguments to an XLA function.
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||
@@AutoShardPolicy
|
||||
@@AutotuneAlgorithm
|
||||
@@AutotuneOptions
|
||||
@@CheckpointInputPipelineHook
|
||||
@@Counter
|
||||
@@CsvDataset
|
||||
@@DatasetInitializer
|
||||
|
|
@ -120,7 +119,6 @@ from tensorflow.python.data.experimental.ops.interleave_ops import parallel_inte
|
|||
from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
|
||||
from tensorflow.python.data.experimental.ops.io import load
|
||||
from tensorflow.python.data.experimental.ops.io import save
|
||||
from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
|
||||
from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
|
||||
from tensorflow.python.data.experimental.ops.lookup_ops import DatasetInitializer
|
||||
from tensorflow.python.data.experimental.ops.lookup_ops import index_table_from_dataset
|
||||
|
|
|
|||
|
|
@ -85,30 +85,6 @@ tf_py_strict_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_py_strict_test(
|
||||
name = "checkpoint_input_pipeline_hook_test",
|
||||
size = "medium",
|
||||
srcs = ["checkpoint_input_pipeline_hook_test.py"],
|
||||
tags = ["no_windows"], # b/287333711
|
||||
deps = [
|
||||
"//tensorflow/python/checkpoint:checkpoint_management",
|
||||
"//tensorflow/python/data/experimental/ops:iterator_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/framework:combinations",
|
||||
"//tensorflow/python/framework:constant_op",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/ops:control_flow_ops",
|
||||
"//tensorflow/python/ops:variable_v1",
|
||||
"//tensorflow/python/platform:client_testlib",
|
||||
"//tensorflow/python/training:saver",
|
||||
"//tensorflow/python/training:training_util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_strict_test(
|
||||
name = "compression_ops_test",
|
||||
size = "small",
|
||||
|
|
|
|||
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for experimental iterator_ops."""
|
||||
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python.checkpoint import checkpoint_management
|
||||
from tensorflow.python.data.experimental.ops import iterator_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.estimator import estimator_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_v1
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
# TODO(b/123904664)
|
||||
class CheckpointInputPipelineHookTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def _model_fn(features, labels, mode, config):
|
||||
del labels
|
||||
del mode
|
||||
del config
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
update_global_step_op = global_step.assign_add(1)
|
||||
latest_feature = variable_v1.VariableV1(
|
||||
0, name='latest_feature', dtype=dtypes.int64)
|
||||
store_latest_feature_op = latest_feature.assign(features)
|
||||
ops.add_to_collection('my_vars', global_step)
|
||||
ops.add_to_collection('my_vars', latest_feature)
|
||||
return estimator_lib.EstimatorSpec(
|
||||
mode='train',
|
||||
train_op=control_flow_ops.group(
|
||||
[update_global_step_op, store_latest_feature_op]),
|
||||
loss=constant_op.constant(2.0))
|
||||
|
||||
def _read_vars(self, model_dir):
|
||||
"""Returns (global_step, latest_feature)."""
|
||||
with ops.Graph().as_default() as g:
|
||||
ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
|
||||
meta_filename = ckpt_path + '.meta'
|
||||
saver_lib.import_meta_graph(meta_filename)
|
||||
saver = saver_lib.Saver()
|
||||
with self.session(graph=g) as sess:
|
||||
saver.restore(sess, ckpt_path)
|
||||
return sess.run(ops.get_collection('my_vars'))
|
||||
|
||||
def _build_iterator_saver_hook(self, est):
|
||||
return iterator_ops.CheckpointInputPipelineHook(est)
|
||||
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testReturnDatasetFromInputFn(self):
|
||||
|
||||
def _input_fn():
|
||||
return dataset_ops.Dataset.range(10)
|
||||
|
||||
est = estimator_lib.Estimator(model_fn=self._model_fn)
|
||||
|
||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
|
||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
|
||||
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testBuildIteratorInInputFn(self):
|
||||
|
||||
def _input_fn():
|
||||
ds = dataset_ops.Dataset.range(10)
|
||||
iterator = ds.make_one_shot_iterator()
|
||||
return iterator.get_next()
|
||||
|
||||
est = estimator_lib.Estimator(model_fn=self._model_fn)
|
||||
|
||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
|
||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
|
||||
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testDoNotRestore(self):
|
||||
|
||||
def _input_fn():
|
||||
return dataset_ops.Dataset.range(10)
|
||||
|
||||
est = estimator_lib.Estimator(model_fn=self._model_fn)
|
||||
|
||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
|
||||
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
|
||||
# Hook not provided, input pipeline was not restored.
|
||||
est.train(_input_fn, steps=2)
|
||||
self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
|
||||
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testRaiseErrorIfNoIterator(self):
|
||||
|
||||
def _input_fn():
|
||||
return constant_op.constant(1, dtype=dtypes.int64)
|
||||
|
||||
est = estimator_lib.Estimator(model_fn=self._model_fn)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
est.train(
|
||||
_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
@ -223,13 +223,8 @@ py_strict_library(
|
|||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python/checkpoint:checkpoint_management",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/training:basic_session_run_hooks",
|
||||
"//tensorflow/python/training:saver",
|
||||
"//tensorflow/python/training:session_run_hook",
|
||||
"//tensorflow/python/util:deprecation",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -89,44 +89,6 @@ def get_single_element(dataset):
|
|||
signatures={'serving_default': preprocessing_model.serving_fn})
|
||||
```
|
||||
|
||||
# Estimator
|
||||
|
||||
In the case of estimators, you need to generally define a `serving_input_fn`
|
||||
which would require the features to be processed by the model while
|
||||
inferencing.
|
||||
|
||||
```python
|
||||
def serving_input_fn():
|
||||
|
||||
raw_feature_spec = ... # Spec for the raw_features
|
||||
input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
|
||||
raw_feature_spec, default_batch_size=None)
|
||||
)
|
||||
serving_input_receiver = input_fn()
|
||||
raw_features = serving_input_receiver.features
|
||||
|
||||
def preprocessing_fn(raw_feature):
|
||||
# ... the raw_feature is preprocessed as per the use-case
|
||||
return feature
|
||||
|
||||
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
|
||||
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
|
||||
.batch(BATCH_SIZE))
|
||||
|
||||
processed_features = tf.data.experimental.get_single_element(dataset)
|
||||
|
||||
# Please note that the value of `BATCH_SIZE` should be equal to
|
||||
# the size of the leading dimension of `raw_features`. This ensures
|
||||
# that `dataset` has only element, which is a pre-requisite for
|
||||
# using `tf.data.experimental.get_single_element(dataset)`.
|
||||
|
||||
return tf.estimator.export.ServingInputReceiver(
|
||||
processed_features, serving_input_receiver.receiver_tensors)
|
||||
|
||||
estimator = ... # A pre-built or custom estimator
|
||||
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
|
||||
```
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` object containing a single element.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,13 +14,8 @@
|
|||
# ==============================================================================
|
||||
"""Iterator ops."""
|
||||
|
||||
from tensorflow.python.checkpoint import checkpoint_management
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
|
@ -100,222 +95,3 @@ def make_saveable_from_iterator(iterator, external_state_policy=None):
|
|||
iterator._iterator_resource, # pylint: disable=protected-access
|
||||
iterator._iterator_resource.name, # pylint: disable=protected-access
|
||||
external_state_policy=policy_enum)
|
||||
|
||||
|
||||
@tf_export("data.experimental.CheckpointInputPipelineHook")
|
||||
class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
|
||||
"""Checkpoints input pipeline state every N steps or seconds.
|
||||
|
||||
This hook saves the state of the iterators in the `Graph` so that when
|
||||
training is resumed the input pipeline continues from where it left off.
|
||||
This could potentially avoid overfitting in certain pipelines where the
|
||||
number of training steps per eval are small compared to the dataset
|
||||
size or if the training pipeline is pre-empted.
|
||||
|
||||
Differences from `CheckpointSaverHook`:
|
||||
1. Saves only the input pipelines in the "iterators" collection and not the
|
||||
global variables or other saveable objects.
|
||||
2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
|
||||
|
||||
Example of checkpointing the training pipeline:
|
||||
|
||||
```python
|
||||
est = tf.estimator.Estimator(model_fn)
|
||||
while True:
|
||||
est.train(
|
||||
train_input_fn,
|
||||
hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
|
||||
steps=train_steps_per_eval)
|
||||
# Note: We do not pass the hook here.
|
||||
metrics = est.evaluate(eval_input_fn)
|
||||
if should_stop_the_training(metrics):
|
||||
break
|
||||
```
|
||||
|
||||
This hook should be used if the input pipeline state needs to be saved
|
||||
separate from the model checkpoint. Doing so may be useful for a few reasons:
|
||||
1. The input pipeline checkpoint may be large, if there are large shuffle
|
||||
or prefetch buffers for instance, and may bloat the checkpoint size.
|
||||
2. If the input pipeline is shared between training and validation, restoring
|
||||
the checkpoint during validation may override the validation input
|
||||
pipeline.
|
||||
|
||||
For saving the input pipeline checkpoint alongside the model weights use
|
||||
`tf.data.experimental.make_saveable_from_iterator` directly to create a
|
||||
`SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
|
||||
that you will need to be careful not to restore the training iterator during
|
||||
eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
|
||||
collector when building the eval graph.
|
||||
"""
|
||||
|
||||
def __init__(self, estimator, external_state_policy=None):
|
||||
"""Initializes a `CheckpointInputPipelineHook`.
|
||||
|
||||
If the input pipeline depends on external state (e.g. seeds for
|
||||
RandomUniform) beyond the input pipeline, this hook would be unable to
|
||||
serialize and deserialize that state. If its acceptable to ignore that state
|
||||
change the external_state_policy argument to 'warn' or 'ignore'. For e.g.
|
||||
|
||||
```python
|
||||
est = tf.estimator.Estimator(model_fn)
|
||||
while True:
|
||||
est.train(
|
||||
train_input_fn,
|
||||
hooks=[tf.data.experimental.CheckpointInputPipelineHook(
|
||||
est, external_state_policy='warn')],
|
||||
steps=train_steps_per_eval)
|
||||
# Note: We do not pass the hook here.
|
||||
metrics = est.evaluate(eval_input_fn)
|
||||
if should_stop_the_training(metrics):
|
||||
break
|
||||
```
|
||||
|
||||
Args:
|
||||
estimator: Estimator.
|
||||
external_state_policy: A string that identifies how to handle input
|
||||
pipelines that depend on external state. Possible values are
|
||||
'ignore': The external state is silently ignored.
|
||||
'warn': The external state is ignored, logging a warning.
|
||||
'fail': The operation fails upon encountering external state.
|
||||
By default we set it to 'fail'.
|
||||
|
||||
Raises:
|
||||
ValueError: One of `save_steps` or `save_secs` should be set.
|
||||
ValueError: At most one of saver or scaffold should be set.
|
||||
ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
|
||||
'fail'.
|
||||
"""
|
||||
if external_state_policy is None:
|
||||
external_state_policy = "fail"
|
||||
self._external_state_policy = _convert_external_state_policy_to_enum(
|
||||
external_state_policy)
|
||||
# `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
|
||||
# of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
|
||||
# Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
|
||||
# "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
|
||||
# to be different to avoid conflicts with the model checkpoint.
|
||||
|
||||
# pylint: disable=protected-access
|
||||
checkpoint_prefix = "input"
|
||||
if estimator._config.num_worker_replicas > 1:
|
||||
# Distributed setting.
|
||||
suffix = "_{}_{}".format(estimator._config.task_type,
|
||||
estimator._config.task_id)
|
||||
checkpoint_prefix += suffix
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# We use a composition paradigm instead of inheriting from
|
||||
# `CheckpointSaverHook` because `Estimator` does an `isinstance` check
|
||||
# to check whether a `CheckpointSaverHook` is already present in the list
|
||||
# of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
|
||||
# would thwart this behavior. This hook checkpoints *only the iterators*
|
||||
# and not the graph variables.
|
||||
self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
|
||||
estimator.model_dir,
|
||||
save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
|
||||
save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
|
||||
checkpoint_basename=checkpoint_prefix + ".ckpt")
|
||||
|
||||
# Name for the protocol buffer file that will contain the list of most
|
||||
# recent checkpoints stored as a `CheckpointState` protocol buffer.
|
||||
# This file, kept in the same directory as the checkpoint files, is
|
||||
# automatically managed by the `Saver` to keep track of recent checkpoints.
|
||||
# The default name used by the `Saver` for this file is "checkpoint". Here
|
||||
# we use the name "checkpoint_<checkpoint_prefix>" so that in case the
|
||||
# `checkpoint_dir` is the same as the model checkpoint directory, there are
|
||||
# no conflicts during restore.
|
||||
self._latest_filename = "checkpoint_" + checkpoint_prefix
|
||||
|
||||
def begin(self):
|
||||
# Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
|
||||
# collection if no `Saver` or `Scaffold` is provided.
|
||||
# pylint: disable=protected-access
|
||||
if (self._checkpoint_saver_hook._saver is None and
|
||||
self._checkpoint_saver_hook._scaffold is None):
|
||||
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
|
||||
saveables = [
|
||||
iterator_ops._IteratorSaveable(
|
||||
i, i.name, external_state_policy=self._external_state_policy)
|
||||
for i in iterators
|
||||
]
|
||||
self._checkpoint_saver_hook._saver = _CustomSaver(
|
||||
saveables, self._latest_filename, sharded=True)
|
||||
# pylint: enable=protected-access
|
||||
self._checkpoint_saver_hook.begin()
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
# If a new session was created, we set _first_run to True so that we can
|
||||
# restore if needed.
|
||||
self._first_run = True
|
||||
|
||||
def _restore_or_save_initial_ckpt(self, session):
|
||||
# Ideally this should be run in after_create_session but is not for the
|
||||
# following reason:
|
||||
# Currently there is no way of enforcing an order of running the
|
||||
# `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
|
||||
# is run *after* this hook. That is troublesome because
|
||||
# 1. If a checkpoint exists and this hook restores it, the initializer hook
|
||||
# will override it.
|
||||
# 2. If no checkpoint exists, this hook will try to save an uninitialized
|
||||
# iterator which will result in an exception.
|
||||
#
|
||||
# As a temporary fix we enter the following implicit contract between this
|
||||
# hook and the _DatasetInitializerHook.
|
||||
# 1. The _DatasetInitializerHook initializes the iterator in the call to
|
||||
# after_create_session.
|
||||
# 2. This hook saves the iterator on the first call to `before_run()`, which
|
||||
# is guaranteed to happen after `after_create_session()` of all hooks
|
||||
# have been run.
|
||||
|
||||
# Check if there is an existing checkpoint. If so, restore from it.
|
||||
# pylint: disable=protected-access
|
||||
latest_checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
self._checkpoint_saver_hook._checkpoint_dir,
|
||||
latest_filename=self._latest_filename)
|
||||
if latest_checkpoint_path:
|
||||
self._checkpoint_saver_hook._get_saver().restore(session,
|
||||
latest_checkpoint_path)
|
||||
else:
|
||||
# The checkpoint saved here is the state at step "global_step".
|
||||
# Note: We do not save the GraphDef or MetaGraphDef here.
|
||||
global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
|
||||
self._checkpoint_saver_hook._save(session, global_step)
|
||||
self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def before_run(self, run_context):
|
||||
if self._first_run:
|
||||
self._restore_or_save_initial_ckpt(run_context.session)
|
||||
self._first_run = False
|
||||
return self._checkpoint_saver_hook.before_run(run_context)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
self._checkpoint_saver_hook.after_run(run_context, run_values)
|
||||
|
||||
def end(self, session):
|
||||
self._checkpoint_saver_hook.end(session)
|
||||
|
||||
|
||||
class _CustomSaver(saver_lib.Saver):
|
||||
"""`Saver` with a different default `latest_filename`.
|
||||
|
||||
This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
|
||||
the model ckpt saved by the `CheckpointSaverHook`.
|
||||
"""
|
||||
|
||||
def __init__(self, var_list, latest_filename, sharded=False):
|
||||
super(_CustomSaver, self).__init__(var_list, sharded=sharded)
|
||||
self._latest_filename = latest_filename
|
||||
|
||||
def save(self,
|
||||
sess,
|
||||
save_path,
|
||||
global_step=None,
|
||||
latest_filename=None,
|
||||
meta_graph_suffix="meta",
|
||||
write_meta_graph=True,
|
||||
write_state=True,
|
||||
strip_default_attrs=False):
|
||||
return super(_CustomSaver, self).save(
|
||||
sess, save_path, global_step, latest_filename or self._latest_filename,
|
||||
meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
|
||||
|
|
|
|||
|
|
@ -421,9 +421,7 @@ def make_csv_dataset_v2(
|
|||
index.
|
||||
label_name: A optional string corresponding to the label column. If
|
||||
provided, the data for this column is returned as a separate `Tensor` from
|
||||
the features dictionary, so that the dataset complies with the format
|
||||
expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
|
||||
function.
|
||||
the features dictionary.
|
||||
select_columns: An optional list of integer indices or string column
|
||||
names, that specifies a subset of columns of CSV data to select. If
|
||||
column names are provided, these must correspond to names provided in
|
||||
|
|
|
|||
|
|
@ -402,7 +402,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"make sure that the dataset is created in "
|
||||
"Make sure that the dataset is created in "
|
||||
"the same graph as the iterator"):
|
||||
_ = dataset_ops.make_one_shot_iterator(dataset)
|
||||
|
||||
|
|
@ -412,7 +412,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
dataset = dataset_ops.Dataset.range(10)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"make sure that the dataset is created in "
|
||||
"Make sure that the dataset is created in "
|
||||
"the same graph as the iterator"):
|
||||
_ = dataset_ops.make_initializable_iterator(dataset)
|
||||
|
||||
|
|
|
|||
|
|
@ -2867,44 +2867,6 @@ name=None))
|
|||
)
|
||||
```
|
||||
|
||||
#### Estimator
|
||||
|
||||
In the case of estimators, you need to generally define a `serving_input_fn`
|
||||
which would require the features to be processed by the model while
|
||||
inferencing.
|
||||
|
||||
```python
|
||||
def serving_input_fn():
|
||||
|
||||
raw_feature_spec = ... # Spec for the raw_features
|
||||
input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
|
||||
raw_feature_spec, default_batch_size=None)
|
||||
)
|
||||
serving_input_receiver = input_fn()
|
||||
raw_features = serving_input_receiver.features
|
||||
|
||||
def preprocessing_fn(raw_feature):
|
||||
# ... the raw_feature is preprocessed as per the use-case
|
||||
return feature
|
||||
|
||||
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
|
||||
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
|
||||
.batch(BATCH_SIZE))
|
||||
|
||||
processed_features = dataset.get_single_element()
|
||||
|
||||
# Please note that the value of `BATCH_SIZE` should be equal to
|
||||
# the size of the leading dimension of `raw_features`. This ensures
|
||||
# that `dataset` has only element, which is a pre-requisite for
|
||||
# using `dataset.get_single_element()`.
|
||||
|
||||
return tf.estimator.export.ServingInputReceiver(
|
||||
processed_features, serving_input_receiver.receiver_tensors)
|
||||
|
||||
estimator = ... # A pre-built or custom estimator
|
||||
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
|
||||
```
|
||||
|
||||
Args:
|
||||
name: (Optional.) A name for the tf.data operation.
|
||||
|
||||
|
|
@ -4266,10 +4228,8 @@ def _ensure_same_dataset_graph(dataset):
|
|||
raise ValueError(
|
||||
f"The graph {current_graph} of the iterator is different from the "
|
||||
f"graph {ds_graph} the dataset: {ds._variant_tensor} was created in. "
|
||||
f"If you are using the Estimator API, make sure that no part of the "
|
||||
f"dataset returned by the `input_fn` function is defined outside the "
|
||||
f"`input_fn` function. Otherwise, make sure that the dataset is "
|
||||
f"created in the same graph as the iterator.")
|
||||
f"Make sure that the dataset is created in the same graph as the "
|
||||
f"iterator.")
|
||||
for input_ds in ds._inputs():
|
||||
if input_ds not in visited:
|
||||
bfs_q.put(input_ds)
|
||||
|
|
|
|||
|
|
@ -29,20 +29,6 @@ py_strict_binary(
|
|||
],
|
||||
)
|
||||
|
||||
py_strict_binary(
|
||||
name = "debug_tflearn_iris",
|
||||
srcs = ["debug_tflearn_iris.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
# copybara:uncomment_begin(google-only)
|
||||
# "//third_party/py/tensorflow:tensorflow_compat_v1_estimator", # build_cleaner:keep
|
||||
# copybara:uncomment_end
|
||||
"//tensorflow/python/debug:debug_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_strict_binary(
|
||||
name = "debug_keras",
|
||||
srcs = ["debug_keras.py"],
|
||||
|
|
@ -119,19 +105,6 @@ sh_test(
|
|||
],
|
||||
)
|
||||
|
||||
sh_test(
|
||||
name = "examples_v1_debug_tflearn_iris_test",
|
||||
srcs = ["examples_v1_debug_tflearn_iris_test.sh"],
|
||||
data = [
|
||||
":debug_tflearn_iris",
|
||||
],
|
||||
tags = [
|
||||
"no_windows",
|
||||
"noasan",
|
||||
"v1only",
|
||||
],
|
||||
)
|
||||
|
||||
sh_test(
|
||||
name = "examples_v1_offline_analyzer_test",
|
||||
srcs = ["examples_v1_offline_analyzer_test.sh"],
|
||||
|
|
|
|||
|
|
@ -1,145 +0,0 @@
|
|||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Debug the tf-learn iris example, based on the tf-learn tutorial."""
|
||||
import argparse
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import tensorflow
|
||||
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
tf = tensorflow.compat.v1
|
||||
|
||||
_IRIS_INPUT_DIM = 4
|
||||
|
||||
|
||||
def main(_):
|
||||
# Generate some fake Iris data.
|
||||
# It is okay for this example because this example is about how to use the
|
||||
# debugger, not how to use machine learning to solve the Iris classification
|
||||
# problem.
|
||||
def training_input_fn():
|
||||
return ({
|
||||
"features": tf.random_normal([128, 4])
|
||||
}, tf.random_uniform([128], minval=0, maxval=3, dtype=tf.int32))
|
||||
|
||||
def test_input_fn():
|
||||
return ({
|
||||
"features": tf.random_normal([32, 4])
|
||||
}, tf.random_uniform([32], minval=0, maxval=3, dtype=tf.int32))
|
||||
|
||||
feature_columns = [tf.feature_column.numeric_column("features", shape=(4,))]
|
||||
|
||||
# Build 3 layer DNN with 10, 20, 10 units respectively.
|
||||
model_dir = FLAGS.model_dir or tempfile.mkdtemp(prefix="debug_tflearn_iris_")
|
||||
|
||||
classifier = tf.estimator.DNNClassifier(
|
||||
feature_columns=feature_columns,
|
||||
hidden_units=[10, 20, 10],
|
||||
n_classes=3,
|
||||
model_dir=model_dir)
|
||||
|
||||
if FLAGS.debug and FLAGS.tensorboard_debug_address:
|
||||
raise ValueError(
|
||||
"The --debug and --tensorboard_debug_address flags are mutually "
|
||||
"exclusive.")
|
||||
hooks = []
|
||||
if FLAGS.debug:
|
||||
if FLAGS.use_random_config_path:
|
||||
_, config_file_path = tempfile.mkstemp(".tfdbg_config")
|
||||
else:
|
||||
config_file_path = None
|
||||
hooks.append(
|
||||
tf_debug.LocalCLIDebugHook(
|
||||
ui_type=FLAGS.ui_type,
|
||||
dump_root=FLAGS.dump_root,
|
||||
config_file_path=config_file_path))
|
||||
elif FLAGS.tensorboard_debug_address:
|
||||
hooks.append(tf_debug.TensorBoardDebugHook(FLAGS.tensorboard_debug_address))
|
||||
|
||||
# Train model, using tfdbg hook.
|
||||
classifier.train(training_input_fn, steps=FLAGS.train_steps, hooks=hooks)
|
||||
|
||||
# Evaluate accuracy, using tfdbg hook.
|
||||
accuracy_score = classifier.evaluate(
|
||||
test_input_fn, steps=FLAGS.eval_steps, hooks=hooks)["accuracy"]
|
||||
|
||||
print("After training %d steps, Accuracy = %f" %
|
||||
(FLAGS.train_steps, accuracy_score))
|
||||
|
||||
# Make predictions, using tfdbg hook.
|
||||
predict_results = classifier.predict(test_input_fn, hooks=hooks)
|
||||
print("A prediction result: %s" % next(predict_results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default="/tmp/iris_data",
|
||||
help="Directory to save the training and test data in.")
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Directory to save the trained model in.")
|
||||
parser.add_argument(
|
||||
"--train_steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of steps to run training for.")
|
||||
parser.add_argument(
|
||||
"--eval_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of steps to run evaluation foir.")
|
||||
parser.add_argument(
|
||||
"--ui_type",
|
||||
type=str,
|
||||
default="readline",
|
||||
help="Command-line user interface type (only readline is supported)")
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="Use debugger to track down bad values during training. "
|
||||
"Mutually exclusive with the --tensorboard_debug_address flag.")
|
||||
parser.add_argument(
|
||||
"--dump_root",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional custom root directory for temporary debug dump data")
|
||||
parser.add_argument(
|
||||
"--use_random_config_path",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="""If set, set config file path to a random file in the temporary
|
||||
directory.""")
|
||||
parser.add_argument(
|
||||
"--tensorboard_debug_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Connect to the TensorBoard Debugger Plugin backend specified by "
|
||||
"the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
|
||||
"--debug flag.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
#
|
||||
# Bash unit tests for TensorFlow Debugger (tfdbg) Python examples that do not
|
||||
# involve downloading data.
|
||||
#
|
||||
# Command-line flags:
|
||||
# --virtualenv: (optional) If set, will test the examples and binaries
|
||||
# against pip install of TensorFlow in a virtualenv.
|
||||
|
||||
set -e
|
||||
|
||||
# Filter out LOG(INFO)
|
||||
export TF_CPP_MIN_LOG_LEVEL=1
|
||||
|
||||
IS_VIRTUALENV=0
|
||||
PYTHON_BIN_PATH=""
|
||||
while true; do
|
||||
if [[ -z "$1" ]]; then
|
||||
break
|
||||
elif [[ "$1" == "--virtualenv" ]]; then
|
||||
IS_VIRTUALENV=1
|
||||
PYTHON_BIN_PATH=$(which python)
|
||||
echo
|
||||
echo "IS_VIRTUALENV = ${IS_VIRTUALENV}"
|
||||
echo "PYTHON_BIN_PATH = ${PYTHON_BIN_PATH}"
|
||||
echo "Will test tfdbg debug_tflearn_iris against virtualenv pip install."
|
||||
echo
|
||||
fi
|
||||
shift 1
|
||||
done
|
||||
|
||||
if [[ -z "${PYTHON_BIN_PATH}" ]]; then
|
||||
DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/examples/v1/debug_tflearn_iris"
|
||||
else
|
||||
DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.v1.debug_tflearn_iris"
|
||||
fi
|
||||
|
||||
# Test the custom dump_root option.
|
||||
CUSTOM_DUMP_ROOT=$(mktemp -d)
|
||||
mkdir -p ${CUSTOM_DUMP_ROOT}
|
||||
|
||||
# Override the default ui_type=curses to allow the test to pass in a tty-less
|
||||
# test environment.
|
||||
cat << EOF | ${DEBUG_TFLEARN_IRIS_BIN} --debug --train_steps=2 --dump_root="${CUSTOM_DUMP_ROOT}" --ui_type=readline --use_random_config_path
|
||||
run -p
|
||||
run -f has_inf_or_nan
|
||||
EOF
|
||||
|
||||
# Verify that the dump root has been cleaned up on exit.
|
||||
if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then
|
||||
echo "ERROR: dump root at ${CUSTOM_DUMP_ROOT} failed to be cleaned up." 1>&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "SUCCESS: tfdbg debug_tflearn_iris test PASSED"
|
||||
|
|
@ -26,10 +26,7 @@ from tensorflow.python.training import session_run_hook
|
|||
class LocalCLIDebugHook(session_run_hook.SessionRunHook):
|
||||
"""Command-line-interface debugger hook.
|
||||
|
||||
Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
|
||||
`tf.estimator.Estimator`s. Provides a substitute for
|
||||
`tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly
|
||||
available.
|
||||
Can be used as a hook for `tf.compat.v1.train.MonitoredSession`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -147,8 +144,7 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook):
|
|||
class DumpingDebugHook(session_run_hook.SessionRunHook):
|
||||
"""A debugger hook that dumps debug data to filesystem.
|
||||
|
||||
Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
|
||||
`tf.estimator.Estimator`s.
|
||||
Can be used as a hook for `tf.compat.v1.train.MonitoredSession`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -222,8 +218,7 @@ class GrpcDebugHook(session_run_hook.SessionRunHook):
|
|||
When the arguments of debug_utils.watch_graph changes, strongly consider
|
||||
changing arguments here too so that features are available to tflearn users.
|
||||
|
||||
Can be used as a hook for `tf.compat.v1.train.MonitoredSession`s and
|
||||
`tf.estimator.Estimator`s.
|
||||
Can be used as a hook for `tf.compat.v1.train.MonitoredSession`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
|||
|
|
@ -827,23 +827,6 @@ tpu_py_strict_test(
|
|||
],
|
||||
)
|
||||
|
||||
# Used only by estimator.
|
||||
py_strict_library(
|
||||
name = "estimator_training",
|
||||
srcs = [
|
||||
"estimator_training.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":distribute_coordinator",
|
||||
":distribute_coordinator_context",
|
||||
":multi_worker_util",
|
||||
"//tensorflow/python/platform:tf_logging",
|
||||
"//tensorflow/python/training:server_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_strict_library(
|
||||
name = "reduce_util",
|
||||
srcs = ["reduce_util.py"],
|
||||
|
|
@ -1756,14 +1739,9 @@ distribute_py_strict_test(
|
|||
":distribute_utils",
|
||||
":strategy_combinations",
|
||||
":values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/framework:constant_op",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/ops:array_ops",
|
||||
"//tensorflow/python/ops:variable_scope",
|
||||
"//tensorflow/python/ops:variable_v1",
|
||||
"//tensorflow/python/saved_model/model_utils:mode_keys",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@pypi_wrapt//:pkg",
|
||||
],
|
||||
|
|
@ -2258,7 +2236,6 @@ cuda_py_strict_test(
|
|||
"//tensorflow/python/distribute/v1:input_lib",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/estimator:run_config",
|
||||
"//tensorflow/python/framework:constant_op",
|
||||
"//tensorflow/python/framework:device",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ tf.distribute.Strategy is a TensorFlow API to distribute training across
|
|||
multiple GPUs, multiple machines or TPUs. Using this API, users can distribute
|
||||
their existing models and training code with minimal code changes.
|
||||
|
||||
It can be used with TensorFlow's high level APIs, tf.keras and tf.estimator,
|
||||
It can be used with TensorFlow's high level APIs, like tf.keras,
|
||||
with just a couple of lines of code change. It does so by changing the
|
||||
underlying components of TensorFlow to become strategy-aware.
|
||||
This includes variables, layers, models, optimizers, metrics, summaries,
|
||||
|
|
@ -22,8 +22,6 @@ and checkpoints.
|
|||
|
||||
[Multiworker Training With Keras Tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
|
||||
|
||||
[Multiworker Training With Estimator Tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator)
|
||||
|
||||
[Save and Load with Distribution Strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load)
|
||||
|
||||
## Simple Examples
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
|
@ -35,9 +35,8 @@ v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute
|
|||
Tutorials](https://www.tensorflow.org/tutorials/distribute/)
|
||||
|
||||
The tutorials cover how to use `tf.distribute.Strategy` to do distributed
|
||||
training with native Keras APIs, custom training loops,
|
||||
and Estimator APIs. They also cover how to save/load model when using
|
||||
`tf.distribute.Strategy`.
|
||||
training with native Keras APIs, and custom training loops.
|
||||
They also cover how to save/load model when using `tf.distribute.Strategy`.
|
||||
|
||||
*Glossary*
|
||||
|
||||
|
|
|
|||
|
|
@ -1001,10 +1001,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||
# to limit the backward incompatibility.
|
||||
if hasattr(self, "_check_health_thread"):
|
||||
raise ValueError(
|
||||
"MultiWorkerMirroredStrategy cannot be deep copied in eager mode. "
|
||||
"If you're using Estimator and see this error message, call "
|
||||
"tf.compat.v1.disable_eager_execution() at the beginning of your "
|
||||
"program")
|
||||
"MultiWorkerMirroredStrategy cannot be deep copied in eager mode.")
|
||||
# Otherwise, do a regular deepcopy.
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
|
|
|
|||
|
|
@ -755,7 +755,6 @@ class ExperimentalCompatibilityTest(test.TestCase):
|
|||
_CollectiveAllReduceStrategyExperimental)
|
||||
|
||||
def testName(self):
|
||||
# Estimator checks the __name__ to special case MultiWorkerMirroredStrategy.
|
||||
self.assertEqual(CollectiveAllReduceStrategy.__name__,
|
||||
'CollectiveAllReduceStrategy')
|
||||
self.assertEqual(_CollectiveAllReduceStrategyExperimental.__name__,
|
||||
|
|
|
|||
|
|
@ -36,9 +36,8 @@ v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute
|
|||
Tutorials](https://www.tensorflow.org/tutorials/distribute/)
|
||||
|
||||
The tutorials cover how to use `tf.distribute.Strategy` to do distributed
|
||||
training with native Keras APIs, custom training loops,
|
||||
and Estimator APIs. They also cover how to save/load model when using
|
||||
`tf.distribute.Strategy`.
|
||||
training with native Keras APIs, and custom training loops.
|
||||
They also cover how to save/load model when using `tf.distribute.Strategy`.
|
||||
|
||||
*Glossary*
|
||||
|
||||
|
|
@ -1100,10 +1099,6 @@ class StrategyBase(object):
|
|||
* To use it with Keras `compile`/`fit`,
|
||||
[please
|
||||
read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
|
||||
* You may pass descendant of `tf.distribute.Strategy` to
|
||||
`tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
|
||||
should distribute its computation. See
|
||||
[guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
|
||||
* Otherwise, use `tf.distribute.Strategy.scope` to specify that a
|
||||
strategy should be used when building an executing your model.
|
||||
(This puts you in the "cross-replica context" for this strategy, which
|
||||
|
|
@ -1181,9 +1176,6 @@ class StrategyBase(object):
|
|||
def __init__(self, extended):
|
||||
self._extended = extended
|
||||
|
||||
# Flag that is used to indicate whether distribution strategy is used with
|
||||
# Estimator. This is required for backward compatibility of loss scaling
|
||||
# when using v1 optimizer with estimator.
|
||||
self._scale_loss_for_estimator = False
|
||||
|
||||
if not hasattr(extended, "_retrace_functions_for_each_device"):
|
||||
|
|
|
|||
|
|
@ -40,9 +40,6 @@ from tensorflow.python.util.tf_export import tf_export
|
|||
def get_loss_reduction():
|
||||
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction.
|
||||
|
||||
This is used to decide whether loss should be scaled in optimizer (used only
|
||||
for estimator + v1 optimizer use case).
|
||||
|
||||
Returns:
|
||||
`tf.distribute.ReduceOp` corresponding to the last loss reduction for
|
||||
estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
|
||||
|
|
|
|||
|
|
@ -24,14 +24,9 @@ from tensorflow.python.distribute import combinations
|
|||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variable_v1
|
||||
from tensorflow.python.saved_model.model_utils import mode_keys
|
||||
|
||||
|
||||
def _nested_value(d):
|
||||
|
|
@ -188,52 +183,6 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
|||
self.assertEqual(_nested_value("1"),
|
||||
distribute_utils.select_replica(0, result))
|
||||
|
||||
def testNamedTuple(self):
|
||||
|
||||
# We include toy implementations of Scaffold and EstimatorSpec to
|
||||
# avoid a dependency on Estimator here.
|
||||
|
||||
class Scaffold(object):
|
||||
pass
|
||||
|
||||
class EstimatorSpec(collections.namedtuple(
|
||||
"EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])):
|
||||
|
||||
def __new__(cls, mode, loss, train_op, scaffold=None):
|
||||
return super(EstimatorSpec, cls).__new__(
|
||||
cls, mode=mode, loss=loss, train_op=train_op,
|
||||
scaffold=scaffold or Scaffold())
|
||||
|
||||
with context.graph_mode(), ops.Graph().as_default():
|
||||
created_estimator_specs = []
|
||||
|
||||
for device_id in range(3):
|
||||
spec = EstimatorSpec(
|
||||
mode=mode_keys.EstimatorModeKeys.TRAIN,
|
||||
loss=constant_op.constant(device_id / 2),
|
||||
train_op=array_ops.identity(constant_op.constant(device_id)))
|
||||
created_estimator_specs.append(spec)
|
||||
|
||||
merged_estimator_spec = distribute_utils.regroup(created_estimator_specs)
|
||||
|
||||
self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
|
||||
self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
|
||||
merged_estimator_spec.mode)
|
||||
for device_id in range(3):
|
||||
self.assertEqual(created_estimator_specs[device_id].loss,
|
||||
merged_estimator_spec.loss.values[device_id])
|
||||
self.assertEqual(created_estimator_specs[device_id].train_op,
|
||||
merged_estimator_spec.train_op.values[device_id])
|
||||
# Scaffold is populated by `EstimatorSpec.__new__`.
|
||||
self.assertEqual(created_estimator_specs[device_id].scaffold,
|
||||
merged_estimator_spec.scaffold.values[device_id])
|
||||
self.assertIsInstance(created_estimator_specs[device_id].scaffold,
|
||||
Scaffold)
|
||||
# Also test that we can undo the merge using select_replica()
|
||||
self.assertEqual(created_estimator_specs[device_id],
|
||||
distribute_utils.select_replica(
|
||||
device_id, merged_estimator_spec))
|
||||
|
||||
def testWrappedNamedTuple(self):
|
||||
Point = collections.namedtuple("Point", ["x", "y"])
|
||||
point1 = Point(x=0, y=2)
|
||||
|
|
|
|||
|
|
@ -1,387 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Training utilities for Estimator to use Distribute Coordinator."""
|
||||
|
||||
import copy
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
# pylint: disable=protected-access
|
||||
CHIEF = dc._TaskType.CHIEF
|
||||
EVALUATOR = dc._TaskType.EVALUATOR
|
||||
PS = dc._TaskType.PS
|
||||
WORKER = dc._TaskType.WORKER
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _count_ps(cluster_spec):
|
||||
"""Counts the number of parameter servers in cluster_spec."""
|
||||
if not cluster_spec:
|
||||
raise RuntimeError(
|
||||
'Internal error: `_count_ps` does not expect empty cluster_spec.')
|
||||
|
||||
return len(cluster_spec.as_dict().get(PS, []))
|
||||
|
||||
|
||||
def _count_worker(cluster_spec, chief_task_type):
|
||||
"""Counts the number of workers (including chief) in cluster_spec."""
|
||||
if not cluster_spec:
|
||||
raise RuntimeError(
|
||||
'Internal error: `_count_worker` does not expect empty cluster_spec.')
|
||||
|
||||
return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
|
||||
cluster_spec.as_dict().get(chief_task_type, [])))
|
||||
|
||||
|
||||
def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
|
||||
"""Returns the global id of the given task type in a cluster."""
|
||||
if not task_type:
|
||||
return 0
|
||||
|
||||
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
|
||||
# and "ps". More details can be found at the documentation of
|
||||
# `tf.estimator.RunConfig.global_id_in_cluster`.
|
||||
task_type_ordered_list = []
|
||||
if chief_task_type in cluster_spec.jobs:
|
||||
task_type_ordered_list = [chief_task_type]
|
||||
task_type_ordered_list.extend([
|
||||
t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
|
||||
])
|
||||
if PS in cluster_spec.jobs:
|
||||
task_type_ordered_list.append(PS)
|
||||
|
||||
# Find the right global_id for current task.
|
||||
next_global_id = 0
|
||||
for t in task_type_ordered_list:
|
||||
if t == task_type:
|
||||
return next_global_id + task_id
|
||||
# `cluster_spec.job_tasks` returns all task addresses of type `t`.
|
||||
next_global_id += len(cluster_spec.job_tasks(t))
|
||||
|
||||
# It is unexpected that it passes through all task_types in
|
||||
# `task_type_ordered_list`.
|
||||
raise RuntimeError('Internal Error: `task_type` ({}) is not in '
|
||||
'cluster_spec ({}).'.format(task_type, cluster_spec))
|
||||
|
||||
|
||||
def _init_run_config_from_worker_context(config, worker_context):
|
||||
"""Initializes run config from distribute coordinator's worker context."""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
config._service = None
|
||||
config._cluster_spec = worker_context.cluster_spec
|
||||
config._task_type = worker_context.task_type
|
||||
config._task_id = worker_context.task_id
|
||||
config._evaluation_master = worker_context.master_target
|
||||
config._master = worker_context.master_target
|
||||
config._is_chief = worker_context.is_chief
|
||||
|
||||
if config._cluster_spec:
|
||||
# Distributed mode.
|
||||
if config._task_type != EVALUATOR:
|
||||
|
||||
config._num_ps_replicas = _count_ps(config._cluster_spec)
|
||||
config._num_worker_replicas = _count_worker(
|
||||
config._cluster_spec, chief_task_type=CHIEF)
|
||||
config._global_id_in_cluster = _get_global_id(
|
||||
config._cluster_spec,
|
||||
config._task_type,
|
||||
config._task_id,
|
||||
chief_task_type=CHIEF)
|
||||
else:
|
||||
# Evaluator task should not be aware of the other tasks.
|
||||
config._cluster_spec = server_lib.ClusterSpec({})
|
||||
config._num_ps_replicas = 0
|
||||
config._num_worker_replicas = 0
|
||||
config._global_id_in_cluster = None # undefined
|
||||
else:
|
||||
# Local mode.
|
||||
config._global_id_in_cluster = 0
|
||||
config._num_ps_replicas = 0
|
||||
config._num_worker_replicas = 1
|
||||
|
||||
|
||||
def init_run_config(config, tf_config):
|
||||
"""Initializes RunConfig for distribution strategies."""
|
||||
# pylint: disable=protected-access
|
||||
if (config._experimental_distribute and
|
||||
config._experimental_distribute.train_distribute):
|
||||
if config._train_distribute:
|
||||
raise ValueError('Either `train_distribute` or'
|
||||
'`experimental_distribute.train_distribute` can be set.')
|
||||
config._train_distribute = config._experimental_distribute.train_distribute
|
||||
|
||||
if (config._experimental_distribute and
|
||||
config._experimental_distribute.eval_distribute):
|
||||
if config._eval_distribute:
|
||||
raise ValueError('Either `eval_distribute` or'
|
||||
'`experimental_distribute.eval_distribute` can be set.')
|
||||
config._eval_distribute = config._experimental_distribute.eval_distribute
|
||||
|
||||
cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
|
||||
config._init_distributed_setting_from_environment_var({})
|
||||
|
||||
# Use distribute coordinator with STANDALONE_CLIENT mode if
|
||||
# `experimental_distribute.remote_cluster` is set.
|
||||
if (config._train_distribute and config._experimental_distribute and
|
||||
config._experimental_distribute.remote_cluster):
|
||||
if cluster_spec:
|
||||
raise ValueError('Cannot set both "cluster_spec" of TF_CONFIG and '
|
||||
'`experimental_distribute.remote_cluster`')
|
||||
config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
|
||||
config._cluster_spec = config._experimental_distribute.remote_cluster
|
||||
logging.info('RunConfig initialized for Distribute Coordinator with '
|
||||
'STANDALONE_CLIENT mode')
|
||||
return
|
||||
|
||||
# Don't use distribute coordinator if it is local training or cluster has a
|
||||
# MASTER job or `train_distribute` is not specified.
|
||||
if (not cluster_spec or 'master' in cluster_spec.jobs or
|
||||
not config._train_distribute):
|
||||
config._distribute_coordinator_mode = None
|
||||
config._init_distributed_setting_from_environment_var(tf_config)
|
||||
config._maybe_overwrite_session_config_for_distributed_training()
|
||||
logging.info('Not using Distribute Coordinator.')
|
||||
return
|
||||
|
||||
# Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
|
||||
assert tf_config
|
||||
|
||||
# Set the cluster_spec only since the distributed setting will come from
|
||||
# distribute coordinator.
|
||||
config._cluster_spec = cluster_spec
|
||||
config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
|
||||
logging.info('RunConfig initialized for Distribute Coordinator with '
|
||||
'INDEPENDENT_WORKER mode')
|
||||
|
||||
|
||||
def should_run_distribute_coordinator(config):
|
||||
"""Checks the config to see whether to run distribute coordinator."""
|
||||
# pylint: disable=protected-access
|
||||
if (not hasattr(config, '_distribute_coordinator_mode') or
|
||||
config._distribute_coordinator_mode is None):
|
||||
logging.info('Not using Distribute Coordinator.')
|
||||
return False
|
||||
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
|
||||
config._distribute_coordinator_mode not in [
|
||||
dc.CoordinatorMode.STANDALONE_CLIENT,
|
||||
dc.CoordinatorMode.INDEPENDENT_WORKER
|
||||
]):
|
||||
logging.warning('Unexpected distribute_coordinator_mode: %r',
|
||||
config._distribute_coordinator_mode)
|
||||
return False
|
||||
if not config.cluster_spec:
|
||||
logging.warning('Running `train_and_evaluate` locally, ignoring '
|
||||
'`experimental_distribute_coordinator_mode`.')
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
|
||||
"""Run distribute coordinator for Estimator's `train_and_evaluate`.
|
||||
|
||||
Args:
|
||||
estimator: An `Estimator` instance to train and evaluate.
|
||||
train_spec: A `TrainSpec` instance to specify the training specification.
|
||||
eval_spec: A `EvalSpec` instance to specify the evaluation and export
|
||||
specification.
|
||||
executor_cls: the evaluation executor class of Estimator.
|
||||
|
||||
Raises:
|
||||
ValueError: if `distribute_coordinator_mode` is None in RunConfig.
|
||||
"""
|
||||
run_config = estimator.config
|
||||
if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
|
||||
raise ValueError(
|
||||
'Distribute coordinator mode is not specified in `RunConfig`.')
|
||||
|
||||
def _worker_fn(strategy):
|
||||
"""Function for worker task."""
|
||||
local_estimator = copy.deepcopy(estimator)
|
||||
# pylint: disable=protected-access
|
||||
local_estimator._config._train_distribute = strategy
|
||||
context = dc_context.get_current_worker_context()
|
||||
_init_run_config_from_worker_context(local_estimator._config, context)
|
||||
logging.info('Updated config: %s', str(vars(local_estimator._config)))
|
||||
local_estimator._train_distribution = strategy
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# In the standalone client, we don't need to run hooks on all threads
|
||||
# because logging hooks on all threads may be too much on the screen; also
|
||||
# tensor passed to one hook can only be fetched with the graph where the
|
||||
# tensor is defined. Other hooks such as checkpointing hooks will added by
|
||||
# MonitoredTrainingSession.
|
||||
# TODO(yuefengz): Is there a hook that does need to run on all threads in
|
||||
# standalone client mode?
|
||||
if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access
|
||||
dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
|
||||
hooks = list(train_spec.hooks)
|
||||
else:
|
||||
hooks = []
|
||||
|
||||
# Prevent estimator.train from calling distribute coordinator again. This
|
||||
# function calls estimator.train which will use distribute coordinator path
|
||||
# again if `_distribute_coordinator_mode` is set.
|
||||
local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access
|
||||
local_estimator.train(
|
||||
input_fn=train_spec.input_fn,
|
||||
max_steps=train_spec.max_steps,
|
||||
hooks=hooks)
|
||||
|
||||
def _eval_fn(strategy):
|
||||
"""Function for evaluator task."""
|
||||
local_estimator = copy.deepcopy(estimator)
|
||||
# pylint: disable=protected-access
|
||||
local_estimator._config._eval_distribute = strategy
|
||||
_init_run_config_from_worker_context(
|
||||
local_estimator._config, dc_context.get_current_worker_context())
|
||||
logging.info('Updated config: %s', str(vars(local_estimator._config)))
|
||||
local_estimator._eval_distribution = strategy
|
||||
|
||||
# Prevent estimator.evaluate from calling distribute coordinator again. This
|
||||
# function calls estimator.evaluate which will use distribute coordinator
|
||||
# path again if `_distribute_coordinator_mode` is set.
|
||||
local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access
|
||||
|
||||
executor = executor_cls(local_estimator, train_spec, eval_spec)
|
||||
executor._start_continuous_evaluation()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if (run_config._distribute_coordinator_mode ==
|
||||
dc.CoordinatorMode.STANDALONE_CLIENT):
|
||||
cluster_spec = run_config.cluster_spec
|
||||
assert cluster_spec
|
||||
else:
|
||||
# The cluster_spec comes from TF_CONFIG environment variable if it is
|
||||
# INDEPENDENT_WORKER mode.
|
||||
cluster_spec = None
|
||||
|
||||
dc.run_distribute_coordinator(
|
||||
_worker_fn,
|
||||
run_config.train_distribute,
|
||||
_eval_fn,
|
||||
run_config.eval_distribute,
|
||||
mode=run_config._distribute_coordinator_mode,
|
||||
cluster_spec=cluster_spec,
|
||||
session_config=run_config.session_config)
|
||||
|
||||
|
||||
# TODO(yuefengz): maybe merge the following two functions?
|
||||
# pylint: disable=protected-access
|
||||
def estimator_train(estimator, train_distributed_fn, hooks):
|
||||
"""Run distribute coordinator for Estimator's `train` method."""
|
||||
assert estimator._config._distribute_coordinator_mode
|
||||
run_config = estimator._config
|
||||
assert estimator._config.cluster_spec
|
||||
cluster_spec = multi_worker_util.normalize_cluster_spec(
|
||||
estimator._config.cluster_spec)
|
||||
assert estimator._config._train_distribute
|
||||
|
||||
if 'evaluator' in cluster_spec.jobs:
|
||||
raise ValueError("'evaluator' job is not supported if you don't use "
|
||||
'`train_and_evaluate`')
|
||||
|
||||
if (estimator._config._distribute_coordinator_mode != # pylint: disable=protected-access
|
||||
dc.CoordinatorMode.STANDALONE_CLIENT):
|
||||
raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
|
||||
'`estimator.train`')
|
||||
|
||||
if estimator._config._train_distribute.extended.experimental_between_graph:
|
||||
# TODO(yuefengz): remove this limitation once we figure out how to merge
|
||||
# return values from `_worker_fn`s.
|
||||
raise ValueError('`Estimator.train` API is not supported for %s with '
|
||||
'`STANDALONE_CLIENT` mode.' %
|
||||
estimator._config._train_distribute.__class__.__name__)
|
||||
|
||||
def _worker_fn(strategy):
|
||||
"""Function for worker task."""
|
||||
local_estimator = copy.deepcopy(estimator)
|
||||
local_estimator._config._train_distribute = strategy
|
||||
context = dc_context.get_current_worker_context()
|
||||
_init_run_config_from_worker_context(local_estimator._config, context)
|
||||
logging.info('Updated config: %s', str(vars(local_estimator._config)))
|
||||
local_estimator._train_distribution = strategy
|
||||
|
||||
if context.is_chief:
|
||||
chief_hooks = hooks
|
||||
else:
|
||||
chief_hooks = []
|
||||
train_distributed_fn(local_estimator, strategy, chief_hooks)
|
||||
return local_estimator
|
||||
|
||||
return dc.run_distribute_coordinator(
|
||||
_worker_fn,
|
||||
estimator._config.train_distribute,
|
||||
mode=run_config._distribute_coordinator_mode,
|
||||
cluster_spec=cluster_spec,
|
||||
session_config=run_config.session_config)
|
||||
|
||||
|
||||
def estimator_evaluate(estimator, evaluate_distributed_fn, hooks):
|
||||
"""Run distribute coordinator for Estimator's `evaluate` method."""
|
||||
assert estimator._config._distribute_coordinator_mode
|
||||
run_config = estimator._config
|
||||
assert estimator._config.cluster_spec
|
||||
cluster_spec = multi_worker_util.normalize_cluster_spec(
|
||||
estimator._config.cluster_spec)
|
||||
assert estimator._config._eval_distribute
|
||||
|
||||
if 'evaluator' in cluster_spec.jobs:
|
||||
raise ValueError("'evaluator' job is not supported if you don't use "
|
||||
'`train_and_evaluate`')
|
||||
|
||||
if (estimator._config._distribute_coordinator_mode !=
|
||||
dc.CoordinatorMode.STANDALONE_CLIENT):
|
||||
raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
|
||||
'`Estimator.evaluate`')
|
||||
|
||||
if estimator._config._eval_distribute.extended.experimental_between_graph:
|
||||
# TODO(yuefengz): remove this limitation once we figure out how to merge
|
||||
# return values from `_worker_fn`s.
|
||||
raise ValueError('`Estimator.evaluate` API is not supported for %s with '
|
||||
'`STANDALONE_CLIENT` mode.' %
|
||||
estimator._config._eval_distribute.__class__.__name__)
|
||||
|
||||
def _worker_fn(strategy):
|
||||
"""Function for evaluation."""
|
||||
local_estimator = copy.deepcopy(estimator)
|
||||
local_estimator._config._eval_distribute = strategy
|
||||
context = dc_context.get_current_worker_context()
|
||||
_init_run_config_from_worker_context(local_estimator._config, context)
|
||||
logging.info('Updated config: %s', str(vars(local_estimator._config)))
|
||||
local_estimator._eval_distribution = strategy
|
||||
|
||||
if context.is_chief:
|
||||
chief_hooks = hooks
|
||||
else:
|
||||
chief_hooks = []
|
||||
return evaluate_distributed_fn(local_estimator, strategy, chief_hooks)
|
||||
|
||||
return dc.run_distribute_coordinator(
|
||||
_worker_fn,
|
||||
estimator._config.eval_distribute,
|
||||
mode=run_config._distribute_coordinator_mode,
|
||||
cluster_spec=cluster_spec,
|
||||
session_config=run_config.session_config)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
|
@ -60,7 +60,7 @@ def _validate_cluster_spec(cluster_spec,
|
|||
2) whether there is such a task type as `task_type` in the `cluster_spec`. The
|
||||
only exception is `evaluator`. In other words, it is still a valid
|
||||
configuration when `task_type` is `evaluator` but it doesn't appear in
|
||||
`cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
|
||||
`cluster_spec`.
|
||||
3) whether there is at most one "chief" job.
|
||||
4) whether there is at most one "evaluator" job.
|
||||
5) whether the `task_id` is smaller than the number of tasks for that
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cl
|
|||
from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.estimator import run_config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
|
@ -54,9 +53,9 @@ from tensorflow.python.ops import variables
|
|||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
CHIEF = run_config.TaskType.CHIEF
|
||||
WORKER = run_config.TaskType.WORKER
|
||||
PS = run_config.TaskType.PS
|
||||
CHIEF = 'chief'
|
||||
WORKER = 'worker'
|
||||
PS = 'ps'
|
||||
|
||||
|
||||
def _get_replica_id_integer():
|
||||
|
|
|
|||
|
|
@ -766,7 +766,7 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
|
|||
host. Note that this can have side-effects on performance, hooks,
|
||||
metrics, summaries etc.
|
||||
This parameter is only used when Distribution Strategy is used with
|
||||
estimator or keras.
|
||||
Keras.
|
||||
device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
|
||||
specify the placement of replicas on the TPU cluster. Currently only
|
||||
supports the usecase of using a single core within a TPU cluster.
|
||||
|
|
|
|||
|
|
@ -1,388 +0,0 @@
|
|||
load("//tensorflow:py.default.bzl", "py_library")
|
||||
|
||||
package(
|
||||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "estimator_py",
|
||||
srcs = [
|
||||
"estimator_lib.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
visibility = [
|
||||
"//tensorflow:__pkg__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":baseline",
|
||||
":dnn",
|
||||
":dnn_linear_combined",
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":export",
|
||||
":exporter",
|
||||
":inputs",
|
||||
":keras",
|
||||
":linear",
|
||||
":model_fn",
|
||||
":parsing_utils",
|
||||
":run_config",
|
||||
":training",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "exporter",
|
||||
srcs = ["exporter.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":gc",
|
||||
":metric_keys",
|
||||
":util",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gc",
|
||||
srcs = ["gc.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_fn",
|
||||
srcs = ["model_fn.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":export_output",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "training",
|
||||
srcs = ["training.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":exporter",
|
||||
":run_config",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "run_config",
|
||||
srcs = ["run_config.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "baseline",
|
||||
srcs = ["canned/baseline.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":head",
|
||||
":model_fn",
|
||||
":optimizers",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dnn",
|
||||
srcs = ["canned/dnn.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":head",
|
||||
":model_fn",
|
||||
":optimizers",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dnn_testing_utils",
|
||||
testonly = 1,
|
||||
srcs = ["canned/dnn_testing_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":head",
|
||||
":metric_keys",
|
||||
":model_fn",
|
||||
":numpy_io",
|
||||
":prediction_keys",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dnn_linear_combined",
|
||||
srcs = ["canned/dnn_linear_combined.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dnn",
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":head",
|
||||
":linear",
|
||||
":model_fn",
|
||||
":optimizers",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "util",
|
||||
srcs = [
|
||||
"util.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "estimator",
|
||||
srcs = [
|
||||
"estimator.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":export_export",
|
||||
":model_fn",
|
||||
":run_config",
|
||||
":util",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "parsing_utils",
|
||||
srcs = [
|
||||
"canned/parsing_utils.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "export_output",
|
||||
srcs = ["export/export_output.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "export",
|
||||
srcs = [
|
||||
"export/export_lib.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":export_export",
|
||||
":export_output",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "export_export",
|
||||
srcs = [
|
||||
"export/export.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":util",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "head",
|
||||
srcs = ["canned/head.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":export_output",
|
||||
":metric_keys",
|
||||
":model_fn",
|
||||
":prediction_keys",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "inputs",
|
||||
srcs = ["inputs/inputs.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":numpy_io",
|
||||
":pandas_io",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "linear",
|
||||
srcs = ["canned/linear.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":head",
|
||||
":optimizers",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "linear_testing_utils",
|
||||
testonly = 1,
|
||||
srcs = ["canned/linear_testing_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":export_export",
|
||||
":linear",
|
||||
":metric_keys",
|
||||
":numpy_io",
|
||||
":pandas_io",
|
||||
":run_config",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metric_keys",
|
||||
srcs = ["canned/metric_keys.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":model_fn",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "numpy_io",
|
||||
srcs = ["inputs/numpy_io.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":inputs_queues",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "optimizers",
|
||||
srcs = ["canned/optimizers.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pandas_io",
|
||||
srcs = ["inputs/pandas_io.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":inputs_queues",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "prediction_keys",
|
||||
srcs = ["canned/prediction_keys.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "inputs_queues",
|
||||
srcs = [
|
||||
"inputs/queues/feeding_functions.py",
|
||||
"inputs/queues/feeding_queue_runner.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":expect_tensorflow_estimator_installed",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "keras",
|
||||
srcs = ["keras.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":expect_tensorflow_estimator_installed",
|
||||
":export_export",
|
||||
":model_fn",
|
||||
":run_config",
|
||||
"//tensorflow:tensorflow_py_no_contrib",
|
||||
],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "expect_tensorflow_estimator_installed",
|
||||
actual = "@pypi_tf_estimator_nightly//:pkg",
|
||||
)
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""baseline python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import baseline
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
baseline.__all__ = [s for s in dir(baseline) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.baseline import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""dnn python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import dnn
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
dnn.__all__ = [s for s in dir(dnn) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.dnn import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""dnn_linear_combined python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import dnn_linear_combined
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
dnn_linear_combined.__all__ = [
|
||||
s for s in dir(dnn_linear_combined) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.dnn_linear_combined import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""dnn_testing_utils python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import dnn_testing_utils
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
dnn_testing_utils.__all__ = [
|
||||
s for s in dir(dnn_testing_utils) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.dnn_testing_utils import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""head python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import head
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
head.__all__ = [s for s in dir(head) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.head import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""linear python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import linear
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
linear.__all__ = [s for s in dir(linear) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.linear import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""linear_testing_utils python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import linear_testing_utils
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
linear_testing_utils.__all__ = [
|
||||
s for s in dir(linear_testing_utils) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.linear_testing_utils import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""metric_keys python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import metric_keys
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
metric_keys.__all__ = [s for s in dir(metric_keys) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.metric_keys import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""optimizers python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import optimizers
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
optimizers.__all__ = [s for s in dir(optimizers) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.optimizers import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""parsing_utils python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import parsing_utils
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
parsing_utils.__all__ = [
|
||||
s for s in dir(parsing_utils) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.parsing_utils import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""prediction_keys python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned import prediction_keys
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
prediction_keys.__all__ = [
|
||||
s for s in dir(prediction_keys) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.canned.prediction_keys import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""estimator python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import estimator
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
estimator.__all__ = [s for s in dir(estimator) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.estimator import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""estimator_lib python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import estimator_lib
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
estimator_lib.__all__ = [
|
||||
s for s in dir(estimator_lib) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.estimator_lib import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""export python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.export import export
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
export.__all__ = [s for s in dir(export) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.export.export import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""export_lib python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.export import export_lib
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
export_lib.__all__ = [s for s in dir(export_lib) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.export.export_lib import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""export_output python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.export import export_output
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
export_output.__all__ = [
|
||||
s for s in dir(export_output) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.export.export_output import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""exporter python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import exporter
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
exporter.__all__ = [s for s in dir(exporter) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.exporter import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""gc python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import gc
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
gc.__all__ = [s for s in dir(gc) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.gc import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""inputs python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs import inputs
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
inputs.__all__ = [s for s in dir(inputs) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.inputs import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""numpy_io python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs import numpy_io
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
numpy_io.__all__ = [s for s in dir(numpy_io) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.numpy_io import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""pandas_io python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs import pandas_io
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
pandas_io.__all__ = [s for s in dir(pandas_io) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.pandas_io import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""queues python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs import queues
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
queues.__all__ = [s for s in dir(queues) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.queues import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""feeding_functions python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.queues import feeding_functions
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
feeding_functions.__all__ = [
|
||||
s for s in dir(feeding_functions) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.queues.feeding_functions import *
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""feeding_queue_runner python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.queues import feeding_queue_runner
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
feeding_queue_runner.__all__ = [
|
||||
s for s in dir(feeding_queue_runner) if not s.startswith('__')
|
||||
]
|
||||
|
||||
from tensorflow_estimator.python.estimator.inputs.queues.feeding_queue_runner import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""keras python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import keras_lib
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
keras_lib.__all__ = [s for s in dir(keras_lib) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.keras_lib import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""model_fn python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import model_fn
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
model_fn.__all__ = [s for s in dir(model_fn) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.model_fn import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""run_config python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import run_config
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
run_config.__all__ = [s for s in dir(run_config) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.run_config import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""training python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import training
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
training.__all__ = [s for s in dir(training) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.training import *
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""util python module.
|
||||
|
||||
Importing from tensorflow.python.estimator is unsupported
|
||||
and will soon break!
|
||||
"""
|
||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
||||
|
||||
from tensorflow_estimator.python.estimator import util
|
||||
|
||||
# Include attrs that start with single underscore.
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
util.__all__ = [s for s in dir(util) if not s.startswith('__')]
|
||||
|
||||
from tensorflow_estimator.python.estimator.util import *
|
||||
|
|
@ -14,111 +14,9 @@
|
|||
# ==============================================================================
|
||||
"""This API defines FeatureColumn abstraction.
|
||||
|
||||
FeatureColumns provide a high level abstraction for ingesting and representing
|
||||
features. FeatureColumns are also the primary way of encoding features for
|
||||
canned `tf.estimator.Estimator`s.
|
||||
|
||||
When using FeatureColumns with `Estimators`, the type of feature column you
|
||||
should choose depends on (1) the feature type and (2) the model type.
|
||||
|
||||
1. Feature type:
|
||||
|
||||
* Continuous features can be represented by `numeric_column`.
|
||||
* Categorical features can be represented by any `categorical_column_with_*`
|
||||
column:
|
||||
- `categorical_column_with_vocabulary_list`
|
||||
- `categorical_column_with_vocabulary_file`
|
||||
- `categorical_column_with_hash_bucket`
|
||||
- `categorical_column_with_identity`
|
||||
- `weighted_categorical_column`
|
||||
|
||||
2. Model type:
|
||||
|
||||
* Deep neural network models (`DNNClassifier`, `DNNRegressor`).
|
||||
|
||||
Continuous features can be directly fed into deep neural network models.
|
||||
|
||||
age_column = numeric_column("age")
|
||||
|
||||
To feed sparse features into DNN models, wrap the column with
|
||||
`embedding_column` or `indicator_column`. `indicator_column` is recommended
|
||||
for features with only a few possible values. For features with many
|
||||
possible values, to reduce the size of your model, `embedding_column` is
|
||||
recommended.
|
||||
|
||||
embedded_dept_column = embedding_column(
|
||||
categorical_column_with_vocabulary_list(
|
||||
"department", ["math", "philosophy", ...]), dimension=10)
|
||||
|
||||
* Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
|
||||
|
||||
Sparse features can be fed directly into linear models. They behave like an
|
||||
indicator column but with an efficient implementation.
|
||||
|
||||
dept_column = categorical_column_with_vocabulary_list("department",
|
||||
["math", "philosophy", "english"])
|
||||
|
||||
It is recommended that continuous features be bucketized before being
|
||||
fed into linear models.
|
||||
|
||||
bucketized_age_column = bucketized_column(
|
||||
source_column=age_column,
|
||||
boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
|
||||
|
||||
Sparse features can be crossed (also known as conjuncted or combined) in
|
||||
order to form non-linearities, and then fed into linear models.
|
||||
|
||||
cross_dept_age_column = crossed_column(
|
||||
columns=["department", bucketized_age_column],
|
||||
hash_bucket_size=1000)
|
||||
|
||||
Example of building canned `Estimator`s using FeatureColumns:
|
||||
|
||||
```python
|
||||
# Define features and transformations
|
||||
deep_feature_columns = [age_column, embedded_dept_column]
|
||||
wide_feature_columns = [dept_column, bucketized_age_column,
|
||||
cross_dept_age_column]
|
||||
|
||||
# Build deep model
|
||||
estimator = DNNClassifier(
|
||||
feature_columns=deep_feature_columns,
|
||||
hidden_units=[500, 250, 50])
|
||||
estimator.train(...)
|
||||
|
||||
# Or build a wide model
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=wide_feature_columns)
|
||||
estimator.train(...)
|
||||
|
||||
# Or build a wide and deep model!
|
||||
estimator = DNNLinearCombinedClassifier(
|
||||
linear_feature_columns=wide_feature_columns,
|
||||
dnn_feature_columns=deep_feature_columns,
|
||||
dnn_hidden_units=[500, 250, 50])
|
||||
estimator.train(...)
|
||||
```
|
||||
|
||||
|
||||
FeatureColumns can also be transformed into a generic input layer for
|
||||
custom models using `input_layer`.
|
||||
|
||||
Example of building model using FeatureColumns, this can be used in a
|
||||
`model_fn` which is given to the {tf.estimator.Estimator}:
|
||||
|
||||
```python
|
||||
# Building model via layers
|
||||
|
||||
deep_feature_columns = [age_column, embedded_dept_column]
|
||||
columns_to_tensor = parse_feature_columns_from_examples(
|
||||
serialized=my_data,
|
||||
feature_columns=deep_feature_columns)
|
||||
first_layer = input_layer(
|
||||
features=columns_to_tensor,
|
||||
feature_columns=deep_feature_columns)
|
||||
second_layer = fully_connected(first_layer, ...)
|
||||
```
|
||||
|
||||
NOTE: Functions prefixed with "_" indicate experimental or private parts of
|
||||
the API subject to change, and should not be relied upon!
|
||||
|
||||
|
|
@ -850,35 +748,6 @@ def _embedding_column(categorical_column,
|
|||
`categorical_column_*` function. Here is an example of using
|
||||
`embedding_column` with `DNNClassifier`:
|
||||
|
||||
```python
|
||||
video_id = categorical_column_with_identity(
|
||||
key='video_id', num_buckets=1000000, default_value=0)
|
||||
columns = [embedding_column(video_id, 9),...]
|
||||
|
||||
estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
|
||||
|
||||
label_column = ...
|
||||
def input_fn():
|
||||
features = tf.io.parse_example(
|
||||
..., features=make_parse_example_spec(columns + [label_column]))
|
||||
labels = features.pop(label_column.name)
|
||||
return features, labels
|
||||
|
||||
estimator.train(input_fn=input_fn, steps=100)
|
||||
```
|
||||
|
||||
Here is an example using `embedding_column` with model_fn:
|
||||
|
||||
```python
|
||||
def model_fn(features, ...):
|
||||
video_id = categorical_column_with_identity(
|
||||
key='video_id', num_buckets=1000000, default_value=0)
|
||||
columns = [embedding_column(video_id, 9),...]
|
||||
dense_tensor = input_layer(features, columns)
|
||||
# Form DNN layers, calculate loss, and return EstimatorSpec.
|
||||
...
|
||||
```
|
||||
|
||||
Args:
|
||||
categorical_column: A `_CategoricalColumn` created by a
|
||||
`categorical_column_with_*` function. This column produces the sparse IDs
|
||||
|
|
|
|||
|
|
@ -15,110 +15,11 @@
|
|||
"""This API defines FeatureColumn abstraction.
|
||||
|
||||
FeatureColumns provide a high level abstraction for ingesting and representing
|
||||
features. FeatureColumns are also the primary way of encoding features for
|
||||
canned `tf.estimator.Estimator`s.
|
||||
|
||||
When using FeatureColumns with `Estimators`, the type of feature column you
|
||||
should choose depends on (1) the feature type and (2) the model type.
|
||||
|
||||
1. Feature type:
|
||||
|
||||
* Continuous features can be represented by `numeric_column`.
|
||||
* Categorical features can be represented by any `categorical_column_with_*`
|
||||
column:
|
||||
- `categorical_column_with_vocabulary_list`
|
||||
- `categorical_column_with_vocabulary_file`
|
||||
- `categorical_column_with_hash_bucket`
|
||||
- `categorical_column_with_identity`
|
||||
- `weighted_categorical_column`
|
||||
|
||||
2. Model type:
|
||||
|
||||
* Deep neural network models (`DNNClassifier`, `DNNRegressor`).
|
||||
|
||||
Continuous features can be directly fed into deep neural network models.
|
||||
|
||||
age_column = numeric_column("age")
|
||||
|
||||
To feed sparse features into DNN models, wrap the column with
|
||||
`embedding_column` or `indicator_column`. `indicator_column` is recommended
|
||||
for features with only a few possible values. For features with many
|
||||
possible values, to reduce the size of your model, `embedding_column` is
|
||||
recommended.
|
||||
|
||||
embedded_dept_column = embedding_column(
|
||||
categorical_column_with_vocabulary_list(
|
||||
"department", ["math", "philosophy", ...]), dimension=10)
|
||||
|
||||
* Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
|
||||
|
||||
Sparse features can be fed directly into linear models. They behave like an
|
||||
indicator column but with an efficient implementation.
|
||||
|
||||
dept_column = categorical_column_with_vocabulary_list("department",
|
||||
["math", "philosophy", "english"])
|
||||
|
||||
It is recommended that continuous features be bucketized before being
|
||||
fed into linear models.
|
||||
|
||||
bucketized_age_column = bucketized_column(
|
||||
source_column=age_column,
|
||||
boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
|
||||
|
||||
Sparse features can be crossed (also known as conjuncted or combined) in
|
||||
order to form non-linearities, and then fed into linear models.
|
||||
|
||||
cross_dept_age_column = crossed_column(
|
||||
columns=["department", bucketized_age_column],
|
||||
hash_bucket_size=1000)
|
||||
|
||||
Example of building canned `Estimator`s using FeatureColumns:
|
||||
|
||||
```python
|
||||
# Define features and transformations
|
||||
deep_feature_columns = [age_column, embedded_dept_column]
|
||||
wide_feature_columns = [dept_column, bucketized_age_column,
|
||||
cross_dept_age_column]
|
||||
|
||||
# Build deep model
|
||||
estimator = DNNClassifier(
|
||||
feature_columns=deep_feature_columns,
|
||||
hidden_units=[500, 250, 50])
|
||||
estimator.train(...)
|
||||
|
||||
# Or build a wide model
|
||||
estimator = LinearClassifier(
|
||||
feature_columns=wide_feature_columns)
|
||||
estimator.train(...)
|
||||
|
||||
# Or build a wide and deep model!
|
||||
estimator = DNNLinearCombinedClassifier(
|
||||
linear_feature_columns=wide_feature_columns,
|
||||
dnn_feature_columns=deep_feature_columns,
|
||||
dnn_hidden_units=[500, 250, 50])
|
||||
estimator.train(...)
|
||||
```
|
||||
|
||||
features.
|
||||
|
||||
FeatureColumns can also be transformed into a generic input layer for
|
||||
custom models using `input_layer`.
|
||||
|
||||
Example of building model using FeatureColumns, this can be used in a
|
||||
`model_fn` which is given to the {tf.estimator.Estimator}:
|
||||
|
||||
```python
|
||||
# Building model via layers
|
||||
|
||||
deep_feature_columns = [age_column, embedded_dept_column]
|
||||
columns_to_tensor = parse_feature_columns_from_examples(
|
||||
serialized=my_data,
|
||||
feature_columns=deep_feature_columns)
|
||||
first_layer = input_layer(
|
||||
features=columns_to_tensor,
|
||||
feature_columns=deep_feature_columns)
|
||||
second_layer = fully_connected(first_layer, ...)
|
||||
```
|
||||
|
||||
NOTE: Functions prefixed with "_" indicate experimental or private parts of
|
||||
the API subject to change, and should not be relied upon!
|
||||
"""
|
||||
|
|
@ -548,39 +449,6 @@ def embedding_column(categorical_column,
|
|||
Use this when your inputs are sparse, but you want to convert them to a dense
|
||||
representation (e.g., to feed to a DNN).
|
||||
|
||||
Inputs must be a `CategoricalColumn` created by any of the
|
||||
`categorical_column_*` function. Here is an example of using
|
||||
`embedding_column` with `DNNClassifier`:
|
||||
|
||||
```python
|
||||
video_id = categorical_column_with_identity(
|
||||
key='video_id', num_buckets=1000000, default_value=0)
|
||||
columns = [embedding_column(video_id, 9),...]
|
||||
|
||||
estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
|
||||
|
||||
label_column = ...
|
||||
def input_fn():
|
||||
features = tf.io.parse_example(
|
||||
..., features=make_parse_example_spec(columns + [label_column]))
|
||||
labels = features.pop(label_column.name)
|
||||
return features, labels
|
||||
|
||||
estimator.train(input_fn=input_fn, steps=100)
|
||||
```
|
||||
|
||||
Here is an example using `embedding_column` with model_fn:
|
||||
|
||||
```python
|
||||
def model_fn(features, ...):
|
||||
video_id = categorical_column_with_identity(
|
||||
key='video_id', num_buckets=1000000, default_value=0)
|
||||
columns = [embedding_column(video_id, 9),...]
|
||||
dense_tensor = input_layer(features, columns)
|
||||
# Form DNN layers, calculate loss, and return EstimatorSpec.
|
||||
...
|
||||
```
|
||||
|
||||
Args:
|
||||
categorical_column: A `CategoricalColumn` created by a
|
||||
`categorical_column_with_*` function. This column produces the sparse IDs
|
||||
|
|
@ -675,43 +543,6 @@ def shared_embedding_columns(categorical_columns,
|
|||
categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
|
||||
all columns could also be weighted_categorical_column.
|
||||
|
||||
Here is an example embedding of two features for a DNNClassifier model:
|
||||
|
||||
```python
|
||||
watched_video_id = categorical_column_with_vocabulary_file(
|
||||
'watched_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
impression_video_id = categorical_column_with_vocabulary_file(
|
||||
'impression_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
columns = shared_embedding_columns(
|
||||
[watched_video_id, impression_video_id], dimension=10)
|
||||
|
||||
estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
|
||||
|
||||
label_column = ...
|
||||
def input_fn():
|
||||
features = tf.io.parse_example(
|
||||
..., features=make_parse_example_spec(columns + [label_column]))
|
||||
labels = features.pop(label_column.name)
|
||||
return features, labels
|
||||
|
||||
estimator.train(input_fn=input_fn, steps=100)
|
||||
```
|
||||
|
||||
Here is an example using `shared_embedding_columns` with model_fn:
|
||||
|
||||
```python
|
||||
def model_fn(features, ...):
|
||||
watched_video_id = categorical_column_with_vocabulary_file(
|
||||
'watched_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
impression_video_id = categorical_column_with_vocabulary_file(
|
||||
'impression_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
columns = shared_embedding_columns(
|
||||
[watched_video_id, impression_video_id], dimension=10)
|
||||
dense_tensor = input_layer(features, columns)
|
||||
# Form DNN layers, calculate loss, and return EstimatorSpec.
|
||||
...
|
||||
```
|
||||
|
||||
Args:
|
||||
categorical_columns: List of categorical columns created by a
|
||||
`categorical_column_with_*` function. These columns produce the sparse IDs
|
||||
|
|
@ -871,43 +702,6 @@ def shared_embedding_columns_v2(categorical_columns,
|
|||
categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
|
||||
all columns could also be weighted_categorical_column.
|
||||
|
||||
Here is an example embedding of two features for a DNNClassifier model:
|
||||
|
||||
```python
|
||||
watched_video_id = categorical_column_with_vocabulary_file(
|
||||
'watched_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
impression_video_id = categorical_column_with_vocabulary_file(
|
||||
'impression_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
columns = shared_embedding_columns(
|
||||
[watched_video_id, impression_video_id], dimension=10)
|
||||
|
||||
estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
|
||||
|
||||
label_column = ...
|
||||
def input_fn():
|
||||
features = tf.io.parse_example(
|
||||
..., features=make_parse_example_spec(columns + [label_column]))
|
||||
labels = features.pop(label_column.name)
|
||||
return features, labels
|
||||
|
||||
estimator.train(input_fn=input_fn, steps=100)
|
||||
```
|
||||
|
||||
Here is an example using `shared_embedding_columns` with model_fn:
|
||||
|
||||
```python
|
||||
def model_fn(features, ...):
|
||||
watched_video_id = categorical_column_with_vocabulary_file(
|
||||
'watched_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
impression_video_id = categorical_column_with_vocabulary_file(
|
||||
'impression_video_id', video_vocabulary_file, video_vocabulary_size)
|
||||
columns = shared_embedding_columns(
|
||||
[watched_video_id, impression_video_id], dimension=10)
|
||||
dense_tensor = input_layer(features, columns)
|
||||
# Form DNN layers, calculate loss, and return EstimatorSpec.
|
||||
...
|
||||
```
|
||||
|
||||
Args:
|
||||
categorical_columns: List of categorical columns created by a
|
||||
`categorical_column_with_*` function. These columns produce the sparse IDs
|
||||
|
|
|
|||
|
|
@ -2088,7 +2088,6 @@ pytype_strict_library(
|
|||
srcs_version = "PY3",
|
||||
visibility = visibility + [
|
||||
"//tensorflow:internal",
|
||||
"//tensorflow_estimator/python/estimator:__subpackages__",
|
||||
"//tensorflow_model_optimization:__subpackages__",
|
||||
"//third_party/cloud_tpu/convergence_tools:__subpackages__",
|
||||
"//third_party/py/neural_structured_learning:__subpackages__",
|
||||
|
|
|
|||
|
|
@ -609,8 +609,6 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
|
|||
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
|
||||
self.assertEqual(len(greater_op_outputs), 1)
|
||||
self.assertAllClose(greater_op_outputs[0], False)
|
||||
# This was needed for backwards compatibility with TF2 Estimators which
|
||||
# rely on variable names.
|
||||
prefix = b"cond/" if context.executing_eagerly() else b""
|
||||
pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix]
|
||||
self.assertEqual(len(pow_op_outputs), 1)
|
||||
|
|
|
|||
|
|
@ -187,9 +187,7 @@ def _as_graph_element(obj):
|
|||
return None
|
||||
|
||||
|
||||
# Deprecated - do not use.
|
||||
# This API to avoid breaking estimator and tensorflow-mesh which depend on this
|
||||
# internal API. The stub should be safe to use after TF 2.3 is released.
|
||||
# Deprecated - legacy purposes only.
|
||||
def is_dense_tensor_like(t) -> bool:
|
||||
return isinstance(t, core_tf_types.Tensor)
|
||||
|
||||
|
|
@ -2036,13 +2034,10 @@ class Graph(pywrap_tf_session.PyGraph):
|
|||
# actual outside graph).
|
||||
self._graph_key = "graph-key-%d/" % (uid(),)
|
||||
# A string with the last reduction method passed to
|
||||
# losses.compute_weighted_loss(), or None. This is required only for
|
||||
# backward compatibility with Estimator and optimizer V1 use cases.
|
||||
# losses.compute_weighted_loss(), or None.
|
||||
# Backward compatibility with optimizer V1 use cases.
|
||||
self._last_loss_reduction = None
|
||||
# Flag that is used to indicate whether loss has been scaled by optimizer.
|
||||
# If this flag has been set, then estimator uses it to scale losss back
|
||||
# before reporting. This is required only for backward compatibility with
|
||||
# Estimator and optimizer V1 use cases.
|
||||
# Required only for backward compatibility with optimizer V1 use cases.
|
||||
self._is_loss_scaled_by_optimizer = False
|
||||
self._container = ""
|
||||
|
||||
|
|
|
|||
|
|
@ -194,8 +194,6 @@ def create_keras_history(tensors):
|
|||
# (Only via Savedmodels). It may also change the semantics of whether
|
||||
# generated random numbers are generated once and re-used, or recomputed
|
||||
# each time.
|
||||
# Note: This path triggers for TPUEstimators / xla compiled graphs regardless
|
||||
# of this setting.
|
||||
_UNSAFE_GRAPH_OP_LAYER_CREATION = False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -838,13 +838,10 @@ class Layer(base_layer.Layer):
|
|||
raise ValueError(
|
||||
'Your Layer or Model is in an invalid state. '
|
||||
'This can happen for the following cases:\n '
|
||||
'1. You might be interleaving estimator/non-estimator models or '
|
||||
'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
|
||||
'with models/layers created outside of it. '
|
||||
'Converting a model to an estimator (via model_to_estimator) '
|
||||
'invalidates all models/layers made before the conversion (even '
|
||||
'if they were not the model converted to an estimator). '
|
||||
'Similarly, making a layer or a model inside a '
|
||||
'1. You might be interleaving models/layers made in '
|
||||
'tf.compat.v1.Graph.as_default() with models/layers created '
|
||||
'outside of it.\n'
|
||||
'Making a layer or a model inside a '
|
||||
'a tf.compat.v1.Graph invalidates all layers/models you previously '
|
||||
'made outside of the graph.\n'
|
||||
'2. You might be using a custom keras layer implementation with '
|
||||
|
|
|
|||
|
|
@ -361,7 +361,7 @@ class Model(training_lib.Model):
|
|||
parameter_server_strategy.ParameterServerStrategyV1):
|
||||
raise NotImplementedError(
|
||||
'`tf.compat.v1.distribute.experimental.ParameterServerStrategy` '
|
||||
'currently only works with the tf.Estimator API')
|
||||
'currently only works with the deprecated tf.Estimator API')
|
||||
|
||||
if isinstance(self._distribution_strategy,
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2):
|
||||
|
|
@ -1513,9 +1513,7 @@ class Model(training_lib.Model):
|
|||
"""Compiles the model loss and weighted metric sub-graphs.
|
||||
|
||||
This may be used to set graph tensors as sample weights (instead of creating
|
||||
placeholders). This functionality is necessary for
|
||||
`tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1
|
||||
graph, and creates iterator tensors for inputs, targets, and sample weights.
|
||||
placeholders).
|
||||
|
||||
Args:
|
||||
sample_weights: List of tensors to use as the sample weights. Must be the
|
||||
|
|
|
|||
|
|
@ -147,11 +147,10 @@ class Regularizer(object):
|
|||
training and executing models, exporting to and from SavedModels, or saving
|
||||
and loading weight checkpoints.
|
||||
|
||||
Registration is required for Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON. If using this functionality,
|
||||
you must make sure any python process running your model has also defined
|
||||
and registered your custom regularizer.
|
||||
Registration is required for saving and loading models to HDF5 formats,
|
||||
Keras model cloning, some visualization utilities, and exporting models to and
|
||||
from JSON. If using this functionality, you must make sure any python process
|
||||
running your model has also defined and registered your custom regularizer.
|
||||
|
||||
`tf.keras.utils.register_keras_serializable` is only available in TF 2.1 and
|
||||
beyond. In earlier versions of TensorFlow you must pass your custom
|
||||
|
|
@ -171,9 +170,9 @@ class Regularizer(object):
|
|||
capable of instantiating the same regularizer from the config
|
||||
dictionary.
|
||||
|
||||
This method is used by Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON.
|
||||
This method is used by saving and loading models to HDF5 formats,
|
||||
Keras model cloning, some visualization utilities,
|
||||
and exporting models to and from JSON.
|
||||
|
||||
Args:
|
||||
config: A Python dictionary, typically the output of get_config.
|
||||
|
|
@ -194,9 +193,9 @@ class Regularizer(object):
|
|||
This method is optional if you are just training and executing models,
|
||||
exporting to and from SavedModels, or using weight checkpoints.
|
||||
|
||||
This method is required for Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON.
|
||||
This method is required for saving and loading models to HDF5 formats,
|
||||
Keras model cloning, some visualization utilities,
|
||||
and exporting models to and from JSON.
|
||||
|
||||
Returns:
|
||||
Python dictionary.
|
||||
|
|
|
|||
|
|
@ -189,8 +189,7 @@ def _save_v1_format(model, path, custom_objects, as_text, input_signature):
|
|||
# one save is needed once the weights can be copied from the model to clone.
|
||||
checkpoint_path = _export_model_variables(model, path)
|
||||
|
||||
# Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
|
||||
# Keras models and `Estimator`s are exported with the same format.
|
||||
# Export each mode.
|
||||
# Every time a mode is exported, the code checks to see if new variables have
|
||||
# been created (e.g. optimizer slot variables). If that is the case, the
|
||||
# checkpoint is re-saved to include the new variables.
|
||||
|
|
@ -234,7 +233,7 @@ def _export_mode(
|
|||
"""Exports a model, and optionally saves new vars from the clone model.
|
||||
|
||||
Args:
|
||||
mode: A `tf.estimator.ModeKeys` string.
|
||||
mode: A `KerasModeKeys` string.
|
||||
has_saved_vars: A `boolean` indicating whether the SavedModel has already
|
||||
exported variables.
|
||||
builder: A `SavedModelBuilder` object.
|
||||
|
|
@ -271,8 +270,7 @@ def _export_mode(
|
|||
|
||||
# Make sure that iterations variable is added to the global step collection,
|
||||
# to ensure that, when the SavedModel graph is loaded, the iterations
|
||||
# variable is returned by `tf.compat.v1.train.get_global_step()`. This is
|
||||
# required for compatibility with the SavedModelEstimator.
|
||||
# variable is returned by `tf.compat.v1.train.get_global_step()`.
|
||||
if compile_clone:
|
||||
g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from tensorflow.python.util import nest
|
|||
def extract_model_metrics(model):
|
||||
"""Convert metrics from a Keras model `compile` API to dictionary.
|
||||
|
||||
This is used for converting Keras models to Estimators and SavedModels.
|
||||
This is used for converting Keras models to SavedModels.
|
||||
|
||||
Args:
|
||||
model: A `tf.keras.Model` object.
|
||||
|
|
@ -44,7 +44,6 @@ def extract_model_metrics(model):
|
|||
the model does not contain any metrics.
|
||||
"""
|
||||
if getattr(model, '_compile_metrics', None):
|
||||
# TODO(psv/kathywu): use this implementation in model to estimator flow.
|
||||
# We are not using model.metrics here because we want to exclude the metrics
|
||||
# added using `add_metric` API.
|
||||
return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# LINT.IfChange
|
||||
"""Utils for saving a Keras Model or Estimator to the SavedModel format."""
|
||||
"""Utils for saving a Keras Model to the SavedModel format."""
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.keras.saving.utils_v1.export_output import *
|
||||
from tensorflow.python.keras.saving.utils_v1.export_utils import build_all_signature_defs
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ def export_outputs_for_mode(
|
|||
metric_value must be a Tensor, and update_op must be a Tensor or Op
|
||||
|
||||
Returns:
|
||||
Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
|
||||
Dictionary mapping the key to an `ExportOutput` object.
|
||||
The key is the expected SignatureDef key for the mode.
|
||||
|
||||
Raises:
|
||||
|
|
@ -279,8 +279,6 @@ def export_outputs_for_mode(
|
|||
if mode not in SIGNATURE_KEY_MAP:
|
||||
raise ValueError(
|
||||
'Export output type not found for mode: {}. Expected one of: {}.\n'
|
||||
'One likely error is that V1 Estimator Modekeys were somehow passed to '
|
||||
'this function. Please ensure that you are using the new ModeKeys.'
|
||||
.format(mode, SIGNATURE_KEY_MAP.keys()))
|
||||
signature_key = SIGNATURE_KEY_MAP[mode]
|
||||
if mode_keys.is_predict(mode):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# LINT.IfChange
|
||||
"""Utils for managing different mode strings used by Keras and Estimator models.
|
||||
"""Utils for managing different mode strings used by Keras models.
|
||||
"""
|
||||
|
||||
import collections
|
||||
|
|
@ -34,44 +34,24 @@ class KerasModeKeys:
|
|||
PREDICT = 'predict'
|
||||
|
||||
|
||||
# TODO(kathywu): Remove copy in Estimator after nightlies
|
||||
class EstimatorModeKeys:
|
||||
"""Standard names for Estimator model modes.
|
||||
|
||||
The following standard keys are defined:
|
||||
|
||||
* `TRAIN`: training/fitting mode.
|
||||
* `EVAL`: testing/evaluation mode.
|
||||
* `PREDICT`: predication/inference mode.
|
||||
"""
|
||||
|
||||
TRAIN = 'train'
|
||||
EVAL = 'eval'
|
||||
PREDICT = 'infer'
|
||||
|
||||
|
||||
def is_predict(mode):
|
||||
return mode in [KerasModeKeys.PREDICT, EstimatorModeKeys.PREDICT]
|
||||
return mode == KerasModeKeys.PREDICT
|
||||
|
||||
|
||||
def is_eval(mode):
|
||||
return mode in [KerasModeKeys.TEST, EstimatorModeKeys.EVAL]
|
||||
return mode == KerasModeKeys.TEST
|
||||
|
||||
|
||||
def is_train(mode):
|
||||
return mode in [KerasModeKeys.TRAIN, EstimatorModeKeys.TRAIN]
|
||||
return mode == KerasModeKeys.TRAIN
|
||||
|
||||
|
||||
class ModeKeyMap(collections.abc.Mapping):
|
||||
"""Map using ModeKeys as keys.
|
||||
|
||||
This class creates an immutable mapping from modes to values. For example,
|
||||
SavedModel export of Keras and Estimator models use this to map modes to their
|
||||
SavedModel export of Keras models use this to map modes to their
|
||||
corresponding MetaGraph tags/SignatureDef keys.
|
||||
|
||||
Since this class uses modes, rather than strings, as keys, both "predict"
|
||||
(Keras's PREDICT ModeKey) and "infer" (Estimator's PREDICT ModeKey) map to the
|
||||
same value.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
|
|
|||
|
|
@ -874,8 +874,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||
return tf_cond.cond(
|
||||
pred, lambda: true_fn(inputs), lambda: false_fn(inputs))
|
||||
|
||||
# This was needed for backwards compatibility with TF2 Estimators which
|
||||
# rely on variable names.
|
||||
prefix = "cond/" if context.executing_eagerly() else ""
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
|
@ -907,8 +905,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||
lambda: br4_fn(inputs)
|
||||
])
|
||||
|
||||
# This was needed for backwards compatibility with TF2 Estimators which
|
||||
# rely on variable names.
|
||||
prefix = "switch_case/indexed_case/" if context.executing_eagerly() else ""
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Tensor %sbr1_identity:0 in branch 1 is "
|
||||
|
|
|
|||
|
|
@ -578,10 +578,7 @@ tf_gen_op_strict_wrapper_private_py(
|
|||
|
||||
tf_gen_op_strict_wrapper_private_py(
|
||||
name = "sdca_ops_gen",
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
"//tensorflow_estimator/python/estimator/canned/linear_optimizer:__pkg__",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
tf_gen_op_strict_wrapper_private_py(
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ RANDOM_INIT = 'random'
|
|||
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
|
||||
KMC2_INIT = 'kmc2'
|
||||
|
||||
# The name of the variable holding the cluster centers. Used by the Estimator.
|
||||
CLUSTERS_VAR_NAME = 'clusters'
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -406,15 +406,15 @@ def set_output_all_intermediates(state): # pylint: disable=invalid-name
|
|||
"""Whether to output all intermediates from functional control flow ops.
|
||||
|
||||
The "default" behavior to is to output all intermediates when using v2 control
|
||||
flow inside Keras models in graph mode (possibly inside Estimators). This is
|
||||
needed to support taking gradients of v2 control flow. In graph mode, Keras
|
||||
can sometimes freeze the forward graph before the gradient computation which
|
||||
does not work for v2 control flow since it requires updating the forward ops
|
||||
to output the needed intermediates. We work around this by proactively
|
||||
outputting the needed intermediates when building the forward pass itself.
|
||||
Ideally any such extra tensors should be pruned out at runtime. However, if
|
||||
for any reason this doesn't work for you or if you have an inference-only
|
||||
model you can turn this behavior off using
|
||||
flow inside Keras models in graph mode. This is needed to support taking
|
||||
gradients of v2 control flow. In graph mode, Keras can sometimes freeze the
|
||||
forward graph before the gradient computation which does not work for v2
|
||||
control flow since it requires updating the forward ops to output the needed
|
||||
intermediates. We work around this by proactively outputting the needed
|
||||
intermediates when building the forward pass itself. Ideally any such extra
|
||||
tensors should be pruned out at runtime. However, if for any reason this
|
||||
doesn't work for you or if you have an inference-only model you can turn this
|
||||
behavior off using
|
||||
`tf.compat.v1.experimental.output_all_intermediates(False)`.
|
||||
|
||||
If with the default behavior you are still seeing errors of the form
|
||||
|
|
|
|||
|
|
@ -331,9 +331,9 @@ def _internal_py_func(func,
|
|||
func = EagerFunc(func, Tout, is_grad_func)
|
||||
|
||||
# Tying the registered function's lifetime with the current default graph is
|
||||
# not reliable. For example, Estimator-based binaries may switch graphs in
|
||||
# between model training end evaluation, via saved_model. Those binaries work
|
||||
# because the original function is global, and break once the registered
|
||||
# not reliable. For example, a binary may switch graphs in between model
|
||||
# training end evaluation, via saved_model. Those binaries work because the
|
||||
# original function is global, and break once the registered
|
||||
# function is an anonymous lambda, like the one produced by do_not_convert.
|
||||
# To avoid breaking those cases, we attach the wrapper to the original
|
||||
# function so that their lifetime is connected.
|
||||
|
|
|
|||
|
|
@ -142,8 +142,6 @@ def while_loop(cond,
|
|||
return math_ops.logical_and(
|
||||
loop_counter < maximum_iterations_arg, pred)
|
||||
|
||||
# NOTE(skyewm): we set collections to the outer graph's collections for
|
||||
# compatibility with TPUEstimator.
|
||||
cond_graph = func_graph_module.func_graph_from_py_func(
|
||||
cond_name,
|
||||
wrapped_cond,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user