Make TPU symbols more easily accessible from contrib.

PiperOrigin-RevId: 165753322
This commit is contained in:
A. Unique TensorFlower 2017-08-18 14:30:07 -07:00 committed by TensorFlower Gardener
parent cdc08afbb2
commit a0544b0b8e
4 changed files with 44 additions and 82 deletions

View File

@ -77,9 +77,7 @@ py_library(
"//tensorflow/contrib/text:text_py",
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/timeseries",
"//tensorflow/contrib/tpu:tpu_estimator",
"//tensorflow/contrib/tpu:tpu_helper_library",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/tpu",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
],

View File

@ -38,9 +38,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":tpu",
":tpu_feed",
":tpu_py",
":training_loop",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
@ -123,89 +121,32 @@ tf_custom_op_py_library(
],
)
py_library(
name = "tpu_helper_library",
srcs_version = "PY2AND3",
deps = [
":tpu",
":tpu_feed",
":tpu_function",
":tpu_py",
":tpu_sharding",
":training_loop",
],
)
py_library(
name = "tpu_function",
srcs = ["python/tpu/tpu_function.py"],
srcs_version = "PY2AND3",
deps = [
":tpu_feed",
"//tensorflow/python:util",
],
)
py_library(
name = "tpu",
srcs = [
"python/tpu/__init__.py",
"python/tpu/tpu.py",
],
srcs_version = "PY2AND3",
deps = [
":profiler",
":tpu_function",
":tpu_py",
":training_loop",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variable_scope",
],
)
py_library(
name = "tpu_sharding",
srcs = ["python/tpu/tpu_sharding.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework",
"//tensorflow/python:tensor_shape",
],
)
py_library(
name = "tpu_feed",
srcs = ["python/tpu/tpu_feed.py"],
srcs_version = "PY2AND3",
deps = [
":tpu_py",
":tpu_sharding",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
],
)
py_library(
name = "training_loop",
srcs = [
"python/tpu/tpu_feed.py",
"python/tpu/tpu_function.py",
"python/tpu/tpu_optimizer.py",
"python/tpu/tpu_sharding.py",
"python/tpu/training_loop.py",
],
srcs_version = "PY2AND3",
deps = [
":tpu_function",
":profiler",
":tpu_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/losses",
],
)
@ -214,7 +155,7 @@ tf_py_test(
size = "small",
srcs = ["python/tpu/tpu_sharding_test.py"],
additional_deps = [
":tpu_sharding",
":tpu",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
],
@ -225,8 +166,7 @@ tf_py_test(
size = "small",
srcs = ["python/tpu/tpu_infeed_test.py"],
additional_deps = [
":tpu_feed",
":tpu_sharding",
":tpu",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
@ -237,7 +177,7 @@ tf_py_test(
size = "small",
srcs = ["python/tpu/tpu_function_test.py"],
additional_deps = [
":tpu_function",
":tpu",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],

View File

@ -13,7 +13,30 @@
# limitations under the License.
# =============================================================================
"""Ops related to Tensor Processing Units."""
"""Ops related to Tensor Processing Units.
@@cross_replica_sum
@@infeed_dequeue
@@infeed_dequeue_tuple
@@outfeed_enqueue
@@outfeed_enqueue_tuple
@@initialize_system
@@shutdown_system
@@core
@@outside_all_rewrites
@@replicate
@@shard
@@batch_parallel
@@rewrite
@@CrossShardOptimizer
@@InfeedQueue
@@while_loop
@@repeat
"""
from __future__ import absolute_import
from __future__ import division
@ -22,7 +45,10 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
from tensorflow.contrib.tpu.python.tpu import *
from tensorflow.contrib.tpu.python.tpu.tpu import *
from tensorflow.contrib.tpu.python.tpu.tpu_feed import *
from tensorflow.contrib.tpu.python.tpu.tpu_optimizer import *
from tensorflow.contrib.tpu.python.tpu.training_loop import *
# pylint: enable=wildcard-import,unused-import
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -166,9 +166,7 @@ sh_binary(
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
"//tensorflow/contrib/timeseries:timeseries_pip",
"//tensorflow/contrib/tpu:tpu_estimator",
"//tensorflow/contrib/tpu:tpu_helper_library",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/tpu",
"//tensorflow/examples/tutorials/mnist:package",
"//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:meta_graph_testdata",