Merge branch 'tensorflow:master' into ica574-doc-contrib

This commit is contained in:
Isaac Cilia Attard 2023-06-11 16:34:30 +02:00 committed by GitHub
commit b154e2c6fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3290 changed files with 149761 additions and 58850 deletions

View File

@ -194,6 +194,9 @@ build:macos --apple_platform_type=macos
# gRPC on MacOS requires this #define
build:macos --copt=-DGRPC_BAZEL_BUILD
# Avoid hitting command line argument limit
build:macos --features=archive_param_file
# Settings for MacOS on ARM CPUs.
build:macos_arm64 --cpu=darwin_arm64
build:macos_arm64 --macos_minimum_os=11.0
@ -345,6 +348,7 @@ build:windows --host_copt=/D_USE_MATH_DEFINES
# Windows has a relatively short command line limit, which TF has begun to hit.
# See https://docs.bazel.build/versions/main/windows.html
build:windows --features=compiler_param_file
build:windows --features=archive_param_file
# Speed Windows compile times. Available in VS 16.4 (we are on 16.11). See
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
@ -446,7 +450,6 @@ build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
build:rbe --flaky_test_attempts=3
build:rbe --jobs=800
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
build:rbe --remote_timeout=3600
@ -627,7 +630,6 @@ try-import %workspace%/.bazelrc.user
# Here are bazelrc configs for release builds
build:release_base --config=v2
test:release_base --flaky_test_attempts=3
test:release_base --test_size_filters=small,medium
build:release_cpu_linux --config=release_base
@ -691,10 +693,10 @@ build:ubsan --linkopt -fsanitize=undefined
build:ubsan --linkopt -lubsan
# Disable TFRT integration for now unless --config=tfrt is specified.
build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug
build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python
# TODO(b/240450920): We are in the process of migrating JitRt backend to XLA
# and while we are doing this we can't keep it buildable/testable in OSS.
build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug
build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python
# TF Fuzztest config
try-import fuzztest.bazelrc

View File

@ -1,2 +1,2 @@
5.3.0
6.1.0
# NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss

View File

@ -15,7 +15,7 @@
# A list of assignees
assignees:
- synandi
- sushreebarsa
- SuryanarayanaY
- tilakrayal
# A list of assignees for compiler folder

View File

@ -28,6 +28,7 @@ jobs:
runs-on: [self-hosted, linux, ARM64]
continue-on-error: ${{ matrix.experimental }}
strategy:
fail-fast: false
matrix:
pyver: ['3.8', '3.9', '3.10']
experimental: [false]

View File

@ -27,8 +27,9 @@ jobs:
if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks
runs-on: [self-hosted, linux, ARM64]
strategy:
fail-fast: false
matrix:
pyver: ['3.10']
pyver: ['3.8', '3.9', '3.10', '3.11']
steps:
- name: Stop old running containers (if any)
shell: bash

View File

@ -92,6 +92,18 @@ jobs:
map sigbuild-r2.13-clang-python3.9 2.13-python3.9
map sigbuild-r2.13-clang-python3.10 2.13-python3.10
map sigbuild-r2.13-clang-python3.11 2.13-python3.11
# TF 2.14
map sigbuild-r2.14 2.14-python3.9
map sigbuild-r2.14-python3.8 2.14-python3.8
map sigbuild-r2.14-python3.9 2.14-python3.9
map sigbuild-r2.14-python3.10 2.14-python3.10
map sigbuild-r2.14-python3.11 2.14-python3.11
# TF 2.14 + Clang (containers are the same, but env vars in configs.bzl are different)
map sigbuild-r2.14-clang 2.14-python3.9
map sigbuild-r2.14-clang-python3.8 2.14-python3.8
map sigbuild-r2.14-clang-python3.9 2.14-python3.9
map sigbuild-r2.14-clang-python3.10 2.14-python3.10
map sigbuild-r2.14-clang-python3.11 2.14-python3.11
- name: Create Pull Request with changes
uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3
with:

View File

@ -2,7 +2,7 @@
<img src="https://www.tensorflow.org/images/tf_logo_horizontal.png">
</div>
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg)](https://badge.fury.io/py/tensorflow)
[![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow)
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.4724125.svg)](https://doi.org/10.5281/zenodo.4724125)
[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486)
@ -11,6 +11,8 @@
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow-py.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow-py)
[![OSSRank](https://shields.io/endpoint?url=https://ossrank.com/shield/44)](https://ossrank.com/p/44)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md)
[![TF Official Continuous](https://tensorflow.github.io/build/TF%20Official%20Continuous.svg)](https://tensorflow.github.io/build#TF%20Official%20Continuous)
[![TF Official Nightly](https://tensorflow.github.io/build/TF%20Official%20Nightly.svg)](https://tensorflow.github.io/build#TF%20Official%20Nightly)
**`Documentation`** |
------------------- |

View File

@ -16,6 +16,11 @@
2.13 may be used when it is necessary to determine if a value is
specifically a symbolic tensor.
* `tf.compat.v1.Session`
* `tf.compat.v1.Session.partial_run` and
`tf.compat.v1.Session.partial_run_setup` will be deprecated in the
next release.
# Known Caveats
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES).>
@ -26,6 +31,15 @@
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* `tf.keras`
* `Model.compile` now support `steps_per_execution='auto'` as a parameter,
allowing automatic tuning of steps per execution during `Model.fit`,
`Model.predict`, and `Model.evaluate` for a significant performance boost.
* Enable JIT-compiled i64-indexed kernels on GPU for large tensors with more
than 2**32 elements.
* Unary GPU kernels: Abs, Atanh, Acos, Acosh, Asin, Asinh, Atan, Cos,
Cosh, Sin, Sinh, Tan, Tanh.
# Bug Fixes and Other Changes
* `tf.lite`
@ -34,6 +48,22 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* `tf.config.experimental.enable_tensor_float_32_execution`
* Disabling TensorFloat-32 execution now causes TPUs to use float32
precision for float32 matmuls and other ops. TPUs have always used
bfloat16 precision for certain ops, like matmul, when such ops had float32
inputs. Now, disabling TensorFloat-32 by calling
`tf.config.experimental.enable_tensor_float_32_execution(False)` will
cause TPUs to use float32 precision for such ops instead of bfloat16.
* `tf.experimental.dtensor`
* API changes for Relayout. Added a new API, `dtensor.relayout_like`, for
relayouting a tensor according to the layout of another tensor.
* Added `dtensor.get_default_mesh`, for retrieving the current default
mesh under the dtensor context.
* TensorFlow Debugger (tfdbg) CLI: ncurses-based CLI for tfdbg v1 was removed.
# Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
@ -185,6 +215,9 @@ This release contains contributions from many people at Google, as well as:
`dataset = dataset.shuffle(dataset.cardinality())`. This will load the
full dataset into memory so that it can be shuffled, so make sure to
only use this with datasets of filenames or other small datasets.
* Added a new `tf.data.experimental.pad_to_cardinality` transformation
which pads a dataset with zero elements up to a specified cardinality.
This is useful for avoiding partial batches while not dropping any data.
* `tf.math`
@ -243,6 +276,8 @@ This release contains contributions from many people at Google, as well as:
* `tf.lite`:
* Add UINT32 support to tfl.pack
* Add INT64 support to tfl.range
* Add UINT32 support to tfl.concatenation
## Thanks to our Contributors

17
ci/README.md Normal file
View File

@ -0,0 +1,17 @@
# TensorFlow continuous integration
> **Warning** This folder is still under construction. It is part of an ongoing
> effort to improve the structure of CI and build related files within the
> TensorFlow repo. This warning will be removed when the contents of this
> directory are stable and appropriate documentation around its usage is in
> place.
Maintainer: TensorFlow DevInfra
********************************************************************************
The CI folder contains the configuration files and scripts used to build, test,
and deploy TensorFlow. This folder is typically used by continuous integration
(CI) tools to build and test TensorFlow whenever there is a change to the
code. This folder is broken into subfolders that represent the level of support
and ownership of the files contained within.

17
ci/devinfra/README.md Normal file
View File

@ -0,0 +1,17 @@
# DevInfra CI Directory
> **Warning** This folder is still under construction. It is part of an ongoing
> effort to improve the structure of CI and build related files within the
> TensorFlow repo. This warning will be removed when the contents of this
> directory are stable and appropriate documentation around its usage is in
> place.
Maintainer: TensorFlow DevInfra
Issue Reporting: File an issue against this repo and tag
[@devinfra](https://github.com/orgs/tensorflow/teams/devinfra)
********************************************************************************
A directory for build and CI related scripts and jobs managed by the TensorFlow
DevInfra team but not part of the official build, test, or release process.

17
ci/official/README.md Normal file
View File

@ -0,0 +1,17 @@
# Official CI Directory
> **Warning** This folder is still under construction. It is part of an ongoing
> effort to improve the structure of CI and build related files within the
> TensorFlow repo. This warning will be removed when the contents of this
> directory are stable and appropriate documentation around its usage is in
> place.
Maintainer: TensorFlow and TensorFlow DevInfra
Issue Reporting: File an issue against this repo and tag
[@devinfra](https://github.com/orgs/tensorflow/teams/devinfra)
********************************************************************************
A directory for build and CI related scripts and jobs that are used and
monitored as part of the official TensorFlow build, test, and release process.

View File

@ -964,7 +964,6 @@ def set_other_cuda_vars(environ_cp):
def system_specific_test_config(environ_cp):
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
write_to_bazelrc('test --test_size_filters=small,medium')
# Each instance of --test_tag_filters or --build_tag_filters overrides all

View File

@ -109,6 +109,7 @@ PACKAGE_STATIC_DEPS = [
"@local_execution_config_platform//:__subpackages__",
"@mkl_dnn_acl_compatible//:__subpackages__",
"@mkl_dnn_v1//:__subpackages__",
"@ml_dtypes//:__subpackages__",
"@nccl_archive//:__subpackages__",
"@nvtx_archive//:__subpackages__",
"@org_sqlite//:__subpackages__",
@ -1036,7 +1037,13 @@ package_group(
],
)
package_group(name = "ndarray_tensor_allow_list")
package_group(
name = "ndarray_tensor_allow_list",
packages = [
"//third_party/py/courier/...",
"//third_party/py/tensorfn/...",
],
)
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.

View File

@ -129,6 +129,7 @@ cc_library(
# TODO: Only include tf_tstring_hdrs. Don't expose the implementation of TF_TString to API
# users.
":tf_tstring",
"//tensorflow/core:protos_all_cc",
],
)
@ -171,6 +172,7 @@ tf_cuda_library(
":tf_buffer_internal",
":tf_status_internal",
":tf_tensor_internal",
"//tensorflow/core:protos_all_cc",
],
)
@ -238,6 +240,7 @@ tf_cuda_library(
":tf_status_internal",
":tf_tensor_internal",
":tf_tstring",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:tstring",
"//tensorflow/tsl/c:tsl_status",
] + select({
@ -881,7 +884,7 @@ tf_cc_test(
tf_cuda_cc_test(
name = "c_api_test",
size = "small",
size = "medium",
srcs = ["c_api_test.cc"],
data = [
":test_op1.so",
@ -968,7 +971,7 @@ tf_cc_test(
tf_cc_test(
name = "c_api_function_test",
size = "small",
size = "medium",
srcs = ["c_api_function_test.cc"],
deps = [
":c_api",
@ -985,7 +988,7 @@ tf_cc_test(
tf_cc_test(
name = "while_loop_test",
size = "small",
size = "medium",
srcs = ["while_loop_test.cc"],
deps = [
":c_api",
@ -1013,7 +1016,7 @@ tf_kernel_library(
tf_cuda_cc_test(
name = "env_test",
size = "small",
size = "medium",
srcs = ["env_test.cc"],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
@ -1032,7 +1035,7 @@ tf_cuda_cc_test(
tf_cuda_cc_test(
name = "kernels_test",
size = "small",
size = "medium",
srcs = ["kernels_test.cc"],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
@ -1059,7 +1062,7 @@ tf_cuda_cc_test(
tf_cc_test(
name = "ops_test",
size = "small",
size = "medium",
srcs = ["ops_test.cc"],
linkopts = select({
"//conditions:default": [],

View File

@ -142,7 +142,7 @@ struct TF_ImportGraphDefOptions {
// Backing memory for TensorId fields in opts.
// TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
std::list<tensorflow::string> tensor_id_data;
std::vector<tensorflow::string> tensor_id_data;
};
struct TF_ImportGraphDefResults {
@ -152,7 +152,7 @@ struct TF_ImportGraphDefResults {
std::vector<int> missing_unused_key_indexes;
// Backing memory for missing_unused_key_names values.
std::list<tensorflow::string> missing_unused_key_names_data;
std::vector<tensorflow::string> missing_unused_key_names_data;
};
struct TF_DeviceList {

View File

@ -26,7 +26,12 @@ limitations under the License.
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#ifdef TF_CAPI_WEAK
#define TF_CAPI_EXPORT \
__attribute__((visibility("default"))) __attribute((weak))
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // TF_CAPI_WEAK
#endif // _WIN32
#endif // SWIG

View File

@ -8,7 +8,7 @@ load(
"tf_cuda_cc_test",
"tf_cuda_library",
)
load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "internal_tfrt_deps")
load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup")
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
@ -95,7 +95,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
] + internal_tfrt_deps(),
],
alwayslink = 1,
)
@ -636,7 +636,7 @@ tf_cuda_library(
tf_cuda_cc_test(
name = "c_api_test",
size = "small",
size = "medium",
srcs = [
"c_api_debug_test.cc",
"c_api_test.cc",
@ -653,7 +653,6 @@ tf_cuda_cc_test(
":c_api_test_util",
":tfe_op_internal",
":tfe_tensorhandle_internal",
"@com_google_absl//absl/strings",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -663,10 +662,7 @@ tf_cuda_cc_test(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
# copybara:uncomment_begin
# "//tensorflow/core/tfrt/eager:c_api_tfrt",
# "@tf_runtime//backends/cpu:tf_ops_alwayslink",
# copybara:uncomment_end
"@com_google_absl//absl/strings",
],
)
@ -693,7 +689,7 @@ tf_cuda_library(
tf_cuda_cc_test(
name = "c_api_remote_test",
size = "small",
size = "medium",
srcs = [
"c_api_remote_test.cc",
],
@ -725,7 +721,7 @@ tf_cuda_cc_test(
tf_cuda_cc_test(
name = "c_api_remote_function_test",
size = "small",
size = "medium",
srcs = [
"c_api_remote_function_test.cc",
],
@ -776,7 +772,7 @@ tf_cuda_cc_test(
tf_cuda_cc_test(
name = "c_api_cluster_test",
size = "small",
size = "medium",
srcs = [
"c_api_cluster_test.cc",
],
@ -1014,7 +1010,7 @@ cc_library(
tf_cc_test(
name = "custom_device_test",
size = "small",
size = "medium",
srcs = [
"custom_device_test.cc",
],

View File

@ -64,13 +64,6 @@ limitations under the License.
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/public/version.h"
// "tensorflow/core/platform/platform.h" must be included first before using
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \
!defined(PLATFORM_FUCHSIA)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
#endif // !IS_MOBILE_PLATFORM
@ -117,18 +110,8 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \
!defined(PLATFORM_FUCHSIA)
tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
opts->async);
return tensorflow::wrap(tfrt_context);
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(

View File

@ -434,7 +434,7 @@ class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
tensorflow::Status Run(const std::string& function_name,
const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
absl::string_view xla_compile_device_type,
const FunctionOptions& function_options,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,

View File

@ -47,10 +47,6 @@ limitations under the License.
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#ifdef PLATFORM_GOOGLE
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif
using tensorflow::string;
namespace {
@ -1262,16 +1258,6 @@ TEST(CAPI, RunAddFunctionWithGrappler) {
RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true);
}
#ifdef PLATFORM_GOOGLE
TEST(CAPI, RunAddFunction_TFRT) {
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false);
}
TEST(CAPI, RunAddFunctionWithGrappler_TFRT) {
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true);
}
#endif
void BM_ExecuteFunction(::testing::benchmark::State& state) {
const int async = state.range(0);
state.SetLabel(async ? "ExecuteFunctionAsync" : "ExecuteFunction");
@ -1802,23 +1788,9 @@ void TestOpAddAttrs(bool use_tfrt) {
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
if (use_tfrt) {
#ifdef PLATFORM_GOOGLE
auto* op = tensorflow::down_cast<tfrt::tf::OperationInterface*>(
tensorflow::unwrap(copy_op));
auto* tfrt_op_attrs =
tensorflow::down_cast<const tfrt::tf::OpAttrsInterface*>(
op->GetOpAttrs());
tensorflow::DataType result;
tfrt_op_attrs->GetType("dtype", &result);
EXPECT_EQ(tensorflow::DT_FLOAT, result);
tfrt_op_attrs->GetFallbackAttrs()->FillAttrValueMap(&attr_values);
#endif
} else {
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
op->Attrs().FillAttrValueMap(&attr_values);
}
tensorflow::EagerOperation* op =
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
@ -1829,11 +1801,6 @@ void TestOpAddAttrs(bool use_tfrt) {
TEST(CAPI, TestTFE_OpAddAttrs) { TestOpAddAttrs(/*use_tfrt=*/false); }
#ifdef PLATFORM_GOOGLE
TEST(CAPI, TestTFE_OpAddAttrs_TFRT) { TestOpAddAttrs(/*use_tfrt=*/true); }
#endif
TEST(CAPI, TestTFE_OpAttrsSerialize) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -1006,18 +1006,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
// The above tests are run for a combination of:
// - graphdef and MLIR tracing engine
// - Using TFRT as an execution runtime (true == enable TFRT)
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(false)));
#endif
} // namespace
} // namespace tensorflow

View File

@ -188,18 +188,10 @@ TEST_P(UnifiedAPI, TestPartialShapeTracing) {
ASSERT_EQ(-1, shape.dim_size(1));
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCppAPI, UnifiedAPI,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(true, false),
/*use_function*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCppAPI, UnifiedAPI,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*use_function*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace tensorflow

View File

@ -125,19 +125,12 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
result_tensor = nullptr;
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
CustomGradientTest, CustomGradientTest,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(true, false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
CustomGradientTest, CustomGradientTest,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients

View File

@ -24,6 +24,7 @@ cc_library(
"//tensorflow/compiler/xla/pjrt:pjrt_c_api_client",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs",
"//tensorflow/compiler/xla/stream_executor/tpu:tpu_initializer_helper",
"//tensorflow/core:framework",
"//tensorflow/core/common_runtime/next_pluggable_device:plugin_resource",
"//tensorflow/core/platform:status",
@ -32,6 +33,8 @@ cc_library(
"//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent",
"//tensorflow/tsl/platform:errors",
"//tensorflow/tsl/platform:statusor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)

View File

@ -21,6 +21,9 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/kernels_experimental.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_status_internal.h"
@ -30,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/variable_info_util.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.h" // NOLINT(unused-includes): required for tensorflow::tpu::FindAndLoadTpuLibrary
#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/status.h"
@ -110,9 +114,9 @@ TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx,
const tensorflow::Tensor& arg_tensor = cc_ctx->input(index);
tsl::Status cc_status;
if (arg_tensor.dtype() != tensorflow::DT_RESOURCE) {
cc_status = tsl::errors::InvalidArgument(
"Trying to obtain resource handle from Input[", index,
"], which is not type DT_RESOURCE.");
cc_status = absl::InvalidArgumentError(
absl::StrCat("Trying to obtain resource handle from Input[", index,
"], which is not type DT_RESOURCE."));
tsl::Set_TF_Status_from_Status(status, cc_status);
return nullptr;
}
@ -140,12 +144,12 @@ void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx,
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelContext*>(ctx);
tsl::Status cc_status;
if (var_info == nullptr) {
cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL.");
cc_status = absl::InvalidArgumentError("TF_VariableInfo is NULL.");
tsl::Set_TF_Status_from_Status(status, cc_status);
return;
}
if (var_info->var_info.var() == nullptr) {
cc_status = tsl::errors::InvalidArgument(
cc_status = absl::InvalidArgumentError(
"VariableInfo does not track a resource variable.");
tsl::Set_TF_Status_from_Status(status, cc_status);
return;
@ -161,12 +165,12 @@ TF_Tensor* TF_GetTensorFromVariableInfo(TF_VariableInfo* var_info,
TF_Status* status) {
tsl::Status cc_status;
if (var_info == nullptr) {
cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL.");
cc_status = absl::InvalidArgumentError("TF_VariableInfo is NULL.");
tsl::Set_TF_Status_from_Status(status, cc_status);
return nullptr;
}
if (var_info->var_info.var() == nullptr) {
cc_status = tsl::errors::InvalidArgument(
cc_status = absl::InvalidArgumentError(
"VariableInfo does not track a resource variable.");
tsl::Set_TF_Status_from_Status(status, cc_status);
return nullptr;
@ -239,6 +243,18 @@ void TF_CoordinationServiceDeleteKeyValue(const char* key,
void TF_CreateAndSetPjRtCApiClient(const char* device_type, TF_Status* status,
PJRT_NamedValue* create_options,
int num_options) {
// TODO(b/262050449): use a common plugin discovery mechanism, rather than
// having TPU-specific code here.
#if !defined(PLATFORM_GOOGLE) || defined(LIBTPU_STATIC)
if (absl::AsciiStrToLower(device_type) == "tpu") {
// TODO(b/261484192): handle device specific initialization.
tsl::Status tpu_status = tensorflow::tpu::FindAndLoadTpuLibrary();
if (!tpu_status.ok()) {
tensorflow::Set_TF_Status_from_Status(status, tpu_status);
return;
}
}
#endif
tsl::StatusOr<std::unique_ptr<xla::PjRtClient>> pjrt_client =
xla::GetCApiClient(device_type, pjrt::ConvertFromPjRtNamedValueList(
create_options, num_options));
@ -263,8 +279,9 @@ PJRT_Client* TF_GetPjRtCClient(const char* device_type, TF_Status* status) {
tensorflow::down_cast<xla::PjRtCApiClient*>(*pjrt_client);
if (pjrt_c_api_client == nullptr) {
tensorflow::Set_TF_Status_from_Status(
status, tsl::errors::Internal("PjRtClient for ", device_type,
" is not type PjRtCApiClient"));
status,
absl::InternalError(absl::StrCat("PjRtClient for ", device_type,
" is not type PjRtCApiClient")));
return nullptr;
}
TF_SetStatus(status, TF_OK, "");
@ -282,8 +299,7 @@ PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) {
tensorflow::AsyncValueTensor::FromTensor(&tensor);
if (av_tensor == nullptr || av_tensor->GetBuffer() == nullptr) {
tensorflow::Set_TF_Status_from_Status(
status,
tsl::errors::Internal("Input tensor does not have PjRtBuffer."));
status, absl::InternalError("Input tensor does not have PjRtBuffer."));
return nullptr;
}
auto* c_api_buffer =
@ -291,7 +307,7 @@ PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) {
if (c_api_buffer == nullptr) {
tensorflow::Set_TF_Status_from_Status(
status,
tsl::errors::Internal(
absl::InternalError(
"The PjRtBuffer in the tensor is not type PjRtCApiBuffer."));
return nullptr;
}
@ -317,8 +333,9 @@ void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer,
tensorflow::down_cast<xla::PjRtCApiClient*>(*pjrt_client);
if (pjrt_c_api_client == nullptr) {
tensorflow::Set_TF_Status_from_Status(
status, tsl::errors::Internal("PjRtClient for ", device_type,
" is not type PjRtCApiClient"));
status,
absl::InternalError(absl::StrCat("PjRtClient for ", device_type,
" is not type PjRtCApiClient")));
return;
}
tensorflow::AsyncValueTensor* av_tensor =
@ -326,7 +343,7 @@ void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer,
if (av_tensor == nullptr) {
tensorflow::Set_TF_Status_from_Status(
status,
tsl::errors::Internal(
absl::InternalError(
"The tensor to set PjRtBuffer is not an AsyncValueTensor."));
return;
}

View File

@ -28,6 +28,8 @@ cc_library(
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
@ -99,6 +101,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <string>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
@ -37,8 +39,8 @@ Status Asset::Create(ImmediateExecutionContext* ctx,
io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename);
AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create scalar string tensor for Asset at path ", abs_path);
return absl::InternalError(absl::StrCat(
"Failed to create scalar string tensor for Asset at path ", abs_path));
}
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <string>
#include <utility>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
@ -60,15 +62,15 @@ Status AssertAllCreateResourceFunctionsHaveNoCaptures(
const TFConcreteFunctionRevivalState* create_resource_fn =
resource.create_resource;
if (create_resource_fn == nullptr) {
return errors::FailedPrecondition(
"Resource at node ", node_id,
" did not have a create_resource() function");
return absl::FailedPreconditionError(
absl::StrCat("Resource at node ", node_id,
" did not have a create_resource() function"));
}
const SavedConcreteFunction* saved_create_resource_fn =
create_resource_fn->saved_concrete_func;
if (!saved_create_resource_fn->bound_inputs().empty()) {
// TODO(b/124045874): Support loading resource functions via a top sort
return errors::Unimplemented(
return absl::UnimplementedError(
"Create Resource functions with captures are currently unsupported.");
}
}
@ -86,9 +88,9 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
case SavedObject::kVariable: {
const auto& variables_iter = objects.variables.find(node_id);
if (variables_iter == objects.variables.end()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Tried to convert node id ", node_id,
" of type variable to tensor but the variable wasn't initialized");
" of type variable to tensor but the variable wasn't initialized"));
}
*handle = variables_iter->second->handle();
return Status();
@ -96,9 +98,10 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
case SavedObject::kConstant: {
const auto& constants_iter = objects.constants.find(node_id);
if (constants_iter == objects.constants.end()) {
return errors::FailedPrecondition("Tried to convert node id ", node_id,
" of type constant to tensor but the "
"constant wasn't initialized");
return absl::FailedPreconditionError(
absl::StrCat("Tried to convert node id ", node_id,
" of type constant to tensor but the "
"constant wasn't initialized"));
}
*handle = constants_iter->second->handle();
return Status();
@ -106,9 +109,9 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
case SavedObject::kAsset: {
const auto& assets_iter = objects.assets.find(node_id);
if (assets_iter == objects.assets.end()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Tried to convert node id ", node_id,
" of type asset to tensor but the asset wasn't initialized");
" of type asset to tensor but the asset wasn't initialized"));
}
*handle = assets_iter->second->handle();
return Status();
@ -116,24 +119,24 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
case SavedObject::kResource: {
const auto& resource_iter = objects.restored_resources.find(node_id);
if (resource_iter == objects.restored_resources.end()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Tried to convert node id ", node_id,
" of type Resource to tensor but the Resource wasn't initialized");
" of type Resource to tensor but the Resource wasn't initialized"));
}
const RestoredResourceRevivalState& resource = resource_iter->second;
if (resource.resource_handle == nullptr) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Resource with node id ", node_id,
" should have its resource_handle created, but was nullptr.");
" should have its resource_handle created, but was nullptr."));
}
*handle = resource.resource_handle.get();
return Status();
}
default: {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Only objects of type variable, constant, asset, and resources have "
"capturable tensorhandles. Encountered object of kind ",
node.kind_case(), " at node id: ", node_id);
node.kind_case(), " at node id: ", node_id));
}
}
}
@ -167,35 +170,35 @@ Status SignatureDefArgsFromInputs(
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
// string keys to TensorSpecs.
if (!canonicalized_input_signature.has_tuple_value()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"SignatureDefFunction's canonicalized_input_signature should be "
"of form tuple(tuple(), dict()), but was instead: \n",
canonicalized_input_signature.DebugString());
canonicalized_input_signature.DebugString()));
}
const TupleValue& args_kwargs_tuple =
canonicalized_input_signature.tuple_value();
if (args_kwargs_tuple.values_size() != 2) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"SignatureDefFunction's canonicalized_input_signature should be "
"a tuple of two elements (args, kwargs), but was instead: \n",
args_kwargs_tuple.DebugString());
args_kwargs_tuple.DebugString()));
}
const StructuredValue& args = args_kwargs_tuple.values(0);
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"SignatureDefFunction's canonicalized_input_signature's args"
"should be an empty tuple, but instead got: \n",
args.DebugString());
args.DebugString()));
}
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
if (!kwargs.has_dict_value()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"should be a dictionary, but instead got: \n",
kwargs.DebugString());
kwargs.DebugString()));
}
const DictValue& kwargs_dict = kwargs.dict_value();
@ -206,10 +209,10 @@ Status SignatureDefArgsFromInputs(
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"dictionary contained a non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
"Key: ", key, "Value: \n", value.DebugString()));
}
result[key] = &value.tensor_spec_value();
}
@ -226,10 +229,10 @@ Status SignatureDefArgsFromInputs(
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
std::vector<SignatureDefParam>* out) {
if (!output_signature.has_dict_value()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"SignatureDefFunction's output_signature must be a dictionary, but "
"instead got: ",
output_signature.DebugString());
output_signature.DebugString()));
}
const DictValue& output_dict = output_signature.dict_value();
@ -240,10 +243,10 @@ Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"SignatureDefFunction's output_signature dictionary contained a "
"non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
"Key: ", key, "Value: \n", value.DebugString()));
}
result[key] = &value.tensor_spec_value();
}
@ -337,7 +340,7 @@ Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx,
create_resource_fn->saved_concrete_func;
if (!saved_create_resource_fn->bound_inputs().empty()) {
// TODO(b/124045874): Load resource functions via a topological sort
return errors::Unimplemented(
return absl::UnimplementedError(
"Create Resource functions with captures are currently unsupported.");
}
std::unique_ptr<TFConcreteFunction> out;
@ -401,9 +404,9 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
const TFConcreteFunction* create_resource_fn =
revived->concrete_functions.Find(create_resource_fn_node);
if (create_resource_fn == nullptr) {
return errors::FailedPrecondition(
"ConcreteFunction at node ", create_resource_fn_node,
" should have been initialized prior to being called.");
return absl::FailedPreconditionError(
absl::StrCat("ConcreteFunction at node ", create_resource_fn_node,
" should have been initialized prior to being called."));
}
ImmediateOpPtr function_op;
TF_RETURN_IF_ERROR(create_resource_fn->MakeCallOp({}, &function_op));
@ -416,7 +419,7 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx,
AbstractTensorHandlePtr owned_resource_handle(resource_handle);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
owned_resource_handle.get())) {
return errors::Internal("Unexpected tensor handle kind.");
return absl::InternalError("Unexpected tensor handle kind.");
}
ImmediateTensorHandlePtr result(
reinterpret_cast<ImmediateExecutionTensorHandle*>(
@ -443,9 +446,9 @@ Status BuildResources(ImmediateExecutionContext* ctx,
create_resource = revived->concrete_functions.Find(
resource_revival_state.create_resource->node_id);
if (create_resource == nullptr) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"'create_resource' function with node id ",
resource_revival_state.create_resource->node_id, " not found");
resource_revival_state.create_resource->node_id, " not found"));
}
}
@ -454,9 +457,9 @@ Status BuildResources(ImmediateExecutionContext* ctx,
initialize = revived->concrete_functions.Find(
resource_revival_state.initialize->node_id);
if (initialize == nullptr) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"'initialize' function with node id ",
resource_revival_state.initialize->node_id, " not found");
resource_revival_state.initialize->node_id, " not found"));
}
}
@ -465,15 +468,16 @@ Status BuildResources(ImmediateExecutionContext* ctx,
destroy_resource = revived->concrete_functions.Find(
resource_revival_state.destroy_resource->node_id);
if (destroy_resource == nullptr) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"'destroy_resource' function with node id ",
resource_revival_state.destroy_resource->node_id, " not found");
resource_revival_state.destroy_resource->node_id, " not found"));
}
}
if (resource_revival_state.resource_handle == nullptr) {
return errors::FailedPrecondition("Resource at node id ", node_id,
" does not have a resource handle.");
return absl::FailedPreconditionError(
absl::StrCat("Resource at node id ", node_id,
" does not have a resource handle."));
}
revived->restored_resources.emplace(

View File

@ -344,7 +344,7 @@ cc_library(
tf_cc_test(
name = "saved_model_api_test",
size = "small",
size = "medium",
srcs = [
"saved_model_api_test.cc",
],

View File

@ -13,15 +13,15 @@ py_strict_binary(
srcs = ["gen_saved_models.py"],
python_version = "PY3",
deps = [
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:variables",
"//tensorflow/python:while_loop",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/module",
"//tensorflow/python/ops:resource_variable_ops",
"//tensorflow/python/ops:variables",
"//tensorflow/python/ops:while_loop",
"//tensorflow/python/saved_model",
"@absl_py//absl:app",
],

View File

@ -189,6 +189,8 @@ cc_library_with_android_deps(
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)

View File

@ -1,4 +1,5 @@
# TODO(aselle): describe this package.
#include "third_party/absl/strings/str_cat.h"
#TODO(aselle) : describe this package.
load(
"//tensorflow/core/platform:rules_cc.bzl",
@ -42,6 +43,8 @@ cc_library(
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:statusor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
@ -84,13 +87,13 @@ py_strict_binary(
srcs = ["tests/generate_testdata.py"],
python_version = "PY3",
deps = [
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/module",
"//tensorflow/python/ops:variables",
"//tensorflow/python/saved_model",
"@absl_py//absl:app",
"@absl_py//absl/flags",
@ -180,13 +183,11 @@ tf_cc_test(
size = "medium",
srcs = [
"tests/runtime_test_core.cc",
"tests/runtime_test_tfrt.cc",
],
deps = [
":runtime_test",
"//tensorflow/cc/experimental/libtf/runtime",
"//tensorflow/cc/experimental/libtf/runtime/core",
"//tensorflow/cc/experimental/libtf/runtime/tfrt",
],
)

View File

@ -29,6 +29,8 @@ limitations under the License.
#include <string>
#include <utility>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/cc/experimental/libtf/value.h"
#include "tensorflow/core/platform/errors.h"
@ -172,7 +174,7 @@ class Object : public Handle {
}
}
}
return tensorflow::errors::NotFound("Key not in dictionary.");
return absl::NotFoundError("Key not in dictionary.");
}
/// Sets `key` attribute with the underlying value of `h`.
@ -202,7 +204,7 @@ class Dictionary final : public Handle {
tensorflow::StatusOr<T> Get(const Handle& key) {
auto it = value_.dict().find(key.value_);
if (it != value_.dict().end()) return Cast<T>(Handle(it->second));
return tensorflow::errors::NotFound("Key not in dictionary.");
return absl::NotFoundError("Key not in dictionary.");
}
/// Sets `key` with value `value`.
void Set(const String& key, Handle value) {
@ -282,7 +284,7 @@ tensorflow::Status Tensor::GetValue(absl::Span<T> data) const {
{
const auto abstract_t = value_.tensor().get();
if (!tensorflow::ImmediateExecutionTensorHandle::classof(abstract_t)) {
return tensorflow::errors::InvalidArgument(
return absl::InvalidArgumentError(
"Attempting to get value of non eager tensor.");
}
auto imm_t =
@ -315,7 +317,7 @@ class Tuple : public Handle {
template <class T>
tensorflow::StatusOr<T> Get(size_t i) {
if (i >= value_.tuple().size())
return tensorflow::errors::InvalidArgument("Out of bounds index.");
return absl::InvalidArgumentError("Out of bounds index.");
return Cast<T>(Handle(value_.tuple()[i]));
}
@ -348,7 +350,7 @@ class List final : public Handle {
template <class T>
tensorflow::StatusOr<T> Get(size_t i) {
if (i >= size()) {
return tensorflow::errors::InvalidArgument("Out of bounds index.");
return absl::InvalidArgumentError("Out of bounds index.");
}
return Cast<T>(Handle(value_.list()[i]));
}
@ -356,7 +358,7 @@ class List final : public Handle {
/// Sets value `h` at index `i`.
tensorflow::Status Set(size_t i, Handle h) {
if (i >= size()) {
return tensorflow::errors::InvalidArgument("Out of bounds index.");
return absl::InvalidArgumentError("Out of bounds index.");
}
value_.list()[i] = std::move(h.value_);
return ::tensorflow::OkStatus();
@ -533,7 +535,7 @@ tensorflow::StatusOr<T> Cast(Handle handle) {
if (handle.value_.type() == TypeToTaggedType<T>() ||
std::is_same<T, Handle>::value)
return T((std::move(handle.value_)));
return tensorflow::errors::InvalidArgument("Incompatible cast.");
return absl::InvalidArgumentError("Incompatible cast.");
}
// Converters for C++ primitives like float and int to handles. Allows callable
@ -656,10 +658,10 @@ class UneraseCallHelper<Fn, TReturn, TSignatureArg, TSignatureRest...> {
Handle h(std::move(args_in.tuple()[argument_index]));
tensorflow::StatusOr<TSignatureArg> x = Cast<TSignatureArg>(std::move(h));
if (!x.ok())
return tensorflow::errors::InvalidArgument(
std::string("Function ") + name + " Arg " +
std::to_string(argument_index) +
" cannot be cast to desired signature type ");
return absl::InvalidArgumentError(
absl::StrCat(std::string("Function ") + name + " Arg " +
std::to_string(argument_index) +
" cannot be cast to desired signature type "));
return UneraseCallHelper<Fn, TReturn, TSignatureRest...>::template Call(
name, fn, argument_index + 1, args_in, args..., *x);
}
@ -683,9 +685,9 @@ class CallableWrapper {
TaggedValue kwargs) {
constexpr size_t argument_count = sizeof...(TFuncArgs);
if (argument_count != args.tuple().size())
return tensorflow::errors::InvalidArgument(
std::string("Function ") + name_ + " expected " +
std::to_string(argument_count) + " args.");
return absl::InvalidArgumentError(
absl::StrCat(std::string("Function ") + name_ + " expected " +
std::to_string(argument_count) + " args."));
return UneraseCallHelper<Fn, TReturn, TFuncArgs...>::Call(name_, functor_,
0, args);
}

View File

@ -1,26 +0,0 @@
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//tensorflow/cc/experimental/libtf:__subpackages__",
],
licenses = ["notice"],
)
cc_library(
name = "tfrt",
srcs = [
"tfrt.cc",
],
hdrs = [
"tfrt.h",
],
deps = [
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/cc/experimental/libtf",
"//tensorflow/cc/experimental/libtf/runtime",
],
)

View File

@ -1,45 +0,0 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/experimental/libtf/runtime/tfrt/tfrt.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/cc/experimental/libtf/value.h"
namespace tf {
namespace libtf {
namespace runtime {
namespace tfrt {
runtime::Runtime Runtime() {
TFE_Context* ctx;
TFE_ContextOptions* ctx_options = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(ctx_options, true);
TFE_ContextOptionsSetDevicePlacementPolicy(ctx_options,
TFE_DEVICE_PLACEMENT_WARN);
TF_Status* status = TF_NewStatus();
ctx = TFE_NewContext(ctx_options, status);
TF_DeleteStatus(status);
TFE_DeleteContextOptions(ctx_options);
return runtime::Runtime(tensorflow::unwrap(ctx));
}
} // namespace tfrt
} // namespace runtime
} // namespace libtf
} // namespace tf

View File

@ -288,7 +288,7 @@ TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) {
INSTANTIATE_TEST_SUITE_P(TF2CAPI, FunctionTest,
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(false, true)));
::testing::Values(false)));
} // namespace libtf
} // namespace tf

View File

@ -123,7 +123,7 @@ TEST_P(UnifiedCAPI, SimpleCreationFunctions) {
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(true, false)));
::testing::Values(false)));
} // namespace libtf
} // namespace tf

View File

@ -114,7 +114,7 @@ TEST_P(VariableTest, CreateAssignReadDestroy) {
INSTANTIATE_TEST_SUITE_P(TF2CAPI, VariableTest,
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(false, true)));
::testing::Values(false)));
} // namespace libtf
} // namespace tf

View File

@ -21,6 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
@ -156,9 +158,9 @@ class Input {
typedef typename RealType<T>::type RealT;
Tensor t(DataTypeToEnum<RealT>::v(), shape);
if (t.NumElements() != static_cast<int64_t>(v.size())) {
status = errors::InvalidArgument(
status = absl::InvalidArgumentError(absl::StrCat(
"Cannot construct a tensor with ", t.NumElements(),
" from an initializer list with ", v.size(), " elements");
" from an initializer list with ", v.size(), " elements"));
return;
}
std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());

View File

@ -466,7 +466,7 @@ TEST_F(CWiseUnaryGradTest, Asin_Complex) {
};
// TODO(kbsriram)
// Enable test when the asin kernel supports complex numbers
if (false) {
if (/* DISABLES CODE */ (false)) {
TestCWiseGrad<complex64, complex64>(ASIN, x_fn);
}
}
@ -482,7 +482,7 @@ TEST_F(CWiseUnaryGradTest, Acos_Complex) {
};
// TODO(kbsriram)
// Add test when the acos kernel supports complex numbers
if (false) {
if (/* DISABLES CODE */ (false)) {
TestCWiseGrad<complex64, complex64>(ACOS, x_fn);
}
}
@ -510,7 +510,7 @@ TEST_F(CWiseUnaryGradTest, Atan_Complex) {
};
// TODO(kbsriram)
// Add test when the atan kernel supports complex numbers
if (false) {
if (/* DISABLES CODE */ (false)) {
TestCWiseGrad<complex64, complex64>(ATAN, x_fn);
}
}
@ -561,7 +561,7 @@ TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
};
// TODO(kbsriram)
// Add test when the lgamma kernel supports complex numbers
if (false) {
if (/* DISABLES CODE */ (false)) {
TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
}
}
@ -579,7 +579,7 @@ TEST_F(CWiseUnaryGradTest, Erf_Complex) {
};
// TODO(kbsriram)
// Add test when the erf kernel supports complex numbers
if (false) {
if (/* DISABLES CODE */ (false)) {
TestCWiseGrad<complex64, complex64>(ERF, x_fn);
}
}

View File

@ -1,4 +1,5 @@
# Description:
#include "third_party/absl/strings/str_cat.h"
#Description:
# TensorFlow SavedModel.
load("//tensorflow:tensorflow.default.bzl", "filegroup")
@ -6,6 +7,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"if_google",
"if_mobile",
"if_not_mobile",
"tf_cc_test",
@ -28,6 +30,8 @@ package(
exports_files([
"loader.h",
"testdata/chunked_saved_model/chunked_model/saved_model.cpb",
"testdata/chunked_saved_model/chunked_model/saved_model.pbtxt",
])
cc_library(
@ -63,7 +67,9 @@ cc_library(
":metrics",
":util",
"//tensorflow/core:protos_all_cc",
] + if_not_mobile([
] + if_google([
"//tensorflow/tools/proto_splitter:merge",
]) + if_not_mobile([
# TODO(b/111634734): :lib and :protos_all contain dependencies that
# cannot be built on mobile platforms. Instead, include the appropriate
# tf_lib depending on the build platform.
@ -87,9 +93,8 @@ tf_cc_test(
":tag_constants",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:resource_loader",
"@com_google_googletest//:gtest_main",
],
)
@ -131,6 +136,8 @@ cc_library(
":fingerprinting",
":loader_util",
":reader",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
] + if_not_mobile([
":metrics",
":util",
@ -252,15 +259,15 @@ py_binary(
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/module",
"//tensorflow/python/ops:lookup_ops",
"//tensorflow/python/ops:variables",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/saved_model",
"//tensorflow/python/saved_model:save_options",
"//tensorflow/python/trackable:asset",
@ -268,6 +275,31 @@ py_binary(
],
)
# copybara:uncomment_begin(google-only)
#
# py_binary(
# name = "testdata/generate_chunked_models",
# srcs = ["testdata/generate_chunked_models.py"],
# python_version = "PY3",
# srcs_version = "PY3",
# deps = [
# "//tensorflow/python/compat:v2_compat",
# "//tensorflow/python/eager:def_function",
# "//tensorflow/python/framework:constant_op",
# "//tensorflow/python/module",
# "//tensorflow/python/platform:client_testlib",
# "//tensorflow/python/saved_model:loader",
# "//tensorflow/python/saved_model:save",
# "//tensorflow/python/saved_model:save_options",
# "//tensorflow/python/util:compat",
# "//tensorflow/tools/proto_splitter:constants",
# "//tensorflow/tools/proto_splitter/python:saved_model",
# "@absl_py//absl:app",
# ],
# )
#
# copybara:uncomment_end
# TODO(b/32673259): add a test to continuously validate these files.
filegroup(
name = "saved_model_test_files",
@ -284,6 +316,7 @@ filegroup(
"testdata/fuzz_generated/**",
"testdata/SimpleV1Model/**",
"testdata/OptimizerSlotVariableModule/**",
"testdata/chunked_saved_model/**",
]),
)
@ -369,7 +402,7 @@ tf_cc_test(
":metrics",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_googletest//:gtest_main",
"@com_google_googletest//:gtest",
"@jsoncpp_git//:jsoncpp",
],
)

View File

@ -39,60 +39,6 @@ using strings::StrCat;
// `tensorflow::SavedModelV2Bundle::Load` API label.
constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2";
Status ReadSavedModelProto(const string& export_dir,
SavedModel* saved_model_proto) {
LOG(INFO) << "Reading SavedModel from: " << export_dir;
const string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
Status found_pb = Env::Default()->FileExists(saved_model_pb_path);
if (found_pb.ok()) {
Status result =
ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
if (result.ok()) {
metrics::SavedModelReadCount(
saved_model::GetWriteVersion(*saved_model_proto))
.IncrementBy(1);
}
return result;
}
const string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
Status found_pbtxt = Env::Default()->FileExists(saved_model_pbtxt_path);
if (found_pbtxt.ok()) {
Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
saved_model_proto);
if (result.ok()) {
metrics::SavedModelReadCount(
saved_model::GetWriteVersion(*saved_model_proto))
.IncrementBy(1);
}
return result;
}
Status err;
if (found_pb.code() == found_pbtxt.code()) {
err = Status(found_pb.code(),
StrCat(found_pb.message(), "\n", found_pbtxt.message()));
} else if (found_pb.code() == NOT_FOUND) {
err = found_pbtxt;
} else if (found_pbtxt.code() == NOT_FOUND) {
err = found_pb;
} else {
// found_pb and found_pbtxt both errored, w/ different codes, neither being
// NOT_FOUND.
err = Status(
absl::StatusCode::kInternal,
StrCat("Different errors encountered while looking for saved_model.pb "
"and saved_model.pbtxt in the export directory path \"",
export_dir, "\": \n", found_pb.ToString(), "\n",
found_pbtxt.ToString()));
}
return err;
}
Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
TrackableObjectGraph* object_graph) {
Tensor object_graph_tensor;
@ -123,7 +69,7 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir,
SavedModelV2Bundle* const bundle) {
metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1);
SavedModel saved_model_proto;
TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
metrics::SavedModelReadPath().Set(export_dir);
// Load MetaGraphDef.

View File

@ -19,56 +19,63 @@ limitations under the License.
namespace tensorflow {
// SavedModel assets directory.
constexpr char kSavedModelAssetsDirectory[] = "assets";
inline constexpr char kSavedModelAssetsDirectory[] = "assets";
// SavedModel assets.extra directory.
constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra";
inline constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra";
// SavedModel assets key for graph collection-def.
constexpr char kSavedModelAssetsKey[] = "saved_model_assets";
inline constexpr char kSavedModelAssetsKey[] = "saved_model_assets";
/// SavedModel legacy init op collection key. Used in v1 SavedModels.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
inline constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
/// SavedModel main op collection key. Used in v1 SavedModels.
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
inline constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
// CollectionDef key for the SavedModel train op.
// Not exported while export_all_saved_models is experimental.
constexpr char kSavedModelTrainOpKey[] = "saved_model_train_op";
inline constexpr char kSavedModelTrainOpKey[] = "saved_model_train_op";
// Schema version for SavedModel.
constexpr int kSavedModelSchemaVersion = 1;
inline constexpr int kSavedModelSchemaVersion = 1;
// SavedModel proto filename prefix.
inline constexpr char kSavedModelFilenamePrefix[] = "saved_model";
// SavedModel proto filename.
constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
inline constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
// SavedModel chunked proto filename.
inline constexpr char kSavedModelFilenameCpb[] = "saved_model.cpb";
// SavedModel text format proto filename.
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
inline constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
// Subdirectory where debugging related files are written.
constexpr char kSavedModelDebugDirectory[] = "debug";
inline constexpr char kSavedModelDebugDirectory[] = "debug";
// File name for GraphDebugInfo protocol buffer which corresponds to the
// SavedModel.
constexpr char kSavedModelDebugInfoFilenamePb[] = "saved_model_debug_info.pb";
inline constexpr char kSavedModelDebugInfoFilenamePb[] =
"saved_model_debug_info.pb";
// Directory in which to save the SavedModel variables.
constexpr char kSavedModelVariablesDirectory[] = "variables";
inline constexpr char kSavedModelVariablesDirectory[] = "variables";
// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "variables";
inline constexpr char kSavedModelVariablesFilename[] = "variables";
// SavedModel SignatureDef keys for the initialization and train ops. Used in
// V2 SavedModels.
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
inline constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
inline constexpr char kSavedModelTrainOpSignatureKey[] =
"__saved_model_train_op";
// Key in the TensorBundle for the object graph proto.
constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH";
inline constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH";
// Filename for the FingerprintDef protocol buffer.
constexpr char kFingerprintFilenamePb[] = "fingerprint.pb";
inline constexpr char kFingerprintFilenamePb[] = "fingerprint.pb";
} // namespace tensorflow

View File

@ -120,23 +120,29 @@ uint64 HashCheckpointIndexFile(absl::string_view model_dir) {
StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
absl::string_view export_dir) {
SavedModel copy = saved_model;
return CreateFingerprintDef(&copy, export_dir);
}
StatusOr<FingerprintDef> CreateFingerprintDef(SavedModel* saved_model,
absl::string_view export_dir) {
// Create a copy of `metagraph` which will be used and mutated for fingerprint
// computation.
MetaGraphDef metagraph_copy = saved_model.meta_graphs(0);
FingerprintDef fingerprint_def;
MetaGraphDef* metagraph = saved_model->mutable_meta_graphs(0);
// Set fingerprint field #1.
fingerprint_def.set_saved_model_checksum(HashSavedModel(saved_model));
fingerprint_def.set_saved_model_checksum(HashSavedModel(*saved_model));
// Set fingerprint field #2.
graph_regularization::SimpleDelete(*metagraph_copy.mutable_graph_def());
graph_regularization::SimpleDelete(*metagraph->mutable_graph_def());
fingerprint_def.set_graph_def_program_hash(
graph_regularization::ComputeHash(metagraph_copy.graph_def()));
graph_regularization::ComputeHash(metagraph->graph_def()));
// Set fingerprint field #3.
fingerprint_def.set_signature_def_hash(
RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
RegularizeAndHashSignatureDefs(metagraph->signature_def()));
// Set fingerprint field #4.
TF_ASSIGN_OR_RETURN(
StatusOr<uint64> object_graph_hash,
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
RegularizeAndHashSavedObjectGraph(metagraph->object_graph_def()));
fingerprint_def.set_saved_object_graph_hash(object_graph_hash.value());
// Set fingerprint field #5.
fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir));

View File

@ -31,6 +31,12 @@ namespace tensorflow::saved_model::fingerprinting {
StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
absl::string_view export_dir);
// Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file
// (.index) in `export_dir`. The passed in `saved_model` is mutated and should
// not be used afterwards.
StatusOr<FingerprintDef> CreateFingerprintDef(SavedModel* saved_model,
absl::string_view export_dir);
// Loads the `fingerprint.pb` from `export_dir`, returns an error if there is
// none.
StatusOr<FingerprintDef> ReadSavedModelFingerprint(

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <string>
#include <unordered_set>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/fingerprinting.h"
#include "tensorflow/cc/saved_model/loader_util.h"
@ -90,16 +92,16 @@ static Status ValidateNode(const NodeDef& node) {
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
"\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
node_shape.num_elements(), " elements"));
}
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
"\" which is a constant tensor but no value has been provided"));
}
return OkStatus();
}
@ -108,9 +110,9 @@ static Status ValidateFunctionNotRecursive(const FunctionDef& function) {
const auto& function_name = function.signature().name();
for (const auto& node : function.node_def()) {
if (node.op() == function_name) {
return errors::FailedPrecondition(
return absl::FailedPreconditionError(absl::StrCat(
"Function ", function_name,
" is self recursive and TensorFlow does not support this scenario.");
" is self recursive and TensorFlow does not support this scenario."));
}
}
@ -340,17 +342,17 @@ class LiteSessionWrapper : public Session {
: wrapped_(std::move(wrapped)) {}
Status Create(const GraphDef& graph) override {
return errors::Unimplemented("Session::Create()");
return absl::UnimplementedError("Session::Create()");
}
Status Create(GraphDef&& graph) override {
return errors::Unimplemented("Session::Create()");
return absl::UnimplementedError("Session::Create()");
}
Status Extend(const GraphDef& graph) override {
return errors::Unimplemented("Session::Extend()");
return absl::UnimplementedError("Session::Extend()");
}
Status Extend(GraphDef&& graph) override {
return errors::Unimplemented("Session::Extend()");
return absl::UnimplementedError("Session::Extend()");
}
Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
@ -362,16 +364,16 @@ class LiteSessionWrapper : public Session {
}
Status Create(const RunOptions& run_options, const GraphDef& graph) override {
return errors::Unimplemented("Session::Create()");
return absl::UnimplementedError("Session::Create()");
}
Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
return errors::Unimplemented("Session::Extend()");
return absl::UnimplementedError("Session::Extend()");
}
Status Create(const RunOptions& run_options, GraphDef&& graph) override {
return errors::Unimplemented("Session::Create()");
return absl::UnimplementedError("Session::Create()");
}
Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
return errors::Unimplemented("Session::Extend()");
return absl::UnimplementedError("Session::Extend()");
}
Status Close(const RunOptions& run_options) override {
return wrapped_->Close(run_options);
@ -390,14 +392,14 @@ class LiteSessionWrapper : public Session {
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) override {
return errors::Unimplemented("Session::PRunSetup()");
return absl::UnimplementedError("Session::PRunSetup()");
}
Status PRun(const string& handle,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) override {
return errors::Unimplemented("Session::PRun()");
return absl::UnimplementedError("Session::PRun()");
}
Status ListDevices(std::vector<DeviceAttributes>* response) override {

View File

@ -15,7 +15,10 @@ limitations under the License.
#include "tensorflow/cc/saved_model/reader.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include "absl/memory/memory.h"
#include "tensorflow/cc/saved_model/constants.h"
@ -31,60 +34,17 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h"
// Placeholder for protosplitter merger include.
#define IS_OSS true
namespace tensorflow {
namespace {
// Reads the SavedModel proto from saved_model.pb in `export_dir`.
// Returns a failure status when the SavedModel file does not exist.
Status ReadSavedModel(absl::string_view export_dir,
SavedModel* saved_model_proto) {
LOG(INFO) << "Reading SavedModel from: " << export_dir;
const std::string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
TF_ASSIGN_OR_RETURN(
bool saved_model_pb_exists,
internal::FileExists(Env::Default(), saved_model_pb_path));
if (saved_model_pb_exists) {
Status result =
ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
if (result.ok()) {
metrics::SavedModelReadCount(
saved_model::GetWriteVersion(*saved_model_proto))
.IncrementBy(1);
}
return result;
}
const std::string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
TF_ASSIGN_OR_RETURN(
bool saved_model_pbtxt_exists,
internal::FileExists(Env::Default(), saved_model_pbtxt_path));
if (saved_model_pbtxt_exists) {
Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
saved_model_proto);
if (result.ok()) {
metrics::SavedModelReadCount(
saved_model::GetWriteVersion(*saved_model_proto))
.IncrementBy(1);
}
return result;
}
return Status(
absl::StatusCode::kNotFound,
strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied "
"export directory path: ",
export_dir,
". Check that "
"the directory exists and that you have the right "
"permissions for accessing it."));
}
Status FindMetaGraphDef(const std::unordered_set<string>& tags,
SavedModel* saved_model_proto,
MetaGraphDef* meta_graph_def) {
@ -116,6 +76,61 @@ Status FindMetaGraphDef(const std::unordered_set<string>& tags,
}
} // namespace
// Reads the SavedModel proto from saved_model.pb in `export_dir`.
// Returns a failure status when the SavedModel file does not exist.
Status ReadSavedModel(absl::string_view export_dir,
SavedModel* saved_model_proto) {
LOG(INFO) << "Reading SavedModel from: " << export_dir;
if (IS_OSS) {
const std::string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
TF_ASSIGN_OR_RETURN(
bool saved_model_pb_exists,
internal::FileExists(Env::Default(), saved_model_pb_path));
if (saved_model_pb_exists) {
Status result = ReadBinaryProto(Env::Default(), saved_model_pb_path,
saved_model_proto);
if (result.ok()) {
metrics::SavedModelReadCount(
saved_model::GetWriteVersion(*saved_model_proto))
.IncrementBy(1);
}
return result;
}
}
const std::string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
TF_ASSIGN_OR_RETURN(
bool saved_model_pbtxt_exists,
internal::FileExists(Env::Default(), saved_model_pbtxt_path));
if (saved_model_pbtxt_exists) {
Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
saved_model_proto);
if (result.ok()) {
metrics::SavedModelReadCount(
saved_model::GetWriteVersion(*saved_model_proto))
.IncrementBy(1);
}
return result;
}
if (!IS_OSS) {
// Only use Merger outside of OSS.
// Placeholder for protosplitter merger call.
}
return Status(
absl::StatusCode::kNotFound,
strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied "
"export directory path: ",
export_dir,
". Check that "
"the directory exists and that you have the right "
"permissions for accessing it."));
}
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
MetaGraphDef* const meta_graph_def) {
@ -140,8 +155,7 @@ Status ReadSavedModelDebugInfoIfPresent(
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(
ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
*debug_info_proto =
absl::make_unique<GraphDebugInfo>(std::move(debug_info));
*debug_info_proto = std::make_unique<GraphDebugInfo>(std::move(debug_info));
}
return OkStatus();
}

View File

@ -18,21 +18,26 @@ limitations under the License.
#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_
#define TENSORFLOW_CC_SAVED_MODEL_READER_H_
#include <memory>
#include <string>
#include <unordered_set>
#include "tensorflow/core/framework/graph_debug_info.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
namespace tensorflow {
Status ReadSavedModel(absl::string_view export_dir,
SavedModel* saved_model_proto);
// Reads the SavedModel proto from saved_model.pb(txt) in the given directory,
// finds the MetaGraphDef that matches the given set of tags and writes it to
// the `meta_graph_def` parameter. Returns a failure status when the SavedModel
// file does not exist or no MetaGraphDef matches the tags.
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
MetaGraphDef* const meta_graph_def);
MetaGraphDef* meta_graph_def);
// Store debug info from the SavedModel export dir.
Status ReadSavedModelDebugInfoIfPresent(

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/reader.h"
#include <gmock/gmock.h>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/metrics.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
@ -39,6 +39,16 @@ string TestDataSharded() {
"half_plus_two", "00000123");
}
string ChunkedSavedModel() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"chunked_saved_model", "chunked_model");
}
string NonChunkedSavedModel() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"chunked_saved_model", "non_chunked_model");
}
class ReaderTest : public ::testing::Test {
protected:
ReaderTest() {}
@ -88,15 +98,6 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
<< st.message();
}
TEST_F(ReaderTest, PbtxtFormat) {
MetaGraphDef meta_graph_def;
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
}
TEST_F(ReaderTest, InvalidExportPath) {
MetaGraphDef meta_graph_def;
@ -136,5 +137,7 @@ TEST_F(ReaderTest, MetricsUpdatedSuccessfulRead) {
EXPECT_EQ(metrics::SavedModelReadCount("1").value(), read_count_v1 + 1);
}
// Placeholder for protosplitter merger merge test.
} // namespace
} // namespace tensorflow

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
2(ЕћзмЅ…вА Њ‚¦юЎћхујЋў¶в­ђЪвЋЯЕ«ѓПѕњоУА°ійов®Щ

View File

@ -0,0 +1,76 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates GraphDef test data for Merger.
Constructs chunked protos test data containing GraphDefs with lots of nodes and
large nodes for Merger::Read and Merger::Merge.
"""
from collections.abc import Sequence
import os
from absl import app
from absl import flags
import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.lib.io import file_io
from tensorflow.python.module import module
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import save_options
from tensorflow.python.util import compat
from tensorflow.tools.proto_splitter import constants
from tensorflow.tools.proto_splitter.python import saved_model as proto_splitter
SPLITTER_TESTDATA_PATH = flags.DEFINE_string(
"path", None, help="Path to testdata directory.")
def generate_non_chunked_model(non_chunked_dir: str):
root = module.Module()
root.c = constant_op.constant(np.random.random_sample([150, 150]))
constants.debug_set_max_size(80000)
root.get_c = def_function.function(lambda: root.c)
signatures = root.get_c.get_concrete_function()
save.save(root, non_chunked_dir, signatures=signatures,
options=save_options.SaveOptions(experimental_image_format=False))
def generate_chunked_model(non_chunked_dir: str, chunked_dir: str):
saved_model = loader_impl.parse_saved_model(non_chunked_dir)
prefix = file_io.join(compat.as_str(chunked_dir), "saved_model")
file_io.write_string_to_file(f"{prefix}.pbtxt", str(saved_model))
proto_splitter.SavedModelSplitter(saved_model).write(prefix)
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
main_dir = os.path.join(SPLITTER_TESTDATA_PATH.value, "chunked_saved_model")
non_chunked_dir = os.path.join(main_dir, "non_chunked_model")
generate_non_chunked_model(non_chunked_dir)
chunked_dir = os.path.join(main_dir, "chunked_model")
generate_chunked_model(non_chunked_dir, chunked_dir)
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
app.run(main)

View File

@ -1,4 +1,5 @@
# Description:
#include "third_party/absl/strings/str_cat.h"
#Description:
# TensorFlow cc tools.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
@ -22,6 +23,8 @@ cc_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <iostream>
#include <queue>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -193,7 +195,7 @@ StatusOr<string> GetVarHandleName(
if (node->op() == "VarHandleOp") {
return node->name();
}
return errors::NotFound("No VarHandleOp ancestor found");
return absl::NotFoundError("No VarHandleOp ancestor found");
}
// Looks up the variable handle that provides input to node with node_name,
@ -209,7 +211,7 @@ StatusOr<string> GetHandleNameIfNeedsToFreeze(
if (var_handle_name.ok() && variable_node_names.count(*var_handle_name)) {
return var_handle_name;
}
return errors::NotFound("No VarHandleOp ancestor found");
return absl::NotFoundError("No VarHandleOp ancestor found");
}
// Freezes the subgraph of all nodes needed by `outputs`.

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
@ -24,6 +25,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
@ -312,6 +314,74 @@ Status GenVariableMethods(const tf2xla::Config& config,
return OkStatus();
}
// Generate shape infos for args (inputs).
Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) {
for (int i = 0; i < ps.parameters_size(); ++i) {
const xla::ShapeProto& shape = ps.parameters(i);
if (shape.element_type() == xla::TUPLE) {
// ShapeInfo cannot represent tuple args.
return absl::InternalError(
absl::StrCat("parameter ", i,
": codegen requires XLA parameters to "
"be non-tuples."));
}
// Please some compilers (e.g. MSVC) by avoiding the initialization of an
// array of unknown size an empty initializer. Use "-1" for this; note that
// this value is never used (the size attribute is set to 0 in ShapeInfo).
*infos += absl::Substitute(R"( static constexpr int32_t kArg$0Shapes[] = {
$1
};
)",
i,
shape.dimensions_size() > 0
? absl::StrJoin(shape.dimensions(), ", ")
: "-1");
}
*infos += R"( static const ShapeInfo* ArgShapeInfos() {
static constexpr ShapeInfo kArgShapeInfoTable[kNumArgs] = {
)";
for (int i = 0; i < ps.parameters_size(); ++i) {
const xla::ShapeProto& shape = ps.parameters(i);
*infos +=
absl::Substitute("{ kArg$0Shapes, $1 },\n", i, shape.dimensions_size());
}
*infos += R"( };
return kArgShapeInfoTable;
})";
return OkStatus();
}
// Generate shape infos for results.
Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, string* infos) {
if (ps.result().element_type() != xla::TUPLE) {
return absl::InternalError("codegen requires the XLA result to be a tuple");
}
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
const xla::ShapeProto& shape = ps.result().tuple_shapes(i);
// See above comment about the use here of "-1".
*infos += absl::Substitute(
R"( static constexpr int32_t kResult$0Shapes[] = {
$1
};
)",
i,
shape.dimensions_size() > 0 ? absl::StrJoin(shape.dimensions(), ", ")
: "-1");
}
*infos += R"( static const ShapeInfo* ResultShapeInfos() {
static constexpr ShapeInfo kResultShapeInfoTable[kNumResults] = {
)";
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
const xla::ShapeProto& shape = ps.result().tuple_shapes(i);
*infos += absl::Substitute("{ kResult$0Shapes, $1 },\n", i,
shape.dimensions_size());
}
*infos += R"( };
return kResultShapeInfoTable;
})";
return OkStatus();
}
// Generates code implementing {Arg,Result}Names(), where T is one of
// tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style
// string literal in the array, with nullptr terminating the array.
@ -377,17 +447,27 @@ std::vector<string> BufferInfosToCppExpression(
std::transform(buffer_infos.begin(), buffer_infos.end(),
std::back_inserter(buffer_infos_as_strings),
[](const BufferInfo& buffer_info) {
std::pair<uint64, uint64> encoded = buffer_info.Encode();
string encoded_second_as_str =
encoded.second == ~0ULL
? "~0ULL"
: absl::StrCat(encoded.second, "ULL");
xla::cpu_function_runtime::EncodedBufferInfo encoded =
buffer_info.Encode();
auto param_to_str = [](uint32_t param) -> std::string {
return param == ~0U ? "~0U" : absl::StrCat(param, "U");
};
return absl::StrCat(
"::xla::cpu_function_runtime::BufferInfo({",
encoded.first, "ULL, ", encoded_second_as_str, "})");
"::xla::cpu_function_runtime::BufferInfo(",
encoded.packed_kind_and_size, "ULL, ",
param_to_str(encoded.entry_param_number), ", ",
param_to_str(encoded.result_param_number), ")");
});
return buffer_infos_as_strings;
}
Status CheckEqual(size_t a, size_t b, absl::string_view error_msg) {
if (a != b) {
return absl::InternalError(
absl::StrCat(error_msg, ". Expected ", a, ", got ", b, "."));
}
return OkStatus();
}
} // namespace
Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
@ -400,6 +480,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
compile_result.aot->buffer_infos();
const std::vector<int32> arg_index_table =
::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
const std::vector<int32> result_index_table =
::xla::cpu::CreateResultIndexTableFromBufferInfos(buffer_infos);
std::vector<string> buffer_infos_as_strings =
BufferInfosToCppExpression(buffer_infos);
const int64_t buffer_infos_size = buffer_infos.size();
@ -419,6 +501,15 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable));
string arg_shape_infos, result_shape_infos;
TF_RETURN_IF_ERROR(GenArgShapeInfos(ps, &arg_shape_infos));
TF_RETURN_IF_ERROR(
CheckEqual(ps.parameters_size(), arg_index_table.size(),
"Arg number mismatch, proto vs. arg_index_table"));
TF_RETURN_IF_ERROR(GenResultShapeInfos(ps, &result_shape_infos));
TF_RETURN_IF_ERROR(
CheckEqual(ps.result().tuple_shapes_size(), result_index_table.size(),
"Result number mismatch, proto vs. result_index_table"));
const size_t arg_bytes_aligned =
xla::cpu_function_runtime::AlignedBufferBytes(
buffer_infos_for_args.data(), buffer_infos_for_args.size(),
@ -544,6 +635,8 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = {{ARG_NUM}};
static constexpr size_t kNumResults = {{RESULT_NUM}};
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = {{VARIABLE_NUM}};
@ -560,16 +653,21 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_raw_function(data, {{ENTRY}});
set_static_data_buffer_infos(data, BufferInfos());
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_result_index_table(data, ResultIndexToBufferIndex());
set_static_data_num_results(data, kNumResults);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_shape_infos(data, ArgShapeInfos());
set_static_data_result_shape_infos(data, ResultShapeInfos());
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
data, StaticHloProfilePrinterData());
set_static_data_use_xla_runtime(data, {{USE_XLA_RUNTIME}});
{{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
@ -589,7 +687,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
//
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
// any AllocMode. Must be called before Run to have an effect. Must be
// called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
// argument, to set the argument buffers.
//
@ -655,6 +753,13 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
return kBufferInfos;
}
static const ::tensorflow::int32* ResultIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kResultIndexToBufferIndex[kNumResults] = {
{{RESULT_INDEX_TABLE}}
};
return kResultIndexToBufferIndex;
}
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
{{ARG_INDEX_TABLE}}
@ -665,6 +770,12 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
// Shapes of the input arguments.
{{ARG_SHAPE_INFOS}};
// Shapes of the results.
{{RESULT_SHAPE_INFOS}};
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
@ -699,13 +810,18 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
{"{{ARG_SHAPE_INFOS}}", arg_shape_infos},
{"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())},
{"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{RESULT_NUM}}", absl::StrCat(result_index_table.size())},
{"{{RESULT_INDEX_TABLE}}", absl::StrJoin(result_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
{"{{USE_XLA_RUNTIME}}", opts.use_xla_runtime ? "true" : "false"},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
{"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
@ -722,6 +838,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{VARIABLE_NAMES_CODE}}", variable_names_code},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{RESULT_SHAPE_INFOS}}", result_shape_infos},
{"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
{"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
@ -749,7 +866,7 @@ Status GenerateMetadata(const CodegenOpts& opts,
if (opts.gen_program_shape) {
program_shape =
absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
std::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
// The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save

View File

@ -48,6 +48,9 @@ struct CodegenOpts {
// If true, emit a serialized HloProfilePrinterData protobuf that can be used
// to pretty print HLO profile counters.
bool gen_hlo_profile_printer_data = false;
// If true, sets this executable as an XLA Runtime one.
bool use_xla_runtime = false;
};
// Describes a generated metadata object file.

View File

@ -215,18 +215,22 @@ TEST(CodegenTest, Golden) {
CompileResult compile_result;
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
{},
{BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
{BufferInfo::MakeTempBuffer(3 * 8),
BufferInfo::MakeEntryParameter(/*size=*/8, /*entry_param_number=*/0),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/1),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/2),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
BufferInfo::MakeTempBuffer(1),
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
11, {}));
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/3),
BufferInfo::MakeResultParameter(/*size=*/5 * 6 * 4,
/*result_param_number=*/0),
BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/4),
BufferInfo::MakeResultParameter(/*size=*/1 * 4,
/*result_param_number=*/1),
BufferInfo::MakeResultParameter(/*size=*/5 * 4,
/*result_param_number=*/2)},
0, {}));
compile_result.program_shape =
xla::ShapeUtil::MakeProgramShape(
{

View File

@ -58,13 +58,15 @@ namespace bar {
// Memory stats:
// arg bytes total: 392
// arg bytes aligned: 576
// temp bytes total: 126
// temp bytes total: 171
// temp bytes aligned: 512
class MyClass final : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = 5;
static constexpr size_t kNumResults = 3;
// Number of variables for the compiled computation.
static constexpr size_t kNumVariables = 3;
@ -81,16 +83,21 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_raw_function(data, entry_point);
set_static_data_buffer_infos(data, BufferInfos());
set_static_data_num_buffers(data, kNumBuffers);
set_static_data_result_index_table(data, ResultIndexToBufferIndex());
set_static_data_num_results(data, kNumResults);
set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
set_static_data_num_args(data, kNumArgs);
set_static_data_num_variables(data, kNumVariables);
set_static_data_result_index(data, kResultIndex);
set_static_data_arg_shape_infos(data, ArgShapeInfos());
set_static_data_result_shape_infos(data, ResultShapeInfos());
set_static_data_arg_names(data, StaticArgNames());
set_static_data_variable_names(data, StaticVariableNames());
set_static_data_result_names(data, StaticResultNames());
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
data, StaticHloProfilePrinterData());
set_static_data_use_xla_runtime(data, false);
return data;
}();
@ -110,7 +117,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
//
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
// any AllocMode. Must be called before Run to have an effect. Must be
// called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
// argument, to set the argument buffers.
//
@ -354,22 +361,29 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::xla::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}),
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
::xla::cpu_function_runtime::BufferInfo(97ULL, ~0U, ~0U),
::xla::cpu_function_runtime::BufferInfo(34ULL, 0U, ~0U),
::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U),
::xla::cpu_function_runtime::BufferInfo(386ULL, 1U, ~0U),
::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U),
::xla::cpu_function_runtime::BufferInfo(386ULL, 2U, ~0U),
::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U),
::xla::cpu_function_runtime::BufferInfo(386ULL, 3U, ~0U),
::xla::cpu_function_runtime::BufferInfo(481ULL, ~0U, 0U),
::xla::cpu_function_runtime::BufferInfo(386ULL, 4U, ~0U),
::xla::cpu_function_runtime::BufferInfo(17ULL, ~0U, 1U),
::xla::cpu_function_runtime::BufferInfo(81ULL, ~0U, 2U)
};
return kBufferInfos;
}
static const ::tensorflow::int32* ResultIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kResultIndexToBufferIndex[kNumResults] = {
8, 10, 11
};
return kResultIndexToBufferIndex;
}
static const ::tensorflow::int32* ArgIndexToBufferIndex() {
static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
1, 3, 5, 7, 9
@ -378,7 +392,53 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
}
// The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 11;
static constexpr size_t kResultIndex = 0;
// Shapes of the input arguments.
static constexpr int32_t kArg0Shapes[] = {
1, 2
};
static constexpr int32_t kArg1Shapes[] = {
3, 4
};
static constexpr int32_t kArg2Shapes[] = {
1
};
static constexpr int32_t kArg3Shapes[] = {
1
};
static constexpr int32_t kArg4Shapes[] = {
5
};
static const ShapeInfo* ArgShapeInfos() {
static constexpr ShapeInfo kArgShapeInfoTable[kNumArgs] = {
{ kArg0Shapes, 2 },
{ kArg1Shapes, 2 },
{ kArg2Shapes, 1 },
{ kArg3Shapes, 1 },
{ kArg4Shapes, 1 },
};
return kArgShapeInfoTable;
};
// Shapes of the results.
static constexpr int32_t kResult0Shapes[] = {
5, 6
};
static constexpr int32_t kResult1Shapes[] = {
1
};
static constexpr int32_t kResult2Shapes[] = {
5
};
static const ShapeInfo* ResultShapeInfos() {
static constexpr ShapeInfo kResultShapeInfoTable[kNumResults] = {
{ kResult0Shapes, 2 },
{ kResult1Shapes, 1 },
{ kResult2Shapes, 1 },
};
return kResultShapeInfoTable;
};
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {

View File

@ -273,6 +273,15 @@ Status Main(const MainFlags& flags) {
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
// Set the XLA Runtime bit if this is an HloLowering.
if (!flags.mlir_components.empty() && flags.mlir_components != "None") {
for (auto component : absl::StrSplit(flags.mlir_components, ',')) {
if (component == "HloLowering") {
codegen_opts.use_xla_runtime = true;
}
}
}
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}

View File

@ -69,18 +69,18 @@ py_binary(
srcs_version = "PY3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:cond",
"//tensorflow/python:control_flow_assert",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:session",
"//tensorflow/python:training",
"//tensorflow/python:variable_v1",
"//tensorflow/python:variables",
"//tensorflow/python/client",
"//tensorflow/python/client:session",
"//tensorflow/python/framework:for_generated_wrappers",
"//tensorflow/python/ops:array_ops",
"//tensorflow/python/ops:cond",
"//tensorflow/python/ops:control_flow_assert",
"//tensorflow/python/ops:control_flow_ops",
"//tensorflow/python/ops:math_ops",
"//tensorflow/python/ops:nn_ops",
"//tensorflow/python/ops:variable_v1",
"//tensorflow/python/ops:variables",
"//tensorflow/python/training",
"@absl_py//absl:app",
"@six_archive//:six",
],
@ -437,6 +437,7 @@ tf_cc_test(
tags = [
"manual",
"no_mac", # TODO(b/228273415)
"not_run:arm",
],
deps = [
":test_graph_tfadd",
@ -510,6 +511,7 @@ tf_cc_test(
tags = [
"manual",
"no_mac", # TODO(b/228273415)
"not_run:arm",
],
deps = [
":test_graph_tfadd_mlir_bridge",

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <vector>
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL

View File

@ -319,6 +319,8 @@ def _tf_library(
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually
# needed.
"//tensorflow/compiler/xla/service/cpu/runtime:convolution_ffi",
"//tensorflow/compiler/xla/service/cpu/runtime:rng_ffi",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_custom_call_status",
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
@ -329,6 +331,7 @@ def _tf_library(
"//third_party/eigen3",
] or []) + (
mlir_components.count("HloLowering") > 0 and [
"//tensorflow/compiler/xla/runtime:aot_ffi_c_symbols",
"//tensorflow/compiler/xla/service/cpu:runtime_mlir_utils",
] or []
) + (

View File

@ -1,5 +1,5 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_cuda_only_cc_test")
load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
@ -352,8 +352,10 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:tf_pjrt_client",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/core/tfrt/common:create_pjrt_client_util",
"//tensorflow/core/tfrt/common:global_state",
"//tensorflow/core/tfrt/common:pjrt_util",
"//tensorflow/core/tpu:tpu_defs",
"@com_google_absl//absl/log",
@ -595,7 +597,8 @@ tf_cc_test(
"//tensorflow/core/framework:fake_input",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:ops_testutil",
"@com_google_googletest//:gtest_main",
"//tensorflow/core/tpu:tpu_defs",
"@com_google_googletest//:gtest",
],
)
@ -782,6 +785,7 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
],
)
@ -1509,6 +1513,70 @@ cc_library(
],
)
cc_library(
name = "xla_host_recv_device_context",
srcs = [
"xla_host_recv_device_context.cc",
],
hdrs = [
"xla_host_recv_device_context.h",
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/stream_executor",
"//tensorflow/compiler/xla/stream_executor:device_memory",
"//tensorflow/core:framework",
"@tf_runtime//:async_value",
],
)
cc_library(
name = "xla_host_send_device_context",
srcs = [
"xla_host_send_device_context.cc",
],
hdrs = [
"xla_host_send_device_context.h",
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/stream_executor",
"//tensorflow/compiler/xla/stream_executor:device_memory",
"//tensorflow/core:framework",
"@tf_runtime//:async_value",
],
)
tf_cuda_only_cc_test(
name = "xla_host_send_recv_device_context_test",
srcs = ["xla_host_send_recv_device_context_test.cc"],
tags = tf_cuda_tests_tags() + [
"config-cuda-only",
"no_oss", # Temporarily disable OSS.
],
deps = [
":flags",
":xla_device",
":xla_gpu_device",
":xla_host_recv_device_context",
":xla_host_send_device_context",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/stream_executor",
"//tensorflow/compiler/xla/stream_executor:device_memory",
"//tensorflow/compiler/xla/stream_executor:multi_platform_manager",
"//tensorflow/core:framework_internal",
"//tensorflow/core:test",
"//tensorflow/core/framework:tensor_testutil",
"@com_google_googletest//:gtest_main",
],
)
tf_cc_test(
name = "device_compilation_cluster_signature_test",
srcs = [
@ -1527,6 +1595,9 @@ tf_cc_test(
tf_cc_test(
name = "device_compilation_profiler_test",
srcs = ["device_compilation_profiler_test.cc"],
tags = [
"nomsan", # TODO(b/284492454)
],
deps = [
":device_compilation_profiler",
":xla_activity_proto_cc",
@ -1641,6 +1712,7 @@ tf_cuda_cc_test(
tags = tf_cuda_tests_tags(),
deps = [
":flags",
":pjrt_device_compiler_client",
":test_util",
":xla_device_no_jit_rewrite_registration",
":xla_gpu_device",
@ -1649,6 +1721,7 @@ tf_cuda_cc_test(
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core:framework",
"//tensorflow/core:framework_types_hdr",
"//tensorflow/core/tpu:tpu_defs",

View File

@ -19,9 +19,9 @@ limitations under the License.
#include <memory>
#include <numeric>
#include <string>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
@ -188,7 +188,7 @@ class Encapsulator {
// Adds the function call node to graph_out.
Status AddFunctionCallNode(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Returns the Node that the inputs and outputs of the function should be
@ -206,7 +206,7 @@ class Encapsulator {
// and adds the edge within the subgraph from the _Arg node to the image of
// the dst node.
Status RecordArg(const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
// Records the src of the given edge as a control result of the graph.
@ -214,14 +214,14 @@ class Encapsulator {
// the function signature.
Status RecordControlResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images);
const absl::flat_hash_map<const Node*, Node*>& node_images);
// Creates a _Retval node for the src node of edge, and add it to results_,
// if none exists yet. If a new _Retval node is created, also adds the edge
// within the subgraph from the src to the _Retval node.
Status RecordResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images);
const absl::flat_hash_map<const Node*, Node*>& node_images);
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
@ -260,14 +260,14 @@ class Encapsulator {
// (consumer node/slot) tensors in the input graph to _Arg numbers in
// the subgraph. The source map is one-to-one, whereas the dest map may be
// many-to-one.
std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
absl::flat_hash_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
// The arguments to the subgraph, in order.
std::vector<Node*> args_;
// Map from source tensor in the input graph to result #.
std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash> results_;
// Set of node names that are the source of a control output of the
// subgraph. We store strings here so that we can tolerate nodes being
@ -285,19 +285,20 @@ class Encapsulator {
// Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
// subgraphs for data edges that cross subgraph boundaries.
Status CopySubgraphEdges(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
// Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
Status CopySubgraphNodes(
absl::flat_hash_map<const Node*, Node*>* node_images);
// Copies all nodes that aren't in a compiled subgraph to the output graph.
Status CopyNodesToOutputGraph(
Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
Graph* graph_out, absl::flat_hash_map<const Node*, Node*>* node_images);
// Adds function call nodes for each compiled subgraph.
Status AddFunctionCallNodes(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Finds the image of an edge source in the output graph. If the edge crosses
@ -305,7 +306,7 @@ class Encapsulator {
// in the output graph.
Status FindOutputImageOfEdgeSrc(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
const Node* original_src_node, Node** src_image);
// Finds an edge source slot in the output graph. If the edge crosses a
@ -320,7 +321,7 @@ class Encapsulator {
// a node in the output graph.
Status FindOutputImageOfEdgeDst(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
const Node* original_dst_node, Node** dst_image);
// Finds an edge destination slot in the output graph. If the edge crosses a
@ -334,14 +335,14 @@ class Encapsulator {
// within the output graph, or crosses into or out of a compiled subgraph.
Status CopyEdgeToOutputGraph(
const Edge* edge, const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out,
std::unordered_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>* edges_added);
absl::flat_hash_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>* edges_added);
// Adds all edges to the output graph.
Status AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Makes a copy of graph containing only nodes that are ancestors of at least
@ -351,13 +352,13 @@ class Encapsulator {
Status MakePrunedGraphCopyAndInline(
const Graph& graph, const std::vector<Node*>& sink_nodes,
std::unique_ptr<Graph>* pruned_graph,
std::unordered_map<const Node*, Node*>* node_images,
absl::flat_hash_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library);
const string group_attribute_;
const Graph* graph_in_;
std::unordered_map<string, Subgraph> subgraphs_;
absl::flat_hash_map<string, Subgraph> subgraphs_;
TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
};
@ -369,9 +370,9 @@ namespace {
// including clusters that are not present in the ancestors map. has_successors
// is the set of clusters that are ancestors of some other cluster.
void TopologicalClusterSort(
const std::unordered_set<string>& clusters,
const std::unordered_set<string>& has_successors,
const std::unordered_map<string, std::unordered_set<string>>& ancestors,
const absl::flat_hash_set<string>& clusters,
const absl::flat_hash_set<string>& has_successors,
const absl::flat_hash_map<string, absl::flat_hash_set<string>>& ancestors,
std::vector<string>* sorted) {
// The nodes are placed in 'sorted' in topological order.
sorted->clear();
@ -447,11 +448,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
Status Encapsulator::Subgraph::RecordArg(
const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
const Edge* edge,
const absl::flat_hash_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
Node* src_node = edge->src();
int src_slot = edge->src_output();
std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
std::tie(iter, inserted) = args_by_src_.emplace(
OutputTensor(src_node, src_slot), args_by_src_.size());
@ -481,7 +483,7 @@ Status Encapsulator::Subgraph::RecordArg(
Status Encapsulator::Subgraph::RecordControlResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images) {
const absl::flat_hash_map<const Node*, Node*>& node_images) {
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
control_output_nodes_.insert(src_image->name());
@ -490,11 +492,11 @@ Status Encapsulator::Subgraph::RecordControlResult(
Status Encapsulator::Subgraph::RecordResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images) {
const absl::flat_hash_map<const Node*, Node*>& node_images) {
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
int src_slot = edge->src_output();
std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
absl::flat_hash_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
std::tie(iter, inserted) =
results_.emplace(OutputTensor(src_node, src_slot), results_.size());
@ -592,7 +594,7 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
FunctionDef fdef;
auto lookup = [this](const Node* node) -> std::optional<string> {
if (control_output_nodes_.contains(node->name())) {
return absl::make_optional(node->name());
return std::make_optional(node->name());
}
return std::nullopt;
};
@ -637,7 +639,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
}
Status Encapsulator::Subgraph::AddFunctionCallNode(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out) {
TF_ASSIGN_OR_RETURN(call_node_, graph_out->AddNode(call_node_def_));
@ -663,7 +665,7 @@ Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
Status Encapsulator::CopySubgraphNodes(
std::unordered_map<const Node*, Node*>* node_images) {
absl::flat_hash_map<const Node*, Node*>* node_images) {
for (Node* node : graph_in_->op_nodes()) {
string func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
@ -678,7 +680,7 @@ Status Encapsulator::CopySubgraphNodes(
}
Status Encapsulator::CopySubgraphEdges(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
for (const Edge* edge : graph_in_->edges()) {
string src_func_id;
@ -752,7 +754,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
Status s;
// Map from input graph nodes to subgraph nodes.
std::unordered_map<const Node*, Node*> node_images;
absl::flat_hash_map<const Node*, Node*> node_images;
// Each entry of src_arg_pairs is a pair whose first element is a node in the
// original graph that has an output edge in the subgraph, and whose second
@ -794,7 +796,7 @@ Status Encapsulator::BuildFunctionDefs(
}
Status Encapsulator::CopyNodesToOutputGraph(
Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
Graph* graph_out, absl::flat_hash_map<const Node*, Node*>* node_images) {
for (Node* node : graph_in_->op_nodes()) {
string func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
@ -811,7 +813,7 @@ Status Encapsulator::CopyNodesToOutputGraph(
}
Status Encapsulator::AddFunctionCallNodes(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out) {
for (auto& subgraph_entry : subgraphs_) {
TF_RETURN_IF_ERROR(
@ -822,7 +824,7 @@ Status Encapsulator::AddFunctionCallNodes(
Status Encapsulator::FindOutputImageOfEdgeSrc(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
const Node* original_src_node, Node** src_image) {
if (IsInSubgraph(src_func_id)) {
// The edge is from a subgraph to a regular node in the output graph so
@ -853,7 +855,7 @@ int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
Status Encapsulator::FindOutputImageOfEdgeDst(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
const Node* original_dst_node, Node** dst_image) {
if (IsInSubgraph(dst_func_id)) {
// The edge is to a subgraph from a regular node in the output graph so
@ -884,9 +886,10 @@ int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
Status Encapsulator::CopyEdgeToOutputGraph(
const Edge* edge, const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
std::unordered_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>* edges_added) {
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out,
absl::flat_hash_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>* edges_added) {
Node* src_image;
TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
src_func_id, dst_func_id, node_images, edge->src(), &src_image));
@ -924,13 +927,13 @@ Status Encapsulator::CopyEdgeToOutputGraph(
}
Status Encapsulator::AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Graph* graph_out) {
// Set of edges already added to the output graph, represented as (src, dst)
// pairs. We use the set to deduplicate edges; multiple edges in the input
// graph may map to one edge in the output graph.
std::unordered_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>
absl::flat_hash_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>
edges_added;
for (const Edge* edge : graph_in_->edges()) {
@ -1010,7 +1013,7 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port,
Status Encapsulator::MakePrunedGraphCopyAndInline(
const Graph& graph, const std::vector<Node*>& sink_nodes,
std::unique_ptr<Graph>* pruned_graph,
std::unordered_map<const Node*, Node*>* node_images,
absl::flat_hash_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library) {
// First copy all ancestor nodes of sink_nodes into a new graph.
pruned_graph->reset(new Graph(library));
@ -1070,7 +1073,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
Status Encapsulator::BuildOutputGraph(Graph* graph_out,
FunctionLibraryDefinition* library) {
// Map from nodes in the input graph to nodes in the output graph.
std::unordered_map<const Node*, Node*> node_images;
absl::flat_hash_map<const Node*, Node*> node_images;
TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));

View File

@ -34,6 +34,7 @@ BuildXlaOpsPassFlags* build_ops_flags;
MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
XlaCallModuleFlags* call_module_flags;
MlirCommonFlags* mlir_flags;
JitRtFlags* jitrt_flags;
std::vector<Flag>* jitrt_flag_list;
@ -76,6 +77,13 @@ bool SetterForXlaAutoJitFlag(const string& value) {
return true;
}
bool SetterForXlaCallModuleDisabledChecks(const string& value) {
auto directives = absl::StrSplit(value, ',', absl::SkipEmpty());
call_module_flags->disabled_checks.insert(directives.begin(),
directives.end());
return true;
}
void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
std::vector<Flag> new_flags = {
Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
@ -184,6 +192,7 @@ void AllocateAndParseFlags() {
build_ops_flags->tf_xla_check_cluster_input_numerics = false;
build_ops_flags->tf_xla_check_cluster_output_numerics = false;
build_ops_flags->tf_xla_disable_constant_folding = false;
build_ops_flags->tf_xla_disable_full_embedding_pipelining = false;
mark_for_compilation_flags = new MarkForCompilationPassFlags;
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
@ -213,9 +222,12 @@ void AllocateAndParseFlags() {
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
ops_flags->tf_xla_async_compilation = false;
ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_ = false;
ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_ = false;
ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_ = true;
ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_ = true;
ops_flags->tf_xla_use_device_api.enabled_for_compile_and_run_ = false;
ops_flags->tf_xla_use_device_api.enabled_for_all_ = false;
call_module_flags = new XlaCallModuleFlags;
// The `enable_mlir_bridge` flag allows the user to explicitly request that
// their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
//
@ -251,6 +263,10 @@ void AllocateAndParseFlags() {
&build_ops_flags->tf_xla_disable_constant_folding,
"If true then disables constant folding on TF graph before XLA "
"compilation."),
Flag("tf_xla_disable_full_embedding_pipelining",
&build_ops_flags->tf_xla_disable_full_embedding_pipelining,
"If true then disables full embedding pipelining and instead use "
"strict SparseCore / TensorCore sequencing."),
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "
@ -277,6 +293,22 @@ void AllocateAndParseFlags() {
&ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_,
"If true, uses Device API (PjRt) for compiling and executing ops "
"one by one in 'on-demand' mode. Defaults to false."),
Flag("tf_xla_use_device_api_for_auto_jit",
&ops_flags->tf_xla_use_device_api.enabled_for_compile_and_run_,
"If true, uses Device API (PjRt) for compilation and execution "
"when auto-clustering is enabled. Defaults to false."),
Flag("tf_xla_use_device_api",
&ops_flags->tf_xla_use_device_api.enabled_for_all_,
"If true, uses Device API (PjRt) for compilation and execution "
"of ops one-by-one in 'on-demand' mode, for functions marked for "
"JIT compilation, or when auto-clustering is enabled. Defaults to "
"false."),
Flag("tf_xla_call_module_disabled_checks",
SetterForXlaCallModuleDisabledChecks, "",
"A comma-sepated list of directives specifying the safety checks "
"to be skipped when compiling XlaCallModuleOp. See the op "
"documentation for the recognized values."),
Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
@ -365,6 +397,11 @@ XlaOpsCommonFlags* GetXlaOpsCommonFlags() {
return ops_flags;
}
XlaCallModuleFlags* GetXlaCallModuleFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return call_module_flags;
}
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;

View File

@ -137,8 +137,9 @@ struct XlaOpsCommonFlags {
}
bool IsEnabledInXlaLaunchForDevice(const DeviceType& device_type) const {
return enabled_for_xla_launch_ &&
xla_launch_allowed_devices_.contains(device_type.type_string());
return enabled_for_all_ ||
(enabled_for_xla_launch_ &&
xla_launch_allowed_devices_.contains(device_type.type_string()));
}
// Allow using Device API (PjRt) for `device_type` in the XlaCompileOnDemand
@ -152,9 +153,26 @@ struct XlaOpsCommonFlags {
bool IsEnabledInXlaCompileOnDemandForDevice(
const DeviceType& device_type) const {
return enabled_for_compile_on_demand_ &&
xla_compile_on_demand_allowed_devices_.contains(
device_type.type_string());
return enabled_for_all_ ||
(enabled_for_compile_on_demand_ &&
xla_compile_on_demand_allowed_devices_.contains(
device_type.type_string()));
}
// Allow using Device API (PjRt) for `device_type` in the XlaCompile and
// XlaRun ops. Please note that `enabled_for_compile_and_run_` needs to be
// true in addition to the `device_type` being allowed in order to use the
// Device API for single device compilation and execution in the XlaCompile
// and XlaRun ops.
void AllowForDeviceInXlaCompileAndRun(const DeviceType& device_type) {
xla_compile_and_run_allowed_devices_.insert(device_type.type_string());
}
bool IsEnabledInXlaCompileAndRunForDevice(
const DeviceType& device_type) const {
return enabled_for_all_ || (enabled_for_compile_and_run_ &&
xla_compile_and_run_allowed_devices_.contains(
device_type.type_string()));
}
// If true, uses Device API (PjRt) for single device compilation and
@ -166,6 +184,16 @@ struct XlaOpsCommonFlags {
// one in "on-demand" mode. Defaults to false.
bool enabled_for_compile_on_demand_;
// If true, uses Device API (PjRt) for compilation and execution when
// auto-clustering is enabled. Defaults to false.
bool enabled_for_compile_and_run_;
// If true, uses Device API (PjRt) for compilation and execution everywhere
// i.e. for functions marked for JIT compilation, for ops in "on-demand"
// mode and autoclustering, no matter whether other flags are enabled or
// not, and whether devices have been allowed or not. Defaults to false.
bool enabled_for_all_;
private:
// Devices for which using Device API (PjRt) is allowed in the XlaLaunch op.
// This can only be modified programmatically.
@ -173,9 +201,18 @@ struct XlaOpsCommonFlags {
// Devices for which using Device API (PjRt) is allowed in the
// XlaCompileOnDemand op. This can only be modified programmatically.
absl::flat_hash_set<std::string> xla_compile_on_demand_allowed_devices_;
// Devices for which using Device API (PjRt) is allowed in the
// XlaCompile and XlaRun ops. This can only be modified programmatically.
absl::flat_hash_set<std::string> xla_compile_and_run_allowed_devices_;
} tf_xla_use_device_api;
};
// Flags for the XlaCallModule kernel.
struct XlaCallModuleFlags {
// Used by XlaCallModuleOp to specify safety checks to disable.
absl::flat_hash_set<std::string> disabled_checks;
};
// Flags for the build_xla_ops pass.
struct BuildXlaOpsPassFlags {
// Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
@ -197,6 +234,10 @@ struct BuildXlaOpsPassFlags {
// Disables all constant folding. The primary use for this is for testing to
// guarantee that tests are run on XLA and not on TF's CPU implementation.
bool tf_xla_disable_constant_folding;
// Disables full embedding pipelining when true. Instead, strict SparseCore
// TensorCore sequencing will be used.
bool tf_xla_disable_full_embedding_pipelining;
};
// Flags for common MLIR configurations.
@ -235,6 +276,7 @@ MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
XlaDeviceFlags* GetXlaDeviceFlags();
XlaOpsCommonFlags* GetXlaOpsCommonFlags();
XlaCallModuleFlags* GetXlaCallModuleFlags();
MlirCommonFlags* GetMlirCommonFlags();

View File

@ -23,6 +23,8 @@ XLA_OPS_DEPS = [
"//tensorflow/compiler/jit:variable_info_util",
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
"//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/jit:xla_host_recv_device_context",
"//tensorflow/compiler/jit:xla_host_send_device_context",
"//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
@ -59,6 +61,8 @@ cc_library(
"//tensorflow/compiler/jit:xla_compile_util",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core/platform:refcount",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -20,11 +20,15 @@ limitations under the License.
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <variant>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/device_compilation_profiler.h"
#include "tensorflow/compiler/jit/device_compiler.h"
@ -35,6 +39,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/jit/xla_compile_util.h"
#include "tensorflow/compiler/jit/xla_compiler_options_util.h"
#include "tensorflow/compiler/jit/xla_host_recv_device_context.h"
#include "tensorflow/compiler/jit/xla_host_send_device_context.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -92,10 +98,11 @@ auto* xla_launch_counter = monitoring::Counter<1>::New(
// the initial values for the resource variables (and cannot snapshot them again
// during execution) because otherwise we risk observing a different snapshot
// with shapes different from what we compiled for.
class XlaExecutableClosure {
template <typename ExecutableType, typename ClientType>
class ExecutableClosure {
public:
explicit XlaExecutableClosure(
xla::LocalClient* client, xla::LocalExecutable* executable,
explicit ExecutableClosure(
ClientType* client, ExecutableType* executable,
const XlaCompiler::CompilationResult* compilation_result,
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
: client_(client),
@ -104,11 +111,11 @@ class XlaExecutableClosure {
resource_var_snapshots_(std::move(resource_var_snapshots)),
num_constant_args_(num_constant_args) {}
XlaExecutableClosure(XlaExecutableClosure&&) = default;
XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
ExecutableClosure(ExecutableClosure&&) = default;
ExecutableClosure& operator=(ExecutableClosure&&) = default;
xla::LocalClient* client() const { return client_; }
xla::LocalExecutable* executable() const { return executable_; }
ClientType* client() const { return client_; }
ExecutableType* executable() const { return executable_; }
const XlaCompiler::CompilationResult* compilation_result() const {
return compilation_result_;
}
@ -118,24 +125,25 @@ class XlaExecutableClosure {
int num_constant_args() const { return num_constant_args_; }
private:
xla::LocalClient* client_;
xla::LocalExecutable* executable_;
ClientType* client_;
ExecutableType* executable_;
const XlaCompiler::CompilationResult* compilation_result_;
ResourceVarsSnapshot resource_var_snapshots_;
int num_constant_args_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
TF_DISALLOW_COPY_AND_ASSIGN(ExecutableClosure);
};
// This maintains a mapping from a globally unique ID to XlaExecutableClosure
// This maintains a mapping from a globally unique ID to ExecutableClosure
// instances.
class XlaExecutableClosureStore {
template <typename ExecutableType, typename ClientType>
class ExecutableClosureStore {
public:
XlaExecutableClosureStore() : key_counter_(0) {}
ExecutableClosureStore() : key_counter_(0) {}
using KeyT = string;
KeyT Produce(XlaExecutableClosure result) {
KeyT Produce(ExecutableClosure<ExecutableType, ClientType> result) {
mutex_lock l(mutex_);
KeyT key = absl::StrCat(key_counter_++);
bool insert_successful = closures_.emplace(key, std::move(result)).second;
@ -144,29 +152,38 @@ class XlaExecutableClosureStore {
return key;
}
XlaExecutableClosure Consume(const KeyT& key) {
ExecutableClosure<ExecutableType, ClientType> Consume(const KeyT& key) {
mutex_lock l(mutex_);
auto it = closures_.find(key);
DCHECK(it != closures_.end());
XlaExecutableClosure value = std::move(it->second);
ExecutableClosure<ExecutableType, ClientType> value = std::move(it->second);
closures_.erase(it);
return value;
}
static XlaExecutableClosureStore* Global() {
static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
static ExecutableClosureStore* Global() {
static ExecutableClosureStore* instance = new ExecutableClosureStore;
return instance;
}
private:
mutex mutex_;
int64_t key_counter_ TF_GUARDED_BY(mutex_);
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
TF_GUARDED_BY(mutex_);
absl::flat_hash_map<KeyT, ExecutableClosure<ExecutableType, ClientType>>
closures_ TF_GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
TF_DISALLOW_COPY_AND_ASSIGN(ExecutableClosureStore);
};
using XlaExecutableClosure =
ExecutableClosure<xla::LocalExecutable, xla::LocalClient>;
using XlaExecutableClosureStore =
ExecutableClosureStore<xla::LocalExecutable, xla::LocalClient>;
using PjRtExecutableClosure =
ExecutableClosure<xla::PjRtLoadedExecutable, xla::PjRtClient>;
using PjRtExecutableClosureStore =
ExecutableClosureStore<xla::PjRtLoadedExecutable, xla::PjRtClient>;
se::Stream* GetStream(OpKernelContext* ctx) {
return ctx->op_device_context() ? ctx->op_device_context()->stream()
: nullptr;
@ -185,6 +202,111 @@ XlaComputationLaunchContext GetLaunchContext(
return launch_context;
}
Status GetTaskName(const std::string_view device_name, std::string* task_name) {
string ignored;
if (!DeviceNameUtils::SplitDeviceName(device_name, task_name, &ignored)) {
return errors::InvalidArgument("Unable to parse device name: ",
device_name);
}
return OkStatus();
}
// Provide SendDeviceMemoryFunction for XLA host callbacks. This callback
// handles transferring from device to host.
xla::SendDeviceMemoryFunction GetSendDeviceMemoryFunction(
OpKernelContext* ctx) {
return
[ctx](int64_t channel_id, se::Stream* stream, const xla::Shape& shape,
const se::DeviceMemoryBase& device_memory_base,
const absl::flat_hash_map<std::string, std::string>& frontend_attrs)
-> StatusOr<tsl::AsyncValueRef<se::Event>> {
auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous");
// Generate the Rendezvous key.
const std::string& rendezvous_key_base = iter->second;
const std::string& src_device = ctx->device()->name();
std::string task_prefix;
TF_RETURN_IF_ERROR(GetTaskName(src_device, &task_prefix));
const std::string dst_device =
absl::StrCat(task_prefix, "/device:CPU:0");
const std::string& rendezvous_key =
Rendezvous::CreateKey(src_device, /*src_incarnation=*/1, dst_device,
rendezvous_key_base, FrameAndIter(0, 0));
VLOG(2) << "Rendezvous Key for receiving at host: " << rendezvous_key;
RendezvousInterface::ParsedKey parsed_key;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key));
tsl::AsyncValueRef<se::Event> done_event =
tsl::MakeConstructedAsyncValueRef<se::Event>(stream->parent());
if (!done_event->Init()) {
return errors::Internal(
"Failed to initialize done event (channel_id=%d)", channel_id);
}
Rendezvous::Args args;
// Rendezvous::Args owns the device context pointer.
args.device_context = new XlaHostRecvDeviceContext(
stream, device_memory_base, shape, done_event);
Tensor host_tensor;
TF_RETURN_IF_ERROR(
ctx->rendezvous()->Send(parsed_key, args, host_tensor, false));
return std::move(done_event);
};
}
// Provide RecvDeviceMemoryFunction for XLA host callbacks. This callback
// handles transferring from host to device.
xla::RecvDeviceMemoryFunction GetRecvDeviceMemoryFunction(
OpKernelContext* ctx) {
return
[ctx](int64_t channel_id, se::Stream* stream, const xla::Shape& shape,
se::DeviceMemoryBase* device_memory_base,
const absl::flat_hash_map<std::string, std::string>& frontend_attrs)
-> StatusOr<tsl::AsyncValueRef<se::Event>> {
auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous");
// Generate the Rendezvous key.
const std::string& rendezvous_key_base = iter->second;
const std::string& dst_device = ctx->device()->name();
std::string task_prefix;
TF_RETURN_IF_ERROR(GetTaskName(dst_device, &task_prefix));
const std::string src_device =
absl::StrCat(task_prefix, "/device:CPU:0");
const std::string& rendezvous_key =
Rendezvous::CreateKey(src_device, /*src_incarnation=*/1, dst_device,
rendezvous_key_base, FrameAndIter(0, 0));
VLOG(2) << "Rendezvous Key for sending from host: " << rendezvous_key;
RendezvousInterface::ParsedKey parsed_key;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key));
tsl::AsyncValueRef<se::Event> done_event =
tsl::MakeConstructedAsyncValueRef<se::Event>(stream->parent());
if (!done_event->Init()) {
return errors::Internal(
"Failed to initialize done event (channel_id=%d)", channel_id);
}
Rendezvous::Args args;
// Rendezvous::Args owns the device context pointer.
args.device_context = new XlaHostSendDeviceContext(
stream, device_memory_base, shape, done_event);
Tensor device_tensor;
bool is_dead;
TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(
parsed_key, args, &device_tensor, /*is_dead=*/&is_dead));
return std::move(done_event);
};
}
StatusOr<xla::ExecutionOutput> RunExecutable(
const XlaPlatformInfo& platform_info,
const XlaComputationLaunchContext& launch_context,
@ -200,6 +322,15 @@ StatusOr<xla::ExecutionOutput> RunExecutable(
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
// Host callbacks used for HLO send/recv.
xla::SendDeviceMemoryFunction send_function =
GetSendDeviceMemoryFunction(ctx);
run_options.set_send_device_memory_function(&send_function);
xla::RecvDeviceMemoryFunction recv_function =
GetRecvDeviceMemoryFunction(ctx);
run_options.set_recv_device_memory_function(&recv_function);
StatusOr<xla::ExecutionOutput> execution_output;
bool run_synchronous =
!stream || platform_info.platform_id() == se::host::kHostPlatformId;
@ -263,7 +394,7 @@ Status CompileToLocalExecutable(
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();
if (!rm) {
return errors::Internal("No resource manager.");
return absl::InternalError("No resource manager.");
}
XlaDeviceCompiler* xla_device_compiler;
@ -312,23 +443,13 @@ Status CompileToPjRtLoadedExecutable(
// in the ResourceMgr.
ResourceMgr* rm = ctx.resource_manager();
if (!rm) {
return errors::Internal("No resource manager.");
return absl::InternalError("No resource manager.");
}
PjRtDeviceCompiler* pjrt_device_compiler;
TF_RETURN_IF_ERROR(rm->LookupOrCreate<PjRtDeviceCompiler>(
rm->default_container(), "pjrt_device_compiler", &pjrt_device_compiler,
[&](PjRtDeviceCompiler** pjrt_device_compiler) {
return BuildPjRtDeviceCompiler(platform_info, ctx.function_library(),
pjrt_device_compiler);
}));
DeviceCompilationProfiler* profiler;
TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
rm->default_container(), "pjrt_device_compilation_profiler", &profiler,
[](DeviceCompilationProfiler** profiler) {
*profiler = new DeviceCompilationProfiler();
return OkStatus();
}));
TF_RETURN_IF_ERROR(GetOrCreatePjRtDeviceCompilerAndProfiler(
platform_info, ctx.function_library(), &pjrt_device_compiler, &profiler));
// Hold the reference to the PJRT device compiler and profiler during
// evaluation. (We could probably free them sooner because the ResourceMgr
// will retain references, but this is more obviously correct.)
@ -337,8 +458,9 @@ Status CompileToPjRtLoadedExecutable(
*client = pjrt_device_compiler->client();
XlaCompiler::Options options = GenerateCompilerOptionsForPjRt(
*ctx.function_library(), ctx.device(), platform_info);
XlaCompiler::Options options =
GenerateCompilerOptionsForPjRt(*ctx.function_library(), ctx.device(),
platform_info, pjrt_device_compiler);
XlaCompiler::CompileOptions compile_options =
GenerateCompileOptions(has_ref_vars, may_alias_resource_update);
@ -474,19 +596,23 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
compilation_result, done, inputs,
resources = resources_]() {
auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK_ASYNC(
ctx,
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
&variable_infos),
done);
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
done);
OP_REQUIRES_OK_ASYNC(
ctx,
RunPjRtExecutable(*pjrt_client, inputs, variable_infos,
*compilation_result, pjrt_executable, ctx),
done);
// Separate scope so that VariableInfo locks are released before done() is
// called.
{
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK_ASYNC(
ctx,
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
&variable_infos),
done);
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
done);
OP_REQUIRES_OK_ASYNC(
ctx,
RunPjRtExecutable(*pjrt_client, inputs, variable_infos,
*compilation_result, pjrt_executable, ctx),
done);
}
VLOG(2) << "Done executing with PJRT.";
done();
};
@ -505,65 +631,69 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
// Continuation of the execution, may be run in a different thread.
auto run_xla_cluster = [ctx, client, executable, compilation_result, done,
inputs, resources = resources_]() {
auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK_ASYNC(
ctx,
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
&variable_infos),
done);
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
done);
std::map<int, const Tensor*> resource_var_ptrs;
for (int i = 0; i < resources.size(); i++) {
resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor();
}
std::shared_ptr<se::DeviceMemoryAllocator> allocator =
GetAllocator(ctx->device(), GetStream(ctx), platform_info);
XlaComputationLaunchContext launch_context =
GetLaunchContext(platform_info, ctx, client, allocator.get());
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(
ctx, compilation_result, resource_var_ptrs,
/*missing_ctx_input_prefix=*/0, input_output_alias);
OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done);
xla::gpu::GpuExecutableRunOptions gpu_options;
xla::DeviceAssignment device_assignment;
xla::ExecutableRunOptions run_options;
if (compilation_result->collective_info.has_value()) {
// Separate scope so that VariableInfo locks are released before done is
// called.
{
auto platform_info = XlaPlatformInfoFromDevice(ctx->device());
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK_ASYNC(
ctx,
ResolveDeviceAssignment(ctx, *compilation_result->collective_info,
run_options, device_assignment, gpu_options),
GetUpdatedVariables(ctx, inputs, resources, *compilation_result,
&variable_infos),
done);
OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)),
done);
std::map<int, const Tensor*> resource_var_ptrs;
for (int i = 0; i < resources.size(); i++) {
resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor();
}
std::shared_ptr<se::DeviceMemoryAllocator> allocator =
GetAllocator(ctx->device(), GetStream(ctx), platform_info);
XlaComputationLaunchContext launch_context =
GetLaunchContext(platform_info, ctx, client, allocator.get());
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(
ctx, compilation_result, resource_var_ptrs,
/*missing_ctx_input_prefix=*/0, input_output_alias);
OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done);
xla::gpu::GpuExecutableRunOptions gpu_options;
xla::DeviceAssignment device_assignment;
xla::ExecutableRunOptions run_options;
if (compilation_result->collective_info.has_value()) {
OP_REQUIRES_OK_ASYNC(ctx,
ResolveDeviceAssignment(
ctx, *compilation_result->collective_info,
run_options, device_assignment, gpu_options),
done);
}
// Hardcode run id to always be zero: TF distributed strategy
// differentiates between subsequent runs using dependency edges. This
// is safe, as only TF dist-strat can produce distributed ops, and we
// can rely on TF dist-strat invariants.
xla::RunId run_id(0);
run_options.set_run_id(run_id);
StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
platform_info, launch_context, std::move(*execution_inputs),
run_options, executable, ctx, allocator.get());
OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(),
done);
OP_REQUIRES_OK_ASYNC(
ctx,
launch_context.PopulateOutputs(
ctx, compilation_result, execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
input_output_alias, resource_var_ptrs),
done);
VLOG(1) << "Done";
}
// Hardcode run id to always be zero: TF distributed strategy
// differentiates between subsequent runs using dependency edges. This
// is safe, as only TF dist-strat can produce distributed ops, and we
// can rely on TF dist-strat invariants.
xla::RunId run_id(0);
run_options.set_run_id(run_id);
StatusOr<xla::ExecutionOutput> execution_output = RunExecutable(
platform_info, launch_context, std::move(*execution_inputs),
run_options, executable, ctx, allocator.get());
OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(),
done);
OP_REQUIRES_OK_ASYNC(
ctx,
launch_context.PopulateOutputs(
ctx, compilation_result, execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
input_output_alias, resource_var_ptrs),
done);
VLOG(1) << "Done";
done();
};
@ -658,9 +788,11 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
void XlaCompileOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaCompileOp " << def().name()
<< (must_compile_ ? "(must-compile)" : "");
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
const XlaCompiler::CompilationResult* kernel = nullptr;
xla::LocalClient* client = nullptr;
xla::LocalExecutable* executable = nullptr;
xla::PjRtClient* pjrt_client = nullptr;
xla::PjRtLoadedExecutable* pjrt_executable = nullptr;
ResourceVarsSnapshot variables_snapshot;
std::vector<const Tensor*> inputs = InputsFromContext(ctx);
@ -678,6 +810,11 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
: DeviceCompileMode::kLazy;
}();
bool use_pjrt =
GetXlaOpsCommonFlags()
->tf_xla_use_device_api.IsEnabledInXlaCompileAndRunForDevice(
platform_info_.device_type());
if (GetXlaOpsCommonFlags()->tf_xla_always_defer_compilation ||
cannot_compile_cluster) {
executable = nullptr;
@ -691,22 +828,33 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
// Do not alias resource updates as locking variables in XlaCompile and
// unlocking them in XlaRun may lead to deadlocks.
const Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, args, compile_mode,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
Status status;
if (use_pjrt) {
VLOG(2) << "Using PJRT for compilation. Function name: "
<< function_.name();
status = CompileToPjRtLoadedExecutable(
*ctx, platform_info_, function_, args, compile_mode, has_ref_vars_,
/*may_alias_resource_update=*/false, &kernel, &pjrt_client,
&pjrt_executable);
} else {
status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, args, compile_mode,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
}
if (compile_mode != DeviceCompileMode::kLazy ||
status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}
if (status.code() == error::UNIMPLEMENTED) {
LOG(WARNING) << "Compilation failed:" << status.ToString()
LOG(WARNING) << "Compilation failed:" << status
<< ". Falling back to TF function call.";
BroadcastOptimizationRemark(
XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
.IgnoreError();
executable = nullptr;
pjrt_executable = nullptr;
mutex_lock guard(cannot_compile_cluster_mu_);
cannot_compile_cluster_ = true;
}
@ -718,28 +866,36 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
// Async compilation returns nullptr executable without an error.
if (!executable) {
if (!executable && !pjrt_executable) {
DCHECK(!must_compile_);
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
compilation_successful.scalar<bool>()() = false;
ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
ctx->set_output(0, compilation_key);
ctx->set_output(1, compilation_successful);
return;
}
// Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
// Each execution of an XlaCompile op creates a new ExecutableClosure, even
// if it didn't have to compile the cluster because of a compilation-cache
// hit. This is because we at least need new snapshots of the resource
// variables.
XlaExecutableClosureStore::KeyT key =
XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
client, executable, kernel, std::move(variables_snapshot),
constants_.size()));
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
compilation_key.flat<tstring>()(0) = key;
if (use_pjrt) {
PjRtExecutableClosureStore::KeyT key =
PjRtExecutableClosureStore::Global()->Produce(PjRtExecutableClosure(
pjrt_client, pjrt_executable, kernel, std::move(variables_snapshot),
constants_.size()));
compilation_key.flat<tstring>()(0) = key;
VLOG(2) << "Compiled with PJRT. compilation_key: " << key;
} else {
XlaExecutableClosureStore::KeyT key =
XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
client, executable, kernel, std::move(variables_snapshot),
constants_.size()));
compilation_key.flat<tstring>()(0) = key;
VLOG(2) << "Compiled with XLA. compilation_key: " << key;
}
Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
compilation_successful.flat<bool>()(0) = true;

View File

@ -24,5 +24,5 @@ py_library(
name = "xla_ops_grad",
srcs = ["xla_ops_grad.py"],
srcs_version = "PY3",
deps = ["//tensorflow/python:framework_ops"],
deps = ["//tensorflow/python/framework:ops"],
)

View File

@ -77,8 +77,8 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:path",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/core/graph/graph_def_builder.h"
@ -51,7 +52,7 @@ class JitCompilationListener : public XlaActivityListener {
bool expect_persistent_cache_use) {
for (const auto& activity : activity_history_) {
if (activity.used_persistent_cache() != expect_persistent_cache_use) {
return errors::FailedPrecondition("Unexpected listener history.");
return absl::FailedPreconditionError("Unexpected listener history.");
}
}
return OkStatus();

View File

@ -23,14 +23,15 @@ Status TfGraphToHloCompiler::Compile(const XlaCompiler::CompileOptions& options,
const NameAttrList& function,
absl::Span<const XlaArgument> args,
XlaCompilationResult* result) {
return xla_compiler_.CompileFunction(options, function, args, result);
return ADD_SOURCE_LOCATION(
xla_compiler_.CompileFunction(options, function, args, result));
}
Status TfGraphToHloCompiler::CompileSingleOp(
const XlaCompiler::CompileOptions& options, const OpKernelContext* ctx,
absl::Span<const XlaArgument> args, XlaCompilationResult* result) {
return xla_compiler_.CompileSingleOp(
options, XlaCompiler::SingleOpCompileArgument(*ctx), args, result);
return ADD_SOURCE_LOCATION(xla_compiler_.CompileSingleOp(
options, XlaCompiler::SingleOpCompileArgument(*ctx), args, result));
}
} // namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
@ -43,6 +44,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tf_pjrt_client.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/core/framework/function.h"
@ -53,6 +55,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/tfrt/common/pjrt_util.h"
#include "tensorflow/tsl/platform/errors.h"
namespace tensorflow {
@ -169,28 +172,12 @@ Status XlaCompileOnDemandOp::Compile(
DeviceCompilationProfiler** profiler,
const XlaCompiler::CompilationResult** result,
xla::PjRtLoadedExecutable** executable) {
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();
if (!rm) {
return errors::Internal("No resource manager.");
}
TF_RETURN_IF_ERROR(GetOrCreatePjRtDeviceCompilerAndProfiler(
platform_info_, ctx->function_library(), pjrt_device_compiler, profiler));
TF_RETURN_IF_ERROR(rm->LookupOrCreate<PjRtDeviceCompiler>(
rm->default_container(), "pjrt_device_compiler", pjrt_device_compiler,
[&](PjRtDeviceCompiler** pjrt_device_compiler) {
return BuildPjRtDeviceCompiler(platform_info_, ctx->function_library(),
pjrt_device_compiler);
}));
TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
rm->default_container(), "pjrt_device_compilation_profiler", profiler,
[](DeviceCompilationProfiler** profiler) {
*profiler = new DeviceCompilationProfiler();
return OkStatus();
}));
XlaCompiler::Options options = GenerateCompilerOptionsForPjRt(
*(ctx->function_library()), ctx->device(), platform_info_);
XlaCompiler::Options options =
GenerateCompilerOptionsForPjRt(*(ctx->function_library()), ctx->device(),
platform_info_, *pjrt_device_compiler);
// No detailed logging for on demand op.
options.detailed_logging = false;
XlaCompiler::CompileOptions compile_options = GetCompileOptions(true);

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compile_util.h"
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/jit/flags.h"
@ -24,6 +25,11 @@ limitations under the License.
#include "tensorflow/core/util/determinism.h"
namespace tensorflow {
namespace {
constexpr const char* kPjRtDeviceCompilerResourceName = "pjrt_device_compiler";
constexpr const char* kPjRtDeviceCompilationProfilerResourceName =
"pjrt_device_compilation_profiler";
} // namespace
StatusOr<std::unique_ptr<Graph>> CreateSingleOpGraph(
const NodeDef& node_def, absl::Span<const XlaArgument> args,
@ -69,7 +75,18 @@ StatusOr<std::unique_ptr<Graph>> CreateSingleOpGraph(
bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type) {
const auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api;
return rollout_config.IsEnabledInXlaLaunchForDevice(device_type) ||
rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type);
rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type) ||
rollout_config.IsEnabledInXlaCompileAndRunForDevice(device_type);
}
std::string GetPjRtDeviceCompilerResourceName(const DeviceType& device_type) {
return absl::StrCat(kPjRtDeviceCompilerResourceName, "_",
device_type.type_string());
}
std::string GetPjRtDeviceCompilationProfilerResourceName(
const DeviceType& device_type) {
return absl::StrCat(kPjRtDeviceCompilationProfilerResourceName, "_",
device_type.type_string());
}
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_
#include <memory>
#include <string>
#include "tensorflow/compiler/tf2xla/xla_argument.h"
#include "tensorflow/core/graph/graph.h"
@ -47,6 +48,14 @@ StatusOr<std::unique_ptr<Graph>> CreateSingleOpGraph(
// Checks if single device compilation and execution with PJRT is enabled for
// `device_type` in either the XlaLaunch op or the XlaCompileOnDemand op.
bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type);
// Gets the resource name of the PjRt DeviceCompiler for `device_type`.
std::string GetPjRtDeviceCompilerResourceName(const DeviceType& device_type);
// Gets the resource name of the DeviceCompilationProfiler for `device_type`
// when PjRt is used for compilation and execution.
std::string GetPjRtDeviceCompilationProfilerResourceName(
const DeviceType& device_type);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace {
@ -118,5 +119,31 @@ TEST(XlaCompileUtilTest, PjRtXlaCompileOnDemandFlagTest) {
EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU)));
}
TEST(XlaCompileUtilTest, PjRtDeviceCompilerResourceName) {
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_TPU)),
"pjrt_device_compiler_TPU");
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_TPU_NODE)),
"pjrt_device_compiler_TPU");
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_CPU)),
"pjrt_device_compiler_CPU");
EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_GPU)),
"pjrt_device_compiler_GPU");
}
TEST(XlaCompileUtilTest, PjRtDeviceCompilationProfilerResourceName) {
EXPECT_EQ(
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_TPU)),
"pjrt_device_compilation_profiler_TPU");
EXPECT_EQ(
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_TPU_NODE)),
"pjrt_device_compilation_profiler_TPU");
EXPECT_EQ(
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_CPU)),
"pjrt_device_compilation_profiler_CPU");
EXPECT_EQ(
GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_GPU)),
"pjrt_device_compilation_profiler_GPU");
}
} // namespace
} // namespace tensorflow

View File

@ -15,10 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compiler_options_util.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
namespace tensorflow {
namespace {
using XlaDeviceCompiler =
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
using PjRtDeviceCompiler =
DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>;
inline void LogOptions(const XlaCompiler::Options& options) {
VLOG(2) << "XlaCompiler::Options[device_type=" << options.device_type
@ -81,7 +85,8 @@ XlaCompiler::Options GenerateCompilerOptionsForTfrtTpu(
XlaCompiler::Options GenerateCompilerOptionsForPjRt(
const FunctionLibraryRuntime& function_library,
const DeviceBase* device_base, const XlaPlatformInfo& platform_info) {
const DeviceBase* device_base, const XlaPlatformInfo& platform_info,
const PjRtDeviceCompiler* pjrt_device_compiler) {
XlaCompiler::Options options;
options.device_ordinal = device_base->parsed_name().id;
options.flib_def = function_library.GetFunctionLibraryDefinition();
@ -96,8 +101,9 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt(
options.device_type = metadata->jit_device_type();
options.shape_determination_fns =
metadata->default_shape_determination_fns();
} else if (pjrt_device_compiler != nullptr) {
options.device_type = pjrt_device_compiler->device_type();
}
// TODO(b/255826209): Set options for non-XLA devices once PjRt supports them.
// TODO(b/255826209): Confirm below options are correctly set after testing.
options.allow_cpu_custom_calls = false;
options.alias_passthrough_params = false;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
namespace tensorflow {
@ -39,9 +40,12 @@ XlaCompiler::Options GenerateCompilerOptionsForTfrtTpu(
// Returns created options for XLA compiler when PjRt (Device API) is used for
// compilation and execution.
// TODO(b/255826209): Remove default arg once PjRtCompileOnDemand op is deleted.
XlaCompiler::Options GenerateCompilerOptionsForPjRt(
const FunctionLibraryRuntime& function_library,
const DeviceBase* device_base, const XlaPlatformInfo& platform_info);
const DeviceBase* device_base, const XlaPlatformInfo& platform_info,
const DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>*
pjrt_device_compiler = nullptr);
} // namespace tensorflow

View File

@ -23,12 +23,14 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/device_base.h"
@ -42,18 +44,31 @@ using XlaDeviceCompiler =
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>;
using XlaDeviceExecutablePersistor =
DeviceExecutablePersistor<xla::LocalExecutable, xla::LocalClient>;
using PjRtDeviceCompiler =
DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>;
using PjRtDeviceExecutablePersistor =
DeviceExecutablePersistor<xla::PjRtLoadedExecutable, xla::PjRtClient>;
XlaDeviceCompiler* CreateXlaDeviceCompiler(
const XlaDeviceExecutablePersistor::Config& persistor_config,
DeviceType device_type, xla::LocalClient* local_client) {
XlaDeviceCompiler* CreateXlaDeviceCompiler(DeviceType device_type,
xla::LocalClient* local_client) {
auto persistor = std::make_unique<XlaDeviceExecutablePersistor>(
std::move(persistor_config), device_type);
XlaDeviceExecutablePersistor::Config(), device_type);
auto compiler_client =
std::make_unique<XlaDeviceCompilerClient>(local_client);
return new XlaDeviceCompiler(std::move(persistor),
std::move(compiler_client));
}
PjRtDeviceCompiler* CreatePjRtDeviceCompiler(DeviceType device_type,
xla::PjRtClient* pjrt_client) {
auto persistor = std::make_unique<PjRtDeviceExecutablePersistor>(
PjRtDeviceExecutablePersistor::Config(), device_type);
auto compiler_client =
std::make_unique<PjRtDeviceCompilerClient>(pjrt_client);
return new PjRtDeviceCompiler(std::move(persistor),
std::move(compiler_client));
}
std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
GetShapeDeterminationFns() {
XlaHelpers::ShapeRepresentationFn shape_representation_fn =
@ -160,6 +175,45 @@ TEST_F(XlaCompilerOptionsTest, PjRtOptionsPjRtBaseDevice) {
tensorflow::XlaLayoutPreference::kTpuPreferLinearLayout);
}
TEST_F(XlaCompilerOptionsTest, PjRtOptionsNonXlaDevice) {
device_setup_.AddDevicesAndSetUp({DEVICE_CPU});
Device* device = device_setup_.GetDevice(DEVICE_CPU);
DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT);
XlaPlatformInfo platform_info(compilation_device_type,
/*platform_id=*/nullptr,
/*xla_device_metadata=*/nullptr,
/*pjrt_device_metadata=*/nullptr,
/*device_allocator=*/nullptr);
auto pjrt_device_compiler =
CreatePjRtDeviceCompiler(compilation_device_type, nullptr);
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
XlaCompiler::Options options = GenerateCompilerOptionsForPjRt(
*device_setup_.flr(), device, platform_info, pjrt_device_compiler);
EXPECT_EQ(options.device_type, compilation_device_type);
EXPECT_EQ(options.device_ordinal, 0);
EXPECT_NE(options.flib_def, nullptr);
EXPECT_EQ(options.graph_def_version, TF_GRAPH_DEF_VERSION);
EXPECT_FALSE(options.allow_cpu_custom_calls);
EXPECT_FALSE(options.alias_passthrough_params);
EXPECT_FALSE(options.detailed_logging);
// Check whether options have default shape determination functions set.
TF_ASSERT_OK_AND_ASSIGN(
auto shape, options.shape_determination_fns.shape_representation_fn(
TensorShape(), DT_FLOAT, false,
tensorflow::XlaLayoutPreference::kNoPreference));
xla::ShapeProto shape_proto;
shape_proto.set_element_type(xla::PrimitiveType::F32);
shape_proto.mutable_layout();
EXPECT_EQ(shape, xla::Shape(shape_proto));
EXPECT_EQ(options.shape_determination_fns.layout_preference_fn(
TensorShape(), DT_FLOAT, std::nullopt),
tensorflow::XlaLayoutPreference::kNoPreference);
}
TEST_F(XlaCompilerOptionsTest, XlaOptions) {
device_setup_.AddDevicesAndSetUp({DEVICE_XLA_GPU});
Device* device = device_setup_.GetDevice(DEVICE_XLA_GPU);
@ -168,8 +222,8 @@ TEST_F(XlaCompilerOptionsTest, XlaOptions) {
DeviceType device_type = DeviceType(DEVICE_XLA_GPU);
DeviceType compilation_device_type = DeviceType(DEVICE_GPU_XLA_JIT);
auto xla_device_compiler = CreateXlaDeviceCompiler(
XlaDeviceExecutablePersistor::Config(), compilation_device_type, client);
auto xla_device_compiler =
CreateXlaDeviceCompiler(compilation_device_type, client);
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
se::Platform::Id platform_id = se::host::kHostPlatformId;
@ -208,8 +262,8 @@ TEST_F(XlaCompilerOptionsTest, XlaOptionsHasRefVarsNoXlaDeviceMetadata) {
DeviceType device_type = DeviceType(DEVICE_CPU);
DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT);
auto xla_device_compiler = CreateXlaDeviceCompiler(
XlaDeviceExecutablePersistor::Config(), compilation_device_type, client);
auto xla_device_compiler =
CreateXlaDeviceCompiler(compilation_device_type, client);
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
se::Platform::Id platform_id = se::host::kHostPlatformId;
@ -249,8 +303,8 @@ TEST_F(XlaCompilerOptionsTest, TfRtTpuOptions) {
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
auto xla_device_compiler = CreateXlaDeviceCompiler(
XlaDeviceExecutablePersistor::Config(), compilation_device_type, client);
auto xla_device_compiler =
CreateXlaDeviceCompiler(compilation_device_type, client);
core::ScopedUnref xla_device_compiler_ref(xla_device_compiler);
XlaCompiler::Options options = GenerateCompilerOptionsForTfrtTpu(

View File

@ -0,0 +1,49 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_host_recv_device_context.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
namespace tensorflow {
void XlaHostRecvDeviceContext::CopyDeviceTensorToCPU(
const Tensor* device_tensor, StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) {
DataType dtype = EncodePrimitiveTypeAsDataType(shape_.element_type()).value();
TensorShape tensor_shape;
Status status = XLAShapeToTensorShape(shape_, &tensor_shape);
if (!status.ok()) {
done(status);
return;
}
*cpu_tensor = Tensor(dtype, tensor_shape);
stream_->ThenMemcpy(cpu_tensor->data(), device_memory_base_,
device_memory_base_.size());
stream_->ThenRecordEvent(&done_event_.get());
if (auto st = stream_->BlockHostUntilDone(); !st.ok()) {
done_event_.SetError(absl::InternalError(absl::StrFormat(
"failed to synchronize send operation with a stream: %s",
st.ToString())));
return;
}
done_event_.SetStateConcrete();
done(OkStatus());
}
} // namespace tensorflow

View File

@ -0,0 +1,92 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_
#define TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
#include "tensorflow/compiler/xla/stream_executor/stream.h"
#include "tensorflow/core/framework/device_base.h"
#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime
namespace tensorflow {
// XlaHostRecvDeviceContext is a DeviceContext that is intended to be
// used to transfer from device->host using Rendezvous. It transfers the
// content of `device_memory_base` with `shape` using `stream`. Only
// `CopyDeviceTensorToCPU` method is implemented. The `done_event` is marked as
// Concrete once transfer is completed.
//
// Example usage:
//
// Device device;
// stream_executor::Stream stream(executor);
// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2}));
// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
// xla::Shape shape(xla::F32, {2, 2}, {}, {})
// tsl::AsyncValueRef<se::Event> done_event =
// tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
// done_event->Init();
// Tensor dest_cpu_tensor;
//
// XlaHostRecvDeviceContext device_context(&stream, gpu_dst,
// shape, done_event);
// device_context.CopyDeviceTensorToCPUSync(
// &device_tensor, "", &device, &dest_cpu_tensor);
class XlaHostRecvDeviceContext : public DeviceContext {
public:
XlaHostRecvDeviceContext(se::Stream* stream,
const se::DeviceMemoryBase& device_memory_base,
const xla::Shape& shape,
tsl::AsyncValueRef<se::Event>& done_event)
: stream_(stream),
device_memory_base_(device_memory_base),
shape_(shape),
done_event_(done_event) {}
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done,
bool sync_dst_compute) const override {
done(errors::Internal("host->device copy not implemented."));
}
// Copies `device_memory_base_` with `shape_` into `cpu_tensor`.
// `device_tensor` is unused.
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
Tensor* output_tensor,
StatusCallback done) const override {
done(errors::Internal("device->device copy not implemented."));
}
private:
se::Stream* stream_; // Not owned.
// This is copied rather than a reference or pointer since its lifetime
// is not guaranteed to outlast the original object. Object slicing is
// not an issue here since only DeviceMemoryBase methods/members are used.
const se::DeviceMemoryBase device_memory_base_;
const xla::Shape shape_;
tsl::AsyncValueRef<se::Event> done_event_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaHostRecvDeviceContext);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_

View File

@ -0,0 +1,39 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_host_send_device_context.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
namespace tensorflow {
void XlaHostSendDeviceContext::CopyCPUTensorToDevice(
const Tensor* cpu_tensor, Device* device, Tensor* device_tensor,
StatusCallback done, bool sync_dst_compute) const {
stream_->ThenMemcpy(device_memory_base_, cpu_tensor->data(),
device_memory_base_->size());
stream_->ThenRecordEvent(&done_event_.get());
if (auto st = stream_->BlockHostUntilDone(); !st.ok()) {
done_event_.SetError(absl::InternalError(absl::StrFormat(
"failed to synchronize send operation with a stream: %s",
st.ToString())));
return;
}
done_event_.SetStateConcrete();
done(OkStatus());
}
} // namespace tensorflow

View File

@ -0,0 +1,89 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_
#define TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
#include "tensorflow/compiler/xla/stream_executor/stream.h"
#include "tensorflow/core/framework/device_base.h"
#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime
namespace tensorflow {
// XlaHostSendDeviceContext is a DeviceContext that is intended to be
// used to transfer from host->device using Rendezvous. It transfers the
// content of `device_memory_base` with `shape` using `stream`. Only
// `CopyCPUTensorToDevice` method is implemented. The `done_event` is marked as
// Concrete once transfer is completed.
//
// Example usage:
//
// Device device;
// stream_executor::Stream stream(executor);
// Tensor cpu_tensor(host_allocator, DT_FLOAT, TensorShape({2, 2}));
// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2}));
// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
// xla::Shape shape(xla::F32, {2, 2}, {}, {})
// tsl::AsyncValueRef<se::Event> done_event =
// tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
// done_event->Init();
//
// XlaHostSendDeviceContext device_context(&stream, &gpu_dst,
// shape, done_event);
// device_context.CopyCPUTensorToDeviceSync(
// &cpu_tensor, &device, &device_tensor);
class XlaHostSendDeviceContext : public DeviceContext {
public:
XlaHostSendDeviceContext(se::Stream* stream,
se::DeviceMemoryBase* device_memory_base,
const xla::Shape& shape,
tsl::AsyncValueRef<se::Event>& done_event)
: stream_(stream),
device_memory_base_(device_memory_base),
shape_(shape),
done_event_(done_event) {}
// Copies 'cpu_tensor' to `device_memory_base_` with `shape_`.
// `device_tensor` is unused.
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done,
bool sync_dst_compute) const override;
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override {
done(errors::Internal("host->device copy not implemented."));
}
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
Tensor* output_tensor,
StatusCallback done) const override {
done(errors::Internal("device->device copy not implemented."));
}
private:
se::Stream* stream_; // Not owned.
se::DeviceMemoryBase* device_memory_base_; // Not owned.
const xla::Shape shape_;
tsl::AsyncValueRef<se::Event> done_event_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaHostSendDeviceContext);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_

View File

@ -0,0 +1,171 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string_view>
#include <utility>
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/compiler/jit/xla_host_recv_device_context.h"
#include "tensorflow/compiler/jit/xla_host_send_device_context.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
#include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h"
#include "tensorflow/compiler/xla/stream_executor/stream.h"
#include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
class XlaHostSendRecvDeviceContextTest : public ::testing::Test {
public:
void SetDevice(const string& device_type) {
auto device_factory = DeviceFactory::GetFactory(device_type);
SessionOptions options;
std::vector<std::unique_ptr<Device>> devices;
Status s = device_factory->CreateDevices(
options, "/job:worker/replica:0/task:0", &devices);
device_ = std::move(devices[0]);
AllocatorAttributes host_alloc_attr;
host_alloc_attr.set_on_host(true);
host_allocator_ = device_->GetAllocator(host_alloc_attr);
AllocatorAttributes device_alloc_attr;
device_alloc_attr.set_on_host(false);
device_allocator_ = device_->GetAllocator(device_alloc_attr);
}
protected:
std::unique_ptr<Device> device_;
Allocator* host_allocator_;
Allocator* device_allocator_;
};
TEST_F(XlaHostSendRecvDeviceContextTest, CopyDeviceTensorToCPU) {
SetDevice("GPU");
Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5});
Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2}));
Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
stream_executor::Platform* platform =
stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value();
stream_executor::StreamExecutor* executor =
platform->ExecutorForDevice(0).value();
stream_executor::Stream stream(executor);
stream.Init();
ASSERT_TRUE(stream.ok());
se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
xla::Shape shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape));
// Copy the cpu_tensor to the GPU first before trying to copy it back.
stream.ThenMemcpy(&gpu_dst, origin_cpu_tensor.data(), gpu_dst.size());
TF_ASSERT_OK(stream.BlockHostUntilDone());
tsl::AsyncValueRef<se::Event> done_event =
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
done_event->Init();
XlaHostRecvDeviceContext* device_context =
new XlaHostRecvDeviceContext(&stream, gpu_dst, shape, done_event);
TF_ASSERT_OK(device_context->CopyDeviceTensorToCPUSync(
&device_tensor, "", device_.get(), &dest_cpu_tensor));
tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor);
device_context->Unref();
}
TEST_F(XlaHostSendRecvDeviceContextTest, CopyCPUTensorToDevice) {
SetDevice("GPU");
Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5});
Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2}));
Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
stream_executor::Platform* platform =
stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value();
stream_executor::StreamExecutor* executor =
platform->ExecutorForDevice(0).value();
stream_executor::Stream stream(executor);
stream.Init();
ASSERT_TRUE(stream.ok());
se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
xla::Shape shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape));
tsl::AsyncValueRef<se::Event> done_event =
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
done_event->Init();
XlaHostSendDeviceContext* device_context =
new XlaHostSendDeviceContext(&stream, &gpu_dst, shape, done_event);
TF_ASSERT_OK(device_context->CopyCPUTensorToDeviceSync(
&origin_cpu_tensor, device_.get(), &device_tensor));
// Copy the GPU tensor back to CPU to check that copy worked.
stream.ThenMemcpy(dest_cpu_tensor.data(), gpu_dst, gpu_dst.size());
TF_ASSERT_OK(stream.BlockHostUntilDone());
tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor);
device_context->Unref();
}
TEST_F(XlaHostSendRecvDeviceContextTest, RoundTrip) {
SetDevice("GPU");
Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5});
Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2}));
Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2}));
stream_executor::Platform* platform =
stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value();
stream_executor::StreamExecutor* executor =
platform->ExecutorForDevice(0).value();
stream_executor::Stream stream(executor);
stream.Init();
ASSERT_TRUE(stream.ok());
se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)};
xla::Shape shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape));
tsl::AsyncValueRef<se::Event> send_done_event =
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
send_done_event->Init();
XlaHostSendDeviceContext* send_device_context =
new XlaHostSendDeviceContext(&stream, &gpu_dst, shape, send_done_event);
TF_ASSERT_OK(send_device_context->CopyCPUTensorToDeviceSync(
&origin_cpu_tensor, device_.get(), &device_tensor));
tsl::AsyncValueRef<se::Event> recv_done_event =
tsl::MakeConstructedAsyncValueRef<se::Event>(stream.parent());
recv_done_event->Init();
XlaHostRecvDeviceContext* recv_device_context =
new XlaHostRecvDeviceContext(&stream, gpu_dst, shape, recv_done_event);
TF_ASSERT_OK(recv_device_context->CopyDeviceTensorToCPUSync(
&device_tensor, "", device_.get(), &dest_cpu_tensor));
tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor);
send_device_context->Unref();
recv_device_context->Unref();
}
} // namespace
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
@ -136,7 +137,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
EXPECT_TRUE(absl::IsInternal(status)) << status;
}
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
@ -153,7 +154,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
EXPECT_TRUE(absl::IsInternal(status)) << status;
}
} // namespace tensorflow

View File

@ -18,17 +18,21 @@ limitations under the License.
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/jit/device_executable_persistor.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h"
#include "tensorflow/compiler/jit/xla_compile_util.h"
#include "tensorflow/compiler/jit/xla_device_compiler_client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tfrt/common/create_pjrt_client_util.h"
#include "tensorflow/core/tfrt/common/global_state.h"
#include "tensorflow/core/tfrt/common/pjrt_util.h"
#include "tensorflow/core/tpu/tpu_defs.h"
@ -45,19 +49,23 @@ using PjRtDeviceExecutablePersistor =
XlaDeviceCompiler* CreateXlaDeviceCompiler(
const XlaDeviceExecutablePersistor::Config& persistor_config,
DeviceType device_type, xla::LocalClient* local_client) {
DeviceType compilation_device_type, xla::LocalClient* local_client) {
return new XlaDeviceCompiler(
std::make_unique<XlaDeviceExecutablePersistor>(
std::move(persistor_config), device_type),
std::move(persistor_config), compilation_device_type),
std::make_unique<XlaDeviceCompilerClient>(local_client));
}
PjRtDeviceCompiler* CreatePjRtDeviceCompiler(
const PjRtDeviceExecutablePersistor::Config& persistor_config,
DeviceType device_type, xla::PjRtClient* pjrt_client) {
PjRtDeviceCompiler* CreatePjRtDeviceCompiler(DeviceType compilation_device_type,
xla::PjRtClient* pjrt_client) {
PjRtDeviceExecutablePersistor::Config persistor_config(
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);
return new PjRtDeviceCompiler(
std::make_unique<PjRtDeviceExecutablePersistor>(
std::move(persistor_config), device_type),
std::move(persistor_config), compilation_device_type),
std::make_unique<PjRtDeviceCompilerClient>(pjrt_client));
}
@ -73,6 +81,60 @@ StatusOr<std::optional<std::set<int>>> GetAllowedGpus(
return gpu_ids;
}
Status GetCompilationDeviceTypeAndPjRtClient(
const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr,
DeviceType* compilation_device_type, xla::PjRtClient** pjrt_client) {
DeviceType device_type = platform_info.device_type();
if (platform_info.xla_device_metadata()) {
VLOG(2) << "Building PjRtDeviceCompiler using "
"platform_info.xla_device_metadata().";
*compilation_device_type =
platform_info.xla_device_metadata()->jit_device_type();
TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type));
return OkStatus();
}
if (platform_info.pjrt_device_metadata()) {
VLOG(2) << "Building PjRtDeviceCompiler using "
"platform_info.pjrt_device_metadata().";
*compilation_device_type =
platform_info.pjrt_device_metadata()->jit_device_type();
TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type));
return OkStatus();
}
// TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not
// have `xla_device_metadata`.
if (device_type == DEVICE_TPU) {
*compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type));
return OkStatus();
}
VLOG(2) << "platform_info.xla_device_metadata not found and "
"platform_info.device_type() != DEVICE_TPU. Building "
"PjRtDeviceCompiler for non-XLA device.";
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return errors::InvalidArgument("No JIT device registered for ",
device_type.type());
}
*compilation_device_type = DeviceType(registration->compilation_device_name);
TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr));
// TODO(b/255826209): Set platform, intra op parallelism threads if required
// and when supported by GetOrCreatePjRtClient().
// The `allowed_gpus` argument is used only if the `device_type` is GPU.
TF_ASSIGN_OR_RETURN(*pjrt_client,
GetOrCreatePjRtClient(device_type, allowed_gpus));
return OkStatus();
}
} // namespace
xla::StatusOr<std::optional<std::set<int>>> ParseVisibleDeviceList(
@ -175,71 +237,45 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr,
return OkStatus();
}
Status BuildPjRtDeviceCompiler(const XlaPlatformInfo& platform_info,
FunctionLibraryRuntime* flr,
PjRtDeviceCompiler** pjrt_device_compiler) {
PjRtDeviceExecutablePersistor::Config persistor_config(
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory,
GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks,
GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix);
Status GetOrCreatePjRtDeviceCompilerAndProfiler(
const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr,
PjRtDeviceCompiler** pjrt_device_compiler,
DeviceCompilationProfiler** profiler) {
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = tfrt_global::GetTFGlobalResourceMgr();
DeviceType device_type = platform_info.device_type();
const auto& device_type = platform_info.device_type();
const std::string& compiler_name =
GetPjRtDeviceCompilerResourceName(device_type);
if (platform_info.xla_device_metadata()) {
VLOG(2) << "Building PjRtDeviceCompiler using "
"platform_info.xla_device_metadata().";
// Lookup the DeviceCompiler, create one if not found.
Status s = rm->Lookup<PjRtDeviceCompiler>(
rm->default_container(), compiler_name, pjrt_device_compiler);
if (!s.ok()) {
DeviceType compilation_device_type("");
xla::PjRtClient* pjrt_client = nullptr;
TF_RETURN_IF_ERROR(GetCompilationDeviceTypeAndPjRtClient(
platform_info, flr, &compilation_device_type, &pjrt_client));
DeviceType compilation_device_type =
platform_info.xla_device_metadata()->jit_device_type();
TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
persistor_config, compilation_device_type, pjrt_client);
return OkStatus();
}
if (platform_info.pjrt_device_metadata()) {
VLOG(2) << "Building PjRtDeviceCompiler using "
"platform_info.pjrt_device_metadata().";
DeviceType compilation_device_type =
platform_info.pjrt_device_metadata()->jit_device_type();
TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
persistor_config, compilation_device_type, pjrt_client);
return OkStatus();
TF_RETURN_IF_ERROR(rm->LookupOrCreate<PjRtDeviceCompiler>(
rm->default_container(), compiler_name, pjrt_device_compiler,
[&](PjRtDeviceCompiler** pjrt_device_compiler) {
*pjrt_device_compiler =
CreatePjRtDeviceCompiler(compilation_device_type, pjrt_client);
return OkStatus();
}));
}
// TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not
// have `xla_device_metadata`.
if (device_type == DEVICE_TPU) {
TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
persistor_config, DeviceType(DEVICE_TPU_XLA_JIT), pjrt_client);
return OkStatus();
}
const std::string& profiler_name =
GetPjRtDeviceCompilationProfilerResourceName(device_type);
TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
rm->default_container(), profiler_name, profiler,
[](DeviceCompilationProfiler** profiler) {
*profiler = new DeviceCompilationProfiler();
return OkStatus();
}));
VLOG(2) << "platform_info.xla_device_metadata not found and "
"platform_info.device_type() != DEVICE_TPU. Building "
"PjRtDeviceCompiler for non-XLA device.";
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return errors::InvalidArgument("No JIT device registered for ",
device_type.type());
}
auto compilation_device_type =
DeviceType(registration->compilation_device_name);
TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr));
// TODO(b/255826209): Set platform, intra op parallelism threads if required
// and when supported by GetOrCreatePjRtClient().
// The `allowed_gpus` argument is used only if the `device_type` is GPU.
TF_ASSIGN_OR_RETURN(auto pjrt_client,
GetOrCreatePjRtClient(device_type, allowed_gpus));
*pjrt_device_compiler = CreatePjRtDeviceCompiler(
persistor_config, compilation_device_type, pjrt_client);
return OkStatus();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include "tensorflow/compiler/jit/device_compiler.h"
#include "tensorflow/compiler/jit/pjrt_base_device.h"
@ -113,17 +114,21 @@ Status BuildXlaDeviceCompiler(
DeviceCompiler<xla::LocalExecutable, xla::LocalClient>**
xla_device_compiler);
// Builds a DeviceCompiler that uses xla::PjRtClient using an appropriate
// Fetches a DeviceCompiler from the tfrt_global resource manager (or creates
// one there if not found) that uses xla::PjRtClient using an appropriate
// PjRtClient for `platform_info.device_type()` and sets *pjrt_device_compiler
// to point to it. Uses flags from `MarkForCompilationPassFlags` for configuring
// the persistor used in the DeviceCompiler. Please note that non-XLA devices
// aren't supported yet. This is because:
// to point to it. Also fetches/creates a DeviceCompilationProfiler from/in the
// tfrt_global resource manager for `platform_info.device_type()` and sets
// *profiler to point to it. Uses flags from `MarkForCompilationPassFlags` for
// configuring the persistor used in the DeviceCompiler. Please note that
// non-XLA devices aren't supported yet. This is because:
// 1. PjRtClient doesn't support data transfer for non-XLA devices yet
// 2. Fetching the PjRtClient for non-XLA devices is also not supported yet
Status BuildPjRtDeviceCompiler(
Status GetOrCreatePjRtDeviceCompilerAndProfiler(
const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr,
DeviceCompiler<xla::PjRtLoadedExecutable, xla::PjRtClient>**
pjrt_device_compiler);
pjrt_device_compiler,
DeviceCompilationProfiler** profiler);
// Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);

View File

@ -81,7 +81,7 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerNonXlaDevice) {
EXPECT_TRUE(xla_device_compiler->client() != nullptr);
}
TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestXlaDevice) {
TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerXlaDevice) {
DeviceType device_type = DeviceType(DEVICE_XLA_GPU);
device_setup_.AddDevicesAndSetUp({device_type.type()});
@ -91,23 +91,27 @@ TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestXlaDevice) {
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device);
PjRtDeviceCompiler* pjrt_device_compiler = nullptr;
TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(),
&pjrt_device_compiler));
DeviceCompilationProfiler* profiler = nullptr;
TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler(
platform_info, device_setup_.flr(), &pjrt_device_compiler, &profiler));
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
core::ScopedUnref profiler_ref(profiler);
TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, GetOrCreatePjRtClient(device_type));
EXPECT_EQ(pjrt_device_compiler->device_type(), metadata->jit_device_type());
EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client);
}
TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestGpuDevice) {
TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerGpuDevice) {
device_setup_.AddDevicesAndSetUp({DEVICE_GPU});
Device* device = device_setup_.GetDevice(DEVICE_GPU);
XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device);
PjRtDeviceCompiler* pjrt_device_compiler = nullptr;
TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(),
&pjrt_device_compiler));
DeviceCompilationProfiler* profiler = nullptr;
TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler(
platform_info, device_setup_.flr(), &pjrt_device_compiler, &profiler));
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
core::ScopedUnref profiler_ref(profiler);
}
#endif
@ -138,7 +142,7 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerTpuDevice) {
// TODO(b/255826209): Look into using an actual TPU device for the unit test,
// and move this out of OSS.
TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTpuDevice) {
TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerTpuDevice) {
DeviceType device_type = DeviceType(DEVICE_TPU);
DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT);
// Use a CPU PjRtClient instead of a TPU one just for testing whether
@ -158,9 +162,11 @@ TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTpuDevice) {
/*device_allocator=*/nullptr);
PjRtDeviceCompiler* pjrt_device_compiler = nullptr;
TF_EXPECT_OK(
BuildPjRtDeviceCompiler(platform_info, nullptr, &pjrt_device_compiler));
DeviceCompilationProfiler* profiler = nullptr;
TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler(
platform_info, nullptr, &pjrt_device_compiler, &profiler));
core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler);
core::ScopedUnref profiler_ref(profiler);
EXPECT_EQ(pjrt_device_compiler->device_type(), compilation_device_type);
EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client);

View File

@ -272,6 +272,10 @@ compiling to XLA.
### `-tf-embedding-pipelining`: Rewrite graph for embedding pipelining
For architectures that support accelerated embedding lookups, this pass will
rewrite the graph to use pipelining for better device utilization.
### `-tf-embedding-sequencing`: Rewrite graph for sequential execution of embeddings
This is a strictly sequential and formally correct fallback option for the
embedding pipelining pass intended for debugging during pipelining
development.
### `-tf-executor-break-up-islands`: Transform from TF control dialect to TF executor dialect.
### `-tf-executor-check-control-dependencies`: Checks control dependencies
This pass analyzes control dependencies between islands and warns about
@ -1993,4 +1997,21 @@ This pass will transform it into
### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph
Verifies whether all functions in module are of single tf_executor.graph and
each tf_executor.island in tf_executor.graph only has a single op.
### `-tf-xla-call-module-deserialization`: Deserializes StableHLO functions embedded in `tf.XlaCallModule` to top level module
This pass deserializes the StableHLO bytecodes embedded in tf.XlaCallModule,
then outlines the functions in the deserialized StableHLO module to the top
level MLIR module, with function renamings to avoid naming conflicts.
After the outlining, it updates tf.XlaCallModule's module attribute to be
empty, adds an `_entry_function` attribute referring to the entry function.
It also adds a `_from_xla_call_module: true` attribute to each lifted
StableHLO function.
### `-tf-xla-call-module-serialization`: Serializes StableHLO functions from top-level module into `tf.XlaCallModule`'s `module` attribute
This pass collects StableHLO functions referenced from `tf.XlaCallModule`'s
`_entry_function` attribute into a module, serializes the module into MLIR
bytecode, and embed the bytecode to `tf.XlaCallModule`'s `module` attribute.
After serialization, this pass removes the `_entry_function` attribute from
`tf.XlaCallModule`, and removes all the serialized stablehlo functions
from the top-level module.
### `-tfe-legalize-tfg`: Legalize from TFG to the TFE dialect

View File

@ -64,6 +64,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties):
)
def glob_lit_tests(
name = None,
exclude = [],
test_file_exts = _default_test_file_exts,
default_size = _default_size,
@ -78,6 +79,7 @@ def glob_lit_tests(
"""Creates all plausible Lit tests (and their inputs) under this directory.
Args:
name: str, name of the test_suite rule to generate for running all tests.
exclude: [str], paths to exclude (for tests and inputs).
test_file_exts: [str], extensions for files that are tests.
default_size: str, the test size for targets not in "size_override".
@ -103,7 +105,10 @@ def glob_lit_tests(
# Run tests individually such that errors can be attributed to a specific
# failure.
all_tests = []
for curr_test in tests:
all_tests.append(curr_test + ".test")
# Instantiate this test with updated parameters.
_run_lit_test(
name = curr_test + ".test",
@ -114,3 +119,11 @@ def glob_lit_tests(
features = features,
exec_properties = exec_properties,
)
# TODO: remove this check after making it a required param.
if name:
native.test_suite(
name = name,
tests = all_tests,
tags = ["manual"],
)

View File

@ -342,8 +342,6 @@ cc_library(
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
@ -1235,6 +1233,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes",
"//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass",
"//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",

View File

@ -83,7 +83,9 @@ struct PassConfig {
// Whether to run the `GuaranteeAllFuncsOneUsePass` to ensure each function
// has a single use.
bool guarantee_all_funcs_one_use;
// Whether to enable the hlo to tf conversion.
// Whether to enable the hlo/stablehlo to tf conversion. This also supports
// the case where a saved model contains both TF module and serialized
// StableHLO module.
bool enable_hlo_to_tf_conversion;
// Whether to enable to use DynamicUpdateSlice op.
bool enable_dynamic_update_slice;

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/emit_error_reporter.h"
#include <cstdio>
#include <vector>
namespace tflite {
int EmitErrorReporter::Report(const char* format, va_list args) {

View File

@ -1,5 +1,7 @@
load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud")
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
cc_library(
name = "outline_operations",
srcs = ["outline_operations.cc"],

View File

@ -1,5 +1,7 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
cc_library(
name = "rematerializer",
srcs = ["rematerializer.cc"],

View File

@ -232,6 +232,7 @@ cc_library(
"transforms/get_alternative_subgraph.cc",
"transforms/pick_subgraphs.cc",
"transforms/raise_target_subgraphs.cc",
"transforms/tac_filter.cc",
"transforms/target_annotation.cc",
],
hdrs = [
@ -243,6 +244,7 @@ cc_library(
":common",
":cost_model",
":device_transform",
":tac_filter_cc_proto",
":tac_importer_exporter",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
@ -253,6 +255,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf_headers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
@ -380,3 +383,15 @@ py_library(
"//tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper:_pywrap_tac_wrapper",
],
)
proto_library(
name = "tac_filter_proto",
srcs = ["tac_filter.proto"],
compatible_with = get_compatible_with_cloud(),
)
cc_proto_library(
name = "tac_filter_cc_proto",
compatible_with = get_compatible_with_cloud(),
deps = [":tac_filter_proto"],
)

Some files were not shown because too many files have changed in this diff Show More