mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge branch 'tensorflow:master' into ica574-doc-contrib
This commit is contained in:
commit
b154e2c6fb
10
.bazelrc
10
.bazelrc
|
|
@ -194,6 +194,9 @@ build:macos --apple_platform_type=macos
|
|||
# gRPC on MacOS requires this #define
|
||||
build:macos --copt=-DGRPC_BAZEL_BUILD
|
||||
|
||||
# Avoid hitting command line argument limit
|
||||
build:macos --features=archive_param_file
|
||||
|
||||
# Settings for MacOS on ARM CPUs.
|
||||
build:macos_arm64 --cpu=darwin_arm64
|
||||
build:macos_arm64 --macos_minimum_os=11.0
|
||||
|
|
@ -345,6 +348,7 @@ build:windows --host_copt=/D_USE_MATH_DEFINES
|
|||
# Windows has a relatively short command line limit, which TF has begun to hit.
|
||||
# See https://docs.bazel.build/versions/main/windows.html
|
||||
build:windows --features=compiler_param_file
|
||||
build:windows --features=archive_param_file
|
||||
|
||||
# Speed Windows compile times. Available in VS 16.4 (we are on 16.11). See
|
||||
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
|
@ -446,7 +450,6 @@ build:rbe --bes_backend=buildeventservice.googleapis.com
|
|||
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
||||
build:rbe --bes_timeout=600s
|
||||
build:rbe --define=EXECUTOR=remote
|
||||
build:rbe --flaky_test_attempts=3
|
||||
build:rbe --jobs=800
|
||||
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
||||
build:rbe --remote_timeout=3600
|
||||
|
|
@ -627,7 +630,6 @@ try-import %workspace%/.bazelrc.user
|
|||
|
||||
# Here are bazelrc configs for release builds
|
||||
build:release_base --config=v2
|
||||
test:release_base --flaky_test_attempts=3
|
||||
test:release_base --test_size_filters=small,medium
|
||||
|
||||
build:release_cpu_linux --config=release_base
|
||||
|
|
@ -691,10 +693,10 @@ build:ubsan --linkopt -fsanitize=undefined
|
|||
build:ubsan --linkopt -lubsan
|
||||
|
||||
# Disable TFRT integration for now unless --config=tfrt is specified.
|
||||
build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug
|
||||
build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python
|
||||
# TODO(b/240450920): We are in the process of migrating JitRt backend to XLA
|
||||
# and while we are doing this we can't keep it buildable/testable in OSS.
|
||||
build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug
|
||||
build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python
|
||||
|
||||
# TF Fuzztest config
|
||||
try-import fuzztest.bazelrc
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
5.3.0
|
||||
6.1.0
|
||||
# NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss
|
||||
2
.github/bot_config.yml
vendored
2
.github/bot_config.yml
vendored
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
- synandi
|
||||
- sushreebarsa
|
||||
- SuryanarayanaY
|
||||
- tilakrayal
|
||||
# A list of assignees for compiler folder
|
||||
|
|
|
|||
1
.github/workflows/arm-cd.yml
vendored
1
.github/workflows/arm-cd.yml
vendored
|
|
@ -28,6 +28,7 @@ jobs:
|
|||
runs-on: [self-hosted, linux, ARM64]
|
||||
continue-on-error: ${{ matrix.experimental }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
pyver: ['3.8', '3.9', '3.10']
|
||||
experimental: [false]
|
||||
|
|
|
|||
3
.github/workflows/arm-ci-extended.yml
vendored
3
.github/workflows/arm-ci-extended.yml
vendored
|
|
@ -27,8 +27,9 @@ jobs:
|
|||
if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks
|
||||
runs-on: [self-hosted, linux, ARM64]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
pyver: ['3.10']
|
||||
pyver: ['3.8', '3.9', '3.10', '3.11']
|
||||
steps:
|
||||
- name: Stop old running containers (if any)
|
||||
shell: bash
|
||||
|
|
|
|||
12
.github/workflows/update-rbe.yml
vendored
12
.github/workflows/update-rbe.yml
vendored
|
|
@ -92,6 +92,18 @@ jobs:
|
|||
map sigbuild-r2.13-clang-python3.9 2.13-python3.9
|
||||
map sigbuild-r2.13-clang-python3.10 2.13-python3.10
|
||||
map sigbuild-r2.13-clang-python3.11 2.13-python3.11
|
||||
# TF 2.14
|
||||
map sigbuild-r2.14 2.14-python3.9
|
||||
map sigbuild-r2.14-python3.8 2.14-python3.8
|
||||
map sigbuild-r2.14-python3.9 2.14-python3.9
|
||||
map sigbuild-r2.14-python3.10 2.14-python3.10
|
||||
map sigbuild-r2.14-python3.11 2.14-python3.11
|
||||
# TF 2.14 + Clang (containers are the same, but env vars in configs.bzl are different)
|
||||
map sigbuild-r2.14-clang 2.14-python3.9
|
||||
map sigbuild-r2.14-clang-python3.8 2.14-python3.8
|
||||
map sigbuild-r2.14-clang-python3.9 2.14-python3.9
|
||||
map sigbuild-r2.14-clang-python3.10 2.14-python3.10
|
||||
map sigbuild-r2.14-clang-python3.11 2.14-python3.11
|
||||
- name: Create Pull Request with changes
|
||||
uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3
|
||||
with:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
<img src="https://www.tensorflow.org/images/tf_logo_horizontal.png">
|
||||
</div>
|
||||
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://doi.org/10.5281/zenodo.4724125)
|
||||
[](https://bestpractices.coreinfrastructure.org/projects/1486)
|
||||
|
|
@ -11,6 +11,8 @@
|
|||
[](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow-py)
|
||||
[](https://ossrank.com/p/44)
|
||||
[](CODE_OF_CONDUCT.md)
|
||||
[](https://tensorflow.github.io/build#TF%20Official%20Continuous)
|
||||
[](https://tensorflow.github.io/build#TF%20Official%20Nightly)
|
||||
|
||||
**`Documentation`** |
|
||||
------------------- |
|
||||
|
|
|
|||
35
RELEASE.md
35
RELEASE.md
|
|
@ -16,6 +16,11 @@
|
|||
2.13 may be used when it is necessary to determine if a value is
|
||||
specifically a symbolic tensor.
|
||||
|
||||
* `tf.compat.v1.Session`
|
||||
* `tf.compat.v1.Session.partial_run` and
|
||||
`tf.compat.v1.Session.partial_run_setup` will be deprecated in the
|
||||
next release.
|
||||
|
||||
# Known Caveats
|
||||
|
||||
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES).>
|
||||
|
|
@ -26,6 +31,15 @@
|
|||
|
||||
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
* `tf.keras`
|
||||
* `Model.compile` now support `steps_per_execution='auto'` as a parameter,
|
||||
allowing automatic tuning of steps per execution during `Model.fit`,
|
||||
`Model.predict`, and `Model.evaluate` for a significant performance boost.
|
||||
|
||||
* Enable JIT-compiled i64-indexed kernels on GPU for large tensors with more
|
||||
than 2**32 elements.
|
||||
* Unary GPU kernels: Abs, Atanh, Acos, Acosh, Asin, Asinh, Atan, Cos,
|
||||
Cosh, Sin, Sinh, Tan, Tanh.
|
||||
|
||||
# Bug Fixes and Other Changes
|
||||
* `tf.lite`
|
||||
|
|
@ -34,6 +48,22 @@
|
|||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
|
||||
* `tf.config.experimental.enable_tensor_float_32_execution`
|
||||
* Disabling TensorFloat-32 execution now causes TPUs to use float32
|
||||
precision for float32 matmuls and other ops. TPUs have always used
|
||||
bfloat16 precision for certain ops, like matmul, when such ops had float32
|
||||
inputs. Now, disabling TensorFloat-32 by calling
|
||||
`tf.config.experimental.enable_tensor_float_32_execution(False)` will
|
||||
cause TPUs to use float32 precision for such ops instead of bfloat16.
|
||||
|
||||
* `tf.experimental.dtensor`
|
||||
* API changes for Relayout. Added a new API, `dtensor.relayout_like`, for
|
||||
relayouting a tensor according to the layout of another tensor.
|
||||
* Added `dtensor.get_default_mesh`, for retrieving the current default
|
||||
mesh under the dtensor context.
|
||||
|
||||
* TensorFlow Debugger (tfdbg) CLI: ncurses-based CLI for tfdbg v1 was removed.
|
||||
|
||||
# Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
|
@ -185,6 +215,9 @@ This release contains contributions from many people at Google, as well as:
|
|||
`dataset = dataset.shuffle(dataset.cardinality())`. This will load the
|
||||
full dataset into memory so that it can be shuffled, so make sure to
|
||||
only use this with datasets of filenames or other small datasets.
|
||||
* Added a new `tf.data.experimental.pad_to_cardinality` transformation
|
||||
which pads a dataset with zero elements up to a specified cardinality.
|
||||
This is useful for avoiding partial batches while not dropping any data.
|
||||
|
||||
* `tf.math`
|
||||
|
||||
|
|
@ -243,6 +276,8 @@ This release contains contributions from many people at Google, as well as:
|
|||
|
||||
* `tf.lite`:
|
||||
* Add UINT32 support to tfl.pack
|
||||
* Add INT64 support to tfl.range
|
||||
* Add UINT32 support to tfl.concatenation
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
|
|
|||
17
ci/README.md
Normal file
17
ci/README.md
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# TensorFlow continuous integration
|
||||
|
||||
> **Warning** This folder is still under construction. It is part of an ongoing
|
||||
> effort to improve the structure of CI and build related files within the
|
||||
> TensorFlow repo. This warning will be removed when the contents of this
|
||||
> directory are stable and appropriate documentation around its usage is in
|
||||
> place.
|
||||
|
||||
Maintainer: TensorFlow DevInfra
|
||||
|
||||
********************************************************************************
|
||||
|
||||
The CI folder contains the configuration files and scripts used to build, test,
|
||||
and deploy TensorFlow. This folder is typically used by continuous integration
|
||||
(CI) tools to build and test TensorFlow whenever there is a change to the
|
||||
code. This folder is broken into subfolders that represent the level of support
|
||||
and ownership of the files contained within.
|
||||
17
ci/devinfra/README.md
Normal file
17
ci/devinfra/README.md
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# DevInfra CI Directory
|
||||
|
||||
> **Warning** This folder is still under construction. It is part of an ongoing
|
||||
> effort to improve the structure of CI and build related files within the
|
||||
> TensorFlow repo. This warning will be removed when the contents of this
|
||||
> directory are stable and appropriate documentation around its usage is in
|
||||
> place.
|
||||
|
||||
Maintainer: TensorFlow DevInfra
|
||||
|
||||
Issue Reporting: File an issue against this repo and tag
|
||||
[@devinfra](https://github.com/orgs/tensorflow/teams/devinfra)
|
||||
|
||||
********************************************************************************
|
||||
|
||||
A directory for build and CI related scripts and jobs managed by the TensorFlow
|
||||
DevInfra team but not part of the official build, test, or release process.
|
||||
17
ci/official/README.md
Normal file
17
ci/official/README.md
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Official CI Directory
|
||||
|
||||
> **Warning** This folder is still under construction. It is part of an ongoing
|
||||
> effort to improve the structure of CI and build related files within the
|
||||
> TensorFlow repo. This warning will be removed when the contents of this
|
||||
> directory are stable and appropriate documentation around its usage is in
|
||||
> place.
|
||||
|
||||
Maintainer: TensorFlow and TensorFlow DevInfra
|
||||
|
||||
Issue Reporting: File an issue against this repo and tag
|
||||
[@devinfra](https://github.com/orgs/tensorflow/teams/devinfra)
|
||||
|
||||
********************************************************************************
|
||||
|
||||
A directory for build and CI related scripts and jobs that are used and
|
||||
monitored as part of the official TensorFlow build, test, and release process.
|
||||
|
|
@ -964,7 +964,6 @@ def set_other_cuda_vars(environ_cp):
|
|||
|
||||
def system_specific_test_config(environ_cp):
|
||||
"""Add default build and test flags required for TF tests to bazelrc."""
|
||||
write_to_bazelrc('test --flaky_test_attempts=3')
|
||||
write_to_bazelrc('test --test_size_filters=small,medium')
|
||||
|
||||
# Each instance of --test_tag_filters or --build_tag_filters overrides all
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ PACKAGE_STATIC_DEPS = [
|
|||
"@local_execution_config_platform//:__subpackages__",
|
||||
"@mkl_dnn_acl_compatible//:__subpackages__",
|
||||
"@mkl_dnn_v1//:__subpackages__",
|
||||
"@ml_dtypes//:__subpackages__",
|
||||
"@nccl_archive//:__subpackages__",
|
||||
"@nvtx_archive//:__subpackages__",
|
||||
"@org_sqlite//:__subpackages__",
|
||||
|
|
@ -1036,7 +1037,13 @@ package_group(
|
|||
],
|
||||
)
|
||||
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
package_group(
|
||||
name = "ndarray_tensor_allow_list",
|
||||
packages = [
|
||||
"//third_party/py/courier/...",
|
||||
"//third_party/py/tensorfn/...",
|
||||
],
|
||||
)
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ cc_library(
|
|||
# TODO: Only include tf_tstring_hdrs. Don't expose the implementation of TF_TString to API
|
||||
# users.
|
||||
":tf_tstring",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -171,6 +172,7 @@ tf_cuda_library(
|
|||
":tf_buffer_internal",
|
||||
":tf_status_internal",
|
||||
":tf_tensor_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -238,6 +240,7 @@ tf_cuda_library(
|
|||
":tf_status_internal",
|
||||
":tf_tensor_internal",
|
||||
":tf_tstring",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:tstring",
|
||||
"//tensorflow/tsl/c:tsl_status",
|
||||
] + select({
|
||||
|
|
@ -881,7 +884,7 @@ tf_cc_test(
|
|||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["c_api_test.cc"],
|
||||
data = [
|
||||
":test_op1.so",
|
||||
|
|
@ -968,7 +971,7 @@ tf_cc_test(
|
|||
|
||||
tf_cc_test(
|
||||
name = "c_api_function_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["c_api_function_test.cc"],
|
||||
deps = [
|
||||
":c_api",
|
||||
|
|
@ -985,7 +988,7 @@ tf_cc_test(
|
|||
|
||||
tf_cc_test(
|
||||
name = "while_loop_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["while_loop_test.cc"],
|
||||
deps = [
|
||||
":c_api",
|
||||
|
|
@ -1013,7 +1016,7 @@ tf_kernel_library(
|
|||
|
||||
tf_cuda_cc_test(
|
||||
name = "env_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["env_test.cc"],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
|
|
@ -1032,7 +1035,7 @@ tf_cuda_cc_test(
|
|||
|
||||
tf_cuda_cc_test(
|
||||
name = "kernels_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["kernels_test.cc"],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
|
|
@ -1059,7 +1062,7 @@ tf_cuda_cc_test(
|
|||
|
||||
tf_cc_test(
|
||||
name = "ops_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["ops_test.cc"],
|
||||
linkopts = select({
|
||||
"//conditions:default": [],
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ struct TF_ImportGraphDefOptions {
|
|||
|
||||
// Backing memory for TensorId fields in opts.
|
||||
// TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
|
||||
std::list<tensorflow::string> tensor_id_data;
|
||||
std::vector<tensorflow::string> tensor_id_data;
|
||||
};
|
||||
|
||||
struct TF_ImportGraphDefResults {
|
||||
|
|
@ -152,7 +152,7 @@ struct TF_ImportGraphDefResults {
|
|||
std::vector<int> missing_unused_key_indexes;
|
||||
|
||||
// Backing memory for missing_unused_key_names values.
|
||||
std::list<tensorflow::string> missing_unused_key_names_data;
|
||||
std::vector<tensorflow::string> missing_unused_key_names_data;
|
||||
};
|
||||
|
||||
struct TF_DeviceList {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,12 @@ limitations under the License.
|
|||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#ifdef TF_CAPI_WEAK
|
||||
#define TF_CAPI_EXPORT \
|
||||
__attribute__((visibility("default"))) __attribute((weak))
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // TF_CAPI_WEAK
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ load(
|
|||
"tf_cuda_cc_test",
|
||||
"tf_cuda_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "internal_tfrt_deps")
|
||||
load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
|
|
@ -95,7 +95,7 @@ tf_cuda_library(
|
|||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
] + internal_tfrt_deps(),
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
|
|
@ -636,7 +636,7 @@ tf_cuda_library(
|
|||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"c_api_debug_test.cc",
|
||||
"c_api_test.cc",
|
||||
|
|
@ -653,7 +653,6 @@ tf_cuda_cc_test(
|
|||
":c_api_test_util",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
|
@ -663,10 +662,7 @@ tf_cuda_cc_test(
|
|||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
# copybara:uncomment_begin
|
||||
# "//tensorflow/core/tfrt/eager:c_api_tfrt",
|
||||
# "@tf_runtime//backends/cpu:tf_ops_alwayslink",
|
||||
# copybara:uncomment_end
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -693,7 +689,7 @@ tf_cuda_library(
|
|||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_remote_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"c_api_remote_test.cc",
|
||||
],
|
||||
|
|
@ -725,7 +721,7 @@ tf_cuda_cc_test(
|
|||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_remote_function_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"c_api_remote_function_test.cc",
|
||||
],
|
||||
|
|
@ -776,7 +772,7 @@ tf_cuda_cc_test(
|
|||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_cluster_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"c_api_cluster_test.cc",
|
||||
],
|
||||
|
|
@ -1014,7 +1010,7 @@ cc_library(
|
|||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"custom_device_test.cc",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -64,13 +64,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
// "tensorflow/core/platform/platform.h" must be included first before using
|
||||
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \
|
||||
!defined(PLATFORM_FUCHSIA)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
|
@ -117,18 +110,8 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|||
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \
|
||||
!defined(PLATFORM_FUCHSIA)
|
||||
tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async);
|
||||
return tensorflow::wrap(tfrt_context);
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||
|
|
|
|||
|
|
@ -434,7 +434,7 @@ class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
|||
tensorflow::Status Run(const std::string& function_name,
|
||||
const tensorflow::DeviceSet& device_set,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
absl::string_view xla_compile_device_type,
|
||||
const FunctionOptions& function_options,
|
||||
std::unique_ptr<tensorflow::Graph>* graph,
|
||||
tensorflow::FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
|
|
|
|||
|
|
@ -47,10 +47,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
|
|
@ -1262,16 +1258,6 @@ TEST(CAPI, RunAddFunctionWithGrappler) {
|
|||
RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true);
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
TEST(CAPI, RunAddFunction_TFRT) {
|
||||
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RunAddFunctionWithGrappler_TFRT) {
|
||||
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true);
|
||||
}
|
||||
#endif
|
||||
|
||||
void BM_ExecuteFunction(::testing::benchmark::State& state) {
|
||||
const int async = state.range(0);
|
||||
state.SetLabel(async ? "ExecuteFunctionAsync" : "ExecuteFunction");
|
||||
|
|
@ -1802,23 +1788,9 @@ void TestOpAddAttrs(bool use_tfrt) {
|
|||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
if (use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
auto* op = tensorflow::down_cast<tfrt::tf::OperationInterface*>(
|
||||
tensorflow::unwrap(copy_op));
|
||||
auto* tfrt_op_attrs =
|
||||
tensorflow::down_cast<const tfrt::tf::OpAttrsInterface*>(
|
||||
op->GetOpAttrs());
|
||||
tensorflow::DataType result;
|
||||
tfrt_op_attrs->GetType("dtype", &result);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, result);
|
||||
tfrt_op_attrs->GetFallbackAttrs()->FillAttrValueMap(&attr_values);
|
||||
#endif
|
||||
} else {
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
}
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
|
@ -1829,11 +1801,6 @@ void TestOpAddAttrs(bool use_tfrt) {
|
|||
|
||||
TEST(CAPI, TestTFE_OpAddAttrs) { TestOpAddAttrs(/*use_tfrt=*/false); }
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
TEST(CAPI, TestTFE_OpAddAttrs_TFRT) { TestOpAddAttrs(/*use_tfrt=*/true); }
|
||||
|
||||
#endif
|
||||
|
||||
TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
|
|
|
|||
|
|
@ -1006,18 +1006,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
|
|||
|
||||
// The above tests are run for a combination of:
|
||||
// - graphdef and MLIR tracing engine
|
||||
// - Using TFRT as an execution runtime (true == enable TFRT)
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Combine(::testing::Values("graphdef",
|
||||
"mlir"),
|
||||
::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Combine(::testing::Values("graphdef",
|
||||
"mlir"),
|
||||
::testing::Values(false)));
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -188,18 +188,10 @@ TEST_P(UnifiedAPI, TestPartialShapeTracing) {
|
|||
ASSERT_EQ(-1, shape.dim_size(1));
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -125,19 +125,12 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
|
|||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
CustomGradientTest, CustomGradientTest,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
CustomGradientTest, CustomGradientTest,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/pjrt:pjrt_c_api_client",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"//tensorflow/compiler/xla/stream_executor/tpu:tpu_initializer_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/common_runtime/next_pluggable_device:plugin_resource",
|
||||
"//tensorflow/core/platform:status",
|
||||
|
|
@ -32,6 +33,8 @@ cc_library(
|
|||
"//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent",
|
||||
"//tensorflow/tsl/platform:errors",
|
||||
"//tensorflow/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/kernels_experimental.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
|
|
@ -30,6 +33,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/jit/variable_info_util.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.h" // NOLINT(unused-includes): required for tensorflow::tpu::FindAndLoadTpuLibrary
|
||||
#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
|
@ -110,9 +114,9 @@ TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx,
|
|||
const tensorflow::Tensor& arg_tensor = cc_ctx->input(index);
|
||||
tsl::Status cc_status;
|
||||
if (arg_tensor.dtype() != tensorflow::DT_RESOURCE) {
|
||||
cc_status = tsl::errors::InvalidArgument(
|
||||
"Trying to obtain resource handle from Input[", index,
|
||||
"], which is not type DT_RESOURCE.");
|
||||
cc_status = absl::InvalidArgumentError(
|
||||
absl::StrCat("Trying to obtain resource handle from Input[", index,
|
||||
"], which is not type DT_RESOURCE."));
|
||||
tsl::Set_TF_Status_from_Status(status, cc_status);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -140,12 +144,12 @@ void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx,
|
|||
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelContext*>(ctx);
|
||||
tsl::Status cc_status;
|
||||
if (var_info == nullptr) {
|
||||
cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL.");
|
||||
cc_status = absl::InvalidArgumentError("TF_VariableInfo is NULL.");
|
||||
tsl::Set_TF_Status_from_Status(status, cc_status);
|
||||
return;
|
||||
}
|
||||
if (var_info->var_info.var() == nullptr) {
|
||||
cc_status = tsl::errors::InvalidArgument(
|
||||
cc_status = absl::InvalidArgumentError(
|
||||
"VariableInfo does not track a resource variable.");
|
||||
tsl::Set_TF_Status_from_Status(status, cc_status);
|
||||
return;
|
||||
|
|
@ -161,12 +165,12 @@ TF_Tensor* TF_GetTensorFromVariableInfo(TF_VariableInfo* var_info,
|
|||
TF_Status* status) {
|
||||
tsl::Status cc_status;
|
||||
if (var_info == nullptr) {
|
||||
cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL.");
|
||||
cc_status = absl::InvalidArgumentError("TF_VariableInfo is NULL.");
|
||||
tsl::Set_TF_Status_from_Status(status, cc_status);
|
||||
return nullptr;
|
||||
}
|
||||
if (var_info->var_info.var() == nullptr) {
|
||||
cc_status = tsl::errors::InvalidArgument(
|
||||
cc_status = absl::InvalidArgumentError(
|
||||
"VariableInfo does not track a resource variable.");
|
||||
tsl::Set_TF_Status_from_Status(status, cc_status);
|
||||
return nullptr;
|
||||
|
|
@ -239,6 +243,18 @@ void TF_CoordinationServiceDeleteKeyValue(const char* key,
|
|||
void TF_CreateAndSetPjRtCApiClient(const char* device_type, TF_Status* status,
|
||||
PJRT_NamedValue* create_options,
|
||||
int num_options) {
|
||||
// TODO(b/262050449): use a common plugin discovery mechanism, rather than
|
||||
// having TPU-specific code here.
|
||||
#if !defined(PLATFORM_GOOGLE) || defined(LIBTPU_STATIC)
|
||||
if (absl::AsciiStrToLower(device_type) == "tpu") {
|
||||
// TODO(b/261484192): handle device specific initialization.
|
||||
tsl::Status tpu_status = tensorflow::tpu::FindAndLoadTpuLibrary();
|
||||
if (!tpu_status.ok()) {
|
||||
tensorflow::Set_TF_Status_from_Status(status, tpu_status);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
tsl::StatusOr<std::unique_ptr<xla::PjRtClient>> pjrt_client =
|
||||
xla::GetCApiClient(device_type, pjrt::ConvertFromPjRtNamedValueList(
|
||||
create_options, num_options));
|
||||
|
|
@ -263,8 +279,9 @@ PJRT_Client* TF_GetPjRtCClient(const char* device_type, TF_Status* status) {
|
|||
tensorflow::down_cast<xla::PjRtCApiClient*>(*pjrt_client);
|
||||
if (pjrt_c_api_client == nullptr) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
status, tsl::errors::Internal("PjRtClient for ", device_type,
|
||||
" is not type PjRtCApiClient"));
|
||||
status,
|
||||
absl::InternalError(absl::StrCat("PjRtClient for ", device_type,
|
||||
" is not type PjRtCApiClient")));
|
||||
return nullptr;
|
||||
}
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
|
|
@ -282,8 +299,7 @@ PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) {
|
|||
tensorflow::AsyncValueTensor::FromTensor(&tensor);
|
||||
if (av_tensor == nullptr || av_tensor->GetBuffer() == nullptr) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
status,
|
||||
tsl::errors::Internal("Input tensor does not have PjRtBuffer."));
|
||||
status, absl::InternalError("Input tensor does not have PjRtBuffer."));
|
||||
return nullptr;
|
||||
}
|
||||
auto* c_api_buffer =
|
||||
|
|
@ -291,7 +307,7 @@ PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) {
|
|||
if (c_api_buffer == nullptr) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
status,
|
||||
tsl::errors::Internal(
|
||||
absl::InternalError(
|
||||
"The PjRtBuffer in the tensor is not type PjRtCApiBuffer."));
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -317,8 +333,9 @@ void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer,
|
|||
tensorflow::down_cast<xla::PjRtCApiClient*>(*pjrt_client);
|
||||
if (pjrt_c_api_client == nullptr) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
status, tsl::errors::Internal("PjRtClient for ", device_type,
|
||||
" is not type PjRtCApiClient"));
|
||||
status,
|
||||
absl::InternalError(absl::StrCat("PjRtClient for ", device_type,
|
||||
" is not type PjRtCApiClient")));
|
||||
return;
|
||||
}
|
||||
tensorflow::AsyncValueTensor* av_tensor =
|
||||
|
|
@ -326,7 +343,7 @@ void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer,
|
|||
if (av_tensor == nullptr) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
status,
|
||||
tsl::errors::Internal(
|
||||
absl::InternalError(
|
||||
"The tensor to set PjRtBuffer is not an AsyncValueTensor."));
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ cc_library(
|
|||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -99,6 +101,8 @@ cc_library(
|
|||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
|
|
@ -37,8 +39,8 @@ Status Asset::Create(ImmediateExecutionContext* ctx,
|
|||
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
|
||||
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create scalar string tensor for Asset at path ", abs_path);
|
||||
return absl::InternalError(absl::StrCat(
|
||||
"Failed to create scalar string tensor for Asset at path ", abs_path));
|
||||
}
|
||||
|
||||
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
|
|
@ -60,15 +62,15 @@ Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
|||
const TFConcreteFunctionRevivalState* create_resource_fn =
|
||||
resource.create_resource;
|
||||
if (create_resource_fn == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Resource at node ", node_id,
|
||||
" did not have a create_resource() function");
|
||||
return absl::FailedPreconditionError(
|
||||
absl::StrCat("Resource at node ", node_id,
|
||||
" did not have a create_resource() function"));
|
||||
}
|
||||
const SavedConcreteFunction* saved_create_resource_fn =
|
||||
create_resource_fn->saved_concrete_func;
|
||||
if (!saved_create_resource_fn->bound_inputs().empty()) {
|
||||
// TODO(b/124045874): Support loading resource functions via a top sort
|
||||
return errors::Unimplemented(
|
||||
return absl::UnimplementedError(
|
||||
"Create Resource functions with captures are currently unsupported.");
|
||||
}
|
||||
}
|
||||
|
|
@ -86,9 +88,9 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
|||
case SavedObject::kVariable: {
|
||||
const auto& variables_iter = objects.variables.find(node_id);
|
||||
if (variables_iter == objects.variables.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type variable to tensor but the variable wasn't initialized");
|
||||
" of type variable to tensor but the variable wasn't initialized"));
|
||||
}
|
||||
*handle = variables_iter->second->handle();
|
||||
return Status();
|
||||
|
|
@ -96,9 +98,10 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
|||
case SavedObject::kConstant: {
|
||||
const auto& constants_iter = objects.constants.find(node_id);
|
||||
if (constants_iter == objects.constants.end()) {
|
||||
return errors::FailedPrecondition("Tried to convert node id ", node_id,
|
||||
" of type constant to tensor but the "
|
||||
"constant wasn't initialized");
|
||||
return absl::FailedPreconditionError(
|
||||
absl::StrCat("Tried to convert node id ", node_id,
|
||||
" of type constant to tensor but the "
|
||||
"constant wasn't initialized"));
|
||||
}
|
||||
*handle = constants_iter->second->handle();
|
||||
return Status();
|
||||
|
|
@ -106,9 +109,9 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
|||
case SavedObject::kAsset: {
|
||||
const auto& assets_iter = objects.assets.find(node_id);
|
||||
if (assets_iter == objects.assets.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type asset to tensor but the asset wasn't initialized");
|
||||
" of type asset to tensor but the asset wasn't initialized"));
|
||||
}
|
||||
*handle = assets_iter->second->handle();
|
||||
return Status();
|
||||
|
|
@ -116,24 +119,24 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
|||
case SavedObject::kResource: {
|
||||
const auto& resource_iter = objects.restored_resources.find(node_id);
|
||||
if (resource_iter == objects.restored_resources.end()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Tried to convert node id ", node_id,
|
||||
" of type Resource to tensor but the Resource wasn't initialized");
|
||||
" of type Resource to tensor but the Resource wasn't initialized"));
|
||||
}
|
||||
const RestoredResourceRevivalState& resource = resource_iter->second;
|
||||
if (resource.resource_handle == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Resource with node id ", node_id,
|
||||
" should have its resource_handle created, but was nullptr.");
|
||||
" should have its resource_handle created, but was nullptr."));
|
||||
}
|
||||
*handle = resource.resource_handle.get();
|
||||
return Status();
|
||||
}
|
||||
default: {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Only objects of type variable, constant, asset, and resources have "
|
||||
"capturable tensorhandles. Encountered object of kind ",
|
||||
node.kind_case(), " at node id: ", node_id);
|
||||
node.kind_case(), " at node id: ", node_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -167,35 +170,35 @@ Status SignatureDefArgsFromInputs(
|
|||
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
|
||||
// string keys to TensorSpecs.
|
||||
if (!canonicalized_input_signature.has_tuple_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||
"of form tuple(tuple(), dict()), but was instead: \n",
|
||||
canonicalized_input_signature.DebugString());
|
||||
canonicalized_input_signature.DebugString()));
|
||||
}
|
||||
|
||||
const TupleValue& args_kwargs_tuple =
|
||||
canonicalized_input_signature.tuple_value();
|
||||
if (args_kwargs_tuple.values_size() != 2) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||
"a tuple of two elements (args, kwargs), but was instead: \n",
|
||||
args_kwargs_tuple.DebugString());
|
||||
args_kwargs_tuple.DebugString()));
|
||||
}
|
||||
|
||||
const StructuredValue& args = args_kwargs_tuple.values(0);
|
||||
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"SignatureDefFunction's canonicalized_input_signature's args"
|
||||
"should be an empty tuple, but instead got: \n",
|
||||
args.DebugString());
|
||||
args.DebugString()));
|
||||
}
|
||||
|
||||
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
|
||||
if (!kwargs.has_dict_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||
"should be a dictionary, but instead got: \n",
|
||||
kwargs.DebugString());
|
||||
kwargs.DebugString()));
|
||||
}
|
||||
|
||||
const DictValue& kwargs_dict = kwargs.dict_value();
|
||||
|
|
@ -206,10 +209,10 @@ Status SignatureDefArgsFromInputs(
|
|||
const std::string& key = key_value.first;
|
||||
const StructuredValue& value = key_value.second;
|
||||
if (!value.has_tensor_spec_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||
"dictionary contained a non-tensorspec value for key-value pair: \n",
|
||||
"Key: ", key, "Value: \n", value.DebugString());
|
||||
"Key: ", key, "Value: \n", value.DebugString()));
|
||||
}
|
||||
result[key] = &value.tensor_spec_value();
|
||||
}
|
||||
|
|
@ -226,10 +229,10 @@ Status SignatureDefArgsFromInputs(
|
|||
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
|
||||
std::vector<SignatureDefParam>* out) {
|
||||
if (!output_signature.has_dict_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"SignatureDefFunction's output_signature must be a dictionary, but "
|
||||
"instead got: ",
|
||||
output_signature.DebugString());
|
||||
output_signature.DebugString()));
|
||||
}
|
||||
|
||||
const DictValue& output_dict = output_signature.dict_value();
|
||||
|
|
@ -240,10 +243,10 @@ Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
|
|||
const std::string& key = key_value.first;
|
||||
const StructuredValue& value = key_value.second;
|
||||
if (!value.has_tensor_spec_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"SignatureDefFunction's output_signature dictionary contained a "
|
||||
"non-tensorspec value for key-value pair: \n",
|
||||
"Key: ", key, "Value: \n", value.DebugString());
|
||||
"Key: ", key, "Value: \n", value.DebugString()));
|
||||
}
|
||||
result[key] = &value.tensor_spec_value();
|
||||
}
|
||||
|
|
@ -337,7 +340,7 @@ Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx,
|
|||
create_resource_fn->saved_concrete_func;
|
||||
if (!saved_create_resource_fn->bound_inputs().empty()) {
|
||||
// TODO(b/124045874): Load resource functions via a topological sort
|
||||
return errors::Unimplemented(
|
||||
return absl::UnimplementedError(
|
||||
"Create Resource functions with captures are currently unsupported.");
|
||||
}
|
||||
std::unique_ptr<TFConcreteFunction> out;
|
||||
|
|
@ -401,9 +404,9 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
|
|||
const TFConcreteFunction* create_resource_fn =
|
||||
revived->concrete_functions.Find(create_resource_fn_node);
|
||||
if (create_resource_fn == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"ConcreteFunction at node ", create_resource_fn_node,
|
||||
" should have been initialized prior to being called.");
|
||||
return absl::FailedPreconditionError(
|
||||
absl::StrCat("ConcreteFunction at node ", create_resource_fn_node,
|
||||
" should have been initialized prior to being called."));
|
||||
}
|
||||
ImmediateOpPtr function_op;
|
||||
TF_RETURN_IF_ERROR(create_resource_fn->MakeCallOp({}, &function_op));
|
||||
|
|
@ -416,7 +419,7 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
|
|||
AbstractTensorHandlePtr owned_resource_handle(resource_handle);
|
||||
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
|
||||
owned_resource_handle.get())) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
return absl::InternalError("Unexpected tensor handle kind.");
|
||||
}
|
||||
ImmediateTensorHandlePtr result(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
|
|
@ -443,9 +446,9 @@ Status BuildResources(ImmediateExecutionContext* ctx,
|
|||
create_resource = revived->concrete_functions.Find(
|
||||
resource_revival_state.create_resource->node_id);
|
||||
if (create_resource == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"'create_resource' function with node id ",
|
||||
resource_revival_state.create_resource->node_id, " not found");
|
||||
resource_revival_state.create_resource->node_id, " not found"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -454,9 +457,9 @@ Status BuildResources(ImmediateExecutionContext* ctx,
|
|||
initialize = revived->concrete_functions.Find(
|
||||
resource_revival_state.initialize->node_id);
|
||||
if (initialize == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"'initialize' function with node id ",
|
||||
resource_revival_state.initialize->node_id, " not found");
|
||||
resource_revival_state.initialize->node_id, " not found"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -465,15 +468,16 @@ Status BuildResources(ImmediateExecutionContext* ctx,
|
|||
destroy_resource = revived->concrete_functions.Find(
|
||||
resource_revival_state.destroy_resource->node_id);
|
||||
if (destroy_resource == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"'destroy_resource' function with node id ",
|
||||
resource_revival_state.destroy_resource->node_id, " not found");
|
||||
resource_revival_state.destroy_resource->node_id, " not found"));
|
||||
}
|
||||
}
|
||||
|
||||
if (resource_revival_state.resource_handle == nullptr) {
|
||||
return errors::FailedPrecondition("Resource at node id ", node_id,
|
||||
" does not have a resource handle.");
|
||||
return absl::FailedPreconditionError(
|
||||
absl::StrCat("Resource at node id ", node_id,
|
||||
" does not have a resource handle."));
|
||||
}
|
||||
|
||||
revived->restored_resources.emplace(
|
||||
|
|
|
|||
|
|
@ -344,7 +344,7 @@ cc_library(
|
|||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -13,15 +13,15 @@ py_strict_binary(
|
|||
srcs = ["gen_saved_models.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python:while_loop",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/framework:constant_op",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:tensor_spec",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops:resource_variable_ops",
|
||||
"//tensorflow/python/ops:variables",
|
||||
"//tensorflow/python/ops:while_loop",
|
||||
"//tensorflow/python/saved_model",
|
||||
"@absl_py//absl:app",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -189,6 +189,8 @@ cc_library_with_android_deps(
|
|||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# TODO(aselle): describe this package.
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#TODO(aselle) : describe this package.
|
||||
|
||||
load(
|
||||
"//tensorflow/core/platform:rules_cc.bzl",
|
||||
|
|
@ -42,6 +43,8 @@ cc_library(
|
|||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:statusor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -84,13 +87,13 @@ py_strict_binary(
|
|||
srcs = ["tests/generate_testdata.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/framework:constant_op",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:tensor_spec",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops:variables",
|
||||
"//tensorflow/python/saved_model",
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
|
|
@ -180,13 +183,11 @@ tf_cc_test(
|
|||
size = "medium",
|
||||
srcs = [
|
||||
"tests/runtime_test_core.cc",
|
||||
"tests/runtime_test_tfrt.cc",
|
||||
],
|
||||
deps = [
|
||||
":runtime_test",
|
||||
"//tensorflow/cc/experimental/libtf/runtime",
|
||||
"//tensorflow/cc/experimental/libtf/runtime/core",
|
||||
"//tensorflow/cc/experimental/libtf/runtime/tfrt",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/cc/experimental/libtf/value.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
|
@ -172,7 +174,7 @@ class Object : public Handle {
|
|||
}
|
||||
}
|
||||
}
|
||||
return tensorflow::errors::NotFound("Key not in dictionary.");
|
||||
return absl::NotFoundError("Key not in dictionary.");
|
||||
}
|
||||
|
||||
/// Sets `key` attribute with the underlying value of `h`.
|
||||
|
|
@ -202,7 +204,7 @@ class Dictionary final : public Handle {
|
|||
tensorflow::StatusOr<T> Get(const Handle& key) {
|
||||
auto it = value_.dict().find(key.value_);
|
||||
if (it != value_.dict().end()) return Cast<T>(Handle(it->second));
|
||||
return tensorflow::errors::NotFound("Key not in dictionary.");
|
||||
return absl::NotFoundError("Key not in dictionary.");
|
||||
}
|
||||
/// Sets `key` with value `value`.
|
||||
void Set(const String& key, Handle value) {
|
||||
|
|
@ -282,7 +284,7 @@ tensorflow::Status Tensor::GetValue(absl::Span<T> data) const {
|
|||
{
|
||||
const auto abstract_t = value_.tensor().get();
|
||||
if (!tensorflow::ImmediateExecutionTensorHandle::classof(abstract_t)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
return absl::InvalidArgumentError(
|
||||
"Attempting to get value of non eager tensor.");
|
||||
}
|
||||
auto imm_t =
|
||||
|
|
@ -315,7 +317,7 @@ class Tuple : public Handle {
|
|||
template <class T>
|
||||
tensorflow::StatusOr<T> Get(size_t i) {
|
||||
if (i >= value_.tuple().size())
|
||||
return tensorflow::errors::InvalidArgument("Out of bounds index.");
|
||||
return absl::InvalidArgumentError("Out of bounds index.");
|
||||
return Cast<T>(Handle(value_.tuple()[i]));
|
||||
}
|
||||
|
||||
|
|
@ -348,7 +350,7 @@ class List final : public Handle {
|
|||
template <class T>
|
||||
tensorflow::StatusOr<T> Get(size_t i) {
|
||||
if (i >= size()) {
|
||||
return tensorflow::errors::InvalidArgument("Out of bounds index.");
|
||||
return absl::InvalidArgumentError("Out of bounds index.");
|
||||
}
|
||||
return Cast<T>(Handle(value_.list()[i]));
|
||||
}
|
||||
|
|
@ -356,7 +358,7 @@ class List final : public Handle {
|
|||
/// Sets value `h` at index `i`.
|
||||
tensorflow::Status Set(size_t i, Handle h) {
|
||||
if (i >= size()) {
|
||||
return tensorflow::errors::InvalidArgument("Out of bounds index.");
|
||||
return absl::InvalidArgumentError("Out of bounds index.");
|
||||
}
|
||||
value_.list()[i] = std::move(h.value_);
|
||||
return ::tensorflow::OkStatus();
|
||||
|
|
@ -533,7 +535,7 @@ tensorflow::StatusOr<T> Cast(Handle handle) {
|
|||
if (handle.value_.type() == TypeToTaggedType<T>() ||
|
||||
std::is_same<T, Handle>::value)
|
||||
return T((std::move(handle.value_)));
|
||||
return tensorflow::errors::InvalidArgument("Incompatible cast.");
|
||||
return absl::InvalidArgumentError("Incompatible cast.");
|
||||
}
|
||||
|
||||
// Converters for C++ primitives like float and int to handles. Allows callable
|
||||
|
|
@ -656,10 +658,10 @@ class UneraseCallHelper<Fn, TReturn, TSignatureArg, TSignatureRest...> {
|
|||
Handle h(std::move(args_in.tuple()[argument_index]));
|
||||
tensorflow::StatusOr<TSignatureArg> x = Cast<TSignatureArg>(std::move(h));
|
||||
if (!x.ok())
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
std::string("Function ") + name + " Arg " +
|
||||
std::to_string(argument_index) +
|
||||
" cannot be cast to desired signature type ");
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat(std::string("Function ") + name + " Arg " +
|
||||
std::to_string(argument_index) +
|
||||
" cannot be cast to desired signature type "));
|
||||
return UneraseCallHelper<Fn, TReturn, TSignatureRest...>::template Call(
|
||||
name, fn, argument_index + 1, args_in, args..., *x);
|
||||
}
|
||||
|
|
@ -683,9 +685,9 @@ class CallableWrapper {
|
|||
TaggedValue kwargs) {
|
||||
constexpr size_t argument_count = sizeof...(TFuncArgs);
|
||||
if (argument_count != args.tuple().size())
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
std::string("Function ") + name_ + " expected " +
|
||||
std::to_string(argument_count) + " args.");
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat(std::string("Function ") + name_ + " expected " +
|
||||
std::to_string(argument_count) + " args."));
|
||||
return UneraseCallHelper<Fn, TReturn, TFuncArgs...>::Call(name_, functor_,
|
||||
0, args);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
package(
|
||||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
|
||||
default_visibility = [
|
||||
"//tensorflow/cc/experimental/libtf:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfrt",
|
||||
srcs = [
|
||||
"tfrt.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tfrt.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/cc/experimental/libtf",
|
||||
"//tensorflow/cc/experimental/libtf/runtime",
|
||||
],
|
||||
)
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/cc/experimental/libtf/runtime/tfrt/tfrt.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/cc/experimental/libtf/value.h"
|
||||
|
||||
namespace tf {
|
||||
namespace libtf {
|
||||
namespace runtime {
|
||||
namespace tfrt {
|
||||
|
||||
runtime::Runtime Runtime() {
|
||||
TFE_Context* ctx;
|
||||
TFE_ContextOptions* ctx_options = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(ctx_options, true);
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(ctx_options,
|
||||
TFE_DEVICE_PLACEMENT_WARN);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
ctx = TFE_NewContext(ctx_options, status);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContextOptions(ctx_options);
|
||||
return runtime::Runtime(tensorflow::unwrap(ctx));
|
||||
}
|
||||
|
||||
} // namespace tfrt
|
||||
} // namespace runtime
|
||||
} // namespace libtf
|
||||
} // namespace tf
|
||||
|
|
@ -288,7 +288,7 @@ TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) {
|
|||
INSTANTIATE_TEST_SUITE_P(TF2CAPI, FunctionTest,
|
||||
::testing::Combine(::testing::Values("graphdef",
|
||||
"mlir"),
|
||||
::testing::Values(false, true)));
|
||||
::testing::Values(false)));
|
||||
|
||||
} // namespace libtf
|
||||
} // namespace tf
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ TEST_P(UnifiedCAPI, SimpleCreationFunctions) {
|
|||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Combine(::testing::Values("graphdef",
|
||||
"mlir"),
|
||||
::testing::Values(true, false)));
|
||||
::testing::Values(false)));
|
||||
|
||||
} // namespace libtf
|
||||
} // namespace tf
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ TEST_P(VariableTest, CreateAssignReadDestroy) {
|
|||
INSTANTIATE_TEST_SUITE_P(TF2CAPI, VariableTest,
|
||||
::testing::Combine(::testing::Values("graphdef",
|
||||
"mlir"),
|
||||
::testing::Values(false, true)));
|
||||
::testing::Values(false)));
|
||||
|
||||
} // namespace libtf
|
||||
} // namespace tf
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
|
@ -156,9 +158,9 @@ class Input {
|
|||
typedef typename RealType<T>::type RealT;
|
||||
Tensor t(DataTypeToEnum<RealT>::v(), shape);
|
||||
if (t.NumElements() != static_cast<int64_t>(v.size())) {
|
||||
status = errors::InvalidArgument(
|
||||
status = absl::InvalidArgumentError(absl::StrCat(
|
||||
"Cannot construct a tensor with ", t.NumElements(),
|
||||
" from an initializer list with ", v.size(), " elements");
|
||||
" from an initializer list with ", v.size(), " elements"));
|
||||
return;
|
||||
}
|
||||
std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
|
||||
|
|
|
|||
|
|
@ -466,7 +466,7 @@ TEST_F(CWiseUnaryGradTest, Asin_Complex) {
|
|||
};
|
||||
// TODO(kbsriram)
|
||||
// Enable test when the asin kernel supports complex numbers
|
||||
if (false) {
|
||||
if (/* DISABLES CODE */ (false)) {
|
||||
TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
|
||||
}
|
||||
}
|
||||
|
|
@ -482,7 +482,7 @@ TEST_F(CWiseUnaryGradTest, Acos_Complex) {
|
|||
};
|
||||
// TODO(kbsriram)
|
||||
// Add test when the acos kernel supports complex numbers
|
||||
if (false) {
|
||||
if (/* DISABLES CODE */ (false)) {
|
||||
TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
|
||||
}
|
||||
}
|
||||
|
|
@ -510,7 +510,7 @@ TEST_F(CWiseUnaryGradTest, Atan_Complex) {
|
|||
};
|
||||
// TODO(kbsriram)
|
||||
// Add test when the atan kernel supports complex numbers
|
||||
if (false) {
|
||||
if (/* DISABLES CODE */ (false)) {
|
||||
TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
|
||||
}
|
||||
}
|
||||
|
|
@ -561,7 +561,7 @@ TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
|
|||
};
|
||||
// TODO(kbsriram)
|
||||
// Add test when the lgamma kernel supports complex numbers
|
||||
if (false) {
|
||||
if (/* DISABLES CODE */ (false)) {
|
||||
TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
|
||||
}
|
||||
}
|
||||
|
|
@ -579,7 +579,7 @@ TEST_F(CWiseUnaryGradTest, Erf_Complex) {
|
|||
};
|
||||
// TODO(kbsriram)
|
||||
// Add test when the erf kernel supports complex numbers
|
||||
if (false) {
|
||||
if (/* DISABLES CODE */ (false)) {
|
||||
TestCWiseGrad<complex64, complex64>(ERF, x_fn);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Description:
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#Description:
|
||||
# TensorFlow SavedModel.
|
||||
|
||||
load("//tensorflow:tensorflow.default.bzl", "filegroup")
|
||||
|
|
@ -6,6 +7,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
|||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
"if_google",
|
||||
"if_mobile",
|
||||
"if_not_mobile",
|
||||
"tf_cc_test",
|
||||
|
|
@ -28,6 +30,8 @@ package(
|
|||
|
||||
exports_files([
|
||||
"loader.h",
|
||||
"testdata/chunked_saved_model/chunked_model/saved_model.cpb",
|
||||
"testdata/chunked_saved_model/chunked_model/saved_model.pbtxt",
|
||||
])
|
||||
|
||||
cc_library(
|
||||
|
|
@ -63,7 +67,9 @@ cc_library(
|
|||
":metrics",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + if_not_mobile([
|
||||
] + if_google([
|
||||
"//tensorflow/tools/proto_splitter:merge",
|
||||
]) + if_not_mobile([
|
||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||
# tf_lib depending on the build platform.
|
||||
|
|
@ -87,9 +93,8 @@ tf_cc_test(
|
|||
":tag_constants",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -131,6 +136,8 @@ cc_library(
|
|||
":fingerprinting",
|
||||
":loader_util",
|
||||
":reader",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
] + if_not_mobile([
|
||||
":metrics",
|
||||
":util",
|
||||
|
|
@ -252,15 +259,15 @@ py_binary(
|
|||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/framework:tensor_spec",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops:lookup_ops",
|
||||
"//tensorflow/python/ops:variables",
|
||||
"//tensorflow/python/platform:client_testlib",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
"//tensorflow/python/trackable:asset",
|
||||
|
|
@ -268,6 +275,31 @@ py_binary(
|
|||
],
|
||||
)
|
||||
|
||||
# copybara:uncomment_begin(google-only)
|
||||
#
|
||||
# py_binary(
|
||||
# name = "testdata/generate_chunked_models",
|
||||
# srcs = ["testdata/generate_chunked_models.py"],
|
||||
# python_version = "PY3",
|
||||
# srcs_version = "PY3",
|
||||
# deps = [
|
||||
# "//tensorflow/python/compat:v2_compat",
|
||||
# "//tensorflow/python/eager:def_function",
|
||||
# "//tensorflow/python/framework:constant_op",
|
||||
# "//tensorflow/python/module",
|
||||
# "//tensorflow/python/platform:client_testlib",
|
||||
# "//tensorflow/python/saved_model:loader",
|
||||
# "//tensorflow/python/saved_model:save",
|
||||
# "//tensorflow/python/saved_model:save_options",
|
||||
# "//tensorflow/python/util:compat",
|
||||
# "//tensorflow/tools/proto_splitter:constants",
|
||||
# "//tensorflow/tools/proto_splitter/python:saved_model",
|
||||
# "@absl_py//absl:app",
|
||||
# ],
|
||||
# )
|
||||
#
|
||||
# copybara:uncomment_end
|
||||
|
||||
# TODO(b/32673259): add a test to continuously validate these files.
|
||||
filegroup(
|
||||
name = "saved_model_test_files",
|
||||
|
|
@ -284,6 +316,7 @@ filegroup(
|
|||
"testdata/fuzz_generated/**",
|
||||
"testdata/SimpleV1Model/**",
|
||||
"testdata/OptimizerSlotVariableModule/**",
|
||||
"testdata/chunked_saved_model/**",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
@ -369,7 +402,7 @@ tf_cc_test(
|
|||
":metrics",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@jsoncpp_git//:jsoncpp",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,60 +39,6 @@ using strings::StrCat;
|
|||
// `tensorflow::SavedModelV2Bundle::Load` API label.
|
||||
constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2";
|
||||
|
||||
Status ReadSavedModelProto(const string& export_dir,
|
||||
SavedModel* saved_model_proto) {
|
||||
LOG(INFO) << "Reading SavedModel from: " << export_dir;
|
||||
|
||||
const string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
Status found_pb = Env::Default()->FileExists(saved_model_pb_path);
|
||||
if (found_pb.ok()) {
|
||||
Status result =
|
||||
ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
|
||||
if (result.ok()) {
|
||||
metrics::SavedModelReadCount(
|
||||
saved_model::GetWriteVersion(*saved_model_proto))
|
||||
.IncrementBy(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const string saved_model_pbtxt_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||
Status found_pbtxt = Env::Default()->FileExists(saved_model_pbtxt_path);
|
||||
if (found_pbtxt.ok()) {
|
||||
Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
||||
saved_model_proto);
|
||||
if (result.ok()) {
|
||||
metrics::SavedModelReadCount(
|
||||
saved_model::GetWriteVersion(*saved_model_proto))
|
||||
.IncrementBy(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status err;
|
||||
if (found_pb.code() == found_pbtxt.code()) {
|
||||
err = Status(found_pb.code(),
|
||||
StrCat(found_pb.message(), "\n", found_pbtxt.message()));
|
||||
} else if (found_pb.code() == NOT_FOUND) {
|
||||
err = found_pbtxt;
|
||||
} else if (found_pbtxt.code() == NOT_FOUND) {
|
||||
err = found_pb;
|
||||
} else {
|
||||
// found_pb and found_pbtxt both errored, w/ different codes, neither being
|
||||
// NOT_FOUND.
|
||||
err = Status(
|
||||
absl::StatusCode::kInternal,
|
||||
StrCat("Different errors encountered while looking for saved_model.pb "
|
||||
"and saved_model.pbtxt in the export directory path \"",
|
||||
export_dir, "\": \n", found_pb.ToString(), "\n",
|
||||
found_pbtxt.ToString()));
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
|
||||
TrackableObjectGraph* object_graph) {
|
||||
Tensor object_graph_tensor;
|
||||
|
|
@ -123,7 +69,7 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir,
|
|||
SavedModelV2Bundle* const bundle) {
|
||||
metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1);
|
||||
SavedModel saved_model_proto;
|
||||
TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
|
||||
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
||||
metrics::SavedModelReadPath().Set(export_dir);
|
||||
|
||||
// Load MetaGraphDef.
|
||||
|
|
|
|||
|
|
@ -19,56 +19,63 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
|
||||
// SavedModel assets directory.
|
||||
constexpr char kSavedModelAssetsDirectory[] = "assets";
|
||||
inline constexpr char kSavedModelAssetsDirectory[] = "assets";
|
||||
|
||||
// SavedModel assets.extra directory.
|
||||
constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra";
|
||||
inline constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra";
|
||||
|
||||
// SavedModel assets key for graph collection-def.
|
||||
constexpr char kSavedModelAssetsKey[] = "saved_model_assets";
|
||||
inline constexpr char kSavedModelAssetsKey[] = "saved_model_assets";
|
||||
|
||||
/// SavedModel legacy init op collection key. Used in v1 SavedModels.
|
||||
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
||||
inline constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
||||
|
||||
/// SavedModel main op collection key. Used in v1 SavedModels.
|
||||
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
|
||||
inline constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
|
||||
|
||||
// CollectionDef key for the SavedModel train op.
|
||||
// Not exported while export_all_saved_models is experimental.
|
||||
constexpr char kSavedModelTrainOpKey[] = "saved_model_train_op";
|
||||
inline constexpr char kSavedModelTrainOpKey[] = "saved_model_train_op";
|
||||
|
||||
// Schema version for SavedModel.
|
||||
constexpr int kSavedModelSchemaVersion = 1;
|
||||
inline constexpr int kSavedModelSchemaVersion = 1;
|
||||
|
||||
// SavedModel proto filename prefix.
|
||||
inline constexpr char kSavedModelFilenamePrefix[] = "saved_model";
|
||||
// SavedModel proto filename.
|
||||
constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
|
||||
inline constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
|
||||
|
||||
// SavedModel chunked proto filename.
|
||||
inline constexpr char kSavedModelFilenameCpb[] = "saved_model.cpb";
|
||||
|
||||
// SavedModel text format proto filename.
|
||||
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
||||
inline constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
||||
|
||||
// Subdirectory where debugging related files are written.
|
||||
constexpr char kSavedModelDebugDirectory[] = "debug";
|
||||
inline constexpr char kSavedModelDebugDirectory[] = "debug";
|
||||
|
||||
// File name for GraphDebugInfo protocol buffer which corresponds to the
|
||||
// SavedModel.
|
||||
constexpr char kSavedModelDebugInfoFilenamePb[] = "saved_model_debug_info.pb";
|
||||
inline constexpr char kSavedModelDebugInfoFilenamePb[] =
|
||||
"saved_model_debug_info.pb";
|
||||
|
||||
// Directory in which to save the SavedModel variables.
|
||||
constexpr char kSavedModelVariablesDirectory[] = "variables";
|
||||
inline constexpr char kSavedModelVariablesDirectory[] = "variables";
|
||||
|
||||
// SavedModel variables filename.
|
||||
constexpr char kSavedModelVariablesFilename[] = "variables";
|
||||
inline constexpr char kSavedModelVariablesFilename[] = "variables";
|
||||
|
||||
// SavedModel SignatureDef keys for the initialization and train ops. Used in
|
||||
// V2 SavedModels.
|
||||
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
|
||||
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
|
||||
inline constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
|
||||
inline constexpr char kSavedModelTrainOpSignatureKey[] =
|
||||
"__saved_model_train_op";
|
||||
|
||||
// Key in the TensorBundle for the object graph proto.
|
||||
constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH";
|
||||
inline constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH";
|
||||
|
||||
// Filename for the FingerprintDef protocol buffer.
|
||||
constexpr char kFingerprintFilenamePb[] = "fingerprint.pb";
|
||||
inline constexpr char kFingerprintFilenamePb[] = "fingerprint.pb";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
|||
|
|
@ -120,23 +120,29 @@ uint64 HashCheckpointIndexFile(absl::string_view model_dir) {
|
|||
|
||||
StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
|
||||
absl::string_view export_dir) {
|
||||
SavedModel copy = saved_model;
|
||||
return CreateFingerprintDef(©, export_dir);
|
||||
}
|
||||
|
||||
StatusOr<FingerprintDef> CreateFingerprintDef(SavedModel* saved_model,
|
||||
absl::string_view export_dir) {
|
||||
// Create a copy of `metagraph` which will be used and mutated for fingerprint
|
||||
// computation.
|
||||
MetaGraphDef metagraph_copy = saved_model.meta_graphs(0);
|
||||
FingerprintDef fingerprint_def;
|
||||
MetaGraphDef* metagraph = saved_model->mutable_meta_graphs(0);
|
||||
// Set fingerprint field #1.
|
||||
fingerprint_def.set_saved_model_checksum(HashSavedModel(saved_model));
|
||||
fingerprint_def.set_saved_model_checksum(HashSavedModel(*saved_model));
|
||||
// Set fingerprint field #2.
|
||||
graph_regularization::SimpleDelete(*metagraph_copy.mutable_graph_def());
|
||||
graph_regularization::SimpleDelete(*metagraph->mutable_graph_def());
|
||||
fingerprint_def.set_graph_def_program_hash(
|
||||
graph_regularization::ComputeHash(metagraph_copy.graph_def()));
|
||||
graph_regularization::ComputeHash(metagraph->graph_def()));
|
||||
// Set fingerprint field #3.
|
||||
fingerprint_def.set_signature_def_hash(
|
||||
RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
|
||||
RegularizeAndHashSignatureDefs(metagraph->signature_def()));
|
||||
// Set fingerprint field #4.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
StatusOr<uint64> object_graph_hash,
|
||||
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
|
||||
RegularizeAndHashSavedObjectGraph(metagraph->object_graph_def()));
|
||||
fingerprint_def.set_saved_object_graph_hash(object_graph_hash.value());
|
||||
// Set fingerprint field #5.
|
||||
fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir));
|
||||
|
|
|
|||
|
|
@ -31,6 +31,12 @@ namespace tensorflow::saved_model::fingerprinting {
|
|||
StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
|
||||
absl::string_view export_dir);
|
||||
|
||||
// Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file
|
||||
// (.index) in `export_dir`. The passed in `saved_model` is mutated and should
|
||||
// not be used afterwards.
|
||||
StatusOr<FingerprintDef> CreateFingerprintDef(SavedModel* saved_model,
|
||||
absl::string_view export_dir);
|
||||
|
||||
// Loads the `fingerprint.pb` from `export_dir`, returns an error if there is
|
||||
// none.
|
||||
StatusOr<FingerprintDef> ReadSavedModelFingerprint(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/fingerprinting.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
|
|
@ -90,16 +92,16 @@ static Status ValidateNode(const NodeDef& node) {
|
|||
if (node_value.has_tensor()) {
|
||||
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
|
||||
if (node_shape.num_elements() < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
|
||||
"\") which initializes from a tensor with ",
|
||||
node_shape.num_elements(), " elements");
|
||||
node_shape.num_elements(), " elements"));
|
||||
}
|
||||
}
|
||||
} else if (node.op() == "Const") {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Saved model contains node \"", node.name(),
|
||||
"\" which is a constant tensor but no value has been provided");
|
||||
"\" which is a constant tensor but no value has been provided"));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -108,9 +110,9 @@ static Status ValidateFunctionNotRecursive(const FunctionDef& function) {
|
|||
const auto& function_name = function.signature().name();
|
||||
for (const auto& node : function.node_def()) {
|
||||
if (node.op() == function_name) {
|
||||
return errors::FailedPrecondition(
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Function ", function_name,
|
||||
" is self recursive and TensorFlow does not support this scenario.");
|
||||
" is self recursive and TensorFlow does not support this scenario."));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -340,17 +342,17 @@ class LiteSessionWrapper : public Session {
|
|||
: wrapped_(std::move(wrapped)) {}
|
||||
|
||||
Status Create(const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
return absl::UnimplementedError("Session::Create()");
|
||||
}
|
||||
Status Create(GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
return absl::UnimplementedError("Session::Create()");
|
||||
}
|
||||
|
||||
Status Extend(const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
return absl::UnimplementedError("Session::Extend()");
|
||||
}
|
||||
Status Extend(GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
return absl::UnimplementedError("Session::Extend()");
|
||||
}
|
||||
|
||||
Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
|
|
@ -362,16 +364,16 @@ class LiteSessionWrapper : public Session {
|
|||
}
|
||||
|
||||
Status Create(const RunOptions& run_options, const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
return absl::UnimplementedError("Session::Create()");
|
||||
}
|
||||
Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
return absl::UnimplementedError("Session::Extend()");
|
||||
}
|
||||
Status Create(const RunOptions& run_options, GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Create()");
|
||||
return absl::UnimplementedError("Session::Create()");
|
||||
}
|
||||
Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
|
||||
return errors::Unimplemented("Session::Extend()");
|
||||
return absl::UnimplementedError("Session::Extend()");
|
||||
}
|
||||
Status Close(const RunOptions& run_options) override {
|
||||
return wrapped_->Close(run_options);
|
||||
|
|
@ -390,14 +392,14 @@ class LiteSessionWrapper : public Session {
|
|||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) override {
|
||||
return errors::Unimplemented("Session::PRunSetup()");
|
||||
return absl::UnimplementedError("Session::PRunSetup()");
|
||||
}
|
||||
|
||||
Status PRun(const string& handle,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) override {
|
||||
return errors::Unimplemented("Session::PRun()");
|
||||
return absl::UnimplementedError("Session::PRun()");
|
||||
}
|
||||
|
||||
Status ListDevices(std::vector<DeviceAttributes>* response) override {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,10 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
|
|
@ -31,60 +34,17 @@ limitations under the License.
|
|||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system_helper.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/statusor.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h"
|
||||
// Placeholder for protosplitter merger include.
|
||||
|
||||
#define IS_OSS true
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Reads the SavedModel proto from saved_model.pb in `export_dir`.
|
||||
// Returns a failure status when the SavedModel file does not exist.
|
||||
Status ReadSavedModel(absl::string_view export_dir,
|
||||
SavedModel* saved_model_proto) {
|
||||
LOG(INFO) << "Reading SavedModel from: " << export_dir;
|
||||
|
||||
const std::string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool saved_model_pb_exists,
|
||||
internal::FileExists(Env::Default(), saved_model_pb_path));
|
||||
if (saved_model_pb_exists) {
|
||||
Status result =
|
||||
ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
|
||||
if (result.ok()) {
|
||||
metrics::SavedModelReadCount(
|
||||
saved_model::GetWriteVersion(*saved_model_proto))
|
||||
.IncrementBy(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
const std::string saved_model_pbtxt_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool saved_model_pbtxt_exists,
|
||||
internal::FileExists(Env::Default(), saved_model_pbtxt_path));
|
||||
if (saved_model_pbtxt_exists) {
|
||||
Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
||||
saved_model_proto);
|
||||
if (result.ok()) {
|
||||
metrics::SavedModelReadCount(
|
||||
saved_model::GetWriteVersion(*saved_model_proto))
|
||||
.IncrementBy(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return Status(
|
||||
absl::StatusCode::kNotFound,
|
||||
strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied "
|
||||
"export directory path: ",
|
||||
export_dir,
|
||||
". Check that "
|
||||
"the directory exists and that you have the right "
|
||||
"permissions for accessing it."));
|
||||
}
|
||||
|
||||
Status FindMetaGraphDef(const std::unordered_set<string>& tags,
|
||||
SavedModel* saved_model_proto,
|
||||
MetaGraphDef* meta_graph_def) {
|
||||
|
|
@ -116,6 +76,61 @@ Status FindMetaGraphDef(const std::unordered_set<string>& tags,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
// Reads the SavedModel proto from saved_model.pb in `export_dir`.
|
||||
// Returns a failure status when the SavedModel file does not exist.
|
||||
Status ReadSavedModel(absl::string_view export_dir,
|
||||
SavedModel* saved_model_proto) {
|
||||
LOG(INFO) << "Reading SavedModel from: " << export_dir;
|
||||
|
||||
if (IS_OSS) {
|
||||
const std::string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool saved_model_pb_exists,
|
||||
internal::FileExists(Env::Default(), saved_model_pb_path));
|
||||
if (saved_model_pb_exists) {
|
||||
Status result = ReadBinaryProto(Env::Default(), saved_model_pb_path,
|
||||
saved_model_proto);
|
||||
if (result.ok()) {
|
||||
metrics::SavedModelReadCount(
|
||||
saved_model::GetWriteVersion(*saved_model_proto))
|
||||
.IncrementBy(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string saved_model_pbtxt_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool saved_model_pbtxt_exists,
|
||||
internal::FileExists(Env::Default(), saved_model_pbtxt_path));
|
||||
if (saved_model_pbtxt_exists) {
|
||||
Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
||||
saved_model_proto);
|
||||
if (result.ok()) {
|
||||
metrics::SavedModelReadCount(
|
||||
saved_model::GetWriteVersion(*saved_model_proto))
|
||||
.IncrementBy(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
if (!IS_OSS) {
|
||||
// Only use Merger outside of OSS.
|
||||
// Placeholder for protosplitter merger call.
|
||||
}
|
||||
|
||||
return Status(
|
||||
absl::StatusCode::kNotFound,
|
||||
strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied "
|
||||
"export directory path: ",
|
||||
export_dir,
|
||||
". Check that "
|
||||
"the directory exists and that you have the right "
|
||||
"permissions for accessing it."));
|
||||
}
|
||||
|
||||
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* const meta_graph_def) {
|
||||
|
|
@ -140,8 +155,7 @@ Status ReadSavedModelDebugInfoIfPresent(
|
|||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
|
||||
*debug_info_proto =
|
||||
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
|
||||
*debug_info_proto = std::make_unique<GraphDebugInfo>(std::move(debug_info));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,21 +18,26 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
Status ReadSavedModel(absl::string_view export_dir,
|
||||
SavedModel* saved_model_proto);
|
||||
|
||||
// Reads the SavedModel proto from saved_model.pb(txt) in the given directory,
|
||||
// finds the MetaGraphDef that matches the given set of tags and writes it to
|
||||
// the `meta_graph_def` parameter. Returns a failure status when the SavedModel
|
||||
// file does not exist or no MetaGraphDef matches the tags.
|
||||
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* const meta_graph_def);
|
||||
MetaGraphDef* meta_graph_def);
|
||||
|
||||
// Store debug info from the SavedModel export dir.
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/metrics.h"
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
|
|
@ -24,7 +25,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
|
@ -39,6 +39,16 @@ string TestDataSharded() {
|
|||
"half_plus_two", "00000123");
|
||||
}
|
||||
|
||||
string ChunkedSavedModel() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"chunked_saved_model", "chunked_model");
|
||||
}
|
||||
|
||||
string NonChunkedSavedModel() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"chunked_saved_model", "non_chunked_model");
|
||||
}
|
||||
|
||||
class ReaderTest : public ::testing::Test {
|
||||
protected:
|
||||
ReaderTest() {}
|
||||
|
|
@ -88,15 +98,6 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
|||
<< st.message();
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, PbtxtFormat) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, InvalidExportPath) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
|
|
@ -136,5 +137,7 @@ TEST_F(ReaderTest, MetricsUpdatedSuccessfulRead) {
|
|||
EXPECT_EQ(metrics::SavedModelReadCount("1").value(), read_count_v1 + 1);
|
||||
}
|
||||
|
||||
// Placeholder for protosplitter merger merge test.
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
BIN
tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.cpb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.cpb
vendored
Normal file
Binary file not shown.
2063
tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.pbtxt
vendored
Normal file
2063
tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.pbtxt
vendored
Normal file
File diff suppressed because one or more lines are too long
1
tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/fingerprint.pb
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/fingerprint.pb
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
2(ЕћзмЅ…вА Њ‚¦юЎћхујЋў¶вђЪвЋЯЕ«ѓПѕњоУА°ійов®Щ
|
||||
BIN
tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.index
vendored
Normal file
Binary file not shown.
76
tensorflow/cc/saved_model/testdata/generate_chunked_models.py
vendored
Normal file
76
tensorflow/cc/saved_model/testdata/generate_chunked_models.py
vendored
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Generates GraphDef test data for Merger.
|
||||
|
||||
Constructs chunked protos test data containing GraphDefs with lots of nodes and
|
||||
large nodes for Merger::Read and Merger::Merge.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.tools.proto_splitter import constants
|
||||
from tensorflow.tools.proto_splitter.python import saved_model as proto_splitter
|
||||
|
||||
SPLITTER_TESTDATA_PATH = flags.DEFINE_string(
|
||||
"path", None, help="Path to testdata directory.")
|
||||
|
||||
|
||||
def generate_non_chunked_model(non_chunked_dir: str):
|
||||
root = module.Module()
|
||||
root.c = constant_op.constant(np.random.random_sample([150, 150]))
|
||||
constants.debug_set_max_size(80000)
|
||||
root.get_c = def_function.function(lambda: root.c)
|
||||
signatures = root.get_c.get_concrete_function()
|
||||
save.save(root, non_chunked_dir, signatures=signatures,
|
||||
options=save_options.SaveOptions(experimental_image_format=False))
|
||||
|
||||
|
||||
def generate_chunked_model(non_chunked_dir: str, chunked_dir: str):
|
||||
saved_model = loader_impl.parse_saved_model(non_chunked_dir)
|
||||
prefix = file_io.join(compat.as_str(chunked_dir), "saved_model")
|
||||
file_io.write_string_to_file(f"{prefix}.pbtxt", str(saved_model))
|
||||
proto_splitter.SavedModelSplitter(saved_model).write(prefix)
|
||||
|
||||
|
||||
def main(argv: Sequence[str]) -> None:
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError("Too many command-line arguments.")
|
||||
|
||||
main_dir = os.path.join(SPLITTER_TESTDATA_PATH.value, "chunked_saved_model")
|
||||
non_chunked_dir = os.path.join(main_dir, "non_chunked_model")
|
||||
generate_non_chunked_model(non_chunked_dir)
|
||||
chunked_dir = os.path.join(main_dir, "chunked_model")
|
||||
generate_chunked_model(non_chunked_dir, chunked_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
v2_compat.enable_v2_behavior()
|
||||
app.run(main)
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
# Description:
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#Description:
|
||||
# TensorFlow cc tools.
|
||||
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
|
@ -22,6 +23,8 @@ cc_library(
|
|||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ limitations under the License.
|
|||
#include <iostream>
|
||||
#include <queue>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
|
|
@ -193,7 +195,7 @@ StatusOr<string> GetVarHandleName(
|
|||
if (node->op() == "VarHandleOp") {
|
||||
return node->name();
|
||||
}
|
||||
return errors::NotFound("No VarHandleOp ancestor found");
|
||||
return absl::NotFoundError("No VarHandleOp ancestor found");
|
||||
}
|
||||
|
||||
// Looks up the variable handle that provides input to node with node_name,
|
||||
|
|
@ -209,7 +211,7 @@ StatusOr<string> GetHandleNameIfNeedsToFreeze(
|
|||
if (var_handle_name.ok() && variable_node_names.count(*var_handle_name)) {
|
||||
return var_handle_name;
|
||||
}
|
||||
return errors::NotFound("No VarHandleOp ancestor found");
|
||||
return absl::NotFoundError("No VarHandleOp ancestor found");
|
||||
}
|
||||
|
||||
// Freezes the subgraph of all nodes needed by `outputs`.
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
|
@ -24,6 +25,7 @@ limitations under the License.
|
|||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
|
|
@ -312,6 +314,74 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
|||
return OkStatus();
|
||||
}
|
||||
|
||||
// Generate shape infos for args (inputs).
|
||||
Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) {
|
||||
for (int i = 0; i < ps.parameters_size(); ++i) {
|
||||
const xla::ShapeProto& shape = ps.parameters(i);
|
||||
if (shape.element_type() == xla::TUPLE) {
|
||||
// ShapeInfo cannot represent tuple args.
|
||||
return absl::InternalError(
|
||||
absl::StrCat("parameter ", i,
|
||||
": codegen requires XLA parameters to "
|
||||
"be non-tuples."));
|
||||
}
|
||||
// Please some compilers (e.g. MSVC) by avoiding the initialization of an
|
||||
// array of unknown size an empty initializer. Use "-1" for this; note that
|
||||
// this value is never used (the size attribute is set to 0 in ShapeInfo).
|
||||
*infos += absl::Substitute(R"( static constexpr int32_t kArg$0Shapes[] = {
|
||||
$1
|
||||
};
|
||||
)",
|
||||
i,
|
||||
shape.dimensions_size() > 0
|
||||
? absl::StrJoin(shape.dimensions(), ", ")
|
||||
: "-1");
|
||||
}
|
||||
*infos += R"( static const ShapeInfo* ArgShapeInfos() {
|
||||
static constexpr ShapeInfo kArgShapeInfoTable[kNumArgs] = {
|
||||
)";
|
||||
for (int i = 0; i < ps.parameters_size(); ++i) {
|
||||
const xla::ShapeProto& shape = ps.parameters(i);
|
||||
*infos +=
|
||||
absl::Substitute("{ kArg$0Shapes, $1 },\n", i, shape.dimensions_size());
|
||||
}
|
||||
*infos += R"( };
|
||||
return kArgShapeInfoTable;
|
||||
})";
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
// Generate shape infos for results.
|
||||
Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, string* infos) {
|
||||
if (ps.result().element_type() != xla::TUPLE) {
|
||||
return absl::InternalError("codegen requires the XLA result to be a tuple");
|
||||
}
|
||||
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
|
||||
const xla::ShapeProto& shape = ps.result().tuple_shapes(i);
|
||||
// See above comment about the use here of "-1".
|
||||
*infos += absl::Substitute(
|
||||
R"( static constexpr int32_t kResult$0Shapes[] = {
|
||||
$1
|
||||
};
|
||||
)",
|
||||
i,
|
||||
shape.dimensions_size() > 0 ? absl::StrJoin(shape.dimensions(), ", ")
|
||||
: "-1");
|
||||
}
|
||||
*infos += R"( static const ShapeInfo* ResultShapeInfos() {
|
||||
static constexpr ShapeInfo kResultShapeInfoTable[kNumResults] = {
|
||||
)";
|
||||
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
|
||||
const xla::ShapeProto& shape = ps.result().tuple_shapes(i);
|
||||
*infos += absl::Substitute("{ kResult$0Shapes, $1 },\n", i,
|
||||
shape.dimensions_size());
|
||||
}
|
||||
*infos += R"( };
|
||||
return kResultShapeInfoTable;
|
||||
})";
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
// Generates code implementing {Arg,Result}Names(), where T is one of
|
||||
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
|
||||
// string literal in the array, with nullptr terminating the array.
|
||||
|
|
@ -377,17 +447,27 @@ std::vector<string> BufferInfosToCppExpression(
|
|||
std::transform(buffer_infos.begin(), buffer_infos.end(),
|
||||
std::back_inserter(buffer_infos_as_strings),
|
||||
[](const BufferInfo& buffer_info) {
|
||||
std::pair<uint64, uint64> encoded = buffer_info.Encode();
|
||||
string encoded_second_as_str =
|
||||
encoded.second == ~0ULL
|
||||
? "~0ULL"
|
||||
: absl::StrCat(encoded.second, "ULL");
|
||||
xla::cpu_function_runtime::EncodedBufferInfo encoded =
|
||||
buffer_info.Encode();
|
||||
auto param_to_str = [](uint32_t param) -> std::string {
|
||||
return param == ~0U ? "~0U" : absl::StrCat(param, "U");
|
||||
};
|
||||
return absl::StrCat(
|
||||
"::xla::cpu_function_runtime::BufferInfo({",
|
||||
encoded.first, "ULL, ", encoded_second_as_str, "})");
|
||||
"::xla::cpu_function_runtime::BufferInfo(",
|
||||
encoded.packed_kind_and_size, "ULL, ",
|
||||
param_to_str(encoded.entry_param_number), ", ",
|
||||
param_to_str(encoded.result_param_number), ")");
|
||||
});
|
||||
return buffer_infos_as_strings;
|
||||
}
|
||||
|
||||
Status CheckEqual(size_t a, size_t b, absl::string_view error_msg) {
|
||||
if (a != b) {
|
||||
return absl::InternalError(
|
||||
absl::StrCat(error_msg, ". Expected ", a, ", got ", b, "."));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
|
|
@ -400,6 +480,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||
compile_result.aot->buffer_infos();
|
||||
const std::vector<int32> arg_index_table =
|
||||
::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
|
||||
const std::vector<int32> result_index_table =
|
||||
::xla::cpu::CreateResultIndexTableFromBufferInfos(buffer_infos);
|
||||
std::vector<string> buffer_infos_as_strings =
|
||||
BufferInfosToCppExpression(buffer_infos);
|
||||
const int64_t buffer_infos_size = buffer_infos.size();
|
||||
|
|
@ -419,6 +501,15 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
||||
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
||||
TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable));
|
||||
string arg_shape_infos, result_shape_infos;
|
||||
TF_RETURN_IF_ERROR(GenArgShapeInfos(ps, &arg_shape_infos));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckEqual(ps.parameters_size(), arg_index_table.size(),
|
||||
"Arg number mismatch, proto vs. arg_index_table"));
|
||||
TF_RETURN_IF_ERROR(GenResultShapeInfos(ps, &result_shape_infos));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckEqual(ps.result().tuple_shapes_size(), result_index_table.size(),
|
||||
"Result number mismatch, proto vs. result_index_table"));
|
||||
const size_t arg_bytes_aligned =
|
||||
xla::cpu_function_runtime::AlignedBufferBytes(
|
||||
buffer_infos_for_args.data(), buffer_infos_for_args.size(),
|
||||
|
|
@ -544,6 +635,8 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = {{ARG_NUM}};
|
||||
|
||||
static constexpr size_t kNumResults = {{RESULT_NUM}};
|
||||
|
||||
// Number of variables for the compiled computation.
|
||||
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
|
||||
|
||||
|
|
@ -560,16 +653,21 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||
set_static_data_raw_function(data, {{ENTRY}});
|
||||
set_static_data_buffer_infos(data, BufferInfos());
|
||||
set_static_data_num_buffers(data, kNumBuffers);
|
||||
set_static_data_result_index_table(data, ResultIndexToBufferIndex());
|
||||
set_static_data_num_results(data, kNumResults);
|
||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||
set_static_data_num_args(data, kNumArgs);
|
||||
set_static_data_num_variables(data, kNumVariables);
|
||||
set_static_data_result_index(data, kResultIndex);
|
||||
set_static_data_arg_shape_infos(data, ArgShapeInfos());
|
||||
set_static_data_result_shape_infos(data, ResultShapeInfos());
|
||||
set_static_data_arg_names(data, StaticArgNames());
|
||||
set_static_data_variable_names(data, StaticVariableNames());
|
||||
set_static_data_result_names(data, StaticResultNames());
|
||||
set_static_data_program_shape(data, StaticProgramShape());
|
||||
set_static_data_hlo_profile_printer_data(
|
||||
data, StaticHloProfilePrinterData());
|
||||
set_static_data_use_xla_runtime(data, {{USE_XLA_RUNTIME}});
|
||||
{{ASSIGN_PROFILE_COUNTERS_SIZE}}
|
||||
return data;
|
||||
}();
|
||||
|
|
@ -589,7 +687,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||
//
|
||||
// void set_argN_data(void* data)
|
||||
// Sets the buffer of type T for positional argument N. May be called in
|
||||
// any AllocMode. Must be called before Run to have an affect. Must be
|
||||
// any AllocMode. Must be called before Run to have an effect. Must be
|
||||
// called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
|
||||
// argument, to set the argument buffers.
|
||||
//
|
||||
|
|
@ -655,6 +753,13 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||
return kBufferInfos;
|
||||
}
|
||||
|
||||
static const ::tensorflow::int32* ResultIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kResultIndexToBufferIndex[kNumResults] = {
|
||||
{{RESULT_INDEX_TABLE}}
|
||||
};
|
||||
return kResultIndexToBufferIndex;
|
||||
}
|
||||
|
||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||
{{ARG_INDEX_TABLE}}
|
||||
|
|
@ -665,6 +770,12 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||
// The 0-based index of the result tuple in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
|
||||
|
||||
// Shapes of the input arguments.
|
||||
{{ARG_SHAPE_INFOS}};
|
||||
|
||||
// Shapes of the results.
|
||||
{{RESULT_SHAPE_INFOS}};
|
||||
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
|
||||
|
||||
|
|
@ -699,13 +810,18 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
|
||||
{"{{ARG_NAMES_CODE}}", arg_names_code},
|
||||
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
|
||||
{"{{ARG_SHAPE_INFOS}}", arg_shape_infos},
|
||||
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
|
||||
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
|
||||
{"{{RESULT_NUM}}", absl::StrCat(result_index_table.size())},
|
||||
{"{{RESULT_INDEX_TABLE}}", absl::StrJoin(result_index_table, ", ")},
|
||||
|
||||
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
{"{{DECLS_FROM_OBJ_FILE}}",
|
||||
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
||||
{"{{ENTRY}}", compile_result.entry_point},
|
||||
{"{{USE_XLA_RUNTIME}}", opts.use_xla_runtime ? "true" : "false"},
|
||||
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
||||
metadata_result.hlo_profile_printer_data_access_shim},
|
||||
{"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
|
||||
|
|
@ -722,6 +838,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
|
||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||
{"{{RESULT_NAMES_CODE}}", result_names_code},
|
||||
{"{{RESULT_SHAPE_INFOS}}", result_shape_infos},
|
||||
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
|
||||
{"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
|
||||
{"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
|
||||
|
|
@ -749,7 +866,7 @@ Status GenerateMetadata(const CodegenOpts& opts,
|
|||
|
||||
if (opts.gen_program_shape) {
|
||||
program_shape =
|
||||
absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
|
||||
std::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
|
||||
|
||||
// The parameter names are currently meaningless, and redundant with the
|
||||
// rest of our metadata, so clear them out to avoid confusion and save
|
||||
|
|
|
|||
|
|
@ -48,6 +48,9 @@ struct CodegenOpts {
|
|||
// If true, emit a serialized HloProfilePrinterData protobuf that can be used
|
||||
// to pretty print HLO profile counters.
|
||||
bool gen_hlo_profile_printer_data = false;
|
||||
|
||||
// If true, sets this executable as an XLA Runtime one.
|
||||
bool use_xla_runtime = false;
|
||||
};
|
||||
|
||||
// Describes a generated metadata object file.
|
||||
|
|
|
|||
|
|
@ -215,18 +215,22 @@ TEST(CodegenTest, Golden) {
|
|||
CompileResult compile_result;
|
||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||
{},
|
||||
{BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
|
||||
{BufferInfo::MakeTempBuffer(3 * 8),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/8, /*entry_param_number=*/0),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/1),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/2),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
|
||||
BufferInfo::MakeTempBuffer(1),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
|
||||
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
|
||||
11, {}));
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/3),
|
||||
BufferInfo::MakeResultParameter(/*size=*/5 * 6 * 4,
|
||||
/*result_param_number=*/0),
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/4),
|
||||
BufferInfo::MakeResultParameter(/*size=*/1 * 4,
|
||||
/*result_param_number=*/1),
|
||||
BufferInfo::MakeResultParameter(/*size=*/5 * 4,
|
||||
/*result_param_number=*/2)},
|
||||
0, {}));
|
||||
compile_result.program_shape =
|
||||
xla::ShapeUtil::MakeProgramShape(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -58,13 +58,15 @@ namespace bar {
|
|||
// Memory stats:
|
||||
// arg bytes total: 392
|
||||
// arg bytes aligned: 576
|
||||
// temp bytes total: 126
|
||||
// temp bytes total: 171
|
||||
// temp bytes aligned: 512
|
||||
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
public:
|
||||
// Number of input arguments for the compiled computation.
|
||||
static constexpr size_t kNumArgs = 5;
|
||||
|
||||
static constexpr size_t kNumResults = 3;
|
||||
|
||||
// Number of variables for the compiled computation.
|
||||
static constexpr size_t kNumVariables = 3;
|
||||
|
||||
|
|
@ -81,16 +83,21 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||
set_static_data_raw_function(data, entry_point);
|
||||
set_static_data_buffer_infos(data, BufferInfos());
|
||||
set_static_data_num_buffers(data, kNumBuffers);
|
||||
set_static_data_result_index_table(data, ResultIndexToBufferIndex());
|
||||
set_static_data_num_results(data, kNumResults);
|
||||
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
|
||||
set_static_data_num_args(data, kNumArgs);
|
||||
set_static_data_num_variables(data, kNumVariables);
|
||||
set_static_data_result_index(data, kResultIndex);
|
||||
set_static_data_arg_shape_infos(data, ArgShapeInfos());
|
||||
set_static_data_result_shape_infos(data, ResultShapeInfos());
|
||||
set_static_data_arg_names(data, StaticArgNames());
|
||||
set_static_data_variable_names(data, StaticVariableNames());
|
||||
set_static_data_result_names(data, StaticResultNames());
|
||||
set_static_data_program_shape(data, StaticProgramShape());
|
||||
set_static_data_hlo_profile_printer_data(
|
||||
data, StaticHloProfilePrinterData());
|
||||
set_static_data_use_xla_runtime(data, false);
|
||||
|
||||
return data;
|
||||
}();
|
||||
|
|
@ -110,7 +117,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||
//
|
||||
// void set_argN_data(void* data)
|
||||
// Sets the buffer of type T for positional argument N. May be called in
|
||||
// any AllocMode. Must be called before Run to have an affect. Must be
|
||||
// any AllocMode. Must be called before Run to have an effect. Must be
|
||||
// called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
|
||||
// argument, to set the argument buffers.
|
||||
//
|
||||
|
|
@ -354,22 +361,29 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
|
||||
static const ::xla::cpu_function_runtime::BufferInfo
|
||||
kBufferInfos[kNumBuffers] = {
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
|
||||
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
|
||||
::xla::cpu_function_runtime::BufferInfo(97ULL, ~0U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(34ULL, 0U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(386ULL, 1U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(386ULL, 2U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(386ULL, 3U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(481ULL, ~0U, 0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(386ULL, 4U, ~0U),
|
||||
::xla::cpu_function_runtime::BufferInfo(17ULL, ~0U, 1U),
|
||||
::xla::cpu_function_runtime::BufferInfo(81ULL, ~0U, 2U)
|
||||
};
|
||||
return kBufferInfos;
|
||||
}
|
||||
|
||||
static const ::tensorflow::int32* ResultIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kResultIndexToBufferIndex[kNumResults] = {
|
||||
8, 10, 11
|
||||
};
|
||||
return kResultIndexToBufferIndex;
|
||||
}
|
||||
|
||||
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
|
||||
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
|
||||
1, 3, 5, 7, 9
|
||||
|
|
@ -378,7 +392,53 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||
}
|
||||
|
||||
// The 0-based index of the result tuple in the temporary buffers.
|
||||
static constexpr size_t kResultIndex = 11;
|
||||
static constexpr size_t kResultIndex = 0;
|
||||
|
||||
// Shapes of the input arguments.
|
||||
static constexpr int32_t kArg0Shapes[] = {
|
||||
1, 2
|
||||
};
|
||||
static constexpr int32_t kArg1Shapes[] = {
|
||||
3, 4
|
||||
};
|
||||
static constexpr int32_t kArg2Shapes[] = {
|
||||
1
|
||||
};
|
||||
static constexpr int32_t kArg3Shapes[] = {
|
||||
1
|
||||
};
|
||||
static constexpr int32_t kArg4Shapes[] = {
|
||||
5
|
||||
};
|
||||
static const ShapeInfo* ArgShapeInfos() {
|
||||
static constexpr ShapeInfo kArgShapeInfoTable[kNumArgs] = {
|
||||
{ kArg0Shapes, 2 },
|
||||
{ kArg1Shapes, 2 },
|
||||
{ kArg2Shapes, 1 },
|
||||
{ kArg3Shapes, 1 },
|
||||
{ kArg4Shapes, 1 },
|
||||
};
|
||||
return kArgShapeInfoTable;
|
||||
};
|
||||
|
||||
// Shapes of the results.
|
||||
static constexpr int32_t kResult0Shapes[] = {
|
||||
5, 6
|
||||
};
|
||||
static constexpr int32_t kResult1Shapes[] = {
|
||||
1
|
||||
};
|
||||
static constexpr int32_t kResult2Shapes[] = {
|
||||
5
|
||||
};
|
||||
static const ShapeInfo* ResultShapeInfos() {
|
||||
static constexpr ShapeInfo kResultShapeInfoTable[kNumResults] = {
|
||||
{ kResult0Shapes, 2 },
|
||||
{ kResult1Shapes, 1 },
|
||||
{ kResult2Shapes, 1 },
|
||||
};
|
||||
return kResultShapeInfoTable;
|
||||
};
|
||||
|
||||
// Array of names of each positional argument, terminated by nullptr.
|
||||
static const char** StaticArgNames() {
|
||||
|
|
|
|||
|
|
@ -273,6 +273,15 @@ Status Main(const MainFlags& flags) {
|
|||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
codegen_opts.target_triple = flags.target_triple;
|
||||
// Set the XLA Runtime bit if this is an HloLowering.
|
||||
if (!flags.mlir_components.empty() && flags.mlir_components != "None") {
|
||||
for (auto component : absl::StrSplit(flags.mlir_components, ',')) {
|
||||
if (component == "HloLowering") {
|
||||
codegen_opts.use_xla_runtime = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (flags.cpp_class.empty()) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,18 +69,18 @@ py_binary(
|
|||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:cond",
|
||||
"//tensorflow/python:control_flow_assert",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_v1",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/client",
|
||||
"//tensorflow/python/client:session",
|
||||
"//tensorflow/python/framework:for_generated_wrappers",
|
||||
"//tensorflow/python/ops:array_ops",
|
||||
"//tensorflow/python/ops:cond",
|
||||
"//tensorflow/python/ops:control_flow_assert",
|
||||
"//tensorflow/python/ops:control_flow_ops",
|
||||
"//tensorflow/python/ops:math_ops",
|
||||
"//tensorflow/python/ops:nn_ops",
|
||||
"//tensorflow/python/ops:variable_v1",
|
||||
"//tensorflow/python/ops:variables",
|
||||
"//tensorflow/python/training",
|
||||
"@absl_py//absl:app",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
|
@ -437,6 +437,7 @@ tf_cc_test(
|
|||
tags = [
|
||||
"manual",
|
||||
"no_mac", # TODO(b/228273415)
|
||||
"not_run:arm",
|
||||
],
|
||||
deps = [
|
||||
":test_graph_tfadd",
|
||||
|
|
@ -510,6 +511,7 @@ tf_cc_test(
|
|||
tags = [
|
||||
"manual",
|
||||
"no_mac", # TODO(b/228273415)
|
||||
"not_run:arm",
|
||||
],
|
||||
deps = [
|
||||
":test_graph_tfadd_mlir_bridge",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#define EIGEN_USE_THREADS
|
||||
#define EIGEN_USE_CUSTOM_THREAD_POOL
|
||||
|
||||
|
|
|
|||
|
|
@ -319,6 +319,8 @@ def _tf_library(
|
|||
] or []) + (include_standard_runtime_deps and [
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually
|
||||
# needed.
|
||||
"//tensorflow/compiler/xla/service/cpu/runtime:convolution_ffi",
|
||||
"//tensorflow/compiler/xla/service/cpu/runtime:rng_ffi",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_custom_call_status",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
|
||||
|
|
@ -329,6 +331,7 @@ def _tf_library(
|
|||
"//third_party/eigen3",
|
||||
] or []) + (
|
||||
mlir_components.count("HloLowering") > 0 and [
|
||||
"//tensorflow/compiler/xla/runtime:aot_ffi_c_symbols",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_mlir_utils",
|
||||
] or []
|
||||
) + (
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_cuda_only_cc_test")
|
||||
load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||
load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
|
@ -352,8 +352,10 @@ cc_library(
|
|||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla/hlo/ir:hlo",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:tf_pjrt_client",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/core/tfrt/common:create_pjrt_client_util",
|
||||
"//tensorflow/core/tfrt/common:global_state",
|
||||
"//tensorflow/core/tfrt/common:pjrt_util",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"@com_google_absl//absl/log",
|
||||
|
|
@ -595,7 +597,8 @@ tf_cc_test(
|
|||
"//tensorflow/core/framework:fake_input",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -782,6 +785,7 @@ tf_cc_test(
|
|||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -1509,6 +1513,70 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_host_recv_device_context",
|
||||
srcs = [
|
||||
"xla_host_recv_device_context.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_host_recv_device_context.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/stream_executor",
|
||||
"//tensorflow/compiler/xla/stream_executor:device_memory",
|
||||
"//tensorflow/core:framework",
|
||||
"@tf_runtime//:async_value",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_host_send_device_context",
|
||||
srcs = [
|
||||
"xla_host_send_device_context.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_host_send_device_context.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/stream_executor",
|
||||
"//tensorflow/compiler/xla/stream_executor:device_memory",
|
||||
"//tensorflow/core:framework",
|
||||
"@tf_runtime//:async_value",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_only_cc_test(
|
||||
name = "xla_host_send_recv_device_context_test",
|
||||
srcs = ["xla_host_send_recv_device_context_test.cc"],
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"config-cuda-only",
|
||||
"no_oss", # Temporarily disable OSS.
|
||||
],
|
||||
deps = [
|
||||
":flags",
|
||||
":xla_device",
|
||||
":xla_gpu_device",
|
||||
":xla_host_recv_device_context",
|
||||
":xla_host_send_device_context",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/stream_executor",
|
||||
"//tensorflow/compiler/xla/stream_executor:device_memory",
|
||||
"//tensorflow/compiler/xla/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/framework:tensor_testutil",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "device_compilation_cluster_signature_test",
|
||||
srcs = [
|
||||
|
|
@ -1527,6 +1595,9 @@ tf_cc_test(
|
|||
tf_cc_test(
|
||||
name = "device_compilation_profiler_test",
|
||||
srcs = ["device_compilation_profiler_test.cc"],
|
||||
tags = [
|
||||
"nomsan", # TODO(b/284492454)
|
||||
],
|
||||
deps = [
|
||||
":device_compilation_profiler",
|
||||
":xla_activity_proto_cc",
|
||||
|
|
@ -1641,6 +1712,7 @@ tf_cuda_cc_test(
|
|||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":flags",
|
||||
":pjrt_device_compiler_client",
|
||||
":test_util",
|
||||
":xla_device_no_jit_rewrite_registration",
|
||||
":xla_gpu_device",
|
||||
|
|
@ -1649,6 +1721,7 @@ tf_cuda_cc_test(
|
|||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_types_hdr",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
|
@ -188,7 +188,7 @@ class Encapsulator {
|
|||
|
||||
// Adds the function call node to graph_out.
|
||||
Status AddFunctionCallNode(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out);
|
||||
|
||||
// Returns the Node that the inputs and outputs of the function should be
|
||||
|
|
@ -206,7 +206,7 @@ class Encapsulator {
|
|||
// and adds the edge within the subgraph from the _Arg node to the image of
|
||||
// the dst node.
|
||||
Status RecordArg(const Edge* edge,
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
|
||||
|
||||
// Records the src of the given edge as a control result of the graph.
|
||||
|
|
@ -214,14 +214,14 @@ class Encapsulator {
|
|||
// the function signature.
|
||||
Status RecordControlResult(
|
||||
const Edge* edge,
|
||||
const std::unordered_map<const Node*, Node*>& node_images);
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images);
|
||||
|
||||
// Creates a _Retval node for the src node of edge, and add it to results_,
|
||||
// if none exists yet. If a new _Retval node is created, also adds the edge
|
||||
// within the subgraph from the src to the _Retval node.
|
||||
Status RecordResult(
|
||||
const Edge* edge,
|
||||
const std::unordered_map<const Node*, Node*>& node_images);
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images);
|
||||
|
||||
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
|
||||
Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
|
||||
|
|
@ -260,14 +260,14 @@ class Encapsulator {
|
|||
// (consumer node/slot) tensors in the input graph to _Arg numbers in
|
||||
// the subgraph. The source map is one-to-one, whereas the dest map may be
|
||||
// many-to-one.
|
||||
std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
|
||||
std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
|
||||
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
|
||||
absl::flat_hash_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
|
||||
|
||||
// The arguments to the subgraph, in order.
|
||||
std::vector<Node*> args_;
|
||||
|
||||
// Map from source tensor in the input graph to result #.
|
||||
std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
|
||||
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash> results_;
|
||||
|
||||
// Set of node names that are the source of a control output of the
|
||||
// subgraph. We store strings here so that we can tolerate nodes being
|
||||
|
|
@ -285,19 +285,20 @@ class Encapsulator {
|
|||
// Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
|
||||
// subgraphs for data edges that cross subgraph boundaries.
|
||||
Status CopySubgraphEdges(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
|
||||
|
||||
// Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
|
||||
Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
|
||||
Status CopySubgraphNodes(
|
||||
absl::flat_hash_map<const Node*, Node*>* node_images);
|
||||
|
||||
// Copies all nodes that aren't in a compiled subgraph to the output graph.
|
||||
Status CopyNodesToOutputGraph(
|
||||
Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
|
||||
Graph* graph_out, absl::flat_hash_map<const Node*, Node*>* node_images);
|
||||
|
||||
// Adds function call nodes for each compiled subgraph.
|
||||
Status AddFunctionCallNodes(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out);
|
||||
|
||||
// Finds the image of an edge source in the output graph. If the edge crosses
|
||||
|
|
@ -305,7 +306,7 @@ class Encapsulator {
|
|||
// in the output graph.
|
||||
Status FindOutputImageOfEdgeSrc(
|
||||
const string& src_func_id, const string& dst_func_id,
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
const Node* original_src_node, Node** src_image);
|
||||
|
||||
// Finds an edge source slot in the output graph. If the edge crosses a
|
||||
|
|
@ -320,7 +321,7 @@ class Encapsulator {
|
|||
// a node in the output graph.
|
||||
Status FindOutputImageOfEdgeDst(
|
||||
const string& src_func_id, const string& dst_func_id,
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
const Node* original_dst_node, Node** dst_image);
|
||||
|
||||
// Finds an edge destination slot in the output graph. If the edge crosses a
|
||||
|
|
@ -334,14 +335,14 @@ class Encapsulator {
|
|||
// within the output graph, or crosses into or out of a compiled subgraph.
|
||||
Status CopyEdgeToOutputGraph(
|
||||
const Edge* edge, const string& src_func_id, const string& dst_func_id,
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out,
|
||||
std::unordered_set<std::pair<OutputTensor, InputTensor>,
|
||||
OutputInputTensorPairHasher>* edges_added);
|
||||
absl::flat_hash_set<std::pair<OutputTensor, InputTensor>,
|
||||
OutputInputTensorPairHasher>* edges_added);
|
||||
|
||||
// Adds all edges to the output graph.
|
||||
Status AddEdgesToOutputGraph(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out);
|
||||
|
||||
// Makes a copy of graph containing only nodes that are ancestors of at least
|
||||
|
|
@ -351,13 +352,13 @@ class Encapsulator {
|
|||
Status MakePrunedGraphCopyAndInline(
|
||||
const Graph& graph, const std::vector<Node*>& sink_nodes,
|
||||
std::unique_ptr<Graph>* pruned_graph,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
absl::flat_hash_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
const string group_attribute_;
|
||||
const Graph* graph_in_;
|
||||
|
||||
std::unordered_map<string, Subgraph> subgraphs_;
|
||||
absl::flat_hash_map<string, Subgraph> subgraphs_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
|
||||
};
|
||||
|
|
@ -369,9 +370,9 @@ namespace {
|
|||
// including clusters that are not present in the ancestors map. has_successors
|
||||
// is the set of clusters that are ancestors of some other cluster.
|
||||
void TopologicalClusterSort(
|
||||
const std::unordered_set<string>& clusters,
|
||||
const std::unordered_set<string>& has_successors,
|
||||
const std::unordered_map<string, std::unordered_set<string>>& ancestors,
|
||||
const absl::flat_hash_set<string>& clusters,
|
||||
const absl::flat_hash_set<string>& has_successors,
|
||||
const absl::flat_hash_map<string, absl::flat_hash_set<string>>& ancestors,
|
||||
std::vector<string>* sorted) {
|
||||
// The nodes are placed in 'sorted' in topological order.
|
||||
sorted->clear();
|
||||
|
|
@ -447,11 +448,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
|
|||
Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
|
||||
|
||||
Status Encapsulator::Subgraph::RecordArg(
|
||||
const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const Edge* edge,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
|
||||
Node* src_node = edge->src();
|
||||
int src_slot = edge->src_output();
|
||||
std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
|
||||
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
|
||||
bool inserted;
|
||||
std::tie(iter, inserted) = args_by_src_.emplace(
|
||||
OutputTensor(src_node, src_slot), args_by_src_.size());
|
||||
|
|
@ -481,7 +483,7 @@ Status Encapsulator::Subgraph::RecordArg(
|
|||
|
||||
Status Encapsulator::Subgraph::RecordControlResult(
|
||||
const Edge* edge,
|
||||
const std::unordered_map<const Node*, Node*>& node_images) {
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images) {
|
||||
Node* src_node = edge->src();
|
||||
Node* src_image = node_images.at(src_node);
|
||||
control_output_nodes_.insert(src_image->name());
|
||||
|
|
@ -490,11 +492,11 @@ Status Encapsulator::Subgraph::RecordControlResult(
|
|||
|
||||
Status Encapsulator::Subgraph::RecordResult(
|
||||
const Edge* edge,
|
||||
const std::unordered_map<const Node*, Node*>& node_images) {
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images) {
|
||||
Node* src_node = edge->src();
|
||||
Node* src_image = node_images.at(src_node);
|
||||
int src_slot = edge->src_output();
|
||||
std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
|
||||
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
|
||||
bool inserted;
|
||||
std::tie(iter, inserted) =
|
||||
results_.emplace(OutputTensor(src_node, src_slot), results_.size());
|
||||
|
|
@ -592,7 +594,7 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
|
|||
FunctionDef fdef;
|
||||
auto lookup = [this](const Node* node) -> std::optional<string> {
|
||||
if (control_output_nodes_.contains(node->name())) {
|
||||
return absl::make_optional(node->name());
|
||||
return std::make_optional(node->name());
|
||||
}
|
||||
return std::nullopt;
|
||||
};
|
||||
|
|
@ -637,7 +639,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
|
|||
}
|
||||
|
||||
Status Encapsulator::Subgraph::AddFunctionCallNode(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out) {
|
||||
TF_ASSIGN_OR_RETURN(call_node_, graph_out->AddNode(call_node_def_));
|
||||
|
||||
|
|
@ -663,7 +665,7 @@ Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
|
|||
bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
|
||||
|
||||
Status Encapsulator::CopySubgraphNodes(
|
||||
std::unordered_map<const Node*, Node*>* node_images) {
|
||||
absl::flat_hash_map<const Node*, Node*>* node_images) {
|
||||
for (Node* node : graph_in_->op_nodes()) {
|
||||
string func_id;
|
||||
TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
|
||||
|
|
@ -678,7 +680,7 @@ Status Encapsulator::CopySubgraphNodes(
|
|||
}
|
||||
|
||||
Status Encapsulator::CopySubgraphEdges(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
|
||||
for (const Edge* edge : graph_in_->edges()) {
|
||||
string src_func_id;
|
||||
|
|
@ -752,7 +754,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
|
|||
Status s;
|
||||
|
||||
// Map from input graph nodes to subgraph nodes.
|
||||
std::unordered_map<const Node*, Node*> node_images;
|
||||
absl::flat_hash_map<const Node*, Node*> node_images;
|
||||
|
||||
// Each entry of src_arg_pairs is a pair whose first element is a node in the
|
||||
// original graph that has an output edge in the subgraph, and whose second
|
||||
|
|
@ -794,7 +796,7 @@ Status Encapsulator::BuildFunctionDefs(
|
|||
}
|
||||
|
||||
Status Encapsulator::CopyNodesToOutputGraph(
|
||||
Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
|
||||
Graph* graph_out, absl::flat_hash_map<const Node*, Node*>* node_images) {
|
||||
for (Node* node : graph_in_->op_nodes()) {
|
||||
string func_id;
|
||||
TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
|
||||
|
|
@ -811,7 +813,7 @@ Status Encapsulator::CopyNodesToOutputGraph(
|
|||
}
|
||||
|
||||
Status Encapsulator::AddFunctionCallNodes(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out) {
|
||||
for (auto& subgraph_entry : subgraphs_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
|
@ -822,7 +824,7 @@ Status Encapsulator::AddFunctionCallNodes(
|
|||
|
||||
Status Encapsulator::FindOutputImageOfEdgeSrc(
|
||||
const string& src_func_id, const string& dst_func_id,
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
const Node* original_src_node, Node** src_image) {
|
||||
if (IsInSubgraph(src_func_id)) {
|
||||
// The edge is from a subgraph to a regular node in the output graph so
|
||||
|
|
@ -853,7 +855,7 @@ int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
|
|||
|
||||
Status Encapsulator::FindOutputImageOfEdgeDst(
|
||||
const string& src_func_id, const string& dst_func_id,
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
const Node* original_dst_node, Node** dst_image) {
|
||||
if (IsInSubgraph(dst_func_id)) {
|
||||
// The edge is to a subgraph from a regular node in the output graph so
|
||||
|
|
@ -884,9 +886,10 @@ int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
|
|||
|
||||
Status Encapsulator::CopyEdgeToOutputGraph(
|
||||
const Edge* edge, const string& src_func_id, const string& dst_func_id,
|
||||
const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
|
||||
std::unordered_set<std::pair<OutputTensor, InputTensor>,
|
||||
OutputInputTensorPairHasher>* edges_added) {
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out,
|
||||
absl::flat_hash_set<std::pair<OutputTensor, InputTensor>,
|
||||
OutputInputTensorPairHasher>* edges_added) {
|
||||
Node* src_image;
|
||||
TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
|
||||
src_func_id, dst_func_id, node_images, edge->src(), &src_image));
|
||||
|
|
@ -924,13 +927,13 @@ Status Encapsulator::CopyEdgeToOutputGraph(
|
|||
}
|
||||
|
||||
Status Encapsulator::AddEdgesToOutputGraph(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
const absl::flat_hash_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out) {
|
||||
// Set of edges already added to the output graph, represented as (src, dst)
|
||||
// pairs. We use the set to deduplicate edges; multiple edges in the input
|
||||
// graph may map to one edge in the output graph.
|
||||
std::unordered_set<std::pair<OutputTensor, InputTensor>,
|
||||
OutputInputTensorPairHasher>
|
||||
absl::flat_hash_set<std::pair<OutputTensor, InputTensor>,
|
||||
OutputInputTensorPairHasher>
|
||||
edges_added;
|
||||
|
||||
for (const Edge* edge : graph_in_->edges()) {
|
||||
|
|
@ -1010,7 +1013,7 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port,
|
|||
Status Encapsulator::MakePrunedGraphCopyAndInline(
|
||||
const Graph& graph, const std::vector<Node*>& sink_nodes,
|
||||
std::unique_ptr<Graph>* pruned_graph,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
absl::flat_hash_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// First copy all ancestor nodes of sink_nodes into a new graph.
|
||||
pruned_graph->reset(new Graph(library));
|
||||
|
|
@ -1070,7 +1073,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
|
|||
Status Encapsulator::BuildOutputGraph(Graph* graph_out,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// Map from nodes in the input graph to nodes in the output graph.
|
||||
std::unordered_map<const Node*, Node*> node_images;
|
||||
absl::flat_hash_map<const Node*, Node*> node_images;
|
||||
|
||||
TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
|
||||
TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ BuildXlaOpsPassFlags* build_ops_flags;
|
|||
MarkForCompilationPassFlags* mark_for_compilation_flags;
|
||||
XlaDeviceFlags* device_flags;
|
||||
XlaOpsCommonFlags* ops_flags;
|
||||
XlaCallModuleFlags* call_module_flags;
|
||||
MlirCommonFlags* mlir_flags;
|
||||
JitRtFlags* jitrt_flags;
|
||||
std::vector<Flag>* jitrt_flag_list;
|
||||
|
|
@ -76,6 +77,13 @@ bool SetterForXlaAutoJitFlag(const string& value) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SetterForXlaCallModuleDisabledChecks(const string& value) {
|
||||
auto directives = absl::StrSplit(value, ',', absl::SkipEmpty());
|
||||
call_module_flags->disabled_checks.insert(directives.begin(),
|
||||
directives.end());
|
||||
return true;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
|
||||
std::vector<Flag> new_flags = {
|
||||
Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
|
||||
|
|
@ -184,6 +192,7 @@ void AllocateAndParseFlags() {
|
|||
build_ops_flags->tf_xla_check_cluster_input_numerics = false;
|
||||
build_ops_flags->tf_xla_check_cluster_output_numerics = false;
|
||||
build_ops_flags->tf_xla_disable_constant_folding = false;
|
||||
build_ops_flags->tf_xla_disable_full_embedding_pipelining = false;
|
||||
|
||||
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
||||
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
|
||||
|
|
@ -213,9 +222,12 @@ void AllocateAndParseFlags() {
|
|||
ops_flags = new XlaOpsCommonFlags;
|
||||
ops_flags->tf_xla_always_defer_compilation = false;
|
||||
ops_flags->tf_xla_async_compilation = false;
|
||||
ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_ = false;
|
||||
ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_ = false;
|
||||
ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_ = true;
|
||||
ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_ = true;
|
||||
ops_flags->tf_xla_use_device_api.enabled_for_compile_and_run_ = false;
|
||||
ops_flags->tf_xla_use_device_api.enabled_for_all_ = false;
|
||||
|
||||
call_module_flags = new XlaCallModuleFlags;
|
||||
// The `enable_mlir_bridge` flag allows the user to explicitly request that
|
||||
// their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
|
||||
//
|
||||
|
|
@ -251,6 +263,10 @@ void AllocateAndParseFlags() {
|
|||
&build_ops_flags->tf_xla_disable_constant_folding,
|
||||
"If true then disables constant folding on TF graph before XLA "
|
||||
"compilation."),
|
||||
Flag("tf_xla_disable_full_embedding_pipelining",
|
||||
&build_ops_flags->tf_xla_disable_full_embedding_pipelining,
|
||||
"If true then disables full embedding pipelining and instead use "
|
||||
"strict SparseCore / TensorCore sequencing."),
|
||||
|
||||
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
|
|
@ -277,6 +293,22 @@ void AllocateAndParseFlags() {
|
|||
&ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_,
|
||||
"If true, uses Device API (PjRt) for compiling and executing ops "
|
||||
"one by one in 'on-demand' mode. Defaults to false."),
|
||||
Flag("tf_xla_use_device_api_for_auto_jit",
|
||||
&ops_flags->tf_xla_use_device_api.enabled_for_compile_and_run_,
|
||||
"If true, uses Device API (PjRt) for compilation and execution "
|
||||
"when auto-clustering is enabled. Defaults to false."),
|
||||
Flag("tf_xla_use_device_api",
|
||||
&ops_flags->tf_xla_use_device_api.enabled_for_all_,
|
||||
"If true, uses Device API (PjRt) for compilation and execution "
|
||||
"of ops one-by-one in 'on-demand' mode, for functions marked for "
|
||||
"JIT compilation, or when auto-clustering is enabled. Defaults to "
|
||||
"false."),
|
||||
|
||||
Flag("tf_xla_call_module_disabled_checks",
|
||||
SetterForXlaCallModuleDisabledChecks, "",
|
||||
"A comma-sepated list of directives specifying the safety checks "
|
||||
"to be skipped when compiling XlaCallModuleOp. See the op "
|
||||
"documentation for the recognized values."),
|
||||
|
||||
Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
|
||||
|
|
@ -365,6 +397,11 @@ XlaOpsCommonFlags* GetXlaOpsCommonFlags() {
|
|||
return ops_flags;
|
||||
}
|
||||
|
||||
XlaCallModuleFlags* GetXlaCallModuleFlags() {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return call_module_flags;
|
||||
}
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags() {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return mlir_flags;
|
||||
|
|
|
|||
|
|
@ -137,8 +137,9 @@ struct XlaOpsCommonFlags {
|
|||
}
|
||||
|
||||
bool IsEnabledInXlaLaunchForDevice(const DeviceType& device_type) const {
|
||||
return enabled_for_xla_launch_ &&
|
||||
xla_launch_allowed_devices_.contains(device_type.type_string());
|
||||
return enabled_for_all_ ||
|
||||
(enabled_for_xla_launch_ &&
|
||||
xla_launch_allowed_devices_.contains(device_type.type_string()));
|
||||
}
|
||||
|
||||
// Allow using Device API (PjRt) for `device_type` in the XlaCompileOnDemand
|
||||
|
|
@ -152,9 +153,26 @@ struct XlaOpsCommonFlags {
|
|||
|
||||
bool IsEnabledInXlaCompileOnDemandForDevice(
|
||||
const DeviceType& device_type) const {
|
||||
return enabled_for_compile_on_demand_ &&
|
||||
xla_compile_on_demand_allowed_devices_.contains(
|
||||
device_type.type_string());
|
||||
return enabled_for_all_ ||
|
||||
(enabled_for_compile_on_demand_ &&
|
||||
xla_compile_on_demand_allowed_devices_.contains(
|
||||
device_type.type_string()));
|
||||
}
|
||||
|
||||
// Allow using Device API (PjRt) for `device_type` in the XlaCompile and
|
||||
// XlaRun ops. Please note that `enabled_for_compile_and_run_` needs to be
|
||||
// true in addition to the `device_type` being allowed in order to use the
|
||||
// Device API for single device compilation and execution in the XlaCompile
|
||||
// and XlaRun ops.
|
||||
void AllowForDeviceInXlaCompileAndRun(const DeviceType& device_type) {
|
||||
xla_compile_and_run_allowed_devices_.insert(device_type.type_string());
|
||||
}
|
||||
|
||||
bool IsEnabledInXlaCompileAndRunForDevice(
|
||||
const DeviceType& device_type) const {
|
||||
return enabled_for_all_ || (enabled_for_compile_and_run_ &&
|
||||
xla_compile_and_run_allowed_devices_.contains(
|
||||
device_type.type_string()));
|
||||
}
|
||||
|
||||
// If true, uses Device API (PjRt) for single device compilation and
|
||||
|
|
@ -166,6 +184,16 @@ struct XlaOpsCommonFlags {
|
|||
// one in "on-demand" mode. Defaults to false.
|
||||
bool enabled_for_compile_on_demand_;
|
||||
|
||||
// If true, uses Device API (PjRt) for compilation and execution when
|
||||
// auto-clustering is enabled. Defaults to false.
|
||||
bool enabled_for_compile_and_run_;
|
||||
|
||||
// If true, uses Device API (PjRt) for compilation and execution everywhere
|
||||
// i.e. for functions marked for JIT compilation, for ops in "on-demand"
|
||||
// mode and autoclustering, no matter whether other flags are enabled or
|
||||
// not, and whether devices have been allowed or not. Defaults to false.
|
||||
bool enabled_for_all_;
|
||||
|
||||
private:
|
||||
// Devices for which using Device API (PjRt) is allowed in the XlaLaunch op.
|
||||
// This can only be modified programmatically.
|
||||
|
|
@ -173,9 +201,18 @@ struct XlaOpsCommonFlags {
|
|||
// Devices for which using Device API (PjRt) is allowed in the
|
||||
// XlaCompileOnDemand op. This can only be modified programmatically.
|
||||
absl::flat_hash_set<std::string> xla_compile_on_demand_allowed_devices_;
|
||||
// Devices for which using Device API (PjRt) is allowed in the
|
||||
// XlaCompile and XlaRun ops. This can only be modified programmatically.
|
||||
absl::flat_hash_set<std::string> xla_compile_and_run_allowed_devices_;
|
||||
} tf_xla_use_device_api;
|
||||
};
|
||||
|
||||
// Flags for the XlaCallModule kernel.
|
||||
struct XlaCallModuleFlags {
|
||||
// Used by XlaCallModuleOp to specify safety checks to disable.
|
||||
absl::flat_hash_set<std::string> disabled_checks;
|
||||
};
|
||||
|
||||
// Flags for the build_xla_ops pass.
|
||||
struct BuildXlaOpsPassFlags {
|
||||
// Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
|
||||
|
|
@ -197,6 +234,10 @@ struct BuildXlaOpsPassFlags {
|
|||
// Disables all constant folding. The primary use for this is for testing to
|
||||
// guarantee that tests are run on XLA and not on TF's CPU implementation.
|
||||
bool tf_xla_disable_constant_folding;
|
||||
|
||||
// Disables full embedding pipelining when true. Instead, strict SparseCore
|
||||
// TensorCore sequencing will be used.
|
||||
bool tf_xla_disable_full_embedding_pipelining;
|
||||
};
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
|
|
@ -235,6 +276,7 @@ MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
|
|||
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
|
||||
XlaDeviceFlags* GetXlaDeviceFlags();
|
||||
XlaOpsCommonFlags* GetXlaOpsCommonFlags();
|
||||
XlaCallModuleFlags* GetXlaCallModuleFlags();
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags();
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ XLA_OPS_DEPS = [
|
|||
"//tensorflow/compiler/jit:variable_info_util",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/jit:xla_host_recv_device_context",
|
||||
"//tensorflow/compiler/jit:xla_host_send_device_context",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
|
|
@ -59,6 +61,8 @@ cc_library(
|
|||
"//tensorflow/compiler/jit:xla_compile_util",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,11 +20,15 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/device_compilation_profiler.h"
|
||||
#include "tensorflow/compiler/jit/device_compiler.h"
|
||||
|
|
@ -35,6 +39,8 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/jit/xla_compile_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_compiler_options_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_host_recv_device_context.h"
|
||||
#include "tensorflow/compiler/jit/xla_host_send_device_context.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
|
@ -92,10 +98,11 @@ auto* xla_launch_counter = monitoring::Counter<1>::New(
|
|||
// the initial values for the resource variables (and cannot snapshot them again
|
||||
// during execution) because otherwise we risk observing a different snapshot
|
||||
// with shapes different from what we compiled for.
|
||||
class XlaExecutableClosure {
|
||||
template <typename ExecutableType, typename ClientType>
|
||||
class ExecutableClosure {
|
||||
public:
|
||||
explicit XlaExecutableClosure(
|
||||
xla::LocalClient* client, xla::LocalExecutable* executable,
|
||||
explicit ExecutableClosure(
|
||||
ClientType* client, ExecutableType* executable,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
|
||||
: client_(client),
|
||||
|
|
@ -104,11 +111,11 @@ class XlaExecutableClosure {
|
|||
resource_var_snapshots_(std::move(resource_var_snapshots)),
|
||||
num_constant_args_(num_constant_args) {}
|
||||
|
||||
XlaExecutableClosure(XlaExecutableClosure&&) = default;
|
||||
XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
|
||||
ExecutableClosure(ExecutableClosure&&) = default;
|
||||
ExecutableClosure& operator=(ExecutableClosure&&) = default;
|
||||
|
||||
xla::LocalClient* client() const { return client_; }
|
||||
xla::LocalExecutable* executable() const { return executable_; }
|
||||
ClientType* client() const { return client_; }
|
||||
ExecutableType* executable() const { return executable_; }
|
||||
const XlaCompiler::CompilationResult* compilation_result() const {
|
||||
return compilation_result_;
|
||||
}
|
||||
|
|
@ -118,24 +125,25 @@ class XlaExecutableClosure {
|
|||
int num_constant_args() const { return num_constant_args_; }
|
||||
|
||||
private:
|
||||
xla::LocalClient* client_;
|
||||
xla::LocalExecutable* executable_;
|
||||
ClientType* client_;
|
||||
ExecutableType* executable_;
|
||||
const XlaCompiler::CompilationResult* compilation_result_;
|
||||
ResourceVarsSnapshot resource_var_snapshots_;
|
||||
int num_constant_args_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ExecutableClosure);
|
||||
};
|
||||
|
||||
// This maintains a mapping from a globally unique ID to XlaExecutableClosure
|
||||
// This maintains a mapping from a globally unique ID to ExecutableClosure
|
||||
// instances.
|
||||
class XlaExecutableClosureStore {
|
||||
template <typename ExecutableType, typename ClientType>
|
||||
class ExecutableClosureStore {
|
||||
public:
|
||||
XlaExecutableClosureStore() : key_counter_(0) {}
|
||||
ExecutableClosureStore() : key_counter_(0) {}
|
||||
|
||||
using KeyT = string;
|
||||
|
||||
KeyT Produce(XlaExecutableClosure result) {
|
||||
KeyT Produce(ExecutableClosure<ExecutableType, ClientType> result) {
|
||||
mutex_lock l(mutex_);
|
||||
KeyT key = absl::StrCat(key_counter_++);
|
||||
bool insert_successful = closures_.emplace(key, std::move(result)).second;
|
||||
|
|
@ -144,29 +152,38 @@ class XlaExecutableClosureStore {
|
|||
return key;
|
||||
}
|
||||
|
||||
XlaExecutableClosure Consume(const KeyT& key) {
|
||||
ExecutableClosure<ExecutableType, ClientType> Consume(const KeyT& key) {
|
||||
mutex_lock l(mutex_);
|
||||
auto it = closures_.find(key);
|
||||
DCHECK(it != closures_.end());
|
||||
XlaExecutableClosure value = std::move(it->second);
|
||||
ExecutableClosure<ExecutableType, ClientType> value = std::move(it->second);
|
||||
closures_.erase(it);
|
||||
return value;
|
||||
}
|
||||
|
||||
static XlaExecutableClosureStore* Global() {
|
||||
static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
|
||||
static ExecutableClosureStore* Global() {
|
||||
static ExecutableClosureStore* instance = new ExecutableClosureStore;
|
||||
return instance;
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mutex_;
|
||||
int64_t key_counter_ TF_GUARDED_BY(mutex_);
|
||||
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
|
||||
TF_GUARDED_BY(mutex_);
|
||||
absl::flat_hash_map<KeyT, ExecutableClosure<ExecutableType, ClientType>>
|
||||
closures_ TF_GUARDED_BY(mutex_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ExecutableClosureStore);
|
||||
};
|
||||
|
||||
using XlaExecutableClosure =
|
||||
ExecutableClosure<xla::LocalExecutable, xla::LocalClient>;
|
||||
using XlaExecutableClosureStore =
|
||||
ExecutableClosureStore<xla::LocalExecutable, xla::LocalClient>;
|
||||
using PjRtExecutableClosure =
|
||||
ExecutableClosure<xla::PjRtLoadedExecutable, xla::PjRtClient>;
|
||||
using PjRtExecutableClosureStore =
|
||||
ExecutableClosureStore<xla::PjRtLoadedExecutable, xla::PjRtClient>;
|
||||
|
||||
se::Stream* GetStream(OpKernelContext* ctx) {
|
||||
return ctx->op_device_context() ? ctx->op_device_context()->stream()
|
||||
: nullptr;
|
||||
|
|
@ -185,6 +202,111 @@ XlaComputationLaunchContext GetLaunchContext(
|
|||
return launch_context;
|
||||
}
|
||||
|
||||
Status GetTaskName(const std::string_view device_name, std::string* task_name) {
|
||||
string ignored;
|
||||
if (!DeviceNameUtils::SplitDeviceName(device_name, task_name, &ignored)) {
|
||||
return errors::InvalidArgument("Unable to parse device name: ",
|
||||
device_name);
|
||||
}
|
||||
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
// Provide SendDeviceMemoryFunction for XLA host callbacks. This callback
|
||||
// handles transferring from device to host.
|
||||
xla::SendDeviceMemoryFunction GetSendDeviceMemoryFunction(
|
||||
OpKernelContext* ctx) {
|
||||
return
|
||||
[ctx](int64_t channel_id, se::Stream* stream, const xla::Shape& shape,
|
||||
const se::DeviceMemoryBase& device_memory_base,
|
||||
const absl::flat_hash_map<std::string, std::string>& frontend_attrs)
|
||||
-> StatusOr<tsl::AsyncValueRef<se::Event>> {
|
||||
auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous");
|
||||
|
||||
// Generate the Rendezvous key.
|
||||
const std::string& rendezvous_key_base = iter->second;
|
||||
const std::string& src_device = ctx->device()->name();
|
||||
|
||||
std::string task_prefix;
|
||||
TF_RETURN_IF_ERROR(GetTaskName(src_device, &task_prefix));
|
||||
const std::string dst_device =
|
||||
absl::StrCat(task_prefix, "/device:CPU:0");
|
||||
const std::string& rendezvous_key =
|
||||
Rendezvous::CreateKey(src_device, /*src_incarnation=*/1, dst_device,
|
||||
rendezvous_key_base, FrameAndIter(0, 0));
|
||||
VLOG(2) << "Rendezvous Key for receiving at host: " << rendezvous_key;
|
||||
|
||||
RendezvousInterface::ParsedKey parsed_key;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key));
|
||||
|
||||
tsl::AsyncValueRef<se::Event> done_event =
|
||||
tsl::MakeConstructedAsyncValueRef<se::Event>(stream->parent());
|
||||
if (!done_event->Init()) {
|
||||
return errors::Internal(
|
||||
"Failed to initialize done event (channel_id=%d)", channel_id);
|
||||
}
|
||||
|
||||
Rendezvous::Args args;
|
||||
// Rendezvous::Args owns the device context pointer.
|
||||
args.device_context = new XlaHostRecvDeviceContext(
|
||||
stream, device_memory_base, shape, done_event);
|
||||
|
||||
Tensor host_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->rendezvous()->Send(parsed_key, args, host_tensor, false));
|
||||
|
||||
return std::move(done_event);
|
||||
};
|
||||
}
|
||||
|
||||
// Provide RecvDeviceMemoryFunction for XLA host callbacks. This callback
|
||||
// handles transferring from host to device.
|
||||
xla::RecvDeviceMemoryFunction GetRecvDeviceMemoryFunction(
|
||||
OpKernelContext* ctx) {
|
||||
return
|
||||
[ctx](int64_t channel_id, se::Stream* stream, const xla::Shape& shape,
|
||||
se::DeviceMemoryBase* device_memory_base,
|
||||
const absl::flat_hash_map<std::string, std::string>& frontend_attrs)
|
||||
-> StatusOr<tsl::AsyncValueRef<se::Event>> {
|
||||
auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous");
|
||||
|
||||
// Generate the Rendezvous key.
|
||||
const std::string& rendezvous_key_base = iter->second;
|
||||
const std::string& dst_device = ctx->device()->name();
|
||||
|
||||
std::string task_prefix;
|
||||
TF_RETURN_IF_ERROR(GetTaskName(dst_device, &task_prefix));
|
||||
const std::string src_device =
|
||||
absl::StrCat(task_prefix, "/device:CPU:0");
|
||||
const std::string& rendezvous_key =
|
||||
Rendezvous::CreateKey(src_device, /*src_incarnation=*/1, dst_device,
|
||||
rendezvous_key_base, FrameAndIter(0, 0));
|
||||
VLOG(2) << "Rendezvous Key for sending from host: " << rendezvous_key;
|
||||
|
||||
RendezvousInterface::ParsedKey parsed_key;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key));
|
||||
|
||||
tsl::AsyncValueRef<se::Event> done_event =
|
||||
tsl::MakeConstructedAsyncValueRef<se::Event>(stream->parent());
|
||||
if (!done_event->Init()) {
|
||||
return errors::Internal(
|
||||
"Failed to initialize done event (channel_id=%d)", channel_id);
|
||||
}
|
||||
|
||||
Rendezvous::Args args;
|
||||
// Rendezvous::Args owns the device context pointer.
|
||||
args.device_context = new XlaHostSendDeviceContext(
|
||||
stream, device_memory_base, shape, done_event);
|
||||
|
||||
Tensor device_tensor;
|
||||
bool is_dead;
|
||||
TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(
|
||||
parsed_key, args, &device_tensor, /*is_dead=*/&is_dead));
|
||||
|
||||
return std::move(done_event);
|
||||
};
|
||||
}
|
||||
|
||||
StatusOr<xla::ExecutionOutput> RunExecutable(
|
||||
const XlaPlatformInfo& platform_info,
|
||||
const XlaComputationLaunchContext& launch_context,
|
||||
|
|
@ -200,6 +322,15 @@ StatusOr<xla::ExecutionOutput> RunExecutable(
|
|||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
|
||||
// Host callbacks used for HLO send/recv.
|
||||
xla::SendDeviceMemoryFunction send_function =
|
||||
GetSendDeviceMemoryFunction(ctx);
|
||||
run_options.set_send_device_memory_function(&send_function);
|
||||
xla::RecvDeviceMemoryFunction recv_function =
|
||||
GetRecvDeviceMemoryFunction(ctx);
|
||||
run_options.set_recv_device_memory_function(&recv_function);
|
||||
|
||||
StatusOr<xla::ExecutionOutput> execution_output;
|
||||
bool run_synchronous =
|
||||
!stream || platform_info.platform_id() == se::host::kHostPlatformId;
|
||||
|
|
@ -263,7 +394,7 @@ Status CompileToLocalExecutable(
|
|||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
if (!rm) {
|
||||
return errors::Internal("No resource manager.");
|
||||
return absl::InternalError("No resource manager.");
|
||||
}
|
||||
|
||||
XlaDeviceCompiler* xla_device_compiler;
|
||||
|
|
@ -312,23 +443,13 @@ Status CompileToPjRtLoadedExecutable(
|
|||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = ctx.resource_manager();
|
||||
if (!rm) {
|
||||
return errors::Internal("No resource manager.");
|
||||
return absl::InternalError("No resource manager.");
|
||||
}
|
||||
|
||||
PjRtDeviceCompiler* pjrt_device_compiler;
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<PjRtDeviceCompiler>(
|
||||
rm->default_container(), "pjrt_device_compiler", &pjrt_device_compiler,
|
||||
[&](PjRtDeviceCompiler** pjrt_device_compiler) {
|
||||
return BuildPjRtDeviceCompiler(platform_info, ctx.function_library(),
|
||||
pjrt_device_compiler);
|
||||
}));
|
||||
DeviceCompilationProfiler* profiler;
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
|
||||
rm->default_container(), "pjrt_device_compilation_profiler", &profiler,
|
||||
[](DeviceCompilationProfiler** profiler) {
|
||||
*profiler = new DeviceCompilationProfiler();
|
||||
return OkStatus();
|
||||
}));
|
||||
TF_RETURN_IF_ERROR(GetOrCreatePjRtDeviceCompilerAndProfiler(
|
||||
platform_info, ctx.function_library(), &pjrt_device_compiler, &profiler));
|
||||
// Hold the reference to the PJRT device compiler and profiler during
|
||||
// evaluation. (We could probably free them sooner because the ResourceMgr
|
||||
// will retain references, but this is more obviously correct.)
|
||||
|
|
@ -337,8 +458,9 @@ Status CompileToPjRtLoadedExecutable(
|
|||
|
||||
*client = pjrt_device_compiler->client();
|
||||
|
||||
XlaCompiler::Options options = GenerateCompilerOptionsForPjRt(
|
||||
*ctx.function_library(), ctx.device(), platform_info);
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptionsForPjRt(*ctx.function_library(), ctx.device(),
|
||||
platform_info, pjrt_device_compiler);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options =
|
||||
GenerateCompileOptions(has_ref_vars, may_alias_resource_update);
|
||||
|
|
@ -474,19 +596,23 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
|||
compilation_result, done, inputs,
|
||||
resources = resources_]() {
|
||||
auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
|
||||
&variable_infos),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RunPjRtExecutable(*pjrt_client, inputs, variable_infos,
|
||||
*compilation_result, pjrt_executable, ctx),
|
||||
done);
|
||||
// Separate scope so that VariableInfo locks are released before done() is
|
||||
// called.
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
|
||||
&variable_infos),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RunPjRtExecutable(*pjrt_client, inputs, variable_infos,
|
||||
*compilation_result, pjrt_executable, ctx),
|
||||
done);
|
||||
}
|
||||
VLOG(2) << "Done executing with PJRT.";
|
||||
done();
|
||||
};
|
||||
|
|
@ -505,65 +631,69 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
|||
// Continuation of the execution, may be run in a different thread.
|
||||
auto run_xla_cluster = [ctx, client, executable, compilation_result, done,
|
||||
inputs, resources = resources_]() {
|
||||
auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
|
||||
&variable_infos),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
|
||||
done);
|
||||
std::map<int, const Tensor*> resource_var_ptrs;
|
||||
for (int i = 0; i < resources.size(); i++) {
|
||||
resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor();
|
||||
}
|
||||
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> allocator =
|
||||
GetAllocator(ctx->device(), GetStream(ctx), platform_info);
|
||||
XlaComputationLaunchContext launch_context =
|
||||
GetLaunchContext(platform_info, ctx, client, allocator.get());
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
|
||||
launch_context.PopulateInputs(
|
||||
ctx, compilation_result, resource_var_ptrs,
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done);
|
||||
|
||||
xla::gpu::GpuExecutableRunOptions gpu_options;
|
||||
xla::DeviceAssignment device_assignment;
|
||||
xla::ExecutableRunOptions run_options;
|
||||
if (compilation_result->collective_info.has_value()) {
|
||||
// Separate scope so that VariableInfo locks are released before done is
|
||||
// called.
|
||||
{
|
||||
auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
ResolveDeviceAssignment(ctx, *compilation_result->collective_info,
|
||||
run_options, device_assignment, gpu_options),
|
||||
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
|
||||
&variable_infos),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
|
||||
done);
|
||||
std::map<int, const Tensor*> resource_var_ptrs;
|
||||
for (int i = 0; i < resources.size(); i++) {
|
||||
resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor();
|
||||
}
|
||||
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> allocator =
|
||||
GetAllocator(ctx->device(), GetStream(ctx), platform_info);
|
||||
XlaComputationLaunchContext launch_context =
|
||||
GetLaunchContext(platform_info, ctx, client, allocator.get());
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
|
||||
launch_context.PopulateInputs(
|
||||
ctx, compilation_result, resource_var_ptrs,
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias);
|
||||
OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done);
|
||||
|
||||
xla::gpu::GpuExecutableRunOptions gpu_options;
|
||||
xla::DeviceAssignment device_assignment;
|
||||
xla::ExecutableRunOptions run_options;
|
||||
if (compilation_result->collective_info.has_value()) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ResolveDeviceAssignment(
|
||||
ctx, *compilation_result->collective_info,
|
||||
run_options, device_assignment, gpu_options),
|
||||
done);
|
||||
}
|
||||
|
||||
// Hardcode run id to always be zero: TF distributed strategy
|
||||
// differentiates between subsequent runs using dependency edges. This
|
||||
// is safe, as only TF dist-strat can produce distributed ops, and we
|
||||
// can rely on TF dist-strat invariants.
|
||||
xla::RunId run_id(0);
|
||||
run_options.set_run_id(run_id);
|
||||
|
||||
StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
|
||||
platform_info, launch_context, std::move(*execution_inputs),
|
||||
run_options, executable, ctx, allocator.get());
|
||||
OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(),
|
||||
done);
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
ctx, compilation_result, execution_output->ConsumeResult(),
|
||||
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
|
||||
input_output_alias, resource_var_ptrs),
|
||||
done);
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
|
||||
// Hardcode run id to always be zero: TF distributed strategy
|
||||
// differentiates between subsequent runs using dependency edges. This
|
||||
// is safe, as only TF dist-strat can produce distributed ops, and we
|
||||
// can rely on TF dist-strat invariants.
|
||||
xla::RunId run_id(0);
|
||||
run_options.set_run_id(run_id);
|
||||
|
||||
StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
|
||||
platform_info, launch_context, std::move(*execution_inputs),
|
||||
run_options, executable, ctx, allocator.get());
|
||||
OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(),
|
||||
done);
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
ctx, compilation_result, execution_output->ConsumeResult(),
|
||||
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
|
||||
input_output_alias, resource_var_ptrs),
|
||||
done);
|
||||
VLOG(1) << "Done";
|
||||
done();
|
||||
};
|
||||
|
||||
|
|
@ -658,9 +788,11 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
|||
void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaCompileOp " << def().name()
|
||||
<< (must_compile_ ? "(must-compile)" : "");
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
xla::LocalExecutable* executable;
|
||||
const XlaCompiler::CompilationResult* kernel = nullptr;
|
||||
xla::LocalClient* client = nullptr;
|
||||
xla::LocalExecutable* executable = nullptr;
|
||||
xla::PjRtClient* pjrt_client = nullptr;
|
||||
xla::PjRtLoadedExecutable* pjrt_executable = nullptr;
|
||||
ResourceVarsSnapshot variables_snapshot;
|
||||
|
||||
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
|
||||
|
|
@ -678,6 +810,11 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||
: DeviceCompileMode::kLazy;
|
||||
}();
|
||||
|
||||
bool use_pjrt =
|
||||
GetXlaOpsCommonFlags()
|
||||
->tf_xla_use_device_api.IsEnabledInXlaCompileAndRunForDevice(
|
||||
platform_info_.device_type());
|
||||
|
||||
if (GetXlaOpsCommonFlags()->tf_xla_always_defer_compilation ||
|
||||
cannot_compile_cluster) {
|
||||
executable = nullptr;
|
||||
|
|
@ -691,22 +828,33 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||
|
||||
// Do not alias resource updates as locking variables in XlaCompile and
|
||||
// unlocking them in XlaRun may lead to deadlocks.
|
||||
const Status status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, args, compile_mode,
|
||||
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
||||
Status status;
|
||||
if (use_pjrt) {
|
||||
VLOG(2) << "Using PJRT for compilation. Function name: "
|
||||
<< function_.name();
|
||||
status = CompileToPjRtLoadedExecutable(
|
||||
*ctx, platform_info_, function_, args, compile_mode, has_ref_vars_,
|
||||
/*may_alias_resource_update=*/false, &kernel, &pjrt_client,
|
||||
&pjrt_executable);
|
||||
} else {
|
||||
status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, args, compile_mode,
|
||||
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
||||
}
|
||||
if (compile_mode != DeviceCompileMode::kLazy ||
|
||||
status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
||||
if (status.code() == error::UNIMPLEMENTED) {
|
||||
LOG(WARNING) << "Compilation failed:" << status.ToString()
|
||||
LOG(WARNING) << "Compilation failed:" << status
|
||||
<< ". Falling back to TF function call.";
|
||||
|
||||
BroadcastOptimizationRemark(
|
||||
XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
|
||||
.IgnoreError();
|
||||
executable = nullptr;
|
||||
pjrt_executable = nullptr;
|
||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
cannot_compile_cluster_ = true;
|
||||
}
|
||||
|
|
@ -718,28 +866,36 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||
Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
|
||||
|
||||
// Async compilation returns nullptr executable without an error.
|
||||
if (!executable) {
|
||||
if (!executable && !pjrt_executable) {
|
||||
DCHECK(!must_compile_);
|
||||
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
|
||||
|
||||
Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
|
||||
compilation_successful.scalar<bool>()() = false;
|
||||
ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
|
||||
ctx->set_output(0, compilation_key);
|
||||
ctx->set_output(1, compilation_successful);
|
||||
return;
|
||||
}
|
||||
|
||||
// Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
|
||||
// Each execution of an XlaCompile op creates a new ExecutableClosure, even
|
||||
// if it didn't have to compile the cluster because of a compilation-cache
|
||||
// hit. This is because we at least need new snapshots of the resource
|
||||
// variables.
|
||||
XlaExecutableClosureStore::KeyT key =
|
||||
XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
|
||||
client, executable, kernel, std::move(variables_snapshot),
|
||||
constants_.size()));
|
||||
|
||||
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
|
||||
compilation_key.flat<tstring>()(0) = key;
|
||||
if (use_pjrt) {
|
||||
PjRtExecutableClosureStore::KeyT key =
|
||||
PjRtExecutableClosureStore::Global()->Produce(PjRtExecutableClosure(
|
||||
pjrt_client, pjrt_executable, kernel, std::move(variables_snapshot),
|
||||
constants_.size()));
|
||||
compilation_key.flat<tstring>()(0) = key;
|
||||
VLOG(2) << "Compiled with PJRT. compilation_key: " << key;
|
||||
} else {
|
||||
XlaExecutableClosureStore::KeyT key =
|
||||
XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
|
||||
client, executable, kernel, std::move(variables_snapshot),
|
||||
constants_.size()));
|
||||
compilation_key.flat<tstring>()(0) = key;
|
||||
VLOG(2) << "Compiled with XLA. compilation_key: " << key;
|
||||
}
|
||||
|
||||
Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
|
||||
compilation_successful.flat<bool>()(0) = true;
|
||||
|
|
|
|||
|
|
@ -24,5 +24,5 @@ py_library(
|
|||
name = "xla_ops_grad",
|
||||
srcs = ["xla_ops_grad.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//tensorflow/python:framework_ops"],
|
||||
deps = ["//tensorflow/python/framework:ops"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -77,8 +77,8 @@ cc_library(
|
|||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:path",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
|
|
@ -51,7 +52,7 @@ class JitCompilationListener : public XlaActivityListener {
|
|||
bool expect_persistent_cache_use) {
|
||||
for (const auto& activity : activity_history_) {
|
||||
if (activity.used_persistent_cache() != expect_persistent_cache_use) {
|
||||
return errors::FailedPrecondition("Unexpected listener history.");
|
||||
return absl::FailedPreconditionError("Unexpected listener history.");
|
||||
}
|
||||
}
|
||||
return OkStatus();
|
||||
|
|
|
|||
|
|
@ -23,14 +23,15 @@ Status TfGraphToHloCompiler::Compile(const XlaCompiler::CompileOptions& options,
|
|||
const NameAttrList& function,
|
||||
absl::Span<const XlaArgument> args,
|
||||
XlaCompilationResult* result) {
|
||||
return xla_compiler_.CompileFunction(options, function, args, result);
|
||||
return ADD_SOURCE_LOCATION(
|
||||
xla_compiler_.CompileFunction(options, function, args, result));
|
||||
}
|
||||
|
||||
Status TfGraphToHloCompiler::CompileSingleOp(
|
||||
const XlaCompiler::CompileOptions& options, const OpKernelContext* ctx,
|
||||
absl::Span<const XlaArgument> args, XlaCompilationResult* result) {
|
||||
return xla_compiler_.CompileSingleOp(
|
||||
options, XlaCompiler::SingleOpCompileArgument(*ctx), args, result);
|
||||
return ADD_SOURCE_LOCATION(xla_compiler_.CompileSingleOp(
|
||||
options, XlaCompiler::SingleOpCompileArgument(*ctx), args, result));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -43,6 +44,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tf_pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
|
|
@ -53,6 +55,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/statusor.h"
|
||||
#include "tensorflow/core/tfrt/common/pjrt_util.h"
|
||||
#include "tensorflow/tsl/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
|
@ -169,28 +172,12 @@ Status XlaCompileOnDemandOp::Compile(
|
|||
DeviceCompilationProfiler** profiler,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
xla::PjRtLoadedExecutable** executable) {
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
if (!rm) {
|
||||
return errors::Internal("No resource manager.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(GetOrCreatePjRtDeviceCompilerAndProfiler(
|
||||
platform_info_, ctx->function_library(), pjrt_device_compiler, profiler));
|
||||
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<PjRtDeviceCompiler>(
|
||||
rm->default_container(), "pjrt_device_compiler", pjrt_device_compiler,
|
||||
[&](PjRtDeviceCompiler** pjrt_device_compiler) {
|
||||
return BuildPjRtDeviceCompiler(platform_info_, ctx->function_library(),
|
||||
pjrt_device_compiler);
|
||||
}));
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
|
||||
rm->default_container(), "pjrt_device_compilation_profiler", profiler,
|
||||
[](DeviceCompilationProfiler** profiler) {
|
||||
*profiler = new DeviceCompilationProfiler();
|
||||
return OkStatus();
|
||||
}));
|
||||
|
||||
XlaCompiler::Options options = GenerateCompilerOptionsForPjRt(
|
||||
*(ctx->function_library()), ctx->device(), platform_info_);
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptionsForPjRt(*(ctx->function_library()), ctx->device(),
|
||||
platform_info_, *pjrt_device_compiler);
|
||||
// No detailed logging for on demand op.
|
||||
options.detailed_logging = false;
|
||||
XlaCompiler::CompileOptions compile_options = GetCompileOptions(true);
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/jit/xla_compile_util.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
|
|
@ -24,6 +25,11 @@ limitations under the License.
|
|||
#include "tensorflow/core/util/determinism.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
constexpr const char* kPjRtDeviceCompilerResourceName = "pjrt_device_compiler";
|
||||
constexpr const char* kPjRtDeviceCompilationProfilerResourceName =
|
||||
"pjrt_device_compilation_profiler";
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Graph>> CreateSingleOpGraph(
|
||||
const NodeDef& node_def, absl::Span<const XlaArgument> args,
|
||||
|
|
@ -69,7 +75,18 @@ StatusOr<std::unique_ptr<Graph>> CreateSingleOpGraph(
|
|||
bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type) {
|
||||
const auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api;
|
||||
return rollout_config.IsEnabledInXlaLaunchForDevice(device_type) ||
|
||||
rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type);
|
||||
rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type) ||
|
||||
rollout_config.IsEnabledInXlaCompileAndRunForDevice(device_type);
|
||||
}
|
||||
|
||||
std::string GetPjRtDeviceCompilerResourceName(const DeviceType& device_type) {
|
||||
return absl::StrCat(kPjRtDeviceCompilerResourceName, "_",
|
||||
device_type.type_string());
|
||||
}
|
||||
|
||||
std::string GetPjRtDeviceCompilationProfilerResourceName(
|
||||
const DeviceType& device_type) {
|
||||
return absl::StrCat(kPjRtDeviceCompilationProfilerResourceName, "_",
|
||||
device_type.type_string());
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
|
@ -47,6 +48,14 @@ StatusOr<std::unique_ptr<Graph>> CreateSingleOpGraph(
|
|||
// Checks if single device compilation and execution with PJRT is enabled for
|
||||
// `device_type` in either the XlaLaunch op or the XlaCompileOnDemand op.
|
||||
bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type);
|
||||
|
||||
// Gets the resource name of the PjRt DeviceCompiler for `device_type`.
|
||||
std::string GetPjRtDeviceCompilerResourceName(const DeviceType& device_type);
|
||||
|
||||
// Gets the resource name of the DeviceCompilationProfiler for `device_type`
|
||||
// when PjRt is used for compilation and execution.
|
||||
std::string GetPjRtDeviceCompilationProfilerResourceName(
|
||||
const DeviceType& device_type);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
|
@ -118,5 +119,31 @@ TEST(XlaCompileUtilTest, PjRtXlaCompileOnDemandFlagTest) {
|
|||
EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU)));
|
||||
}
|
||||
|
||||
TEST(XlaCompileUtilTest, PjRtDeviceCompilerResourceName) {
|
||||
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_TPU)),
|
||||
"pjrt_device_compiler_TPU");
|
||||
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_TPU_NODE)),
|
||||
"pjrt_device_compiler_TPU");
|
||||
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_CPU)),
|
||||
"pjrt_device_compiler_CPU");
|
||||
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_GPU)),
|
||||
"pjrt_device_compiler_GPU");
|
||||
}
|
||||
|
||||
TEST(XlaCompileUtilTest, PjRtDeviceCompilationProfilerResourceName) {
|
||||
EXPECT_EQ(
|
||||
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_TPU)),
|
||||
"pjrt_device_compilation_profiler_TPU");
|
||||
EXPECT_EQ(
|
||||
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_TPU_NODE)),
|
||||
"pjrt_device_compilation_profiler_TPU");
|
||||
EXPECT_EQ(
|
||||
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_CPU)),
|
||||
"pjrt_device_compilation_profiler_CPU");
|
||||
EXPECT_EQ(
|
||||
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_GPU)),
|
||||
"pjrt_device_compilation_profiler_GPU");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -15,10 +15,14 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/jit/xla_compiler_options_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
using XlaDeviceCompiler =
|
||||
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
|
||||
using PjRtDeviceCompiler =
|
||||
DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>;
|
||||
|
||||
inline void LogOptions(const XlaCompiler::Options& options) {
|
||||
VLOG(2) << "XlaCompiler::Options[device_type=" << options.device_type
|
||||
|
|
@ -81,7 +85,8 @@ XlaCompiler::Options GenerateCompilerOptionsForTfrtTpu(
|
|||
|
||||
XlaCompiler::Options GenerateCompilerOptionsForPjRt(
|
||||
const FunctionLibraryRuntime& function_library,
|
||||
const DeviceBase* device_base, const XlaPlatformInfo& platform_info) {
|
||||
const DeviceBase* device_base, const XlaPlatformInfo& platform_info,
|
||||
const PjRtDeviceCompiler* pjrt_device_compiler) {
|
||||
XlaCompiler::Options options;
|
||||
options.device_ordinal = device_base->parsed_name().id;
|
||||
options.flib_def = function_library.GetFunctionLibraryDefinition();
|
||||
|
|
@ -96,8 +101,9 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt(
|
|||
options.device_type = metadata->jit_device_type();
|
||||
options.shape_determination_fns =
|
||||
metadata->default_shape_determination_fns();
|
||||
} else if (pjrt_device_compiler != nullptr) {
|
||||
options.device_type = pjrt_device_compiler->device_type();
|
||||
}
|
||||
// TODO(b/255826209): Set options for non-XLA devices once PjRt supports them.
|
||||
// TODO(b/255826209): Confirm below options are correctly set after testing.
|
||||
options.allow_cpu_custom_calls = false;
|
||||
options.alias_passthrough_params = false;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
@ -39,9 +40,12 @@ XlaCompiler::Options GenerateCompilerOptionsForTfrtTpu(
|
|||
|
||||
// Returns created options for XLA compiler when PjRt (Device API) is used for
|
||||
// compilation and execution.
|
||||
// TODO(b/255826209): Remove default arg once PjRtCompileOnDemand op is deleted.
|
||||
XlaCompiler::Options GenerateCompilerOptionsForPjRt(
|
||||
const FunctionLibraryRuntime& function_library,
|
||||
const DeviceBase* device_base, const XlaPlatformInfo& platform_info);
|
||||
const DeviceBase* device_base, const XlaPlatformInfo& platform_info,
|
||||
const DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>*
|
||||
pjrt_device_compiler = nullptr);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
|||
|
|
@ -23,12 +23,14 @@ limitations under the License.
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
|
|
@ -42,18 +44,31 @@ using XlaDeviceCompiler =
|
|||
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
|
||||
using XlaDeviceExecutablePersistor =
|
||||
DeviceExecutablePersistor<xla::LocalExecutable, xla::LocalClient>;
|
||||
using PjRtDeviceCompiler =
|
||||
DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>;
|
||||
using PjRtDeviceExecutablePersistor =
|
||||
DeviceExecutablePersistor<xla::PjRtLoadedExecutable, xla::PjRtClient>;
|
||||
|
||||
XlaDeviceCompiler* CreateXlaDeviceCompiler(
|
||||
const XlaDeviceExecutablePersistor::Config& persistor_config,
|
||||
DeviceType device_type, xla::LocalClient* local_client) {
|
||||
XlaDeviceCompiler* CreateXlaDeviceCompiler(DeviceType device_type,
|
||||
xla::LocalClient* local_client) {
|
||||
auto persistor = std::make_unique<XlaDeviceExecutablePersistor>(
|
||||
std::move(persistor_config), device_type);
|
||||
XlaDeviceExecutablePersistor::Config(), device_type);
|
||||
auto compiler_client =
|
||||
std::make_unique<XlaDeviceCompilerClient>(local_client);
|
||||
return new XlaDeviceCompiler(std::move(persistor),
|
||||
std::move(compiler_client));
|
||||
}
|
||||
|
||||
PjRtDeviceCompiler* CreatePjRtDeviceCompiler(DeviceType device_type,
|
||||
xla::PjRtClient* pjrt_client) {
|
||||
auto persistor = std::make_unique<PjRtDeviceExecutablePersistor>(
|
||||
PjRtDeviceExecutablePersistor::Config(), device_type);
|
||||
auto compiler_client =
|
||||
std::make_unique<PjRtDeviceCompilerClient>(pjrt_client);
|
||||
return new PjRtDeviceCompiler(std::move(persistor),
|
||||
std::move(compiler_client));
|
||||
}
|
||||
|
||||
std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
|
||||
GetShapeDeterminationFns() {
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn =
|
||||
|
|
@ -160,6 +175,45 @@ TEST_F(XlaCompilerOptionsTest, PjRtOptionsPjRtBaseDevice) {
|
|||
tensorflow::XlaLayoutPreference::kTpuPreferLinearLayout);
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerOptionsTest, PjRtOptionsNonXlaDevice) {
|
||||
device_setup_.AddDevicesAndSetUp({DEVICE_CPU});
|
||||
Device* device = device_setup_.GetDevice(DEVICE_CPU);
|
||||
DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||
|
||||
XlaPlatformInfo platform_info(compilation_device_type,
|
||||
/*platform_id=*/nullptr,
|
||||
/*xla_device_metadata=*/nullptr,
|
||||
/*pjrt_device_metadata=*/nullptr,
|
||||
/*device_allocator=*/nullptr);
|
||||
|
||||
auto pjrt_device_compiler =
|
||||
CreatePjRtDeviceCompiler(compilation_device_type, nullptr);
|
||||
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
|
||||
|
||||
XlaCompiler::Options options = GenerateCompilerOptionsForPjRt(
|
||||
*device_setup_.flr(), device, platform_info, pjrt_device_compiler);
|
||||
|
||||
EXPECT_EQ(options.device_type, compilation_device_type);
|
||||
EXPECT_EQ(options.device_ordinal, 0);
|
||||
EXPECT_NE(options.flib_def, nullptr);
|
||||
EXPECT_EQ(options.graph_def_version, TF_GRAPH_DEF_VERSION);
|
||||
EXPECT_FALSE(options.allow_cpu_custom_calls);
|
||||
EXPECT_FALSE(options.alias_passthrough_params);
|
||||
EXPECT_FALSE(options.detailed_logging);
|
||||
// Check whether options have default shape determination functions set.
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto shape, options.shape_determination_fns.shape_representation_fn(
|
||||
TensorShape(), DT_FLOAT, false,
|
||||
tensorflow::XlaLayoutPreference::kNoPreference));
|
||||
xla::ShapeProto shape_proto;
|
||||
shape_proto.set_element_type(xla::PrimitiveType::F32);
|
||||
shape_proto.mutable_layout();
|
||||
EXPECT_EQ(shape, xla::Shape(shape_proto));
|
||||
EXPECT_EQ(options.shape_determination_fns.layout_preference_fn(
|
||||
TensorShape(), DT_FLOAT, std::nullopt),
|
||||
tensorflow::XlaLayoutPreference::kNoPreference);
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerOptionsTest, XlaOptions) {
|
||||
device_setup_.AddDevicesAndSetUp({DEVICE_XLA_GPU});
|
||||
Device* device = device_setup_.GetDevice(DEVICE_XLA_GPU);
|
||||
|
|
@ -168,8 +222,8 @@ TEST_F(XlaCompilerOptionsTest, XlaOptions) {
|
|||
DeviceType device_type = DeviceType(DEVICE_XLA_GPU);
|
||||
DeviceType compilation_device_type = DeviceType(DEVICE_GPU_XLA_JIT);
|
||||
|
||||
auto xla_device_compiler = CreateXlaDeviceCompiler(
|
||||
XlaDeviceExecutablePersistor::Config(), compilation_device_type, client);
|
||||
auto xla_device_compiler =
|
||||
CreateXlaDeviceCompiler(compilation_device_type, client);
|
||||
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
|
||||
|
||||
se::Platform::Id platform_id = se::host::kHostPlatformId;
|
||||
|
|
@ -208,8 +262,8 @@ TEST_F(XlaCompilerOptionsTest, XlaOptionsHasRefVarsNoXlaDeviceMetadata) {
|
|||
DeviceType device_type = DeviceType(DEVICE_CPU);
|
||||
DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||
|
||||
auto xla_device_compiler = CreateXlaDeviceCompiler(
|
||||
XlaDeviceExecutablePersistor::Config(), compilation_device_type, client);
|
||||
auto xla_device_compiler =
|
||||
CreateXlaDeviceCompiler(compilation_device_type, client);
|
||||
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
|
||||
|
||||
se::Platform::Id platform_id = se::host::kHostPlatformId;
|
||||
|
|
@ -249,8 +303,8 @@ TEST_F(XlaCompilerOptionsTest, TfRtTpuOptions) {
|
|||
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
|
||||
DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
|
||||
|
||||
auto xla_device_compiler = CreateXlaDeviceCompiler(
|
||||
XlaDeviceExecutablePersistor::Config(), compilation_device_type, client);
|
||||
auto xla_device_compiler =
|
||||
CreateXlaDeviceCompiler(compilation_device_type, client);
|
||||
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
|
||||
|
||||
XlaCompiler::Options options = GenerateCompilerOptionsForTfrtTpu(
|
||||
|
|
|
|||
49
tensorflow/compiler/jit/xla_host_recv_device_context.cc
Normal file
49
tensorflow/compiler/jit/xla_host_recv_device_context.cc
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_host_recv_device_context.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void XlaHostRecvDeviceContext::CopyDeviceTensorToCPU(
|
||||
const Tensor* device_tensor, StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) {
|
||||
DataType dtype = EncodePrimitiveTypeAsDataType(shape_.element_type()).value();
|
||||
TensorShape tensor_shape;
|
||||
Status status = XLAShapeToTensorShape(shape_, &tensor_shape);
|
||||
if (!status.ok()) {
|
||||
done(status);
|
||||
return;
|
||||
}
|
||||
|
||||
*cpu_tensor = Tensor(dtype, tensor_shape);
|
||||
|
||||
stream_->ThenMemcpy(cpu_tensor->data(), device_memory_base_,
|
||||
device_memory_base_.size());
|
||||
stream_->ThenRecordEvent(&done_event_.get());
|
||||
if (auto st = stream_->BlockHostUntilDone(); !st.ok()) {
|
||||
done_event_.SetError(absl::InternalError(absl::StrFormat(
|
||||
"failed to synchronize send operation with a stream: %s",
|
||||
st.ToString())));
|
||||
return;
|
||||
}
|
||||
|
||||
done_event_.SetStateConcrete();
|
||||
done(OkStatus());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
92
tensorflow/compiler/jit/xla_host_recv_device_context.h
Normal file
92
tensorflow/compiler/jit/xla_host_recv_device_context.h
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/stream.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// XlaHostRecvDeviceContext is a DeviceContext that is intended to be
|
||||
// used to transfer from device->host using Rendezvous. It transfers the
|
||||
// content of `device_memory_base` with `shape` using `stream`. Only
|
||||
// `CopyDeviceTensorToCPU` method is implemented. The `done_event` is marked as
|
||||
// Concrete once transfer is completed.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// Device device;
|
||||
// stream_executor::Stream stream(executor);
|
||||
// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2}));
|
||||
// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
|
||||
// xla::Shape shape(xla::F32, {2, 2}, {}, {})
|
||||
// tsl::AsyncValueRef<se::Event> done_event =
|
||||
// tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
|
||||
// done_event->Init();
|
||||
// Tensor dest_cpu_tensor;
|
||||
//
|
||||
// XlaHostRecvDeviceContext device_context(&stream, gpu_dst,
|
||||
// shape, done_event);
|
||||
// device_context.CopyDeviceTensorToCPUSync(
|
||||
// &device_tensor, "", &device, &dest_cpu_tensor);
|
||||
|
||||
class XlaHostRecvDeviceContext : public DeviceContext {
|
||||
public:
|
||||
XlaHostRecvDeviceContext(se::Stream* stream,
|
||||
const se::DeviceMemoryBase& device_memory_base,
|
||||
const xla::Shape& shape,
|
||||
tsl::AsyncValueRef<se::Event>& done_event)
|
||||
: stream_(stream),
|
||||
device_memory_base_(device_memory_base),
|
||||
shape_(shape),
|
||||
done_event_(done_event) {}
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor, StatusCallback done,
|
||||
bool sync_dst_compute) const override {
|
||||
done(errors::Internal("host->device copy not implemented."));
|
||||
}
|
||||
|
||||
// Copies `device_memory_base_` with `shape_` into `cpu_tensor`.
|
||||
// `device_tensor` is unused.
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
|
||||
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
|
||||
Tensor* output_tensor,
|
||||
StatusCallback done) const override {
|
||||
done(errors::Internal("device->device copy not implemented."));
|
||||
}
|
||||
|
||||
private:
|
||||
se::Stream* stream_; // Not owned.
|
||||
// This is copied rather than a reference or pointer since its lifetime
|
||||
// is not guaranteed to outlast the original object. Object slicing is
|
||||
// not an issue here since only DeviceMemoryBase methods/members are used.
|
||||
const se::DeviceMemoryBase device_memory_base_;
|
||||
const xla::Shape shape_;
|
||||
tsl::AsyncValueRef<se::Event> done_event_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaHostRecvDeviceContext);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_
|
||||
39
tensorflow/compiler/jit/xla_host_send_device_context.cc
Normal file
39
tensorflow/compiler/jit/xla_host_send_device_context.cc
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_host_send_device_context.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void XlaHostSendDeviceContext::CopyCPUTensorToDevice(
|
||||
const Tensor* cpu_tensor, Device* device, Tensor* device_tensor,
|
||||
StatusCallback done, bool sync_dst_compute) const {
|
||||
stream_->ThenMemcpy(device_memory_base_, cpu_tensor->data(),
|
||||
device_memory_base_->size());
|
||||
stream_->ThenRecordEvent(&done_event_.get());
|
||||
if (auto st = stream_->BlockHostUntilDone(); !st.ok()) {
|
||||
done_event_.SetError(absl::InternalError(absl::StrFormat(
|
||||
"failed to synchronize send operation with a stream: %s",
|
||||
st.ToString())));
|
||||
return;
|
||||
}
|
||||
|
||||
done_event_.SetStateConcrete();
|
||||
done(OkStatus());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
89
tensorflow/compiler/jit/xla_host_send_device_context.h
Normal file
89
tensorflow/compiler/jit/xla_host_send_device_context.h
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/stream.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// XlaHostSendDeviceContext is a DeviceContext that is intended to be
|
||||
// used to transfer from host->device using Rendezvous. It transfers the
|
||||
// content of `device_memory_base` with `shape` using `stream`. Only
|
||||
// `CopyCPUTensorToDevice` method is implemented. The `done_event` is marked as
|
||||
// Concrete once transfer is completed.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// Device device;
|
||||
// stream_executor::Stream stream(executor);
|
||||
// Tensor cpu_tensor(host_allocator, DT_FLOAT, TensorShape({2, 2}));
|
||||
// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2}));
|
||||
// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
|
||||
// xla::Shape shape(xla::F32, {2, 2}, {}, {})
|
||||
// tsl::AsyncValueRef<se::Event> done_event =
|
||||
// tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
|
||||
// done_event->Init();
|
||||
//
|
||||
// XlaHostSendDeviceContext device_context(&stream, &gpu_dst,
|
||||
// shape, done_event);
|
||||
// device_context.CopyCPUTensorToDeviceSync(
|
||||
// &cpu_tensor, &device, &device_tensor);
|
||||
|
||||
class XlaHostSendDeviceContext : public DeviceContext {
|
||||
public:
|
||||
XlaHostSendDeviceContext(se::Stream* stream,
|
||||
se::DeviceMemoryBase* device_memory_base,
|
||||
const xla::Shape& shape,
|
||||
tsl::AsyncValueRef<se::Event>& done_event)
|
||||
: stream_(stream),
|
||||
device_memory_base_(device_memory_base),
|
||||
shape_(shape),
|
||||
done_event_(done_event) {}
|
||||
|
||||
// Copies 'cpu_tensor' to `device_memory_base_` with `shape_`.
|
||||
// `device_tensor` is unused.
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor, StatusCallback done,
|
||||
bool sync_dst_compute) const override;
|
||||
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override {
|
||||
done(errors::Internal("host->device copy not implemented."));
|
||||
}
|
||||
|
||||
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
|
||||
Tensor* output_tensor,
|
||||
StatusCallback done) const override {
|
||||
done(errors::Internal("device->device copy not implemented."));
|
||||
}
|
||||
|
||||
private:
|
||||
se::Stream* stream_; // Not owned.
|
||||
se::DeviceMemoryBase* device_memory_base_; // Not owned.
|
||||
const xla::Shape shape_;
|
||||
tsl::AsyncValueRef<se::Event> done_event_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaHostSendDeviceContext);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/compiler/jit/xla_host_recv_device_context.h"
|
||||
#include "tensorflow/compiler/jit/xla_host_send_device_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/stream.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/tsl/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class XlaHostSendRecvDeviceContextTest : public ::testing::Test {
|
||||
public:
|
||||
void SetDevice(const string& device_type) {
|
||||
auto device_factory = DeviceFactory::GetFactory(device_type);
|
||||
SessionOptions options;
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
Status s = device_factory->CreateDevices(
|
||||
options, "/job:worker/replica:0/task:0", &devices);
|
||||
device_ = std::move(devices[0]);
|
||||
|
||||
AllocatorAttributes host_alloc_attr;
|
||||
host_alloc_attr.set_on_host(true);
|
||||
host_allocator_ = device_->GetAllocator(host_alloc_attr);
|
||||
|
||||
AllocatorAttributes device_alloc_attr;
|
||||
device_alloc_attr.set_on_host(false);
|
||||
device_allocator_ = device_->GetAllocator(device_alloc_attr);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Device> device_;
|
||||
Allocator* host_allocator_;
|
||||
Allocator* device_allocator_;
|
||||
};
|
||||
|
||||
TEST_F(XlaHostSendRecvDeviceContextTest, CopyDeviceTensorToCPU) {
|
||||
SetDevice("GPU");
|
||||
Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5});
|
||||
Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
|
||||
stream_executor::Platform* platform =
|
||||
stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value();
|
||||
stream_executor::StreamExecutor* executor =
|
||||
platform->ExecutorForDevice(0).value();
|
||||
stream_executor::Stream stream(executor);
|
||||
stream.Init();
|
||||
ASSERT_TRUE(stream.ok());
|
||||
|
||||
se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
|
||||
xla::Shape shape;
|
||||
TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape));
|
||||
|
||||
// Copy the cpu_tensor to the GPU first before trying to copy it back.
|
||||
stream.ThenMemcpy(&gpu_dst, origin_cpu_tensor.data(), gpu_dst.size());
|
||||
TF_ASSERT_OK(stream.BlockHostUntilDone());
|
||||
|
||||
tsl::AsyncValueRef<se::Event> done_event =
|
||||
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
|
||||
done_event->Init();
|
||||
XlaHostRecvDeviceContext* device_context =
|
||||
new XlaHostRecvDeviceContext(&stream, gpu_dst, shape, done_event);
|
||||
TF_ASSERT_OK(device_context->CopyDeviceTensorToCPUSync(
|
||||
&device_tensor, "", device_.get(), &dest_cpu_tensor));
|
||||
|
||||
tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor);
|
||||
device_context->Unref();
|
||||
}
|
||||
|
||||
TEST_F(XlaHostSendRecvDeviceContextTest, CopyCPUTensorToDevice) {
|
||||
SetDevice("GPU");
|
||||
Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5});
|
||||
Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
|
||||
stream_executor::Platform* platform =
|
||||
stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value();
|
||||
stream_executor::StreamExecutor* executor =
|
||||
platform->ExecutorForDevice(0).value();
|
||||
stream_executor::Stream stream(executor);
|
||||
stream.Init();
|
||||
ASSERT_TRUE(stream.ok());
|
||||
|
||||
se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
|
||||
xla::Shape shape;
|
||||
TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape));
|
||||
|
||||
tsl::AsyncValueRef<se::Event> done_event =
|
||||
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
|
||||
done_event->Init();
|
||||
XlaHostSendDeviceContext* device_context =
|
||||
new XlaHostSendDeviceContext(&stream, &gpu_dst, shape, done_event);
|
||||
TF_ASSERT_OK(device_context->CopyCPUTensorToDeviceSync(
|
||||
&origin_cpu_tensor, device_.get(), &device_tensor));
|
||||
|
||||
// Copy the GPU tensor back to CPU to check that copy worked.
|
||||
stream.ThenMemcpy(dest_cpu_tensor.data(), gpu_dst, gpu_dst.size());
|
||||
TF_ASSERT_OK(stream.BlockHostUntilDone());
|
||||
|
||||
tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor);
|
||||
device_context->Unref();
|
||||
}
|
||||
|
||||
TEST_F(XlaHostSendRecvDeviceContextTest, RoundTrip) {
|
||||
SetDevice("GPU");
|
||||
Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5});
|
||||
Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
|
||||
|
||||
stream_executor::Platform* platform =
|
||||
stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value();
|
||||
stream_executor::StreamExecutor* executor =
|
||||
platform->ExecutorForDevice(0).value();
|
||||
stream_executor::Stream stream(executor);
|
||||
stream.Init();
|
||||
ASSERT_TRUE(stream.ok());
|
||||
|
||||
se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
|
||||
xla::Shape shape;
|
||||
TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape));
|
||||
|
||||
tsl::AsyncValueRef<se::Event> send_done_event =
|
||||
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
|
||||
send_done_event->Init();
|
||||
XlaHostSendDeviceContext* send_device_context =
|
||||
new XlaHostSendDeviceContext(&stream, &gpu_dst, shape, send_done_event);
|
||||
TF_ASSERT_OK(send_device_context->CopyCPUTensorToDeviceSync(
|
||||
&origin_cpu_tensor, device_.get(), &device_tensor));
|
||||
|
||||
tsl::AsyncValueRef<se::Event> recv_done_event =
|
||||
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
|
||||
recv_done_event->Init();
|
||||
XlaHostRecvDeviceContext* recv_device_context =
|
||||
new XlaHostRecvDeviceContext(&stream, gpu_dst, shape, recv_done_event);
|
||||
TF_ASSERT_OK(recv_device_context->CopyDeviceTensorToCPUSync(
|
||||
&device_tensor, "", device_.get(), &dest_cpu_tensor));
|
||||
|
||||
tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor);
|
||||
send_device_context->Unref();
|
||||
recv_device_context->Unref();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
|
|
@ -136,7 +137,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
|||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
EXPECT_TRUE(absl::IsInternal(status)) << status;
|
||||
}
|
||||
|
||||
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
||||
|
|
@ -153,7 +154,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
|||
input: 'b'
|
||||
)proto"),
|
||||
&kernel_);
|
||||
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
|
||||
EXPECT_TRUE(absl::IsInternal(status)) << status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -18,17 +18,21 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/device_executable_persistor.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h"
|
||||
#include "tensorflow/compiler/jit/xla_compile_util.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_compiler_client.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tfrt/common/create_pjrt_client_util.h"
|
||||
#include "tensorflow/core/tfrt/common/global_state.h"
|
||||
#include "tensorflow/core/tfrt/common/pjrt_util.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
|
|
@ -45,19 +49,23 @@ using PjRtDeviceExecutablePersistor =
|
|||
|
||||
XlaDeviceCompiler* CreateXlaDeviceCompiler(
|
||||
const XlaDeviceExecutablePersistor::Config& persistor_config,
|
||||
DeviceType device_type, xla::LocalClient* local_client) {
|
||||
DeviceType compilation_device_type, xla::LocalClient* local_client) {
|
||||
return new XlaDeviceCompiler(
|
||||
std::make_unique<XlaDeviceExecutablePersistor>(
|
||||
std::move(persistor_config), device_type),
|
||||
std::move(persistor_config), compilation_device_type),
|
||||
std::make_unique<XlaDeviceCompilerClient>(local_client));
|
||||
}
|
||||
|
||||
PjRtDeviceCompiler* CreatePjRtDeviceCompiler(
|
||||
const PjRtDeviceExecutablePersistor::Config& persistor_config,
|
||||
DeviceType device_type, xla::PjRtClient* pjrt_client) {
|
||||
PjRtDeviceCompiler* CreatePjRtDeviceCompiler(DeviceType compilation_device_type,
|
||||
xla::PjRtClient* pjrt_client) {
|
||||
PjRtDeviceExecutablePersistor::Config persistor_config(
|
||||
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
|
||||
GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
|
||||
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);
|
||||
|
||||
return new PjRtDeviceCompiler(
|
||||
std::make_unique<PjRtDeviceExecutablePersistor>(
|
||||
std::move(persistor_config), device_type),
|
||||
std::move(persistor_config), compilation_device_type),
|
||||
std::make_unique<PjRtDeviceCompilerClient>(pjrt_client));
|
||||
}
|
||||
|
||||
|
|
@ -73,6 +81,60 @@ StatusOr<std::optional<std::set<int>>> GetAllowedGpus(
|
|||
|
||||
return gpu_ids;
|
||||
}
|
||||
|
||||
Status GetCompilationDeviceTypeAndPjRtClient(
|
||||
const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr,
|
||||
DeviceType* compilation_device_type, xla::PjRtClient** pjrt_client) {
|
||||
DeviceType device_type = platform_info.device_type();
|
||||
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
VLOG(2) << "Building PjRtDeviceCompiler using "
|
||||
"platform_info.xla_device_metadata().";
|
||||
|
||||
*compilation_device_type =
|
||||
platform_info.xla_device_metadata()->jit_device_type();
|
||||
TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
if (platform_info.pjrt_device_metadata()) {
|
||||
VLOG(2) << "Building PjRtDeviceCompiler using "
|
||||
"platform_info.pjrt_device_metadata().";
|
||||
|
||||
*compilation_device_type =
|
||||
platform_info.pjrt_device_metadata()->jit_device_type();
|
||||
TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
// TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not
|
||||
// have `xla_device_metadata`.
|
||||
if (device_type == DEVICE_TPU) {
|
||||
*compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
|
||||
TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
VLOG(2) << "platform_info.xla_device_metadata not found and "
|
||||
"platform_info.device_type() != DEVICE_TPU. Building "
|
||||
"PjRtDeviceCompiler for non-XLA device.";
|
||||
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
return errors::InvalidArgument("No JIT device registered for ",
|
||||
device_type.type());
|
||||
}
|
||||
*compilation_device_type = DeviceType(registration->compilation_device_name);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr));
|
||||
// TODO(b/255826209): Set platform, intra op parallelism threads if required
|
||||
// and when supported by GetOrCreatePjRtClient().
|
||||
// The `allowed_gpus` argument is used only if the `device_type` is GPU.
|
||||
TF_ASSIGN_OR_RETURN(*pjrt_client,
|
||||
GetOrCreatePjRtClient(device_type, allowed_gpus));
|
||||
|
||||
return OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
xla::StatusOr<std::optional<std::set<int>>> ParseVisibleDeviceList(
|
||||
|
|
@ -175,71 +237,45 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr,
|
|||
return OkStatus();
|
||||
}
|
||||
|
||||
Status BuildPjRtDeviceCompiler(const XlaPlatformInfo& platform_info,
|
||||
FunctionLibraryRuntime* flr,
|
||||
PjRtDeviceCompiler** pjrt_device_compiler) {
|
||||
PjRtDeviceExecutablePersistor::Config persistor_config(
|
||||
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
|
||||
GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
|
||||
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);
|
||||
Status GetOrCreatePjRtDeviceCompilerAndProfiler(
|
||||
const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr,
|
||||
PjRtDeviceCompiler** pjrt_device_compiler,
|
||||
DeviceCompilationProfiler** profiler) {
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = tfrt_global::GetTFGlobalResourceMgr();
|
||||
|
||||
DeviceType device_type = platform_info.device_type();
|
||||
const auto& device_type = platform_info.device_type();
|
||||
const std::string& compiler_name =
|
||||
GetPjRtDeviceCompilerResourceName(device_type);
|
||||
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
VLOG(2) << "Building PjRtDeviceCompiler using "
|
||||
"platform_info.xla_device_metadata().";
|
||||
// Lookup the DeviceCompiler, create one if not found.
|
||||
Status s = rm->Lookup<PjRtDeviceCompiler>(
|
||||
rm->default_container(), compiler_name, pjrt_device_compiler);
|
||||
if (!s.ok()) {
|
||||
DeviceType compilation_device_type("");
|
||||
xla::PjRtClient* pjrt_client = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetCompilationDeviceTypeAndPjRtClient(
|
||||
platform_info, flr, &compilation_device_type, &pjrt_client));
|
||||
|
||||
DeviceType compilation_device_type =
|
||||
platform_info.xla_device_metadata()->jit_device_type();
|
||||
TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
|
||||
|
||||
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
|
||||
persistor_config, compilation_device_type, pjrt_client);
|
||||
return OkStatus();
|
||||
}
|
||||
if (platform_info.pjrt_device_metadata()) {
|
||||
VLOG(2) << "Building PjRtDeviceCompiler using "
|
||||
"platform_info.pjrt_device_metadata().";
|
||||
|
||||
DeviceType compilation_device_type =
|
||||
platform_info.pjrt_device_metadata()->jit_device_type();
|
||||
TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
|
||||
|
||||
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
|
||||
persistor_config, compilation_device_type, pjrt_client);
|
||||
return OkStatus();
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<PjRtDeviceCompiler>(
|
||||
rm->default_container(), compiler_name, pjrt_device_compiler,
|
||||
[&](PjRtDeviceCompiler** pjrt_device_compiler) {
|
||||
*pjrt_device_compiler =
|
||||
CreatePjRtDeviceCompiler(compilation_device_type, pjrt_client);
|
||||
return OkStatus();
|
||||
}));
|
||||
}
|
||||
|
||||
// TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not
|
||||
// have `xla_device_metadata`.
|
||||
if (device_type == DEVICE_TPU) {
|
||||
TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
|
||||
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
|
||||
persistor_config, DeviceType(DEVICE_TPU_XLA_JIT), pjrt_client);
|
||||
return OkStatus();
|
||||
}
|
||||
const std::string& profiler_name =
|
||||
GetPjRtDeviceCompilationProfilerResourceName(device_type);
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
|
||||
rm->default_container(), profiler_name, profiler,
|
||||
[](DeviceCompilationProfiler** profiler) {
|
||||
*profiler = new DeviceCompilationProfiler();
|
||||
return OkStatus();
|
||||
}));
|
||||
|
||||
VLOG(2) << "platform_info.xla_device_metadata not found and "
|
||||
"platform_info.device_type() != DEVICE_TPU. Building "
|
||||
"PjRtDeviceCompiler for non-XLA device.";
|
||||
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
return errors::InvalidArgument("No JIT device registered for ",
|
||||
device_type.type());
|
||||
}
|
||||
auto compilation_device_type =
|
||||
DeviceType(registration->compilation_device_name);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr));
|
||||
// TODO(b/255826209): Set platform, intra op parallelism threads if required
|
||||
// and when supported by GetOrCreatePjRtClient().
|
||||
// The `allowed_gpus` argument is used only if the `device_type` is GPU.
|
||||
TF_ASSIGN_OR_RETURN(auto pjrt_client,
|
||||
GetOrCreatePjRtClient(device_type, allowed_gpus));
|
||||
|
||||
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
|
||||
persistor_config, compilation_device_type, pjrt_client);
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/jit/device_compiler.h"
|
||||
#include "tensorflow/compiler/jit/pjrt_base_device.h"
|
||||
|
|
@ -113,17 +114,21 @@ Status BuildXlaDeviceCompiler(
|
|||
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>**
|
||||
xla_device_compiler);
|
||||
|
||||
// Builds a DeviceCompiler that uses xla::PjRtClient using an appropriate
|
||||
// Fetches a DeviceCompiler from the tfrt_global resource manager (or creates
|
||||
// one there if not found) that uses xla::PjRtClient using an appropriate
|
||||
// PjRtClient for `platform_info.device_type()` and sets *pjrt_device_compiler
|
||||
// to point to it. Uses flags from `MarkForCompilationPassFlags` for configuring
|
||||
// the persistor used in the DeviceCompiler. Please note that non-XLA devices
|
||||
// aren't supported yet. This is because:
|
||||
// to point to it. Also fetches/creates a DeviceCompilationProfiler from/in the
|
||||
// tfrt_global resource manager for `platform_info.device_type()` and sets
|
||||
// *profiler to point to it. Uses flags from `MarkForCompilationPassFlags` for
|
||||
// configuring the persistor used in the DeviceCompiler. Please note that
|
||||
// non-XLA devices aren't supported yet. This is because:
|
||||
// 1. PjRtClient doesn't support data transfer for non-XLA devices yet
|
||||
// 2. Fetching the PjRtClient for non-XLA devices is also not supported yet
|
||||
Status BuildPjRtDeviceCompiler(
|
||||
Status GetOrCreatePjRtDeviceCompilerAndProfiler(
|
||||
const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr,
|
||||
DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>**
|
||||
pjrt_device_compiler);
|
||||
pjrt_device_compiler,
|
||||
DeviceCompilationProfiler** profiler);
|
||||
|
||||
// Returns information about the platform from kernel context.
|
||||
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerNonXlaDevice) {
|
|||
EXPECT_TRUE(xla_device_compiler->client() != nullptr);
|
||||
}
|
||||
|
||||
TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestXlaDevice) {
|
||||
TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerXlaDevice) {
|
||||
DeviceType device_type = DeviceType(DEVICE_XLA_GPU);
|
||||
device_setup_.AddDevicesAndSetUp({device_type.type()});
|
||||
|
||||
|
|
@ -91,23 +91,27 @@ TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestXlaDevice) {
|
|||
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device);
|
||||
|
||||
PjRtDeviceCompiler* pjrt_device_compiler = nullptr;
|
||||
TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(),
|
||||
&pjrt_device_compiler));
|
||||
DeviceCompilationProfiler* profiler = nullptr;
|
||||
TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler(
|
||||
platform_info, device_setup_.flr(), &pjrt_device_compiler, &profiler));
|
||||
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
|
||||
core::ScopedUnref profiler_ref(profiler);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
|
||||
EXPECT_EQ(pjrt_device_compiler->device_type(), metadata->jit_device_type());
|
||||
EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client);
|
||||
}
|
||||
|
||||
TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestGpuDevice) {
|
||||
TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerGpuDevice) {
|
||||
device_setup_.AddDevicesAndSetUp({DEVICE_GPU});
|
||||
Device* device = device_setup_.GetDevice(DEVICE_GPU);
|
||||
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device);
|
||||
PjRtDeviceCompiler* pjrt_device_compiler = nullptr;
|
||||
TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(),
|
||||
&pjrt_device_compiler));
|
||||
DeviceCompilationProfiler* profiler = nullptr;
|
||||
TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler(
|
||||
platform_info, device_setup_.flr(), &pjrt_device_compiler, &profiler));
|
||||
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
|
||||
core::ScopedUnref profiler_ref(profiler);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -138,7 +142,7 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerTpuDevice) {
|
|||
|
||||
// TODO(b/255826209): Look into using an actual TPU device for the unit test,
|
||||
// and move this out of OSS.
|
||||
TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTpuDevice) {
|
||||
TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerTpuDevice) {
|
||||
DeviceType device_type = DeviceType(DEVICE_TPU);
|
||||
DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
|
||||
// Use a CPU PjRtClient instead of a TPU one just for testing whether
|
||||
|
|
@ -158,9 +162,11 @@ TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTpuDevice) {
|
|||
/*device_allocator=*/nullptr);
|
||||
|
||||
PjRtDeviceCompiler* pjrt_device_compiler = nullptr;
|
||||
TF_EXPECT_OK(
|
||||
BuildPjRtDeviceCompiler(platform_info, nullptr, &pjrt_device_compiler));
|
||||
DeviceCompilationProfiler* profiler = nullptr;
|
||||
TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler(
|
||||
platform_info, nullptr, &pjrt_device_compiler, &profiler));
|
||||
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
|
||||
core::ScopedUnref profiler_ref(profiler);
|
||||
|
||||
EXPECT_EQ(pjrt_device_compiler->device_type(), compilation_device_type);
|
||||
EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client);
|
||||
|
|
|
|||
|
|
@ -272,6 +272,10 @@ compiling to XLA.
|
|||
### `-tf-embedding-pipelining`: Rewrite graph for embedding pipelining
|
||||
For architectures that support accelerated embedding lookups, this pass will
|
||||
rewrite the graph to use pipelining for better device utilization.
|
||||
### `-tf-embedding-sequencing`: Rewrite graph for sequential execution of embeddings
|
||||
This is a strictly sequential and formally correct fallback option for the
|
||||
embedding pipelining pass intended for debugging during pipelining
|
||||
development.
|
||||
### `-tf-executor-break-up-islands`: Transform from TF control dialect to TF executor dialect.
|
||||
### `-tf-executor-check-control-dependencies`: Checks control dependencies
|
||||
This pass analyzes control dependencies between islands and warns about
|
||||
|
|
@ -1993,4 +1997,21 @@ This pass will transform it into
|
|||
### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph
|
||||
Verifies whether all functions in module are of single tf_executor.graph and
|
||||
each tf_executor.island in tf_executor.graph only has a single op.
|
||||
### `-tf-xla-call-module-deserialization`: Deserializes StableHLO functions embedded in `tf.XlaCallModule` to top level module
|
||||
This pass deserializes the StableHLO bytecodes embedded in tf.XlaCallModule,
|
||||
then outlines the functions in the deserialized StableHLO module to the top
|
||||
level MLIR module, with function renamings to avoid naming conflicts.
|
||||
|
||||
After the outlining, it updates tf.XlaCallModule's module attribute to be
|
||||
empty, adds an `_entry_function` attribute referring to the entry function.
|
||||
It also adds a `_from_xla_call_module: true` attribute to each lifted
|
||||
StableHLO function.
|
||||
### `-tf-xla-call-module-serialization`: Serializes StableHLO functions from top-level module into `tf.XlaCallModule`'s `module` attribute
|
||||
This pass collects StableHLO functions referenced from `tf.XlaCallModule`'s
|
||||
`_entry_function` attribute into a module, serializes the module into MLIR
|
||||
bytecode, and embed the bytecode to `tf.XlaCallModule`'s `module` attribute.
|
||||
|
||||
After serialization, this pass removes the `_entry_function` attribute from
|
||||
`tf.XlaCallModule`, and removes all the serialized stablehlo functions
|
||||
from the top-level module.
|
||||
### `-tfe-legalize-tfg`: Legalize from TFG to the TFE dialect
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties):
|
|||
)
|
||||
|
||||
def glob_lit_tests(
|
||||
name = None,
|
||||
exclude = [],
|
||||
test_file_exts = _default_test_file_exts,
|
||||
default_size = _default_size,
|
||||
|
|
@ -78,6 +79,7 @@ def glob_lit_tests(
|
|||
"""Creates all plausible Lit tests (and their inputs) under this directory.
|
||||
|
||||
Args:
|
||||
name: str, name of the test_suite rule to generate for running all tests.
|
||||
exclude: [str], paths to exclude (for tests and inputs).
|
||||
test_file_exts: [str], extensions for files that are tests.
|
||||
default_size: str, the test size for targets not in "size_override".
|
||||
|
|
@ -103,7 +105,10 @@ def glob_lit_tests(
|
|||
|
||||
# Run tests individually such that errors can be attributed to a specific
|
||||
# failure.
|
||||
all_tests = []
|
||||
for curr_test in tests:
|
||||
all_tests.append(curr_test + ".test")
|
||||
|
||||
# Instantiate this test with updated parameters.
|
||||
_run_lit_test(
|
||||
name = curr_test + ".test",
|
||||
|
|
@ -114,3 +119,11 @@ def glob_lit_tests(
|
|||
features = features,
|
||||
exec_properties = exec_properties,
|
||||
)
|
||||
|
||||
# TODO: remove this check after making it a required param.
|
||||
if name:
|
||||
native.test_suite(
|
||||
name = name,
|
||||
tests = all_tests,
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -342,8 +342,6 @@ cc_library(
|
|||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:LoopLikeInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SideEffectInterfaces",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
|
|
@ -1235,6 +1233,8 @@ cc_library(
|
|||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes",
|
||||
"//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass",
|
||||
"//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
|
|
|
|||
|
|
@ -83,7 +83,9 @@ struct PassConfig {
|
|||
// Whether to run the `GuaranteeAllFuncsOneUsePass` to ensure each function
|
||||
// has a single use.
|
||||
bool guarantee_all_funcs_one_use;
|
||||
// Whether to enable the hlo to tf conversion.
|
||||
// Whether to enable the hlo/stablehlo to tf conversion. This also supports
|
||||
// the case where a saved model contains both TF module and serialized
|
||||
// StableHLO module.
|
||||
bool enable_hlo_to_tf_conversion;
|
||||
// Whether to enable to use DynamicUpdateSlice op.
|
||||
bool enable_dynamic_update_slice;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/mlir/lite/emit_error_reporter.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
int EmitErrorReporter::Report(const char* format, va_list args) {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud")
|
||||
|
||||
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
|
||||
|
||||
cc_library(
|
||||
name = "outline_operations",
|
||||
srcs = ["outline_operations.cc"],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
|
||||
|
||||
cc_library(
|
||||
name = "rematerializer",
|
||||
srcs = ["rematerializer.cc"],
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ cc_library(
|
|||
"transforms/get_alternative_subgraph.cc",
|
||||
"transforms/pick_subgraphs.cc",
|
||||
"transforms/raise_target_subgraphs.cc",
|
||||
"transforms/tac_filter.cc",
|
||||
"transforms/target_annotation.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
|
@ -243,6 +244,7 @@ cc_library(
|
|||
":common",
|
||||
":cost_model",
|
||||
":device_transform",
|
||||
":tac_filter_cc_proto",
|
||||
":tac_importer_exporter",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
|
||||
|
|
@ -253,6 +255,7 @@ cc_library(
|
|||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:FuncDialect",
|
||||
|
|
@ -380,3 +383,15 @@ py_library(
|
|||
"//tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper:_pywrap_tac_wrapper",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "tac_filter_proto",
|
||||
srcs = ["tac_filter.proto"],
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "tac_filter_cc_proto",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
deps = [":tac_filter_proto"],
|
||||
)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user