diff --git a/.bazelrc b/.bazelrc index 391fc927c27..f137258fa26 100644 --- a/.bazelrc +++ b/.bazelrc @@ -194,6 +194,9 @@ build:macos --apple_platform_type=macos # gRPC on MacOS requires this #define build:macos --copt=-DGRPC_BAZEL_BUILD +# Avoid hitting command line argument limit +build:macos --features=archive_param_file + # Settings for MacOS on ARM CPUs. build:macos_arm64 --cpu=darwin_arm64 build:macos_arm64 --macos_minimum_os=11.0 @@ -345,6 +348,7 @@ build:windows --host_copt=/D_USE_MATH_DEFINES # Windows has a relatively short command line limit, which TF has begun to hit. # See https://docs.bazel.build/versions/main/windows.html build:windows --features=compiler_param_file +build:windows --features=archive_param_file # Speed Windows compile times. Available in VS 16.4 (we are on 16.11). See # https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion @@ -446,7 +450,6 @@ build:rbe --bes_backend=buildeventservice.googleapis.com build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" build:rbe --bes_timeout=600s build:rbe --define=EXECUTOR=remote -build:rbe --flaky_test_attempts=3 build:rbe --jobs=800 build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com build:rbe --remote_timeout=3600 @@ -627,7 +630,6 @@ try-import %workspace%/.bazelrc.user # Here are bazelrc configs for release builds build:release_base --config=v2 -test:release_base --flaky_test_attempts=3 test:release_base --test_size_filters=small,medium build:release_cpu_linux --config=release_base @@ -691,10 +693,10 @@ build:ubsan --linkopt -fsanitize=undefined build:ubsan --linkopt -lubsan # Disable TFRT integration for now unless --config=tfrt is specified. -build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug +build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python # TODO(b/240450920): We are in the process of migrating JitRt backend to XLA # and while we are doing this we can't keep it buildable/testable in OSS. -build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug +build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python # TF Fuzztest config try-import fuzztest.bazelrc diff --git a/.bazelversion b/.bazelversion index f53152b50eb..b536fbc5061 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1,2 +1,2 @@ -5.3.0 +6.1.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/.github/bot_config.yml b/.github/bot_config.yml index b5cf2a5a6c2..b90b4f52c56 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,7 +15,7 @@ # A list of assignees assignees: - - synandi + - sushreebarsa - SuryanarayanaY - tilakrayal # A list of assignees for compiler folder diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index b601b0054c7..a191c65a98f 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -28,6 +28,7 @@ jobs: runs-on: [self-hosted, linux, ARM64] continue-on-error: ${{ matrix.experimental }} strategy: + fail-fast: false matrix: pyver: ['3.8', '3.9', '3.10'] experimental: [false] diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 7e32dafabe9..faba79089b8 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -27,8 +27,9 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks runs-on: [self-hosted, linux, ARM64] strategy: + fail-fast: false matrix: - pyver: ['3.10'] + pyver: ['3.8', '3.9', '3.10', '3.11'] steps: - name: Stop old running containers (if any) shell: bash diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index d32d7affd64..965e3515b84 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -92,6 +92,18 @@ jobs: map sigbuild-r2.13-clang-python3.9 2.13-python3.9 map sigbuild-r2.13-clang-python3.10 2.13-python3.10 map sigbuild-r2.13-clang-python3.11 2.13-python3.11 + # TF 2.14 + map sigbuild-r2.14 2.14-python3.9 + map sigbuild-r2.14-python3.8 2.14-python3.8 + map sigbuild-r2.14-python3.9 2.14-python3.9 + map sigbuild-r2.14-python3.10 2.14-python3.10 + map sigbuild-r2.14-python3.11 2.14-python3.11 + # TF 2.14 + Clang (containers are the same, but env vars in configs.bzl are different) + map sigbuild-r2.14-clang 2.14-python3.9 + map sigbuild-r2.14-clang-python3.8 2.14-python3.8 + map sigbuild-r2.14-clang-python3.9 2.14-python3.9 + map sigbuild-r2.14-clang-python3.10 2.14-python3.10 + map sigbuild-r2.14-clang-python3.11 2.14-python3.11 - name: Create Pull Request with changes uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3 with: diff --git a/README.md b/README.md index fa7a6c45733..d0feb038bc0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ -[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow) +[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg)](https://badge.fury.io/py/tensorflow) [![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.4724125.svg)](https://doi.org/10.5281/zenodo.4724125) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) @@ -11,6 +11,8 @@ [![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow-py.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow-py) [![OSSRank](https://shields.io/endpoint?url=https://ossrank.com/shield/44)](https://ossrank.com/p/44) [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md) +[![TF Official Continuous](https://tensorflow.github.io/build/TF%20Official%20Continuous.svg)](https://tensorflow.github.io/build#TF%20Official%20Continuous) +[![TF Official Nightly](https://tensorflow.github.io/build/TF%20Official%20Nightly.svg)](https://tensorflow.github.io/build#TF%20Official%20Nightly) **`Documentation`** | ------------------- | diff --git a/RELEASE.md b/RELEASE.md index 87ebf46e557..c404a6183ae 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -16,6 +16,11 @@ 2.13 may be used when it is necessary to determine if a value is specifically a symbolic tensor. +* `tf.compat.v1.Session` + * `tf.compat.v1.Session.partial_run` and + `tf.compat.v1.Session.partial_run_setup` will be deprecated in the + next release. + # Known Caveats * @@ -26,6 +31,15 @@ * * +* `tf.keras` + * `Model.compile` now support `steps_per_execution='auto'` as a parameter, + allowing automatic tuning of steps per execution during `Model.fit`, + `Model.predict`, and `Model.evaluate` for a significant performance boost. + +* Enable JIT-compiled i64-indexed kernels on GPU for large tensors with more + than 2**32 elements. + * Unary GPU kernels: Abs, Atanh, Acos, Acosh, Asin, Asinh, Atan, Cos, + Cosh, Sin, Sinh, Tan, Tanh. # Bug Fixes and Other Changes * `tf.lite` @@ -34,6 +48,22 @@ * * +* `tf.config.experimental.enable_tensor_float_32_execution` + * Disabling TensorFloat-32 execution now causes TPUs to use float32 + precision for float32 matmuls and other ops. TPUs have always used + bfloat16 precision for certain ops, like matmul, when such ops had float32 + inputs. Now, disabling TensorFloat-32 by calling + `tf.config.experimental.enable_tensor_float_32_execution(False)` will + cause TPUs to use float32 precision for such ops instead of bfloat16. + +* `tf.experimental.dtensor` + * API changes for Relayout. Added a new API, `dtensor.relayout_like`, for + relayouting a tensor according to the layout of another tensor. + * Added `dtensor.get_default_mesh`, for retrieving the current default + mesh under the dtensor context. + +* TensorFlow Debugger (tfdbg) CLI: ncurses-based CLI for tfdbg v1 was removed. + # Thanks to our Contributors This release contains contributions from many people at Google, as well as: @@ -185,6 +215,9 @@ This release contains contributions from many people at Google, as well as: `dataset = dataset.shuffle(dataset.cardinality())`. This will load the full dataset into memory so that it can be shuffled, so make sure to only use this with datasets of filenames or other small datasets. + * Added a new `tf.data.experimental.pad_to_cardinality` transformation + which pads a dataset with zero elements up to a specified cardinality. + This is useful for avoiding partial batches while not dropping any data. * `tf.math` @@ -243,6 +276,8 @@ This release contains contributions from many people at Google, as well as: * `tf.lite`: * Add UINT32 support to tfl.pack + * Add INT64 support to tfl.range + * Add UINT32 support to tfl.concatenation ## Thanks to our Contributors diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 00000000000..1dc705f8e35 --- /dev/null +++ b/ci/README.md @@ -0,0 +1,17 @@ +# TensorFlow continuous integration + +> **Warning** This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> TensorFlow repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +Maintainer: TensorFlow DevInfra + +******************************************************************************** + +The CI folder contains the configuration files and scripts used to build, test, +and deploy TensorFlow. This folder is typically used by continuous integration +(CI) tools to build and test TensorFlow whenever there is a change to the +code. This folder is broken into subfolders that represent the level of support +and ownership of the files contained within. diff --git a/ci/devinfra/README.md b/ci/devinfra/README.md new file mode 100644 index 00000000000..c31d50b87a6 --- /dev/null +++ b/ci/devinfra/README.md @@ -0,0 +1,17 @@ +# DevInfra CI Directory + +> **Warning** This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> TensorFlow repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +Maintainer: TensorFlow DevInfra + +Issue Reporting: File an issue against this repo and tag +[@devinfra](https://github.com/orgs/tensorflow/teams/devinfra) + +******************************************************************************** + +A directory for build and CI related scripts and jobs managed by the TensorFlow +DevInfra team but not part of the official build, test, or release process. diff --git a/ci/official/README.md b/ci/official/README.md new file mode 100644 index 00000000000..2bd578c0160 --- /dev/null +++ b/ci/official/README.md @@ -0,0 +1,17 @@ +# Official CI Directory + +> **Warning** This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> TensorFlow repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +Maintainer: TensorFlow and TensorFlow DevInfra + +Issue Reporting: File an issue against this repo and tag +[@devinfra](https://github.com/orgs/tensorflow/teams/devinfra) + +******************************************************************************** + +A directory for build and CI related scripts and jobs that are used and +monitored as part of the official TensorFlow build, test, and release process. diff --git a/configure.py b/configure.py index 73e124fb356..47b566a9c0f 100644 --- a/configure.py +++ b/configure.py @@ -964,7 +964,6 @@ def set_other_cuda_vars(environ_cp): def system_specific_test_config(environ_cp): """Add default build and test flags required for TF tests to bazelrc.""" - write_to_bazelrc('test --flaky_test_attempts=3') write_to_bazelrc('test --test_size_filters=small,medium') # Each instance of --test_tag_filters or --build_tag_filters overrides all diff --git a/tensorflow/BUILD b/tensorflow/BUILD index fce465ff1f2..a014c90df67 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -109,6 +109,7 @@ PACKAGE_STATIC_DEPS = [ "@local_execution_config_platform//:__subpackages__", "@mkl_dnn_acl_compatible//:__subpackages__", "@mkl_dnn_v1//:__subpackages__", + "@ml_dtypes//:__subpackages__", "@nccl_archive//:__subpackages__", "@nvtx_archive//:__subpackages__", "@org_sqlite//:__subpackages__", @@ -1036,7 +1037,13 @@ package_group( ], ) -package_group(name = "ndarray_tensor_allow_list") +package_group( + name = "ndarray_tensor_allow_list", + packages = [ + "//third_party/py/courier/...", + "//third_party/py/tensorfn/...", + ], +) # Packages that use private types symbols, until they are exported. # TODO(b/154650521) Remove. diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 0e70244453f..f52e342da94 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -129,6 +129,7 @@ cc_library( # TODO: Only include tf_tstring_hdrs. Don't expose the implementation of TF_TString to API # users. ":tf_tstring", + "//tensorflow/core:protos_all_cc", ], ) @@ -171,6 +172,7 @@ tf_cuda_library( ":tf_buffer_internal", ":tf_status_internal", ":tf_tensor_internal", + "//tensorflow/core:protos_all_cc", ], ) @@ -238,6 +240,7 @@ tf_cuda_library( ":tf_status_internal", ":tf_tensor_internal", ":tf_tstring", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:tstring", "//tensorflow/tsl/c:tsl_status", ] + select({ @@ -881,7 +884,7 @@ tf_cc_test( tf_cuda_cc_test( name = "c_api_test", - size = "small", + size = "medium", srcs = ["c_api_test.cc"], data = [ ":test_op1.so", @@ -968,7 +971,7 @@ tf_cc_test( tf_cc_test( name = "c_api_function_test", - size = "small", + size = "medium", srcs = ["c_api_function_test.cc"], deps = [ ":c_api", @@ -985,7 +988,7 @@ tf_cc_test( tf_cc_test( name = "while_loop_test", - size = "small", + size = "medium", srcs = ["while_loop_test.cc"], deps = [ ":c_api", @@ -1013,7 +1016,7 @@ tf_kernel_library( tf_cuda_cc_test( name = "env_test", - size = "small", + size = "medium", srcs = ["env_test.cc"], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], @@ -1032,7 +1035,7 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "kernels_test", - size = "small", + size = "medium", srcs = ["kernels_test.cc"], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], @@ -1059,7 +1062,7 @@ tf_cuda_cc_test( tf_cc_test( name = "ops_test", - size = "small", + size = "medium", srcs = ["ops_test.cc"], linkopts = select({ "//conditions:default": [], diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 92f63553ee1..15d279b61ac 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -142,7 +142,7 @@ struct TF_ImportGraphDefOptions { // Backing memory for TensorId fields in opts. // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. - std::list tensor_id_data; + std::vector tensor_id_data; }; struct TF_ImportGraphDefResults { @@ -152,7 +152,7 @@ struct TF_ImportGraphDefResults { std::vector missing_unused_key_indexes; // Backing memory for missing_unused_key_names values. - std::list missing_unused_key_names_data; + std::vector missing_unused_key_names_data; }; struct TF_DeviceList { diff --git a/tensorflow/c/c_api_macros.h b/tensorflow/c/c_api_macros.h index e0c91a0d549..d73546aed16 100644 --- a/tensorflow/c/c_api_macros.h +++ b/tensorflow/c/c_api_macros.h @@ -26,7 +26,12 @@ limitations under the License. #define TF_CAPI_EXPORT __declspec(dllimport) #endif // TF_COMPILE_LIBRARY #else +#ifdef TF_CAPI_WEAK +#define TF_CAPI_EXPORT \ + __attribute__((visibility("default"))) __attribute((weak)) +#else #define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // TF_CAPI_WEAK #endif // _WIN32 #endif // SWIG diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index dd61bd26bc1..748d49565f6 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -8,7 +8,7 @@ load( "tf_cuda_cc_test", "tf_cuda_library", ) -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "internal_tfrt_deps") +load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup") load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", @@ -95,7 +95,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", - ] + internal_tfrt_deps(), + ], alwayslink = 1, ) @@ -636,7 +636,7 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", - size = "small", + size = "medium", srcs = [ "c_api_debug_test.cc", "c_api_test.cc", @@ -653,7 +653,6 @@ tf_cuda_cc_test( ":c_api_test_util", ":tfe_op_internal", ":tfe_tensorhandle_internal", - "@com_google_absl//absl/strings", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -663,10 +662,7 @@ tf_cuda_cc_test( "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - # copybara:uncomment_begin - # "//tensorflow/core/tfrt/eager:c_api_tfrt", - # "@tf_runtime//backends/cpu:tf_ops_alwayslink", - # copybara:uncomment_end + "@com_google_absl//absl/strings", ], ) @@ -693,7 +689,7 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_remote_test", - size = "small", + size = "medium", srcs = [ "c_api_remote_test.cc", ], @@ -725,7 +721,7 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "c_api_remote_function_test", - size = "small", + size = "medium", srcs = [ "c_api_remote_function_test.cc", ], @@ -776,7 +772,7 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "c_api_cluster_test", - size = "small", + size = "medium", srcs = [ "c_api_cluster_test.cc", ], @@ -1014,7 +1010,7 @@ cc_library( tf_cc_test( name = "custom_device_test", - size = "small", + size = "medium", srcs = [ "custom_device_test.cc", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8503485f63c..41ced14455e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -64,13 +64,6 @@ limitations under the License. #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/public/version.h" -// "tensorflow/core/platform/platform.h" must be included first before using -// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc. -#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \ - !defined(PLATFORM_FUCHSIA) -#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" -#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA - #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" #endif // !IS_MOBILE_PLATFORM @@ -117,18 +110,8 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { -#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \ - !defined(PLATFORM_FUCHSIA) - tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - opts->async); - return tensorflow::wrap(tfrt_context); -#else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; -#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA } std::vector> devices; status->status = tensorflow::DeviceFactory::AddDevices( diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index e35bc962525..13b688889a4 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -434,7 +434,7 @@ class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { tensorflow::Status Run(const std::string& function_name, const tensorflow::DeviceSet& device_set, const tensorflow::ConfigProto& config_proto, - absl::string_view xla_compile_device_type, + const FunctionOptions& function_options, std::unique_ptr* graph, tensorflow::FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 0f8c97ce7ba..254648d9e09 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -47,10 +47,6 @@ limitations under the License. #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#ifdef PLATFORM_GOOGLE -#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" -#endif - using tensorflow::string; namespace { @@ -1262,16 +1258,6 @@ TEST(CAPI, RunAddFunctionWithGrappler) { RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true); } -#ifdef PLATFORM_GOOGLE -TEST(CAPI, RunAddFunction_TFRT) { - RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false); -} - -TEST(CAPI, RunAddFunctionWithGrappler_TFRT) { - RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true); -} -#endif - void BM_ExecuteFunction(::testing::benchmark::State& state) { const int async = state.range(0); state.SetLabel(async ? "ExecuteFunctionAsync" : "ExecuteFunction"); @@ -1802,23 +1788,9 @@ void TestOpAddAttrs(bool use_tfrt) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); tensorflow::AttrValueMap attr_values; - if (use_tfrt) { -#ifdef PLATFORM_GOOGLE - auto* op = tensorflow::down_cast( - tensorflow::unwrap(copy_op)); - auto* tfrt_op_attrs = - tensorflow::down_cast( - op->GetOpAttrs()); - tensorflow::DataType result; - tfrt_op_attrs->GetType("dtype", &result); - EXPECT_EQ(tensorflow::DT_FLOAT, result); - tfrt_op_attrs->GetFallbackAttrs()->FillAttrValueMap(&attr_values); -#endif - } else { - tensorflow::EagerOperation* op = - tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op)); - op->Attrs().FillAttrValueMap(&attr_values); - } + tensorflow::EagerOperation* op = + tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op)); + op->Attrs().FillAttrValueMap(&attr_values); EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type()); TF_DeleteStatus(status); @@ -1829,11 +1801,6 @@ void TestOpAddAttrs(bool use_tfrt) { TEST(CAPI, TestTFE_OpAddAttrs) { TestOpAddAttrs(/*use_tfrt=*/false); } -#ifdef PLATFORM_GOOGLE -TEST(CAPI, TestTFE_OpAddAttrs_TFRT) { TestOpAddAttrs(/*use_tfrt=*/true); } - -#endif - TEST(CAPI, TestTFE_OpAttrsSerialize) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index edaf3d8e579..e866ec0ca78 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -1006,18 +1006,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) { // The above tests are run for a combination of: // - graphdef and MLIR tracing engine -// - Using TFRT as an execution runtime (true == enable TFRT) -#ifdef PLATFORM_GOOGLE -INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, - ::testing::Combine(::testing::Values("graphdef", - "mlir"), - ::testing::Values(true, false))); -#else INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Combine(::testing::Values("graphdef", "mlir"), ::testing::Values(false))); -#endif } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/unified_api_test.cc b/tensorflow/c/eager/unified_api_test.cc index 27e42be5bcc..fce632b2210 100644 --- a/tensorflow/c/eager/unified_api_test.cc +++ b/tensorflow/c/eager/unified_api_test.cc @@ -188,18 +188,10 @@ TEST_P(UnifiedAPI, TestPartialShapeTracing) { ASSERT_EQ(-1, shape.dim_size(1)); } -#ifdef PLATFORM_GOOGLE -INSTANTIATE_TEST_SUITE_P( - UnifiedCppAPI, UnifiedAPI, - ::testing::Combine(::testing::Values("graphdef", "mlir"), - /*tfrt*/ ::testing::Values(true, false), - /*use_function*/ ::testing::Values(true, false))); -#else INSTANTIATE_TEST_SUITE_P( UnifiedCppAPI, UnifiedAPI, ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), /*use_function*/ ::testing::Values(true, false))); -#endif } // namespace } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/custom_gradient_test.cc b/tensorflow/c/experimental/gradients/custom_gradient_test.cc index cce9a051a74..02066362892 100644 --- a/tensorflow/c/experimental/gradients/custom_gradient_test.cc +++ b/tensorflow/c/experimental/gradients/custom_gradient_test.cc @@ -125,19 +125,12 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) { result_tensor = nullptr; } -#ifdef PLATFORM_GOOGLE -INSTANTIATE_TEST_SUITE_P( - CustomGradientTest, CustomGradientTest, - ::testing::Combine(::testing::Values("graphdef", "mlir"), - /*tfrt*/ ::testing::Values(true, false), - /*executing_eagerly*/ ::testing::Values(true, false))); -#else INSTANTIATE_TEST_SUITE_P( CustomGradientTest, CustomGradientTest, ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); -#endif + } // namespace } // namespace internal } // namespace gradients diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index eda00deb59c..ef81acf75a5 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/compiler/xla/pjrt:pjrt_c_api_client", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_initializer_helper", "//tensorflow/core:framework", "//tensorflow/core/common_runtime/next_pluggable_device:plugin_resource", "//tensorflow/core/platform:status", @@ -32,6 +33,8 @@ cc_library( "//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc index caa49be2d3f..dda6f5bcc26 100644 --- a/tensorflow/c/experimental/next_pluggable_device/c_api.cc +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/kernels_experimental.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_internal.h" @@ -30,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/jit/variable_info_util.h" #include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.h" // NOLINT(unused-includes): required for tensorflow::tpu::FindAndLoadTpuLibrary #include "tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" @@ -110,9 +114,9 @@ TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx, const tensorflow::Tensor& arg_tensor = cc_ctx->input(index); tsl::Status cc_status; if (arg_tensor.dtype() != tensorflow::DT_RESOURCE) { - cc_status = tsl::errors::InvalidArgument( - "Trying to obtain resource handle from Input[", index, - "], which is not type DT_RESOURCE."); + cc_status = absl::InvalidArgumentError( + absl::StrCat("Trying to obtain resource handle from Input[", index, + "], which is not type DT_RESOURCE.")); tsl::Set_TF_Status_from_Status(status, cc_status); return nullptr; } @@ -140,12 +144,12 @@ void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx, auto* cc_ctx = reinterpret_cast(ctx); tsl::Status cc_status; if (var_info == nullptr) { - cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL."); + cc_status = absl::InvalidArgumentError("TF_VariableInfo is NULL."); tsl::Set_TF_Status_from_Status(status, cc_status); return; } if (var_info->var_info.var() == nullptr) { - cc_status = tsl::errors::InvalidArgument( + cc_status = absl::InvalidArgumentError( "VariableInfo does not track a resource variable."); tsl::Set_TF_Status_from_Status(status, cc_status); return; @@ -161,12 +165,12 @@ TF_Tensor* TF_GetTensorFromVariableInfo(TF_VariableInfo* var_info, TF_Status* status) { tsl::Status cc_status; if (var_info == nullptr) { - cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL."); + cc_status = absl::InvalidArgumentError("TF_VariableInfo is NULL."); tsl::Set_TF_Status_from_Status(status, cc_status); return nullptr; } if (var_info->var_info.var() == nullptr) { - cc_status = tsl::errors::InvalidArgument( + cc_status = absl::InvalidArgumentError( "VariableInfo does not track a resource variable."); tsl::Set_TF_Status_from_Status(status, cc_status); return nullptr; @@ -239,6 +243,18 @@ void TF_CoordinationServiceDeleteKeyValue(const char* key, void TF_CreateAndSetPjRtCApiClient(const char* device_type, TF_Status* status, PJRT_NamedValue* create_options, int num_options) { + // TODO(b/262050449): use a common plugin discovery mechanism, rather than + // having TPU-specific code here. +#if !defined(PLATFORM_GOOGLE) || defined(LIBTPU_STATIC) + if (absl::AsciiStrToLower(device_type) == "tpu") { + // TODO(b/261484192): handle device specific initialization. + tsl::Status tpu_status = tensorflow::tpu::FindAndLoadTpuLibrary(); + if (!tpu_status.ok()) { + tensorflow::Set_TF_Status_from_Status(status, tpu_status); + return; + } + } +#endif tsl::StatusOr> pjrt_client = xla::GetCApiClient(device_type, pjrt::ConvertFromPjRtNamedValueList( create_options, num_options)); @@ -263,8 +279,9 @@ PJRT_Client* TF_GetPjRtCClient(const char* device_type, TF_Status* status) { tensorflow::down_cast(*pjrt_client); if (pjrt_c_api_client == nullptr) { tensorflow::Set_TF_Status_from_Status( - status, tsl::errors::Internal("PjRtClient for ", device_type, - " is not type PjRtCApiClient")); + status, + absl::InternalError(absl::StrCat("PjRtClient for ", device_type, + " is not type PjRtCApiClient"))); return nullptr; } TF_SetStatus(status, TF_OK, ""); @@ -282,8 +299,7 @@ PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) { tensorflow::AsyncValueTensor::FromTensor(&tensor); if (av_tensor == nullptr || av_tensor->GetBuffer() == nullptr) { tensorflow::Set_TF_Status_from_Status( - status, - tsl::errors::Internal("Input tensor does not have PjRtBuffer.")); + status, absl::InternalError("Input tensor does not have PjRtBuffer.")); return nullptr; } auto* c_api_buffer = @@ -291,7 +307,7 @@ PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) { if (c_api_buffer == nullptr) { tensorflow::Set_TF_Status_from_Status( status, - tsl::errors::Internal( + absl::InternalError( "The PjRtBuffer in the tensor is not type PjRtCApiBuffer.")); return nullptr; } @@ -317,8 +333,9 @@ void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer, tensorflow::down_cast(*pjrt_client); if (pjrt_c_api_client == nullptr) { tensorflow::Set_TF_Status_from_Status( - status, tsl::errors::Internal("PjRtClient for ", device_type, - " is not type PjRtCApiClient")); + status, + absl::InternalError(absl::StrCat("PjRtClient for ", device_type, + " is not type PjRtCApiClient"))); return; } tensorflow::AsyncValueTensor* av_tensor = @@ -326,7 +343,7 @@ void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer, if (av_tensor == nullptr) { tensorflow::Set_TF_Status_from_Status( status, - tsl::errors::Internal( + absl::InternalError( "The tensor to set PjRtBuffer is not an AsyncValueTensor.")); return; } diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index ab7de9bae06..6c4c32a31db 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -28,6 +28,8 @@ cc_library( "//tensorflow/cc/saved_model:constants", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) @@ -99,6 +101,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/asset.cc b/tensorflow/c/experimental/saved_model/core/revived_types/asset.cc index 5cc14d615f5..2f32e6c76b0 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/asset.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/asset.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" @@ -37,8 +39,8 @@ Status Asset::Create(ImmediateExecutionContext* ctx, io::JoinPath(saved_model_dir, kSavedModelAssetsDirectory, asset_filename); AbstractTensorPtr tensor(ctx->CreateStringScalar(abs_path)); if (tensor.get() == nullptr) { - return errors::Internal( - "Failed to create scalar string tensor for Asset at path ", abs_path); + return absl::InternalError(absl::StrCat( + "Failed to create scalar string tensor for Asset at path ", abs_path)); } ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get())); diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc index b9344238b79..fe78a84a649 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_context.h" @@ -60,15 +62,15 @@ Status AssertAllCreateResourceFunctionsHaveNoCaptures( const TFConcreteFunctionRevivalState* create_resource_fn = resource.create_resource; if (create_resource_fn == nullptr) { - return errors::FailedPrecondition( - "Resource at node ", node_id, - " did not have a create_resource() function"); + return absl::FailedPreconditionError( + absl::StrCat("Resource at node ", node_id, + " did not have a create_resource() function")); } const SavedConcreteFunction* saved_create_resource_fn = create_resource_fn->saved_concrete_func; if (!saved_create_resource_fn->bound_inputs().empty()) { // TODO(b/124045874): Support loading resource functions via a top sort - return errors::Unimplemented( + return absl::UnimplementedError( "Create Resource functions with captures are currently unsupported."); } } @@ -86,9 +88,9 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, case SavedObject::kVariable: { const auto& variables_iter = objects.variables.find(node_id); if (variables_iter == objects.variables.end()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Tried to convert node id ", node_id, - " of type variable to tensor but the variable wasn't initialized"); + " of type variable to tensor but the variable wasn't initialized")); } *handle = variables_iter->second->handle(); return Status(); @@ -96,9 +98,10 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, case SavedObject::kConstant: { const auto& constants_iter = objects.constants.find(node_id); if (constants_iter == objects.constants.end()) { - return errors::FailedPrecondition("Tried to convert node id ", node_id, - " of type constant to tensor but the " - "constant wasn't initialized"); + return absl::FailedPreconditionError( + absl::StrCat("Tried to convert node id ", node_id, + " of type constant to tensor but the " + "constant wasn't initialized")); } *handle = constants_iter->second->handle(); return Status(); @@ -106,9 +109,9 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, case SavedObject::kAsset: { const auto& assets_iter = objects.assets.find(node_id); if (assets_iter == objects.assets.end()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Tried to convert node id ", node_id, - " of type asset to tensor but the asset wasn't initialized"); + " of type asset to tensor but the asset wasn't initialized")); } *handle = assets_iter->second->handle(); return Status(); @@ -116,24 +119,24 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, case SavedObject::kResource: { const auto& resource_iter = objects.restored_resources.find(node_id); if (resource_iter == objects.restored_resources.end()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Tried to convert node id ", node_id, - " of type Resource to tensor but the Resource wasn't initialized"); + " of type Resource to tensor but the Resource wasn't initialized")); } const RestoredResourceRevivalState& resource = resource_iter->second; if (resource.resource_handle == nullptr) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Resource with node id ", node_id, - " should have its resource_handle created, but was nullptr."); + " should have its resource_handle created, but was nullptr.")); } *handle = resource.resource_handle.get(); return Status(); } default: { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Only objects of type variable, constant, asset, and resources have " "capturable tensorhandles. Encountered object of kind ", - node.kind_case(), " at node id: ", node_id); + node.kind_case(), " at node id: ", node_id)); } } } @@ -167,35 +170,35 @@ Status SignatureDefArgsFromInputs( // (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of // string keys to TensorSpecs. if (!canonicalized_input_signature.has_tuple_value()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "SignatureDefFunction's canonicalized_input_signature should be " "of form tuple(tuple(), dict()), but was instead: \n", - canonicalized_input_signature.DebugString()); + canonicalized_input_signature.DebugString())); } const TupleValue& args_kwargs_tuple = canonicalized_input_signature.tuple_value(); if (args_kwargs_tuple.values_size() != 2) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "SignatureDefFunction's canonicalized_input_signature should be " "a tuple of two elements (args, kwargs), but was instead: \n", - args_kwargs_tuple.DebugString()); + args_kwargs_tuple.DebugString())); } const StructuredValue& args = args_kwargs_tuple.values(0); if (!args.has_tuple_value() || !args.tuple_value().values().empty()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "SignatureDefFunction's canonicalized_input_signature's args" "should be an empty tuple, but instead got: \n", - args.DebugString()); + args.DebugString())); } const StructuredValue& kwargs = args_kwargs_tuple.values(1); if (!kwargs.has_dict_value()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "SignatureDefFunction's canonicalized_input_signature's kwargs" "should be a dictionary, but instead got: \n", - kwargs.DebugString()); + kwargs.DebugString())); } const DictValue& kwargs_dict = kwargs.dict_value(); @@ -206,10 +209,10 @@ Status SignatureDefArgsFromInputs( const std::string& key = key_value.first; const StructuredValue& value = key_value.second; if (!value.has_tensor_spec_value()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "SignatureDefFunction's canonicalized_input_signature's kwargs" "dictionary contained a non-tensorspec value for key-value pair: \n", - "Key: ", key, "Value: \n", value.DebugString()); + "Key: ", key, "Value: \n", value.DebugString())); } result[key] = &value.tensor_spec_value(); } @@ -226,10 +229,10 @@ Status SignatureDefArgsFromInputs( Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature, std::vector* out) { if (!output_signature.has_dict_value()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "SignatureDefFunction's output_signature must be a dictionary, but " "instead got: ", - output_signature.DebugString()); + output_signature.DebugString())); } const DictValue& output_dict = output_signature.dict_value(); @@ -240,10 +243,10 @@ Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature, const std::string& key = key_value.first; const StructuredValue& value = key_value.second; if (!value.has_tensor_spec_value()) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "SignatureDefFunction's output_signature dictionary contained a " "non-tensorspec value for key-value pair: \n", - "Key: ", key, "Value: \n", value.DebugString()); + "Key: ", key, "Value: \n", value.DebugString())); } result[key] = &value.tensor_spec_value(); } @@ -337,7 +340,7 @@ Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx, create_resource_fn->saved_concrete_func; if (!saved_create_resource_fn->bound_inputs().empty()) { // TODO(b/124045874): Load resource functions via a topological sort - return errors::Unimplemented( + return absl::UnimplementedError( "Create Resource functions with captures are currently unsupported."); } std::unique_ptr out; @@ -401,9 +404,9 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx, const TFConcreteFunction* create_resource_fn = revived->concrete_functions.Find(create_resource_fn_node); if (create_resource_fn == nullptr) { - return errors::FailedPrecondition( - "ConcreteFunction at node ", create_resource_fn_node, - " should have been initialized prior to being called."); + return absl::FailedPreconditionError( + absl::StrCat("ConcreteFunction at node ", create_resource_fn_node, + " should have been initialized prior to being called.")); } ImmediateOpPtr function_op; TF_RETURN_IF_ERROR(create_resource_fn->MakeCallOp({}, &function_op)); @@ -416,7 +419,7 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx, AbstractTensorHandlePtr owned_resource_handle(resource_handle); if (!tensorflow::isa( owned_resource_handle.get())) { - return errors::Internal("Unexpected tensor handle kind."); + return absl::InternalError("Unexpected tensor handle kind."); } ImmediateTensorHandlePtr result( reinterpret_cast( @@ -443,9 +446,9 @@ Status BuildResources(ImmediateExecutionContext* ctx, create_resource = revived->concrete_functions.Find( resource_revival_state.create_resource->node_id); if (create_resource == nullptr) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "'create_resource' function with node id ", - resource_revival_state.create_resource->node_id, " not found"); + resource_revival_state.create_resource->node_id, " not found")); } } @@ -454,9 +457,9 @@ Status BuildResources(ImmediateExecutionContext* ctx, initialize = revived->concrete_functions.Find( resource_revival_state.initialize->node_id); if (initialize == nullptr) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "'initialize' function with node id ", - resource_revival_state.initialize->node_id, " not found"); + resource_revival_state.initialize->node_id, " not found")); } } @@ -465,15 +468,16 @@ Status BuildResources(ImmediateExecutionContext* ctx, destroy_resource = revived->concrete_functions.Find( resource_revival_state.destroy_resource->node_id); if (destroy_resource == nullptr) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "'destroy_resource' function with node id ", - resource_revival_state.destroy_resource->node_id, " not found"); + resource_revival_state.destroy_resource->node_id, " not found")); } } if (resource_revival_state.resource_handle == nullptr) { - return errors::FailedPrecondition("Resource at node id ", node_id, - " does not have a resource handle."); + return absl::FailedPreconditionError( + absl::StrCat("Resource at node id ", node_id, + " does not have a resource handle.")); } revived->restored_resources.emplace( diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index d6dc1f202b0..36e5cb52d2e 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -344,7 +344,7 @@ cc_library( tf_cc_test( name = "saved_model_api_test", - size = "small", + size = "medium", srcs = [ "saved_model_api_test.cc", ], diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD index ee8fda30e46..a10cfd03e3d 100644 --- a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD @@ -13,15 +13,15 @@ py_strict_binary( srcs = ["gen_saved_models.py"], python_version = "PY3", deps = [ - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_spec", - "//tensorflow/python:variables", - "//tensorflow/python:while_loop", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/module", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/ops:while_loop", "//tensorflow/python/saved_model", "@absl_py//absl:app", ], diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index c1019c705ba..82d1f2a7e4f 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -189,6 +189,8 @@ cc_library_with_android_deps( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/cc/experimental/libtf/BUILD b/tensorflow/cc/experimental/libtf/BUILD index e281672de9e..31e15972668 100644 --- a/tensorflow/cc/experimental/libtf/BUILD +++ b/tensorflow/cc/experimental/libtf/BUILD @@ -1,4 +1,5 @@ -# TODO(aselle): describe this package. +#include "third_party/absl/strings/str_cat.h" +#TODO(aselle) : describe this package. load( "//tensorflow/core/platform:rules_cc.bzl", @@ -42,6 +43,8 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) @@ -84,13 +87,13 @@ py_strict_binary( srcs = ["tests/generate_testdata.py"], python_version = "PY3", deps = [ - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:tensor_spec", - "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/module", + "//tensorflow/python/ops:variables", "//tensorflow/python/saved_model", "@absl_py//absl:app", "@absl_py//absl/flags", @@ -180,13 +183,11 @@ tf_cc_test( size = "medium", srcs = [ "tests/runtime_test_core.cc", - "tests/runtime_test_tfrt.cc", ], deps = [ ":runtime_test", "//tensorflow/cc/experimental/libtf/runtime", "//tensorflow/cc/experimental/libtf/runtime/core", - "//tensorflow/cc/experimental/libtf/runtime/tfrt", ], ) diff --git a/tensorflow/cc/experimental/libtf/object.h b/tensorflow/cc/experimental/libtf/object.h index 72d05aaf430..4e15a508e39 100644 --- a/tensorflow/cc/experimental/libtf/object.h +++ b/tensorflow/cc/experimental/libtf/object.h @@ -29,6 +29,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/cc/experimental/libtf/value.h" #include "tensorflow/core/platform/errors.h" @@ -172,7 +174,7 @@ class Object : public Handle { } } } - return tensorflow::errors::NotFound("Key not in dictionary."); + return absl::NotFoundError("Key not in dictionary."); } /// Sets `key` attribute with the underlying value of `h`. @@ -202,7 +204,7 @@ class Dictionary final : public Handle { tensorflow::StatusOr Get(const Handle& key) { auto it = value_.dict().find(key.value_); if (it != value_.dict().end()) return Cast(Handle(it->second)); - return tensorflow::errors::NotFound("Key not in dictionary."); + return absl::NotFoundError("Key not in dictionary."); } /// Sets `key` with value `value`. void Set(const String& key, Handle value) { @@ -282,7 +284,7 @@ tensorflow::Status Tensor::GetValue(absl::Span data) const { { const auto abstract_t = value_.tensor().get(); if (!tensorflow::ImmediateExecutionTensorHandle::classof(abstract_t)) { - return tensorflow::errors::InvalidArgument( + return absl::InvalidArgumentError( "Attempting to get value of non eager tensor."); } auto imm_t = @@ -315,7 +317,7 @@ class Tuple : public Handle { template tensorflow::StatusOr Get(size_t i) { if (i >= value_.tuple().size()) - return tensorflow::errors::InvalidArgument("Out of bounds index."); + return absl::InvalidArgumentError("Out of bounds index."); return Cast(Handle(value_.tuple()[i])); } @@ -348,7 +350,7 @@ class List final : public Handle { template tensorflow::StatusOr Get(size_t i) { if (i >= size()) { - return tensorflow::errors::InvalidArgument("Out of bounds index."); + return absl::InvalidArgumentError("Out of bounds index."); } return Cast(Handle(value_.list()[i])); } @@ -356,7 +358,7 @@ class List final : public Handle { /// Sets value `h` at index `i`. tensorflow::Status Set(size_t i, Handle h) { if (i >= size()) { - return tensorflow::errors::InvalidArgument("Out of bounds index."); + return absl::InvalidArgumentError("Out of bounds index."); } value_.list()[i] = std::move(h.value_); return ::tensorflow::OkStatus(); @@ -533,7 +535,7 @@ tensorflow::StatusOr Cast(Handle handle) { if (handle.value_.type() == TypeToTaggedType() || std::is_same::value) return T((std::move(handle.value_))); - return tensorflow::errors::InvalidArgument("Incompatible cast."); + return absl::InvalidArgumentError("Incompatible cast."); } // Converters for C++ primitives like float and int to handles. Allows callable @@ -656,10 +658,10 @@ class UneraseCallHelper { Handle h(std::move(args_in.tuple()[argument_index])); tensorflow::StatusOr x = Cast(std::move(h)); if (!x.ok()) - return tensorflow::errors::InvalidArgument( - std::string("Function ") + name + " Arg " + - std::to_string(argument_index) + - " cannot be cast to desired signature type "); + return absl::InvalidArgumentError( + absl::StrCat(std::string("Function ") + name + " Arg " + + std::to_string(argument_index) + + " cannot be cast to desired signature type ")); return UneraseCallHelper::template Call( name, fn, argument_index + 1, args_in, args..., *x); } @@ -683,9 +685,9 @@ class CallableWrapper { TaggedValue kwargs) { constexpr size_t argument_count = sizeof...(TFuncArgs); if (argument_count != args.tuple().size()) - return tensorflow::errors::InvalidArgument( - std::string("Function ") + name_ + " expected " + - std::to_string(argument_count) + " args."); + return absl::InvalidArgumentError( + absl::StrCat(std::string("Function ") + name_ + " expected " + + std::to_string(argument_count) + " args.")); return UneraseCallHelper::Call(name_, functor_, 0, args); } diff --git a/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD b/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD deleted file mode 100644 index 586ef6b9523..00000000000 --- a/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "tfrt", - srcs = [ - "tfrt.cc", - ], - hdrs = [ - "tfrt.h", - ], - deps = [ - "//tensorflow/c:tf_status_helper", - "//tensorflow/c:tf_status_internal", - "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:c_api_experimental", - "//tensorflow/c/eager:tfe_context_internal", - "//tensorflow/cc/experimental/libtf", - "//tensorflow/cc/experimental/libtf/runtime", - ], -) diff --git a/tensorflow/cc/experimental/libtf/runtime/tfrt/tfrt.cc b/tensorflow/cc/experimental/libtf/runtime/tfrt/tfrt.cc deleted file mode 100644 index b50344fb0ed..00000000000 --- a/tensorflow/cc/experimental/libtf/runtime/tfrt/tfrt.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/runtime/tfrt/tfrt.h" - -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/cc/experimental/libtf/value.h" - -namespace tf { -namespace libtf { -namespace runtime { -namespace tfrt { - -runtime::Runtime Runtime() { - TFE_Context* ctx; - TFE_ContextOptions* ctx_options = TFE_NewContextOptions(); - TFE_ContextOptionsSetTfrt(ctx_options, true); - TFE_ContextOptionsSetDevicePlacementPolicy(ctx_options, - TFE_DEVICE_PLACEMENT_WARN); - TF_Status* status = TF_NewStatus(); - ctx = TFE_NewContext(ctx_options, status); - TF_DeleteStatus(status); - TFE_DeleteContextOptions(ctx_options); - return runtime::Runtime(tensorflow::unwrap(ctx)); -} - -} // namespace tfrt -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/function_test.cc b/tensorflow/cc/experimental/libtf/tests/function_test.cc index a9b4061f1a0..fa1f21389df 100644 --- a/tensorflow/cc/experimental/libtf/tests/function_test.cc +++ b/tensorflow/cc/experimental/libtf/tests/function_test.cc @@ -288,7 +288,7 @@ TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) { INSTANTIATE_TEST_SUITE_P(TF2CAPI, FunctionTest, ::testing::Combine(::testing::Values("graphdef", "mlir"), - ::testing::Values(false, true))); + ::testing::Values(false))); } // namespace libtf } // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc b/tensorflow/cc/experimental/libtf/tests/tensor_test.cc index 3f4708f0f0d..85243dd4287 100644 --- a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc +++ b/tensorflow/cc/experimental/libtf/tests/tensor_test.cc @@ -123,7 +123,7 @@ TEST_P(UnifiedCAPI, SimpleCreationFunctions) { INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Combine(::testing::Values("graphdef", "mlir"), - ::testing::Values(true, false))); + ::testing::Values(false))); } // namespace libtf } // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/variable_test.cc b/tensorflow/cc/experimental/libtf/tests/variable_test.cc index 8e7aca22bdc..1e37ed9cb2b 100644 --- a/tensorflow/cc/experimental/libtf/tests/variable_test.cc +++ b/tensorflow/cc/experimental/libtf/tests/variable_test.cc @@ -114,7 +114,7 @@ TEST_P(VariableTest, CreateAssignReadDestroy) { INSTANTIATE_TEST_SUITE_P(TF2CAPI, VariableTest, ::testing::Combine(::testing::Values("graphdef", "mlir"), - ::testing::Values(false, true))); + ::testing::Values(false))); } // namespace libtf } // namespace tf diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index d19b895654b..ab8b387ab56 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" @@ -156,9 +158,9 @@ class Input { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); if (t.NumElements() != static_cast(v.size())) { - status = errors::InvalidArgument( + status = absl::InvalidArgumentError(absl::StrCat( "Cannot construct a tensor with ", t.NumElements(), - " from an initializer list with ", v.size(), " elements"); + " from an initializer list with ", v.size(), " elements")); return; } std::copy_n(v.begin(), v.size(), t.flat().data()); diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 59ea373bd6d..b3d77f29b06 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -466,7 +466,7 @@ TEST_F(CWiseUnaryGradTest, Asin_Complex) { }; // TODO(kbsriram) // Enable test when the asin kernel supports complex numbers - if (false) { + if (/* DISABLES CODE */ (false)) { TestCWiseGrad(ASIN, x_fn); } } @@ -482,7 +482,7 @@ TEST_F(CWiseUnaryGradTest, Acos_Complex) { }; // TODO(kbsriram) // Add test when the acos kernel supports complex numbers - if (false) { + if (/* DISABLES CODE */ (false)) { TestCWiseGrad(ACOS, x_fn); } } @@ -510,7 +510,7 @@ TEST_F(CWiseUnaryGradTest, Atan_Complex) { }; // TODO(kbsriram) // Add test when the atan kernel supports complex numbers - if (false) { + if (/* DISABLES CODE */ (false)) { TestCWiseGrad(ATAN, x_fn); } } @@ -561,7 +561,7 @@ TEST_F(CWiseUnaryGradTest, Lgamma_Complex) { }; // TODO(kbsriram) // Add test when the lgamma kernel supports complex numbers - if (false) { + if (/* DISABLES CODE */ (false)) { TestCWiseGrad(LGAMMA, x_fn); } } @@ -579,7 +579,7 @@ TEST_F(CWiseUnaryGradTest, Erf_Complex) { }; // TODO(kbsriram) // Add test when the erf kernel supports complex numbers - if (false) { + if (/* DISABLES CODE */ (false)) { TestCWiseGrad(ERF, x_fn); } } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index d52db030b1b..b9764d72c7e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -1,4 +1,5 @@ -# Description: +#include "third_party/absl/strings/str_cat.h" +#Description: # TensorFlow SavedModel. load("//tensorflow:tensorflow.default.bzl", "filegroup") @@ -6,6 +7,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "if_android", + "if_google", "if_mobile", "if_not_mobile", "tf_cc_test", @@ -28,6 +30,8 @@ package( exports_files([ "loader.h", + "testdata/chunked_saved_model/chunked_model/saved_model.cpb", + "testdata/chunked_saved_model/chunked_model/saved_model.pbtxt", ]) cc_library( @@ -63,7 +67,9 @@ cc_library( ":metrics", ":util", "//tensorflow/core:protos_all_cc", - ] + if_not_mobile([ + ] + if_google([ + "//tensorflow/tools/proto_splitter:merge", + ]) + if_not_mobile([ # TODO(b/111634734): :lib and :protos_all contain dependencies that # cannot be built on mobile platforms. Instead, include the appropriate # tf_lib depending on the build platform. @@ -87,9 +93,8 @@ tf_cc_test( ":tag_constants", "//tensorflow/core:lib", "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", "//tensorflow/core/platform:resource_loader", + "@com_google_googletest//:gtest_main", ], ) @@ -131,6 +136,8 @@ cc_library( ":fingerprinting", ":loader_util", ":reader", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ] + if_not_mobile([ ":metrics", ":util", @@ -252,15 +259,15 @@ py_binary( python_version = "PY3", srcs_version = "PY3", deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:tensor_spec", - "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/module", + "//tensorflow/python/ops:lookup_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model", "//tensorflow/python/saved_model:save_options", "//tensorflow/python/trackable:asset", @@ -268,6 +275,31 @@ py_binary( ], ) +# copybara:uncomment_begin(google-only) +# +# py_binary( +# name = "testdata/generate_chunked_models", +# srcs = ["testdata/generate_chunked_models.py"], +# python_version = "PY3", +# srcs_version = "PY3", +# deps = [ +# "//tensorflow/python/compat:v2_compat", +# "//tensorflow/python/eager:def_function", +# "//tensorflow/python/framework:constant_op", +# "//tensorflow/python/module", +# "//tensorflow/python/platform:client_testlib", +# "//tensorflow/python/saved_model:loader", +# "//tensorflow/python/saved_model:save", +# "//tensorflow/python/saved_model:save_options", +# "//tensorflow/python/util:compat", +# "//tensorflow/tools/proto_splitter:constants", +# "//tensorflow/tools/proto_splitter/python:saved_model", +# "@absl_py//absl:app", +# ], +# ) +# +# copybara:uncomment_end + # TODO(b/32673259): add a test to continuously validate these files. filegroup( name = "saved_model_test_files", @@ -284,6 +316,7 @@ filegroup( "testdata/fuzz_generated/**", "testdata/SimpleV1Model/**", "testdata/OptimizerSlotVariableModule/**", + "testdata/chunked_saved_model/**", ]), ) @@ -369,7 +402,7 @@ tf_cc_test( ":metrics", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@com_google_googletest//:gtest_main", + "@com_google_googletest//:gtest", "@jsoncpp_git//:jsoncpp", ], ) diff --git a/tensorflow/cc/saved_model/bundle_v2.cc b/tensorflow/cc/saved_model/bundle_v2.cc index 21692edbf40..85af07ce0d5 100644 --- a/tensorflow/cc/saved_model/bundle_v2.cc +++ b/tensorflow/cc/saved_model/bundle_v2.cc @@ -39,60 +39,6 @@ using strings::StrCat; // `tensorflow::SavedModelV2Bundle::Load` API label. constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2"; -Status ReadSavedModelProto(const string& export_dir, - SavedModel* saved_model_proto) { - LOG(INFO) << "Reading SavedModel from: " << export_dir; - - const string saved_model_pb_path = - io::JoinPath(export_dir, kSavedModelFilenamePb); - Status found_pb = Env::Default()->FileExists(saved_model_pb_path); - if (found_pb.ok()) { - Status result = - ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto); - if (result.ok()) { - metrics::SavedModelReadCount( - saved_model::GetWriteVersion(*saved_model_proto)) - .IncrementBy(1); - } - return result; - } - - const string saved_model_pbtxt_path = - io::JoinPath(export_dir, kSavedModelFilenamePbTxt); - Status found_pbtxt = Env::Default()->FileExists(saved_model_pbtxt_path); - if (found_pbtxt.ok()) { - Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path, - saved_model_proto); - if (result.ok()) { - metrics::SavedModelReadCount( - saved_model::GetWriteVersion(*saved_model_proto)) - .IncrementBy(1); - } - return result; - } - - Status err; - if (found_pb.code() == found_pbtxt.code()) { - err = Status(found_pb.code(), - StrCat(found_pb.message(), "\n", found_pbtxt.message())); - } else if (found_pb.code() == NOT_FOUND) { - err = found_pbtxt; - } else if (found_pbtxt.code() == NOT_FOUND) { - err = found_pb; - } else { - // found_pb and found_pbtxt both errored, w/ different codes, neither being - // NOT_FOUND. - err = Status( - absl::StatusCode::kInternal, - StrCat("Different errors encountered while looking for saved_model.pb " - "and saved_model.pbtxt in the export directory path \"", - export_dir, "\": \n", found_pb.ToString(), "\n", - found_pbtxt.ToString())); - } - - return err; -} - Status ReadCheckpointObjectGraph(BundleReader* bundle_reader, TrackableObjectGraph* object_graph) { Tensor object_graph_tensor; @@ -123,7 +69,7 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir, SavedModelV2Bundle* const bundle) { metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1); SavedModel saved_model_proto; - TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto)); + TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); metrics::SavedModelReadPath().Set(export_dir); // Load MetaGraphDef. diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index b571a643113..e8a267e3280 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -19,56 +19,63 @@ limitations under the License. namespace tensorflow { // SavedModel assets directory. -constexpr char kSavedModelAssetsDirectory[] = "assets"; +inline constexpr char kSavedModelAssetsDirectory[] = "assets"; // SavedModel assets.extra directory. -constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra"; +inline constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra"; // SavedModel assets key for graph collection-def. -constexpr char kSavedModelAssetsKey[] = "saved_model_assets"; +inline constexpr char kSavedModelAssetsKey[] = "saved_model_assets"; /// SavedModel legacy init op collection key. Used in v1 SavedModels. -constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; +inline constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; /// SavedModel main op collection key. Used in v1 SavedModels. -constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; +inline constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; // CollectionDef key for the SavedModel train op. // Not exported while export_all_saved_models is experimental. -constexpr char kSavedModelTrainOpKey[] = "saved_model_train_op"; +inline constexpr char kSavedModelTrainOpKey[] = "saved_model_train_op"; // Schema version for SavedModel. -constexpr int kSavedModelSchemaVersion = 1; +inline constexpr int kSavedModelSchemaVersion = 1; +// SavedModel proto filename prefix. +inline constexpr char kSavedModelFilenamePrefix[] = "saved_model"; // SavedModel proto filename. -constexpr char kSavedModelFilenamePb[] = "saved_model.pb"; +inline constexpr char kSavedModelFilenamePb[] = "saved_model.pb"; + +// SavedModel chunked proto filename. +inline constexpr char kSavedModelFilenameCpb[] = "saved_model.cpb"; // SavedModel text format proto filename. -constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; +inline constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; // Subdirectory where debugging related files are written. -constexpr char kSavedModelDebugDirectory[] = "debug"; +inline constexpr char kSavedModelDebugDirectory[] = "debug"; // File name for GraphDebugInfo protocol buffer which corresponds to the // SavedModel. -constexpr char kSavedModelDebugInfoFilenamePb[] = "saved_model_debug_info.pb"; +inline constexpr char kSavedModelDebugInfoFilenamePb[] = + "saved_model_debug_info.pb"; // Directory in which to save the SavedModel variables. -constexpr char kSavedModelVariablesDirectory[] = "variables"; +inline constexpr char kSavedModelVariablesDirectory[] = "variables"; // SavedModel variables filename. -constexpr char kSavedModelVariablesFilename[] = "variables"; +inline constexpr char kSavedModelVariablesFilename[] = "variables"; // SavedModel SignatureDef keys for the initialization and train ops. Used in // V2 SavedModels. -constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; -constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op"; +inline constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; +inline constexpr char kSavedModelTrainOpSignatureKey[] = + "__saved_model_train_op"; // Key in the TensorBundle for the object graph proto. -constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH"; +inline constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH"; // Filename for the FingerprintDef protocol buffer. -constexpr char kFingerprintFilenamePb[] = "fingerprint.pb"; +inline constexpr char kFingerprintFilenamePb[] = "fingerprint.pb"; } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index 389b28bf278..d4f05c472b9 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -120,23 +120,29 @@ uint64 HashCheckpointIndexFile(absl::string_view model_dir) { StatusOr CreateFingerprintDef(const SavedModel& saved_model, absl::string_view export_dir) { + SavedModel copy = saved_model; + return CreateFingerprintDef(©, export_dir); +} + +StatusOr CreateFingerprintDef(SavedModel* saved_model, + absl::string_view export_dir) { // Create a copy of `metagraph` which will be used and mutated for fingerprint // computation. - MetaGraphDef metagraph_copy = saved_model.meta_graphs(0); FingerprintDef fingerprint_def; + MetaGraphDef* metagraph = saved_model->mutable_meta_graphs(0); // Set fingerprint field #1. - fingerprint_def.set_saved_model_checksum(HashSavedModel(saved_model)); + fingerprint_def.set_saved_model_checksum(HashSavedModel(*saved_model)); // Set fingerprint field #2. - graph_regularization::SimpleDelete(*metagraph_copy.mutable_graph_def()); + graph_regularization::SimpleDelete(*metagraph->mutable_graph_def()); fingerprint_def.set_graph_def_program_hash( - graph_regularization::ComputeHash(metagraph_copy.graph_def())); + graph_regularization::ComputeHash(metagraph->graph_def())); // Set fingerprint field #3. fingerprint_def.set_signature_def_hash( - RegularizeAndHashSignatureDefs(metagraph_copy.signature_def())); + RegularizeAndHashSignatureDefs(metagraph->signature_def())); // Set fingerprint field #4. TF_ASSIGN_OR_RETURN( StatusOr object_graph_hash, - RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def())); + RegularizeAndHashSavedObjectGraph(metagraph->object_graph_def())); fingerprint_def.set_saved_object_graph_hash(object_graph_hash.value()); // Set fingerprint field #5. fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); diff --git a/tensorflow/cc/saved_model/fingerprinting.h b/tensorflow/cc/saved_model/fingerprinting.h index 1d0a830b4d2..cabca831076 100644 --- a/tensorflow/cc/saved_model/fingerprinting.h +++ b/tensorflow/cc/saved_model/fingerprinting.h @@ -31,6 +31,12 @@ namespace tensorflow::saved_model::fingerprinting { StatusOr CreateFingerprintDef(const SavedModel& saved_model, absl::string_view export_dir); +// Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file +// (.index) in `export_dir`. The passed in `saved_model` is mutated and should +// not be used afterwards. +StatusOr CreateFingerprintDef(SavedModel* saved_model, + absl::string_view export_dir); + // Loads the `fingerprint.pb` from `export_dir`, returns an error if there is // none. StatusOr ReadSavedModelFingerprint( diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index b9544bc7555..c0a816120cb 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/fingerprinting.h" #include "tensorflow/cc/saved_model/loader_util.h" @@ -90,16 +92,16 @@ static Status ValidateNode(const NodeDef& node) { if (node_value.has_tensor()) { const PartialTensorShape node_shape(node_value.tensor().tensor_shape()); if (node_shape.num_elements() < 0) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Saved model contains node \"", node.name(), "\" (op \"", node.op(), "\") which initializes from a tensor with ", - node_shape.num_elements(), " elements"); + node_shape.num_elements(), " elements")); } } } else if (node.op() == "Const") { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Saved model contains node \"", node.name(), - "\" which is a constant tensor but no value has been provided"); + "\" which is a constant tensor but no value has been provided")); } return OkStatus(); } @@ -108,9 +110,9 @@ static Status ValidateFunctionNotRecursive(const FunctionDef& function) { const auto& function_name = function.signature().name(); for (const auto& node : function.node_def()) { if (node.op() == function_name) { - return errors::FailedPrecondition( + return absl::FailedPreconditionError(absl::StrCat( "Function ", function_name, - " is self recursive and TensorFlow does not support this scenario."); + " is self recursive and TensorFlow does not support this scenario.")); } } @@ -340,17 +342,17 @@ class LiteSessionWrapper : public Session { : wrapped_(std::move(wrapped)) {} Status Create(const GraphDef& graph) override { - return errors::Unimplemented("Session::Create()"); + return absl::UnimplementedError("Session::Create()"); } Status Create(GraphDef&& graph) override { - return errors::Unimplemented("Session::Create()"); + return absl::UnimplementedError("Session::Create()"); } Status Extend(const GraphDef& graph) override { - return errors::Unimplemented("Session::Extend()"); + return absl::UnimplementedError("Session::Extend()"); } Status Extend(GraphDef&& graph) override { - return errors::Unimplemented("Session::Extend()"); + return absl::UnimplementedError("Session::Extend()"); } Status Run(const std::vector>& inputs, @@ -362,16 +364,16 @@ class LiteSessionWrapper : public Session { } Status Create(const RunOptions& run_options, const GraphDef& graph) override { - return errors::Unimplemented("Session::Create()"); + return absl::UnimplementedError("Session::Create()"); } Status Extend(const RunOptions& run_options, const GraphDef& graph) override { - return errors::Unimplemented("Session::Extend()"); + return absl::UnimplementedError("Session::Extend()"); } Status Create(const RunOptions& run_options, GraphDef&& graph) override { - return errors::Unimplemented("Session::Create()"); + return absl::UnimplementedError("Session::Create()"); } Status Extend(const RunOptions& run_options, GraphDef&& graph) override { - return errors::Unimplemented("Session::Extend()"); + return absl::UnimplementedError("Session::Extend()"); } Status Close(const RunOptions& run_options) override { return wrapped_->Close(run_options); @@ -390,14 +392,14 @@ class LiteSessionWrapper : public Session { const std::vector& output_names, const std::vector& target_nodes, string* handle) override { - return errors::Unimplemented("Session::PRunSetup()"); + return absl::UnimplementedError("Session::PRunSetup()"); } Status PRun(const string& handle, const std::vector>& inputs, const std::vector& output_names, std::vector* outputs) override { - return errors::Unimplemented("Session::PRun()"); + return absl::UnimplementedError("Session::PRun()"); } Status ListDevices(std::vector* response) override { diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index 40ba3e4a4e4..365874881dd 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/cc/saved_model/reader.h" +#include +#include #include +#include #include "absl/memory/memory.h" #include "tensorflow/cc/saved_model/constants.h" @@ -31,60 +34,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system_helper.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h" +// Placeholder for protosplitter merger include. + +#define IS_OSS true namespace tensorflow { namespace { -// Reads the SavedModel proto from saved_model.pb in `export_dir`. -// Returns a failure status when the SavedModel file does not exist. -Status ReadSavedModel(absl::string_view export_dir, - SavedModel* saved_model_proto) { - LOG(INFO) << "Reading SavedModel from: " << export_dir; - - const std::string saved_model_pb_path = - io::JoinPath(export_dir, kSavedModelFilenamePb); - - TF_ASSIGN_OR_RETURN( - bool saved_model_pb_exists, - internal::FileExists(Env::Default(), saved_model_pb_path)); - if (saved_model_pb_exists) { - Status result = - ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto); - if (result.ok()) { - metrics::SavedModelReadCount( - saved_model::GetWriteVersion(*saved_model_proto)) - .IncrementBy(1); - } - return result; - } - const std::string saved_model_pbtxt_path = - io::JoinPath(export_dir, kSavedModelFilenamePbTxt); - TF_ASSIGN_OR_RETURN( - bool saved_model_pbtxt_exists, - internal::FileExists(Env::Default(), saved_model_pbtxt_path)); - if (saved_model_pbtxt_exists) { - Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path, - saved_model_proto); - if (result.ok()) { - metrics::SavedModelReadCount( - saved_model::GetWriteVersion(*saved_model_proto)) - .IncrementBy(1); - } - return result; - } - return Status( - absl::StatusCode::kNotFound, - strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied " - "export directory path: ", - export_dir, - ". Check that " - "the directory exists and that you have the right " - "permissions for accessing it.")); -} - Status FindMetaGraphDef(const std::unordered_set& tags, SavedModel* saved_model_proto, MetaGraphDef* meta_graph_def) { @@ -116,6 +76,61 @@ Status FindMetaGraphDef(const std::unordered_set& tags, } } // namespace +// Reads the SavedModel proto from saved_model.pb in `export_dir`. +// Returns a failure status when the SavedModel file does not exist. +Status ReadSavedModel(absl::string_view export_dir, + SavedModel* saved_model_proto) { + LOG(INFO) << "Reading SavedModel from: " << export_dir; + + if (IS_OSS) { + const std::string saved_model_pb_path = + io::JoinPath(export_dir, kSavedModelFilenamePb); + TF_ASSIGN_OR_RETURN( + bool saved_model_pb_exists, + internal::FileExists(Env::Default(), saved_model_pb_path)); + if (saved_model_pb_exists) { + Status result = ReadBinaryProto(Env::Default(), saved_model_pb_path, + saved_model_proto); + if (result.ok()) { + metrics::SavedModelReadCount( + saved_model::GetWriteVersion(*saved_model_proto)) + .IncrementBy(1); + } + return result; + } + } + + const std::string saved_model_pbtxt_path = + io::JoinPath(export_dir, kSavedModelFilenamePbTxt); + TF_ASSIGN_OR_RETURN( + bool saved_model_pbtxt_exists, + internal::FileExists(Env::Default(), saved_model_pbtxt_path)); + if (saved_model_pbtxt_exists) { + Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path, + saved_model_proto); + if (result.ok()) { + metrics::SavedModelReadCount( + saved_model::GetWriteVersion(*saved_model_proto)) + .IncrementBy(1); + } + return result; + } + + if (!IS_OSS) { + // Only use Merger outside of OSS. + // Placeholder for protosplitter merger call. + } + + return Status( + absl::StatusCode::kNotFound, + strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied " + "export directory path: ", + export_dir, + ". Check that " + "the directory exists and that you have the right " + "permissions for accessing it.")); +} + Status ReadMetaGraphDefFromSavedModel(const string& export_dir, const std::unordered_set& tags, MetaGraphDef* const meta_graph_def) { @@ -140,8 +155,7 @@ Status ReadSavedModelDebugInfoIfPresent( GraphDebugInfo debug_info; TF_RETURN_IF_ERROR( ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); - *debug_info_proto = - absl::make_unique(std::move(debug_info)); + *debug_info_proto = std::make_unique(std::move(debug_info)); } return OkStatus(); } diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h index f51fbeb557f..e82bd449c59 100644 --- a/tensorflow/cc/saved_model/reader.h +++ b/tensorflow/cc/saved_model/reader.h @@ -18,21 +18,26 @@ limitations under the License. #ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_ #define TENSORFLOW_CC_SAVED_MODEL_READER_H_ +#include #include #include #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" namespace tensorflow { +Status ReadSavedModel(absl::string_view export_dir, + SavedModel* saved_model_proto); + // Reads the SavedModel proto from saved_model.pb(txt) in the given directory, // finds the MetaGraphDef that matches the given set of tags and writes it to // the `meta_graph_def` parameter. Returns a failure status when the SavedModel // file does not exist or no MetaGraphDef matches the tags. Status ReadMetaGraphDefFromSavedModel(const string& export_dir, const std::unordered_set& tags, - MetaGraphDef* const meta_graph_def); + MetaGraphDef* meta_graph_def); // Store debug info from the SavedModel export dir. Status ReadSavedModelDebugInfoIfPresent( diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc index 4b8b5cde20d..7e00186b3ad 100644 --- a/tensorflow/cc/saved_model/reader_test.cc +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/reader.h" +#include #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/metrics.h" #include "tensorflow/cc/saved_model/tag_constants.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { @@ -39,6 +39,16 @@ string TestDataSharded() { "half_plus_two", "00000123"); } +string ChunkedSavedModel() { + return io::JoinPath("tensorflow", "cc", "saved_model", "testdata", + "chunked_saved_model", "chunked_model"); +} + +string NonChunkedSavedModel() { + return io::JoinPath("tensorflow", "cc", "saved_model", "testdata", + "chunked_saved_model", "non_chunked_model"); +} + class ReaderTest : public ::testing::Test { protected: ReaderTest() {} @@ -88,15 +98,6 @@ TEST_F(ReaderTest, NoTagMatchMultiple) { << st.message(); } -TEST_F(ReaderTest, PbtxtFormat) { - MetaGraphDef meta_graph_def; - - const string export_dir = GetDataDependencyFilepath(TestDataPbTxt()); - TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, - &meta_graph_def)); - CheckMetaGraphDef(meta_graph_def); -} - TEST_F(ReaderTest, InvalidExportPath) { MetaGraphDef meta_graph_def; @@ -136,5 +137,7 @@ TEST_F(ReaderTest, MetricsUpdatedSuccessfulRead) { EXPECT_EQ(metrics::SavedModelReadCount("1").value(), read_count_v1 + 1); } +// Placeholder for protosplitter merger merge test. + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.cpb b/tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.cpb new file mode 100644 index 00000000000..d9f76e3b4a9 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.cpb differ diff --git a/tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.pbtxt b/tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.pbtxt new file mode 100644 index 00000000000..4a37bd88fb4 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/chunked_saved_model/chunked_model/saved_model.pbtxt @@ -0,0 +1,2063 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + attr { + name: "allow_missing_files" + type: "bool" + default_value { + b: false + } + } + is_stateful: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "PartitionedCall" + input_arg { + name: "args" + type_list_attr: "Tin" + } + output_arg { + name: "output" + type_list_attr: "Tout" + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + } + attr { + name: "f" + type: "func" + } + attr { + name: "config" + type: "string" + default_value { + s: "" + } + } + attr { + name: "config_proto" + type: "string" + default_value { + s: "" + } + } + attr { + name: "executor_type" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Select" + input_arg { + name: "condition" + type: DT_BOOL + } + input_arg { + name: "t" + type_attr: "T" + } + input_arg { + name: "e" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "StatefulPartitionedCall" + input_arg { + name: "args" + type_list_attr: "Tin" + } + output_arg { + name: "output" + type_list_attr: "Tout" + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + } + attr { + name: "f" + type: "func" + } + attr { + name: "config" + type: "string" + default_value { + s: "" + } + } + attr { + name: "config_proto" + type: "string" + default_value { + s: "" + } + } + attr { + name: "executor_type" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + is_distributed_communication: true + } + op { + name: "StaticRegexFullMatch" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_BOOL + } + attr { + name: "pattern" + type: "string" + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + } + tags: "serve" + tensorflow_version: "2.14.0" + tensorflow_git_version: "unknown" + stripped_default_attrs: true + } + graph_def { + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_DOUBLE + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_DOUBLE + tensor_shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + tensor_content: "`\250S-\322\210\346?\021\261\230\257\371\257\340?\301\244\352\345g\221\346?\006\231\202C\203#\342?\241\001AY7\240\355?\031\275\343\257\231\272\340?X6B\007s\006\306?\364\216\221\020\241\'\303?\022\273\025\211)\010\336?\330[\324\264\201\231\323?\357\010\212\342\230\037\347?\240\272\301Fe\033\344?\3029J\332\352.\343?\2648\250\3644\034\317?\206\341\225\036\302\331\330?\200t\364\331Z\327\214?\364\341\007\n\201\002\350?b2\254e\243\006\327?\240ra\\\004#\255?\016\247\200\0006\362\334?z\203\247\034\253d\354?\320@pj\240\337\325?JM\251<\340\370\334?JQ{\352=w\330?\214\317D\r\267\314\336?\020\343!\237\026\300\336?\350\333\235\231r6\305?3l\223\335\343\323\347?\300z\344\275j\300\251?\201\345\321\227P\\\344?\031\030$\225\031\302\356?t\351d4\266D\315?\362\306m\265\254\037\354?\206\361\305F\324\002\343?\337\224\201\217\252\211\341?\302T\205x\271\205\345?!D7\244w\311\352?>U\244\250\333\260\330?@\2650\377 \360\271?\200#{\327H\032\206?\224\341\252\022\016\010\317?\024x\360l\215k\356?\300\263\234tb\230\206?Hbo\003\220<\320?`X\350\222H\202\256?\355\337\272\376Re\355?d\222bo}\365\340?\010\340*H\3636\352?\214;\2522\221\\\311?\305+\264\333\236\326\354?\225\257\272\270\242\026\342?\257\341V\245j\325\351?\300>\220\350\235\341\272?\226\\\n6\231_\335?h\323\323\200\220\366\344?\310P\000\351\353\255\330?i\233a\344\352\356\356?\'\365\371\206\330l\340?Ne\357\327\3001\323?\240\216\327_u8\260?x\224\220U\314\240\320?\256\177\267\346\234\n\334?+\302\250>\251\247\340?\214\352\273\013\263\274\322?\002\3351R\334v\347?J\343e\r\342\232\325?\230(\024\2345p\327?\362q\251\300;\346\345?Q\033\263\"y\320\346?\270\r\327\200i\347\302?\352\252\376i\252s\340?\252&\177,\253\r\332?*R]]\016\032\344?!\206{\017\226\026\353?\212S!\333!\247\325?\373U%\354\246\351\347?\230s\00607\212\306?2\355\376\323\343\315\325?\330%\252\212\332\313\333?\370\032.\244\267o\344?\224\212\324\371\264\300\331?\307\363\241\366\322\237\340?\250\275\000S9\373\301?\2449\254\237\025\323\316?\266\360mQ\264\016\327?Tof\271\267l\354?\341\234\036\301\311\314\347?\320\352FQ\tv\264?\005\250\300b\376\203\342?\000\026\366\351\034\372\335?\253\224\337 \370\266\343?\003\342nG\263\274\341?\214\233\200\032\264\331\351?o\'\253#2{\346?N\252\240p\323\236\343?\324\366\311\364\007\367\301?Z\234\361W6/\335?\"\020\002\307\260\357\344?\2562a\306\203s\323?PEG\003L\213\256?8\035\275$\253n\276?\310\305\315\306\n=\310?8\267\034\341w\315\314?p0\363\373x\205\276?Me\253\226\332\252\345?\3168\310M)\367\325?\036DE\220\344Z\326?\240\321G3\256x\250?r\233 \\$\374\323?ZN\367,\334d\357?l\212b\024\242\350\333?\366\213\206\005M4\342?a\254\273\355E?\357?\240u\026a>\374\271?\025\023\354G\305I\344?\177\340\332\272\237\230\346?2JJ\010>\264\350?d\032\276\263\350k\324?z\224\346\303\244j\327?8\205\237\244V\376\355?\246\371\363b#\r\321?`I\342\023\\4\307?\340\220\\c\307[\347?\004m\312\301\370\214\331?R\372E\331\034\206\336?\226c\2617K\360\351?\270\312\214\272\026\274\267?a\276\370az\007\353?\345n\364\321$[\345?b\366c\333\274\325\335?\021\n\232\0045O\342?`vl%\376\356\274?Y\254\363\014\256`\354?k\2025\t\266/\341?l\211\245\325\014{\344?&\350\033W|%\347?\306\260qD\274!\327?\0104c\000\334h\344?!\010\322L\364\001\356?\301\357:\023\302\304\351?\241ip1/\256\351?4H2j\352\321\326?FVN\224\335\273\324?oc\240=V\236\350?U\212h\212Wf\357?\350v\356M\325u\341?\010]\344\016\200Y\310?\370\355\271\371\217m\332?(\231\312\2514\270\263?\333w\367\036\333\224\342?0\375M\240\326\301\355?4\364\311\317\342\313\331?\200L\321\356O\356\237?\202J\032#Uv\345?6\206~j\010\357\345?@\343|\003\217,\261?\020\362\314o\n\226\346?t\030\024A{\013\303?\n%\230NQ\034\356?\222^I\033tN\345?\356\037\364T\177\363\323?`\322\277)\207\250\242?\214\325\233Wp\022\303?\303\360\213\241\311\200\351?\336\025\020\265\'\372\322?\374\225\257\'\273_\326?\314\006\010v\223\037\302?4\340\213\366\255\271\326?j<\335\005\005|\327?`J2\016\336\357\235?H\t\274\323\023\032\331?\000DVCL\033\315?\213\256\240\007\004\n\347?V=I\277\303\033\335?\2170\n\256\200K\355?@\231\257,\036\021\335?\370\345[7j#\306?\374\215\r\031iE\321?\374\374\357\235\263\'\321?\221\203\003\023\306?\354? 7\250\016g\013\274?h\250=\321d/\356?\031~\374V\016\233\340?t<\247~\313F\346?X\031\353\2008\252\353?P\364\033X\232\360\303?+((\241\023\360\341?\363\305\217\013&\235\342?\372\333\242\005~\025\345?\rO\277fe\211\350?\267\353\313B\303\315\351?\273\333K\251Y`\357?\332\324\375\367\310\034\345?\000\347\030\311\251\255\245?\235\231\376\010\'l\341?\317W\231\207q\366\342?uN7\336\351\017\344?$z\334\326I\013\337?~\361D\250\216\361\341?;\230O\371j\301\353?X\336\017\037\340\244\277?H\032^\021!\241\310?\266 xM\036 \337?\327\276\023\377\240\207\345?Wp\031\300\344=\341?\226\335\037\320E\003\340?b\331O\3301\313\325?0\373\\\377}\304\303?~lI\201\231)\334?\370x\n\320\261\266\325?C\353\342qB\260\356?\240\251\223\315\n\371\352?\222\362+\304\273V\326?\340f\322\334\370\357\234?\000\274\n\245`\324\\?\214}Yz\247\363\314?\202\2645P\352\327\354?\244Vv\321\372X\356?\230\312\003\032\372\376\321?\274\352\332\014\325s\306?8\263\370\322\016X\347?\215\240\3432.b\351?\202\277\361=\005#\340?\2721j4S\362\330?\207u\237\035\315\006\350?\231\374\231\0075=\355?\310\257fb\222`\320?NwJ~fB\332?x\255^\247\242\030\347?T\310 \341{\021\335?x`\335\031\207\274\321?\374\3254\240\355\342\336?T\203\247\024<\336\340?\240\212c\264\247\360\221?\022\316\352\340n\227\351?:\336|\377 \274\350?\013\230\2577\261\350\345?$>\375\255y\225\340?\267\017aLp\203\341?\234=\311\332\317\273\330?X\204\254\326\217\r\336?\320\257\337\325\366\335\256?\\\363\264\231\313\"\331?b\2620\026\361\315\352?-4\203\'\241k\347?nYg\366\006^\350?X5/X\224W\350?P(\010\363\004a\341?\220Q\'\240\010\332\273?\360R\232k\2229\336?\200\273\243Bg\321\267?\350\316FP\347(\332?\372\367\220\307\"\354\351?l\300\240\311\237\223\306?@\315c\264\007\250\246?\014\026\243\007B\314\336?P%\037\221B\357\343?c\034u\210\206e\357?\360\377\334}^\311\267?\"\311\032\305\330\364\351?HQ\253\213\251\225\351?\016\250\251\335\245\253\342?\017\2149\267\241\314\345?\272`\265k\316\013\320?\347\260OP\317\330\342?HvQs\302h\274?:Df\226\262\312\327?\236\235\004\260hE\332?\340\324\356A\245d\244?\300\245\270\270\217\234\302?/\227\3576(;\354?\005\014?\213!\315\342?$R\365\322\254\034\300??\020\255\354\264\307\343?\310S\200\207\221\376\313?\324\205iE\262\026\314?8\246\007\254\313-\325?<\001\274\331\0025\352?^\003\207\374JR\355?\304\034\242\n\355>\336?\371\2423\024\"\273\342?f\321\356\335\262\355\334?8\315\212\344\351W\345?\236b0\322\225\240\340?\034]\313\261\000\025\350?\340\370\255V(\265\241?\256\035\245\006\237*\347?p\250x\334\214 \264?U7\252p\310Z\343?\031\275j\321M\261\341?\264\362\262\217\324\311\323?\207)X?\007\363\345?dv\370m\207\356\353?pd\337\002[\203\325?_,MN\032a\356?$\244\227j\360\223\323?\355a\335\006\267\303\353?\010\361\351B\231\r\310?\340\017\321\177\2117\304?4n\342\345!\276\322?\336J\257y\224\201\350?\017\026\365OO\207\345?\250N7\313\222(\356?\250\325QMK\372\315?}\365\373\331\036\305\353?\266veT\202\365\352?\252\316\n\266\234~\332?\340T\001\027+\324\327?\234\375\240%\233\220\330?\014!f\373Qd\357?\0009k\372\374$\243?\244\t\022T\t\231\333?\351:@\362F\371\341?\002\035\305\007\223\364\320?\330\222\324\277\\\325\324?J8\002\200\305\333\346? \212tJ\346&\340?\250\260\014&\025\'\314?\300\277\"\031\355\300\254?+\016\220\236*c\347?\361\261\023\360\254\361\354?fe\331\203\025\032\326?hL\344\312\331\344\267?\323Yt\273\365\350\346?\254!{\352g\224\343?\316\375\236\222u\035\337?\217\314F\373Tt\353?\350\242o\317\204^\346?\314\353\357\262\276\331\340?\254\276{\3219!\315?\305\272@\370\376\337\347?0\251~\324\314\010\254?\000n\335\n\274C\334?\350\360\373\352\377\224\276?%\236\320\2759\201\351?\210\306AWq\232\306?L\302\301\256\345\350\334?`\332\305K\274.\242?\202dU\307\320\274\327?D\331\322\026\365U\314?%\177\237\005\331X\340?\321\360\357\353\243\344\353?P\026\314\013&\327\252?d\304\364$\024\255\346?\255\263\331\024^&\350?\025siS\000\023\353?$\203\213\222\206\374\311?kIg\266\326\240\357?c\031\022\354\322\372\346?H\344 9\0078\314?s\304\330\211\270\240\343?\026\250\305\232\316\333\322?\254\327\266VY\370\310?\267\375)v\210\207\357?\214\222\001\231m`\314?6p\302g\354\341\342?\360\265|+\305\215\300?\";\340\372\354<\342?\016\307\006\363\372\331\327?\360\315d\003\375\324\257?~\364\001\3347O\343?<_\n\316\264\262\301?\3524\253\322\255\234\342?6\020x \216V\342?\277\260I\252+\204\351?p\036\263f\360I\326?\352\263P$K\373\330?\207F\224\000\324\340\352?\3676O]\313\301\342?\202;\372\000[\327\332?L<\307\2433+\321?8$\023\"?n\327?\250\314\340\226\363I\322?\210c\t\264\024\244\344?2\375\233B\"\224\357?4T\201^)\000\302?dfu\317tZ\311?N3\300\2167I\321?\244\251B\341\361\305\304?\200\330\" \350E\357?\203\003\325\2725K\357?c\022\227?\036\316\353?L\231+\001\223i\354?\350\035\247\323\\\236\355?\236m\335af\334\352?\363\334L\2472\026\340?\370\354\223\317\354\244\312?\266\033\301D\327\206\323?4d)\260xw\325?|S\013\000\036g\340?Q\021K?\356~\340?\360\317\311\022\240C\357?\272\221D\332\222\340\332?\226\366\301\277\352\332\336?E\230T\277\320^\345?\274_\031\026\215\322\317?fy\016\3300\204\337?[\322c\214\334<\344?jb\017\357\021^\324?\274\227e\002\322\227\347?\370\005]\242^\t\345?\375\366\037\2630~\346?\206D\026\241\035\034\356?\270 \257\"\202-\317?/\341\276\306H\374\342?\260\034\020\317\r\"\337?\230\231]\021\245\223\271?C\272E|Z\264\352?0\302\000\276\245\242\340?\223\362\225W3\271\355?\344%\257\354$\037\333?\330\246\371jM\254\306?\200\257\320\333J\354\351?\300\254\311\206^5\227?\364y7\301\316.\307?\314\232W\tF\021\307?T+P&\326\250\352?\360G\3128\226\342\247?\262\214\007\016\017\335\340?|\270hH+{\330?K)\373\331\235B\340?0\313\313\0248M\262?\343\322\240m\203\251\355?x\307\333)\022s\322?\317o\t>\376<\341?\255P|\025Q\374\345?\'\320\037\230\250!\357?\356\343\002gV\366\344?d\021\340\001\007\332\334?\020=\221\330\336\335\352?`\2538B\251\324\225?\360\177\017:K\246\335?\202\316\366\265\004\245\347?\326\230\352\244Z\242\335?T\270\0176~\345\307?\220\3209\311\206\007\273?\204\"\r.\275\375\343?\023\352\003;\376}\340?\024u\266\223M\343\344?@h\3242\365\340\333?\251\n\240V\202\002\343?\261\334\'[w\247\340?Pn\331;\347a\262?\235y%\367\251\331\344?\014c \003\035\364\334?x\307\007\n\274\014\262?\374\023\300\362\226\323\350?\320\225\355\177$\333\333?})z\257\370\364\352?\333F\356\357\200\024\342?\244$\027\272\376o\342?\360\327\231\365\375\272\354?\000\266hU\246\324\341?\200N\036[\010\202\347?\227\025\303\353\223t\357?`\220\237\311\237R\300?~y\375\006\354\n\333?8\2353\'\360\375\324?`\004c\363\354\274\325?@~=66\202\332?\251A\274\251\214Y\342?J0\247c\201\225\341?\340\'t\315\354\217\264?@9\010\316\201/\216?%\374\224\3037\233\344?<\341\037\013{\232\354?\320\017\363\373~H\270?\311[\313\322}-\351?s\246\005;\325k\350?\210\251\252\r\334`\335?z\371\203\327\315H\354?\320\330$K\272U\336?\022\224\301R\236\336\341?\360\022\320on\025\345?\360\201c\244/\303\350?\264\247\327P\313\225\336?\237T\266\r\006\245\352?\272\"\262\256(D\355?\325\254\365\033\243\n\354?\024\307p/jV\316?V\332\210\312\212\343\324?\327\366N~\212\237\357?\200\025\252\263\371\036\207?\224wVd9\013\325?\250g\356\221f!\263?\300|\343\330\220\334\331?\362\037)\215\373\227\335?\372\360\000\212\302\244\346?\274~g\223\257\324\332?\024s\255WE\250\346?`b>xO\263\262?eR\007\367\273\233\341?/\330\326\361\312\037\347?\315h\316I\213\275\340?\300\373[gEk\305?2?I\021\262f\347?\250\377(\347\311\332\314?\300Q?o\240\315\274?\323~1\254-p\343?X\277&\tuL\271?\025\327\376\034j\217\347?\010\224\341\343\244U\264?\310n\233~\336\337\322?\337\343F\202\023\303\345?`[\225\010!\371\304?\211s\254\013\024G\345?\024\222\006\314\340\020\300?;\244\255\356\270\324\347?\020#\353\360\371\301\246?\\\000\013\240(\022\330?\276\331i\214\312\017\344?\345\346\000W\313i\353?\360MD\247\005\003\320?\030\201i\3634\344\313?\254B\031k\023\310\304?\030\316\030\215\204\335\267?#\233EM\230\344\354?bi\363u$\240\333?<\373\377D}G\310?\\%\003]\026\270\357?\224\005o67?\342?\342\020s\247k\"\325?\374_\324\343}q\313?%\236\234:\301\300\350?\000\010<=\000j\350?\230\3253\000\301\353\301?\244\367D\214\232\022\337?\364\ne\374\030\331\346?\213\342\225\372\n^\346?\274\260%\351V\010\340?x\220\353Q\347\330\304?\203v\200\266r\374\345?\242\375%8J\273\334?\007\245-U\362|\352?H\262\236f\2065\350?F\300\017\362\255\245\356?\212r\334\320C!\324?x\311\034D\037\362\303?/\207CE?x\342?,L\301d\275:\335?\342\335\330\031I\021\335?\237\342\001P\204\032\344?\010\241dYX~\327?\320\366\272\355\236z\304?\000\3667B6l\\?q\271\331\354c\301\341?\006<\226\337\254\342\354?\200\006k\025\2207\263?\001\361#\000\252i\346?\265\025\3546\367\017\340?\266\001\201\330\315\257\337?P\010\374!\n\372\327?@\267\3463\313x\211?wg/\336\363&\340?\346P\275zh\267\332?\264\230M[0\234\323?\365\244\303O\261W\343?\222\203\235\016\274\271\341?\277<\014k&\000\340?\230p\267J\315T\264?\230\364u\225\256\271\342?\314;\031\002\034\234\331?\240\373\236\217\372\230\333?\252\317A\217r\000\331?\000C\037\247\354r\325?\007V\013\374y/\345?\014\006\377\236j\353\331?\365\027\033\317\031\216\344?T\300\204$D\244\305?\000\343\237\272#\376\250?(\323N\277>O\345?A\275\036\343\037\375\346?|0\007m\020&\340?\214\277\336\001\247\313\314?\250\272\301\302\327\261\270?\372\365\037\034x$\346?B\260\370\212\346u\331?\270TLY\316\223\331?\270\341\300`\324P\304?\2721>\233D\037\341?\274\344+\257\204D\337?\225\005\036\205\034\037\356?\201]\221\237?\016\351?s\205\267\201\314\330\354?.@Re\253\r\327?\360T&\025\315\327\317?fd\021\242\360I\352?\200\362W\261\366\355\342?H\267\305$\313?\356?\347\206\241`\275\010\344?\340\314\177\344\017\230\235?\304\016T\252\207\031\330?\362\213\227$\331W\356?:\017C\205MT\323?\373\205\004\232X\366\350?@E\231\002\302O\357?z\3457\266\260\214\337?\205\210\037\265\210M\343?D\372\2760b6\343?r\362W\267k\333\341?>*\236an\244\353?\177\372fS\361R\344?N.\322\026\026t\321?\321\246\032\340\235\261\343?\000\373.\206\0312\261?\254\200k68[\331?L\256\225\036\006\t\347?%\023-QQ\252\356?z\021\3365\355\316\336?\221\260\225F<*\343?\366\035\300\344]G\356?\323\023t\021?\341\344?? \323\310\023\203\341?\200\216\331n8{\230?\010\334X\315\255]\300?qImhT\205\355?\302V\016\265el\323?\373\356\350\323\323\032\350?JeL{\205s\340?\006\027v\271\353\245\345?\377g\252<\032\376\341?R\213U\327\r\337\346?\240w\255\256r.\305?`\221X\270\240\\\264?0\3065.\365&\300?;\245`\266\025\203\341?\177|0\211\273\357\347?\0271TE\272\326\353?\250\226\n\036j\270\301?(\332\2610\033\352\300?\350\027BQ\215J\350?\235\t`\0307\225\353?\213\355\241E\"C\342?F\247\301\212\202\217\335?tXE*?P\314?\025\276?k\216\n\354?\311%\272\034\003\373\353?w\213\337=\276|\351?p\273\034k\227Z\304?\337=\023\267\230\241\353?`\376\004P\030\226\345?\3204\233\r\0258\306?T\343\204.)`\356?\326\355,\376\027P\333?{+\336\252\332\346\347?\300\270;\225\327\306\252?\000\023\275W+\r\247?h?\303\332\341\252\343?\324\214J\264x\032\344?\000\246P\354\312\024\205?\2729\326\220w\001\321?\230\t\367\354\355\325\304?h\310\234\233\tc\310?\355\003\254\223L\217\353?\314\"{\235\365\377\355?]\324\307\351^\205\355?\242_\224\023iu\344?X\316 \026\320\301\347?EVW\036E0\354?E\373>\366\211\357\344?\210a\272\353\350,\341?|\020\247\217\024\347\315?\275\210\244\226\226t\343?p\351\242\266\246\377\246?\032`\262\302\214C\352?\300\020\331\t\304\371\333?@\222 \370\204n\302?\240\346LQH}\245?\300\016eNYX\331?`Q;\201\031\272\223?\020XX\364\2661\254?\000\214\'*\"\352x?\024\343Z\336\352+\303?\240es*\026\316\253?rX1]\365\340\353?\320\247,\030\237\374\240?\252s*\303~\312\336? \255\216(\274O\317?L\335\267\373\375\301\310?\215\207i\225\024\020\346?\314\322\361\322\364\375\356?\244\222\2564>\320\323?\240\374\005\307\265\343\267?5;\216\206j\272\352?\230\\\323g\362G\323?P\226\367-:\317\271?\265\317\305\025\240\226\340?$\326\336]\311\252\326?\320\376\275&$i\303?\n[\231\316\374\372\357?\000\026\030GW\377\230?0\200\237\315~\326\276?\032\215\243\215^n\353?\370\225*\025\310\307\342?_\177\324v\371\325\346?2\364\371RB\231\320?|\344\331\266D\264\344?\310]w\301\270|\340?\321\000l\262\233\271\341?\033\260\023\246F\001\345?i\233s\233\3000\347?lw\364jMB\300?\316~\277\027}D\351?;yTV(S\354?\000~\t\036@\276\334?\233\237Muh;\351?\021\213\2437F\257\347?R\375,p\266\247\351?\356\354\277\354w\032\322?\234\306Ru\0009\301?}!\252\244 \021\354?0B\322.B\265\300?|/2\203\340\370\330?P\335N18\227\253?8\353\242\262\203\t\324?\350v\n\334\264\204\337?rTY=\303\300\345?\300\365*\277\272\365\266?\022\271\373\307Sn\346?\376\354]]\226\326\320?\320\207iU\234\311\345?\303\017,\321\236\336\343?(\303mb\022;\271?\3325\362\017K\236\332?\016Z\376I:n\336?\330$\266\307\232K\324?\214\0374\357\330M\330?\220\026kq\221\261\322?\301\241K\2017g\350?~\211IL\2463\320?\004\344\030\030\036\365\333?.\325n\311@h\354?\017\217\264~;\r\357?\216;\224\206\240\235\357?\330\212\302c\323y\341?\303\271u\345\371\323\347?4\262|\336L7\337?@\345\030\332\375\376\241?\224R\255\333+\360\311?\362\320\201\305\235\214\327?\010Q~\326\255\331\307? x\250|I{\277?\341g\376DIo\345?\300\205\353E\027T\332?X\352\357V)\002\330?\034\250\351uk\001\336?$\263]$wR\300?\250&\2147W\210\356?\250\017\030\362\311E\346?j\n!\253\201\230\342?}9:\375g\361\356?#\\\342\203\3503\341?H5\371\002;\254\263?(\306\234\205\346\373\312?\305u\"\247\314\222\351?\360M\232;;\321\274?\000\3373x\356\035\271?\354\3724;\300=\331?\222\347Q\023\236\027\343?\250\"|\327A\306\337?\360:\354l\253\013\312?\n\001?\332\332\326\356?!%\257\274N\361\346?|\360*\271\304\033\330?\317\006J\256F\003\353?Ak72|\267\345?\2416\335V\216:\347?\204o/\337V\305\341?\344\320(Z\332\310\341?2\252y\370\000\260\357?H\220\317\332\375\013\303?I>\261J\302\004\341?z\235\215\200w\t\322?\301\253\232\271\331\347\357?\002\243\271\356Ku\352?{\304\356\266\307I\350?\010Is\316\263_\305?p,\264\321)\023\332?VyBm&g\320?\210\311\216g \304\321?r\265\343\037\360\357\332?F\026\271!\345;\350?g\251\214\276Z\206\344?\3424[\241\265.\355?\253nG\266\276\316\350?4\243;\266\310\275\335?x\310\262\364^\001\353?8<\tU\361m\305?\2626t\342\216-\353?,\347\223,\273\034\314?\376\262\200J:\347\351?\366\325D\371Y\310\356?dE\324\330e6\333?\2260\024!\350{\356?!:\265\361%\266\347?H\'\247\271\266\365\332?@\212\274A\341-\320?\013\336\314k7\373\344?\301\006]\240\307\301\353?\304V\214o\306\352\316?\314\315\253\237\270c\345?\000\357\215\376\r&e?\242g\331\334|L\320?\342\356\227a\006C\350?\224\374\305Z;8\343?\352\222_\354jX\355?%~\"\270\036\363\346?t\202\206F\353\327\343? \234\323o\006;\324?\355\'5s\372.\357?}\021W\004V\271\342?\202,\224s\216\220\351?\266\215\001\3700\245\320?\277\237\316\273\302\204\346?\332\rv^i\260\337?\302y|m\230!\320?Pb\331Z\273&\311?\330,lg\324\024\276?\007\n_\03780\354?\354\312$\036\022\311\344?\301O\321\312\0105\355?\340d\200<\375o\271?\2258(\324\300\033\351??\250G\247\2400\354?p\251|S\267\267\300?\306MRFCF\346?\211O\027\373\371t\355?D\321g\234!X\327?\200M\023(\276\277\324?P\251\207\021\375\225\241?\000\026i{y\r\307? \024_\276L\246\357?dG\212\302\022\350\335?\033V\303\010\277\216\347?\244\254\007\272CP\343?=2\350\245}\022\347?NW\331\257\020p\356?\275V\245\214\355\273\350?u\324(\270\376\362\340?$pe)\321\270\341?p\354D\252\210.\332?\345\305{\354\213K\350?\367J\354\340\031\367\343?c\353\033\331\346\335\357?\3723d\216\030\314\353?h,\207\321\226\004\275?\322\036 /\201\304\356?0\266C\006\243\352\264?\255\331\223\331\233r\353?`S~u\323-\306?md\250\221\277\276\353?\320\026\357\316\232\376\354?\200\267n\236\271ru?\304\277\360\034qj\306?\214\362\264\234\337\256\350?\310\342\266d\316\236\331?\3747\345\253\231(\337?#\032d\034\374\275\343?\367RK\366<\030\357?t\316F\262?o\353?\024\334Z<\203\035\341?X~\230y\303\036\303?P\024\266\276\304:\274?\300h\353Y\332h\332?~\302Xun\237\320?\341\251\250\323\350\371\355?\316\003\250C\267\361\326?\200\036\306\272k\322\256?\224\264\3105)\177\322?uP\334\002UR\353?\362\014\341hp\222\336?\307\330\260;\203\341\340?\327Y\314q\3402\351?\000\013e\306\331\362\202?\024\023\236\025q\307\325?\222\251!\246\273\255\320?\036Z9\2015Q\347?F=\254CMX\344?\330\253bO=+\357?8\302\206\206\204u\313?\213\233\365C\006\276\352?y5\331U\377\326\341?\363\356\nb\322\244\353?\315\200\334r\013\200\342?\002\351\326\217\210\353\356?P,E-\010a\275?\232\001S#\255\210\332?5\347s\035\272\331\347?4\345\341\000\300\214\310?\364=E+{\300\337?`\225\203f\376\206\244?\336?\300\32610<%\346?\210\367\3647J\273\357?\004\275\351\250\2753\335?\202W\275\2459\326\321?\370\010\364\272_\245\312?\363\315%wz\033\341?\363\001\241\373\252\233\350?\240Q7\353&\303\313?\263>\304\035\272L\356?ld\275v\307\216\352?.9yA31\325?\263K\251\034\335\246\343?\272\314\347\311Pc\333?\210\363\276t\240}\327?:\327yG\304V\356?\346\303l\024\340\303\325?@\231\002gP\321\263?\263H\230\242\305\205\345?V\263\256\227\323\'\321?\212\220\370cj\345\327?\204\030W\003\352\\\310?\000G\000\020\022\232\222?\356\002Z\2137n\351?\242+\333\216>j\351?\024\025\177\234w\221\341?n\362\027\240\240\350\340?^?JX\275\305\324?\000\233\312\363/\371u?x\006\215\323\200\341\352?\260\227\222\204\304\030\303?P\036\223~t\006\254?\014\306C\320\315\367\336?\350\006)es^\261?\2428\034\255\204\311$\325*\324?d\350~\204e\311\316? \311\265\035\262#\254?(^\331,\334k\354?$\223\364\233B\232\311?P\360\305\r\261@\251?$p\267r\311\022\343? \244!\341?`\253?o\324\355^w\276\352?\036\347\n\227\250\254\352?$\216\345\343\336*\303?\004r\203Ge\364\335?X\272-Q\013\000\266?:\2528\356!Q\354?wL\351\022\330U\344?\2306\r\214<>\323?\323\356 y,\271\352?]\346@9g\324\343?\306\272Q\217\266.\321?\220\246_\"Q\035\331?\320(\274\251k\304\317?\267Y\222r\311\263\345?\347\352\262\332\341\353\345?\314\226\276\376\354\244\312?Dh\322C\230F\306?E\321O\234/\022\353?6Z\344:F\221\352?(\207\034\013\350N\276?\327\320W\031\305\177\352?\332/\334nu\220\351?\000N\253\300X\356\311?H\322\025%\304\254\265?\204\234\264\234\224G\320?\030<\212hE\373\337?\371\224[k\010\256\352?\350I\327\014\232F\342?\"\243\227%\017\252\355?@\335}4\013\373\273?C\372jH\030\332\346?\317\220\364n\350H\340?tW\360\277;\n\324?\260\0318\343|\"\330?\000\231\231O2\014u?\000\2178\343_\347\247?\300\262\000g\260\314\341?\324wU\005KE\305?\240\250\252\377\363\272\344?G\323\030\013\210n\356?\205\036\244\353w\230\340?\222\273\354KOA\320?@\340\2460,Z\243?6|\315\023\322\244\337?\002CPEzm\322?\016\206[~ I\357?<|*\030=2\312?u\270.P:\267\346?\277\370\007\252\321K\351?8\036l`\201\017\344?\000\235\307|\215N\273?V\206-#6\002\320?\t\006\2153\314\225\344?\240\363)\231\376,\311?\002\243,\250\035\306\326?\335\310\005\305gx\353?q\261\034d\257\221\342?o\344\321\257j\300\343?W\342\370e\315\241\357?\n\235\013V:Q\327?]\273\205D\256\321\347?\002\005A\314\234\005\324?\273\315\271\300]\223\343?h\3233\315\270@\325?\376|@\016?Y\335?P\351\253\036\335;\346?\220\323\022\354\365X\305?=)\303d\220%\342?\267m\036a@\030\340?\372\013Lv\3123\331?b\"\253\335\323%\355?~\330\266h\220b\353?\226\',\"\254\273\322?\306\037\323\267\000\021\357?4\205P\254\002\026\357?@m\313\275\031\347\204?A\313\275\212\374\217\344?n\250\262\034g.\345?\002\361v\202\013%\337?[\2525\274\204\244\357?\250\216P\255\204\256\274?\311\325\302\232\337\246\357?\376\027\311.\247|\345?lgUlh\233\331?\300-q\241\210\305\237?\2164e\002\335$\331?\215\271\360C\231\203\357?\360\026\222\237\003\235\326?n\024\225~\374\340\356?\247X\000W\013T\357?;\361\205M<\223\345?E,\217\307{\010\347?h\272\311\2414\226\342?\'6\022\010\371\364\357?%\220)\003\313\224\354?`\272\223\034\201\206\302?k\231h\324\301\274\340?T\277E\250M\346\300?\230\332\245\250\246\314\345?\034\213\307\005\275\220\343?\034\037[\021\271\003\341?\030\010\356\377.s\312?\232B\325\217\027\005\354?\214\\K\317\343\216\305?\347a>\247f\033\342?S\346\236\234\027T\343?v\000\204\262\270\\\345?\t\201\331W@\217\351?;\005\250\251T\001\340?\346\367#f\320=\333?\310\007\215\026\025[\274?\362}(\337\307\235\357?\212\031\353\007\220\271\333?\316\014@n\261\344\356?\200\343v\016\266\243\227?\023\203`|\217\314\345?A\225\234\'\247\374\341?\200\364\272T\324a\340?\350\264\303\370\252\226\324?\235\023B\266\\\261\340?C\252\334\254\236S\347?\365\203<\375\222\314\343?\034\002\334\376\277\026\350?\337\013\246\217\335\263\357?\214\236\2225\363 \341?\025\216\300\361>\254\356?\036\003\237u~\354\323?8\031\210\222\262\374\345?M\252\027\341q9\354?co\211.\360\347\351?\230\266\r\307s\322\356?\224\377\325n\363P\353?\300\002\222\2063\245\306?\207HE\004\010\341\355?|\331\335\237\370u\337?\250\242U\254\026e\265?0\302\363\030\221\225\306?@\177\026F\312\036\206?\365\274r\327}1\344?\003\262\010\255\221%\357?\024\372\037\306\014\276\343?\010\002\371h\300p\347?\323\233,]t\000\352?\265)\306`\226\267\344?\010\372\037fvf\322?p,\"\210\222$\242?\204\023\364\206\223\261\350?\311]y+V\213\347?A\034\246\277\335$\342?\203\001Is\214\205\357?\242\366_\303\220y\350?\010i\'e\320\353\315?\340I\317m\346\264\332?o\\_\255\000\n\352?\265\355I\t\3627\345?\273\3505X(\372\357?\306\"!\212\227\241\335?X\021\316\266\016\216\325?\210\0070\226\241S\306?]6\2116\233\300\340?\300\352$\333\373Y\273?\017\007\026g\356\004\355?P\317Qv\3319\250?\375\352\030[\357\310\355?\007$\005\333\312\316\343?\237~\315\321vN\340?\331F\232O\025\240\345?\010\360\252_\254\014\353?0\033\275\262!(\323?X\207\031\303\300g\356?\275\242\261\rJ\260\346?\343\025\222=\304\312\343?\234\260\231\215\362w\317?\200\207\'\026\222\300\203?@\277\377/\220\200\252?\306R\207\237k\036\327?\205\332VN\005\337\353? )\017?M\243\233?_Z&\265\r\371\340?\364-]j\316\307\337?L?\204\010[\304\330?RF\254\346;\366\324?a\000Pb.k\350?.\374\341*Q\r\355?\357\236\306\305\344\333\347?Na\231\363\037\372\345?\236R\345\214\314\257\322?\241mU3\005\037\346?\275.\360\n\"\317\353?\254J\371\037P\311\336?207mV\325\323?-\377\246#\342(\350?6\345*\317\027+\346?R\337\243\334\033>\343?\304\n\017\2520\317\340?[\227\206\034\365\225\347?l\344\306\361\'\n\314?\344\330\323\253}7\310?\214D\027-`i\347?\002\302\256\304\375\376\352?\010\367 E=\241\313?i\326t\330\242\001\346?\206z\306\205\2658\332?\261\tJ\272\203\376\341?\200\362z\260!o\216? \177\255\277L\302\346?\236\354\373\'\311\277\327?Sm*\313\225H\343?.\311D}f\032\350?X\034N\377\264o\276?\320\277S\360\335w\262?\362\313z\317S\254\344?\001\010\247\324\367\243\351?8\021\266\353\301\206\327?\327\234P\265\\\353\355?L\364\022(\334\304\331?\224\014\302{\250\346\305?\001\214\236\351\355\032\353?X\251s^Z\305\303?d)Ml$\247\321?j\033\243\':\373\350?\363tOk^\222\357?\370\274y\210\357\307\357?(\210P\3600`\265?B\324>&\327\022\350?>7#\212\201\325\326?fd\000V\366\273\356?i:\331%A\311\342?\004\2752\273\014\367\333?\230\255\203T\262Y\261?k\034N\345]\003\346?\002H,\240U\200\344?V\343\r\261\\\215\341?\026?\222\300\362\361\340?\341\277\217VgK\345?\365\005\372*=\317\355?h\350#\350k8\275?\247[\025\247f\020\352?\254\354\303\371\373\021\300?\233,\307\230\271\243\351?\004R\247\230\265\376\323?\324\2376o\341\365\331?\036\261\244y\277\235\321?`\\k\337_]\266?@vO\231It\352?\215\036A\376\364o\352?L\207m\240\017T\310?\214+\034\371\271\006\317?0\337\023\326s\236\354?P\351\353\320a\327\257?\002\206\262K\334\005\351?D\260\302\230\250b\352?*y\013\333U\211\331?\276\323\001F\336\330\344?\240\373%\364\222\201\255?Z\r+\251\343(\357?\n<\274io\313\320?x17\014\027\377\327?\270F\350;\350\233\307?\"\306~\245\307\345\351?\212\33127PN\342?H\343O\243\2265\331?/\304\000\330M\020\340?+j\033\274Y\017\354?\204\201\347\006~h\342?`\361\3442\251\t\331?\340\312\362+\367\241\254?>\005DL\036\311\350?\024}X\205\\\326\355?\300\020\322iD#\236?\370u\001?\357\005\265?\217\262\360\2004\312\344?H\'\234\266\300\336\351?p\006F\302-\340\322?\362J\361\205\017\317\323?|h\325Kb\377\300?\325\245v\035\032{\344?\336\007f\006_s\322?\310Z\277\005o\010\305?\376H8\241\370\267\354?/\344u\322\354\013\342?\320\030\202\255M\257\305?\335/3\312\002\315\356?\271\2743\254up\341?\264\246\350\362NU\320?\317!N\344\371|\344?,Y\263\275d\322\335?\337\354\311\201\256\311\350?\224\335\026y\022y\321?\016,.\342\020\376\330?P\274\226*\250\325\267?}\330\363\013\353\312\355?\356\006(\t\2060\344?\231I\222\243\264\320\352?\245\274\337\200L\004\352?\240\215\300\032\372X\320?(\374-\262\267a\276?\330x\357AL\231\325?\014\316F7\242\312\325?\200\\\024*\n\226\326?\263`\254\014\206\203\344?5\177$\355$\004\354?\274+\3559\035\357\340?\356\307\216\205\260\335\342?S\362\363\263\2356\345?\241r:\264\277r\345?Dcp\206\251k\313?\253\375m\000\212\251\342?\374h\201(\214\t\357?]\333\234;\332\204\350?\3760\3126\240\326\330?\307\326\314\321\256\016\357?\361\204<\026]\216\344?i]\'\327\177b\351?\241\260\3603\214O\347?\354\362\001\270R\374\315?Js2B\260\006\344?\376\3220tTT\340?\312\t\341\202\340\271\353?\266\372\205\231\274G\353?\214a\230\371\313\267\313?6\237\243o\244D\346?\256\275l\267G&\351?{\302.\220t\230\340?\334B\320\202^\226\345?\366\276z\203\271\301\350?7\344\261\332n\342\344?\002G\357O\253\316\320?t\343\010\262\030\031\352?\030+\323\314\212\301\320?\014\364\342\212\212\320\305?\244\352~!o\233\347?t\202GyI\357\347?\216\036\331?\275;#\300%\235\353?\030\243\243?]\001\343?\246~\"\270\352\233\350?[\214Y\231\235w\342?\250$\342\304\224k\263?8\261\333\240\310\306\307?\302\301\177\271\260\202\355?T\002\004\220W|\327?\340z\204\226&\313\301?\254\213A\360\036\251\322?0\325f\226\244\350\351?\270S\262\023\007\265\314?\364\306L?R\232\332?\346\264\003\263\242\323\334?\014R\267\256C\000\345?\033\230~\016U\000\357?\271N\226\301\250\214\355?h\356s:\241\254\260?=\2062\270o\233\356? %vb\203\336\251?\335\2718\3145_\352?X\322\246=\236P\332?\374\211\361|)i\357?\307 \263\264c)\346?\331\001U1\345z\343?\236\264~nl\302\351?\267\316\260\354\215Z\342?Hqe\037i\352\357?p\226\276VT\240\260?(\3506\366H\003\266? B\355\330e\245\225?\036\245\200\024ff\353?A;N\2709\265\345?\000\217\315\221\356\273\243?\224\264\201\365\241F\312?\251\006\357\341\215\367\353?&\231\353w\225\246\324?W\025\020\230L\335\346?\242p\013pX\217\334?\200\014=?{Y\210?@\240zy\263_\257?|\025\210\326n\357\343?\335/\242s\023@\352?\272\215\231\177*\001\321?\335\245c-\262h\344?\007|\003\272\210\223\342?\371\r\255\001\334\024\354?\274\212J\277>E\323?\374\362\027\357:0\312?\222\343\206\233\236W\353?Z\251\013@\030\366\325?\352Er\323\345_\325?\352\215|\303\001\364\345?\3557\220\262\237\377\355?l\256\226z\235\363\354?\273&6\036\341\177\357?\360i\334\303\246\266\333?\310\004\2545\307\311\356?\'\301R\343\223\332\345?\244\373\304\034\2005\331?\000=\331|\234\331j?\344#\306Z\2533\356?\276\307\233en\023\340?\\\231Z\201Z\334\316?\321O\2278\311\336\342?\235\201\224x\261\257\346?\000\332\0242\332(\255?\215\234\277\000\243\267\357?\0141q\310\203\244\323?\375\177\337h\276\273\354?\260/\317p\n\265\267?L\200\241\271\306r\322?\2753LG\367@\357?\220\364!d\345k\337? r\363\033\t5\311?\236g\006\323\254\"\352?x\225\013\354\231\025\357?\200\225&\nPg\300?`/\267j\343\212\353?\230A\232u\330\036\351?(\177\010\243\2246\337?h|=\302\252\256\330?\022%H\356zU\352?8JcE\232b\342?\300\2468\302)\013\322?\337\337C\t\226\317\344?B\213\271\354\3364\357?\240\311g\230C\365\222?\3463\206\310\277.\326?\270z\231\016\345\216\265?\261]\216\323)\330\345?x\263\210\225\207*\342?\375iO\257\237\313\351?\031\340\233C;V\355?\366\227\355\244\266\212\346?\2205\221\361\274\374\312?\372\202WC\351V\354?d\334C\266\333\340\311?2#\215\007\323K\357?\022J\253\211\251\t\320?s\244\261g\330\225\351?\204\033\214\020 `\321?F\t\367\324y\307\321?\326\337\205c>\037\340?{\373\200\342\230\020\357?e,C?+\314\341?\000\240kQQ%\214?\260\240\033\333\006\241\332?D\354\264&:?\324?\232\374\312\215u\264\320?&\262\265\027\336L\357?\320\333\237\\_\363\316?\304\231\345\r\261\337\342?\212\307l\367\017\354\355?\000Y\001.\357\230\275?\026\255\373BK\262\334?\360h\260\363\352*\311?\370\366~D\2038\302?F\246q|n\273\353?\364\240\2609\252\223\317?\201\232\345\030\322\253\351?\004\245\241/n\272\323?p\035\002\245\221\033\310?\274\342\377\241\017:\306?2\024\034D( \356?P\252\315|:\224\261?\227\323g\274cM\354?\000\263nTr\303\330?\362L\027X0)\340?\270K6\357\253s\352?)E\207\036X\207\346?\234\321\212QV\346\321?\321`\037\352U8\352?H(\334\204\324\322\264?}:\312rWI\344?\230\005\366s\2448\343?\3410\320sK-\346?uY\024$\216\242\342?\242\203\316\206\317\017\345?x\241D1\360 \333?\220\373u[\"\240\300?\210^\241\356\246;\313?\002\2003P\203J\326?bzL&?]\347?\334\251[)\330/\340?\007:\345\261\t6\340?-\353\321\302\303\005\342?\264\301l\343\303\003\332?D\334@e6\244\313?\373>O\240\006\362\346?f\231\3543\267\020\351?\030\330\314\336\304N\322?>\271\365\322\213\361\353?\034\262&>U\212\334?G\320\252B;\220\356?\254\t\233H\351L\307?\356\271S\330\236\354\345?\213\021x\267\206\313\340?\200K\375\272\205\"\202?\320\036\222?$\031\343?\030?\345\332\201\243\354?\024\000r[\236\023\306?\345\361\350\314t\267\357?\310\212\306\303B\266\335?\230\274\235v%\202\276?>\211\366\241\362d\355?\332\nQ\213Q\351\333?\230|\216+o^\343?\306*\270\216\214\035\345?\200Zo\250v\017\217?\300\325\273<\346!\213?`\372\036\321Ho\277?\t\340\266:>\342\351?@\351z\247x~\351?k\376\214p\002\303\343?\032\256\306\312\271v\333?HH\340\2270 \314?\222\205\302\367\262\027\351?\237Kg\242o=\355?0\246\205\2245\356\302?\357u\310x(\234\343?\3710\232\202\363\304\344?\240\317:lO>\241?B\313\030\014mQ\330?\316\222\321\242U\016\322?.\017\333?\245T\357?9S\007\215%\330\350?\034\357\205\215F\305\317?\242\251[*\217\020\352?hi\211\0134\322\332?:`L\326\201S\356?\270\324\266\254\345\336\335?H,\234\3552D\317?\350\343p\003\365\343\331?L\366!\022\233\362\334?yMJ\344\255\333\346?\355\033\215\265\263\027\355?L\361\002126\315?\362\243E\363\311\030\356?1\2420\357\214\353\351?\376l\202:\227\340\344?0\344\373\334\214\200\326?\216\206\265;:\232\351?\207T\217p\n\320\347?\224\030\366R\343\216\300? \273|=\341\337\266?t\204\320\3502\375\303?\200Z\233O{\037\323?\212c\335O\010\260\353?{\374\324\341\216\031\355?8D\027+\026\016\352?K\321\352\257\007\235\352?\203\276\313bLU\344?Xdi^p\301\323?\276b\205\345#\025\320?\324\203\026x\024\365\345?.O\264\003\032t\353?\\\034\007\273M\013\345?\335\310\010oU\274\344?\244\247R~\266\232\301?yw\016\236\np\356?\205\275\2169\321\030\341?\210#\317k\347J\267?\2521!z`\362\347?(:\346?3X\330?\333\362\001\261D\017\347?\344\004\342\375\207\016\353?r\331\264G\336\363\353?\236\333H\177\244Z\330?\264\256\275n\2057\345?\310\232\"R\346w\357?\314\225\005\014s\316\347?`\270l\233P\036\346?\252\251\006\355\357\253\354?\200\n:\225\270+\237?jL\033\300\252\337\321?\n\314l7$\232\344?\017\320\256\220\370\234\344?\247e\237\246Vk\351?j\227\330\217Z\363\355?\3358\275\237\036\261\345?,\240\220J\266\002\340?\220\303\026`+\253\273?\274\3579\033\251\233\341?\220\2044Y]\372\310?\302%\252NI\275\350?\220\327R!\225\307\323?$BR\234GZ\322?\365(\324\253\027\214\345?|\031u\273\361e\324?f\371`&\255\034\341?\367~\226\204d\\\356?\2000,\367!\200\221?\254\001\003\256}3\341?K\3330<\2317\343?(\020H\367 \377\276?,\221\236\363\251\017\315?\020\034\343\352g\307\250?\250\235\004\021\2715\275?\220\0361\016>\340\253?$\323\3431\321\005\334?\000r\'X\253UY?r\354P\025eE\333?\0304\223\252\206\351\345?\2358\252\226\203\201\352?.\212\232\323b^\352?\200\247\250\t\334S\261?\317\377\034\224\"\021\352?\370\313=\232\266a\356?\260D\002\306s!\271?7\204\213t\306\274\351?\272\2212\003\265\353\345?\230\216\236\222h\006\324?H\356\033 VM\323?\300\252\361\023\000o\244?\204\034}Q\232\256\344?\351\243\262\031\272c\350?n\034\234\206\315\353\330?\340\256\322\032\3351\354?\300\227\314\330\016e\263?\014v\300\0034-\351?ST\367o&\370\351?`F\234\214\356t\333? n\340\207t\357\241?\016\315\266\231\327L\336?\265\354u\226\346\335\344?\277)\300m\260\334\346?\307\272,\03592\352?\334\315\346\361+\345\331?\3405\266\202\315\005\301?4\272@\334[a\303?\200v\343x\027\236\250?e\206\"\354ic\353?\244nM\306\037o\305?\366#(\317\216f\344?\344\001\327\276\350\273\315?\\\250\217\320>B\306?j\0311\207\254\270\334?\374\262\330l\032\317\320?\253\000\036\332S\264\343?\316\272\177{\313l\341?]h~\365\370\335\351?P\327\307\335\036\335\337?\023|\356\352\302\256\347?iB\n\330-|\346?\335\203\304=9\217\353? \r\203$\212\261\224?\331\355L_L\250\341?4\031\365\327C\316\330?\000\233\204&\362X\256?@\302/\323Ue\205?]\335]\341\237\264\354?\370\205\270\230<\275\335?H\335\264\036`\261\312?\342\335\n\tr\211\353?\356@Z\353.\363\341?\300\272\034\256\2145\331?pv\306#\324\032\320?8\364i\020\037:\354?\230\201\260k\036\010\267?.e\323B\036\326\326?c\301\353\036UI\353?x\211\334]\253\007\261?\240W\206\222wj\323?g\351\377{R\321\351?\204\341REtj\316?\377\226\377\003}%\346?\360\212\006;\035\374\263?\365n/\316;Q\355?/\216\036\240\3473\341?\267\003\357\262\330\037\353?\224j\351\244NN\350?{\32116\003\373\350?0D\031\317-\302\322?N\315Z\342<\331\325?\016x\031\217\346\013\322?\240\212@gSC\354?\200\323\275Z\266\265\251?Y\177\241dSC\347?|\340,\242\215\313\352?\370\366\266\236\033_\333?+\033um\363d\341?\036\033\253q\277t\324?\330\3669\305W\207\326?\206\340\217\336z~\321?\343wd\367\330\030\345?\017\215\0231\347\014\356?\035N@\261uC\346?z\256\266\021\016{\323?\224q\250(\211\206\322?oO\200\347Z\222\356?H\230\256\352\224\320\300?r\234>\222\346\321\350?\014\313\266\222:[\356?\006YN\034\235\230\327?\030*`F@\327\306?3\246\371\307}\361\351?\312yn\017\267\332\324?\221\215 \370F\270\350?/U\030\006\r\322\342?P2\002\215\246\212\346?P\213n\211\034\351\313?\240d\247\311\\\277\351?\230\3362\002U\233\271?<\336\307\333\023\014\351?\006Y\227\346\016\325\355? B\035\t\210\275\223?\030\035V\253z\302\356?\005\004\3258\rS\354?\236\000;\217:\016\331?\240?\037\200\272+\355?N,\355\336\331/\354?\316!\016\366?\n\350?\360\007\224\277\331_\301?$\020\314\246\026\373\336?\232\323\331$\n\246\334?\215\213\353\334\307I\342?F\260u\376B\003\345?\300\250\312V\220\246\341?\000\243\202\375\024\335\314?\264q\275\331f\250\317?#\245:\212}R\357?\226\200\003\036\213\302\341?\231\203\220\232A\245\347?\346\203Hx\273\353\324?\010\r\200v\273\221\275?@?\\cJ\277\335?WM\377h\364\315\346?HL\016\272\020\206\275?\t\316\035\321\363\n\352?IO\036\375\252\274\353?\354\367\031m\013\027\313?\017\357\224BE\271\356?\316\264\037X\212\004\347?\311I\r)Z\261\343?\266\0145\275\235\325\321?\245*\313\344\273H\343?\246\t\325F\301\007\350?(\316\254p?L\317?\276\3658\261\340\377\336?$\265\311F2\243\352?\025\332\231\005\256\323\345?+\036\226\244F\235\341?\347\360\204H_]\354?\265\033\314\235\033\303\343?7~M\343,c\351?*\005G\311ef\356?\342{j\341\335\003\355?`\025\246\022\270\357\255?\323\303\202l\262H\344?\375Wd\310~n\357?h\345S\267\374\315\275?.\032`{\251z\346?r\354=\t\223\036\334?|\323m\243\355E\324?X)\213\037\272\263\260?\363Y\037\0310\356\345?|\311\031y|\315\330?\352nE]\010\373\357?\232\177d\3629\356\323?\335\265\n@/\376\356?C\356\224\375\203\366\351?\352\r\300\002wx\356?\342\377\250t\261\363\320?\352saF\325\201\351?\323z\264\225S\233\344?T\370\210\222A\354\340?\331GRQ\355P\357?L!\020/\250w\354?\322\275\364\222\320]\323?jJ\3751U\337\341?\325\260U\253\316O\346?p\210\332\036](\270? \024St=\201\264?|\251\272,zY\357?\260/\rK\213\330\355?\3642k\253\272\335\314?\340\2718\363\304\034\351?H\265\200V\301\305\310?@c\000!s\260\350?$g@\030\315}\343?:[Avj#\340?\260\362\241\357\346\307\352?\363\310\374\217\342d\340?*m\367\204M\272\354?\213+g\225SS\355?\207y\257\267i\277\354?p4F\'\225)\272?\030\3733\375\274\306\327?\210\275%\037\036\234\351?D\202\232\340\374\262\324?$\r1\367&\236\352?\341\277\002\373\\\374\350?>\347w\362|h\355?~\\\224\276\364\023\357?\274.\336=\026<\332?\342\020\325\263\230,\321?1\350\372\337\371\303\357?\311\257W\326\367\365\346?V-\006\004\r\345\354?\266\266\317%\306b\345?\000N7\347Q\322\210?\022 \274\351\355g\337?`\212\372\302\367\007\270?\234\305\354\2229H\345?\270\336\013=\256\324\355?J\262\225,\320:\336?\317\275\262\203\227\016\346?\310\233\"\274\312\321\326?\300\254\330\021`\332\351?2\360\251\307\236\207\323?@\260\256w\251\352\215?\325\276\375\242\013\266\345?T\326_\371U\232\315?\007*&)z\275\357?\364\021\375\036Q\247\313?\013kfW\312\356\346?0\213\244J\266\275\330?\247\276dA\022j\342?Hl }\360\371\322?h\275U,\213\201\317?\300\241\364\003\271K\343?$\001\361\200\354\353\303?0\204w\250\321\037\351?\021\t\300\333\227`\347?\224w\t\243mu\347?h\305M\005\237\366\263?\263L*\225|H\357?\341.\277\231\271/\347?4,\363p\352\006\301?\034\300\373\227x\230\340?p!\305\322f\205\261?\213\235\276\"\233\026\343?l\343\354\373\321p\341?\214\257>s\243\004\321?\340}\315\253\321\007\256?0\305\2272W{\265?D\221\232\215\365\023\307?\224\247\366\342A\243\333?\241\316h\313\355u\354?\373x\201\037\224\004\350?.\001f\345c\022\331?\245;\247\306u\r\344?\262\275V,$@\357?P\265[\276T\236\347?\"\007\261k(<\355?\251\224/\366\362}\350?\256n\261\277-\270\357?\302\223X\246\345l\335?\020!\2049<\233\351?\022\300\307\313\342]\357?\362\303^\027\364\025\351?Apf\207\200\014\345?\242\356b\023\2277\341?\037Y\263\250\254k\346?\264H_=k\025\335?`$\301\2104\035\351?k\320\342\302\215a\343?\010h\032\3532\035\261?m\226u\364x\333\340?T\013\007\254`\364\337?\302\304\224\016\376\254\345?P\225\231\010Y/\331?\346=\334\020;\234\341?\023x\370A\335\301\354?\025\361\306\322i\206\343?\272Ke\032Y\007\354?\360d\024xfs\341?\210\247WP\236\246\276?\302\307\2653\325\304\355?\310\346Y\305\277\024\354?\225v\222\017\323\037\343?\266\347XQg\304\330?JD\232\310>\247\355?\376\245\224\335\363V\350?\224\322\333\006\326\025\313?Em\026\303\373^\350?\236\327H^\324.\323?\200c#\211\0328\242?\264\374\255\006\345\032\314?\361[\337\235\033B\346?\264\346\257b\376\216\334?\252O\327\301\303\236\350?\366,\263\010W\265\340?]s\r\030,7\357?\233y\265\345\033\263\341?\346\361\316\344\334I\347?\027\025\253\306\305\204\340?x\263\rX\375\177\350?\260!\222r\234\n\240?\354\177yAX$\327?Z\023\356\263x\303\323?_\225#\272\020\370\354?]2\252b?\200\343?\217\036\331o9\236\340?\025~\304\375\330/\353?S\016\306\332\206d\355?\336\033\005\2065\346\347?\354\200p\331|\035\341?L0\337B\016\226\307?\261\362/\356L\273\352?\240\350\374C\334\244\232?`\345\027\'\361$\314?\252\344m\005\332\371\337?\365z\213{\226\021\341?\260\241\357\310\277\257\325?,?S\225:\033\332?\370\311*\003\030X\271?\242B\270\264\314(\340?\272J \233\305\364\342?\000*\243\217\206cb?\206.\3165\n\344\331?\255\377\305c0\004\346?\211\232\346)\003\213\357?QiZ]\310\025\356?\204I\221\252\357\327\322?\242\202A\215\367\237\354?\322\247\234*E\333\333?pb\240 \220\322\341?\330\362l\364BV\332?\204%\214\036\016\367\337?\000B]\317)N\232?\340\263M#\345u\261?\224\257h7Ii\335?|\r\"\261O\367\320?\001\336\307\017\3230\350?d(\200\371\306\322\315?T\023N`\221\364\314?\202\307\037l\212\002\345?\224\030\326Z\343U\350?`\243q\226\254x\347?,\275\237I52\356?\000\235\034p\253i\240?g\267\374g,1\340?e\357?\\u\214\251\010\024\350?\242\343|at\335\354?\211\303v\37332\345?\213\240\274Z0\320\350?\206?t,`U\337?\031>\207\017\2661\350?\264\027En\235\345\304?\214\371Gv}\274\301?\326*\021\242\036\352\326?\333\"\026#\2421\354?\200\312\333r\306xr?T\250xi\226K\351?,Th\252N\220\331?TF\3654D\343\322?\215[\370\207q\210\356?^\263}k\255;\346?\n\261O M0\346?\300\373be\332\343\340?[\340\334\272\264\306\354?\350\335\231\016:\243\275?\261\211\214;ka\353?L\013\033\243sp\337?:\360Z\273\354\210\341?\200%\276\264\306\022w?\200M\032\351[\'\334? \333\017Rg\266\336?\342\333\231-\241&\334?d\017\2412l\241\342?H\226\212.\225\345\351?\200l\027\336B\356\266?v\365:\372\240\332\336?\2200ND\227\374\335?\345\227\260\352k\333\346?\262I0\340;\327\337?h3`A\'\017\341?\364\345\203\357E\254\341?\000\351\330\266j\272\276?J\2043\323\264\356\327?)\252\262g&\222\342?\275M\372q\356E\355?5\022\312@\260\274\346?\274\216\0031\037\013\300?\370\022\3035\313\212\315?\326\230\266n\2236\320?\346\241E\221!D\343?\240\325\327\"\343D\326?@6\360\306*$\340?bi\272\024\263\233\355?\"\026\344P?\340\357?|\363\215\374\004E\323?\214\253\304\235\370\270\343?\227\243\211\t7\245\352?Ds\n\212\033\230\343?vf}J\201g\344?D\321Q\002>\024\355?w\370\242\033\313\260\352?\353M\r\240\273R\344?\260\273,\030\255r\353?\n,\215\0037[\325?D\314\350X\r\231\346?\272g\335b\234J\354?(%\301wU\365\346?\274\277\374/\364\035\313?\276P_S8\t\330?\035\205\347\347\326\004\346?\240K\301\250\306\306\266?\230\030\371\262gQ\345?\000\2468\221\3367\352?\014N\235\030\311;\307?@\355A\257\306\340\334?\210\240\237\352\310\351\305?x\316\272IM3\326?\024\014}\250zz\333?\354\r\023\n\234P\350?\232\353ZZoF\340?\3261\223\n\255\210\351? \371Z\215\304)\302?$MQw\237m\312?\t\033\350mv\230\352?\010q\t\275\322\323\331?~\306!\277d=\322?\202\221`\221/?\352?\210o\255WJ\277\327?U\224&\235\322\212\352?\010\324\034/\342@\276?\034\026\021\177\304:\303?\026\007\204\363\226\001\345?s\330\033\360J\004\352?E\322\375\235\277\320\357?\200\271\326o>\375\267?\2445q/\332\247\312?u;\233\232\275\233\357?\242Yk\254\336.\326?\246A\245!\024@\356?b)U\311\237:\354?\224\204R\223\3557\326?\200e\3146`6s?S\025\207|\343\330\344?\214I\333Vg\361\326?\004\353\364\220\301\335\343?\342\252\010\371k/\332?\277\347\317\347U\240\355?\243\362p\271\010\006\355?L\235N\257h*\346?\203\245\243\037HH\355?\007\225\247\212X\327\357?\352\236\350\207]\257\347?\210z\246I\373\363\336?p\310\371\036\321\313\341?mk\327[\226\232\340?\374\353\005\3207\366\320?\360\317sV\247\372\273?6\341$/\230\032\354?\271Oa\362C\272\347?\350\227(\373\363\367\355?0\366<\370C,\345?\312=\017^\306\304\350?\354\245Z\334\303\254\356?\256;f4%\334\333?\220,\222~\274\026\270?\030\020\377\216\223\376\354?\356\310\215\254m7\347?\234\2603\254u\206\346?\022\257<9w\314\321?44@\031\033\352\301?\020\206{\322\307\310\267?\373\010\327&Y\025\342?$\237\2748M\246\337?\000\210\033\376\335\333{?\276B1yT\250\332?\034V\360\016\240\026\345?\344\270\265{FG\327?\304\366q\312\261\314\347?\0267\350\014\216\263\353?W\2556oq\372\344?\342\375\230q\244\037\340?\315?\020\312s\213\343?\346\315\014`%J\331?XC2\345\003j\350?\212\242\023\305\2428\354?gX\302\324\345\367\342?\337\002\327\023\256\246\341?\374\003!\n\353\221\330?\212_Q\315\033\353\357?\0204\241\352\007\275\346?[\205(\314\264\272\344?\240\024sQ\3463\336?\206\376C\300\302|\344?|\313\005\020\313\320\330?\377\365\025\243K\t\343?g9\350J\352\310\344?\330G\214\213h\252\331?\254\351!\351--\344?\322\376G\005Z\314\345?G\307G1\374I\346?PU\210\'\232\222\260?\000\267\024h\345\n\245?\250\310\325\231\306\255\341?\020\242`\017Ks\314?\032\311;\352y\235\332?Zx\020ll5\322?m\264\357\"\201\r\347?J\272{\304A\260\331?P\035\037Rm\225\245?\037L\267d\035\367\355?\014m=\216\356\300\352?\361\312\3426\362\346\355?\326\2654\205\001\320\357?\017\344\356\332\322\234\347?\370\333\233\313\244\337\351?\334m\343\260;\341\355?\032\364\266\363\332Z\355?\264[Vh\310\306\305?\320\2662\267*\230\330?\205x\275\231f\245\347?\254v\'\205\007\316\340? GNQ\017\210\226?\001a\266\340\266q\344?p\267\240|< \255?U\247\367`\025;\344?\256\311\325q\226f\355?Sk\001\004\300}\350?\310+:\253\361\237\301?\004\356Q\365O\230\312?f\003\210\032tG\325?W\273\2614\034\024\356?\274H\254\0216\360\300?z\352\252%F^\352?h\266\020\201O\005\301?\376\235\314\001\3373\327?m\246\300s\233{\357?\344\373\344\317\215Q\300?n\002|\273\275\370\331?\220\321n\220\\\265\277?K[Q:\341\235\341?&Q4\005\361\234\347?DaWJ\256\305\305?\275M\331\333\216\317\344?\334\325\365\377\014\000\331?\330\376d;\246+\356?\033(\\f\262\022\345?H@\245A\220\324\343??mBa\013\314\350?\216\355(\227\240\027\354?@2Xu7\013\343?>\333\207\320\200\270\355?\340\362\200\0334J\353?\347\241\226g\004\364\342?j\253\246\234@\035\354?V/s\314\374\252\345?\206\264`e\212\035\323?c\365\331o\352\264\345?\212\026)w\216-\353?\010\200)!\341\361\345?zf\220\313\220\021\320?P0n\361\371\020\265?;Tb&J\300\352?\202?L\005/x\344?uK1\310B\020\345?@3o\332C>\240?\037\252\254\232\n\227\356?\370Oym\306\271\335?\034\314\351\315y9\310?\000\212;\026O\026|?,E\221\212&\205\314?<\217\233XfO\337?\222y\362\307\0052\352?\354\020\201G\2028\323?\330\230\337\334-j\315?\271\033\277E\365\223\344?RA?\203\007\223\325?\240\261\245K\217\261\232?\231\253\r\010|\031\354?R\252\027\272\370s\326?d\356WwT\236\305?\334M\363\000\010f\347?F-\375\304\355\032\351?J\376\225\320\204\364\336?\305z\342!\240\007\351?\250\274\313Tu\366\315?\000J\037\262\nC\233?E\177\251\210\'\302\346?\314\304\255\273\305\240\310?{E\234@\202\325\345?\360\310l\000r-\326?\360M\354\227r2\243?\020 \026\037\246(\310?\025\314\227U_/\355?61\267\032\307\274\337?^\025\257\003\316m\354?D\212Y\006\274\341\306?\250\333\335\271Y\227\326?\300\013\267\320\332\276\306?B\214+SQ\003\325?\344\336\227\336\253\233\327?\230\340\331\037x\032\331?2x\320\024\356\374\350?4\235r\276\021\027\323?v\244\377;\223J\356?\231\023F&\236z\346?\314pWW\025\021\316?\"\243\310\351Q\311\341?\320S\201R_\025\343?,\306|t\025\324\340?M<%3\236\301\340?T\2320\366\"\243\325?\200/iG\266q\331?Y\373?\31405\341?\242\204\036\377\020\327\321?\332\017h\357\205\333\341?y\036\301hYC\344?\236w\3431?z\333?\224\207E\026\244\324\345?\200\014\203\325c\n\310?\276\235\245\262\024b\347?\027p4\353\177w\347?K\202\020*\014\301\356?X\344\037\024\"\344\327?0yxj}\347\256?\222c\3116\352\263\323?\026\217\375z\n:\324?\320\262\027\356t\327\312?<\205f\214\210\322\300?\344\242\261y\350\010\307?\230\031\033\007\234\350\311?\222\365\200e\257\313\321?)\267\025zf)\341?\252\250\203h\376\034\342?\376`w\354\\\373\320?\"\302J)\021\210\332?\207X\304V5Q\353?\200\351~J\224\223\305?\240\316So\317\304\271?^/H\343\217\273\326?\226U\215*E\010\346?\332\272\350`\026C\341?\3743\224\216;u\305?|\360\272\214{\272\343?\256\303\035[\021N\325?w\340\370*\2762\341?(\354;\265\312\270\315?\354\032\346\026\006\341\311?\003P*}\026 \354?\016:\217\255\315\211\355?t\033\006v\206\330\301?v\'\354$`\006\335?\246\241=\207\022\020\324?\244\267\332o?e\340?\000\332B\363~\215\315?l\271^\275\237S\315?\205+\2066\230\374\344?\356\016&\030\373\306\352?\220\373\335h\251\321\274?\226[\363\037i\241\357?]$n\023\272\r\346?1\027N\273\226\253\344?\234}U\350t\024\306?\342 \024%\212\237\343?*\252\210\265\001\231\332?|j\360\234\245t\320?`\277G(\036(\275?\2070e\030\363|\341?\322\021\327\006#\373\331?c \373\306}O\354?\300\261NG\r\350\307?\344\350\221\2474\030\347?3\203M\226<\261\347?D\240\025Z\256\375\302?0\342C\025e\340\240?\002\256\275\355\003\003\350?C:.\262\220\232\342?)\034\0345NH\352?\324\203\201\027\005\256\321?\364\000;\021\214&\314?~\272\254\004\ri\324?\007?g\017\351\362\343?\313\310\250\231\342\307\341?\177DOG\212\204\355?\014\3548\240l\236\311?\320\251\255\345\274\334\305?|a\265Q\361\362\346?\274\345xC\256\233\303?P\3610\350\244\374\266?\200\330\234qmz\315?R\250\341[\240\340\352?`x\227HD\021\300?\267\303>\375\322\300\357?\036^\365~#r\324?\"\201\263<\016.\324?\300\001\245?\1778\316?\226W\3227\321\275\333?\000\020\334\\\235\300\262? \177,-\274\331\273?\202`b\3063k\330?\000\225\340\263=\024\263?\272(\324\031\204\037\356?\235\313j\352\270[\353?\370\306\307\204n\231\317?\344\353J-T\364\316?(\223\322\247s\335\340?\354s\177O8\234\346?\000%+\3316-\226?,(X\315\201\263\353?tz\345\253oY\357?@\235^7\213\t\346?\277\022\337v2\246\344?\010<\352n\256\342\312?t\337o&#@\353?\241\245f\225R\310\354?x\222\346\375\3265\340?\360\354I]\324\375\305?|S}\322\032\276\331?\261\275ayM\275\346?\343\221\345\260\256\364\344?\236\376)\004_\346\341?\304\017\204|v\252\346?*\265^_\235&\351?\000\243\205\230\367\027u?\360\277!\2347\023\256?\200\363\273NE2\333?\343\200\366$u\231\347?.\211\314\360\001s\337?\275\277\177EH+\351?8R\025\245\302\352\327?\350\353\361;\207/\327?\014\313]x\255\224\321?\347\220\365\300\0308\345?R\025\010\006\3078\331?\266\210\035\225\303;\343?@UJ\201\363\340\232?\220$\235}?\326\244?\\\2241\242\354\330\315?\033+\n7N\305\344?\227\322\362E7\020\350?\310\330>\234\2623\320?\246\037\240A\234\274\354?\222\2328V\261\017\325?\300;\270\345=\177\340?J4\375\367\346\017\354?\341\331\326\336]\234\341?\300\213\275;Sy\224?8L:>\304\322\314?Ta\274\342\313\305\330?\301\310\350\2078\'\352?\373\357oT\372/\350?\3209\220\346\374\263\301?\236\350~\225\330A\342?\370*\225\335\360\241\351?\000\223\324\230\265~\341?\224h\364*\302-\337?\316TiW\2709\336?\354\213Y\"O\361\355?\230\273s\250\367\030\330?\"a\264`\226\030\326?h\307\017\352\372.\323?V\301\240\034\021\020\343?\255U\316k\363\323\354?\315\257\270m\267\031\346?\016\2077pa\006\331?\3551\237\032t\276\357?\303K\260\265(\247\352?\"\274X\272\311\r\357?\034\034t^A\315\303?\310\247\257\257s\257\335?$A\267\303&\316\331?Z\246u]\216\202\351?\244\030\004B8\346\353?\027\325\024I\027\322\347?\266O\315\\\214D\345?{\317\252\017\264\206\355?\361\017\272\031\010\333\343?\262\021\234\373\274J\326?\226\344\2320\361T\335?\226\206F\222F\210\336?\217Ph\345|\"\353?\320\264\272\270\333`\301?\000\"gW\"js?$\021\323\004z\275\324?\314\354\350 `\335\335?Y.\367\311s)\353?C\235L\272u\304\356?\017\342`\341\334p\342?X\r\305%\337L\321?j\370\257\001s.\351?\\\266\2623\0203\325?\216\'s\311\253]\352?6\033}\'d\022\344?f+\205&\370\214\327?\212\241\242\r!\031\323?\300\257\370G\001d\230?(\341\000\023\202\205\316?\350\337\037\265\035\364\273?\307u\366\344\025\005\354?j\\\255M\210\254\327?\002\362\034\301\321\031\321?Xbi\375@\335\353?x\224\035\245\207\370\322?2\202R\347\366\321\337?\024\325\277\310\350)\311?\360w\3429\207\263\345?\r\300.a\034!\353?s\247\321\035\314\014\355?W\315\304\001I\365\347?\364\351\264\241\002\362\311? )\326\357\334t\312?\200\377\342W\245\236\257?\316~\271\314\266\343\341?\272\\p\372\000\354\337?v\302\271\311c\341\357?\254\014\250\244\321\030\336?\204\347\003\206\363\000\311?\000w<\302\302\246\343?\030G<\035 t\327?\264\004\371Q\307\010\321?Gxu~w\013\357?d\253\000\035\016\030\335?\2645k4\263\305\314?\237\231\343\204\342\212\355?\205xz\243\342j\340?L\275\2651\215\330\325?\340\242p\035\n$\271?D\362\370\262e\031\344?=\332|\276V\252\355?\300\303b\340H\036\233?\004\333\351\312\302\241\340?H\004\3564\\A\321?\002\307\361\211\235X\331?G?q\374fe\352?\332Q\344B}\014\323?&:\364\367M\010\347?\030\n\\\'\214\"\353?\261\262S\020[\311\345?\204Q\\:A\326\330?\261w\351h;#\356?Qq\335\300\324K\352?\253\207S\353g\265\341?h\323k\250\367 \317?)<$\0348y\342?\214\377d\336[_\323?\204\006X6\362\263\320?\360\341\324\026\236\365\322?\301D\024\370\304`\344?\344\341\310%\241l\301?\006\257\242\361Gm\342?\3307\235\211\330\312\333?`\367\007\377DQ\335?\320\034\204\020\346\334\350?\244\204\2309\223g\333?\210&\177\004\311<\345?\022\007\310\313\267\300\346?\240\025\034Os\203\234?\320\032m\250\243f\341?\314\334\035\321\213)\343?\277u\355,\3232\344?@F\343\245[.\252?\344\335\316\267\3725\351?\220i\274Pe\214\257?@\340&\207wa\306?\005T4m\2228\345?\316S\202\025\323}\353?\360\217\026\212+]\337?(z\365\274!Y\321?Y\276\302\255\270\277\342?n\322\2503\003\356\347?`E\355\212A|\277?x&\207\005\245\265\341?\224\277\224\270\363}\344?:I\216\202\310\242\336?\332\020Ah\256\260\343?\317\r\'\360\257\020\355?\300\302\322]Q\'\236?\227\244\026T\277\371\344?\237o\273\202~\022\347?\2401i\331R\035\223?\037\260d\371jT\345?\212I\360g\232\235\322?~pX\212?\353\337?\330s=\020Jl\276?\256$\021\274\031\221\333?x\270\001\016\224\251\300?r\303\326w\2777\354?W\200\211\361\023\262\346?\264\316\374]\365z\355?\372\\\022\251\234=\335?\274}\014e\330\373\313?\000\230\232\225#\017\257?\000\033\233bt\345\200?\200\215\356\2054<\344?\210H(\352\030\351\334?\216\276\337\272\245\266\342?u\212!*\361`\356?\030\375\243~JO\344?i\271\375j2~\341?\r\347\205@=\336\345?l<\345?\346\"\335?\300:\371\375w\r\350?\265\212uPJg\342?\014cr7\266R\312?\346\310\374\345c\211\337?\307_\301\344\266\017\350?\354\317\331\3475U\350?@b\370\003\224u\337?\000\303\225\022\215\351\332?\224\237]\227UR\322?\360\272\276E\361\016\341?\374\361\025?\271\271\305?W\3027\260v\013\353?\362\223\206!\265\322\337?\270J\005\246\007\037\323?\271\300\017\204\242p\351?\306\237\327\200;&\334?\360\244\352\324\317S\250?>\355\260\205+f\355?`{i\231\277\033\306?x\037\246|Yk\327?\206\211V7vZ\327?PJ\254\335\022\211\345?\327\260\335\202\216\360\350?\017\352\021\340\235\202\341?\014\302\233<|\377\332?M\300\327=D/\355?\220\305\366\323,a\244? \205H\262$\232\324?,~\376\247\317I\331?8\347sT&\204\357?6\003\207~\357H\335?\346+5\304\200$\322?\337\371\336\375T\321\340?2\307\231j\310\220\333?|q\021\203\337\025\315?me\005\334\356\232\347?\320E\320\330\332\325\317?Tz\232c\256x\313?\320\020\323\332,\207\321?#\231\320\226\335\316\356?\244\\\340\304I\201\300?\220\001\373f\213\273\327?\264\236(\361\313?\332?\004@|\230\220&\345?\242yi\243H\220\347?h\256\373du&\306?\370f\316\2649=\322?lC\240e\317\036\356?\300[a\235)Q\252?\374V7,\014\262\316?S\232\341\210\270\026\346?\271W\025\374!\026\341?\022l\001iin\342?.\306\263\346\020\250\354?T\216\010\230\025\254\306?B\262\013\021)\270\334?\000\017\333\347\273\023\345?\344\276\264f\225\037\320?\252+b\266T@\323?W\217\230\247\345\222\356?\300\202\3422U0\244?#\311\236(\225\232\343?\2309\331\317\224\320\314?\320\375\256\236\036x\327?D\320\006\247\364\037\331?`\027\215\020\216b\264?d\230\375\235\241\257\346?\353\344\256\336\320\364\347?\000\234O\0053\225T?\332\007\372no\377\327?\320\266}\237\241\316\341?0\340;\\\370S\352?^b\222\3062\262\343?\007L\334\230\341\316\356?\333\370\222\225\240\301\344?x\313\010+\343,\326?G \303F\0105\353?\362\223\005`\231\361\337?{z\237!U\351\350?\014n}N\2078\353?*[&\010\205t\334?0\324\271\256b2\321?s\333\266|\025\233\341?\224\360\343\256T\353\325?\352J\221\344\243h\333?\346E\347~0\177\350?\276\310o\354\254k\354?\266\213g\245H\304\343?\302\245\3176\036\r\330?\274\375\272\355\312H\347?\2340\374Z\'\205\314?t\025\017\325\220l\332?\324\373\255\377\215^\326?\320\342e\270\314)\265?W\244\372N\203\r\352?\252d\277-\345.\331?FIE}kP\323?H\207h\306\220\003\320?D\333zS7\210\345?r\001\332\336@\373\323?Qa\244]\035\034\356?\030\214\331\246\220=\341?l\255\300\020:\"\313?A\251\377+Hv\342?\024\365\333\005\243\243\332?\327\351\214\314\350\333\350?\264\004x3\227\376\322?\353g-\341H\373\356?8\013\367\003a\344\332?\270\361\003z57\275?\316\'!=\272\201\320?B\227r\360\221\247\356?\000w\221Je\226\332?`~v\025\275v\302?\372\3221|\211p\347?ze\017)\364P\353?\2748\003\004\375\261\323?\210_9\307XW\312?7\330\362\321\032R\356?h\026\tP\n\247\312?\251\265;\365\353\302\346?\240\017\340\366\231\332\273?\377\376y:\023\354\354?\340\"\345;]3\343?\276\223H\303[\274\334?\026\256R\020y:\320?\271\347(h!L\340?\334{P\376\240I\356?\262(\004\212\343\034\356?\220]o@P\253\243?l\005\307\215\334\377\333?\251B\324]\233\036\352?\034\336?H\275\335\322?\250G!\2769\023\356?\353\375\256\252\360\366\345?\020\204}K0\320\346?\014>\300\237\247\277\303?\304\330\242&]y\354?^s\336\233\022\023\344?\202\3719\204\305\313\347?\026g\036\352PX\334?\037\362\244\347\2326\340?\034\346Iq\233\260\356?\310LJt\270\014\322?\313s\364\361\300\233\356?\034^\034b-m\312?\"\370)\321}\023\321?\030\021\025\266;\302\323?\262.\0030\035\362\352?t(q\312\236r\310?\300\'\244+\350\303\270?:]\204^\374%\350?\274\223\340e*-\321?\373\240b\245\252\373\350?\310\302\017\353\035\203\307?\010\337\332\345\302\211\260?\317\0375\247\354\273\346?\300O\272Y\177\013\207?F,\364\023l|\353?\333K*5(T\347?\006\033\300\223Ky\320?\366\307.\350\302\252\345?T$\335\272wP\305?\364\031\007\325\2560\306?\024\207}\272\216\266\350?hI\214\r|\322\300?\\\"\367\334@\273\346?X\253\246\337\345,\347?\310\325d\357Pz\331?X\227\255t\353V\314?\300s\213\212j*\323?<\361\230\310u\007\324?\241sU+\231\377\350?\020B\254\214\270\327\320?\320\247h\303}\016\316?\010QeM\233\214\343?`\205\345eJM\232?\317\275I$h\022\357?\010\033\265\342\230n\317?\317z\345+i?\347?v\317H\376\217\376\341?\300\257\327\322\317\320\355?P\333\3675\251\313\273?\304\254+q\336\010\310?\312V\336\276\3173\337?=\337\3657\313\246\341?\306\372\332dR\215\321?j4\200\264\216\t\343?^\020Zj\2071\340?\270\330\010H\177\001\321?T\276\313\312\3127\313?\373\220\330\236\376\316\356?v\362\231\356q\214\357?=\357\002\357\3322\351?\202p\tvm\315\343?`\246\331\237_4\316?\024\215\nO\254\215\327?\203\351\311\247\226\207\341?,&\364p\211\257\340?@\025\273\313\302\005\357?s\273\313a\331\202\341?^\316\343\260\226y\327?Ui\324\273\363B\352?\000\024\237\2573\302e?\020\025\254\256p\355\255?4K\210\357(I\343?h\3603\017?\325\334?\374a\207 \232\002\346?J\032\332\255A\356\340?\210\2721x\260 \326?\251\314i4<;\357?6Sa]3\\\326?\261W5iu\031\346?\0041\'|@U\336?G\314\301mk!\351?\206z\255mz}\323?\257\r\025c\"n\355?\262U\n\3020&\324?\214\354\346\237\0309\340?\254c6\037]\206\315?\334\316\243\372\310\225\357?\356\214\333\315\274\003\321?0\231L\245\024@\240?\300\236\314o\312\010\300?\201n\"oZD\344?@\346,<$\215\234?]\217[\337\342\264\341?$XZ\246\300l\341?\245n\025a\305\234\355?p\370\330)q\352\301?Rs~v\004\370\346?\006\323\023b\016\201\322?Q\022\305\375\020\220\355?\030+&\300\002\354\264?\321.\020\213\204\315\354?\346\017\202\215\325j\321?\374\336\350d\026s\323?\262\246\005W\211_\347?\'! \344\270r\342?\324ei\3700k\353?hl\036c. \337?\320\317\310\215@\216\323?H\210;\227\253\221\331?\266\232\226\370H\334\357?j\336\240$,\340\325?`\213\361\256\341a\327?\210\010\220\373q2\260?\237\336X\301\303\333\345?\300<\365\375Y\352\261?2/\226o\337\324\350?=\3464\223\333\334\351?Mx\206\0244\025\347?/\313\310\3441\357\343?9\231,K1\345\345?\250E\244\332-[\350?\302Te&\021M\346?\230\273\324\372\270\361\354?\334&\354\001J_\351?\204O\376+\n\245\325?B\202e `\307\330?\252\314t!\224\265\346?\344?\350\220\177\312\036O\327?\335o\326?\374G\305\263v\336\355?\357J\203\314]O\352?\020\031\361\320\373\346\346?\371\305\322\237\301\232\350?\335d;7$K\354?\"\362\024R\013[\331?\010\0018,\274\025\331?\266\335\360\324\256\025\341?6\327\252\277\205/\324?\344\365m\262-\004\333?\246\024\251\225,\337\334?\232[t\363=\013\324?\330\217\3211\016a\345?\217\300Ut\316\336\340?K\327AC\351\007\341?\310\315\014jd\376\327?\322;\202D\037\245\337?\000P\373\371\251\362\317?\344mI\207\354\217\326?/\252\030\037*\274\346?\210\001|]\323\351\300?\236\207ve\204\331\331?\020\304\rDe&\257?lGF\360~\374\341?\342\001\231\254\222\304\320?\240\006\362>\243\"\356?\002w[oM\274\335?\366\323\2376\"U\341?\177\302\017\353\031S\344?\032\020\032\302\3006\340?Rx\341\014\250G\354??kZt\307\026\342?j]X\351h\350\335?\362S;T!u\355?\307\n\r\360\245\277\350?\373\204$7\366\020\343?\260\005\376b\323\220\346?\020\316\232\\\356\267\326?\361\376T|T\233\346?\2142\340|H!\303?\342\273d\225\211d\346?\364\006\207\353\373\267\313?H\230P\315;\243\324?\362S|\23204\340?\034\001s\021\236\325\310?\204\2069w8$\345?x\316\307\333\355\033\343?/\376;\211\306\372\346?\034\214M:\264u\332?\360\214\245\265\356\246\241?\324v3\324\313x\341?\276\315\374\312\351P\355?\254\315\242b\232\036\333?\200\002{\357\301\237\305?P\336ed7\274\256?\3520\374\311\216\345\356?P\027_\372\352\254\240?|dD\210\365u\307?@\237\206\\\2438\257?DiD\335j\227\342?\347\354\027\310\020e\354?7\006\237\321\256\211\340?->\325\374\273\252\347?,\333|\226\366\324\304?\272\001\370\306MV\322?\360P\322\330\010s\331?r)\374\346\223\177\354?A\212C\212\267\002\344?j\237\221\220l\243\350?\2764\027/\243\271\342?\024\227>\214\307{\317?\310\001w\215@\253\312?t\316XA\332\267\316?\306\213\334P-\314\326?0\006h\263!H\250?\223P\251\252\271\373\347?P\231\021P\264Z\337?sj\035\200\024\304\350?\340\227u\005\237\346\274?\343\0024\03798\352?\211Q_\356\364\205\354?P\205]\237\274\313\340?9B\344\262W\365\346?\272\003\360\267U_\337?0h\024\"\022r\346?\330\317tz\217\r\307?\256\350\352\221.\210\332?mCU\231\324\305\354?\334Tx%s\324\320? \352\202v\230\'\322?\330zbj>\010\314?\244-t\366/?\332?$\034\246\021\243\320\320?h\203\014b\0043\315?\356.\346\364[\232\322?\317\345\334\373\255\375\341?\000}Dl\242(\254?\336\326\273\353\245F\352?\271Z\254\004\242\356\351?\030\023_\377\205G\261?\000\300\302$ZZq?\270\302\331[\207M\305?\205 \245\030\224\031\354?@\275\377\020\270\002\334?\271\313\330{\030N\345?P5\255Ln\237\272?$\316a\203\311\005\344?,\032u\347\002\346\340?`\022H\371Z=\275?h\212\261\354\305E\273?\370\014J]\\\332\341?\360\365\331\006\253\366\250?f\341d\3015(\336?\306\250\t#43\353?\036\327\341\223y*\330?\340\247\375\275\222g\350?\350\360\326a\314\252\336?\350\371[\365\364I\343?B\\\277\264\001\362\356?\270\255/\266\3218\303?\016\363\322 u\337\356?0\320\007\261\247\371\256?\326A\204Y;\347\342?\227\215\272\222\233\r\341? \335!\355\305\254\303?|L\233\'N\212\354?\300\261\335\240W#\243?J\230\247Q\360B\342?\244j\325\301\202\332\304?n\023L6\030\267\332?\264O\366f\270\275\300?\310\306<\264\311<\355?D}\273%\373\305\344?\253\320\250\312\377\352\355?\025\235\2053d\226\340?\244\013\370\233\347w\332?\006F&a\235\263\345?\002\204`\331A/\327?\'_\320\337>\371\350?\230\342x\335Ze\335?\354c\207l&\377\326?\200\345\272e\271\271\246?\332\013\262\323\353\256\330?hb\016\346\333\211\337?\314-\317\222\211\220\303?\010\027\300J\006\343\334?\300\370@\256\231\205\265?1\267\320\257N\327\350?\270\372\245\3217\234\300?n\3573\312Sp\351?\240\222nM\266\304\334?\257\007\256\316Ru\345?W\035\343 \345\325\343?1\225G|\364\250\342?!\023\355\262\334/\346?\306s\352\273\266 \341?\353M\224\354jq\344?\210t\244\nP\004\302?\317\232\233m\332\006\350?$\022\001Y\250\000\330?\367\212\307\026\001l\342?\310c\217 <^\352?(\314\251o\013W\276?$\224\321/\275\250\311?H\314x\014U\266\353?2z^\363\032u\333?\305\327!y+\334\355?\337\300\256\030\3359\353?6\266\n\010\233\006\325?4\332\335\005{\345\356?\210i\033bZ\244\305?|\341\374*1\373\303?~\337i21D\331?h\033\356\277\304\340\264? \322\241YI5\240?\030<\361\373\321L\353?\001D\025|Q\331\350?\317\001`T\353\316\341?\224]/\351\221\\\341?H\216\017NI\341\325?\014\232.\242\037_\354?zj\237\006+\345\332?\360b\r\245s\352\265?Hn\300A\222b\357?I\177c3+\271\344?\235\000\250\020\244\320\354?\346\026\005\241\232\026\335?\352\222/Gr\005\346?\241\t\272\242Sg\342?\306w\237\341-A\336?9\262\302\\=[\355?\354\272\375\222z\243\310?\272\313\370S$\000\322?R\310p\224\337\255\327?\300\272\207\326\365A\275?\372-<\361\362\355\325?\216\013\005D\326\352\330?\260N\014\353v\355\312?$\257B\"\027\371\353?\256\014\010<\210z\354?8\271\371\013\016\370\271?x\001\303\241\315\353\327?\250\016\277\320\212O\300?#V\314X}\305\353?0\237\204\270\207J\335?8\0224sYY\305?\270\2753\241t\022\276?\014q\3746o)\336?L\211\021\356\221Y\330?R\024\262=\341\240\326?(C\321O\224\234\277?\215\345\241\310\3079\340?lY\241\354z\212\302?$}\373\272\010#\321?)\237\324\302\356\273\341?\270\024\010\365I\213\333?\254\211\201v\226\330\336?\266 \231\335A\317\324?S\'z\004\315\266\344?8\255\331q\342.\323?\030}\245\256l\325\341?@W 5\021\000\334?h\242\330=\257\364\345?p\205\321\t\276\331\255?\010\353\204t\364\254\272?\345\023\256\035\014\210\346?\000\372\315\303\356\205\300?\t\317Y\360\031E\347?f\t\206ii\254\350?\320I\202\341\341\026\267?\010\270\3575a\245\356?\246\364\232\343\245}\345?Vz\374\302\001\031\336?\360\367ah\374s\266?(\025\t9\246\347\341?\210\343\344\330\224\254\267?\274yX\367\301T\325?\340\212A}\304\207\265?\027\3378\354G\355\347?h\305\013\350T\302\351?P@Y:\374z\325?\266\215\256qQ\224\323?\262&\314\341C\335\342?\226\311\332\267\324d\357?\261z\250HU\222\343?\261*X\304\253\240\340?D\347(=\340\355\300?\230\247\312d\357\313\266?G\256)t1I\346?\'\353\261\263\214\327\342?\311\001\033\370(\274\356?\334\256\260\257MR\333?\024\247`\322zS\346?.z\343\260\342\244\344?\205\177\256\327e\264\340?\032/\260\237\304p\326?\240^N\233\202\265\334?\244\331\021\233\201\262\345? v\037=\024/\355?\330K\260\035\235\232\325?\233\0332\317\365\376\346?\307\017\240d\237\300\346?\257\207\260\245\210f\356?\036{!\254\244\t\347?S\036d&k\327\353?\000\017-@\245\242\317?n4\354\207Z \327?\000v\030P6\237\351?\336UJ\017h5\347?\230B\215\272\350\240\337?\362\270\333 X{\341?\014\357m\334\024g\337?p\264\306\211\327\323\271?\000\340*\276Z\267$?\240\360\300\006l\030\326?h\216A\021\233\315\321?(\341\361\026\325\220\315?\214(\216\241\317\000\351?\030%\255\311\217\213\301?x,\035q\356S\333?\220\037\303Z\024\324\242?\374\346\243\n\274\r\324?\023k\226\313\025\003\347?\377\"b\255<\243\345?#P*}\233\017\347?\362P\"7-\204\334?\355\300FR\032\305\354?\267\365\216<\335G\346?\004\007\224E\250i\301?\260\325]_\232\272\242?\332\"\266\230&\265\344?1\344\004&\324\250\355?\314a\312Q\227\325\312?\336M\033\356)A\330?\005\233\r\"\300\332\345?$\2510\306\344\306\300? \204\004_f\314\353?\242\326w\265\232\376\351?\200(\324\356\244\257\353?K;\365:\245\352\352?\370@\207\025\2753\323?\204\265D\336\016\024\337?2\322h\207\010\365\346?\002`VOCf\325?\277\201UM\001\234\343?\330\267\006vw\007\327?\212(\307\n\235F\346?\"6\006N\277\002\336?\310\277\025~\035\317\277?K\362c\227\264*\357?\276\352\260 \325G\342?D\332M\004\335\211\313?\360\334H\217jE\340?&Z\3362\032\256\341?\244\231m\014L\266\343?`\027\231\215\207\250\237?\223\030%C\255\224\354?\356W\010+P\275\353?\361`\243n\246\377\356?l\003W\321\3429\330?\364\251\233G\2324\316?\245\\BFc\221\341?\220u\224\226\026\235\301?\245T\"Y\243\265\350?\360\200\257\010K\206\265?\230j\356\227\257\316\325?\237\030>\311Jm\356?\270K\377\253\376\225\331?j&\3064O\323\346?\313\306?\332\210V\353?\273\324=C\203\214\350?\354V\327+\262\001\355?\0004{\014\276\314C?\002\236\266\270\374\r\354?\332\323\036\020z2\333?e\202\2406\361e\345?\"8jh\356\360\333?\260M\377\220U\315\353?\'M\212\252F\310\340?\304`\256K\257Z\351?\372\364\013zN0\353?\340\363\r:\266\355\236?\000\206\372\267\276\227a?\347\'\'2\332/\354?P\245{\312:\224\306?\320bl\304A@\247?\352\252W\205\"\305\327?\233yv\346zL\341?S\020|\260oF\341?\202\305\013\316%A\354?\3403\000^vE\244?\026\204\303\177)\303\343?PMj\243~\240\276?\200\000K\313\363\207\254?42O`\303\240\343?\200\234jfz\212s?\322\3202&\2202\357?\344.\254;\252\245\325?\260n<\270X\312\316?\330\216\345\027\210]\267?\305\270\324I\336B\341?P\302\320\315\300\022\340?R\024k (\206\342?\254\243\331\302+l\324?\214:7\345&\257\330?v\210\253\342\225\031\322?p\356Sg\326\347\272?\260\014\210\2240\232\274?\000\266\337\304\347V\255?\244@ho\360\262\331?F\310\353\262p/\357?\370\003\363\313jv\304?\272\r\031Sz\020\335?\246\334t6(4\325?\272\216\240R\243\022\340?\200\347\344\0253\343\302?\234\264V\363\347\240\355?\000j\031\217l\256\204?\000\032{\306R\242\257?\310\020\277\302\230]\272?\200!\212\3530y\274?\360\212T\332\014n\304?\206\033\356\014\311Y\357?\345\202(\026G\243\351?x\306\200\016\"*\301?\356\275}H\355\356\353?\036%5\\\367$\324?\031\224gO\262\363\354?\036\215Q\243in\347?>\330\303\032K\356\330?\274\362\352\335x\010\356?(\356\217\001\017\016\310?j \002\246\366\364\327?\256Il\267-G\347?\030\271\325\344\n\273\337?\222]\254\"N\347\327?\214k\270\220\352\250\357?l2\017A\0053\336?\232\311YP4*\352?x0\317\324^\255\313?\336\t%j-{\350?\010\230B\260\203\n\314?\234\331x0-\316\356?8{&.t\024\336?O\262\266\023\342\354\342?^YPJ\310\214\326?\204\023.\352\006\014\317?\004S\232\256\005]\305?\323\240\327\240\200x\350?\206\230\201\343\260Z\357?hD\322.\234\310\322?\352\305A\325B\207\351?\276f\003\377Wb\344?\000\241=\360f\315\340?\370\355\267BlK\277?\350\226\004\007\367a\270?\272\304\371\010$\257\322?\014\035\312n\314\262\315?\257\223;\256[\276\350?\206`\316\332\247\346\337?xK\234\2006\215\320?\020\2359\211\003\205\276?\214v\202\336\331W\321?\234\366H\246\024\367\310?\317\3079v\304[\356?\250V\220Y\342v\355?\344\037\007\215\017\035\355?\254\272\257\323p{\321?n\233\'R\371M\327?pV\332\306\272\374\257?\365q.L\374\245\344?\207l\313\343\036\371\352?K\303\215\350p\350\341?\354\303\210]\246D\333?\346\206\270\254\036B\355?\211\224|!6@\343?\374\263u\217\240\317\313?\235\\O\262\226\311\356?`\212\224&\203\220\334?4\005\224\n\251\222\350?n\001O\313p\021\321?\210ca\310\022c\314?dx\347\026\266^\327?\200\203\022Q\221\277\237?\204\214^\305\333v\304?\024|W\035W\330\345?\340\000c\274\275s\317?\310\372\215\217q4\353?\215\337\312\222y\303\343?~\010\260[Tj\342?\305\357\rZej\351?\340\3043\222\232\374\260?\322?\376\333e\233\343?\374\240\026](]\322?_\374c;\352N\352?j\2749=\365\224\356?0\244,\277\035\341\336?\250\322\353>F\266\344?\340\000\"\341\250\237\327?!\323\242\373\234\242\354?\224\'\217\340\010}\301? \222\232\231\021g\350?\215\031\275\364\310F\355?\334]\304^\360\000\300?\330\004\252\016l\013\303?\234\005\352\302\267\210\347?\024\032\035\254\240\366\353?\304\230{.x0\353?\304\350\320\0279\376\306?0\213\2178\263&\333?\344\266\010\372\362\364\352?X\221k\311\375/\344?#\2445\304\007\004\340?\341I3\027\277\313\350?\200\3478\367wqq?Ib\222\346Nw\346?\370VKL\352<\302?\000=\212\247\200c~?\360 !\233|\332\302?\213\262$^\310W\345?h\214\352\207\316z\261?k\200\001o\277\334\355?a!\334L\267E\353?\273\346\2665\200\254\354?9CR\245\271t\342?R6\246\253\r\370\343?P7\350\241\2710\270?=y\304\327\2667\356?l\025kf#\304\330?Y\367\312\016\306\346\356?t\233\341\325\r\331\336?\214=\006\343\016*\313?\213{\352\303\025\347\352?\342a\031H\346\257\336?\206\001\014\341\374\330\355?\340\017\3708\267n\344?[ \237\314_q\345?&\277\362\250\372\207\332?\263V:\274\374`\351?\226LYo?\207\343?\227\026\250\272\271\n\357?\210\203\347\327 \237\335?(\303\216\251n\317\317?E\277\350\263\014U\351?\336\020\021\366\315\226\323?g\303\317w\007\261\355?\030\264\310\020[\367\272?D\2358\346\270/\305?\014\227\233QxK\310?\366\257\235\257\355\201\352?\362\2429\035Z;\326?\364;j\034\240}\310?\365\n\n\237\251\254\350?\231\211r\302H\363\345?|D\022Yd\247\313?\366;\003\225N\021\353?\342\274g\277\242q\322?\302\2327\013\237\273\324?\237\227\334\275\374\225\347?\235j:F_8\342?~\177<\030\333\265\344?F3\027\355|\233\355?\360\372\261\252\014\002\325?\363\274W$\033A\342?\204\225\004Ed*\321?\340\025\315S\200\365\252?xB]H\317>\330?\0009|\202\305\031}?\370\257\362\347S\201\335?R\320p\331\224<\345?8\244:h\221\206\304?l\032y\203\336S\314?\320\0236!m\224\273?e\346\364\304~\036\340?\010.u\311k\321\274?\"P(Y8U\351?n\345Z\234\004x\346?\272\025~\302\027W\337?\350\343\363\341\277{\266?\025$\341\220\2035\352?\036\252H\303\343\177\327?K\002\'\301&\303\352?\342\354\241\177`\013\322?\266~\251:}f\355?2d\265j)f\325?r\251lH\025\n\323?\020%\005\254Z\232\324?\231%\014TLW\352?\326\327{\363\323\377\341?\022\263NQ\343 \336?+\010\324\306\270\037\355?\\\365\033\245n\266\320?\364\343\253\201\022\277\342?\330Z\272\310\370\256\262?,1M\277zS\324?\200\340\341\300\345\002\264?i\214\310\213\367\374\354?d\342\010\243D6\343?\320\037\014\214.*\243?q\337w}\033=\346?0E\'\026&\277\337?\000\204\315\227\236\004\314?n\313\034\344\347/\321?T\207}Ds\036\323?v\336P1\236\324\335?^\343\232N#\354\355?\354!\252\035\326c\333?2\256x^\307\002\345?( <\336\355\322\300?x\246\203\302\227c\307?\022\206\007\242\243\247\355?gd\211\264\336\010\343?\0344\322O\317\246\332?\237\031\345\347\226\326\354?\233z\177\247u\277\357?\3627&1\010\360\345?\227\223\241\"\355Y\340?\000\371w\022\200V\245?\213\220\340}\035$\356?p\036\211\250\214\227\257?\330\326s\354\316K\303?\315u~\270\267\365\346?\350\322\215\223\213\223\300?>\014\207\263\326\257\343?OB\n)\022\272\342?\354!\316P\305\261\306?\325\2669\006\304]\355?\351+&\372\372W\355?h\347k;\"Y\344?\202\350iZ\177u\344?\000G\013\000\377\222\273?\360\004\004\002\213s\350?\372K\236\271\330l\346?\240\243\336c\177\223\335?m\003V6\227;\353?\022\225\250m\307\034\326?\207b\236(T\222\354?\322\226\376\244\216\343\324?\\R\301GT\'\320?\232&\2727\221\002\321?\234\025mA\355\014\352?\232\355\003\345\211\232\340?\204\255\005\323+l\333?\220\317\001\365\247\237\316?\004<\220Pf\212\310?Z\323\223;\020/\330?\014^\327\251\211\010\337?\210\247\027\233\315\006\354?:V\027\032\263\362\331?0\025\010\244\027\361\246?\002\346\241\360\027\343\345?rGD~\014;\320?vf\270\347\315\035\325?\246\241\361\032.:\344?\246\207\240\303\306\334\356?HJ\nlY\271\312?\306,u\253\314\031\332? \243\352\366\t\236\321?\314n\0179cs\306?\264\020\227Z.\334\305?\240u\277#\217\350\273?\036d\336\"r\264\351?\333\000\033\025\243\314\352?\275\023q\037V\267\341?\034\023~w\342\"\354?\354\217\002\241` \354?\n\373\2377\375\240\336?\276vdX\247H\331?L_D\215\"\237\311?\260c3\254\312\324\352?\364_\246\251\310\343\303?\303\3643B\033\311\354?H\336\362\260\016\333\264?9z\017\213\352`\345?\034\357W\346\373-\333?\030~\376\t\017\265\321?$;\r\324(~\357?\231\205dF)\232\342?k\365N\004\025\332\341?\201\250\264\326\007\'\343?#Q\034\270\224\304\351?n\250II\247\254\335?F\2272\266\270I\325?\200\231~p\371Q\317?*0\230\311Gk\335?\037h\335A\333\"\354?\310\024R\026\3748\335?\2439\024\r\226a\341?8\341O\225\244%\337?\331[\023\203\3032\353?0X\354\351+\211\335?\220\002\\\322\223\333\317?$\350[X\263\022\303?j\034o\227\367s\321?\334\300I\231\207-\356?\000zZr\260\233n?\003\220t\037-y\353?\274\311pt\353\033\314?\0007\345MN\341\250?(U\263U\240W\301?\202\311\032\200\n\252\325?\223\366\314\3169\007\340?`\367\241\216\232\203\223?04\357|\224B\321?+9\035\204\332W\340?\r\275\214\022X\327\353?\024\020\353a\205\311\311?\246\310\325@\325_\354?^\344\220\316.\323\343?!t*\320\315#\345?\372J\375z\232Z\356?\270\353\341\270\020f\265?\274Ie~W\315\320?\005\337P\\\033\032\350?\251\250\322\007c\372\343?\224f\3606\237\255\303?\240($t\372=\273?;\235x\240D\345\340?\240\347\022Y1>\255?\344A\\\220nS\302?;\027c\017W\224\343?t\315*\000\312\304\330?\016\020m\210\343\357\337?`\206\263Lf\031\321?\204\227\253x\202\332\323?\301\245)i\225\327\354?\0138\244-\352s\347?\376\366`\265rF\324?\300\035c!\213\037\216?\002\030\322\216<}\343?h\327\364\363\023N\276?\245\257\213\302\337B\346?\302\223 \374\007\241\326?\000\263\013!\326z\262?\314\311\250t\023\306\332?@2w\326\"\311\303?v\331\004\036\250^\345?9\347\343\225\251e\342?>\206\331\230M\212\335?m\373\032\314\263)\357? \243\001\343\350\353\220?\010\301\275\231\316\201\322?\200\276\335\367\366\343\271?\322\025v\037\360@\346?8r\266<\224\367\345?V\'\242\366\225\322\325?St\373i!\004\343?v\"\266\344{|\336?\022\363\250\374\020\242\334?\366\206\372\032\342\321\354?\366Ycg\307\275\357?@k@\305(\215\316?\026N\330\222\274\274\357?\277_\260\307\tC\340?\277\256\266\2563\315\344?\210\254xw:\224\330?\3315DQ\237\347\341?Mxj[\032I\345?|\321\344\224D\340\304?\351\352\021e}\037\350?\314\237\232\177a\334\322?\310\004\241\325\237t\260?\032\027S\035@\000\357?\324\362u\276\242\334\343?\001\r\257\2638S\352?\006Q\355\360\233g\322?\2558\341\311\213\217\356?\223\312\344\205W8\347?\007\376\001\332\2257\356?\014|\3275\301\275\324?\276\325\207\016\206\311\326?r\225]$\233\240\334?\\\210\3334\270\345\321?\312(\323\272\363\341\344?e(]\3670B\351?\000\204#\304\232\240[?q\343)8\230z\351?\000\200G7?\320[?\244c\n\231\323\312\354?\010A_\026n\310\326?\245\225.\267\307^\357?P\207\320\005\314\252\275?\374\007rs\353\304\350?\220/\rdd\374\240?Xf\216^\242\264\310?`\366\252]\034\201\337?. \272P\272\315\344?\\$\351\333\313W\304?\000\325\203\332\213G\272?{:\257\303\377\003\345?~V\006\247\215\r\354?\256vLAF\206\344?z\303u%\263\354\334?\256|\001\\\234\255\323?\000\020\210\233^\211|?\244\314\226Py\224\321?0\265]T\341\241\257?w\301\037\246\300R\346?\210|\276\347\203\246\266?%\025\004\223\317\005\351?``\010h\'w\221?\260\301M\206b\030\325?\230\246\254\217i\240\354?\021\202\264\241\313b\353?\224\300\221\203\253\205\350?!UH1m\336\351?\260T\200\303P7\356?R^q\315f\265\320?\206[\026%\204\370\340?\000\2067F\377\036n?\344\0073G-\216\317?\260/\203q\370\315\344?^\027\224\025\250\215\325?,\375\017\317\225\332\326?\264\237\177\216zC\312?4\256W{\334J\307?\033W\270.\n1\355?|=t\326\243P\326?\n\014\250g\267\377\335?\377\315\034\tp\204\352?\344\"\025\210\304f\320?\240,0\303p\001\230?\346$\310\233c\237\324?EK\327\311\3451\346?\036\365\266(Q\200\344?\344\212\345\324\001n\322?\303\245<\203\363u\347?y\366\363\270\232\250\354?\\]\332\0045\014\305?F\213\n\321C\333\331?V\350\n\275\003\221\340?\224\227\233\025\372\231\312?\217`2\326\014B\343?^rOE>\204\327?\322\034Od5\347\354? P\203\362\0100\323?\367\325\205\036&Y\343?\272 \301\257\004k\334?\261\372\377\213\212\326\344?\263\331\205\210G\216\341?t\335\217\321\216\006\350??A\2779\343\250\346?\017D\230\310\332\271\343?\255\0046p\373o\346?\246=:\322J \321?\003\367\360F\'\220\350?\226\242\227}\213\"\324?\2150\374\251]!\350?\354I\306\370+[\312?\215/\362\nl\256\346?\306a\352\263\201\331\343?\036J.\212\247\240\336?!\3675\363( \354?8[5\341z\345\337?\3346\216|\014\340\351?\332vW\331\355\274\352?V\201\215\365\216\264\350?^\376\272\360-\002\351?M\312\226[\036\351\355?\206B\177\224~Q\326?\200\230\3670l\362\215?\260\210\205\306\223R\345?\223\222\371\226\\\006\342?\305\244b\343\266\274\347? ~\333\273\3625\225?\317\257?\257~y\344?L\325\336\026\023?\304?8P\356a@\332\322?-h\271S\\\375\352?~(\355\220Y\342\331?[uZJ\262\227\346?f\013r%dt\337?\204W\'\\\314z\332?@\375*\354\032h\221?*\375\253\375\305y\336?\225\357`\231\300;\346?r\3338\004\014\272\345?\320\277\r\331P\036\326?\304\204\360\\\014\014\320?X\204\236\266\230\350\352?\266\337q\225\307\351\332?b \355\253f\261\353?x\367\374\361\340\037\311?\222Z\216\323\377\245\335?_J\320\346:\217\341?L\n\201\323lH\342?\371\206\247iH\002\341?\242\264\017\234\210%\324?~M\214om\324\323?H&F1\001\254\335?\350\310q\227z\010\274?\202=Uu\230\201\347?\370\325 |a{\306?\277\262\236\257\212c\345?m\006\230\304~\222\342?:&\031B\242\304\322?\301\330\365V\336S\341?\256(T+\243\211\320?\000D\030p\214\310\306?\332\273mF\321x\336?\240\376\342\200\004o\227?\370\317\210\304\240!\270?\230\302e\211e;\310?\243\232\231\2126\221\340?\000\'\3010wY\241?8y(\370\226\357\337?p 2\360\033\352\303?;\274Z/*\311\357?\323\006\201\2756\317\344?\003wB0\256\342\346?\240\247\2766\256\227\340?Fe<*\315\016\335?`\257\224\2104|\267?\275P\307\310\031\376\344?p\264\335q\311\031\273?~8\241\275=q\322?\300\303\340\213\034\346\330?p\016\237\'7\234\277?\340D\005\200\222\363\310?LB\307\370t\337\311?\200\014\304\230\307\261w?\253^\301\336\022\362\342?\210t\n\302E\007\350?\344Ho#\223\376\317?\274C\375q\031\355\334?@\227\337Ah\025\265?\030\320%\034\027\032\264?\301x\0356\223_\357?o\337Bw\230\222\351?\026!\320\300\371\267\355?\3202\226>\234\375\317?\330\000\261\246\364\300\303?[Q\"\312\376\002\346?(\300\002\340\315\002\337?m\223\343p\244n\355?\316\214\371\252\351\177\345?\334*\356\264\317\262\312?`.c\025\231E\274?\2107\310\021M\035\272?>.\2248`\323\351?\203\264\224V\361\200\357?|v\367h\020\252\354?\376\024\026\302^@\342?p{\2511\"\010\265?\350\210\023\316\215\"\320?\252\312\230\264\360\225\351?X,R{\262\367\354?\010/\214\365^/\305?tOhg\254\207\341?6\300\240\251\014\227\331??q\341k\002:\356?~\370\2445P7\327?b\014\034\347-\256\354?\007/\005\242\333\365\343?\263\271G\364P\227\351?mv,\267\370=\315\323?V\263\237\206\251\212\331?\242\017\371\365lg\350?\370\024\255$\007\241\261?\016\031\314O\2055\336?0\362\010\277Rq\254?\036\342\262\'\266\343\347?Py/\2448\357\345?d\370]\314\303\037\306?\234\220U)\245\220\336?\240\300\337S_4\220?\210\272.\250\241{\303?\360\035s+\376\027\356?Fq\274<\315\313\343?\267\330\013\353~\255\346?\027\366\020\235\215\030\344?\n\372\252\334x\326\322?\013\244\035\352\r\022\355?x\273\034b\325\257\335?!\371\r\254.\240\343?M\226@\243(L\346?L\177q\245\220\026\327?\313t9i\201\006\352?8\334.\020|D\300?\240\263\tu\327\361\334?\304\212\206\027\275\010\311?\340z\312VVb\220?)\262Y-\017\273\345?\347\021\014;\320\212\354?,\327\311\263\324Z\325?\242\257\204$\240\206\352?\034~\000b\365\220\310?$\370\251\n\347?\343\033#{\"|\344?\244I\006B\374\201\342?\336\257\3149\025\273\353?\227\031\361\215!\026\356?\331B\021\246\n\260\351?\340j|\244\376K\274?Pk7H\244^\335?6\372Q\035\320\332\324?\"\035\021\260n\321\322?\300H$>\213,\205?\261\204\355XE\315\347?\273\023\273j\203\326\344?\243w\342\032\232\252\340?\267\353\232\255_\350\351?h\002\001\223\323K\342?\240M\265\335C\266\350?\016\327?T\367\221\350?\036\306\311E\347\377\346?\342\332\215\262\245y\332?td\235\320\222\307\310? p\323$\276o\224?\204\025\212s\307\246\331?\214\247M\377`\277\343?D\327k\234\202n\345?F`\344n\362\027\351?\016\002\254\"\322\316\345?\374\007]\005\326\263\306?`K\277\350\253\363\270?\337\036\261\354m\037\355?\210\237=^({\357?Z#/\315\326\004\343?\252\366\326b\010\375\335?>H^lh\332\323?\210\274\323?L\262\343?\300\271U\002 \004\304?Xh3\200Q\331\345?*\010I\366(\352\326?[\334\252\3418\025\342?\240F:\356A\261\227?hb\240\370\315\232\312?2\306:\026\303\017\343?\"\'\241\316OO\341?<\263j\246\201\r\357?\274XE\025\007\233\324?\270i\225P3\372\300?)a\301.]\r\344?&5.\311n\311\342?\252]\020\336~Q\343?s\314=\221\271\\\353?D\373\234\325h\035\340?\3360\300\347\005l\345?\2724\315\237\364\177\343?E\215}\365x\007\341?d\214#\263\2220\302?\314kj\352\311\234\350?,\310\227\214}\003\355?`\237\236\032\320\345\235?\000\337\t\226W\256p?kN\314b\317\336\342?\302Y\235\033\307\206\330?\352\354\310L\347\227\343?\346\243\031\265Bo\351?,`V~\013s\334?F.+N\226\330\346?0\3213\030\305!\277?\200\022\363\030\031\367\252?\321\321\330x\2533\355?\320B\323\302g\363\247?\350\215!\274\013a\306?|\nk\266C\t\352?\361Y:\217\360\360\351? \231#\26290\266?\320/\007\270\366F\323?\260^e\357\240\376\301?>\006\250\363\342\014\347?&\271\266l\200g\350?\374Z\027\274K$\345?=\363\277\235\024\372\342?\352[9Y@\201\344?\242\351\324\237$\311\343?\003\364\230&R\375\356?`\326\rlk\271\330?Q\026\367\307\204O\343?\240@f\030T@\275?(\217Se\022\202\353?\374\223\037\364\257?\332?\230u/0O\302\327?m\325\274\003$\214\355?\244\266p\315Cm\323?\252\370\260\000\337\356\340?\370\004\261\366\362=\341?v\n_Il\353\346?\002\206\032X,\343\325?\361?\306 ,\216\343?X\242\345\236y\304\357?\300hhd\273\345\307?\332U\202n*\033\325?\242\363\326\204o;\342?4\222J\255\324F\320?S\336\254\221\356j\357?n.#\003\340,\356?\005\216,L3\030\357?x}p\n\363\n\340?\030j\3117) \355?\220\224V\316\245\n\241?\350\302\t\277\216\340\356?\266\356\336+\253y\346?q\276?\306V\322\353?\t\347~\352\201{\340?\016?\223^\346O\333?88m\204\336\371\262?\316W\3166\\\026\340?\014\001\177\355G.\302?*D\334&*\037\345?\364\200\226\272e]\307?\374f7\3270_\316?\"\356\226\204\315\225\350?\222\265\202\217\3709\351?\230\250\2079O6\323?\336\221\364\026\302\023\340?\013\2525CA\266\353?\236t/\261\211\303\326?V&\350kCI\354?x\344\246\023\355\311\352?\227\267\t*\020P\347?\2246\355\024\265\316\350?\006\026\t\235\302f\354?2`\231\366\203B\342?\220\240\373~\261\343\354?\364\305FT\021\243\324?_9\241U{\267\343?AQ\177k-\204\352?:\273\270!/\026\345?\010\024\343\254ro\356?\340\223\375\000\275\364\222?\2549iF\007s\305?,N\363\356\034\027\342?T%\017\212\207\232\354?\272\177UtKi\322?\264\266c\223\\\373\305?v\365\324=\261\271\343?\351\017\252\316&\222\344?\254\251\r\237\323\377\343?m\357a }9\356?~\311y\202\373_\345?\246S^Z\240\347\347? \003\2031M\264\271?\350\230\037\3513\032\270?\250\004eq\r\230\307?\237\305S$\314M\341?\206+\360z/\216\337?\260\360\300-\206G\344?B\305\201C\362\316\341?\220$d\202v\000\265?.\366=;\263~\337?\007\022\203\005\372\351\353?H\216l\360LF\262?$\271\t?\371V\310?v\"\030\305\226\204\325?\300\225\272H\224\274\226?<\261\322\177\231\023\336?\227z5\237)r\340?0\030\215x\263q\343?\000\25532\362\230e?=J\305\320\236\302\352?v\206\365\331\"F\321?\355\334\n\006\373\027\346?\233\301\354\305\202\017\341?\244\333\213\203\226\032\317?\\\251{x\345\373\310?N\034B&\017\275\340?\337ym\333\261Q\346?\026\032\266C\0008\334?9\203Bx|\315\347?@P}q\356\347\210?\004\350^\307\362z\314?\250\275\330d5@\346?\t>\206\342\324\346\340?D|\024\373\315p\336?h_\rK\340o\340?\302d>F\213\226\353?,\026\007E\345.\320?\352?\271\245\234\261\337?\016\335\222\2436\320\353?\265S&I\353\325\351?\n\256\202\377\253z\350?\252\302\n0 \200\324?\003\204si\235\004\353?\214s\266\250\237\221\317?\344\261\177\373\347B\331?v\202\"\003i\001\323?\224\340\250\372\237\205\315?\r\014|\037\367\203\344?\032\207\365.\010\031\324?\251\034\207_\315e\352?\256\346q\345\1777\326?\305Y\351R\320e\341?\324\353\357\220`\235\331?\216$!\237\270\273\352?\035\327N\023\343P\346?D\333\025}\317\357\341?\361\377\035l{q\354?\264\017K\206\244\007\303?\3462,\305\304\325\333?\325t\254\020\250U\357?[\010\\\312\022\010\354?w\327-\277\243^\354?`\242[\004F\305\334?\340\304\000(\370R\276?N\006\375$\030\277\324?%\026\271(Q[\344?\034x\273?\204g\311?\021\032\235\355\320\222\355?HjT\273\333\326\273?@\243\306\236\334Y\346?$\037O\204\376\252\307?~:\340\027\241\243\352?\000T\"\337\324\341O?\000\035p\361\230\351i?\020\227\246G@\304\337?(\377n\236\361\226\344?\321\224:\261L\347\357?H\314\335\2226a\271?\264\032~\236G\236\310?\334\233\3003(\326\345?\273\251\203\272\022 \357?\240K5\370\\\355\307?l\307\020q\211\223\315?\345\261%\275N\255\340?\037BB\214,\322\345?\026\3425\346\310\224\352?\374\215f\t<\251\305?_\035\2043\364\273\355?\005\363Y\210=~\351?$\t\327\311$\256\304?\250\255,\220\205q\274?U}+$\334\022\350?=\215#\231\t\201\351?\0346\254\251\345\232\333?\3409ip\300\325\251?X\244\236\250\235H\323?$c|\344\022\004\356?\317\360\245W\201\257\345?\254b\250\2322\267\317?\r\362\014\026:\332\357?P\350)\366\204@\331?\004\026\250\007/\210\331?3\245\273h\337a\342?<`\273&HW\357?\316R\266\264\0034\350?\345h\'\014\200\025\355?\362\252wpf\335\320?\001\363z\205\003\022\353?\240\004\315\025p\003\267?`\320Y\227D/\250?\264R\332\2478\r\327?\320&\343\010w\361\313?\000\037\312\351\261\342\317?\010\341\344\244\357\177\317?\222\2618>\225\010\351?\215\035\222\r+\307\340?Ab\27507\200\350?l\371\246\347\242J\315?H\253\013\347T\366\337?e\267\n\215D(\346?P_\033\022\312\230\355?\350\006\364w\344\203\333?x\214\375\263\266\374\351?\242~\023\205\322\317\335?\3061\240b(\210\350?\372K%\327\275\234\327?<\034\367$\234\317\316?n\232$H\353P\352?S\344\025d]t\340?\2009\256/\376\361\310?{&\365Guz\340?\034\2376E\260o\340?C\3012\314e\332\356?k#%c\330\276\341?k\271`^_3\357?\256\255\260\\\241\242\347?*\226\311\232P\336\332?\\\352\274A\342\360\346?\004\2642\177\362;\357?\200\362\334\377[\307\223?T\301D\275c\004\317?\024m\254\323m\022\347?\304\264\231\222Tp\315?\260W\245\210\373\316\343?*\030\371\177\031U\323?\330\235p\307a\265\344?\245RvA\355*\350?\203\313G\377\225\251\341?\242\204\30494k\351?\240{\3076\330D\337?\250\274\226r\255\035\351?\0002\007lA\372\211?\350z\341\270\016m\270?\001\\L\010Q\376\353?\342\337\334`?y\322?`\250+J\243\222\342?\024\326\265\263y\336\330?9]\027\014\026(\342?0+?\333\347\344\243?2\'\357C\363\177\354?H\374\2657\202y\335? \037\373\236\322w\316?\250\025H\377}\036\336?P\240\332\020\035\210\244?\226\270\256yn\300\356?\373\007\353\222\312\223\342?\256g\017li\301\332?\202\254Kh\034\274\335?fox\037p^\345?\310f\220f\216\201\351?\350\315R\365\363x\332?T\265\"6\220\265\300?\241\303\2043\227Y\351?8^2\264\317P\262?\220\317\353\210U\327\333?\201\021\323\300 T\343?K\332\353\301\304\256\355?P%,I\270\244\306?\223\371\226\206[}\344?\274VkP9\213\314?m_\371k~\253\343?T\255\005\3506\022\327?\023i\215B\361\021\341?\223g\207\305\340\216\347?\304\252f\323\266-\346?\034G\236&N\305\321?\2039\374Q\211\372\355?\354]\344\240h&\350?\306\360=\371!m\323?m2 \222i\211\351?\272\375\326\240R\224\322?\360\2404y\027\210\314?\240\002.V\267W\264?\230\347\261\347\233\327\340?\364\031\316P\303\247\313?\250\347A\373U=\304?hK\360\202Y\247\345?\274\\\3175\n8\331?\273\236\210\327W\316\340?\224d\337\006P!\335?lMI\223\375\242\344?a7\300\233\203\233\356?\270\226\372~\327\304\301?\004\251b\253\036\252\324?\336r\307\206\205J\321?@\265\016R~\221\323?\001\332\247=\217M\341?z\274\273\347\253\325\343?J7\276)F\261\323?\020\335\377:\313\236\250?\340\360O\264\235\343\246?q\342\374\335Y2\347?r\227\226\r;\013\323?|\333\302\272\253\026\343?h\272\342\263g\002\353?\206b\305y\003\016\344?\376\367\376\234{&\355?\020\021B\272\240R\303?\360&\236\367\370\322\312?\200]\350\033Z\305\313?\260{_\320\016\352\275?3\312\353l\345\035\341?\\$\374[\237\244\303?d(r\272\227\333\357?\370\257\245\253\224\354\343?\031\263EGV\013\343?\365\370}\335.{\357?_\203d\302t\033\353?\016\020\255\350\213\033\330?\200~[f\025?s?\"\354\300\247Lp\324?\000!{\tQ?\325?p\301O\204\227\366\253?24\256\026h\213\350?l$?s\313\325\330?1\273 \221\323\375\340?x\316\367z\346\005\301?\0102\nP\257`\332?\002\312\233\361q\254\345?\343\3575R\354\354\342?Sqc\325\307k\344?\260\332>\267\205\203\325?C\n7\325BK\353?\270R)_\021-\267?\214\010\362\214\002]\330?~\253\013\347\244>\357?x\023\030\036\"\321\277?\316E\317\020y\371\350?js/\226\242B\341?<\341l\\G\311\324?y\266\270Tv\200\341?4$\013}\221X\302?\2368)~\226\340\330?h\310\375\225sW\345?x}\213\252dd\263?\000e\204\317K<\316?ye\355X\005%\350?1\032\214\211x|\356?\002c\203\n]\232\323?+\277=\251\233\200\346?\363\260\017\3212L\344?d\275\330\033\312=\354?T\243\232,\331\346\343?\017\177?&P\025\346?\2030RN\202L\353?Kf\254\366y\322\352?\264(\322\245e\345\346?+\333t/K\256\354?7X\352L\266h\352?@\023\3013\226H\207?\354s\327\005\274M\314?*\322\001o\243\231\330?Z\375\377}\3703\322? }:UPz\253?\301\217U$\343\321\353?3\027\332\365\364\177\346?s;\034\347\006\232\345?W\002\270\024\202\335\346?D\2007u\362\010\351?\312\004W].X\332?l\226\207\344\375\271\320?l#\024\024?\'\331?\364tv\353\316\034\356?\000\237\365\273\'\340\351?&\230r\365\206\245\343?HBw\215\032\345\335? Qd\331e\216\250?h7\261\030\363\371\307?d\364\257\207i\320\323?\210(\310B\361/\330?\001}\004\235E\r\341?\200M\273\361\327\270\202?tv\372>!:\354?\244\351\244\357\032\224\304?\273\207lhv\344\342?@ g\333vW\316?\310sO}\263}\327?P\334\020\2414\010\345?\253lc\001B\024\356?4<\256[\207\326\307?\370=C\240\374\225\313?]A\302}w+\352?}\266\037\326\324\303\347?J\214\r\313\360\342\356?\202\314d8J\343\345?\2204\317Y\315\344\316?\243`\034\304\313\252\352?\031u,O\351\272\343?\310u\004Ki0\264?\014\304\'\210\261\266\325?\235t^\263\257\023\353?@\315\236\005\305\035\233?\245p\000\332!!\351?\353\325\201x\n\'\355?\260*\240\230zE\275?\003`\262Y\253\336\344?\220\276\200\330\216\316\323?e\031\254\216\335\251\356?\210\005\343R\347\r\325?jP\216\022\371\242\352?\246\303\244\010\177;\327?\220\022\n\272\323\264\241?\230\367\2538\177\321\304?\337$\001\013\217\005\357?\200c\237\352\213\302\305?\036\273Iz\251j\335?H\'\314\322\020B\266?J\3446*\252\007\336?i9\3756>:\353?p;O\t\177\314\266?\013m\372\246\022\032\354?\nA\246\216{V\356?m\205\300Y\316\221\353?\003/\341\214\330\035\357?p\261\373\t\037\243\273?\372\353\250\034\216Q\354?\036\324\243\277t]\322?.\005\336\363<\313\333?@\274\305\232\177\213\272?,QZ\374\032\353\347?@\343\355*\031A\323?@F\254W/b\303?\271\342^F\232\307\344?:\322V\256\"\023\336?Xk\013\254\'\264\300?J\310\365\245X\360\343?\036\020\260\004\275\201\330?x*.v\202G\317?6`\266S\016\371\346?P\0348\356\346w\306?\325\355\240-\r\370\352?8y\217\245L\325\353?A8\301p+\346\351?;\207>c\324\363\355?\373\332Hv\021R\353?\001o\374\022\005e\351?\234J\006\275?\251\323?\333\303X\313\246\320\351?I\345\242\222K\334\347?\312\220\3761\331\363\335?H\001b@\330\301\260?\315\251\217\250\002\355\355?0\245t;k\372\322?d\354P\025\212`\345?X\206\311\250@5\306?P\361\214\034\266\026\305?\233U\\|\226\265\340?8\347^\216U,\266?6\027Q\224\220\177\352?\220\220\204\007\to\276?\300\330\270\325\257\250\234?00{\225\301\364\251?/\254\314\345\337\264\341?\262f\352\371\337\351\321?\260\324!\005~U\307?\000\3008>\242\363Y?&\264\002H\031e\322?X\213\243\211\376[\333?\300+\322\356\336\246\233?\006\323\277\0300\247\357?\034\2679B)\035\344?I\246\260\3422\261\342?\233cu\315*\350\340?n\307\332&{#\322?\374\317W\310\t\253\303?\330\250/\005\236\225\271?P\324\345\205Q\243\276?\224\336\010JN\177\327?p\036+,\344\311\315?\362\227\203`T\312\340?\014\353\031\226;\031\322?\350\222\032\201h6\270?,\334#\206\365Y\314?\270z\375\202=B\330?`=\021\246\335\005\262?\206=\311\014\334u\346?|\2144\020\n\354\305?\200>e\236&\013\212?r\222\263K\346\210\333?\264\203\356O\356\211\351?`\302\030\266\231\204\351?\000*\n)\212fU?\010\233\351\226ZV\267?\270\013b\356M\021\326?vr\320m,\254\323?FV\224;p\316\327?@\245\261\365\375\242\257?\2062\324Q\350*\327?\034\255\317K\016\005\354?\332\377SX\224\354\353?\260\300G\276bM\251?\204C\t\355\017\326\314?\275\320\223\367\336\312\343?5\233i\347\233\035\354?\375\230\360y\306\201\357?\00087\225\211B\262?4\335\030\3462M\320?h\020\001iu\231\340?b\252\030\353N\375\327?\320\205\037d\254W\353?6\307\353)1\241\332?4\246G\275\366\001\317?p\001epg\332\276?u\210\337\021*\271\347?\021\366\025\016O.\352?\233\257\367[\370y\353?\214+\256\305\336W\334?faS\307\231\236\350?\000\320\037^-C\256?1\"\357\245\236\251\351?\314\223\370\372Z\004\346?\361\220\220\207\246s\343?\220Lb=\032a\302?\020\006\377T\251\020\332?]\322j\320\225S\351?\2449\017\330I\300\342?L\276\360W\2006\336?x\220 \256\337\"\306?T\034dN!\275\346?R{\022R\023\277\341?!-\251\302\304\r\341?0\323\254\246vV\341?b\241\276g\336w\355?J\013\242\267\212W\335?G\226\307\177\273\216\355?\320\370,\357|j\332?\222\010\362^\253\326\320?\3600m?\200e\277?\244\346+\320M\346\345?\030\227\2419<]\316?\310eDD^\365\344?\214\000?0K\252\351?\013\272|\346\247s\352?\200-2\276\004\352\237?\010\274\213\244@e\344?h\357~\373Q@\355?\322A!\212>C\351?>\231\276\355\336\367\332?\210\256\235t\r\214\274?l\232\241\204\336f\314?\0001\221\230\257/\336?\210\026\207\251\240\313\351?\260\262\026\211\374s\325?\245\\\215D\247\003\357?\304\364\022e\341\313\320?\354;E\022i\332\316?\224&\370\226\034A\346?ym\224\251Y\347\340?`\365R\336\2278\343?\034\222q`{+\330?L!#\236m\n\322?0\220\216%\245\036\345?\200\037\266\342IU\240?$\226\224\274J\231\343?\331\004\375%\237\374\354?5\247S\035\020\000\343?Y\343\3312\275\262\357?\261G\004\343\242\366\357?Z=a\234\263\224\340?&\326\343fa\360\354?\244\263\021\241W\317\327?\210\000S}\313)\347?\370\024\343z\235\273\272?\345\243\246\352\263\002\353?\333\201\213\317e\213\346?\022\3311\240\017\311\326?4\236\226\356\243\376\327?Dj\233\247OW\357?\355\304\216\361[\010\341?\251\260.i\262\260\351?@\000:\304\\\315\346?l\205.\"\3444\353?\304;\030\216\016\270\341?\326\363\206B\266\014\322? \365w\031\322V\262?\273\322^\314\315S\340?\245T9\237\231\255\343?t\3528+\303\303\344?\374~\314o\313\236\347?\344\313Jysm\350?\221\237\000\347\032\246\353?0:F\356\363\263\302?\360jo\214\276\247\306?\226\314K\367\361\335\351?~\263\247Qf\370\345?B\322\211\335\237\303\354?L\263C\r;T\340?d0E\210\016\326\331?\016]\027I3\215\324?\252.\024\027\300\357\334?\000`\201\232\036\316s?\336\226oK&\333\327?X[\243 \262\365\354?(o\304\212c\254\304?\210\\\232\273\006\177\333?&\017\024\263f\333\327?@#K\321Ld\306?\300}lQ\364;\222?c\001i\211\235S\354?\002\222+1\r\312\337?\177\253\356P\031\264\352?\004{n\306\002\242\303?\340\204\3379\201\274\230?.U\271\215z,\343?\224gN\237\224\307\350?PI\273\220\236k\241?\000%\222z_D\307?\300\346FB\250\007\214?\364\361\347\260\346d\310?\340\263\031\272Ha\252?\360V\272\231\377\344\343?^\226-0p\202\356?\247&\224\322\005\002\356?\240\037\030\205$\260\230?[M\236\003\000\231\342?\256\301\213\306\240\231\333?\256\177z\214f\267\350?\260\034Y7-\257\352?2H\205\201\221E\354?\020^B\352\346b\273?pt\n\350\243?\260?6\333i,%\271\350?U\363\367\222\r1\347?\372j\023\tp\207\322?\200R,\373\016\361u?\031\374T\267\t\244\351?\006Q\226h\013\314\340?\264\272\342\3331c\354?:zs\006\203}\337?\034\377E\263\010\373\332?\203\355%\200\230X\357?PW\034\377\367l\342?\300\344\316\272\240\350\345?\240f(\200_4\237?\332;\242A\261\026\352?00\215s\376-\316?\031\237\317*C\352\350?\224\233\340\tjQ\357?d\305\273\nd\366\332?\204\303\221\264\032w\301?T#\004H\353\304\330?\336\333\177\263\215\341\353?\007rw\r\300u\354?\000\313\377\322\301t\226?\252\261\244\036\333\215\355?\021\234ulr\340\354?\2541*a\252\207\326?zVj\237\344T\335?\367\212\333\310+\025\342?\210+\253\250\262\236\357?C\264\233\3013\375\353?\255@\240\224\214\013\351?8\354\341\315M\020\276?\276\224\001R\032/\340?\334\370\325\363\036\036\321?H\270\360\247\232\004\262?\252\351\243\023F\376\325?wH\001\355\274\025\357?vQ\330\303Q\225\333?\020\302\367\333_\210\271?,\354.\335LQ\335?\302\241\312B\333\227\331?\323\244jQ\356L\345?\206\003\210W\332\355\346?(\216k\354F\360\331?(h\353\241\266R\356?uo\342\327\340!\352?\240\020l\025\343b\344?\000\305v\360\311\347\276?\366\266*Z\017\"\346?1\320\304\245\237\322\351?\355\327\316\316G\026\351?n\353m\207W\005\342?\017@:\237\345\216\356?Swz\230\374]\340?\302,p\322ac\323?P\376<\230M>\265?\363\034\245\026a\246\353?\250!\037\035*\371\350?\250`h\233\227c\277?E\004\n\306\342\256\356?\300\331\322\355\030\226\330?\n`\211^ZY\333?5\330\253\267<\330\352?\357*&\352\201\351\356?\230\262P\027?\344\323?x\354\2464\014\270\263?\250\305Sn\202A\344?\330\230\350\3347\342\335?\201\260H\220\310M\342?\225\'jD\nc\353?\007\376\262\350Z*\343?b\027\214\177Vc\334?\030\r\010\306#\237\306?\207\302k\264\346z\351?&\032b\235\237\221\355?\260X\374r\344\031\353?,\211\243\345\252\315\305?T_\375\375V\377\337?l\377\367z\211a\323?\336\267\013\3141\340\321?\320M\027f\303\234\253?\375\351{\345\243\350\353?\372\006\004E\004\r\355?R\237.6\215t\327?V\224\'\207\026j\333?\2600\242\001\305r\312?\366\\G0\221\256\344?q\256\t\001\250\266\345?\212*a\346\246\376\326?\351\002\330\177M\326\344?f\242$}\016v\323?\253\322\334\311#\307\356?\010\"o`4\236\357?\3253\352+\233\027\356?p\335\337v\2460\267?\177\235\n\206\335*\350?\031\003\307\266P\307\344?L)\243\302\001\305\311?\010\003l\237\231F\267?p\320H\007\363\253\334?\362\037\377\266R\361\357?m\345\307\201CN\353?\213vn\370\233\021\347?^k*\214p\"\342? \007\375\217\331,\241?@\242\355\345[\215\277?^\201\327\036\251\206\336?\260k\017\376\200\254\343?H\346^)\025\376\266?x]\326S\330\017\306?\277\177\227\313\273T\354?\252=F\237\317\271\347?\330k\242\342\356\241\313?Q\371\372\2367S\347?\342\353\323<\253\006\346?\3100i\330\010\350\267?\217f\010\321;7\341?\002\302\353`?Z\351?\220h\326\262\376D\344?!\226A\r\322\337\347?h\332\024\017\340\247\261?R4\030 \020\037\351?\312v\327\227E/\357?D\376\203\303D\n\313?$\247\255\210$\211\330?\305\010\270\265\013\002\356?\265\002\002\2650\232\341?\006\340\225\362y\202\347?(\314\335\316\300p\265?\007\376\253\211w\017\341?\341&H\177!\227\343?\310\020\363Y\265\000\316?\212\032\234\377L\342\354?\235\032\215\234q\313\346?[\301\255\363\251\207\346?\262\312s\226O\204\354?\370\364\023\304\241 \315?\010\025\340\367\267\276\272?\275c\212\217\231\321\342?5\003\034K|\206\357?\227\251gb\200y\347?88\354\261#9\277?w|\364\321\233/\356?H8.\376\303\307\340?\360\365c\263\000\214\265?x\337\303\215J\241\275?G\234SIC\375\356?YH\021\010wX\351?0>\376!\2572\263?(\331\n\340\374.\302?x\302\225q~\233\276?H\004\263\202>\337\330?93+\343I\013\352?\004\325H\030*\314\335?\233\006f\241\230\000\355?Kl\307!\3601\356?~\r\342\231\033\361\331?\337\277\327\354A\262\350?\354\177=\236y\270\342?\342\231\223&2\243\340?U\201\010(\333\357\352?\230Uk\3009\334\327?\251\301\274\245P\267\346?r\274\236\366t\202\327?O8\271PK\233\352?\r\246^\312\260\321\340?\304f\3138w\230\336?\034\224\024\200\300\322\340?_\2019~\255F\346?\000\234\tT]\356\225?$p\256\253C\371\357?\266\032s\310j\203\321?\336LTo^\023\324?\254C\316\300\216\233\331?Ht\031;\211W\344?n\212\323\311\347\377\343?\"\240\033_h\242\334?\374\253b\"\346~\342?\330O\002d\274M\305?\3027\010\371\3156\352?y\352\304y|H\354?P\177\030\364\010\324\350?\371g\245.\331\267\341?\352p\364Y\261\370\334?\367\336\273\253\256\332\341?\327\352\2508\300r\346?8\036p\357\372\264\272?\264\021\270\005\032w\335?j\332\211\024\366\225\345?\243\306\227\361\021|\341?\314\303\215m \362\304?*\177P\355\273\010\343?.N\352*\000b\321?6F\225\250\354G\343?\251\270\300r]S\343?\370\246\324f-\277\273?\226\307\0264\307\245\351?n|\321\250`\240\337?\310\326$O${\324?\240\313;\235\256\350\343?\262\226\273[m%\355?\\t-\030\335\346\343?\366d\034\367\0376\355?\032H)\242\036\013\342?\000+\240d^\013\327?\017\360\343BQ\314\340?Vq\'E\231\263\345?P6|\345>?\270?|\242\033\222\237\366\307?7\"\027\256\260U\355?~\257j,\242\273\354?,H\236Y\232f\352?\274\357\247\356L4\345?\364\206V\311\222\366\332?\315\010\033\352N\365\351?\325RG\217\274\250\354?\203\362mU\311\323\341?\376\367\261\336`\021\350?J\020\302\343\325l\356?\021\177S\231tB\340?\010\370\t_\306\t\326?b(-_\236\001\326?}\246\030\032;}\346?@eD\177\365\224\326?\204\233\234j\235\024\353?\262w\037\235\360\004\320?\000C\236P\010\016\331?\264\343\344\006$\216\324?\240\373\031\3645^\234?u\246E\272St\340?\354\221#\000\210n\354?\274\303\013H\036\240\324?\240\355\303\264\231N\305?~}\261\366i\255\325?dN\220\277\343\235\315?\300\023\211\244\266\242\346?\036\317\354\360\\\225\357?\361\320Si&$\356?\235\324\366L\204\335\342?\024LR\325XK\353?\224n\005\211\335\360\333?gl\301W]=\357?\\\257~$\253c\353?\377\317\367&\317\354\350?\255\327\240\262\260\357\346?\240\312\024\376\371g\251?L4\330\365^\003\321?\320\200\2268\022S\244?\250\356v\221)\210\303?`^\3029\370]\316?K\2314sd\223\351?\013\036\336\347\313\343\351?J\234\233\273\223\332\322?\364\222a\034\224Y\325?\216\023\342\313\0260\340?\020\336\213\231j\'\313?\375\201\301\305 L\356?\036\"A\024\002\360\324?D\032\255\231\350\372\302?\010\034\203\014\302\224\344?\000:\317R5\342\324?\332\232\266\303\314f\356?H\367\307/\030\203\307?=\241=\313\373\256\340?\270\274\304W?.\323?\200\302\2243\\\364\356?\267F\016^\274-\345?\331\224T\367@#\354?\210\032\247\366\341\005\302?\220@V\320\332<\246?J\322\n\'R\333\347?\234\247\211t\244@\341?\024\307\262\264y\341\342?\262\234\253\001\232\232\357?\rw;A\022P\343?\233-\000\330\0223\350?\207C\206\311\000\345\351?\311\014\021\373\020\320\353?~\263\270\007\311(\337?m.\257\224\037\027\341?\343oN1rU\356?Z?\254\223\337N\345?J\304\301\300K\004\333?|\336\326\335\331q\336?p\001\003\252!B\254?\010\031:1\315M\277?\342\254T$\235\272\322?`\204\250\177mk\222?g\232}l\201J\346?\252 s\305\356\002\355?k\233\227W\220\223\345?\244\335\235\214\370\370\356?!=\177\026\013\023\352?\304\017\271A\351H\303?\347r\335\242Y\372\357?\317\006!;\3531\350?L\033\002 c\252\345?\336\nb\255\275\225\320?#z\032/\340\276\356?t<7T=\254\353?)\2236\037Y\006\356?\213\305)\200\313y\340?\260\006\227,\376\006\344? \t\370\327\217\277\273?\240\225\217\265\230=\264?\300Q>\261\n\031\275?\243\245\213I\022\230\342?\205q\255i\014\276\356?\340\265\323\344 \254\350?\004\367\263\210x\035\312?0\260\020\n\334\343\323?\326\034\225\323\302\023\342?\354\343G\362\264\327\323?<[\267 9\370\300?\014\236\217\355\260\207\357?\376 \333\220d\243\347?\342[\t\264\376\241\340?I\ruY\036G\355?x\027\r\\\357\367\351?\376\274\366\037\320-\341?\244\002\037\341=\000\352?\250\254\275\\N~\335?\240H\030w&\311\330?\000\207\276 \366v\316?\320\3656\2638t\246?t?\025\301L\312\343?\200V\'\360\350Pq?\263,3\375\202\367\355?\004\203\211\233\232~\347?5\030\364V\364A\341?@>\037.k\305\341?\350\236\033\317\341\322\353?\232\316\260\305i2\352?\354\243\337\2421\204\357?\332\264\013\271\350\031\321?\202\307\215\025\270\230\350?\322?\377\350,3\320?$6*\355K\300\315?\250*(#\234*\276?\240\254\307\202\361d\355?\246F\032\335\301\035\327?\010B\305\0107u\342?dL\257t\030\277\331?0 \333\337\002L\357?;\026}\237\222>\345?\363\256dq\362u\343?B\226\216\235\260\017\351?\000D\016\024\202\004\337?\020\266\262\333\340\200\252?4\267\214v\215\305\305?\260\337\303\032\343\336\243? \261\013\016E\332\357?P(^\314z8\252?A\022\250\352\3648\357?\354v\263R}9\354?\010\007X4\207@\341?\3529\026 E?\327?`\311N\310\035\021\257?\355\3509[\273p\352?\347\372\231x\235\226\353?\362\234N@\304L\343?\252\354\250\271\364\200\330?F\273\\\020\251\210\324?\200\347\262\224\353*\260?\347\t\354K\351m\345?\3606\000q\370\235\315?nf\246\227=}\352?\220\trv\324\000\311?%\243\0317i\262\355?\330\007;\377R\353\335?4\217\367\226cN\334?t\001\035\334\260I\347?\\\215\302\200\262\262\337?K4\343\311r2\353?\032\n/|\366i\335?\r;x\355\023\366\353?H\341\225:\211M\301?\271=L\205\304z\344?\250\n\036nK\365\267?\200\003k\372\306\254\263?\362\207\315i+%\344?\213+\014\253\020&\347?\212\235\251A\017\267\354? \362\375\274\213\347\260?f\320\276\245\216\247\345?\230\035\013\271\316e\271?\303\014\331\223;d\356?\224l(\232\240\323\335?<\210\375\372H\r\352?\210(\331\324\0014\343?Iw\3015\005B\340?\205\264\243\265<<\351?\024\356cb\020\255\340?\204\213=\025\'C\307?\251\261J\257x\020\353?X\375\362\222\024\005\333?^3F\350\355\353\356?\016{\334\032\270.\337?:\266:\266}\000\321?p+\227s\312\255\273?\333@\313\n\345\322\357?\216\260\344\014\260\267\321?V\216\200\375\233\353\345?>\035\0322\341\007\341?\310.\305\241_d\326?\262!\221\255\350\'\356?l\0262G|\363\357?\014\3141\251\010A\332?\350@=\207\366A\302?\260\363\330s\007Z\333?\267p\356\266\370a\346?\005\032\311\372U\231\344?\000\003\337p\245=\254?h\031\217!\261\235\324?\377\241\350\276\302i\351?\323Dta\320\251\350?\222\3701\312\016\252\336?^\217\357\016l\322\347?\237|o\361\335\200\357?\210!\202,\302H\355?\362\227u\025|\305\350?\320\237\324^\353.\342?\320\024/k\262\202\260?\351*W\321]\253\345?\252\313\023\204\370\315\350?+\340\305\256\221\'\354?\245\007\207.U\220\342?x\246\010!\321(\353?\036\235F-\273\301\353?@\2101\235/\304\210?\316\020pu&\000\334?4\344\244|\242F\333?\234\227\352&\037\237\314?LP\356\302\213\331\330?b\333\\\341N\313\325?$c\212kF\301\302?\024f\246y\347\330\351?\210Y\201P\252\321\267?\367\001\325\033\035\374\341?`\217\000\036\235\333\342?\254iG\246\343\367\324?\303]\223\252\256Z\347?\364\212\037\310\374\374\306?\337\201\244\200\341a\355?`,\220!\206\263\276?\\c\307cM\307\314?\t\'\267)I\343\355?,\375\014X\360\225\326?\024\326c\377\327\202\333?P\"\261*\352k\243?l\214/\314\232\036\317?4B\262Q\0056\317?\014\261\037\363@\005\330?\340\n\264qLm\320?\200\275\256\366\313\314\341?\350|U\311\326\316\354?\016\236\336FA6\335?\210@O\030Uv\266?\255:3s\230a\350?0*\007\017n\357\303?\023\303\3651\256\006\354?\310\322T\264l\337\261?\\\301=\030<\317\347?\313o\322X\200_\340?z\034\355\005\304\307\343??\337\3731\201\325\342?\264\231\037a\032\255\320?\n\002\241H\276\331\355?\236\257\366\264\244\231\323?\260\365\320m:i\267?\301\272\337\024 f\354?\372\251hET\304\357?\260\265\000<\210\335\275?Zu\306\247\016,\323?t\016\213\232^\302\311?\212\360\343\335\250\354\332?\200\025F.B\336\314?\360\353\037\326p\230\251?jM\370\340`\025\355?\007\2431\231\345b\344?\220k\210!P\364\304?Y\002\224\226&\264\356?\027\326#>\362\320\340?\326\261\200\301\207\233\341?p\347c\201\247:\243?\000S \365\000q\336?\026\005X&\221@\325?\000 \336\337\271\221\234?\256\346\035$\367\237\350?\260\025f.>\366\344?`d\223>\216\323\330?~\220\302\231\265\'\345?D\020\241x\273\003\334?<\377\014\326\3473\335?\320\224\222C\261\031\326?\320\276\347A\210\210\337?\230\360}\"I$\357?\244\000\321{6\350\352?`w\220x\350\224\265?\346\225S\363\255\021\350?\364g\221\'\264\344\312?\377\231R\032\0005\353?\240\254\277Kw{\331?\304\375\0033R\223\320?\330#GbZ\217\346?pSC\221g<\253?\372\000\037\r\326?\346\266\026\320B\233\357?\370ZS\345v\344\274?\260Q\240m\217+\353?|u|\235\355\367\303?\212\275 \211\374s\355?$2\211d\t\035\346?\370\232E1\305z\307?|A\2208!\243\320?\000\274\307~\367\203\300?8{\005\002\270\033\334?\340\264x\257\205\274\304?\2048\275\275s[\335?\031,\242qy%\343?\240\364\257\036\312\204\274?\300\376\r\206\256\024\217?\3063\254cL\306\335?u\347\250\356R\352\347?\343\340\260\332D\345\341?`J\333aR2\235?\274\243\244\276\031,\340?\340\343}\271\311\352\260?\352\230\223\263\'\377\344?\nI\023\246O\310\357?}\211>\366\220w\342?0Y\030\332\301m\335?\324\203\257\327\220\343\315?H\362>E\035c\264?\315\241\r\"\300\036\357?\215\251\343\312\255G\356?\244\201pP\360b\337?}\346L\350J\336\352?49\010P\007\221\311?\325\021\267\364\222\253\353?\200\343\250{Io\227??\3106n\311\320\345?\364d\207\331\3314\327?\354\3045g\303\375\326?\264\353\034\242\306\010\334?\023\213z\230\251m\342?\374f3\264B\033\345?\354X\356\353\234\237\352?\200Ia\022(K\244?G\372)\240\010\367\355?(K\205o\305\200\336?\263\264\231\255\333\207\340?\236w\263A\310R\353?\200W\233\217\003\311\235\265?\230iI\221\310\203\273?\025\245v\252\023}\354?\370\371kow\375\312?\354\2775\316;\345\356?j\324\310D\256\227\344?R~\020\314\014$\323?\334e\267\033\034\345\333?P\264\363%\231\003\270?\322\351@M<\260\320? \266\341\244\206W\330?r$*9\251g\335?\030\342\371s\247\305\356?`\024\005-.P\316?\013AT\031\345-\354?\364JG\355\217\230\343?&\222,Z\317p\347?\300\226:\316\304\271\205?2\037\033\013B\023\354?\022\367\356\377\227O\341??S\311\335Px\356?\026\002\204\242\210\375\323?:2t\347k7\343?@\303X\215QA\357?8\357\202a\255I\357?\0247\374~\322\333\334?P\256\213\031\334o\342?|>\375\nY%\310?\332\333\204\365\034\260\331?\355W\374\024\027\322\342?Y\026h:\250\326\342?\244\262\204\273\361_\321?\"\274\2359&m\321?\310\300;?\0272\325?\020\000uL\365\270\303?\216\252;\n\322w\331?\315\257\372\343\305\302\354?\274\324\322\222\007X\350?lq\2728\3240\351?nd$\025\345\374\320?\271F\212b|Z\347?\324$r&\232\303\344?\216\346~r\354\312\342?\370\263\331\255\265?\200?V\"\025<\271?\\\317\037\205\034\226\343?\024\216\024p\310{\314?d:\202O\333\253\341?\006\2021\235M5\323?\356lt\353h}\325?q\262--\001\354\353?\300Z\341\320R \330?\263x?rc\035\347?\300o\202\214N_\266?\232\236\377\260\233\317\330?\207\276\320\216\275\210\343?m\322\232\225\362\272\347?z\244\377\234b\037\352?\332\203\305_}\237\351?\316\037\277H\270R\332?\300\324\nz\272n\242?~8\013\334wn\352?4\021v,\330\336\346?\312(j>SK\350?A\260F\377\216\325\350?\007J=kdG\345?[\320\206\372\230\270\344?P\013\311m\277J\270?\350A&\n\032{\354?\320\334\365\213\216^\313?]-\230\177/,\352?(\303\362@\211\\\320?\257\261\376\371\331\230\350?`\236\216\230c%\303?d\344k\210\331\017\323?(\372UX\304\005\277?\222)z\345E\231\322?\025\330z\006\016\201\354?^`.\252\224\327\337?8\270\232\025\026\306\270?6W\354\213{\347\347?\246\333\226\306\303\347\333?\256\367\214\323>j\322?\204\352\326\252Y\335\353?d\321\253\205\235e\335?\336\212\251\352V1\334?}E\306\255\223\365\342?\344\260z\t|\357\340?R\326\323\241R\'\340?\250\217\321\331\217\210\272?\2741\024\2042\207\357?\306\305HV\312\006\337?\220\214\326\005\323O\263?\370\225:\374\302\203\355?\273\3659\320\207E\345?\014j\354?>\271[y\206M\322?H\016\3242\230{\313?P\214>q\002\350\321?\217\250\305\226N\247\357?\200\233\313\214)>\211?x\254\252\211!(\345?\225\\\033\346\215x\345?\306\n\"\362~\036\322?\027\302h:\000\216\354?T\031\246\315W\337\304?\340\\\367\351B<\225?\333\027)x\025-\346?\010\036{\2702h\342?\202\235?\017\272\256\340?\207\237\365\241\362\303\357?\006\212\034\315-\376\326?\210\202\327\257\341\254\332?x\225\265\013\024\276\350?)K\344\235ch\351?\352\2642\001\236\333\356?\241\226\013\262\\\036\357?.\033d\020\204\360\323?\272=\237$j\222\332?\314\tcC\214Y\345?\310\315]6\245\325\345?\364#\007\t\236\212\337?\360\311<\230C\374\351?@\377\340kG\013\334?*\201p\321\202+\331?~\r\022\213e\307\351?\352\244\270\324\301\275\347?h|\227\336\253\"\351?\3135\010\364\204\354\342?g\273\010\'Sr\357?\230\322\240\221\020\236\356?\311R\227\263\376\366\354?\311\232\351\013\333\277\344?\240\010H\313D\303\274?<\305T\323\014\312\331?\234hkQP\264\322?\234\366\342\354\277c\322?\304\032@\372#\342\351?m\003/\312\322z\357?\240\026\336)\343\360\310?\230\200d\342\224\321\346?1\346\t\271\322P\343?\246\264\350\230\250\032\327?(\300\253\311\2715\305?\322\346\025\004b\272\323?\306\3602\023\232;\353?\240R\023\351\'y\344?,9\214U\307\244\303?\303\204\351;\265\223\343?\350Iz\026.\247\325?\270\016E\010-%\303?9x\203e\206Z\343?`\312\234\313\037\375\330?\0345k\374\217\330\336?\322\342_R($\342?\257}\026h\306\224\352?>\334N7\356z\321?&\306Q\003:\356\337?XIGkl\255\315?\302\241\210\351!\257\321? \247\242N=}\231?\350\371\264z\335\215\344?*^\321\214ol\355?H\236\320\006\354\376\350?1S\tg\035\361\350?\177:\253r\343\320\354?\200\2444h\347.\347?\3466\214a\351\352\337? @h4Z\257\272?H\3322\213\367U\324?\247\250\330aY\222\347?\304|\275}\025\241\312?\360\035\322\351\027\205\327?e\003\342\215W\251\354?\372\200\026\317\263[\330?\372&:\302\206\221\344?\317D\307u\001\241\343? l\322\317x\366\324?\360>6\215\272\033\242?\306\375\000\200j\274\353?\006)T\200\346\241\332?\352\265}\272\225\301\353? eX\231\373=\257?\014\302\253\331\236\221\355?\220v\343\351\360\341\337?\263v\313M\311Q\356?\244\225U\271\"\335\353?\200\035H\231.n\326?)\270\352\024\300x\346?\353\245\032\342\245}\350? \270\362H\362&\346?k|H\334\362\231\357?\220\261\273c\361 \307?I\362\241N\316\206\353?\240\302\272\300\340\216\246?SU\337\331\010\371\356?1G\315\312!\271\341?\253P2\221\270\333\340?_WG\023\205h\342?\216\317\030\027\301\347\326?\257\020,\242I\242\342?\362\367#;\332)\327?^ |\241\027\376\333?\320H\226\177=h\335?V\022\303\371\020\341\337?\320|>\220\307\324\320?\017*\231=2a\341?\036\312\263\250\255\377\347?\324;\3614\321\272\354?K=]\036\251~\345?$}\020=\237\270\322?\005U\225\326\205\317\351?\200\351=\033\200E\341?\006\302\335\334\361\315\357?[\034\315=,\245\353?\014\']4\324\275\316?\237\225\275|]r\356?L`\265\204\005\314\307?@\216\032\005qK\202?&\270z\0328\213\354?.v\\\331\234<\347?\324\003<\354\023_\315?x[]\313q\307\265?\374\274\242l1\377\324?CS\014\242c\264\347?\275\020\001\335g\222\340?4]\234\277?\210\351?0SU\333z\263\313?\263\'\323)gf\346?dJ5\n\207\334\333?FSN\373\\\005\333?\343(`\216Zu\345?\220V\007\372\376_\273?<\237\223\331\337\005\314?\000\000[\214\3327\334?#\344\361\300\367\033\353?T\304)8~\\\350?\004\272\374\013=E\345?\032n\202\364\036\033\350?(\277\177h\2475\271?\250\257\345\311\210\332\310?\242`\r\035<\004\324?\246\245\212\0307\025\344?\200\332|\027\337\035\261?\223\355j\034\3530\343?0\261\244\376\262\345\310?/\376,\235\350\236\353?\200\352\t/N\255\245?w\244\204\253\311{\356?\030\206\370D\025\301\261?L\031\034\276qN\305?*\034$\362\255\251\343?7,\242\241b\372\352?S\236]\335\336\"\357?\312z\243\202N\217\324?4\271\316XDM\353?\240_p\\zK\250?\345\\\235\240\353W\353?\364m\340\373\002W\351?D\245FC\026\227\350?\377\330tE)\005\340?\037\2046Cn.\344?t@&\372.\362\314?!\221\236|\207\324\341?\013j\003x\035\024\342?\000\3218\355\337S\317?\350PX#\321\252\300?\344J\207[\325\323\302?H\360\216\203\005\257\311?\262\351@\333\3220\345?I\342\207\325\014\247\343?^\246\272\242\t\027\323?\0107\202\010\254\004\313?\0000u\317c\000\222?\330\373\361\222\263\267\260?\004\313\366\212SV\352?^B(R\326\007\353?\006\334!\025\270\215\334?\272\344qt[\'\356?\242bU\300uU\356?\272\355\360\337\344\237\344?\017\307\226\032\247\356\341?@\240\224x1\352\232?]6\305vO\003\341?4\215\255\272\214\007\342?\350x&\274}b\315?(fI\312\255\273\334?\352\014\252:\275\300\352?\371\322Tm\253\366\345?\220\215:\tF\370\356?\360\005\n\234\273\334\252?@q%^k\022\334?k{e\206z\335\343?\320\013M\020}\244\342?0A\0022\200H\334?\235\202\023\024\362\354\344?\346\247B\345l\217\341?\020\240\233\325\000 \323?b#\216S\022\264\331?\335\013\\B\257\267\352?\302\022\344\332\307\212\322?D\340M\033\327\'\341?K;\261\234(V\346?\010e\232\235\244=\263?\232\345\241sY\350\333?\332`Wu\2005\340?\000C\367\256Bz\321?\230`X\206\036\233\310?\014\313taS\001\315?\005\37252\315)\350?\344^}o\027I\333?~h\357U\271\310\341?\210\370\267\370\302i\335?\244}\215D\006\316\317?rq|u\244B\347?\010\264\316\023T[\330?\000@\013\2250\306\030?C\027\375\222#\343\351?\212g\355c\013&\324?\314>h&\275Y\310?>o\334\247\030\255\332?\017pO.\215^\343?\366\030n\023:.\357?`\035\177\330O\354\222?\334F\212\277R\367\327?\337\\\244\233.,\347?\254\3123\274Vw\305?`\370\353\351\373\252\346?^n\031i\274\312\356?Pa\"fXs\317?k\033\370\t\213\335\340?\370(\236\235\007\374\262?z\217\2531\310\261\326?X\002\277\\\013\013\277?\220G\0103M\316\252?\371.\231\327 z\350?}\020\340\"\366\360\355?\314AN\205\330\006\303?\025\001`]\337\225\354?o\221\310\355B\250\347?\320\336\371\247D\222\254?eiL\274Eo\355?h\277>\370\271v\306?\254\022@\311\321\365\340?p\002OEU\210\302?D\327\323\226\255\367\304?\304%t\265\263>\317?\304h5/\026\277\307?D\305\003\313\033\335\306?\344~4x\251e\350?\304\303\202bi+\317?\325\212\254\214x\350\355?*#\242\271\027C\347?8\222_\031h\312\312?\226\335\3629\217\026\345?D\270Nq\016\252\334?U\272RR\315\005\350?\214-\342Y\341\267\357?\\pY\2568\'\312?\036\344l\032\332\322\324?\005X\31052\007\350?\241\207u\225\224\023\344?\036\330,G}M\357?t\007\211\251n\372\303?\310AoJz>\324?\023\213\327\214\014%\342?\343\362\016\035E\260\347?\264yk\322\004\305\315?x\315\002\301\230\027\301?\007\223b\330\302z\344?\240\362G\226\372]\227?\207\000\262\250\033\327\345?\031\036\202\273\241t\353?@Hn\245\244\001\321?%\320\210\222C\267\351?\301\020\330\034\260\337\340?p\017V\212\347j\246?\277\344\371\245\267\362\341?0\360Yp0]\300?df\275\241T~\301?\352\314\364\223l8\337?v\370\302\274\275\217\327?\254\345\223\354,\253\314?\304\033J.\324\210\322?H\361\340\276T\260\350?\036\"\3009Tf\341?b\334\261\311\215}\354?\030\030\361\371\355+\265?\342\310\237\305\260\335\344?\260\013\315PM\005\351?\374~d\311\330\250\335?-\222\035\312\0335\343?\035\000\303;\224\016\352?N,\342\273\245|\332?\350\310\264AK\217\326?]\317k8]\357\355?\206\216\211\216\rE\333?\371\373\227\372\212n\342?\t\203V\220nW\351?\222Gim\305\344\356?p\333N\225\027!\243?l\037\023\356\335\216\344?\253\243\263H\224\226\345?\211\374\271\2311M\347?n\271\376\235,_\356?\340\237\231\007\272\202\265?2\366\"0\232E\332?2\035\225\207\341\342\353?}\322\210\371\273\347\355?\254\305\2770V*\354?\300\342\201\254\360\323\342?\341k31\264}\356?/\023e\313\360$\347?\347\256m\225\316\326\346?t\275=&\003\315\305?\255\206\371\033\200\224\346?!\201+\221_o\340?\255\352\023\356\3206\354?\202^>$\220\353\325?-\366\305\016\372{\352?\350\022\003\031\372\004\335?o}n\014Ag\352??\2757\207S\234\341?\256\002\n\271\t=\320?`=q\263d*\267?\274\330\374\320\3440\306?\370\037i\204&\222\320?\264\314\367\024\347\030\305?`|\367\333\223%\264?\200\351\251\2644\227\237?\326 4\017\223`\351?\0355\344`O-\351?/Sk\'\374g\346?8\014\322\2528\273\345?f\013R{\242-\327?\372w\007\250\036;\333?\355w}\255\337\356\344?\370\237\003\273\206T\272?\200#_\346~\273\356??\031\036\212\344\325\356? \311\024\0245\373\333?\200\2415\233\377\311\267?\340\250\355h2\360\237?\274\333\325C\205\250\331?\034\257;\307\034\336\335?\252\224\320\302_\016\351?\310>t|\353\003\304?\225\270$\200\022\315\347?*\n\260<\310p\353?\374`+\307dy\323?\326\026\001\035e\036\340?x\342\241zD\242\323?_\233\275\301\360#\354?\343;\275((\333\344?Vh\261\360s\314\353?\340z2uO@\263?\223\014\032/\n\247\357?\254\032\270\351sg\344?\224\340\353\227!\226\341?(\265J\2354\207\311?\010p\233!\215<\322?xl.))\231\274?k\370\241\374\033\031\353?\2745\0207\352]EF\325?\000v\264\025X\\\341? \356\024\221\2320\265?\201\332^\235\244\354\351?4\032g\227QB\314?\263\037f\274\347\r\343?\006\265\355.\265>\335?D\342r2:k\336?\200\360\373\033j?\256?\034\352\006\322=p\325?\210m\215\222\024M\303?$\243\243S\001\351\347?\204XZ\227\262g\311?P\202\231pI\325\306?\033\230\347!\346\271\353?H\367\252\335\367O\357?\300]7\000Zb\357?F\327\351\000,p\330?\375\265\215Q\033\333\345?N\233\3709\355\003\332?4\347\032w]\312\341?\037\261U\0050G\354?\\\342\200I\273,\302?\232\006\271\226\010\303\356?\303^\341E\231n\356?\276\004\202\ng\003\326?F\376\027q\354[\333?}\232\276\007\310\222\351?\312!\303\003\253\334\341?I\347\277V\005\377\346?\262\253\273T\331\343\323?\256)T\213\027e\353?\346\315\265\021*\"\357?\340\223\346w\242\346\310?\315*\214\254&\033\356?6\201\255\004\021@\341?\'\214[\366\334\342\356?<\306\006\"\373\215\306?\026\240\266:(.\336?r\177\363\332%\363\357?\014\315\024R\234p\331?\200B\317Uf\270\230?\027v\342\007\2168\341?\266z\261\257m~\346?\024\337\010U\337\210\311?\000\204\023\262\313\375\334?\000\321OH\225\355\322?\\f;\256\330l\334?[\310A\330`t\344?\334;\327d9\306\304?\244\230\3411]\314\306?\234\351U>aw\325?\356dXs\355\364\342?\360j!\'\274C\241?\247e\244\356\026\355\347?;\257\250\300Z\366\355?\270\"\322\226(\233\323?\232\354\337\200P\305\342?\ra\202\221\213\215\344?\006a\3305\"S\357?\200v\234\036\3620\254?\3541tW\257\321\321?\026\271\010\363\374\322\334?\024x\037U)\374\306?w\314\264&\335\026\356?r\311\3438:h\327?$\364\225\237.,\330?K8\244\0306\243\347?\256|z\256\360\226\333?8\217o2\353\\\276?S\214e\204;\177\357?\362\374,\241\235\322\323?\004\200\205\273(\305\340?6\215\023a\033J\336?\027\025\302\253HN\350?\034\314<\242M6\356?``c\367\326\257\317?}uy\377\201@\354?\320M\366\237\203!\313?R\210i\230nL\337?\253\224\014?s\\\345?A\210\356wo=\356?\346\216\371\375!_\355?\236\014\010\274\244#\354?\221\264\363\006^\362\351?\030\350\2121\235\005\320?l\035\365i\216i\317?\324\207\246\000\214\021\347?\316HW\363\361k\343?\035\2454\214\260\014\357?\313\255\350\306W\203\355?\234\261\340{u\n\325?a)7\240-h\350?\330\252\324\253\232\322\310?\243X\"r\224\302\340?&\377\013\317V\204\344?\273\373e\3720\026\352?Q\252\003\2677C\342?ok\234\315\277_\340?\340\326\361\303\030\243\353?\304\"\024\'\241\001\350?\3467N^\252\337\332?\300-\006%/P\330?YbM\321\000 \345?`\201\271\274:j\327?\354\316\207\327*!\326?\230\263\237\004\002\373\314?=\207@m\374\315\347?x\033\226\377\034\312\346?\314\370\372Vu\353\357?\010L\354G\034\205\333?\222\322\'\253\343\204\350?\236mU\307\235k\333?&G\330\304H\334\342?\306\027;\234\215\263\342?6\002\271gCe\325?5\032\333h&\345\355?@nD4~L\314?\361\250\215\375\t%\347?0\377\006\332\322\276\272?@L\361\227\322\317\304?nH\004\010\227\021\350?\374\016\257\2160V\356?B\333|\026\274-\351?\\\231\302T\016\000\327?b\304\356\225\374\025\324?\037D\236\326uV\354?P\033-\272\353\024\270?(\005A>B\234\311?\3701\316h\207\272\264?f\356\344\275B\350\321?\200x\235C\255\373\336?\267\336K~\037#\342?\200\031h*C\262\202?\310Aeb\266C\266?\232\030\0348>K\326?x\022t_\306>\302?#?\214\374\034`\342?\362\376\325\312}\300\336?\006\217S\216g[\353?SR\276\257\235\311\342?~~\312\330\330P\337?p\235\331q\n\010\327?\0304\250I\314\030\330?^\t\272\233jo\347?\312\"4k\366\230\341?h\301,\276S\313\305?Ps\330\336\305\313\347?a\032o3N>\342?\033\331nt\354\002\356?\274W\205\351\264l\312?I\375\020Y%#\341?\231#\230\314\t!\345?\362\364w\000nL\333?(/\270l\352\216\312?\0000Y\372\010F6?\212U\240lWK\344?\306D\222\357\3073\344?DG\337\'\375\332\314?\343\307\021\014\275\374\350?\330Z\254\\wX\334?`CO \263o\254?\224uU\373\273\177\340?\224>\266\2160 \307?z\2417\307\272\017\353?\247w\235\217\014\223\340?\200U\216)Y\271\340?\354\\\310_p\216\324?\"\026w\0169w\353?\023\n\217\253\240k\353?\350\322\233-\r\321\307?`\336Y$\026\261\337?\254\222\216\312O\332\347?\336\206\356\tm`\337?\212\305{s\201W\337?5H\361\243\374\370\351?H\337\252p\256S\316?>k.\"\225\334\354?\020?\226\266g@\311?\0005Uz\326\301\247?\032\225\266\370\0247\342?\000\270\230\252\237\326h?pZ\374\267O\007\277?@hT\314v\207\342?,J\231\216\372>\333?\320Fr\336\t\245\303?\206\314\024gKf\355?\000\324\002Ryi\274?\374\362\262}\243:\302?H;\rz(\212\326?\352\225(\306nP\351?\256\025\350mk\376\352?\330\246\371\221\024\322\315?\020\214\005c\020\304\345?\022\022\033EH0\350?\200\356yk\237`\272?\264c\007\025\n\252\311?P\023\325\334\330\301\347?d\307u\016\243\263\315?\244\364\t\014K\244\305?t\306\377\331\265\313\350?dM\343\277\030\020\317?o\261\251p>\324\341?\260kld\025\300\317?W\000*R-\375\354?\332\340\336\037!\320\331?xt\310\363\240\265\311?\250~\230\260\330\204\305?8\325\376\236\303\020\350?\212\347;7\216N\334?\221\000\202\031\317\302\355?h}\004\276\246\233\330?\300\014\331\0079\256\264?\277@6Yk\236\350?\331VP(\310\244\341?\276M\'\017\361s\325?WK\212A\246\\\346?\277B\204\240\"\272\357?\207\2201\341\253?\020\272\243\345\344\373\353?\364\211c\352\036\022\331?\332\333\303\332o/\320?\226\005\003\022\317\351\355?T\241q\213\217\235\342?\200E\016\016\007G\202?\356\371\247@\347%\337?\032v$\264\344\262\320?\003*\265\235\232D\342?\332\306)\003*\023\340?$\2300\214_A\330?\344\024{({\252\316?W\351\337\307W\367\340?\3226,W\\\204\352?\373\342b\371\205\271\344?\200\267\276\034I\003\224?\354\345Uy\242\254\311?\032\335+\377k%\340?\003|\224\231\243\333\356?\001\ro\352\350\366\341?\220\225&\316u\232\345?\005\005Bn\331`\342?B\274P\367F\200\340?\307!c\222\225\262\354?\003\260\235`\212\260\341?P\271\016\265\343\014\274?\343\006;^\344\215\345?\244\343d6\257J\320?\320F\311\362i\022\260?\341*\335\340\3260\355?\277\271RS\215-\353?C=\013\375sm\347?\254)\226\032\360\355\353?p\201\211+ \330\324?(\305\233\302{+\351?\2007\374\256\364\255\267?\350\252>\005\360\306\267?\312\363\310\203\311m\326?\237\274\200\300p\327\340?\244\235\035\252\260r\350?\250\314\023\356\265\225\327?x1M\244H:\346?\347\000\333\353\336 \353?\332I \333w\312\336? \330\351\250\334\350\322?F\333@\264\361\037\352?\270\202Dv\266\320\312?\272M\343N\233\322\340?\230n\357K\375;\273?\210J\357\\\007\214\307?\205.y\331\233\255\343?k;\217\252\351\210\346?[\317\255Y=\\\351?l\316\311U\276z\332?\326\205\365\376\244\355\352?\321\262\204\']\r\343?d(&\026$\000\320?\340\022\330\272`\327\235?\312\320\205y\317\217\326?98\322Y2\344\341?:\220\274\013\274\304\341?\246\307\036\177N\311\332?`\264\266\305\013\252\331?\242 Q\034K\217\351?us\342\273[\204\345?7H\217\244\252r\355?<\004\214 &C\336?\320YX\2629\222\263?\340(\323N\326\232\220?\010\017~\230\004[\322?\010\237\036\177\352\203\302?\364\313@\322\025-\316?\364\326$\213\353\220\316?\307\215\366\006Po\357?\002\233\360BM\263\327?\201\355o\310\364\252\357?r\200\314\367\231\271\357? L\250sf\177\322?\200\024\322\304e\336\227?\204!x\372\010\240\307? \253\214\217\352\202\274?S\274\207R\361\366\355?\000w\035w3\332\325?X{M\300\372\247\305?q\263\376\207\364\023\351?x\222\010\232\025\260\335?\342\264\325P3.\320?7\243\216\334\230\037\353?\306\377\314&\255\310\322?T\262\253\242\234^\352?\232\334Z$\305a\351?u\2535\276\325\263\342?X\210{]\340\312\350?\322\320\2016|\366\326?P\221\354\236\2718\311?\373\264y2\n\237\343?V\026\033\360\273\263\325?\360\320\212\002\227\217\271?\034\372B\032|_\337?\002\257{\002\004\341\353?\204\\FlK\356\300?T\013\377\264.\337\315?\371\244\0319\333\347\345?p]\220\230\315\177\356?\240\221\356?\001\341\267?\252\3754~\374\220\335?V4w\341\267\304\327?\036\276\211\002\323\311\357?e\261\344m\266k\347?\324\255\2326Q\006\334?v\307\347\262w\250\325?\260\020\023\315j\237\303?`H\270T\315\315\240?C\355\007\006\300\325\351?^\371r>\324\004\333?8\2002\r\177\037\264?\024mX\213\025\t\337?\237kl\t\371\277\344?\374\017\021ug-\303?\370%\277X\330&\343?\215\231\325\020\302\214\340?\300\232\264S\225\177\306?\230\017C\3064\030\353?\334\007\335\031 \000\354?\366\267\237-d\306\326?\250d#Z\305\301\334?\200\252\2278L\033\251?Fq\306 \232\234\347?\270\227\242\275\206\310\276??\260\204OT-\350?\322\352\222\342\303\275\337?\357\321\307A\241g\347? ~\327\333\2234\265?P\300\177\352\262\276\330?\303n\032\211\252\317\346?\230\230\310w.\265\310?\240\0234\033\021\344\337?\327\034rl\241\345\352?c9\006k\020.\343?\347j\373\302\214 \357?\262\250\217*\353 \345?\2251\256\306\347\362\342?e\271\257\027\002\254\356?$vcCoy\352?\267\312\324\306\325l\347?X\354\364\303\\\316\326?\354\242\266\177\241\233\320?h\332\242\230\243\265\260?\014]\251\344B\340\337?\364\343\355\3036;\307?D\230z\206\204?\320?\200\347\366G\203\240p?\232%,9\203f\321?\000\232b\352\323\332\270?\310\375\217y\335@\266?0\310\231f\301\340\245?\010\332s\007\357p\350?\344\346$\373\3539\320?\373R\335\360k\226\355?Ku\274\361\rY\341?\3440+\017\212\230\352?>C\034\254\347\356\342?TV\332\241\007\241\307?\240_fg\026%\356?\276\253\254I\014\236\323?ms\007N\267\314\352?,\227T\323\320\202\333?dn\266\014\214[\336?`\313\327\3166\224\261?0\302\005tQ\251\300?\036C\350\310{T\356?\360\244}\036\260\207\246?\202\215Qa^\226\326?\0003\327GV%\261?@\025\332O\315\364\327?\2044\311^\r\353\306?Us\246yfX\351?\304\227\255\205\327\232\315?\274\353\237@\246\323\351?\016\t\261\2245\207\330?\224\226\366\341\370\357\300?\210\3039\361\004D\315?\374\021\216!\205J\302?\300G}\227No\242?\346\017\276\260{\370\351?\206\203\241\310J\321\341?h\356P\241\177\330\260?\220(\316\r<\336\330?]L!\231%^\354?\351\326\3709\337\313\357?0\324\2162\351\347\242?\013*v\025\034\363\346?\016\023\345\022N\000\352??`\377\216\363\342\352?)\003\356\276\241@\343?\010\222KF\374@\321?\230x*Ju\202\304?\266\221;.\243n\354?\2334Oy;j\340?\371\326j`\325\230\344?z\002.\305M\310\341?\240\035.]\036X\236?\305k\262\2404\004\353?\210`\227\237?Z\302?)\313\345(o\020\340?\303\'n;\032T\354?\362\003\317\016\345-\350?\0148GPA\352\357?\225\237Y\023\000\371\353?ry{\322\251#\321?\204\003\270\360\304\035\354?X\347\330.\274\205\303?X{\274\360\315\312\355?\267\273/\240\346\355\342?\204\201\357+>e\333?\274\014\336\367\353\342\313?2:r\224.\255\325?b\005\341\002Kn\321?\026I\225g\303\002\326?\347\342\330\336$\226\345?\246Sd\343#\260\353?@h6i\225\026\317?\356\306\344V\320\201\335?\220M\014\270\333\361\246?D\260\020}\2121\343?\254\252\027o\321\030\317?\360\312%*\272H\330?,\177\341\022\376\213\311?32\313\356(\222\347?Z\364\365_a_\320?\270\025\300\323\216\344\315?2\315\033\363\354\376\345?\207\345\020\006\342(\352?\265]\211\344\267\262\347?@\353\030\227@\'\235?\022\246l\353\370^\350?\340\362Q\312y\360\232?\r\340GA\306,\345?\222\317\316\007ru\320?(\2525\260(j\260?\310\343\357\037\003\006\306?4\020\013\366\277\327\312?z\230]uVU\351?%\001\217{L\344\357?h\251a\245\247b\346?&\225\310\014f\035\332?fE!\321\344Z\343?\\`\006si)\351?J\225\213\314g^\346?\021\221\306)\366I\351?nM!{/J\341?\2205\374ur\245\241?\"\317S\301`\267\320?\007\033z\263\353\013\353?~\014\216\316\3609\342?vLT\212\226\303\340?\212m\344\303?\350\324?\210\272\010\311\205u\314?]\303h\314\216l\357?\3477r\277\032\251\347?\207\333\305\312\304\022\357?\204\016c\202\200\356\335?tn\351\361\324\017\341?\351\362\'\335W%\347?H]\377\273\235=\273?\300\033|\021\252\360\215?\360\337u\273~B\356?\220\220w\257\217p\247?\2144\016\254\201\034\323?\213\311\202\357\232P\354?\240\201_\371\223P\237?\034\353\356\322\330\271\342?\200T\364L\212\322\312?\367\n\t\033>\360\356?i\364$\307\276w\351?#\035C\342\260\277\352?4\241[u\253d\331?\370VzL\3018\265?\220W\021\206g\034\276?\345\004\245\033\305\014\356?b4\376\320Zp\352?$\227\237\352\327\332\317?\227\226\216|\351\275\356?\357Y#\374pQ\351?U5\310\024\357+\342?\350E\252\265\233\365\276?\300\316\365}m\275\321?t\020p.\355\323\342?pb\"n5\277\327?\340\202F\026p\r\347?j6\314t\300\370\331?f\225\002\235N\375\321?x\354s\0010x\313?\230\203\310\220>\330\317?\377UU(\3613\350?>\362\223\037\343\222\352?bE\235\271\366c\322?^\253@L-v\351?\000!7\240K\312e?\372_k&rA\342?\\\031-\213\266\373\335?\370\341z\257\251\251\337?\270\031\023?\3440\327?\236\0331\262\207@\357?\210\322\237\300o\357\277?\235\243\343C\266\002\356?\222\217\020]\224@\356?\000\374\346*\t\355\342? ]\004\251W\311\255?\316N:D4\333\324?\002Iw\037\0106\355?\226\243\025=\257\303\321?\013\233\017wd\241\343?c\027\312\346\022\321\357?W@l\014\363\312\353?:~5\220\270\363\336?x6o\264H-\346?x\221\314\357\'\276\277?\204};\316\035M\307?\374b\300<\255\337\323?\332%nl\203\205\353?\231~\370\027{Y\354?\214F\211\264\000u\316?l\023\252\005\367\314\300?\347-\247\022\337\304\343?\251>\t\212\026\215\340?\270\014\321a!\355\306?\235\271\310:\376\233\341?\350\363\273MS\223\312?\366M\010Y)\246\332?;Q\331\353\367\347\351?=Q!%\377\234\346?\007\300>\341Q\254\345?\220\000\266\310E\325\313?U\2230\001;\354\346?\274jm\256&\264\353?Bk\217q/\030\337?\027R\225Z\"q\352?\276\357\030O\326\244\347?u\"/\262;\300\346?A\213\254}\177X\340?\240\373\031w\233\322\300?5\344\231\216\220&\352?\r\032\266f\376\223\345?\020\317\007E_\224\335?\360\307\230\247\230\232\337? An\330e\353\263?l\322:\360a\337\317?Z\212T\205!\333\353?-\242\341|^N\353?(\213>\017\225#\311?\000\225i-\255\277\312?\264\353P\377\031>\334?\377\327&\226\216\217\352?\257\260N\260\035a\344?P\223\001T\027\366\342?\252\311(\370\261\216\334?\356\314\355\3761y\346?\315\331\323(k\217\344?\350\022\032\266\304\245\324?\255(\323\024\265\244\351?\006s\346\235\377\210\345?\316\244Z\307\221\252\351?\372\243\344\364\306*\332?(+|\225B*\315?3\314\350CPW\352?\020)\222F\246q\266?\350$\3450e\250\264?\375\320\342%\204\350\352?\360~n\326\000X\311?`jA(d[\346?\200|\025\211s3\224?\177\\\226\363(\344\346?\374\210G\373\362\212\355?|\345\235\362e`\347?\324\257\277EK7\305?@\236\341\252\261N\247?\224\366\375\0034s\333?\221\241\2529\000\277\352?\274\022\230\356\335\271\325?\270\232\350\207x\324\327?l\325\373\247\256\316\332?_l&\227\3625\343?\262 \365S\0015\334?\256\366\267\334\233\366\352?V\356Z\273\367\341\343?\340?\340\253\333\363\033W\265?\003\033\253\224\201\324\346?\360mR\216\214\257\337?F\177S\r\212\277\340?\027i\343k\007\347\356?\341F\200\252\201W\351?:\035\027\353\3132\351?\211\217\314\r\362+\347?B\r\212\202\356@\331?\232\216db\241\031\334?\313\236\307\230\337\353\345?\234\t\370j\0142\331? \305@\306S\346\247?\256\201\"J\337\007\355?8[80\222\270\301?&\344o\213-+\353?\375\\;\037{\324\353?c\303\2312\250p\355?\254\243a\213\223\227\334?9\201\266J\256G\345?\3624[\371\307\256\340?\222\322\326\201\273[\355?\3506\222\337\003Q\352?\021pd\310\266\025\343?\006\014]\3053/\341?\037\267\255\301p8\342?\364\265@\264U\035\312?\335i\361\'\031\377\354?\r\010|\363\'\350\346?\210p\351?\010Q\343?\000\267\302\324\300\037\246?\372\265m\257N\277\342?\370!}\234U\236\324?\344?\215X\352,\356?\247c\204\312x\206\351?@\211\\_\001\344\205?\263\010\331\220Gq\341?\'\334\234\230\211\347\356?h/7\006w\250\336?)\000\2507i\310\347?\216\013P\324k\326\351?-j\340\275\232T\345?\000\231:\307\373\262\342?\250z#B\272f\303?`A\034\265\256\024\343?\000\374k4O\327\322?`\2561\251\236\337\336?@\342=\310\325\241\333?\302\010\024W[\250\325?k\237T\026\237@\355?\276UN\300(\n\324?\235\203Y\374^\326\354?\036%B\257N\211\322?@\263*@\377e\231?H;\211\336&\232\347?\340\014\372\301X\327\307?X\207S-\371S\344?^\300\305\034\337\005\330?\033\266\3727E1\340? \352?\356\"\253\256?\302*\210J?\310\331?p3a\325\234\306\265?!\271\314\216\206\236\350?8H\250\377[2\326?\376\332\253\366\036\261\344?\"\324\005\010\245\323\356?\237rq\323\232,\347?\002\374\237x\337/\326?\013d\343#b2\344?41sz74\304?\314\'\255j\310\243\317?\250\310\303W\335\272\352?\340\262\0309hO\256?\365({\361\372-\352?f\331a\2079~\330?\243\207Tu\216)\350?\216\202#\323\337\202\323?\210m\032p\007\323\267?\344&\352\224\201\301\334?5\016\212\266D\203\343?b\276\3223\350j\343?}\037\273;D\215\350?xn[\251\366w\266?\344\377J\313\272\211\343?h\251b\013O\236\321?\320\033\220\031\230\373\242?\320\263N\363=\351\271?\246\250\000\177\267\211\344?\313Io\300\007\337\351?\303z\272\252jQ\357? \342?\334\266\202\313?\t.\367\3633\205\343?\360A\374\353\213\\\344?\004#\331t\225\236\356?l\316\210\246GN\302?*\001\210\214p\024\332?N9a\254\2726\336?\0263\336\272\344$\332?\020\220\'\237%}\320?\3708D\327/\031\332?\273\300u6\0056\355?\224\277\361\203\347\227\327?\377\031\214\t\320d\353?t\247\360\346\215\335\307?\224-\020u2\255\350?D\020\005?\010)\321?\370\335\335\261\334\214\303?\'\037\362\330\260\010\353?6\302\365\203z@\335?\320\276\276\037\002,\346?v\034=\3137A\333?\314F\325\375\225\200\323?Hk@\230\300\242\343?\325a\202\323\221/\354?\267\\q\377\372\371\342?_[\202\201H`\343?\361%\274\356\225\240\342?\341\235b\277\372\354\357?\"q\311\310+\252\320?mI\212!\344?Q\212\350\231J`\356?\300\'r\343@\n\325?4<\276\223\020\222\324?\214k\3501\007\370\341?\264\270\0305\027\335\351?\372X\245N^\\\351?l\350\354Tz\t\337?&\215\270\270s\327\325?@\352c\334\241\350\317?<\261\303\263\"O\311?\325\300\204\225\3533\344?\013`\301Z\"\324\351?\2334|\363\031\010\352?\245\260\363\027\301\276\351?\241\216\247\243\303\004\340?\240\323\t\3664a\307?\214\206*\210\353%\300?@\360Z\034.\354\230?PCs\326\r\245\253?%8\327\003?\002\354?\361\003\336k\313\n\346?(,}\037Y}\337?\357\263\226\2200\344\355?\240\277o \242\017\254?\230\366U/\004T\341?R\035\022\342$\352\347?\336inm\265\007\345?\260V\312\2774\377\340?\215J\035\'n2\356?\257x[be\205\344?n\220\373\235)x\346?oU\301\264CD\350?\250\231?\t\315\267\335?\224>\tQ6\361\323?\232#$\222YE\347?P\352\222t\305\366\323?XQ\003\022f\220\262?\354=\002\002\\*\345?\230\017!:\202\342\314?F\313D\346\242\261\334?\331\243\250\014X\226\345?`M^y\360\320\247?\322\035.4\235\266\331?\030\277\351\373\'\242\270?\210W\215e\350\314\312?)_\353\336\243h\342?\210Z&8\371)\273?\310P\030\314\253[\302?\334\375Z\210E\252\300?\360\234\374\324\344?\250?\316\253o\002|N\321?\210-\225\261\331\335\340?\216V\"N\014G\343?\204\226\316\277\304\231\313?X\250l`W\333\300?=\355N#v\242\347?\360\306\334}}|\260?@D\334UM\251\313?\024\3219\261\353\207\345?\216\363\324\274\210{\353?L@&\231?#\332?\3056/$\010q\351?\244\237t\333\r>\351?\300\250&\325\376]\277?>z=i\2649\350?\013\265\257\377\205\225\345?\217V\232\204W-\350?\376/\247rrh\346?\240\347vC\227\013\300?\tRq\022\007y\344?Zh\274\252\267X\333?H\223\253\244\014\256\314?\000p\033\262l\325\354?\245\2205\253v\367\352?\264\352\314M\330\300\305?\257r\346\026\337\032\355?\306\216\203>\2712\350?D\242\032A\017k\303?@\245\355\234\220W\321?\2566\016\003\343m\330?\346\206o\321\027\343\356?\350~\025\367\254\215\336?\331?Y\375\375\372\355?\320L7\037\003\355\323?\002=\024G\316\025\354? \211\030\'\242\350\253?\237\365\230\246\300\205\341?\322\037\261|\272z\323?P\252\0041\357m\255?8\344\330\022\354\255\342?\360\377bf\0050\334?\344V\230jB\356\320?\321C\310*/\306\354?\202\340\236\312:P\325?\014\027\237?\336\302as@\205\357?CgKV@w\346?\222\356\331Y\301\006\346?\031\265B\345\206\340\340?\000&$\243\336s{?@\361\2010\270\304\300?\021\366e\031I\275\342?\255\250\264k_\177\344?\001\251Us,\307\352?P\rB77K\326?\240\301\204on\242\240?`e\272l\222\202\314?\217\351:\332\t\327\355?N/f\237\311\312\330?\0342\322\366o\353\327?\0228\362\247\3425\346?<\321\362\3661\233\311?D\205\375\\N\026\344?X\366|\352[i\314?\200\023\t\373\244\204\204?[\275\330\366\317N\344?\017\302\036C7J\355?\200\263\245j\211\376\255?V\010\326\342\320*\356?}\250Y\311%,\343?\273\366\303\n\001\302\355?P\023titn\303?\240T\245\364E\206\242?|\244\005qy\212\322?\032\007l,g\301\352?k\271~)\232k\350?\020\300v.\262?\266?\000s$F\371\246\340?\243\315\t)D\377\357?7\272\352\231@\272\344?\010\370\315\013c\367\311?T\253\"\250\033J\323?\340\307e>Y\321\231?\313\241\231\352\331\235\352?p#\361yR\236\324?4\325\032\r\036P\353?@\002\0163\246\352\326?\300s\264\177\264\343\322?\003y\362bu\321\347?H\014\212\352!\347\272?\034\352\246\004\206\343\331?\241\200\233\351\023\356\345?\323\203\304\020\255\330\346?\025\251\367\014\031\377\351?H\227G\027\242\240\274?H\315\244\2504\352\301?W\227\035\006\371\257\353?>\270n\246\":\324?x\335\367%\234-\271?\340\325\364\346\007\025\254?Z\236\226\002\270\250\356?\230+\212\243S\271\270?\256\275\325!vT\351?\"K\303M\004\251\324?\332=\'\000\031\217\327?\321\t\264\344\216+\354?E(\316\002gx\356?A\264\351\346\200\247\351?0\025\347\322\020\246\315?\216_$$y\315\345?\302\252\n\321N\t\351?\307\227X\234\024e\347?s{9\002\300\243\346?\254\267f\237(\032\306?0{\022\272\374\217\255?\370\232\212\324\217\t\264?\250 \275\253\251\224\347?n\"~l\033\343\354?\224nmi0c\335?<\244M\321\304H\327?\274\371\037\371\375\343\352?\230\035\2167\255\341\337?O\200NZ\2478\353?\220\346\325\202\341,\265?\360A\027-\243\344\270?\236\275\333N\311\343\324?\260\363}8~&\273?0C\243\366\3017\345?\216Y\351\026_\276\334?\302\0316\251\260\303\320?\002\277\263\330JW\340?\207=\221V\376\016\347?E\217\"\242\0313\347?\3274\320\357jL\353?\270\250\013z7E\351?\330\212\000\374\302G\260?\370*\363\271vm\335?\013\323\332*\2300\346?$\304\213!jg\335?#\017\005#\313\346\352?\004\254\343\006\357\023\313?\250\254\016\362\362\003\356?X\217\240H{\263\273?0\033\345\322\365\201\302?\302\021\367\276\250\343\336?\300\326\025\005\311D\343?U\306v\340\343\230\350?\244\365]\002\202B\316?(C\366p\363H\274?N\003a\300Z\211\320?P\313(\341^\262\243?\214\330\314\244&$\300?\302\333\207\372}\200\357?\353\251X\227\251\213\343?3\225#\004\034\245\340?\230\226\331=\023\331\303?\'=\263sT\331\341?\275^\325\243!]\357?\354\317\211\306\345\225\312?\347L\250\021\221\021\344?MStw*#\346?D\223\234\006\025T\313?\025\266\2112V\311\346?\250\205@\326\313w\261?v\232\tY\340\260\320??\374\245\306>\271\341?\315\216\306\233\304\215\351?Fs\357i\257\366\347?\000xX,3\006F?\375\253\203]?^\350?\200Ck\346^)\257?S\250;\306\367\307\353?)\324\031\031\021@\352?-*\314&\271\007\356?AOW\025\014\227\356?P\0235\022Ed\271?\014\226\363\374A\273\325?\2016CCy\014\357?B\324^V\276\025\354?@w\230\222\325\226\210?\374\030\001\0024\027\354?\324\2579\206\321\263\337?n\373i\017P\275\322?\350J\220\010s<\277?\017\377\357\313t\262\351?\220\372\233\027\223s\355?3\253\341et\035\340?+\307\344\256\364\371\341?\034\276\3103`\262\351?\301\222\016!Zn\354?\304\rM\331\257\257\340?\364\316\232\277\312q\301?\r\336\265\320Lk\344?\254\346\334o\374-\300?d\265\333\344\307\243\351?x\016\306}\030\273\354?\244\0269RA\'\302?K\371bl\317\203\355?0\013\023[\305\306\346?\252\332\232\250M=\340?\374}\307\247\242\224\343?\250\230\254\217\237\026\320?\002\274\200u\256\375\331?dF\\\331\246\203\344?@>\224)\274\020\202?\330YZ\025\201t\341?\306\315\227\335\207f\340?8\233\"\250\037{\353?\350@?\364\330\344\326?\240\261\307\365.Q\226?j\237Y\t!}\356?^\350\261`\211z\351?\226x(\024\317\374\337?\222\311\372\273I\016\347?\226)\303\013\n9\346?\334\260Z}\305\037\307?\324\206}\277\232B\341?\243_%onF\352?\2106Q\036z\251\332?\351lk;\033\232\341?\310FA\023\201L\313?f_\347\311\007\204\323?\245\3456\205\303\340\347?n\0228\317\213\006\331?D\343\361\253\206\343\356?\334\010\277\026^\r\316?N\277\364\353\003\302\327?\320\241\262\223\374;\353?\233\315\341\354\321\007\341?\016\301\237\235h\320\350?\"\2660\033Q\321\334?\356s\207\013[\023\327?\327\221\376\370\351\376\347?`\373\342wZ\333\226?P\003\267\264O\310\320?\266e;r&\005\346?\230\261\034\325\231s\331?\212\313H\r\223R\335?\231\261h\352\375\017\345?n\355\2115\340<\347?\354\032&%\031\224\320?\323h\266|\005\227\346?H}\027\002D\240\342?\240\004+\227\315\315\315?@c!\336\0355\356?\031C\315w}+\352?*\201\3707\'\237\346?\350\243\202\216x\217\260?\246\027\n\325\275\201\334?1K\264a\273(\344?\320\370)\330\242\361\305?0\'a\330sG\322?H\255\370\345V\251\325?\266\215\031\346\222m\327?\206\t/\255iV\334?\241GJS\371\276\355?\221\325\320\036\334\025\350?@`\006\300z \244?\214\343Bl\007\350\321?\\m\032\2475&\356??:\253\244.b\340?w\274\305\275\217\212\357?u\373\326D\021\227\347?\350\272\365\010\306Z\306?\010\361\001\276\200\356\345?F\206%\321S\013\325?\230\363l\203d\t\357?\237+\261\306\266o\346?\306\255\242\377mJ\332?\256A94\275r\357?tN\371\234B\262\354?<\221\207\315\243\332\354?\324\323\2317\2325\357?\214\3044\021+ \323?\310\177\217\2010\244\322?\324x\232\256.\237\351?WF\351\007m\340\341?J,\302\264\301D\343?`\033\227\022\264\347\350?>B\035\214\261d\347?\340\203C?\252t\263?\304}#\222\223\240\346?n\030\302\371\364\326\327?#+F\235K\376\340?\230\340)\001w\020\306?3^w\001\3472\345?\200\013F\235T\243\345?\230\315\327~|\225\331?\323w\320\276/h\345?8\215\313\035\304i\275?\220\325\244\027@\243\260?\215\021\325\177\200\322\344?B<\212\214\004\246\331?\350\026\\\306\017}\333?\024\274^\033\252,\335?,\203<\334\364Q\312?\210\2431\007\005\205\307?R\2678\020\270\201\337?\260\354m9\274\303\324?\327\302\304\264\007\036\342?Zh\371+\0313\352?U\210\205\372hE\352?\024A\316\367\211\242\346?HG6VU\010\303?];\022\344_\005\355?$\336\301\364L\214\334?\316\226\035n~\035\356?I\235+\367H{\355?h\266\260\212\277E\326?\000\026\222\206\313\362\217?\300\033\330\177\037\216\272?\360`\352Q;\214\326?\346\304\353zo\344\357?k\2014\025\020h\351?\320\346ULO\037\331?G\021Z\363w\200\344?\214v\207lg\221\323?0\327\346z\322k\267?l\243\205.\220\000\337?\264\021%\312\377\301\321?\3240\242\030\315\342\307?\374\251\311\341]+\321?EC5\224\342\255\345?\317MQ\032`7\344?\376\3742\274\002?\355?\024\034x\314\351\252\312?\231vQ\306\033\326\357?h\365\334\342.\357\275?\366\321i;\231\250\345?`\335\036\030l\003\242?\000\002eI\352MP?\222Pc\223\242\204\341?0\370`\320p!\272?\004\366\333D4\"\304?\000Z\247\n\004\200\306?\270A\363N\223Y\267?\256\312\330\325\223Q\325?\320\371\262?)\030\272?r\343\372\236s\002\341?rf%\261\\7\347?\340\\\245\004\034\211\352?D\334c\3744s\343?\311\216\223\202\315 \346?\252DA\226\220\231\335?i\2323S\204\373\343?0\265\251\230tc\313?\230\323\260I`D\323? \241\332\342?@dh\217\025n\337?\254\370\370\034\270\211\301?\234\260\253\343\366\373\337?\245\013>\352}w\357?\356]\204\313~\214\334?j\306o2D]\343?\341\205 *G3\347?f\3269:\243<\320? S\256A\236\267\223?\300\240\207\320\002\330\247?\231)a\n\351\240\341?\370\320\344?\271\031\305?v\261\034?\254b\326?th\332}b\303\350?\311\027iD\367\366\356?\377\n-\212\240\244\351?\004\271+\355q\325\303?bKh\227\246m\351?<\033\234\245\236\354\342?|X$\275\241\271\313?\370\375\376F\330q\267?\313\245!\262\357k\352?\253\364\325\036\216\272\353?\213e\307h\373\n\347?\035\261Q\026\376!\346?\004\361\210U#\300\344?\246\032\216\244\275;\344?\2221\031\340G\337\323?\210z\014(\362\013\325?\001\225\257Ek+\341?+\341\275\374~7\357?\202\273=E\233\032\353?\354\213?IgA\325?>w\337\232\347\367\322?\326-\3238<\335\323?l^>\357\226\236\305?\230\003\363\000\354\264\301?J\301\330w\227-\325?\216>R\352\336P\325?\210\245\330\325K\'\264?\350\3141\345O\373\347?`\2647|\375\311\220?\230\267S\241\267H\272?Y\"\"\020\016\363\346?\305?\337X5\353\207\343\356?\200\277\320<\177;\222?\254\260/d\033\323\336?\330\237\243\242\371J\321?\366;\270\270\t\217\331?\253\326\2361Y\245\353?\340VbJO\204\301?\305S\234#\366\"\350?\376\032\277\007p\262\331?\324\224Q\332\340\322\307?\032\027\356.\215\013\331?\253\331\361e\253.\344?\324\017h\367\321\227\307?\252c<\303\370n\357?\360\263X%\312-\252?\014\267\\\241\235H\357?Y\277\n\r\0140\344?\260u%\304\010T\314?\377o\303\017\331\372\345?\\\317\031\242nP\353?m\220\200\367\346X\350?\271\366\207\372ht\345?8\261\227\335\312\241\337?o\223}\351\035r\340?6w\201\2572\351\323?\344u\364\306\2449\317?N\033);\311\204\357?s\024J\311\271\330\354? \315\336\307Z\000\271?t\207!\202\036^\347?2\221\027\261\240\002\352?W\024yOT\230\354?\270\030\334\264$\366\350?#!\267\363\037s\346?\2641\311%n\001\323?}\374=b\r\\\356?68,A\223\302\323?\312\301\025\224\027?\337?i\311%\251\243&\342?\235\301\202\364\032\235\351?\240\242\343\246\250!\220?\330\246\371\037\235$\276?H.\235\n\016\264\340?KC\220\210\301\217\341?\3241+\340\366?\324?48\254\341\227\036\323?RA\025\222\347\217\324?\353\304\371\345\236Z\344?Z\206J*\367\372\350?\234\250Q\n\271\231\332?4_\226\013\356\310\355?\302^\362\317f\312\331?\200\2031\206z\t\265?9\302\332Z\346w\356?\322\031\030\346\244 \341?\340*\200z\207\357\335?F\\\3578\201\315\342?H[\224\243\270\350\276?\210)b#{\304\335?,Lq!\256\251\310?\376\242\312\033\240\020\347?\236\317\245gA\205\336?\0045\272\206\373\216\332?\030\311S\345\340W\320?p\311\034\340D\024\301?\226\251\036\25555\330?\030\226\271\351+\006\275?\003\252\205\215+\322\340?\204\242\362g\214+\345?\377\203\272\016\217Y\344?\346\213=]\317\304\342?\\\313@\026A\365\315?\324n\266\272\322`\335?5\334\244d=6\351?\215J6\205QQ\343?(\267\231\017\255\'\313?P\227\262J\236\316\352?\220\322\334_0\373\264?6\242\341\346\030\375\350?\277\017`\251\360\037\357?\\\350\255\001\274.\353?nY\220\002\3145\333?d\'\255S\0279\327?\335c\267\326!2\340?\274\036D+\202\207\300?\246\210m\227 \247\345?\240/\304\305\224^\240?\254\271\031k\004\034\306?I^\204\307\327\224\356?-\241\222QM\214\340?\274R\220\272\324.\352?\000\262l\t\254%\317?\260\236y\253\177~\252?|\302h2i\337\336?\210\202 \261\037_\277?\035\233\352\016\225Z\345?\260,R\341\001>\263?P\'c\027\270\361\334?\27211\215\277\374\341?%\276\037U\334\035\356?@\321\200#\343\327\324?\352\271X\363\241\275\346?\224\037\314\304\345\010\322?\255\376Ck\222\376\341?\350\260a\033\352n\340?\305\363;E\362\332\340?,\365^!\337\347\356?\002z\332\010\020Z\323?\310*\n\375\342>\324?\254\370f\213>\333\301?\024\344-\'\357\222\357?\231\016g2/F\346?a\310\260]l`\350?\371\2125\214\344\227\347?\000u\273\224\217N\311?\364\223\230wm=\316?8[E\303RY\276?\312\023\332\303\031\305\334?\357tu[\261\023\346?@\025\216\\\rF\221?|\023\037\001\237\365\302?nn\236+\252\306\334?\026A \376\315\337\347?\277\020\223?5\'\355?\310t\316\354\336\237\301?\366\230\345\004\324\336\344?T\\;i\304\217\310?\000\021\321\014\347\345\350?L0\27407\276\346?\341\363DZ\322\256\342?\301R\235\241\023o\347?\344W9\224\001\023\314?\2213Z8E}\340?<\274\235t\017\247\300?d\264\200&i\354\323?g\334,\264\376\240\341?\200\324\247\3067\205\260?HD\332D\245\215\313?\316L\310\025%\177\331?\022\362\310\277\216\361\354?\270G\317\250\307H\357?0\002!\241\211\314\333?`\304\320_T:\303?`2(Y\3659\223?`\320C\260k\177\252?0\205\017\017\247\261\331?\376,\214I\0314\323?\2422\221\365\362\206\354?\362\030\351\177+\202\330?\230J,\301\\\317\354?\346aZF\t\177\326?p}\177\021\320\234\242?\332\323\211\226\210\r\353?\020t\312\0361\341\355?\272\211\211\340\327e\320?\336l-h\2448\350?\303\232\215\245\335\232\344?a\347\267\214\000\320\354?\214\352\004h\003\343\340?\346\n\263\034\300\353\342?\306\340\3129NU\325?\024\317\204\235Z\332\343?\352\005\227G\264d\327?\331@\231I\357)\356?za\352\265\206\013\344?t\251\245\303~\367\352?`l\224ja\340\234?\200aqv\335\367\272?\250\330\214K\005\007\311?)vy\200\317\007\356?\210\037=\344\000\264\350?\275\263\274mP\300\347?\233c\274(\356\314\342?\270u\354O8j\276?\020\352\316-`\024\262?R\273\003^Zh\326?\200&\221\321\326\032\213?$\301\211\033\rk\357?\037\307\020\315\376\210\353?\252Z\177\270\344\204\335?r)-\034a\314\336?\301\235.\306\016\013\342?\032\211g\254\021\005\340?\005\273\212O\303\024\352?8Gu\202\216\211\322?\233/\327L.\230\343?\300\314F\303w4\247?\377h:\323\206\023\354?\032\343\034\251\266\212\321?\275\270\203\032[\213\355?@\003\231\246&\377\220?\256\224\334\017\035\263\344?\274^\210\373Ws\326?>\347\302\321L\266\325?X\214\215n\323\007\342?\345\300\263[\332\260\341?\200\\u\225\215S\265?nC\351\231\216\016\333?\014n\2753\361\337\300?\345\304\r\314\323\203\346?\350\025\256\202]d\306?@\255}\221\230\013\326?8\2667<3P\265?\020\355h\"n(\307?\366\242\373\245\221\225\352?\344\365k>S\202\347?\340Z\005x:i\272?\250o\254\372\243\226\275?)9$$\356\337\352?\350\310\300v\347\316\331?B\353\352\2060\256\343?j\354\024\245\004-\336?\030\205\321\226\256\216\322? \277\3142\214\337\241?8\273\275\033B\202\263?\200\214$\212\276\236p?\030\323\027\261\301)\303?\367\031\262\275,\300\345?|-\t\347\310\223\301?r>\366\224R<\353?\274\026\330\344\352}\354?\034\246(\246.\225\354?p\005\273\006\342#\316?S`\207\243S\321\357?\n\377L\322w\306\340?\321\030{\003\265!\355?\003\024\014\270\322\003\342?\207{\310\270t*\340?\244\274\261\276C\220\313?\024\026F\027\220\263\344?|d\036\223\310\255\313?\220-\010\317\351\020\346?\316\213\330\320\010t\330?(|@\270\330a\265?Kl$\304\306\342\350?\322\271\325[\251Y\340?\344?\364&\204,\302?\004J\361\3568\253\315?\325e\374\373\003l\352?\250\022\256\265\020\201\317?oz\262I\360\030\343?Pk\245\334c\231\246?\020\373\252l\3526\310?\207\274\305s\243\005\345?\202Y\336\022-3\327?H_~R\250\020\321?B\231r\234\232\016\351?+\r\026;\363\217\353?\260\336\201e\350\363\241?\254\263_\251`\322\350?`i\r\242\"m\305?Sk\312)}o\345?\360+\263\026c\016\311?q\273S\201U\t\352?\246\374\214\356Y\357\332?\366(\375\317\356\234\334?\344\334\034\032^\357\317?\024M\311\324\t\331\336?2\232\315\024>_\344?%\035,+V\210\341? \3434W\324\266\343?Py\373\'\177\337\356?\220t9\301g\345\354?R-\035\277\343\327\327?\034\234\223\tD\264\343?\247\321B\017\362\200\356?\240Vv\0165\211\324?r\250RQ\334\'\335?j}\025\254\327\034\330?v\303\032k\307Q\332?\340\210\033\347\035\036\314?\340\3537\356\246$\225?d*\205\204\033\007\316?\261XJ\212\221a\357?P\007\342\200\006\332\314?\305UHe\367Y\347?F0\240VW\264\350?\260\251\303\251Xz\255?\364\002\246mA}\300?\361&\313 U\265\354?\220\265\344\324\362\337\324?\250\335\333j\246\272\336?V\330=2\014W\337?@\013.\3603\360\336?\235\375\360\037\374\372\344?\204\200;\237\250\206\346?\n\255p\370G\222\352?pM\347i\2457\347?\\\246\232\313\214\211\344?\020\006\375!|3\332?\370\301\355_\344y\274?\361f\244\350i\351\347?\3107M\274s+\321?tq\006\202x1\306?\3670\334\215RS\341?|\251\225\035\321\277\311?\202\037\306\360 \023\356?\207\364\032\251\271?\344?\364\373e\336f\037\306?\200\nj\177\010\024\254?\300~\232\023p\t\310?\327\352\016nOG\350?\353\271\267\030\330\303\355?\374\300\377\005E\017\344?\030_\250\265\211\201\332?\333Q\242L\211\357\355?\360et\023\277{\255?\006.\256\306k^\335?\274fM\250Lr\343?\264d\325\2335\016\314?\217:\266\302EV\342?\345c41<,\356?\201T\014\363\0341\343?\360F9\023\266\343\245?\200.\262\3558w\306?\006Q\001\2453v\356?[1\367\023\215w\352?\340\207\341:\324\316\223?\220J\014\314\216\214\247?\334\335\352\324\376\276\324?G\211AVij\356?YW\232\311\005\256\344?P\025W\205\233\361\343?(\324\306U\334\243\325?\364\230@\031\035\313\336?\354\016\310\352\351\036\336?\200\210\323\236\325\247\246?\230%Y\032\336\216\352?\016\252\244\010dd\356?\243\306%\013\017\250\355? L&\033\241\321\262?wn!E,\236\341?\274\312)\336\266\320\354?\3373)!t\331\340?P*)\037\303!\241?\007\364`^?\265\354?\\\347g\320\245\005\353?\332\315\336aA\327\327?\220\223K\364\307\317\354?\026\333\t-\252\265\337?\211\347%T\305\373\350?\257-\205{\320\332\344?\303QR\327\324\222\346?\020by\313\341q\265?L\335D\274\2060\353?\360\232\021\321\017\234\255?\245\302d\237\016\216\353?\343\377O\261\357\242\351?a\303\265u$\375\342?qC\017\276\375\361\346?\342\031\023\212v\010\322?d\251\275:\241\021\346?\222&o\321^\346\353?\373?\251;9\240\352?/\303\356\017\331O\354?\343\021\272,\216j\346?\004\"\336Mb[\305?I*^\377&\032\345?\316\376\224\026\245\264\345?\204\341hC\300\252\321?\356P%\031\324;\343?\362\213\372\314\2202\350?^;\264cN>\351?\330\222\365M\245\240\331?Z\322\257x5*\346?\035|f]\336\231\353?l\032\261^\037\261\300?w>!5\261\204\352?\236S\335\177%O\322?!\223:\322\037\363\356?s\006\230\272*5\341?\031\036YD\363\263\354?\354\365\025\231\307\217\307?\342\342f\273N\033\340?:\n\346\317\364\211\326?\030q\225\241-\201\311?n\200\003\315:\000\341?\205w\311\342\205\311\354?\033\225Uz\\\332\352?\360\026\322\005\356\t\277?\"\335\n$$\372\334?H\261>\301\277\236\302?\223\336\002\300\305,\353?#8r\246\005\340\350?\216\356z\255\376\205\350?\356\264\240$O\334\332?\267-Cy$\326\353?B\'P\3076\026\344?\270\227\035\237\023\327\277?\331D6\230\253\330\342?\264\211F\354\377W\301?\000\245\t\325\333h\342?\342d\226lho\336?\360\177a\367)\257\317?n@\2743\020\013\350?&\031W\333\277\224\351?\302y\3410\264\244\321?\370r\312\261\312<\325?\340\220\220=\307\243\306?\315\237\3030cM\346?Y\3155\'P\342\354?\222l\277\220\332&\325?\000\004\302\2614\032\256?fa\000T\241\316\342?\270\252b\354w=\274?\320:\274<\360\237\340?6Y.\263=e\340?\374X\352\n\201\255\327?\013\020\253\026\274C\351?\002\255\031\036\3212\334?\214]\313\275\367\233\345?\335\367\236\341\345a\343?\027\031\261\035UR\352?\346\222)\336\016\246\344?@\322^w\'\307\203?\"t\315r9z\332?S\t\300AJ\332\342?\365\025p\027\027\200\346?\\/lx\262\'\355?\004[P\334\272v\335?\003O\020ML\307\347?\030\030;\235\355\311\357?\372\315\252\006W7\324?u\241|\344\010y\353?\036\026\371S\362\326\334?\nP\265\251G\260\322?\364\361hy\325\233\305? \270\350\270\220u\303?\370<\253E4_\267?\354\361\205\035{[\347?J\255\177M9 \342?d\035]\021\006H\335?\310\007q)\220\233\306?W\256el\332j\357?RL-\3618\267\321?\230\354\311H\005\003\330?2\315\235.2\010\341?\274\024\332\344\236\r\317?\204\354\236\236.u\323?\320\270g\024V\366\273?rx2\025\307V\356?\224PN\034\370\237\335?\366\020WM\213\201\351?\364)\321\274\272\006\326?\226\235o\224\366<\351?H\335\227\013\005s\276?\006A#+\005\315\340?\200\230GI\374\207\261?@\200zbY\022\236?\\\303\233\351\325\014\306?\204\266\"n\024\316\336?\253\261\312\r`\207\356?\030\257^\367N\220\327?\334\177Y\212\005M\340?,\245\177o\314N\357?\023.MR\013\016\356?\336\013\036\237\322\035\352?\021\260\302qs;\351?\333\017\224\334\303\306\350?tZ\316zs\177\344?\230\3616\024\216M\351?\260IL\003\327\177\302?,@\363\344\036\253\303?c\320\n\243ce\357?\305\251B\363N\254\352?\362\260\272\000\265\207\332?\000\356i\376\2709}?\002\262k\2700\177\330?\340av\013)\025\304?\202\021d\376\207u\331?\256\246\220\233E\272\357?\3009K;\263\377\252?\336\234\264N\221t\357?\300`\203\372Ux\326?\030\276\216\370\037}\275?\200r\277\373[\364\307?\222\'\312\026\266\005\320?\252\341\361u\2102\333?0Si\003\347\233\333?n\224\356f+\306\355?\304c\001\232`8\320?0\251I\207\031<\342?\020W\362\301\237\001\251?x\377l\343\210t\341?\\gh\300\020\212\335?Pi\031\233\002\t\253?\0204\254N\266<\315?\000t\006\313H\307\334?\254%\212P\n\361\310?l\035*\243\262Q\345?\001\312\303v\032\325\345?\320\374\023\220\261\202\274?\233\210\277\323\310\345\351?\030]\204\210\302\257\333?1\336\304\266\366`\342?Pb\014BuZ\314?\214\021\367\300N\334\325?\204\364\177\221\316\334\317?\010^\204\177u]\330?\016\017\276\021f\360\336?F\363\344\320\217I\332?\272+`\337\210M\340?i\316:\017\274\232\354?f\334\354\230\320\340\347?l]\371L\303l\342?\210\332E\375\273\252\307?\312*\252\276\310\246\321?\357\302\376\222\336/\343?p[\323-\277\205\313?dG\263\021\325\230\320?`\000\326;\214\366\252?${\235WH\253\355?(~&\263%\377\310?\301\236\356\362+\246\342?\356\355\3625\"n\346?\007\216Ci<\304\350?\370\202\365N\022\223\331?\232\277D\001\262\303\347?\237\367K\351vg\350?P\351\363e\247,\340?T_v\337e\325\330?\020\303\216\343z\227\326?\307\211n\320\224\036\342?8\232\252\363\341\374\273?\355l\235@KP\345?\212\216V\204\r\341\344?T\311\271\373\354+\327?\320S\205G8\007\353?0\310\366\205\356\211\303?[\373\213\\\221k\355?$\274\237@p\300\343?\353\352O\250|\200\357?\325\025B\350\'\300\353?\350v\037\256\014\'\323?K\003\201>\027\375\353?\206\201u\254F\223\334?\320\256\22457\337\346?\246\231\210\013\315\356\327?}\371P\254S\355\350?\021\363\213T\273P\344?N\330+\260K\004\325?q\374\346P\377\224\347?\2745\324\353Kz\347?8\204F\302\033\205\316?\010*\032w\313T\352?\210\213+6\351\360\275?\334r\341\023q\014\321?\364\264\242\037\333\355\353?#\332\277\250Y\010\342?\370\340\224n1\364\353?Y\264\344y\0303\343?\330\031Mw\200\n\350?X\216c\024\244\273\340?\006\266\216G\337M\333?>`\340TgQ\327?\250hk\202\362\336\267?\\}v7\235\327\322?\230\347\366_E\345\325?0\373pl\356\005\257?{\334\317\253\242\262\352?\316\275\211\314\220\303\347?7`O\305\354\002\341?\310\377*k\306\375\317?\371\340\204K\263\333\340?@\250{\262[\177\210?\222Wk\010s\025\341?\370\220\371\232\013\326\273?\212c\032\350D3\321?1\253l\264\r\224\357?\232\\\303\245\330%\320?\1774\256\037l)\347?\272\2777\306\037\031\357?\334[h\257\377~\313?``\206Pg\010\243?1\354\346\346wU\340?L[@\272\323V\324?\004\371\031\246\314\304\307?\010(\311\355\0311\321?\020\324$\320\271\016\246?2\030\201a\262\241\351?\250\241\232\037T\031\317?0\373e\317\\)\244?\210F\231BH;\340? \315\301\333j\346\247? \302\246\354\307\306\355?\300\030Z\330RH\220?\r{\200\236?\220\343?2\370\034\276z*\353?\374U\032\323*Q\332?0m\375\206\214\242\323?&\215S\367\240\306\327?\007\030\324N9\350\354?@j\327U\302\247\316?\274\242\352\266\312\342\310?=\227,j\t\324\357?\tq\363V\313C\351?\025r\317\266\035b\344?Im\033\037\202w\340?\206y\330\322\205\275\344?r^\3545O\302\326?\327\203h\223\267y\350?\354\316z\226\351\263\355?\300\373(\002\234G\260?\027\352h\255\216u\355?\350\315$\204\244\"\310?\306Y\327(|\252\324?\330;\303r\244\204\343?|\335gw^\014\333?x\313\314\360\223\265\260?\000*\337\"\206v\223?Z\247@\0166T\330?\322q|\235\272\274\357?\246\360/\264{\004\350?J\'\202>\303\335\333?@\350\202_\233\025\250?\366<\273{\247\t\353?\237c\313\r\010r\351?\237`\373J\346\252\345?\275]\350\024\3605\353?\000\322\235\276\352\340g?\010\226\374\316Z}\357?8\236]\n\367Z\355?\253#\022\245dT\354?\331\275\263\224/\241\356?\330\306\311\220\027\215\330?\271\351\324\004\2278\346?M\254\221E\314\366\345?|\204\013\333bD\352?6I\357\360\301\331\337?\350Nl\277\312m\333? \335\307\335\n1\315?\374\341\301\247mC\307?\"K\014C\224(\324?\242\344v\002\261\356\330?\2744^\227\003\262\306?\370\022\t\304g\026\264?\014ZB\266\372R\327?7b\217$K;\353?pm\037~7[\273?\020\333I1\237(\260?\030o\312\036\t{\274?`%\267\303&\234\226?\032\241-\027\032%\351?\336\334\035\266S\273\342?\366\334\353h\006\213\350?\370\374Z\361\264\360\260?%m\207\317\277\017\352?\364\001L\236\026\352\316?\223\365\371\232\326b\347?.~T\213N\371\340?\220C\310\027\007,\251?\240\'X\\_\353\274?0$0\307\030\232\331?\370\0260\276\273\337\326?>\217\204z\247r\336?`J\216\330F\027\255?\340\353yd\rq\357?\364\343\365\233v\346\345?\000q\303w\266-\341?\360z\261*\3132\345?\252dX\2522\206\342?\331\333w\"h\013\356?\205\002%\036\373\374\346?\240\310\311\224{\035\353?\360\223\365\331\013\271\317?\250\315%\224l\013\262?A\274\003q\014X\354?\234\303fji\350\345?\np\266\272\3604\357?\340U\3601J}\230?u\262\224e5\251\353?^\305/\226n\211\336?\205\355\345e\252\321\341?@\340\235A\022\033\263?Ha|\3215\250\332?\001,\t\256}\177\357?%\372DB)V\350?8Q\274:\342\237\304?\000\364\246\031`\ra?\240\"i\typ\262?\340\207\207\340\377\237\274?\\K4\025\270\214\302?q\374\254\2112\370\342?\260:v^\372\327\330?`\r\347\203\351\203\340?Z\375\364@M\324\340?\"\376,\032\277\007\346?\333b\255_i\311\344?\226\\\346\332\242I\321?8\303QV5\361\263?[$Ri\322\320\355?\030m\001i,\217\300?$\037f\222.\004\306?\236%\302~\215]\351?\345\2615\031\271G\345?\272\355\361%\244\333\320?\370\325Z\235\256\310\337?\300\026i=i\243\240? \361Mu\"\250\322?\377\331\'\275o\255\341?\003\263\'\177\235\346\356?/\227\324\226\370>\355?\274\032\365\027\014\317\305?\000\236\353=\3543\266?l\323\351\230\301\232\343?\212\203o\0361u\320?\314\000\255\363\332y\307?\220\010\352\004\033\263\267?.\006\n\016\332M\343?\270\346@\360\342\276\310?Mg\230)\233\020\344?!\302\344\035\202\033\340?\214\257LTN\225\316?\340\324_yOd\262?\000\305Rb9!\313?\242~\035\2432v\337?\306>\245\021\305\230\352?\340\2007K_\025\316?j~!\0235\240\342?\\\321\031\361\320P\314?\242;g#\"\005\351?\332:Hv\357x\331?\263\225\3223U\267\340?\330u\005@\243\032\337?\202q\337\037\344\010\353?\320\026\273?\264$\306?`,\305\305\177\203\315?@\\\t\306\000\253\217?\347/\"\340O\255\346?\350\320\376\365\034\221\331?F}e=\321\300\355?\020\347\232\022]\316\265? \212\265JD\032\242?9R\254S G\355? @;\3337\004\347?\200\323\370\345\255\325\275?\314\277\r\267\257\007\352?\330\013f\344\214\001\322?\316\357\257\340\217\310\335?\"\304\014\216\246,\354?\24676\236\317\317\354?:\237\016\244y\214\336?\r9\362\210\276\200\340?\326\210F\341{\343\331?\212`:\013A\371\326?\035S\003\364\314\023\343?\340\010u\363\225\336\277?\306\234\356\362\275%\336?md\013}{U\342?\034\323\225\313,\241\352?\326R=v\\\263\341?G_\204R\272\223\350?\361,\320\257[\204\352?\376\013\241K\226\277\322?\270Nf\260\032\203\310?b\320MU=\324\350?\2730v\025\200\357\351?w\302\0173I4\352?\302\033>\343\365,\320?\232\245\311\242+v\321?\232\221\"|\243\246\340?W\274\200a\336\003\343?|\341\014Y\244\263\327?%\270\342HQm\354?\260L\353\372\304\354\246?\275\302\266H\332[\343?\242N\030Y\027o\354?\324 C\036\231K\355?\377\270QX\340\362\351?\316\336\370V\236\311\333?\300\n\334\034\263\236\223?j\273\266\336\341\316\330?8\200?\355*\353\350?\260\263\226jC\006\246?x\237\307\351Y\302\266?%\t\206Jp!\356?B\363dy\024\336\346?\032[\265\246>V\335?,\332]\270\0108\353?\004A\007\255T\320\331?\343>\343\221uv\356?\224\024%\367c\311\355?#\177 A\261\004\355?D\244\214t\2651\311?\334\315x\354`\315\323?rj0,V\t\353?7\363\356\232\244\022\342?y\225\273\027+u\342?\240`\036\277p~\253?x\210\367\301;\231\337?\232v8\020\031\363\324?\350U\023a%\020\325?\347\014\353\345\232\262\347?\220N\342b\264p\347?4\213\226C_\020\345?9\0172\\\252D\347?\270e\277\264\320\371\267?\300\341\311\002;\277\263?\315\245\325j*\340\351?\260=\211\334\006\034\356?\200\320\356\272\205(\347?\230!\027_\253D\315?\303\256\033g]$\345?x\241\002GL\244\347?}\345\360\202\213\356\357?\250%1!\206\315\321?\353%a\3616^\340?\377+7\242\352\014\345?\274\253\006^T\202\315?\340\232\224\370\361\233\231?Hm\302\016%\301\321?]\224%}\335f\353?\310y\235}\205\r\332?\307\324\033\315\263\271\356?P\217\247\376&E\355?\252{+Z\255\254\333?@\033\265\367\371v\351?\2408S\276.\326\275?\310\270\262\205\345\027\340?\230\357y\016\036\311\305?\350\275S\253sz\273?\030\000\302\024\2207\272?\230\277#b{+\342??\242\335\342\004\004\354?6X|\3576\243\333?\374^\206p\365D\333?]\270J\206\020\263\352?l\037eW:N\330?\263\347~t\210\377\350?@?f+\213\236\224?\000\010\217\324ap\337?:p\025\253\016\301\340?$\343\"c\020\037\311?f\337\370:\025\377\354?\204\242\257\272\023\304\304?\010-\307\210\\\036\266?\000\245y\257\317\366\216?0\347\274\005\nj\303?\340\222\321|ds\354?\224[i\274$~\331?t\035N\276\032l\321?|\350a\037\314\365\347?_S\247\275\004\t\342?\376\034\354)#]\333?w\273\263\220\245\356\345?\225\201\341\003\314\310\356?~\215\316\320\325\201\342?d\314\351\003a\232\302?\361\034i\367\226\237\346?\346\033\017\337\262\"\354?r&\360\3143\312\345?\362\261\247\240\"\032\330?)\352RG\276\356\341?:\370\313\315\270\222\355?\025\260\325\336e\005\346?C\220\364\007S\024\346?o\001\024\236w1\344?6K\213\203\374\375\336?\020p.x\266\212\335?H\225\317\"\275(\263?\005\207B@7\250\341?\020\3531\362\347\007\325?r\313!\214K\216\337?k\257\353\257\025\257\347?\303\321/\352\263h\342? \272G\265\273\277\241?\024\323 \354\003I\344?\000\030,z\340(V?H\0026\250A`\266?\351X\300.\223V\354?se\214\245\272\257\343?\254\335\227\032?\010\354?\350%Z\025\270l\340?\030\355J\304[\005\314?\241G\301`k\224\347?F\035@i/\222\323?1f\346\225\002\267\357?I\301\234\270\010\351\343?\2610\317\037\306\251\340?\274u=\256\377|\334?\367\332\223\370\'\231\357?F\376\370\3146\242\327?P}\352\315\310\306\240?\344ci\223\351\363\320?\274\274,8D\312\326?d\253%\231\332\364\325?\264\263q\200\301\247\300?\235\221\334\221\351x\354?G\204\037\226\315\024\345?p\202\335n\331*\332?\010\351=\226<\240\306?\364\231>H\375g\300?8\340e\237yL\324?P\301\307\3735\025\240?\262\274\035xx\030\347?\360\037\342\3551\233\352?\000\340i\247\372\025\330?L8\002\350\207A\342?\330\223>I\357?\344-\006\367\024w\307?\270\374\232v\003\323\271?|V\211\013\231\351\310?\002d\333\313k8\326?b\366\226od\255\325?\272\3133\224;\230\322?\"\\\264K5\025\352?Y\235w\223\014i\346?b\256\302\236\341{\347?\004?\017m\\\303\354?\352\2257\2248\351\321?TB\203\363\0204\316?rc\376\220>W\344?+\304\357R2t\356? \030\237\272\246r\316?\356\211.\301\3354\343?L2\302DQ\306\344?*G\261\303)g\322?.a;a\204\212\325?\361Ii\376\227\235\350?\"\356\207\213\"\264\355?\242\276\337\243w:\332?\352\255\263\225\223I\343??\021]\tB\177\355?\314d\213Mf\205\304?C+s\262}\364\340?x;\024P\261\202\357?\026\334\332\000~\211\352?\374\"\204/\357\177\317?\304\306\036\357\002P\321?B\350`?\340.\332?&\0239\324\314\342\344?\004V\223pE\237\312?\374W\227\275\010\317\325?wf\324\243\373b\357?LX\364\360\363\225\342?]\343\0251\257W\346?\330\253T72\351\324?J\2558\315\205Q\341?\301I\265\223B\037\340?L\202\250\030\275\013\305?\\\220\320\300\323\274\317?\020\220yD\233\001\331?\2544\201b\223\334\302?\270wS\003\233h\356?\376\352.\025B\320\346??Fd\305\271\372\355?\365K\336jw&\341?@\231\342\344\310\001\311?\n\244\201o@\023\344?$>\327P7\024\334?\270$}I\277\375\300?\252\256\3004\234\325\344?\030\266\226\231h\215\260?\340\216\223h\361,\356?\002\201\025\022D\014\322?\352\\\240U\'\322\325?\260\367\245\325\223\244\261?\274\336\230\367}\362\355?\214\260\271\022Yl\357?\037\354\020\212\266\223\344?\344w\364\272\'\263\315?:\367\027\204&\227\336?48\325(O\275\322?\342\373\343yZf\324?ZPME\250{\342?\350\337\"\373\212x\270?\215\217\203\311\224v\347?h\350\24735M\350?\260\227u\025\254R\276?@:\236\r.\260\250?\234\370|\301\205\221\307?\302\314\206\322]\370\353?\206\367\262\241\262\010\334?\315\242E\324E\223\347?\000Z|8\026\237a?z{\213\030\005L\330?=N\330\307j2\354?V\023\t^\244\333\344?Pq*\356K\235\312?\021\216\245\214\214\314\340?B\035_\311\013\027\327?\200I\331\241\370u\205?\264\255\2100\347\020\342?\007\264.\036u\352\345?gV\310GVd\345?\331a\255d\026\214\346?\320\327\020\216R\346\327?\333gnG\242\374\352?@w\341A\224J\316?\000\025\365\365e\244\232?\262\\\014Z\321\021\335?B[K:\374u\346?\036^7\375\322\342?X\300\361\221&\246\305?x\206\030\377\366\315\312?\355\206\277\344\006\302\356?4#\360U\224C\352?n\317RE;\212\343?#\"4\372^\357\355?\214\207\t\345\241\r\317?q\032\004\314\367q\352?\374V\000w\367\364\333?\004\327\226\365\0062\310?B9{\201\2009\351?\001\005X\005;\034\353?\002\035\367\256u>\354?4/\237;\"\347\322?\013E\np\354\310\346?\344R\241\244\023*\353?\334B\314K\367\007\316?(\331\222\370\276\313\354?`01\332\220\211\305?.T\010\3570\345\322?x\366Y\344Wn\305?\224\340f\212\253[\340?\332\241B\333{\252\337?No\177\317\374K\326?\340\243\265\024?8\341?}\374*\3402\006\350?@\350o{\261\\\342?\244z\203\021\345\267\322?O\254\250\312\321\356\351?d\013\267\230;\262\347?(\317\205\300\026\237\264?B\347U\'\351\n\341?\004\342\217\312\177\242\332?\373\263\271\265~\370\344?&\'\024\252\331\242\345?\233\003\346\2528\342\340?@\346\324a\260\343\324?\300\326\321\236T`\236?\364Z\337^\361\242\353?\250\340\263\253\340}\335?\366\014\323\263{J\336?\300\241\313\261\036\304\333?\350\344\036~\371\241\326?bv\201\013\201~\330?2\276\202\t:\255\357?25\002\2755>\340?\n\\\3611\017!\342?\366\304\024\360\035\337\321?\254\022\301\007\233\304\344?\300\375\211=\246\223\345?\022\3735\330%M\341?\314\321\\aK\326\337? d\000\001!\243\234?\332\256\224\036\200\342\350?\374?\033\'p\241\344?(~T\330;@\271?X\"\373\272(\014\342?e\333o\232\233\371\357?\360w`@\216K\252?U\370\000C\277Q\354?[n?\332q\332\354?<\256\355\333\353S\310?@\233n\217\322P\250?\302[b\021v\271\337?\n\t\3115\024\315\342?B7\227\327\336\375\355?9\265X\214\000\034\352?\025\267v\313\302\214\352?\240\375\373\357\326\311\257?\035\\\327\273\245\316\354?\374\243D\261\235V\351?p\221\321\350z\203\246?\323;\377T\317$\352?\034AA6D\241\352?\376\355-\206N\317\342?\317g\\\256+\302\346?\225\037\335-\373\247\347?\245f.\254\324\213\346?\204\346\267\036+v\306?5+\216o\207\314?\024\343s|\353\272\342?\200g\325\232\326\227\337?\245v\334l[N\350?\335\256\215\3510\362\347?\315\311K\345\304\r\356?)0\343\351\244\212\347?\304\204\224T\005\262\343?\320\320\210!}\213\336?%T\210Qn\303\346?\355V\216\352\333\r\345?}\003\275r\255W\346?\360\270\273\277L\245\243?dS\241\304y\232\335?v\201\370-Qb\330?z\311\362\371\243h\321?\250\322\311t\212u\343?&Ga\241f\335\350?\375\322dM\230\243\356?6p\314@\006E\325?C:S\230Vx\350?\334\274\220\202g,\332?P\347\207\214\tW\275?^\313\325W\233\002\334?\035\335x\235\365\265\351?\036\360_n\377\240\341?\240=\216%\250\021\332?6\026^\250\013\347\344?{\343\237\373\030\304\354?\352\016\331\277\343\366\341?\344e\010M\037\023\323?b\025\323w\254\204\325?\320\177O\250\217v\260?\244\3629{\177{\326?>=\313wX\367\324?J!\322\213\336L\320?,\t5\017\023\235\311?0\230!D\345!\305?L\2733o\213\250\340?\030\355\031\366\\\217\341?\031\342\205\223\177\204\343?\260\246\335\266\245\"\307?\261\377\326\336\252\202\357?\366\327\211\312\313e\325?\212\254N\320\200v\337?\374\3759\266\3029\305?\013u\342\017\262\342\346?fh\216\311\337\231\336?\234\215\211yNn\345?\010\2779\200\352a\355?\230\305:|)\307\334?\240\037\277\336s\020\273?\375\274\334\000+\337\354?\364%8`D\274\353?\260}\014\000\232{\270?L\236\224\027\002\223\346?0\000\312\177\267\301\267?\004+Uu\020\016\304?\202P:i\251\305\342?\036\232\212\006l\346\353?\020M\r\"\330\016\301?`\321\247\301\000P\240?@7\177\216T\271\326?\n\202Y{/J\325?\277\352\024\177\211\357\355?^~c<\366\215\327?\200\261b\305.\230\227?\310\357a\310=o\276?\2441md4v\316?\230k\343\365@\361\354? \021\261\232\337)\253?\236\005\025\372\314\332\333?\240)7\025\031\367\323?\366\330K\360\200\322\321?\276\3730\023\232t\353?\020\320\2324^\035\307?2\000\343\177\223\224\327?\210\271\337Q\t\314\343?\345\256\211\221\356s\356?c\237~\022\264\017\344?[\243\240P\305C\345?r(\005\245\305\247\353?x\251\351\322y\337\321?\340Z\237\373\366z\227?\000\206\014\242\236\375v?\252\303\373\307\264\215\355?\273\031\314\221\327\350\352?\332(qT\220{\327?v9\371\361\316\002\341?\327\244h=\305\006\341?\034\n\216\003\365\237\314?FC8\300\333Q\356?\310\357\227]\300\010\310?\300\006V\263/\255\341\356?h`YR1\177\354?\000?\035\225\332\000o?@X\3112\257\337\245?D\263\253:\"w\354?\020]\013\251\330\333\254?\320\007\274\252\275W\272?\342)6\237\330`\340?C\334\310\031\242\200\353?\'\305H\214\016/\355?B\0001\245\216^\357?\253o\273\324\t\223\342?\220r\307gR(\340?\247\275Ej\310\336\346?\324<}\356D\004\313?<\375P\231\341\272\305?B\257\227\023\360\327\335?\232\244\027\347P2\340?\274%\227d\325,\355?N\231\340\254\276t\346?hR\347\217-P\310? \247\r\311\367\232\316?\310<\352\300\310\037\300?\340V\232gKs\317?d\343A\007\316)\341?$\001_g\2758\354?\370\307;\016\243\n\317?\250\237\356\237\373g\336?|\247\307\373\234\335\313?v\342>>\252\275\340?3\355\035\335\313\004\350?,\221&\365\266\276\321?\000t\370\372\030$\223?\220\200\264K{\254\261?\244\322\306\210$\263\337?\312>\'\230\250\206\323?\3700\216{\207\031\332?\3577\353W\253\334\353?Z\217\333\003\327B\333?\240\246\260\272\225c\262?\325QY\204\"\352\340?$\321\331\352\223g\340?\300-#S\345C\220?\332\234\035kt\274\335?\374\257\004\333k\361\352?\021\030\217\257\351\332\341?\251\332\\n\374\211\352?\354\356Q24\213\347?\0211\215\202E\376\342?\300gu\276\220\037\203?\244mt.\331_\314?pkT\235\227e\273?\030o\344\276\233\316\337?F\200\245\017\t@\333?\243\232>\014\275\235\350?\326\004\2003q\240\340?\200\020\003[\242\361\337?~\263\366\322.\024\353?r\245\221*\206r\347?M\352Q\374\255\352\352?\375\222\2628\310p\344?\336\315\001V\305%\352?\3021O\033\261l\355?L\310@\037 p\327?\n\307\234n\236w\356?4\376\350\2601\002\317? (?\304\230J\244?\206\242\306\026\200m\357?\310\\\\,(\341\357?((u\262j\014\305?,\3546\014\353`\322?zlz\253\235\217\333?\314p \317\302\303\330?T3\354N\300\027\320?\266;\315y\3373\342?\2014\213\341<\222\343?\000d\346T\235Q\202?\205\277\002YNH\343?\220\n\350\262\355\306\306?Z\267y\254,\256\336?\340RDI\344\234\274?bE\2039\367x\352?\204\362Q\243x\266\303?\254\206\273\007\250+\343?^\277\027\302 \327\344?\302\023;\221\035\300\323?0\204e\354\322\355\264?\2676\314\007\300\263\347?\220^0\217\036\301\274?A\303\370\315\342\321\352?\010%\003\016>l\340?HC\025[\214\367\261?\206\352\301vQK\351?|K:y\242\363\324?\014\216\350q\335+\326?u\364v\300\213:\340?\032`P\357\272\313\341?@]\371\316\255\007\336?\361Q\014\217\0022\343?g$\225\r\363#\357?\2052\326b9\t\341?\310\t\367\202\257K\333?t?\304c\213\323\321?p\222\322\022\024\333\241?\207\276\302\304\024-\354?}\\\2237#\350\341?\316\223\314\206H\367\343?\177!\321\327\202\326\347?^,\306\322\215?\320?(\262\355\004#,\274?\264\030\373\310`\243\340?\347\307\\\267\004\274\345?\002\014\356\224\3002\343?\316\241\262r\\\240\341?\003G\207U\244\264\347?h3.\301\212\022\265?$\235#\304\003\373\340?\250\351\236?&\016\276?\374\317\230P.m\311?\336.\3577\251X\344?@\302|\3063S\335?\340\230\255@\247?\325?\331\327\265\336\r\241\350?\366\t\273\213\341\340\327?\272h\364\3758\251\351?J\211Ij.\216\341?\212\371@H\027,\324?\240vR\t1r\306?\252\025\312\305\031\364\353?l\262\372\305\277\344\323?\347\241Q\021\215\205\357?0\027\016\246\327[\323?\316Waf&\215\342?P\273\301\374\243\354\263?\026\207:\214\024\332\335?z\273\257\017\\\264\351?zZjhoo\325?\240t\346\030\315\232\327?\302\307cC\034\236\321?\n\277M\003\037\233\340?Jk\242b\334\271\331?\035\366\362VKx\355?\206\r\235\226\220\222\346?\224|\223\222\215\355\303?\272\347V7D\013\331?\270\334\323\251\321r\332?*\374/\363K0\346?\206F\017\344\310\000\343?\350\331\rJ\206r\337?\3661\203\234\203\214\351?\221N\322\220JW\347?\222C\215\256\220f\356?0,-\010D\303\253?\016\377\177\252\\\316\332?\232z\317\314\002/\330?\200\265{%\374\323\351?\2346n\"^\370\333?\362\322\274DJ\303\357?\261\233\006\330\006\377\355?\206x\221\326\037\363\332?\342+\300+VF\324?\220?\255,D\250\276?\374\213\373IJ\025\323?P8G!\206\333\271?\220:)\332S\370\313?\344\300Du\371\025\303?H\377\356\376\013Y\323?0j\216\032\3737\340?\306\357z{\210\241\330?\260n#\376%\t\261?\340\205l\232x\274\266?\036}\237\244\314\305\334?/H;TU\322\343?B\371^=oM\344?\030\324\300\222\3069\263?\331\3626\310\233\314\353?\340\216\0370\237_\275?\220\323\2322\215\354\276?\230\305\271\224|\362\261?6Ft\310o\355\353?\330\233\315J\315\016\274?\312:\237|#\247\352?l\207\250&~\273\340?\262\224\335\326\275\003\356?\354>\321\251\343x\314?\204\350\001\226v-\314?\240\035\211\232\226\217\315?\346\230\014\245\364\342\350?\330\n\025\000\322\210\330?\257\262\010\255\275\375\344?\232e\203\270\353~\327?\271ju\370\345\302\340?&mqZt}\337?\362\223\315\262\351\325\337?\370\353f\036O:\302?@\262\357\244\006\251\342?V\307\'\373\036\251\321?\3432\260\332\231\232\345?\231\345\274\343\371\263\341?\024\370\320,\275~\307?64\353\177\251\"\331?\241\302\024\266A\030\356?-\010\342\032\023\004\346?z&n-\306r\340?n\350\000\267\007\313\334?\344\223?N\352\031\307?+\004\207\216*\350\344?\304nG\317L\320\303? 1hq\024z\310?\310\203a\3013\211\353?\2654o\376\314\023\356?\223\031^}G\036\354?\227c\274\354\"\300\341?XnL\324L?\266?So\312\200\234\314\350?\000;{\271\2569\247?dl\020\025\247\367\355?9\340\236\222\326\301\347?\2345\246\261S\317\306?\360\346\260\034=o\336?\020k\303\n\343\000\355?ft=\217\3674\323?\340\362\204\025\023L\270?ae&\230\222\350\350?Hy\313\024\275\224\274?\271\337OR\245\357\356?\324\017\215\302\010\212\354?T\0058\226\261\\\344?\025e\354\251bf\352?\027\247|\2040\331\353?alTf\330\231\355?\274\355r\244\256\265\302?3\270\023a~V\351?\2303\316\335\376\257\270?\306\r\251\\f\213\355?\256\1778b<6\342?\376h\235-\307\333\343?8l\017#\025A\327?VjZ\377\033Y\345?\007\255J\201\333K\341?j\002^+6\257\325?\3368[\234\370\371\340?\264\261\300&\370h\346?\2201z\210\253l\304?\353\"\206\367\327V\352?`\016$\255\254b\313?l\205\201\016z\346\344?H\327\205\301\020*\304?\354\335D{Y\265\353?so\375mA:\343?\254\306\211\256\201\265\311?\310\013\347\343\347R\265?\240\215\201lkd\243?\274dN\211\342\351\347?\266\207\335\323\271}\355?t]\242\333_Q\351?\202\005\374\346\023e\324?\001\247\343\t\241\205\357?x*\206`\225\321\263?\024D\233\030y\211\340?t\237CUWC\322?\235\343LD\311h\355?\232\314\241\'\373\016\346?\310[T$0\n\346?9\356\3549\266\355\350?PF,\275\333\275\256?\326WJv\307\026\352?0y1o\320\317\330?xv\2106\334\301\346?\224\0031;>:\322?\004\230 rqc\300?T\320\316>\265j\347?\347\201\377B=d\356?\236m\200\237\272\202\357?&\200(\215\353\224\354?\342<\257u\031\271\347?2?\332\205\260]\323?\2605zC|\274\355?;e\307\003\305\351\345?<\023x\3518\r\336?\323+5\340\014\274\352?$\230\343\tm\362\326?\010\371\3169uz\337?\250 w\220\221I\347?:_3\005L\255\320?\335V\031S\276\225\344?$\266\370^\275\372\332?X\333\213Q\034\274\263?\220\356\343\345I\353\346?s\371\3348\330\217\343?K4\306\375\200\277\354?8D\273\224\276\326\351?\026W.\336C\311\325?\326\224p j=\343?st\214Q\342\303\354?\274\033.\343\373\203\354?\246\363&\005\210\203\346?\251\342w\275\203\304\344?\274\320\250\017\235\235\326?z!f\200\244\205\325?\022\250\177B\352\352\356?8\013\021\215O\222\262?\374[\232\236\014l\330?\340\3546@\021\177\307?\010\004\370\320\323\271\333?b\035n` \274\342?.?M\347u\324\356?\364\345\237\027\004\347\324?\"_\214\276\303\027\356?\\\215Y\375\266\n\332?\300\265;y:\323\276?1\216\"\327\255\026\345?\"\335\3132\215\271\355?b\016\252i\270(\345?!]\273\344\241\244\346?\331\215\276\2336\243\356?pQ\327\276\035\356\265?\320\244\306{-\312\312?|\265\334\360=c\305?\370\241\023\273o\233\345?\0205\257\234\267\211\342?\364G\213/\376\365\315?\200\247=xw\337\266?.l\352\341\\j\333?Y\310\261\267\254I\357?\302\221}\035\327y\345?\202d\035\276#\206\340?\260u8\021\220\357\344?\222\360\271\256\377\231\330?\206\317\311\236\260x\324?hs\035H8\232\304?\000\244\002/n\010\271?X\346\347\234\223^\265?d\035d\3648?\326?\000hm\300\342\263D?\014\262&\231\364W\313?\214G:\227\332y\324?l\234\311\020\331S\334?o\017XK\264Z\355?%\023\335v \335\354?\014\314\364ke\201\353?\222Y\246\333\202+\323?\214;\321\322S\307\347?\034\327\234~\374}\341?\340\363\254\344A\371\340?x,I\3616\251\304?\270\273--\273\364\354?\010Vcy\240v\276?+k\250d\332\367\342?\337\004\251\222-+\353?\200\306\001\263Kh\274?~\366\234E{\372\336?tO\336y3R\351?yu\351Z\016?\352?\320\274C\201{S\255?2\020\210\223\016K\355?S\003\354\256~\213\345?\000\321\324u\275,\350?\317\261Q\365\221\360\343?\221\246\014\334\345\272\351?\321\2500\203q\306\352?)8\031\371`\366\356?\350\023K\305]\031\271?h&B\227\033m\310?\020\2600 \273\256\265?\250\336\016\376\005!\352?v\230\224$\216Y\326?@\367Z\262\335\346\264?L\206e\364 #\345?\036\244\351W&\313\356?lY)\221,\356\311?\310G\346\313\262-\355?\250\201\375Q\030F\266? B\350P\2132\332?`f\2073\2411\253?d~MW\010\324\327?\310\375^\316\225\030\351?\331k\312-\332\210\355?\242~\035\037>\224\337?\301b\201\262\031\036\345?\n\332\2333\345\027\355?\360\014\236\225\023\332\242?\344\245\337\'#\'\317?\004\t\177\033\244\266\325?`\323z\353.W\245?\033\227d%wC\341?\372\t\023\367\222\252\330?0\277\022l\322[\267?\020\214\246\337[\250\312?\340\037b\204\333\275\233?\231\035\235g\213\230\351?\356\324L\211\006\232\343?P\332j\t)\251\354?Y;\212\374z!\351?\220ow\365NA\336?}7\271=yK\341?%\347\317!fM\342?\013>\333\217\213{\346?\204\372;[S\310\323?\356\240\003}\231O\354?L\201=\177\337\256\301?psU|\340E\275?\021\267\352\035\235C\351?\001\334\345\016\013W\343?\334\036{Ix~\310?\305?w\024\224\256\347?\344\355\034\213\271`\332?\324\343\335\005\310$\345?\234?\237\315\354?\322?\242U\333}Ta\354?Hy\310\222\000\336\266?\310\275\345^\355\252\262?\034KV\254%\321\306?\240\333\277\013Z\311\306?\260=\024fq\315\332?>\271(\257\303_\355?`\372&\322\377p\337?,\021\034j\\\266\330?\210\372\246\261K\334\314?B3\235ZO8\346?H/B\321c\252\332?\264\001\003&#$\300? \310]\203G\016\252?\311\353\350\023\260\210\355?\343\240\206\310? \344?\306{\275\215\224\025\331?*\251V\373\231\340\332?\361\271B\275\206#\353?z\240\333E\354L\322?\036QH\346\210\347\333?h\321\325fs&\353?\3560yif\275\350?\325\003\263!{}\356?\370w\225|T\231\326?\360\232&\r\014\307\303?C\361W\360\260\002\357?\020~t\210\204\010\241?\226\220\000\t\337\336\320?\003\204\264\375\233W\351?r\207\300M\223\246\345?\013\313\005Fr!\344?Js>B%\177\333?\337\214>M\'g\355?\370\\*E\313\267\351?T\t\242\200\313\306\350?Jh\275\370\365$\346?~\357fnyO\337?\236;\013X\211\007\320?\2729\363\307\347>\324?\340Is\301\027~\305?\360\245d\372t\256\247?\\\200C\t\244\020\313?^`\246\026i\276\320?\304\2064\020\373:\325?@\rF\335w\313\270?\274\255~8\330\211\333?>\313=\323\372\317\336?\024\246\357\353\311$\300?\262\024r?\n\t\343?P8r\250\221\305\273?^\037\301\240\377\252\334?\234 \336\306j\214\302?\020-\022\343a\350\261?\210P\251p\"\226\323?\210\346\222Z\025Y\344?H\\\220C8\214\276?P\360\314v\036\254\263?SUu\332/\271\342?\200\3441\003\2342\217?\306\211Dm\303w\342?\304#\221\331\221\257\326?\366\230\361\004}\375\353?\240\233\005\271\n>\321?\000Zl\240{\004j?\226?%O\030\225\332?|p\206ZF\272\333?P$\251~\376\206\334?\272\375\323\233\217\323\330?\263Car\250\273\354?\024\347\277\354\035>\316?H7z\226\020J\351?\324\306\324\245\\;\316?/\006)\272\343\307\344?\301\251md\364d\350?h9\361\0069\332\311?\333\260\350\306.\242\352?\006\203,\246\037Q\352?\326dos\222/\351?\315\245UTX,\345?\331!\271\227{\350\351?hR0\373\234\326\302?d\214G\037\275q\327?\352Mu\362\000\214\343?I\353\23354\"\350?\360\3602\3128\'\354?\264,\337\216\001j\341?\366\203\354Z7\275\325?\272\ne\034\226\221\345?\362\365|h\037E\326?V\024\256\201\356\374\354?\004C*I\236H\315?\030\030\234);2\271?`\343\271\007v\233\224?3\035.\367b1\357?\300N#$\'\266\272?\360&\357\001\036\324\353?^\353\177L\273\217\357?\232\250rf\246]\334?)\343\2542\340\342\342?\354\275$`\225`\317?>1\265\370E\201\354?l\334\246z\365M\354?\270,\207S\320\001\335?\234Nr\245\207@\331?\254\225Z\217\251+\341?>\037\246\227R\024\321?\360V\322\245\200\343\337?\331\203\207|\367\037\351? \037\034:@w\310?0n\264\363\301\244\334?<\254\007|V\036\344?\364\266\332\254k\212\325?;4\001\211G\214\343?\342\361\360T\273I\343?/\367nP0U\353?\214\353\336\017\334\242\333?}\022)$\357\361\340?\237\274\323\206\212\224\343?-\030L\374\240J\344?\000\343\311\240\3176\331?\032w\305\255o\034\331?`A\007\311\323\001\271?\267C\314\241\235\336\351?\214\311B\177\375\234\336?\226\332\374\375\274\010\322?f\204\207\275\246\377\334?\202%\243\270\020\325\324?p\250\307\233\024\211\302?\222\345`\261\033\305\335?8\364\347\035\276e\337?\216\215\217j\371\257\356?6W\245\247\347\301\322?\210{\252s\025R\300?\366\347\234\330oO\347?\225\377)\216\377\215\340?J\037\265T\'\222\326?By\202\234\302\237\325?\200\221\247\323](\221?\231\277\357\342n{\351?$K\346-w\032\327?\006,wMm\376\357?`\214\202*? \242?l\207\357\'\356\331\324?\016\3304y%\213\343?\374[\336\362\331\377\324?\220\003\373|\330\345\260?[\334\315\315\311-\345?\335\263\360\024\351\267\346?(Xm\211\202\372\263?\362<\373\236\'\307\336?W\356*\350X`\354?4\007\332\216\005\360\302?8\324\375\277gI\315?\276M9v\362p\337?\222Z\377\217\030\340\323?\375\255\334\304\375\030\341?\220\223\211\375\034\322\271?\032P\311$\316\200\322?\316\3556\207\322\307\351?f\022\226:\226\247\336?P\360\257>t\347\333?d\n\234\004\347\376\306?Z\025\t+\300\036\344?{\352\353\3159\303\344?@j\352\210\236:\265?\343\212\230\213\003;\342?p\240\334wS\304\310?n\207\351C\nw\352?\030b\252%!\217\270?\232\025m\303;U\344?\033\234\217\327\251\005\357?\360G\346{\302H\352?\220*\t~-T\254?\256\334S\337\305\342\337?\2003E\177\212\036\266?\274\300\375\207yr\333?\334\3458\275\244^\354?.\035\215\232K\324\346?\334\035\217\341rE\325?\247\031\337D\200l\344?\240L\325\301\203U\356?\373\241\004\005\221\317\347?H\\\010\325n\034\302?O\334\037\374\024$\346?`\247\350\265\335\204\341?\354\026uI\037-\353?\016\356m\\-\030\337?\300\221v\217\376\244\245?\206\345\"\200\263\255\356?\3168 \235H\245\322?\346 x\261u\236\324?\230\022\323\272\267k\313?/\241`\260cb\350?N\022}\004\021\203\343?\270\256k\020\272k\327?\364\243\343\267\026\303\316?\232L\22101\245\332?A\362\030\017 \247\342?p<\341i3B\246?t\301\332}\215(\300?\000.\304E2-\341?\260\3242*IG\252?hV7\257O\336\322?5S\363g\n7\346?\365\035\313\035\372\323\352?8\n?\355\250\314\264?\227\245wI\210\257\344?\\\376r\223\014\242\300?\207+\205\332}\311\355?\244\n\316<~h\343?\234\005C\365\346\003\302?xh\004\024\021\264\356?\232\221E\235\377\256\355?\355\037\324h\375\244\343?A\334{\243\177\327\347?mg\253\034\206[\354?\200\027\333\315X\"\270?\275\016\317*b\017\347?\3248\337-\004\027\331?X(\345I<\274\275?.\317\315\331\'3\324?c\030Jl\013a\353?\240\007\305\001\224\332\263?l\275\033\374\321\301\310?\200\301\025\031G\243\204?\251k\000\375\237O\357?@\276\3141U\311\232?\002\014\'\201c\261\346?\362Y(\361\374x\343?\200\3247\340\304d\252?i\37385\177\357\353?C\2734\374X\214\343?TA\301\357\246o\302?\010|\201\\(\253\261?\367\300\362\231>l\342?\265_\\2\022t\347?`\277\233\200x\347\321?\262)\265+\356\206\350?V#\351Y[p\347?\033\331N\317\232.\357?~ms\242$\000\325?\\\000\264\352\307\213\350?\230\240\222U\022\027\265?\256?\003L+\343\356?`oRn\303\216\357?K\203ax\033\376\353?R\370\233\230\217i\324?,R\352\203\311\366\306?\036\221 e\362\242\357?\030\302\353\346ZK\327?\337R\302\265\027\246\354?\375uB}=\205\351?\260\027Y\235d\022\310?\235\'\342q\225\217\344?\326h\256\241\000\324\324?xJ\222\310!\322\330?\322\210\245\255\217\310\343?6M{\014\271\354\322?{!(#\377\325\354?\310\241\035\346?\202\301?\354\212{\340\005\333\323?Q\317\357\365\242?\354?\242s\307\306\213\306\347?\206iu\220F\313\333?P&\376s\014\310\351?p:o\263\215)\314?\252\372\3252\254\026\331?\206\325\362\236\345\253\330?p\341Y\244o\356\245?{\203\351\272\014a\340?DM\026\374v\007\300?\340\350\\k8\253\353?\323*A\026\312*\356?\220dT\202\363\336\243?\014)Z\t\037\\\344?\340\343\262\250g\254\315?\340O\024\224\337<\325?R\376A\241\343]\335?I*g\303\3177\353?\004\033\013n\210\n\330?\260#\364\331\264\364\320?\360\302\177\330\224H\247?\204\255\277w\007\345\300? \003\213\235?6\302?\201\350N\3237\352\350?\204\006\223T\206\263\316?\252B<\310{\367\350?\216P\215\236\361\316\357?k\221Gf\363}\353?\264B\016T\213-\353?\024u\331V\327n\312?\306\207\365\311\034Q\357?\306~\263\355\222\217\351?pQ\214<\245\023\337?\270\213*\244\364!\346?\255x\237\036\374\304\344?8\360OI\372\225\377\344?\370\226=\376)\265\357?\210\350\2738\364\345\262?d\023\302\225\233x\321?\207\310\205\317\201e\352?\325x\216`NY\343?/\355:\211+\026\341?\330\344\030\315\007/\333?\300\374\262\325Wc\256?\264\244a\322\031<\315?\276\370\342O\214\010\353?\326*\363\241\273\346\332?N>\220\332\023\374\323?\300\311\261\213o\025\246?\334D\0217O/\300?9\3151\317\207!\343?{\315\350A\3716\341?\200\331k\013\324\001\262?\334\267\322\031^\312\343?D\224Iu\013\331\326?\332i]!\324\332\347?\350\3671\371\353q\320?\014\224\353f\033\213\334?\027\350\001\367F\212\347?t\321\234\337\350\327\321?86\274=\3230\314?\302\373c2\374\305\337?\200\3313\265\223I\257?\334\3605q\211\226\353?5\333\334\201\204\350\341?\242\263i\232\275\235\345?\240\031\376\216\317y\357?\266\364k~-\253\327?6f\237\\\236%\335?\346\323)t \306\331?\306\333]\201J\365\346?t/\342c\227\001\317?\230\355\201+\351\202\261?\205`\244\255\351\\\354?\320\223\311\241e\312\276?/T\014\010ZB\346?\224G&w\301\206\313?\232I\324[2\305\344?\010@]2?\035\354?\3269<^ZO\344?m\351Bk\332\323\345?bY\232F\3658\350?\350\312\035\375\354\356\272?\240\335kR\211\304\313?\r\315s\241\275\342\350?\304\037\262{\357]\303?\3561{\203P\235\333?Z\261/\266B\326\346?\t\342\341\017\252\021\342?\0373h:%g\353?\035\337\335h>\207\340?8IqV\273i\307?\"\327:\207\n\340\353?\320\007E aA\314?<-jl\231h\340?\351\2406\204Oi\341?m\311!\021#\303\346?\310\205\002S\241\305\357?\304\022R\253Ae\307?pm\365u\245\307\263?\240\315o\330\306\323\315?\324\013\372\326I\234\351? gqx\336\210\263?\300g\253\377{\032\251?}\255\327i\342\361\346?\346\030T\266\000]\334?`\014\362F\303\225\276?\344X\274\32767\312?\370\262QRF\205\340?\210\234\361\242\360.\334?)\213\257\3510\213\343?D\000c\324\2757\302?g{\2406\244\313\354?\323\244\331\314\216\016\350?\2063T\257\317\034\353?*(#%=\364\335?IO\002Z\017x\340?\"\203A\331\nc\346?\225\374\322\370\374s\352?\230\032B]\250\272\300?\310\342\273X \267\313?\255\354]\025\216j\352?\220z=\030=r\347?y\366*\036h\360\341?\300\013\252\310\017X\212?Y\230\325\361\026\313\347?NP\020\250*(\326?\356\303)\273\203\200\346?\000\020\0000@\272F?\241\n\311\375\361\247\027\350?\312]0c\241\316\335?\2771{pw\342\344?>\220\333\330\266\256\342?\020\231\211~*\035\267?\364\365*\242eb\330?yH\245\016\007c\347?}\222\034\227\356\210\355? \356\360\021M\325\327?\200+K\212\024z\304?\324k\334\350%V\334?\224g\325\347\233=\301?\302\305\366\"J\020\331?{>u\276\272 \345?J1\300\006\346%\325?\360\330i\025\252d\340?\263\204\356\265\276\254\346?\200Z\372p\322\213q?\374\262*B\035!\310?(\344n,\311\226\330?8\305\"\352\261Q\262??\226\026\032fe\352?\370\254\247w\261\033\326?h\261z\236n\252\311?\257}a\260-\367\354?@\021\010}\270\277\336?\000\207@\375\207\254\244?D\275\250s\017x\325?\036\202\213\343\000\242\354?`Pi\t\201F\262? \236eA\032g\313?\t\036\367\377b\353\352?\260~\002y@&\347?\364\273\231Q\215@\330?z\347\260)nd\327?\367\257\222\373\263\231\355?\233$\224\225\311\341\352? \264\232Ji\323\342?\330\215m![\246\337?\235\2219\377\350\245\345?c\212\313\324\360\317\342?\2402\316\231\367o\234?X,\224\222+\313\275? w\367\2201\022\306?\252\030\023\200\244\340\322?`\'4\315[\305\354?\220\227\372@\255\250\356?4\234\345\2104\t\321?\330i\273Nrz\306?@\240\361C\225.\243?V\314\251\274\275?\350?\334y\372S\005\254\344?\010\306R;\237\355\311?\370\271\032\254\263g\307?D\200\021\256\3619\325?I\344q\253\037\030\342?B\350\211\226sc\355?*\261\rD\276y\335?czp\311\371p\352?D\331\265\271\034\374\351?}Y\367\333\303r\344?+\036\2365D\307\352? \2025\022\'\347\250?\276\010\3469\205S\325?\000g\nq(>\210?\332\343\205\310\375\262\325?jE\273\367\006\353\345?\320<~Z\030\337\313?\016u\022\341\377\274\325?\323\024q~\261\347\355?\231\2169\204fi\343?t\234\210\215\n\320\322?\207\357\236\317\223C\351?\346\305\302H\0279\341?L\204\343\357\254\310\315?ps\323\006\271\013\313?\220\325\033\362\022P\332?]n\273\004\262\357\351?\376*\357\221\334;\331?\001z\004\340j\002\352?1\254e\200\014B\357?z\032\214X\371\357\333?h\'\307\013\003\341\263??\023?\017\361v\353?\371\332\'m\377\304\343?\232\333\251}A\010\345?\344g\261\233\210\266\304?\3444\262\205\210b\326?J\327\370\347xa\356?\320u\256\3738a\344?\n\270\2333\212\234\346?\312sj\336\310\211\320?\362\236\3148*\030\350?FsEV\014=\357?\300\320\332K\253\356\211?J\2723L\'\317\324?\350\225T?\006j\350?\tr)\236z\006\352?\3206\021\302o\366\244?4~ft\314\017\336?\240\200#\344\320(\343?k\254\365\271\253\230\342?\201\204\252K\007\247\343?\300\314\313\201\213[\270?.\242\256%6\353\344?\035\267R2X\\\357?\023+B\361\256\320\343?8R\232\010U\345\352?+\';\260{\362\350?\266E\375\t\3656\342?\020\227S\337;M\345? \2569\201\2434\310?D\276\217\013\202&\344?\000x>\356\261\243\242?\027\344U8E\312\352?4\022)N\243\004\300?\226\371\321\320\245\234\345?\256+w\002CO\336?\336YJ\275La\332?d\371\245\317@!\316?X\225,\225\323\237\326?\267\330\212>\023+\355?@\253\211\253\370\275\251?\032\357\200}&U\321?\372d\304L\261\200\332?I\032\244\212R\370\343?\3067\330\360\210k\342?\000(\203\0068/\305?\275\221J\212\376j\356?(d\236b\337h\357?\200\347\342\242f\330q?\340\344\002*v\275\241?\340\361(\366%\306\323?\200\271\006\247\nI\204?R\241\204\246\rg\325?\244\263\2465\367s\320?\330K\270\035\250\305\301?jc\233\003\205/\333?GwS4R\n\350?\337\300}M6q\342?:p\225)L5\350?`>\337\340\261\010\253?\360\343\304?\221\357\326?\334\035R:\027C\347?N\320\370\316\351T\326?\266c\217|)\263\323?J\217\214\213\366\231\357?\n]\2467\312\024\321?\276^\345\320x\274\330?\366\342O4\313\031\353?\240\310\272\204\330{\230?0k\375\364}\375\266?\344.\000\032\267Y\340?T\346\004K\363\221\351?:\341\245\347\255\035\331?Q]\250\355 \007\345?i\342\225\376\331\356\356?\320\267\221j\321\"\275?\007\324\313\343\212A\352?\360\2668\037\245=\311?p\373\216!D\337\333?\336\257 \252\355\201\335?\200\346\014h\t\353\341?\305yS\301.\231\351?T\350\241\372\177\365\331?\020\032\037nq\224\253?`\242\020\\\037o\341?\213\177P\277\205\240\341?\203l\345\n\\\007\352?\320\320h\3028\r\243?\331\276K\033c\344\341?\207+a\265\305\221\347?\234\204\373}\234B\327?\214YW\300T\001\356?\240\327\010\306g~\304?\317\332\307\310\254p\351?\241\302\200-~\020\350?\3406\331\376\2251\353?R\"\236\241H\355\344?\001\373\000[\357[\342?o\025\344\026\261\340\342?v\035\345*0.\355?QGm\355\310z\351?\032\351\200/\017\251\344?\202\004\005&v\002\350?\374.OZ\324\301\353?\027yc\214j\231\352?\254\225\245\261\034\233\311?\202\353\032n*T\354?\">\037\363\267\272\351?h9\251Ye\366\336?`\2656>\004\330\357?\200gX\224u7\245?\\1F\034^\364\312?\310\010\352\356\255\237\355?\300\266\002\005\003\213\240?\000F\204\366\274\"\354?\370A\365\265\351S\341?\336]\210\370$\241\323?\\\t\350=\366u\312?\n\221\271\367\365\350\324?\311\242\341#\253,\357?.\207^Ee\246\351?\314\000\342\274\002c\317?\374\335\2737\021\352\356?\004v%\211bK\351?3\016\325\346\354-\345?\024!r\353\000w\344?MN\325\314\331-\352?\307\212~\022\375\251\356?8\317-\375\037f\327?fw~\274\003\177\343?\362\\y\336\024!\325?\204JJ\356\350v\355?\277\315\344\372_`\351?\326yOw\027Z\330?\036\3422\344\340\326\331?\034\2105\003\255\364\342?\260\n\326\3422\354\332?zN\277\230\320\300\330?\024\207P:\260\201\341?8gt\362\233~\351?\230\371p\304-e\310?<\343g\330\366S\346?\255<\"\240\177\341\345?\232>$zf\016\352?\264\270G\201t\221\352?x\034 \272\374\006\322?\335\377(\326\352\306\355?\274\r\346;;j\306?\264R\210\360\000S\321?\202c\343\361\261\t\347?\001\323+\336\014\332\345?\330\371\271s\351)\273?\324\351\325\242\230\243\313?\371\310,\'\361]\350?\270[,\306\326\360\330?\250\336\\\335F\000\265?\030Q\354-ib\332?\2003B\301V\214\310?\340\r\177\361\036U\226?\000\203\351\317\360\321\214?\364%|\323O\337\336?\315wJGf\360\347?\316\363\234k1[\335?\307W\006\265\234[\340?\342\340n\256\235~\324?F\320>\330\264\"\332?\236\262C\316\351t\351?g\036\223\376pc\355?\345\304\350\310\016\245\351?\330+\232H\350\033\356?~=l\n\364Y\357?\014#0\'\305\022\331?\211\305>\263\033/\345?vp\272\256.\263\336?1\364\034\233\274\034\354?@+msk\265\275?\370F\365F{\372\312?Ll\212\375q:\320?Y3H\313\352\301\352?$~\365*\340\021\350?\276\330\025p\231v\343?|\013\2411\373\216\331?\004\220I`\221$\322?\030\0179i\003$\347?\330\257#\211\\M\326?\227<\211\013H\014\354?d~K\013`\240\323?\004\002\022K\214\001\311?\300\246\377\202\353g\337?\316\267\331\224*\276\324?0\323\t\374n\310\245?\217\267\304pt#\351?\022`\356_`\364\342?\030\307ai\324\353\331?\276\223\027\325[\010\323?t\214\001\253n\353\337?\235pIk\267d\341?+\037\'\307\274\250\347?|\'\325\257(\363\352?\243>\332\302\023g\345?;,\034N\362\232\354?LKX\240WR\312?XA\353\036\025\310\331?^\345\270\216u\022\350?\010k\352\231\234K\274?RI\346\310\030\225\337?\230\261HP\201\261\312?\240K\346\332(\020\230?\377\205\275\271\000l\356?\220\210\261\004\207\232\245?\3700\255\035\226\361\300?\2561_\376]R\331?\255y\253\267\317E\354?\374\234l\226t\233\325?+\241\331b\211\022\351?6\177\315\235\277x\344?\224\257\206=R\214\340?\271\342?\332\216\'\350?\014\302R[)\371\307?\304\256\204\332E=\300?x\344\025\354a\344\273?\217\377y\357L\260\353?\365\034\275@\220#\356?\234\334`,\301\301\302?\032\344\371\243\274\215\342?Z.>\334\243\203\337?\322\025\014\347\324x\351?eHpt*\022\353?\304\335\266\251\3759\317?\365\225v){\266\340?a\002\027D&&\350?\312\020~\350\241r\320?\250lLa\035\201\315?\304\257\3347T\205\312?\274\343\231.\035\200\307?\343\326\342\360 \351\351?\220\002\351\226\204I\275?\316\346\263\304\304\302\351?\203y+CA\251\341?\210\243\304;\270\234\304?R\337\275\'\032j\352?x\351@\027|\006\261?z\371g\270\243\n\342?t\357\303\230\353\037\321? N\243\017\223\270\313?E\267\335\225v\000\346?dF\365L^\264\341?nU\214\021>2\320?B;\375\330\317\001\341?O\225V\266-\236\353?\373K-\336\374\316\346?b\225\355\352\010K\323?p@R\201d\255\356?\" \233]\272\273\346?\320uF\3135\213\350?\327\235\354(>V\340?5\277(\356\026\t\345?\nr\234#\332\261\347?\003\230l\013\242\327\353?X\017!\337)\376\313?^\267r\245\356\313\357?\200\266224\026\317?Fpf\247\215\006\350?\200\360\203\255S;\246?;X\241,(\205\355?\335\031\261\274\032x\353?\366!\265@\325\245\353?*\253\362\314\2141\327?\340N\263l>x\307?\020(p\300\303\317\355?\264\371\261\245\312A\347?,\201\270\203!C\326?\32786i#B\357?\226\324\377\013d\034\321?\241(\251\235q\260\347?\326\027p\274\354\257\351?d\373[{\007\301\357?\354\204\n\216\\\202\336?\256[\r\365\314\373\327?\272\024\202\244Um\346?\262\177On\022s\341?\250\363u%\215\325\344?\310 /\325*\231\351?\242\327\355\232\354\203\322?\334\3231@\244\"\312?*\376$Y)\212\337?Td\364\000O\210\350?\224_\323\241\247$\322?\205\264\350T\204u\342?\200\201\256\252\267\212\273?\221\314K\270\267\276\344?\202jYM\316H\330?v\342l\372~\241\334?N\235\317\264 \342\355?\364_\261\312K}\313?\221\345\014\356\242F\351?\254}\335;Og\333?p{\337\373\365X\312?(\2228\013\206\236\336?\370\277\212\345\201\244\336?\374o\224F\304&\313?\214\375\3700\347\222\316?\000+x\3736\237t?H#\237E\021)\306?\323\335\376\323\200\017\357?)\255Q\232.M\350?f\007\200\267\257\037\343?\343\035\210\311a:\356?\324\356G\241\037M\345?\026\002v\274\363\325\331?0\272\346N8\327\246?\233F\320\356\374\311\341?`zo\266Z\364\310?\023=\215\317\'\323\354?\320\330\372\270\2272\350?x\204v\262\213b\320?\326Z\277\226bj\340?\000\274\343\036\207^\243?E\340\241\034\'\335\356?ZFuH\326(\327?\221\304\244\2412\326\346?\360\227\214\210Kg\355?\200\'#\324\266(\257?\214\366\252\267\227\264\304?k!\333|J\323\347?zs4\324-\334\354?\000i\315\001\245\020\336?\307\362\360V\036\334\351?\363\264\250\324\320\367\344?\246\315\033\370q\357\356?.\234\212\345\033r\357?\274\340\003\374\035\276\325?a[\314\332\361\211\353?>4\010\346\334\022\333?F18\333\274\332\350?h7^{\177R\357?\240:\030\220h\306\335?d\341\000fy\'\300?\311\235\304\014\325e\342?r\215\010\031^\n\337?\000h*\2463\3256?R\377\212)\350\010\352?\262WLSY\326\322?r\316`|\247\370\321?j\233\250+}\231\346?#71\035E\216\357?\300C\t\260\004\340\307?\370\372\351\375\035\256\327?\215\235Z\035\004\t\354?\224\020\246\321\253\353\340?\034.K\003\341\372\324?\206\231V\327\375\330\324?^\227\212\315@Q\322?\344\310\322A\032\353\331?=\3426\267!X\357?\346\342\t\263f\307\344?\340),\276\211\322\274?I<\202\306\325\372\345?8\3332?>E\303?X^p.\005\304\266?\030.\370\347\026\241\333?\nxU;\2755\330?\276\373\252\212P$\323?vR\257\221\301\021\341?\002g\3451z\341\355?*R\236&\0169\322?wv\201\350\\\224\345?v;\213\257k\345\326?b\340\363j]\234\337?\300\3428\270]\000\312?\332j-S`n\325?\200\036\367\365;/\311?\206\253\006\026\377\356\323?\231\013\201\244\026A\342?.\356\235\327D\215\321?V\367\0231\327\307\340?P\345\251\014pq\316?H<\324f\341\201\305?\272\355j\204\272\347\333?\245\331(\037\211\022\352?_\261\265v\241\311\344?\374\366\326)\037\255\351?\214\n\266VI\321\355?v\334\022{V6\353?k&VI2^\344?\'\006\306xm/\356?\\\3206\353\201\300\330?\334\314\020\035>\314\352?\310\200\364\375\245V\343?H\206\217{\006N\322?\256\330\261\326^\245\322?X\3716\227\226\333\274?\263\">\241W\307\346?\264ee\257jV\312?v\214\002\275A\366\341?\'#\327{\004\370\356?\220\203\003\331\206z\333?\364\010\007\204S\213\326? \220\004\224\026\366\330?\202\rG\"\300\373\342?\314\371\362\310&/\344?Qt\333*\244\311\343?\305\307\224F\227.\341?\000,!\007\035W\276?\310\373\3301\335\232\321?\000\017\247\025k\035\276?\004\374\2648\013\320\317?x\373qd\304\204\334?\006\004Fz\003\233\323?\300\230v \020\317\311?*\033\323\333\022\207\323?\273\332z\014,7\352?\n\330\274\211\t\256\344?\257\217\323S;\333?(\224\034/\264=\343?\035}Ln)Z\354?\234\'\233\013\013\242\357?\024<\246\220#\273\336?\230\327\013\031$\277\357?%\315\225\256\362\376\356?\377S\036\245\306\224\355?\3729<4\323L\357?(\214y\347\261\313\305?\257\037\246\203\'\334\350?\247\355hD\3339\344?\245+\2174\375C\347?\300c\375\010\201\264\275?H\022\243\014B\274\260?\363\007\024\207~\306\344?\335d9RjH\341?\257[h\267\202_\356?\227\343\315\260\367\264\353?\370\\}\265\204|\271?\341\226\356UDY\356?`\306\204\231]\237\225?\262\364,\346(!\353?P\373k\354^-\242?\253s\007\034t\202\347?d*\276\363o\234\330?\034|\312sd\270\310?\352\306\362k\003\307\344?\363^B:\004\277\345?+\241\013\021\027n\340?\034);\347\356k\335?.b*\307-}\341?Y\026a\361\037\317\344?H\220\t\356\366D\273?\344\024Y\3570\351\344?\022x\2035\261k\341?\224\340]\325?dW\203\263\337\007\346?c\324\221xmM\343?S\027\253Sn\"\356?\030\315n\375\226x\274?\323\302\272\325\223\324\355?,\261:\216\245M\352?.\247\024\362\277\363\335?dSKj\030\226\345?\264@\250\356kX\343?\330\353\233r\203I\265?\315\223\005\306q\\\352?\357\212\205\2264\331\353?\247\206e\263\024h\351?\206W\212j5\327\341?|\034\023\356\361\243\345?\002\356\310\377tH\357?\323\255 X\343\353\352?\300rQv3l\234?\351\247PP>?\342? \031\323wq\r\264?\000\253M\347\010Wr?0*\355\022\262b\356?\210\020\271\263\357\222\312?aq@f\214.\344?\316\336>\032\225G\324?\000\307\232\273\277\352\212?\261h\266\330\275\361\354?@ky\313F\344\321?x\235X\330x\244\322?\215\036M\223\340\370\345?V\027\022\305\252D\340?\260\341+\325\217\232\310?\264\377\270\2716\310\344?\275\301Z\256W\317\353?\340.\256\014\311N\241?T\001me\331\236\355?B\366a,\237\202\347?hk\321\336i\350\304?\000qz\237\333\007\327?\361\310\341\014N\211\342?V\335\245\000\361I\353?\0107@w\341%\335?\005{\373\221X\371\351?\263\204\034\324\346\177\345?j\231U\316\254h\347?\201,\033\013\324\"\347?t\036\016\nL(\336?\240\247E\202\237e\246?\320`a-H\270\312?Pb\367\225\260\365\341?\224K\2421\370\003\356?\214\344I\307\305\333\356?\366z\\\254\311\220\342?\014\r\241\\\347\272\334?\320\225\037T\003\236\340?\276\320c\350d\254\327?p\343\001\335\000/\242?\264\3777\177\350\312\321?\002\000\030[\254\212\347?\270\\\305.\2035\323?\250\251\366:w\370\304?\236\317\240L\326]\350?\204j\201\202\254\004\305?\250k\357\326L\364\306?p=\236n\024\311\333?\315\177\026.\345\341\344?7\237\312\345\227e\353?\n\325\315\353\177\227\331?\260\275\213\367\316[\274?&Q%\373\245#\342?\230m\t%\266w\261?\356\243Z\324\322\027\350?\221\253\340]\237Z\354?\356\332w\304Y\020\346?b\021\016\262\363\027\343?\340\014\352\241C\362\267?0\356hcP\313\326?\230\230\270\317\260\334\347?\323\304GGr2\345?\027d\253\023\016\246\353?`V\364S\007f\322?\335\005\320\342\203\232\344?\210\251\\\213\314:\337?\327\272\340U\005\241\345?\340\210J`\273\307\314?\010@s\200sh\270?\200\352\016\372\345\306\263?\246\3605\235\003\253\325?\000s\355\226\370\352\203?\275\024\032x\220\352\344?\310\251\347\013[\216\325?\300x]W\377\312\336?\235K\261\316S\375\342?\020\205f\254\320\026\324?\304\rZ\255\356\360\340?{\252\255&\030@\343?^\327\376\370\036\204\355?2\215hV\234\332\353?b\301\257y\313`\353?n3!\343\253\277\321?4\005\301\267d\216\322?Y\216s\371\027z\355?.\223\363\244\355\202\354?\241\235J\260\263:\354?\000\364Py\342\006n?\346\257\214\t?\304\341?\342\331~\1778\214\350?\236\241I\nY\252\334?{\001\2063Q:\355?\020|A7\031\231\333?%\277|Vdz\341?\332\270\354\202p\035\326?EXj\213\211p\350?\207\350\263n\322\230\354?\351\373[X\016\217\356?\312ux\311\366\343\342?\240j\376dk\256\220?X\364\033L\361%\310?\216\244\275\215\225\364\354?H\217\210\020\235\022\315?\340}\271\256\252\374\267?@\263\252\373\322d\202?\026\351\021>\324_\326?\026\374\204\217\301\357\321?\272#\277\2261\002\333?\260\301!\350\276Z\352?A\334u\363\004\241\351?\277\347\226(]\361\346?\212\273\375\225\004\274\330? \231\232\230e\024\354?H\205@\301\302\303\275?\341\2608C\0250\341?\257\304\3133\335\215\342?\024\373M\221\001\217\334?_\211M\373s\255\357?00\272\202\334\224\322?(*\005A\363b\276?\032\273\224\336\374:\357?\254E}\251\245Z\321?\302\222\270\336\325\351\355?^\242\367\260eC\325?\244v\335\nU\007\334?\334\272\231\272G7\332?\252\3534\\\2404\351?\177\317\312\023\023\251\350?P\236\0376\311\365\314?X\001\343\014\233L\261? \363\235\270~\014\272?F{D\265\241\321\350?\270 \236:\373\017\324?r\330\360\227\264,\346?\274\270\242\365\240\305\343?2\027P\3343\001\326?\220\301\345*\006\003\316?\2321M\270^y\341?\001\324r\273h]\342?)\237\206]\326\361\353?b\217\3064-\352\332?P\023\007\2471\213\341?\367\247s\352\n\222\342?\366M\204\3520\250\353?\350\364\233\001\210\317\274?\200\342O\356\001\001\225??\032\346\007\226#\352?\357\251\313\313\343\225\343?#5\177\2418&\340?\205\002a\326\214\236\350?\320m\325\355\351F\313?\020&8\036G\317\277?\256\275V\177\332\271\337?\341\362!\006v\261\357?\300M\331\201^\353\254?\373\002#\331\240\332\340?\300,\261%\335\251\331?\300\031i6\346\324\324?\226\335\\\020tn\353?\354$\253\'\256,\352?t\'\340\210\374\247\350?\310\034\335\253\337\265\313?\rv\201\035\356\305\340?\253\274s\304\323\232\353?^\210]\244\262\261\333?\010ua\321\221o\313? &\200\322\244:\273?\320\363\013\240\303\201\257?TN\366\020\224\351\305?\214`\256\027\332\034\340?\020|\356\033\"\277\305?\014wG\311\323\314\346?\026\227\025\360\316b\324?>\216\360\277\006\236\325?\264\376\237\221\253\337\300?0\3637\253\353\301\323?\304\261aFWj\306?\320\'\355%\004c\276?\200Q-\025\330\r\223?\374\236or\207\260\306?\270B<3\315\027\266?D\227\344\343\261Y\342?\352.\365\t5\331\331?\264W/\271Z\t\320?\306\326\355\025\360:\341?\256\312dE\270\017\347?\020V\'\035\230\035\303?\000Am\356p\006\201?\210\304}H\023\267\311?\300\372r^t`\211?\231\363\215\307\302\213\353?\346\343Z\034pm\326?\3006\260\027\204S\355?\223\006\313\237~\330\341?D\265\263\333E\001\357?lY\007?\002x\327?(r\236\247\013y\351?D\346_ho\036\310?\215]qMFY\351?t\330\303\r3g\317?\320\343m\250\243\255\276?\271\\7g!\014\340?>\355\207J\030!\347?@\214\331>I}\256?\370\037q\3559j\333?\236\244\240\\\302\234\341?\227d\304Q\316y\347?\375\235\215\362Y\033\340?\301\254\322/\231\274\347?\030U\354jk\275\341?\005=\343jCa\355?\020\353%\331\300\024\243?\253\231\345\007]\257\343?\007n\220B\273\347?d\014=W\231\342\350?<\301\\\340\311\320\320?\334\231\247\240\344;\320?\240L\361\245`u\341?\367u\n\317\370\263\346?\3453\271\035S}\351?\014\375i\t\013\022\356?\372\300-\245Dg\335?\326\343\307Z\344\217\330?\002X\254\362\240\261\344?\304\337\257p\022Y\334?l \023\216\340\373\336?\340\025\305\265\345y\356?\034#\252\324\017\247\352?\264\323\262{G\215\307?pv\0359\001\315\306?\225\3428z\204c\355?^\341\343\017S(\344?\261\343\027\262\327\236\342?3\t\322%:+\344?d\035\232\232rD\347?\034\366\021,dY\315?%\356\244\340\013\340?8v\000JQ\231\305?66\326:\035\276\353?_\246\207\253\377\371\347?\216\352\213\275\216\247\346?Y%N\326\305\332\343?\206/\362W\'\370\327?\272\024\"w\302\035\350?\035\345UF:\256\355?\330\033\230t\254X\356?0K\365\007\"\206\322?\252\"e&\353a\331?QZ/\2703\273\350?\236\300\014\313\016\234\346?\244\001\332\341\023\346\321?\366\016\014\204\364\230\331?t\014\216\207m\320\324?aOD\0273\354\353?\006N*]\303\237\336?\002W\014L|r\323?\036)\204M\345\020\336?`\\\263\213M7\260?\340f\352U\273k\224?\216\235AA\013\"\332?\335\354\037\013\0349\343?0\253\000Wp3\264?\000\260\335~\037\306\257?\214\370\305O\200g\316?\2404\351\022d)\267?\332\326a\006Ox\322?\224;\025\263\267\316\323?\316\276\343`b\342\341?\220\273\032Lx\273\330?\030\236u\362\300\261\327?&\233 \266\375A\341?/\274\374\206>\337\356?\321\020\353\366\2550\350?\301\322 \253\037\200\353?|\270\203\332\341\210\353?\266\310\264\247\026p\330?\\h4w\315l\302?\262m\211$\020\300\342?E\303\227\356?\360l@\202\212\\\353?\375\257e\264X.\355?e\2052\3655X\351?R\251\224t\335M\332?\210\336\013\022\022z\356?Y9`\372\001\230\345?h\364\367\217\004\362\311?\354\330\216\353m\276\327?\334s\274\367\300p\313?y\272\376\231R+\356?\324N\376W\224\255\315?\000(oT\323\315v?\032\023\377\016\006=\354?\220\264\224\327\352\255\262?\357\244\212\226\313\372\353?\3170\306\350,h\353?&\267\242\215\327\357\357?L\315\354\352F\300\356?4\271\325\377\014S\313?@\233\252=G\200\324?\212w-\345\021\377\342?\200\236\350K\360Bv?1({\347\253C\351?\225/\354\226\202\325\346?:H[\0310z\346?\036@\"\\\322\301\350?\202/\266\256\272\255\325?\372(\337km\220\335?\026\253\314\225\321\341\320?8\n\310\252\244\302\330?\255\035\330\223\371g\342?f\357BkT\266\322?\240\224\233\304}\350\244?\340\273!\375\035@\255?\256\356\300Ny\035\321?W\366c$\357\270\350?6I#)\354\214\352?\312\327\007\272\325\315\323?\345\'\250\351\216\254\342?Z\021S@\265\346\325?\344@\316S\034&\344?\317\200\266\210mG\353?\020\202\300\337u3\320?\210\020\360s\215>\324?+\277I\332Ig\354?\2534\301\2733\300\342?\334G\177A\347\301\330?\247o\371\272MY\341?\250\021\0364p\265\353?7tS\331\2523\346?\nj\214\007\332\257\346?\211y\364\026f\022\350?\221\220\234\"h\202\341?\320\242\372\0018\274\340?R\351Ch\245\242\354?\324\203E\344\331\373\326?0\217\201n\346\357\275?L\257\247)\224\310\303?\000\027\026\345\276\003g?L\017\014\2421*\341?\320\362\363^~f\240?\377p\3656\2163\345?\360\025\361\270A\026\332?:\027\313\334\0162\356?>\316\302os\036\326?3VZd\307\345\346?\236\244\nP\212P\326?\320(\000\226&m\332?\254$\203+\311>\305?|wC.fV\317?\200\332\250>k8\245?\\e\025G\003\212\315? $\021$9\356\355?\3644*X\255P\337?tP#\217\036`\357?\236td\360\351?\320B\352\350\237\340\273?\266\354\202\336\307[\344?\200\363|\205\350\331\233?v\231U[Z\001\347?\320C\340kf#\240?\264\235v6r^\311?\314{\301\013\363\350\302?\320\222h\200\231\257\320?=\000\351k8f\346?\212\227s}r\323\346?\354k\202\027\225_\342?@s\327(\031\010\341?C@\020a\256\010\341?`:\260]\033\233\346?:\235\314\014\254\327\331?\314\tWA\310Q\327?N\024[\222\007\255\347?Y\202>\\\306\204\347?p\3701\351\t\300\356?8\331\336\216\201\226\303?\272\3013_\2524\351?\234k\320\243:g\341?*\027\342\236\265\257\341?\310\350)@.>\304?h\362\206Tr\377\310?p$\2670\376\324\246?\216E\274:\327p\336?\030\3327\324h\235\323?\203\363n\347\231\332\340?&7r\021\314\357\332?\256\025\314\022\025\207\327?\216\235G\277>\204\343?RC\230}v\321\354?@E`\311?\034\245?9N\352\240\236g\355?\232\035j\302\203*\353?\256\010D\3205\301\322?\364p\256\350\372\273\313?\366/\271VrT\346?{\0060\266m~\347?\322\233\370Xx\377\346?\200\371\007Hy>\332?\030\263[\360\037\004\260?\310\234\216\374\264\252\304?\010\345\235\'\205>\276?\376\346\361\027\271t\320?8\334\230>\212\363\317?\377\236\347?\020\360\245,\343\260\357?R.#e\0142\347?`\363\3024\326 \310?\'U\256\177P]\350?\366o\354\013}\203\341?O\016\377Z;9\356?|\212\241\347\230\323\352?Y\035\221;\374\247\340? \302\305\346ol\334?\333\372\033Qn\372\346?\372C!\267\003\t\342?\020YGQ\262\265\272?h\217\266\240\254\273\350?\220\232\231u@+\347?\032\200=_\035L\351?\026\310,&\261\203\351?\250e\215q\312\226\322?]\001&\241\014<\355?r\261sw\367\032\325?\324i\215\265\257\031\332?v\005\007\275$^\333?\371 \r4U\363\353?\024\362X\326\272\022\301?\374v\035\\\345;\357?\265\216C\321n\374\343?\321B\230;\r\305\344?\330\016=d\222\376\261?\310Z<\225_\265\323?6X<\005\225\253\336?Z@\320\240\330\344\353?W\236[{E0\357?\027l\266\347y\305\343?\204\343\365\303\337T\332?\307\030\316[\201\211\356?\300m\321\255F\322\317?\207lV\365\"\345\351?\210/\001h^\353\311?\314gy\341(\255\306?x3\"\254&F\311?\023\037\233\023\204\207\352?\300*\221\\\027\020\304?8\266O\364;d\307?\213\235\362\317\264\346\342?\240\364C\205\025a\327?\322\020\216\203\227\023\341?E\270j\362\331\312\342?$j|\247\227G\353?\210\\\213\341_\351\346?\361\222\324\032\216e\352?\010\'8\371\t\340\336?\004m\037\345d\217\310?\363\306\274\020Z4\344?\354\330MM:\207\343?\355\355\215x\253z\352?\210\032\177\307=\000\274?\"\223\0022s:\337?\357\364\377\331%\234\351?\031e\313P\310T\354?\336\000\315sV\335\326?cZ\'w\233f\344?\370\240\2211\025?\322?\240\223\273\275C\343\332?\365u\311\025\335t\345?\335\014\001L\227\306\344?8\370\222\342~#\326?<\202\0101}\023\357?\364nF\264\230s\315?\'v\277\207S\327\352?@H\307\276\220q\305?\210\001\322u_\024\313? \034;T=\233\345?\314\347\312\320\017\253\307?\300G@\365\241\370\217?f\304/6\017)\337?8-\353\322\240~\345?\350\360\240\353\205\351\341?\360K\366\002;F\255?\214\t\310\200j\372\306?X_)\300Mc\261?y\r\273|\3621\355?\200\ty\276\313\036\254?P\t\273\n\222<\314?\232p\325O\3003\355?\332\276[\037&\235\321?\313e4\255b\353\352?\334\307\326\210\034\204\346? \002\'\244\257\222\312?h\365\366\342\274\'\316?\013\223h\236t\206\343?\016FgOG\257\322?\003\037\327\367\255\035\354?\326\2045M\342\246\344?\333\315~T\220\276\352?\204W7\351\371\r\345?\360P\023RC\366\326?`$\237m}6\257?\000\237\220\307\312\304c?_\256\241\344CY\357?]\367\2241\305t\341?O\365\323\311\276\246\342?\360oHX\334\356\334?\204\226%\341.\036\317?\037\213\271\203\331P\353?\304\260\205rG\227\306?\247\315\020\337P2\357?@\301\245\345l\000\301?\310\221.\372\355T\352?GZ\360]\343\354\357?\267K\374\377\274\203\353?\234H\r\323[\t\342? \312M\262\256\034\231?\306\0244lLN\351?\250\242|\322\320>\300?\326\240\323\360mj\330?\\\202*\203*_\315?\311\001\277\0210\212\341?\347&~\3360C\342?\204og\351\017\025\320?\322\311\356\246\270(\337?\234-\316>\t\261\343?\236\"3\202\353\351\343?(G\010x\323\323\336?*\303)j%`\334?!!\220\254\233F\351?\300\333\333r\226\226\335?\031\322\362\362\263S\352?I\227M\251v\021\343?:\210\255\204\376y\341?\242\243[\214\374<\352?\323\272/\200pR\342?\350\223b\032\027\036\331?8ND\377\005\354\320?`n\375\337\341\363\346?\000\010\374?\260c\223?PDl\223\211\336\301?~g9\025wJ\341?\210Y\326\325\273\221\315?\017\271\301L\242H\355?$\247\375~\261U\304?\037\273\307v\177p\341?t\030I\006\235\236\347?\030\234\265\345\0255\301?\254\n\241\246\275\372\324?h&>\211\307\302\352?\204!wJ&\025\336?\310O\312J9\352?\217\306;\333\355\233\344?,\211\223T\242\233\353?\252\355\312\016_\010\323?w6\220\t\027\263\343?y\250\360\245\205\246\344?2\002\342\307\212\037\336?\310\223e\026\344\374\334?\034\351\364\'\305l\311?p\0228\226bj\313?\3364`\221\030\224\353?\201)\244\237\034p\346?P\021h|\272\035\332?\000\362\267p\217[\\?w\262>\255\231\231\357?\316\245>@\026t\324?\342\273\337\037r\273\342?\240~>\\\306s\237?h\370\223\322\261\014\326?\304\351\306\337D(\320?`\325\000\271!\251\251?`\np\353\354\211\316?\034\200\221\037\310b\306?\021E+\204\227u\353?\354\203\177\323>\023\345?z\247\304PQ\020\346?F\210c(\267o\333?\276aM\0166\347\355?dt\031\343\367\331\350?\274\375\331\367\363\213\310?R\362\243Wx\301\357?\342\320\275\341\037Y\341?\034\306\250\300\317b\322?\200E\364\252\233W\343?\232\313\2117\276\313\356?\354Ys\025\261\211\310?\244R\261]\'\256\357? .\317\271\231%\221?9\375\235V\241\203\354?B \016\201\360\275\347?@\217>\r\246\240\346?\002\223O\251\316G\353?\240\223\326\326\374\376\251?\014\240\240!\3334\302?\032\207E=\\\374\351?\353\\q\341-\221\356?r\231:\242}\376\357?\002\341\337^O\316\337?\200\364\036\335\257\376\263?.R\371z\036Z\321?j\347\341\223\357\356\356?t\227\365!\334L\302?\321\2436\343\307v\346?r \354`\223\201\352?\216\305\221\354\220\225\345?\000a\342Lk\202m?\272\267\"\217$,\347?\016\230\322\362\334\352\340?xiN\316\033;\275?\272\207\260\223;\265\350?\207\236Z|\301\027\350?2\204;\257\356H\356?\274`\351`s\340\324?(D\246\301\324\217\261?\240\206Xg\374\017\272?\200\320k\250\316V\337?\312\264\023\251%\206\325?\010\353\204\177G&\327?\270!\226,??\342?\326\2304\321\317\333\342?\232\234\375\317\307\354\322?\325\327\253M\341u\344?\000\302\375\235\213\351\324?pP\310\301\216y\256?\253\017\240\312\305\251\347?:U\273\366\005\310\343?\310\277\234eI\214\315?\001\035\276\257T\324\356?0%\244\031\305\203\316?\352\347Qu$\274\347?\210\354\311nwP\274?B\t\2114\007l\357?,m\000\231\271\202\307?\260)d\236E\305\356?S\375\0063\346M\352?a\212\343G\300\"\344?@\324h\247\354V\344?;z\026i\325\223\342?\300\252z3\267*\322?V3\226\270\324\004\332?\030\005\203I\336\370\320?\270v~{\234z\261?\340JmV\036\305\300?4d\225\361\301R\331?_\306\220\213:\372\347?\260\267\363\331\314\241\302?\020\373LE\322H\317?R\3753XE\255\333?N\353s\235\233\006\332?\377\337ZLm\350\351?\356\275_\215s\352\327?\257\2713\262\003\272\356?;\"\026L\237(\343?\310\036\366\323YK\357?\225\331\246\364\257\341\340?\354\262\201r\303\311\326?0SA\302\302\331\357?_\204\005l\263\010\357?\254\352\362\306\374R\305?\265\000nT\254\215\340?X\311H\004\212\264\355?I\375\257\021\202\373\342?l\"!\354V\036\327?\000\342\220K\017(\235?w\020N\221\253\005\344?\000\301\001\025\206@\334? \203\334<\241\337\332?@\334\255\363\312\030\232?\000@\243S\370E\267?\363/:?-\030\354?\227\370\323\246\254V\347?\202p\320R\336}\320?kn\304Oh\257\346?l\026\331\253\336\323\331?\200\250Z\037\230\010\275?2u\313\346\3251\352?\236\265\007\252G\237\333?\026\026\356\275\223\330\345?t\241\242\237\025\275\343?\002\327\330VV~\336?Wef8\374\003\357?\240\270Pu\241\320\353?\354\302\026!\035\261\302?\360\203\337\014\000\222\255?V\327\205\372\001>\343?\020\256\255\020\254\241\313?\262\215o\214\245\223\346?\034\\\326\324,\213\326?\370uo\347\001g\266?\370\2473\247\374\363\302?\360\330\212\0143\247\315?\\\244\274`\273\305\356?\352\362\236\331 O\333?\356E\270\366\363X\330?F\320\236o\237\212\357?\026\007\323\030\240\272\335?\350(\032\277\037>\264?\022\346\313\211O\360\333?w\340\027\371\220W\347?8s\254\276s\206\305?\270\014\317\231\346\232\312?\263\014\332\307sS\350?\375\317\364\274}\204\353?\266\343E:\341\324\357?\261\326\243\346\214/\342?\226\300\217\310\000\235\331?\210+\217&k\225\265?\"d7R\215\377\356?H0\326\206\315\235\332?\020\343\264\335d\337\243?\324\000\033\202#Z\334?\230\345y\374RE\312?\246\261\363\366\265\271\333?\3166\021\024\211\236\321?\010z\221\017\026\222\277?~\265\312\317\347\276\342?[n\\M\212\024\346?\222\237\344lV\021\323?\353ys\252\220u\344?\264`\310\213[\344\347?\242\227\\l\321S\344?\020&\371\310a\357\250?\324\343_;\245a\331?\n\006\370>H\247\353?P2=W\310\033\304?#8\335$\351C\356?I\205\250\0366\266\353?xY\352\260\352n\335?\324+t%d\370\316?\370\272\244\262\321\232\303?\\\022\331\302\302\035\356?\206\213|\311\210\251\334?A\272E\203|+\352?h/\250\224\207V\353?\321ex\207\236\002\356?@\214\214\351\325\241\252?\264p\331~N\344\331?\374\037\225\213\005<\336?\224\037v\270\303\177\302?xk4|c\273\314?h\003\026\250W\\\313?\230\\\022\037\371\003\310?\230\347\274@n8\355?\330\3063T\376\016\272?\010\371u\3616\022\347?\010RW\347\252\240\346?\035\213t\361(!\341?2\307\222\024\354K\341?\016\243\206i\025\245\346?\247_\002^=a\355?\374\033\336!\362\332\315?\306\201\n\334\000\334\343?\363\260\2117Hq\351?\243`\342\377\013\316\343?\256\352IW\365\t\322?\253\006\234\327\353\315\342?\334\310P4L\330\302?\206C\351\220\363?\341?\300:Sm9A\214?\026qXE\030\037\323?E\016\243D\216\226\351?Z&\253\266\225\037\344?\325\311\034\317z\316\350?C\372\024\014(\360\343?\324\274C\204\265\330\353?\270\213\244\236\277I\317?\315A\332XNI\352?\340\365j\333\311\251\232?U\024\333\031\246:\356?\346s\367\021!\303\346?\200b\022`\256\270\332?\3760n\351\356\204\353?:\340C\276\324\312\327?`$T\255\177\020\232?B\000J\234\256X\322?\212\362\317\362~\310\325?$7)\221\221X\317?\227\277\333\251\310Z\355?\274X\3768\347>\313?\362:\037\340\262\306\340?\324\010&\032\231~\334?\010\230X{\020&\315?\274I\007\017\320\002\352?\201\230r-K\357\340?\nL\210d\3373\331?\354\n\311\032\361t\322?,\211b\232\225\366\357?@}\3443\214\016\354?\350\270\377a\377\030\333?\270g\345\004\014\300\316?\364\307\242{\377p\330?\250\226\320f\302\313\333?\274\344>\"&\373\351?\254K\323\213\364\212\323?\000\322\307\330\342\222\344?\350 Js\016W\353?dx\260;\036\010\303?\334\260\242\\\'l\301?.\\i\226v\004\321?\270\'>\324\335\346\340?\003\031\034\2629\302\354?\372\353\2607W\307\355?\320\020\223!4\274\314?8Ot!\031\"\273?\330a\314\026:\215\351?\002_\321\'\203A\342?p\376\260\0003\331\252?\017\003\247CO\264\354?\232\266\177A\304\177\350?\200\315\214\020e\351r?\317jJrO\354\352?\256\351.\341M\252\356?\ro\025\2510\030\343?\206\177\210\022.J\347?hqQ\'\275\035\342?\310\206nV\362\334\262?8\2072_2g\315?\246\004\350%n\346\351?\350\257\222\244\023m\263?`\022I \377\002\341?pk\262\t%_\342?\344\332\216\257$\315\311?@\204\350r\334\225\231?\317\376\3654\374Q\352??\354\346\216H\340\353?\220\353G\216\217\326\340?\346-\014\360\345\370\345?\\.\206\212\342\201\316?\330\206[x7}\357?\220\322\255\320\305\032\331?\324\303YA\265\t\322?P\351r\337\300\220\343?\334~6V\301\200\300?\234w\272px\016\344?@d\362S2%\234?\014\274O\343m\357\316?\316-\03234o\347?\346\255\320#\312H\340?\323\260\000*_\007\347?L\231\252\263\255\275\337?\021\034v\230\247\t\340?B\254M\177\034\347\336? \223\310\341lm\320?J\365\351\023e\022\334?\014{\'\216\033:\341?\336v\233\363\003X\325?\020\234\262\346\210o\301?\260\005\204\"(\324\267?W\301E\017\006?\354?|_1\207\300\276\302?\276\327\002\027Pk\335?\315G\344\246<\210\342?F \266=\016<\332?,\274cL\220\323\340?\027\266\242\323\213\022\346?\334\275\024G\255\377\345?\200\240N\307)d\226?\302a\302\307-7\353?\204,\326l\2271\350?\302R\021\312\316k\345?\266n\255\363\267c\347?\307-\253qGX\350?\340\260<\300\311\260\313?\330\254\356\312\302{\347?\276\353\032Q\023\352\341?\310\327c\372\024Z\302?`&KO\243M\252?\351\022>\\\321\321\352?\334\205PI:\261\323?\016%\317\262*k\355?Z7\024\346\216\263\325?\340>\370g;I\321?[X\003\022\306^\350?\032--x\205\300\341?Ql\266B\230\333\340?\206\245\351\243\244\026\334?\264\360\263\225\223\351\303?&\3671\316\365,\327?\344P\245_\352\275\315?o\271\n\206\n2\354?\327\3722\361L9\340?\216F#\333\337\360\322?y8F:ud\351?`\030\305\216\335\007\274?T\335\026U\350\274\315?\307\266 K\262\271\340?\252j\002b\325u\337?\370\312B\215-\363\311?\230\275W\336X\343\354?\000:l\212\341\370a?\225_w\014{\032\357?\363Y\326\346zo\356?\340\031\261\002\031\246\250?\250\262\327\317A5\347?\\\374\375\266C\317\314?\014\264\340f\300\304\305?\010]\363(\236\014\311?T*\354AV\010\342?`a\277\"i\r\316?\364\332\266+#\347\327?\262\234\303\014L\217\341?\226\354_~I[\352?\342\022\343\017\021\334\323?\376P\013:\013\242\332?\005\315\303b\335\205\352?n\032@\236\007\353\354?<\230\337em\200\314?z\233\nO\355m\351?\200\260l\230\021u\312?<|\3546\225\'\315?d\t9\326\311,\302?,\230AM\305\004\356?$\247e\037\265\255\304?\376#\317\005\246\240\330?\013/!G\356\225\352?\340\020\365\"m\333\315?d:sxP\017\352?\260f\t0\0052\317?\230mHC\231\004\265?S8Uf_\366\356?9\314f\227\236\037\345?\200-)\313\252\n\231?\250\224g\277v_\261?\230b\350\037\232q\263?\200s\223x\"\230r?\200\\eV\202\211\236?>Z\324\310\246\024\335?\3160\n\355\251L\327?\024\246x#n~\345?\212G\270\2463_\336?ZP\006\215\263M\332?\365\020\316~\342f\341? ]\365K\3467\320?\267t\237\375\243\344\342?(P\001\213?\365\307?f\\\244\250\337\035\352?\006\020\213\r\211\361\321?@\346s8\024\213\210?\335\327o*=x\340?\366\202w\316g\272\341?p~E\257\236`\341?\274Sh\350\032@\337?h\313\204\335W{\301?P=o=\325\321\347?\"r\316\024\230,\343? $\212\273\3039\313?C\2643\232.(\351?\263\335%\t\350[\357?\325\216\350eK\362\342?\210\336}\222\030(\272?\2200+\256 \223\250?#\344\261C5K\350?\356\214d\362g\347\354?`\253\347\307tJ\355?\261`\324?\201\037\350?\2043n\030\300\003\341?H\301\250\006\312\222\344?\317\242\000\324#)\357?8e\312\331\276\362\321?\257(\000\205\220\312\355?@\363\342\300s\251\247?M\265j\373\032\232\351?d\032\021+J\356\322?p\326\027\240\315\227\324?\200%\016\025>\362\324?\037\331\'\370V\035\342?\244E;\212\250r\353?\200\331\233kEW\266?\222?/{6\277\333?\n\234m\207Q\212\327?*\232\2539\027\345\355?\224L\346!`\345\350?\307\233\261\027\273u\340?\254`!\234\034*\302?\272\370_X\212}\343?\021\003SU}\251\340?\240\216\245\201\311<\256?\373\010[/\374\345\342?\373\3559\230\201\371\357?\243F\030\203\306\275\342?\256\022\334\346\317x\332?\343\3531\327\343$\352?T\315w\177\261y\347?\220y\363\"\233I\320?\202\014X)d\231\344?\240\374En\210\264\237?\343\313\316\347v\270\354?!\354\003\030at\344?d3\214\030\016\274\305?\n\217\327\315\333\010\352?\374\240`\300JX\304?\323\342U\324\215\274\342?w\010\364\307\340\244\345?\307\274\277\310\272\275\347?\252N\256\312.\035\356?\026\230\253\366\356\323\322?]\034\351\205w\214\346?\236\307\326!\037\360\351?\272\325\376\220\234C\326?X\346a\227\n}\337?\236\372\211\374\204\342\326?\300?\260>\271,\353?\314\301x\313\202\004\300?\034\351\223\224\323\274\307?\3362;r\260z\346?\221j\002| \204\353?\000\200`\317T@\337?\340P\037\021\204\314\230?\032A\365z\001}\356?\355\3045n?I\351?0\t\341\032\000!\267?\304\337\035\022M5\347?\274\210\230\005\262\352\316?d\367\026\005O\013\303?%\014R\220\337P\346?Pn\260\036\n\224\331?\360\367:\177 [\335?\\Yf\037\354\262\312?\244-{\242\025\202\336?\274\2536&\346\002\323?\222\242M\300\302\226\344?\264\262\247\230\260\033\345?n\212\236\301K\235\337?$\236\335\001\022 \301?`\330\272\035\022\372\247?J\2065\001|F\327?\000\245\00330|\335?\253\311r&-\005\350?\302\246\300\323\225\264\347?Z\376\343%0\265\353?S\217\302\231\007\263\342?\25154\000C\327\351?<\353\220`\376?\331?,\335\252\nQ\261\316? \362\304\004\314\207\263?\365U\264\325C4\343?h\224\211\177\250\274\273?\320|\023\316_\335\305?\232\317\007\245\342\207\345?\0241\314\027Y\205\341?4\204a\327\000\272\336?Tw\240\314\271\252\311?\000J\255\363\217\347\311?\336\271y\276\202\263\337?\212\367o\250\2579\342?+\317\262\304f\232\353?\243l\246\3762\343\340?\273\240\242/\347\016\344?0\\\331s\357\304\302?\023{0\226q\376\345?\220!\256\0017\002\274?uY\305\276@t\343?1\3378\024\306\276\345?\020\333\263)\211I\320?\344\257\360\034\025B\302?\244\256Fw\325\370\312?\214fF\334\016>\330?x\251\316\355\2406\337??P\316\234eF\356?.\301\275$\264\251\322?\030\241\226\276<\276\355?\212\n\021@\3327\354?\000\215\307\0338\334\212?Kx\2025/E\342?\314)\327H\222\004\307?A]\001\255\337R\346?\271\315\'\274\201+\357?\340z\370\211\236\214\267?%\240\333\261w\341\357?\220\363\242\206\241\243\342?q\374]\302\310\264\351?\021\035\275\207\311W\345?\004\325\237\362\277\350\355?7\213\203\224f9\346?\346\260\245\212\205\304\353?\3200\363x\235|\262?V}=fn\343\324?\334\341\316\024\216\264\316?\306iT\016\340\273\336?\360ES\263\220\374\273?Bqv\377C\"\350? \302\276G\316~\243?\006(\233:\202\232\322?\326\265\320\211\003\225\353?\010\307\333\253\032\330\347?\354\270\353+\354\243\310?8o>\245\351B\312?\323\033c\017\366\374\356?\330\254\302\037\207\006\271?\240\032\255\025\300\217\265?\310A\247`\211\340\301?\361\210/\377\033\201\343?0\217\263&n\353\252?\330\001\024\005@\020\355?\t\213\303\006%\027\352?\252,E\0316\034\350?`O\223\321\005T\233?\224\250\351\233x6\326?\372\261\006\363\250\363\344?\352\3310\376\020a\351?D\014gI\030\177\331?`\214\022\211\212\303\332?\342\252b\225\207\373\322?&\237\356\255\'\037\341?NPw\223){\326?\330T\331\010N*\350?\020\255q\2264\212\340?f\020G}\333\213\355?\200\350\366\331\273e\304?\002^\226\336\217\034\333?\300\216\267\231\262\226\320?n\343o\340f,\333?\35754\360B\000\343?S\3741\360\342\316\350?\0264@Q\231\234\322?\256\322\354?J\323\345?\314\034\223\211`\314\344?\230\234\016\n\236 \356?\354\"\233j4&\324?\014\002\243\200k\264\356?{%\004\317\365[\356?\010q\224\234\327L\321?@pJ\330\307\231\303?5\335^\265\222\375\354?\264c\205\251\332\200\310?\244gN\212M:\354?T\354\006\227\233\037\353?\230\311\307\364 \335\335?\254\212\343x\322\315\306?Vz\004\234\262a\337?\236T\352\367\243 \327?\314`\023\2259i\347?\360\014ob\006\251\253?\"s\001\266\2534\331?\364\274\315/,>\340?\027\016\037\267\325\242\340?)\332\302H\311\255\345?\276\021\351\246\317h\354?\354\335_v#7\314?\274\375/\326E\265\334?\\n\335\312\373\353\317?(\363\201\035\307x\261?\214\000\374:o\370\334?2?\342\030F\377\324?\324w\331\245\254\374\305?\250\240\331T\324\241\332?\236f\252\257\237\340\330?\345~+\341\220\260\356?\010)m\3240\276\267?\330\314\322\371\211\314\333?\326j\347R\030<\327?w\332\214o\0234\354?\005\225\310\374\334\016\342?\005\345B>\r\306\356?\366\335\240\372T/\327?h\206\341\223\213\236\315?\217[\273D\254j\354?\237\262\344\226\370|\355?~Y\320\3155\265\347?\034\005\266u\375\354\317?\250a\271\220,\"\321?\004\275i\275C\202\304?P\']\307\003@\241?\224#GLw\'\343?\242\220\252\034\311a\344?\305\004\352\351\254e\346?\2403,\352\304d\256?\266P\376)\342\177\356?\300\235\227\000\365\321\235?\360NB\270\354\023\242?\036b\n\010\220t\324?4\245\205Kn4\300?!\026tD\263q\347?\360pS\303\003\260\337?q\3408}\330\\\355?\314\332Zt\377(\336?\022\354|]r\315\330?\2169\230#:z\350?\004\276\226\307\014\023\313?x\"#h&\222\320?\316\2633p\230t\334?\244J\233g\362\355\322?\324b\206\267FY\337?\244*\023\260\267L\341?\006\246m\326\2064\342?K\260\257[&T\343?\300\306\366<\267\224\300?\210\"\215gD)\326?\263~P\016\031`\356?\3202\250\277\213>\271?J\';\251\366\315\321?*_7,\324\037\337?\370H\266W\245\257\315?\260i\037\240\034\357\246?j\235\033\210\355\221\330?tG\377\300\352Z\327?s\361g\027\210\257\344?|\335\032\273\373\317\333?\300S\254Ibg\235?\262W\357\017\266\276\321?\266\200\030V\265\306\322?G\002m\366\316\350\355?P*\375|\030\367\321?\264/\020\347\327*\342?\246\261\264\365*\265\327?PbS\034t@\252?\310\215\367\263\377M\304?<\313\235\331\247\006\312?\362\307h\235\363P\347?\210\334\230\013\373:\345?\206\270\233EQ\216\347?@\220\206\254f\310\303?\324\2675TG\373\333?\205\240\344\230\246\322\344?\3001\350P\310\244\322?z\223\240\235\264\301\331?$:\333\000U\277\352?\214\331\305\243\275\014\322?\365K\305]\0058\347?\345\245\354\271V\204\346?q:=Z\220\375\356?8\032\352\367\371*\305?\305\204c\270A\247\350?\263\342\267\374\017_\356?\212\261\255\351\356\010\334?l\214|\244\023\027\302?`\365\027C\326\310\320?JR\371+\330\274\346?\246o\211\241\002\351\341?j\263l}w\354\321?\025\r\347\0304o\345?\002\263I\237\031n\330?\324\272\336w$T\341?\240\224\314\366\tQ\236?.9\341jP\200\356?@\265\350b\334\354\267?A\263Z\305\002\353\344?8\\\235.\302L\342?\316\010k\"&)\342?6\367!.,\240\342?\356\010R\215G\331\322?N\257?\256\033\232\334?8[\241\303\033\221\355?\002n\224\255\361\371\324?`*\364\343\003\343\337?y-\304V\367\254\345?\200\351\326iZ\031\257?\202f\3619\332\263\325?d\021\276P\375,\327?\006l\220\317-\250\353?\305\231\273\214\204!\341?\310\303\335:~^\321?\277\2008\310\332\357\344?L\240\331.\253\002\316?Rj\036 O\245\347?\204\327H\253\364\"\330?\\?T\330R\324\312?\235\347y\303\346K\354?4\341\334\347\332V\320?\232C}\213\232\177\327?~>\372(nP\357?\0344\177\371>\371\324?OS,6\000\"\345?\200\017:\006\277\216\305?\376\014N\271\027A\325?\251\252P\236\373\005\353?\270\317\233\264\360k\265?,\250\201\203*\370\321?\336z\311\203\224:\334?\340q\372Y\305\230\244?\002/\336j\'[\347?\312\215\352\225\204\342\352?\354\305}\3600\226\326?\306o\354\350\020 \332?\004\335\003\366MT\310?\\m\375\236\352\326\347?\260\211cH\312\264\261?Q\303\027xP\373\344?\022\260}f\nG\340?\347\244\301\360\204\352\346?\301T\034\014\214v\356?\000\320\024\302TYf?\314\361\371\264\362\304\353?\233\265\367\234\371\312\347?\025PR\325}D\342?_\341\007@\354\031\353?\320y\017\254e\274\303?\007\311\250kf\005\345?\357\2779\370\212}\344?x\256\201\363)\271\270?\330\3410\265n\216\351?\t\253\231_\345\226\352?o\204\360\257w\347\355?\n\030do\2330\327?\326_\\U\006\360\325?\256\0059|ut\356?b\033G#k\002\334?\n\001!:\346\252\323?\315\355\341\034>\t\356?\356\261\376\362\016R\351?\016\341\320\002\273\323\352?\\+L8\240\210\341?\204\222Gi!\247\301?\336\327;F6\234\343?\n\364\330dgX\357?\202\270\245\263\320\267\322?\344\334H\006\215h\344?fqD\034[_\320?\270\037\275\016t\372\264?\340ol\022\322\370\272?@\315~U\0001\326?!v!\305@w\356?\374\367\021\361>;\310?$w\257\376\324\201\327?\341\362\352\177\351r\356?mXG\025R\303\341?d!\306By\030\311?`\017\r\310\264\022\314?\245vB\024\327\000\342?\356\365Y\007(\370\321?\352\344C\363\362\364\342?\312\3714\271\233u\356?q\320\0356Yw\344?\266\031\033\336C\331\340?o{\025\215\207\017\347?\340\351C\264R\324\267?\000\006\"\214\037\030X?\210uv[_\262\315?\367QC\036#$\352?r\315\000\033w\210\353?K\241j\221~y\343?\000\365\261\327\253\231\276?PK\340$}U\265?d\267\374:0\204\354?%D#17\255\351?\274|\332\332i\365\300?V\030^$\0130\342?F-\t\312\237\276\346?\036\356\211\r!\232\330?\346\352%\341[\374\326?\256\270\352\326\354\017\334?\374t;\027\3117\317?ra\236\351^\001\345?\2222\004\340\342E\324?\324[WP)\366\322?\364NM\340\034`\325?\230g\337\0252\021\313?\377\024}\244\321\311\340?\264\246\024\3214\362\330?F\2638\2429\371\333?s\337\r\\\005\272\344?\300\314\356\335`\\\320?\306\302\273}\242\321\352?\035pf\3019\211\347?\300%\237\264\210\024\217?l\313\342~\325\207\343?\256C\333\006_!\353?\000\243\376\220R\305\313?\246\003>rq\350\336?\370\321\365\027\314Z\301?\367\024$\305\020F\343?\342\237\341\350\370~\327?\217P\3756\030\252\357?`UAsM]\305?\272^\250\007\277\037\320?u\302}\307\364o\357?\034\230\311\211c,\351?\230\016$\2325\263\323?\034\202p@\267\226\357?\320\322s\304q7\313?\314\334s\261`\340\337?f\362\217\344A,\345?\006F*\200/\"\331?8\251A\r\023\327\320?x\343\273fz\262\302? \013\306\364\305Z\224?\372R\364\352\276\230\330?Gs\361\2516\270\356?\223[\357Obt\354?\352x\000Go1\324?0\tN\215f\327\266?V\200\213s/\250\346?\240\263M\035\226:\274?N{\030\225\\\346\352?\236\213\3269\326\001\335?\304\331\322fj#\342?0\355^\314p`\257?4\021\361\253\306\321\303?\230wk\033\227\226\270?80\250\214\226\301\345?b\256\222\003\3030\325?@7\217~\320\205\355?_`\3424\357z\352?\004e\226\323F\224\334?\200!s\221s\226w?\363\\\266\234\021\236\340?\273{v\363!\323\350?\310X\212\000>\004\321?\334P\267;s\200\300?IK\034\2509\323\351?\326v|\365\225\256\347?\270\036\000\305\000\375\337?\006\345I\347E\365\330?t\344\375L\221\325\336?t\244\240S\013\332\343?$\360\r\372\251N\341?\340(:\303\243\363\302?@a\024\246\203\r\206?\"$\021u\232^\322?\214\024\317\301\353%\356?p\305\334\337IL\317?\000\'=\020\031t\331?\223\302L\375]j\351?`\271\221\326O\033\263?|\304&i-\223\357?\306/\245+\253\305\352?U\r\021\343\001\222\357?\314\321\225\366\204\326\347?\272\241y\365q\263\324?H\265L\352\220\314\335?X\202<\013\200\001\336?(\004\241\030]\255\346?\032@\375\273\240\024\330?\274jj5\3120\310?\326\214h\232\003\337\330?\240\207a~\264\253\353?.\021k\010\307\231\353?\206%(\020\3106\336?\032\300\323\200C\325\325?\220g\301\275\"\352\323?\340\255*$\225h\352?\314\243=SMT\326?\204\345a\3176\276\320?\375i\"a\303\313\345?\224\347\237\312\3533\314?0)\201\312\023\215\273?@\254\333\362\252Z\356?\306u5\315\237\t\344?g?\013o\233x\347?F\236{\353\344L\326?`\302n\204\226\371\326?h\326t?\371p\356?\2307a\307\365|\274?LF\020\337o\210\347?\300mS\036;$\324?\324\266\323\330\315\231\347?!\307\365X\020 \346?7d/\352z6\347?\254\206\252\236\261\234\354?\336}F1l\240\324?p\322+\n\260Y\304? \316\003(\224\204\345?\245\232\006o\311\371\351?\024\013\363\326\263\350\305?\022\323\251B`\322\357?@C\022\341\320q\272?\371\362\314\300\340\013\345?\034\2464\321\212\272\317?\262\327\277H\204\026\333?\221\317\267&p\270\346?\327\022\341:\204\274\345?\325\314\301\350\226N\344?\010\335u\377\222\314\311?\336q\201l\036\010\327?\030\315z\215\032\331\341?f\256\337P\321k\347?\013\224z1:u\341?\000\307_\035\222\312\340?\304%\203\336\262\207\301?2\372%\241N\375\332?\013\034\024\357\204\275\356?0\027\342\3214\222\347?(n\330\375ns\273?!\003\234\323\021x\346?uO\244\207U\302\355?@\302\330\024\377[\312?x\021x\225\342\304\274?8~\240\271\217\034\350?[\226S&\211P\342?^GI\361\257m\330?;\360Q\013$\031\347?\316\354\366 \372*\337?<\203~\336\267:\332?H\333\354\264,C\316?\205\354]zCi\377\340?x\010x\272m\005\263?e\337Fq{(\355?.\346:\341I@\323?\010\277\235\0365\264\262?sjO\210y\331\346?W\337\311\216\013\375\356?8\366\013D8\351\262?\024ds\2179\035\343?\314\237\322;|\320\307?\374HleJ\375\327?\263&\340\032f8\355?\360\022\221\302\017\233\340?\355\311\263\277\262?\354?\364\347\351\320+%\322?\005l8?k\344\346?\037=\2207\367\026\346?E\327\225\343\326\205\346?\nY\004$\034\345\357?\020\r\272\3621\304\334?i\213\001o\311(\342?\004\356\250\221\235t\305?T\373\331\273&\341\332?`p\324\205\256\256\254?\256\257\rY\006+\350?\222\224\005\213G\334?@\226Ou\242\377\337?\024\034rX/\037\301?\250\006\010\363$\210\350?H\343m~\200\320\334?\300\221+\367\021\277\340?\023\020\304\r\251;\357?X\320\304|d\354\274?\370\030y(\321\253\325?\376F\360\205\307\337\353?x\252\312\307\275\257\353?\201\r6&;\361\341?\250\312\026\021\345\363\333?\227U\310\313\245\327\354?\256\361\246\337 K\347?T&\360\221\364\270\337?\300\203\326\016{L\207?\276\327\346\377\010V\355?\306\nd\006I\026\323?\322\345\264\255\340\250\330?\312P\226}\020\236\336?Q\356\n\220\357\366\340?U\3118H\016u\350?\212\3160z}\251\351?&R3\212\026\336\325?\344\300\250\250\273\301\316?\332\334y\306\270\341\320?\240\213d;\317\014\332?|\223YK|\005\321?\237\277(YB\374\357?\242\250\243T\020\273\320?\3443\323\241\026\300\326?5ue\352\366\376\340?\020m~\263)\215\307?@e=~\225(\242?5\226\307\365\203B\345?\216*\367\004\177\354\352?\236\332G+\266\322\355?(0\315\202`p\325?\362\311\007\243\021_\337?F\212\367\201\207\357\352?\360~\315d\266\014\314?t\031W9k\025\325?M\366\3775\347\263\340?X\356D\222\324\034\351?R\033\237\025\242\353\341?(\323\022/\364I\300?U\3670\353\270\362\346?\224\223{\3032\347\346?F\325\222=K\345\321?\222\204\315\345?(\353\217\363\353$\306?\rG\300\240A[\343?\271\371y\032\276\217\342?\324\372\376\236\355G\346?\2448K\022_\367\322?\024\031\212X\245\354\323?X\300\211\313\'\205\310?h\020;[\334=\272?\210\364`X\241\014\263? q\221>I%\323?\370t\035\324\370\205\340?e\254\364\013H,\351?\373\020L!C\031\342?Nm\304\202\270\250\326?\224=^=r\235\304?\236\327k \255.\337?!\325\321\316\0064\345?h#\321\362:}\321?:\001\220\315\262/\345?s\035\201\353\006I\355?\330\003\233^\032l\321?\203\304\2276\000\327\351? .v\t\270\020\331?\000:\202\225*a\311?\006\232\245t\017I\325?\310\275\370*qc\300?\266\"\212\r\035\024\341?`W[\366D\265\273?\220n\364\243\363\314\304?j\360NY\311^\350?\263|\361\203y\263\345?\004\267\241\272gd\353?\010\341\301c\234p\332?\0200\322}\006\334\336?\250y\227\005\372(\350?\332iC\335=\333\347?VS\"\340\357\004\330?\336\321\255\264b`\351?\262n\334L\227B\354?\360\321\272\020\032\324\334?\264.\000\362\277W\343?\372\274m\346HM\323?:-6\341\016\335\325?*\265\022\340y\273\351?\354\2673\322AQ\311?P\256\213!\024\376\315?JS\212\024b\270\356?Y<\222D\263\034\353?\322(\344?Y\245\337?L!\243\232\321U\312?v9(\223\366\204\345?b\013D\311\237\372\354?\316\363\233\206\244W\340?H\364\025>\267\241\324?\\\216V\335\244,\322?\343.KE\033\373\351?\240\215_XWc\347?R\320T\207\320l\334?@F9\022\306\351\252?,\224%\035\026\311\307?8F!\330\347F\277?6X\372\353\237\013\331?w,35\314v\346?\031n\260\3606<\343?\334\360\033\001\035%\320?\037#\200\237\206\373\352?\366\272\241\247\334l\327?:B\245 \314*\320?d\327\007\352\303\371\317?{l\257:5\325\341?\0166s3\307\325\325?W\262\022%\010,\357?\220\354\357\375>>\333?t\346g\346\223\302\346?;\n\221\202Y<\342?T\363\211\264\216x\325?\013h\302\220TE\347?X\262|\2124\355\340?\330\345\324\026\257\277\260?\010W\031\344_t\275?{\034(\034\030(\355?x\303^\353L\370\347?\030ViK\333F\327?\340\263\211\311\346\247\241?7\004\314\013\"y\344?\210\201\321U\245/\347?\327\376`\315\225\202\342?\360\314\224\274\357\021\272?\365\351\213\035\213\233\354?\'\365\2757\003\365\347?\320\376\353\244\036\267\326?jR\274\375\"\251\321?o\210-\313\345\034\350?\034\202\267\266\204\371\326?\010\267t\2053\200\265?`\320\r\341\005\t\263?C\003#\017\333\271\346?F\006\252\034\257Q\332?\253%\263\'!;\346?G\277R;L\221\356?\370YZ\354\177\355\270?\352\332\271\020M\247\321?A?\252\253}\252\350?\274D {\347W\342?X|\231\241\345\326\333?\247\310!\331\365Z\341?\200\252NH6\357\217?H3\367\245\004\330\266?[\\[\307\355\'\352?(\353\360\203\316=\275?d\376\206\013\236\331\332?n\216\020\330u\250\342?Z\246uS\244\027\340?\355\221{$|\223\353?k?IBu\205\350?\352\367]\232\355\333\327?\016|\260\251\350\336\347?X\013%\274A*\264?\224H??\314\260\332?\274o\211\301\313\354\320?\232L\2374\035\341\333?\240Q\302\210!2\235?\372\302\265\304s.\335?\2600\270\240T\225\305?\320:\335\365\372\016\263?d\244N\253\210\207\310?P\306\207$\244\312\337?\270e\250}\226\210\354? c\307\n,:\301?\212\364\314~>\266\344?\327\253\035\227)\322\353?!L)\016\177\307\350?\356F!h\207b\325?\002\360g%\332\365\353?\240\254\224t\324\230\266?\000\\\3068\352\347K?\253\234R\016q\307\343?t\310\244\233\247\307\334?XS\276\254\320O\302?\200\300\353K\2569\353?\020\021)\206i\206\253?\2507\\wd\303\320?6\241\030o\364\236\325?\034/&gZ\354\336?;kW\216\322N\353?\224C\265;\312L\323?\n\243\333\2722\\\353?\200\274]\200\255\347\202?\344\277\025F\020b\352?\"\334\024DY\267\346?\204\354fya/\303?\360\245\242\3640\225\324?\245&\2155\'O\347?Z\341\r\0348\330\336?^GT\003\033\274\337?v\3763\301Ql\334?t\316\231h7\234\302?\300gJ\210\376p\336?\200:M\332\026\020\227?\302\242\027\214\246\010\323?\2604\325\244L\371\346?L\003\260\327(\302\314?\232\311\357\013gR\320?\342\035\337\316V\305\331?\230|\233\326\0027\350?I=\026\2324\010\352?H\360\210;\247\222\274?\200\213\245\211\262\375\233?\212\035w\207HA\335?\302u\301!W\257\325?~\310Q\002`=\321?\224r??\024/\330?\231\375\305\212h\321\350?\230]E\0327\023\355?PX\261Keb\240?f\006r\361\353\254\335?\200\265\222<\n\t\355?\244\374\301\010\306\010\343?\316\365r\274(\016\322?\0206\025\340\323\013\275?\376\037\327e\312\357\322?Z\302S\376A{\332?\030\025\275\270p\342\317?\202A|c\0173\357?\232\247\333\350^\247\326?&\275\225\343\300\003\347?\301\246\260D\365\342\341?&\202BB\261\373\323?|RX\203bD\316?Q\363\340\302s\233\357?@\006\021\235|R\272?\264\325+\315\204\336\316?0t\024\202\2042\252?M\252\307%`s\350?\214\335q}\2605\321?\367\377\231{.\251\353?\310i\232\"\377\217\341?\365e\224\214Z\254\351?\344\315\264HP\312\356?\000.{Gm\014s?\024%;\345\327z\316?\024\331\026/67\354?\355\206\245i\r`\343?\306\233\350\016\tK\337?\330\023O\362\371\030\304?@.\335\333\353\002\347?\362\264,=\312?\335?\326\276\221\0015$\343?\370\363\333\347`D\321?2\000XkP\300\331?\002Ty\305\021\314\345?\n]7f\334X\341?^\r3\267\332\262\346?(,\335\274\344\227\344?\336\010\263\223\265\203\337?2\312YR\244\016\324?mY\253@\264d\356?D\253\356K\343\345\344?\223\'/2;\326\345?\254\000\214U?\273\302?\004\333\215\006\364|\346?\036\337\326\367X\336\335?\014\244\221\2612\221\342?\000C\204\373\002~\221?\002\375\337\006\374\236\336?\036\342\272\330E>\356?D\204\251q\004\217\306?/\342\025\263\254\022\351?\304h\373\365X\221\352?\342\t\365\315\334\017\347?\274DQc\026\223\352?\313A\337\264\337P\356?\356\353;\202\177\220\353?[\232\262\315{\272\354?\274B}\355\026\224\334?\302C\374\003>\234\342?\321\203]\230F\315\345?\251\212\367\255\221\250\340?MoxoR0\346?\337\354\247\220\320\201\341?\352\027\226\000\216D\355?\320\375Y\306\271i\244?B%\362\272-\251\343?x\367Zj\030\235\353?\2263\352h\331\342\323?\360\025Z\2757P\240?\314kK\304\245\342\351?\'\264H\277 \234\340?Y\330VD\031i\350?\177\364\327\365\361\034\342?\246\006\004\232\215\273\345??cd\377WY\341?T%,\022Y5\337?!\032\265\000\257n\356?0\336\326\370a\020\243?pH?\227dh\323?\270\331\320u-\323\315?\222\316K\222W\\\357?O\023\336Z\\\315\355?\257\315\221]\246\224\353?\363\324X6\207\232\353?\327,\177 \351\036\340?\200\021\210\363V*v?\274\233\356\r\310\316\343?\245jT\035Y\333\354?\032I\\w\254\230\342?-j\311?x\275\345?\3477qu<\221\356?\354_\3171+\231\306?\016-\371\305\230\355\335?7.J#\330\344\351?w\224q\023\357\035\351?jwc\255\350\217\345?x\035\221\212|\363\332?\353\362\350\213\257j\344?B\013\322\357\025.\350?(R\377)\277?\301?$\257\347\350\000U\345?8q\331\006zs\354?\214B\005BS-\343?O\235\276\302c*\354?\250\005\234\270\023\240\345?P\340N\013\377\210\314?V\223[\364\005$\320?r\377\031)Al\330?\374w\230\027\352\244\336?\306.\231\325\267o\354?`~\213~\301{\232?\316\225\000$\206\353\334?h~G;r\277\303?\254\357K\2148\366\344?\034+\357\213T\300\305?@\007!\3730\264\313?\031\333\314f\365\210\354?\364\272\237\2444\276\320?{\305\304eG\332\353?R#\336?\230\277\357?\010x]\332\350\031\277?d\231`Pa\371\357?\266l\265\325uO\342?\350\2503\356iw\303?\270\244Y\n\004\227\317?\324D\364:\226\255\322?kp\327]h\341\357?\322\274\252X\007#\324?\231+\330_z{\353?\\H\004\3159A\313?\304\301\371Xls\357?\207\244\257}\230\305\354?\312l\314\311\352\244\341?\375\242\260\r\327\257\341?\312\245\3453a\330\354?\230\226\247q\025g\315?\342(;\214C\263\350?\365\314\235t\036\236\353?\226Xx*gw\350?\372\3414tf%\340?\323\r\223\324\005x\355?0:\204\003\337\243\317?P\200\'\260hU\274?\234\t46\210\233\352?\226\261tk\221\206\331?\316\341!O\211\017\354?P\201\010\304\211U\273?\014\005C\\\377\000\340?`\267F\350\236\200\307?\266\337\337\206\336\203\325?\214\014\370J\256\033\305?\024@\\D\375\002\341?\226\343\367g\272q\335?8\225\371\372\037\204\350?o\217Y\343\253\014\353?\320\245\362l\"\306\301?x\030\267\376i\t\307?\014\362\376\375\332@\315?\270z\357\314x{\330?\242?[l:9\341?l\332@\326\361\377\307?\210\375J\312\252#\347?\n\311\3646nq\345?\364\345\251\220f\226\352?\270\312\377Fw3\276?\344\354\233\324n\206\306?PQ\200B\026\342\315?\010\251\266o\315k\316?\204A\31076\212\301?NYV\342\226\273\343?X9r/\227f\272?\003\237\265\246\030\277\344?\236+\023\016\321\314\334?\360d\034%\345\263\325?\247\370*\271\233u\342?\374o\023H\244T\317?\263a_\342\375\273\340?N\210\255\222\000\364\356?\252Z%wu$\321?\342\2538\327\340\200\351?(\t\244l=b\344?\244\247\371\267\310\230\327?~\271\376y\236\215\335?\000O\242Zo\242\251?\313*me\346\024\340? (\302\205\207|\240?k\257\357\365\261/\345?\212\237\361BF\370\345?\260k\213\373\211\004\337?\314\022d\246\200\030\347?&\205[_\363\250\345?\265\354\243\307o\232\344?\340\007\030?\3344\225?t0\230\032R^\346?\010\202\241\220\345\262\347?\"\'\177\344\020\035\354?\334\251 \334X\214\306?\334\310\336\206Z]\321?S\366\345\340k~\343?\340\002X\010o\257\263?p\216\271\014y\270\333?)\354WJ\257\277\341?\362\211/\303\326U\320? ]v\000\005\237\304?>Q2\006\016\017\342?\200\377{\nw\362\277?\230\251\341\312\227z\317?\366\2403\302\237z\350?{\300\201)/N\343?>\202\202\220\201\252\325?\356\334\376T.\373\345?|\327\330\024\251\342\336?\212\340\310\315=\227\322?\250\001\213\327\362\212\332?\002\215\337\326\216\327\340?\270\371\277p\234\270\357?\246\345\350\n\2441\322?\334;\305/\241\023\340?\036bd\361\030\245\353?\360nH\200z\315\344?\270\244p@0P\277?H\333\331\234\212\274\357?}\353\214\355!\346\357?\314>\330\324\3776\330?\305G\250_/\335\340?n\353\226a\211W\333?\210*\213\036\205\370\343?p\212B}\355\350\346?\2055\277\374\3269\353?\224$\260\007\030\000\321?\257\3240\314\235\336\340?\000{\364\\\031K\301?\310;\246\231hq\350?\314\373\256\374\2039\347?\361/\023\366\332\374\350?\014W\021\214_\220\314?\370]\0067@\352\272?\260\264\210\356\353v\342?\304\037+\210Bd\303?\350\n\223]9k\340?\324\210$O*\241\322?\000\255\246\243\323\311\211?\360\327\036\372\223\267\324?\240\266@W\254t\251?<\2271\242\204\366\322?20\334\020E!\336?\2070|?\301c\353?\345\353m\006\'\321\354?8O\275\340\221\346\272?\000\342\007]\313m\345?`C\373\265\200h\227? \231P \342q\321?\022\031l\222\037\307\353?\214\177\332\343\240 \311?\266\037\216\233\370\304\326?\022\324\031B\303z\327?\354\302\343R\357?\243D\335\354\303\245\355?&%\234\232\204\276\341?0\035\r\242m\356\274?\004\256\253\025aQ\352?(\307\021\205q+\316?>\251\024>\250q\350?\340\023\010o\261>\233?\332e\351Q-\001\346?\000\t\235\014\310\241\215?\2369\231\307\311\337\334?\027a\0277\236\235\346?\3256\345\260$\037\356?D\360T$\007\233\311?\330\t\004\006\031\301\261?\022\377\2162\360x\352?\320:\355\017\361\323\311?\324\r\266\210\r\222\344?\242GQ%\244\035\336?\225\237\370\256\025?\343?4\370\017\205\324\265\303?\273\276\037E/\001\355?~v/\033\202@\320?\310=#\363\274\321\260?P\204\214\301cZ\323?\255s\201;C\r\340?\206!*\226U\372\343?Dq\035\221\361\214\337?h\325\331]\240i\355?\366\215\367Qpu\350?\314KF\340\272=\321?\'\300\343m17\342?\000rg\034\313;\235?\\\036\316\n\337\201\345?\340\\\203(\030\230\277?`\"3H\372\260\270?\312\211\363|\312\t\336?\213\344aZJH\357?\204c\351\362@\021\342?j\211~\266\000\002\322?x\323\026\326\253\336\322?\3419\342{W\277\341?A>\337\263\365\250\353?\t\255\261D\335\307\347?\255\337S\272TN\353?`\346\010A\236,\332?nV\'I0\016\353?\377S\322\300\034\322\346?\264\304\371\304\344\316\353?M=\330@\t\"\355?\326p\254\215\312t\350?\203d\346\230:\330\352?\202\212\254*\252\222\353?\265\227\373\374\017\260\350?&cX\014;3\334?\300\356\371\010\"V\330?\210\211%\275\357p\330?\303\351\311\227!\344\351?\354\224\300^\333\221\300? \311pE%}\310?Tm\326\262\013\340\301?\254\274G\026C@\326?<\nK\302x\211\331?0W\363\256c\032\254?\204\020\000\013\350\342\336?\365\005o\316\207\247\343?\0203h)\002\265\304?\310\335\032\344\224\311\341?\241\275>\345\233A\355?$\327\215\254\030\314\356?@\021Z* \307\264?x\201\372\342W\347\340?\350X\212k\3324\321?\024MA4:c\300?\264NwQ8\025\316?\333\034\024\337\203\330\352?\276\252\316\360\177e\333?\256:\206\036\310\321\353?\256n\277\203\254\223\340?\220\315a\342\301\254\332?b\001\351\212\226\241\324?\340l\342\034\307\334\331?htf\300.\026\263?\334&\247\021X\315\350?\274)jf\314\344\306?u\031\035\250\262n\342?\225\355\275\007\245\363\341?\245\363\230\tG\271\355?\360\265\324\344u\326\355?\254:u\"\377\201\344?%\026\326\232\024v\355?\224\251X7?C\341?(7\362\214\004\376\273?dG\317\320\315\256\303?\215\311>\013\222\014\356?\n\333\312\352#\255\354?\337O\2728\353@\354?<\030T\373vT\336?\346\307\001\273\022\022\336?\210\224\'\256WA\304?s\271W\300\254\257\347?\350\376\2265\2251\310?\220\304\026\2102\304\254?$\231\366+\304r\312?\265z\230*\373)\353?V\221O\014\263\266\325?\200ew\373\340\360\245?\315\220A9\202\307\353?\240\367\2425\252~\245?\320\204XL\007\336\345?\316Y\r=\266\254\347?\271\006\375L\177\201\347?}5\334\301i\244\347?\3373LI\010[\350?3\242,\330o\326\347?\236\021\306\271\001\005\335?\242k9\351\231U\335?\260\335\223\245\232\235\304?f\"\345\020HM\326?\234\303\035\206}\003\310?\360J\324\364\236\004\247?\222\360\376\001\230 \321?\324B\t\010b\312\341?\270wf\017\311\004\333?%\2521\364\270`\347?\204$P\374\267/\311?\370t1Z\363\023\263?\002\006\"\006\316\351\350?^2\2042\272\177\322?R\306\220n9\237\330?\245\004\366\234p\301\355?PU\341\355\323\213\312?\004h\322p\020\323\326?\000\014k\0213\310\324?:^\360#\235\013\330?u?\234\210m\005\353?\307\266\361\036\364\001\344?(\233rj!f\316?\252\033\247\262\325\202\345?\344J\222x\250\254\346?\345\235\344\001\352\343\356?\250\024\376\261Z-\340?\036\200r\273c\210\342?\270\'3\231\365x\273?\000\306E1\r\247\352?\300\304\032\265\305\203\337?\001\377L\2551e\357?\330d\231\037+3\306?\\\370\013\371\321\025\313?PRxez\013\260?\351\356{\363\3050\343?b2T\007\217\307\356?Dc\355I\260r\310?\372\026\2013x4\331?\300\300ED\247\356\253?\\\007\364J\022\n\311?u\255n\266\2310\352?\216\017>\366g]\333?\350\034\317\321\321\026\301?\030\016\254&\200t\342?\334\315E5\335\271\334?f\363@FM_\323?\010\353\302O+\261\275?\255\215Z|)\306\346?<\211lrE\206\336?\3676I3\033@\343?\252u\233N\325\214\333?\256V\317\307\016\341\352?\017y\340F\332\024\347?\355\235?\311\326\\\353?P>\213^\217\353\314?0\340\353T\273\225\353?\244\312\212I\300\\\343?\377\272\366b\375-\352?7\241YC\037q\344?\360\014\226\353%\t\251?\226D\274\036x\246\341?5\024{\237@\202\346?P-\341\005\262\327\345?\264\360/EG\033\310?\220\352!\3410\270\274?,\222\214\004\275\353\331?t`\241l\271\252\335?8 \2726\260l\334?\330\300\271\322,\354\344?Xf\351h\326*\336?\304\022f\253\342\306\350?\020\215\227\035\235\321\336?\351\331\006\235\023(\354?+\237\221\320M\000\343?\250\300\266J\356\254\337?\250\007\235\177\331(\272?\264\016\211\307\027I\312?r\243\233@\255/\321?K\007\320\315J\322\347?\237\235[!\022B\350?\362\343*(\240\376\354?\014Udz\321?\315?*\316\316\213\215\207\340?P\016\003\346tU\333?y!\246\256g\214\347?r\371\224\305!T\341?De\230\207\'e\301?\230a\245\no\377\357?_}.\236WJ\344?f\366h\200\2376\337?\217\337%\013\261f\356?\014\360\212\274&\222\353?\326\017\n\245%y\332?$\346\034\022\376Y\302?\300T\336\363\351C\247?a\313\035\036\025#\355?\211\014\371:\344j\347?\210\257\2254\0042\324?\372\217\234\177x/\346?\334#\261\271-l\330?\200\014#\267k\256\332?8\270\'\234\371\025\320?N\331q\252Mi\351?T\342&Yk\n\306?Q\027\245\n\233\264\353?\254\\\200\333\342<\300?\231\014\300\274\355+\350?\200\303\355U6&v?\222\227K\311\014\314\341?\262\260\246\016Sg\345?\030A\206\320\205Q\303?\340c\312\375\341^\300?\200\217\350\231\"l\346?\032\364\300Zm\362\321?\000\271U\300\340\007s?\204\231\203\277Va\313?\304G\2324\307\362\357?\004d\276\212\256\306\315?\350d:\021\033e\357?N\277\255\2261C\322?T\221Y\360\257\025\314?\372q\311\352\252\336\345?j\"O{9\230\343?\240\225\035a\341\253\253?\035\247\336\232(\364\340?\300$\277|?%\256?o\0332\234\006U\350?\204=\246L\250\332\304?\245\2609\347\342\304\343?\264,_y\016~\320?\224\n\241[\001\227\306?\310\331\246*\2340\273?\230\031\037]\037\314\356?d\314h\253\335\265\326?J\206,\212\216\377\326?X\025\347\272F\020\274?\010\201\313W\007e\356?\345\201\236\226\266U\350?\210\206\224\230\201\273\317?\201\237\302\207\363J\342?%\037\262\013{\362\355?\225\206\323AD\367\347?\014\370\331@\347\220\355?\007\375z8\260\'\345?\2101\214L5\"\325?\034_.\355\343\307\354? d\234\275\2547\307?\305\322\312-\016`\345?\3015\314\203\241\016\342?7\350Ke\237\263\346?\340\317\372U{\316\342?[\360\253T\214\006\347?\362b\363\350\215\324\352?\260\331?\t\323\345\342?\000\026v\345\\\263^?W\372\372\236\375\250\352?F}\0272\306I\335?T\335\244\021]\347\317?\240D\227\312v\266\226?\322\343\311\324\331\024\326?w,\261]\227\016\345?\000MP\177\220\r\357?\000\2272\343\020_\356?6\306\233\023\231\230\346?\201\257Q\271r\231\342?\225*y.\203\227\353?~\263X\037Dn\325?\223-\213\354\240>\347?%,uj\373\321\357?\310\271\003\224\223\005\324?O\335\261\003\330\351\352?Pe\263\330\335\320\353?#{\276a\365\236\343?\000\205\217i{jx?\020\310\361J[E\247?\02513\262\'h\355?\016\357\n\245v\234\336?\302\264\2660\302\315\344?\224\344\010\340\310R\340?\230\021\264\232\217\007\354?@T\220\265\311O\337?r>C\243ln\327?z]\007\322;\266\342?y\245?\343\2610\354?\247S\033\344\364=\340?\272]\3374\031\357\330?s\300?/\271\350\3719w\356?\203\200\372\354j.\353?;%\316\236\346W\357?\256$\016\022\374\356\350?T>\331\276\325\026\346?\n\007\367P\370\354\336?\020\354\265B\375\005\322?\303\371\202\231\361\366\356?/\241%\227\333\325\345?0C\236\244\204$\274?\236\246`PV\373\327?C\3605\351\214`\346?\304\007\314\354\310V\342?\356^\265\036\301|\355?(\245\017\025\324\323\262?\272\221\335z\002\333\347?\274-iBq\014\353?\270b\336\235\252\370\344?H\001\243*\320\004\323?Z\3511I\315C\322?4W;\347\336T\337?\222\303\005\007\313 \347?|\234i\240Y\267\356?\\\271\034\331\375\262\342?\264\000\237\362\263\253\347?\264\200h\331>t\333?;\266\213\244\274\315\352?\022\244\001\346\356|\326?n\215\000\260\356!\343?\376\307C\241o\254\326?\006i\336\225w\017\320?\364\226)}\217\300\302?r\210t\271?\023\320?(\270\331\260\231&\340?\330\205\204\345\3217\300?4\207\365w\026\237\343?\340\277\2322:W\325?>\304!e\264\224\322?\224\223\367\267\026\316\327?\004|\244\303\251\030\316?*N|+i\371\335?P\217\r\2312&\262?\034\2639\324\030\243\347? \314\251&\033g\310?\004~\221\222\327@\350?\310C\235\266\365\376\266?*t3\000\307f\336?\215\357\027aD\333\353?\216E\007{\235\230\340?8\333\234\264\316C\267?6\355\007\236\203\030\322? ?\370Q\340\233\324?:m\005E\326\014\334?\375\266D-\247:\357?\030\365Y\006\010p\300?\0364\206\330)\025\356?R^J\025\"\237\352?\266!\251\020\"\274\324?\230s\314l\272\300\331?\024\225\312\2668J\333?\223\330\0043\021\242\351?s\263\247\021\351\020\350?>\320\244\377\344\376\322?3\377\232\314h\016\357?\321C\251\241x\212\350?P}\202P\203:\315?F\026{\230\337\345\353?>[\213\013\'8\357?\316\226P\267<\007\352?8F\020\327\213{\317?\250i0\277\224\333\264?\336\261\270\350\375\345\321?X\363\207\320\014\013\350?\253MKW|\310\340?`SQ\276Y\316\305?\343\237^E9\243\340?\024\322\237\035%a\335?0\002\217\276\367\354\254?\270\2749\317\373\236\265?\223P`\314\234\320\353?,u\006\004\026F\327?\350&:\313\237n\305?\211q\360)\236\333\344?\240#\022\303fp\311?w\312b\222\2572\347?\370\276\276\272\t\021\302?\324\276\270E\030\376\343?\362\005\354r=I\350?\222\2741\222\315\253\351?\340\364\014s\274\250\265?\340\362\354k\033\221\302?(:\201\272gl\300?\361\262\335)\324\354\355?\030\364\314\204\336\325\323?)\030\266\207d\212\354?Y*\316\373\375\253\357?\244\232\351\231\327E\307?~\323\373\177S\202\321?\300\226K,\374\334\357?S\246\226\"\342P\341?\376\212\334x\324\207\356?i2\337\322;\334\350?\234lluD\356\325?\200\n\342C\365w\246?\003\251\303_+%\350?\210\370\203\323\252k\330?`e\247,ik\257?\265\342\347\261O_\344?H\007\222\007\263\316\323?\335\257\035LQQ\347?\004z\'8\360\033\324?\020[\307Wl3\317?\364\301b\352bt\315?\314\336;\030hZ\355?\220\234\032\376;\002\327?\3447\301\220\002\244\350?V\316\324\327pE\347?^4\371\300\001~\335?\3109\016\235S\274\331?\260\016\024r\213\264\306?4\210\301\366;h\340?lh\220\2259\273\310?\\\022\345\314\230\373\345?\330\004\234\357\236l\275?n\336\327\231X\267\336?\005{\3113*@\356?\324\007vql\352\304?\252b\237\034\301A\334?#:\227\266\":\345?\256\312$r\251\364\345?\260\306\324&I\235\312?\304\315\014\333K\265\352?\260\240\331\335\005N\254?\270I-\362X&\331?\275\257\023\303^\036\353?\035a\r\365S\364\353?\250\202\272\314\207c\335?D\376n4+6\324?\350G\0217z\375\264?F(\344^a\277\327?q0\255+z\205\351?\314\371\357\262\037\364\335?\200\007\341\222\227\327\222?\376Z\233l\020\273\352?\324\006\245\\\240\005\317?\367u\014\030_l\342?\257\324\314\341\377\312\341?\r\207c2\305\326\357?\203\203\024\030\271\372\355?Hyk5\333\017\267?erf\221fv\354?_30P\306\225\347?\360\231o\235\t\337\332?\234\337\354\255V\006\347?H\354k\315\375\227\265?\034x\226r\005\346\337?r\356\252\240\263\313\354?\374\321\014\230-\275\327?pp\260\365\024\370\341?\000\'\377\230\232/n?\034R\207\024LD\333?\356c\227\360e\022\357?G\361]\201*e\354?\000\'\031\220\347\373\272?XI\323X\004\001\306?\331\261X\032 \362\355?\007\320\365\3052Q\340?nB\303\256jz\344?\246\034)\225k\242\346?\030\340u\306\037\000\322?0\003\326\316\005\260\315?\375gP\303k\243\351?V\375\267\254\256\307\352?\000u\345K\255\355\270?4\202\216\307\365\036\306?\220\356\321\312\3768\301?\201\310~Y\257V\355?\340}`\275\206\177\247?\334\225\316<_\234\346?Hk\270H\224\030\311?\014BY\302I\302\315?\022\274\314\366Y\250\334?\233=\005V\232i\356?\257\307\027w\261\\\356?\030\032^\\\261<\277?D,\321I\333\305\303?\216\016\235\354C\306\331?\256lfz%\205\331?a\315\230\274/\033\346?[4\222\231Oy\343?H8\324\017\254\330\311?\260\257\371L\325\265\313?^U\325!\356e\354?\362NK0\035\001\346?\250\377>Xz\200\300?\000\023goE\264\260?\026\357h\265\013q\353?X<\306\376L\370\306?i\331eI)\036\342?\307\262C\304\006_\341?\200n\260\024\263\233\344?\030~\200\021\212\303\326?\002HX\362k\365\350?\322@\334\004/\366\345?U/\314\t\205\344\357?0@<|\207\243\344?P\265\016\\^\324\322?\340l\366\351\323a\226?\316\372SP\272\000\343?N\257\373\304\255\316\325?\352\251-\006Q\200\351?_\235\362\306!\317\340?\331\374\275\035/\210\356?<\202=\034\007p\303?,\036\027\354}\275\322?@3\241\014\221\334\275?\343\240{F\355\256\350?\260\231(W\031\320\321?Wo\'l\353\021\357?\224P\256\267t\324\301?\202\207\222%\036\001\336?\337\371\224\202\356\206\344?\242\324~M\314\355\325?\264u@\260\326\253\332?P\264\337\305\242\325\242?ZFN\377\205\013\356?\030\221X,\341\005\264?\305\350\360D\312F\342?\177\342P\262\213n\345?\364\002\031\033\200\242\352?\220:\271 g\202\333?#x\027\232\367\242\340?\276\233\352\037\343\302\356? KV\305>w\336?\250C\000\301do\271?\255\202\306[\302\261\340?%\374[dc*\356?\324\327\312\'\037\301\355?p\231\367\324A\314\322?\316\216\271%I\027\351?V$\276\374pW\337?(\275\342\364\230\365\316?3N\335L\016\347\352?z\342\365\363\353\323\336?D\352\032`\016\000\314?\350F\331\276\266|\306?\361\205Y\217\333\331\345?\200ypH*\344\253?\334\031\276/\243\304\312?\223\277\355\212\002\006\350?\000\250\263\337\246A\335?R\n~\316m#\322?PJ\250\227\0007\324?\030\006@\347\246h\320?\022z7\206\314\343\337?\r\374M\327,\224\355?e*\236\341\311\217\354?\270W\366\363\'J\271?&\340\2523\206\245\354?\344~\211\313\014x\306?}{h\313\030\271\350?&\363\344\226\3604\350?\232L7\377G\353\356??\324S\304\035\205\345?! z\333\202\345\347?Dx\332\016W\321\325?\234\264V\004\216N\304?X^l\253^\207\277?p\234\311\305.w\241?\202\312qw\334\237\345?\244\234 \337\n\257\336?\376\311\222\356!d\337?\234Q\036E\221\244\306?,j\373\203\247\261\312?\300it\2006\333\256?0\251\'W\313\300\302?\200\027\025\033\200\370\214?<]{|]\267\325?\371z\227\236\246\017\356?\360e\310\351\to\272?=\325c\347\261\360\340?\214\3258\244\337Q\324?\0334\244/\010\203\354?\301\211W\320\275\010\353?\240\301\010\035Q^\233?\320g\265\360G\022\320?\3140C|\200n\326?\274147\037L\311?\323%\227\"8\240\354?0\267\276\365\375\324\354?\205\367\242m\001\320\350?\212\346W}\000+\324?>\373\021\240\316\231\352?c1\276dQU\351?\300\213k\240B=\337?x\212\222)\234\314\357?\307\202bA\344\203\346?\252d\222 \303\243\343?@b\241\010\326\033\316?\303=#\233W\350\344?#\222\314\275\037\t\352?\374\211F\311&x\314?\006\363\205w\262\216\327?\001^\247\232=\016\354?\315\365OB_(\347?\311\177)\202\377\366\346?.\236\311D\010*\353?\304\026\223\304\007\032\326?\210\nI\361w\236\273?\330\307\264\347\330\217\344?R5\360?\032\203\331?r\307\202^\347\026\333?\274#\'\222\243\036\343?\\\266lG\262\271\333?\261Y\273\017\346q\345?\340\264\315J\td\326?\357\257\357\313\246\246\357?\340\010\003\013\212\246\262?\302\261;xR{\342?\260\311\213\323x\214\350?lM\002\\{\007\326?d\034\360\354\312\353\336?#{\267&\274]\343?\213\247\207/r\026\353?\330W^\311\2365\346?R\271Y\016\267\272\354?\224\023H-\234\013\350?P\305[\357kY\255?\255\216j\177\007g\357?\366\025\006\350\230\300\324?\206\021\373Vf2\355?\2670\254\203H\264\353?\024\351\201\202\301\026\333?\200#a\224D\302\200?\200\243\267\276\033\204\271?\370\234\300P\312k\313?\375\207\276M\212\231\341?;T\300)@\206\353?>]\303\247\260o\331?$\361\263W\001\013\323?\"]7\013l\246\350?\350\367\354Z\340@\302?\017\367\306\007]\352\344?\334\316\343\376.:\336?\260\034>iYa\320?U\236\n\017\215\212\347?X0w\246K{\272?8I:\235k\323\345?\245\007\373\226\205w\341?\\iT\361|U\313?.\366&\243\013\353\351?D\r7\236\nQ\334?\313\205,m\342\236\342?iP\361\253\212=\344?^\030\311\022\330\376\340?\030^X\362\244\233\337?\036\256\364en\224\321?\270s\254\t\010\025\324?\005\245\rPbq\356?\347\325)\010\227\260\350?(\235]9\353\344\356?\360o\217QX\361\277?xUz\370Y\303\264?\300\220I\0036:\201?<6ch\277M\311?\'\267\241*\341\325\356?\"1\326\221\201\351\335?\014\277\374\233|\237\316?gH7\251\022@\345?v\007\362\267u\200\353?\236\266\322+\304\225\332?\254\031\267\222w\305\314?\367Y\353<>)\354?\266\002\204hd\013\326?\202;\006t\302D\341?\276o\177\003Ms\352?\020\205%\3668T\305?\270\276\375\351T+\301?\246\343\353\204]\022\350?\n\332\370\200\261\020\337?-\344\263\372\350v\340?\320n\326\352\356p\321?\266\024k\204w\316\352?\337\333\324\013\016\225\344?4\310\342\332OF\321?L\301\035P\\\361\310?0\311P_\360i\256?l\365\205\006\030\303\304?\370\225LL\213\373\272?`+\027^\004A\240?\330\"\265\025\316\254\344?\240\346q\340kp\345?\021s\314;\260\004\350?\344\343/\026:P\346?\234\204\206\310J,\324?+\2400G+\370\354?\341\013\022\235\346X\342?O\201\036\267\373~\355?ba\377|\010\372\352?\344\235\033y\214\246\342?\372\354Y\037\340\005\342?\272\265\313\261\020q\341?x^\235[\266\023\270?\225\356(f\037V\352?\271\370\3516\374J\347?\343\211\267\234\371\323\350?08\260A5Y\303?@3\217\020\355\250\310?5\235\004\204\202U\340?\000\256+\305\277\256\300?\340\234\266\204S\302\270?\321\340\211\251+S\356?)?W]\355\354\342?\312\212\346;\245\355\320?r>\276\201)\366\356?R.!\242 \002\343?o\033\350\261)\346\353?\251\376\201\230I,\344?\212.\000\343\231\371\344?v2\230\244\275\247\322?\351\205V\307Th\347?\306\004Cg\315G\322?\355\227k1\323S\347?H2\246\366[<\356?`g\023\263\233\355\233?X\375\317\"\343\226\353?\020\256\222(\3356\262?\314\017\333\333\243G\331?\030\263\007\355\023\263\340?\221\210;`%\312\355?\303\301\214\177\376\313\355?>?N\360\020\264\333?\010K\006\340\272\340\325?\n\016\205\t\374\235\330?\340\334\335\242L\304\303?\000\327;\346\234C\231?\230\243\365\030\233\000\327?D\027\214\224}\010\307?8\321\330\315\020\352\317?\013X\313\231\240\027\344?dq[\276\267\017\332?\260c\025dzU\315?\200X\336ml\302\261?\373\014\363i\023E\346?\341\'\276\200\214\022\346?\250\210\360\2130-\312?\300\t\277u\340\340\231?\312\034\256\337p\314\357?v\356\375\002\360.\325?\330Z\231\344^\016\337?\270\260I\274\244\321\301?\3401\025i\177\266\305?\214\324\234\301,\035\315?\233\326\024A\251]\344?@[\212\213\022\255\233?\332O\311\000\374\263\353?\'\204(\037+e\342?\350\312S-\033x\302?P\365|T\240\250\311?pF&/\332}\264?\242\242\232\200\327\254\330?\036\243IL\345\357\320?G.ux\304\277\351?9\0169\270\030\252\343?X\302\035\010\261\033\276?\366\344\241\022\333\346\354?\210\263\226<&\372\272?\317/\035\001\001\215\353?\27692w\205F\343?\262#w\023\262\316\347?YN\035\014\323R\355?\306G\240j\225\005\356?p\351{\034\301\375\242?\2209\267\301s\"\257? \213s\301\222n\221?p\274\314\330\275k\353?<\311\231(\276)\342?8\350\375jv\261\300?\265\327c2\221\255\350?\210\374\315\014\226m\301?\334\331s\246\004\364\343?\024\230\300\363\236\236\350?<\025$i\245m\354?\356\335\333S\275\244\346?5]\002j\310\003\356?\364\'\013[\303\306\346?\246\341\021\225\367\212\322?\330\315\265k \262\272?\314\356\276\365\251\275\325?\224kZ{\246b\355?\212/tV\333\n\330?\204\202\212\\4\237\307?\202\301\202u\234\262\332?\020p\341\251!\217\336?\212\010\0368\354\376\336?\365\270\344\331C$\355?\256\250j\263+R\324?bM\030\215\206y\357?`K\"I\314o\352?\342\340\2719\332\025\355?0\310=C\t\210\336?\234a\301\270[\263\303?9\204\230\222\341)\357?\030\224\360\321\375\207\324?(\365]\201p\257\312?pB\332\312\343d\247?\356T\034z\272\345\321?\234\304\217\303gH\332?\360HI-`\362\240?\030S;E\013Q\326?4\355\333&\340s\307?\024\277\247\274\263A\342?@!\205\036%<\357?\336\362\201\235yH\345?\265\252;p3\216\351?\3607(e8\354\256?\000\323T\2770\211e?@\270\222f\027\024\244?\356\034CC\320\316\330?\302\202\223vo\215\355?,mh\326\206J\331?\013|1\300\001W\347?\225!\312\037\365p\346?\333\005\200\024\265q\346?\335T\036;}\204\340?\344\236\307\r\242d\311?\032F\221\366\327\006\337?\016\272\020n\315\354\326?@\333j\013\337\261\275?\212\265\265,\214\n\342?\310\014\001g\324\211\323?\340\321\267N\256\032\312?\004\030\254^d\177\357?\003\333\373_4\326\357?\345I`\r3\351\352?\360\007\024\016~\363\313?\024\331x\312P\262\340?\200\305\034PR\007v?P\340\340\302\320?\335?\370\006\237^\214\256\355?\330\220q\244\304d\326?\220\003\336\022\252G\275?\352d\342\2041B\326?\005_)\217\267\204\353?r\272\270K#\226\324?\275Q\0264\017\022\351?FM\232!\345k\355?\214\202o\205\212f\325?\000mX\211\246\303\261? \361\325\014GD\342?\270\025<\003S\304\346?p\363\236rY\r\300?\"*A\020\t\200\332?6\2254YV\276\327?\340\243\365\257\250\216\331?\252\206\251\236\330\236\341?\310_{+\024\361\344?\363\326\263\014\360\226\343? K~\351\374#\231?\322\344\\\3770J\352?\210w\306\243e4\264?\260\236TT.\202\256?\004\200d\270\373H\335?JEN.\025\237\341?i\2243\342\\\200\351?\200\324 \216$|\274?\326n\327a+\372\346?\200\351\262|\023ts?Z\"\030\032\3057\323?`U(* \\\337?\343\373%\215V\336\345?\220,\223\237g\217\263?\n\227j&\261P\356?\330&\314\202Z&\347?N}\3323i\354?\306\2625J\346\311\351?\331H|#p\346\341?\204\377Ld\335\000\352?\027[L~\312\202\354?\010z:\242\237\220\264?\373\371\001!/&\341?0\332y\322W)\342?\030P\344\273p\253\324?0\237_\202]t\323?\364\353(\216\331 \311?/\362\313\343\210\345\355?>\250`\035\306\010\327?\237\007\332\312\365\336\353?\t\3731\240\224\034\351?3\343,m\007*\353?@Lb\314\233b\260?\302F\247\031L\246\330?\\\332\362\217\035\366\340?\222\007&`\377\374\352?p\276\"y\335\307\260?\372nw\001\275\345\343?\023\3145i\223 \343?`\177\r\325m\025\306?\215\265\220\220h \355?\360\023\341k\004\324\257?\234U\002@\366\204\330?\350j\257\033\010\264\347?\254+)1\014}\351?\230m\237\367\225\270\345?X\341\360=4\t\307?\324\374A\237\004\214\334?\020X\354\177\200h\313?\023\202\210\345F\245\347?i\303:?\362\t\352?\373_Q\342<\253\344?\237\371d!\330\007\357?\014\234\257~\177{\302?\226\267f;\373B\321?\342\305\214\256\002\350\344?\200:3N\033\243\341?\002)\255\310\016\214\334?\223\0259\313\325\324\341?H\024\353\373\333\263\351?\340\230\213D\3246\337?\247\036\301\375q\234\354?\254[[>\362\030\312?\216\250\014\275\206\221\351?\340\232K\302\337\330\253?\240\364$\302\222\222\344?\350u\374LO\256\346?\347\032\027\204\275\030\345?\2203\335\272%\016\253?@q\022&\020-\351?\340\231r\226\212\025\235?\"\217(n\327B\355?\312\267\264Z!\241\335?\260\356\361 \371\314\241?\350\243XX\031\230\343?\266\332_i|j\356?\324B*;\353\341\356?\324\375N\224!\202\343?x\346z\315\014\360\344?4v\370\361\3425\316?\016\206\242[!\r\351?9\353It\241\315\347?\034\246\002\224\214\024\322?\350\204g\351\036\353\275?;\271\355\'\325|\357?CY \327\031\337\356?\310\235\031\276KM\261?}K\203S\331\252\341?\356\255\216\023s1\356?\216\366\032\r\257&\350?\0323\263\265\353\323\325?\324L\000\177\367>\311?\214]D>sJ\302?\0223\330\265\305\311\344?\266&{\324\252\201\335?%\005\026\263\'\373\343?\312\302\222\235\274S\331?\023KY,\317\261\341?sa+\207)w\353?\254\361\370BL\213\352?E4\346D:_\356?E\027\264\027s\274\345?6\231\254Tf\032\352?8\307\341\014\252\004\336?\203\320\334\334\017\272\345?\220\260\342 \210\371\314?-\331\240\314\r\020\347?\266P\017\364gs\353?\030Zn\260\3363\300?\317r\367\210a\001\340?V3\035?w\\\353?\014O\207\005%\331\342?\300C\341\353\276\304\303?^[\370Z^?\321? *\213|\333\334\261?\300\327\235Y\321E\353?\201!\345\021\036\343\347?\227\020\210\014\n\341\345?\244\220\016\242\375\027\314?\204\312j\215?\217\302?\37018T\033\357\307?\247\0323q@\205\356?\260\022\002;bE\334?\200\276\344\334\246\253\304?\314|N\323\206F\342?\222\237T\3075`\357?dv\307\006\201\255\344?8\305\251\3411\254\327?\303\237\350\277?@\357?\217C\202\2312\256\356?\240?)\336\177\222\232?|)_\373F\315\337?\2175^\303\342r\352?~\246\017a\204\246\334?@\273\313\202@|\355?\020\207\262\3706K\321?\004\2377\211\202\353\346?):\033e4\207\352?\330\230i\030\260\007\344?\302W\242\003(O\323?\030S1\334l\037\335?f\360F\354.p\333?\244\200<\241\350\374\345?\004\316L\273\031\326\320?\360z\354??\224\326?L\207\265\234n\350\312?H\211$\234\341_\314?9\270<\267\374\262\350?p}]AFY\317?n\001U&}\244\327?\276\014\'[U&\353?d\213\201*o\237\312?\000(\321\224\352?\213\2778\314vc\357?\342\027\205;\177/\356?p\217\345bK\343\351?\342\022\225\006\236\037\325?\232p\367\0219D\324?@\3340\007\224\037\336?\360\371\023#A\377\312?\222\227~\363D\262\327?@\255\320\243\314<\230?\244\222\013\032\241\022\346?\330b\361\321\177W\340?\006\305v\206\033\264\355?\301\226&\366\361\020\355?\267\"d\341\303\337\351? \312\004\360\rO\264?1]2\n\200)\355?\222:\322\375\020\261\337?H\221\237\261\2737\263?\025)i\324\001!\354?\256\271\253Rk\375\334?\323?\026\260\2261\346?\243\263\226Z\037u\354?0SO\2726\277\252?J\017\214\320\361Q\350?\354K\232\370u\200\332?\230\265\340\264=\330\273?\220\262\273\232\311\374\241?\000\004a J\211\312?\371l\026\2677\273\340?\331\244\n\303\321\000\356?\226\353\257\334\017A\357?Z\367\207&\330\037\320?\350@\t\'h5\336?\2349\377\010O\334\345?\231\000~(\310\214\353?@\211\354\367\346\372\315?R\364P\273\\\014\351?\340\206x\234\321h\352?S\251\355!\307\366\354?\246\217\010\273j\270\342?\364\254\250\271\\X\343?\340t!{0_\340?\214M\366\271\305\301\351?8.\261\r\307\200\321?\226+U\257O>\345?D\202\000W\036d\343?CH\370\010x\271\343?\324V\247\230\225-\310?}^\235\253\024\373\357?\314=`^\355Y\340?\274Za\"\020\001\317?\032p\303\3053\001\330?\"\201*\003E(\330?\"|\326\rhf\322?\242E~o\031O\321?\313\251\311\271\365\030\347?\356\276\307\375\255\275\346?x2\212\271\272\276\314?\034\2672C\307 \321?vT3\372\325{\321?\350m\217S\223b\270?\337\235\036(\2228\355?\300\355\004\327\322\336\333?T\303\030\222\263\"\321?\310\247\241\020^E\276?,6\020\253\247,\301?\331f\0144\363@\341?\222\337\026\0108\225\323?&-\024u:\301\321?$EKdj\273\342?\2604\242\2234\307\304?\200&\030#\216\237\311?p\216g\233/\373\353?\3319\216\014\257\272\347?\211M\276\231\316&\355?\034\230\346}\202\\\341?\001\270\017\241\2441\356?\274/\311g\354t\312?\3344\210\346\315F\336?\226c\331\317\335&\324?\004s\033\232\030\372\346?\274\250\364\005M}\346?\030d\263\246@\236\357?(t\247\372J\326\342?\0345\220\313\212?\343?\236\251\014QFa\346?^\224\272\277\327\375\342?\255\301\345\252Pc\351?\213\365\345\201\237\022\346?$N\232wI\330\350?\240w\323\241\236=\322?LOU\020\337v\343?e\005(\307\301\341\347?\250P\275\037\032&\337?\330\3121\203=\305\343?\211\230\306\353\343\226\342?\300|\314\215,\210\316?\330\027\001\010O\237\313?\340Z\025:\\\262\317?x\2358g\013\232\337?\222\022 13\355?Va=BH2\327?\262\320\363vN\373\330?*\357z3\206!\337?\210?x\343\376\372\312?\000\242\022\246!\037Y?\206\304\201q#O\334?\321\2041xt\\\354?p\360\0071(\252\323? L\333\222\2363\245?h=\347\274UU\324?\004xz\306\342<\344?\335\037^Z\222\367\353?\272G\260\216n\346\344?\246\202\246a\361y\323?\020;Y\351\212\221\306?(D\227\362\216\335\311?\2403\022\213\223\005\223?6\276\247;v9\352?\314&Y]\357`\305?\360F\245\261\034E\242?x\201\034\226\372\310\260?\252\000\t\2039\314\340?\300\333\225\216\246\225\253?\255\277\343\001\371\371\353?\3605[\211\2542\240?\315t~\017\324\275\351?\177\356\332Gq\304\347?U\001Q\215\226\002\357?zO\201\333\240\376\342?\000\343!&\257\237v?\213HU\221\210\253\342?\232N\'\202\371\334\332?X\306\211=\221\005\343?\220>$\016a\236\252?\273h\217\030{\243\342?L\364\260\024\034\307\323?\364*\200\200\340\213\312?_\261\216E\362\251\355?p\3130\370y \304?S\014\273Xo\346\352?\332\234N\033Ho\352?\327@rV:\211\353?\200Z\027)<\207\241?@\365uY\262\016\222?\000\355+\211c\336\307?\014\361\377-O \305?0\312s(\r\344\331?:M\370V<\006\336?J\272\221\326\242\266\343?\004\200TC^\306\352?\3706j\224\241\010\314?7{\237\244\333\252\347?\232:l\270k}\330?\034\036\031;X\362\303?\374i\334i\317c\300?8C7j\265X\330?\334\260\346\312\327\227\316?\214\032\2228\033\334\346?\320\010O\272\246N\317?`\376\357(\260\227\252?||\333\332\241\006\355?\371&y\257\352b\347?\312o\364a6\243\334?\344\316?\252\224[\335?x\321\216\275\337\213\274?\250\312\251\243\014\345\353?\024\207\235\'a\257\304?\2602\353\0315\202\276?\275\217\324\247\343R\351?\206\231\000xi\271\323?I!\225\327\236\350\341?\000\'A\310`\020\337? s\307\377W\322\301?\316\264\016\364+\2276\231\340\330?\001\347?\272\211t\351?\034NT\372\023z\341?2\341\304M\260\016\332?V\033R/D\206\344?\301\233\305\226\347\315\346?\236\303\340\026#\337\356?G\010D\363B\030\347?\\Z\307\"\340\221\322?\024l[\014\212d\305?\330\225`\372u\244\347?\211\010\002\310\\\302\342?\037\272\325j\261\270\352?>n\245\325\221\267\332?\267\224\221)!Y\350?\340\217\007!\016;\305?\016]b3\237\022\331?\264\\\226\"\266\273\301?\242cb\264\2769\333?\350\205\270\327\264N\326?\310\300\033\"GU\307?\270n\235\2737\014\351?\254\017\270.\345B\301?\002\360\036.\320.\327?%\241J\t\023&\353?s]S\324\"3\351?\033\251SS\370\365\343?k\310]\256\245b\355?T\362\211\332X\334\311?\310\314\244\205\350d\275?\356V\276\366\267>\347?s\234\367\354h\261\344?\200\'\3034U\013\227?\225\r\006p\235_\357?\300\340@\312\010U\261?\250\027\264\222\"\350\263?\272\233\nO\246\376\353?\334\325\177\276\231\310\347?Mc\374\023l\001\355?\303\"hdZ\024\346?\244X\355?\037\254\351?\\\254\223\270\374l\307?)\357\357\247\021x\351?w\030\302\022\353\363\355?\211\023\366\273\345\025\346?s\230wM\033\362\346?x\356g\030\"j\343?Z\244\333\370\232\000\320?dO\350\022\037V\304?\264\273\332\"\364?\304?\362\240\321s\364\362\337?f\257\335)3\322\355?\242di\226di\332?4\205%S\236\214\337?C\372j\332 M\344?\010\312\205\364\006\375\337?\353\362\367\206\354x\345?\254X\022~\265-\343?\241\226k\224<\005\346?\031\303)\273\232\005\343?\370y\0047cq\331?\344x\232Y\222\250\350?\201N\321\013\026\223\341?f/wVx\207\355?Xl\265\nI\372\332?\240\006i\037\216J\230?\330\257\235t3:\315?C\021\322\220XD\352?\353F\363\222]\r\356?Px\225\027,e\267?\263\313\270k\326\313\355?\327\310\214\016\240p\347? \370c\223\373-\264?k\023\306\253@\033\357?\006v\2277t\347\340?\220\245\002P\010\207\257?0\032\020\325\245/\323?V\274d\006kx\351?_e\274\027\226F\352?g|\362\253\346h\356?s\354\311\224\370\342\351?\024\340v\305\267\232\356?\022D}\204\317\267\335?RT?5\213\335\353?\225\336\326\231X1\343?\314;\242\222\314\236\340?c\036\264\232\212\303\340?L\242\222\356\364\307\347?\325%X\267e9\346?\200\321\035\362\300\253\345?\314\245\021M\334\232\317?L3+qV\364\355?ZA*\327c\247\344?\364?\007\273F\263\346?\"\026\252\317gm\350?\361.\315\236\230\357\347?\260~\302\207G\233\276?k\3264\220\033\356\356?\357\212*\346&(\352?Lo\316\262\322\344\333?\000\'Ol\236\201t?p\236?\227Lt\243? ]\272\331\340\361\305?\242\\\236\001Q\344\334?(6\r\316\372\256\334?\300\334\376\307R\035\236?\266\324\350\004\021`\331?|&5\203\372\365\310?\372J;k\250\257\344?\222.\200\352n\321\330?\214\004>F\334/\323?\362\213\216\223\333>\355?\316\322)\021d\027\351?\262\267k%\340\036\343?Z[\023\246M\003\357?\366\351+\361:\205\343?\024=R@\343\214\302?\314*V\033_\030\335?\307=D\344>\237\343?D\351\275\177\007x\307?=\314S\334\276!\353?\266\265\300>\002a\342?\264.\200\256&\360\355?\336\264\275\350\365\261\337?X\363\256> \016\334?nL\311u\020\003\336?\200E\357\237`s\217?\325\206\256\372,O\347?P\340ri95\264?\230\217\016K\323\374\302?D\215\026|\331\361\356?\022!\306\013\264>\347?\274\354 \246\3129\347?\242f\352\336\274O\346?\277\003\377D\336S\345?0\003\035KZ\353\322?h-L\271H\'\274?\322\005f\224t4\352?\334\266\031\214\366\327\322?\310\"\255\360\260\031\352?\037?\306\371[\223\355?$0\322\370\210\236\302?8@\320\356\237\013\305?\351\234\277\344\034\016\350?\033\3411\022\220x\354?\376q\355>\356a\351?X3\224\345\322\355\351?\200\353 uhm\324?(\271\354\303\320a\326?t\355\334\227,\301\342?d.\344n\256g\332?\307\017i\3005`\342?d\211\017\243]\016\322?\241S\230\253A)\356?\332l\377\244\n2\324?\202\274\207E\014s\340?\027D\305\024z\021\356?\203`\247\n\322\262\347?\\\313\207\r\363\201\337?\310\224\265\322\002\022\276?0y\306\362\204\356\270?\300H\263\336\314c\341?\360\004\203q\240_\317?\355\207\206`\271\210\344?M\3768I@\003\353?\"w\2160@\241\353?\300v\271\275\364\022\242?(#O\001\031\252\340?\270C\r\325\330\220\331?\357Co\363\207\215\345?\366\266\277\024\210\335\353?.X\033\264i\314\320?T\224\263\177\3634\315?$\272\375\214\235S\340?\024\327\325o\370a\304? C\213m\037\017\230?\034\r\312SA\317\332?\237\030\337\270\342\010\350?\240\362[\274h\021\235?\350=\"\206\003h\276?4\326\300\005\2471\356?\240X\204>p\301\246?\024\326\356\230i3\334?\267\352\245! _\340?*\360?\"\\\n\346?\250\n\356\3266\005\337?\320 \264\324\037\345\334?\277(\205k~\330\352?\235\345\036\234\234\000\353?`\364\322-\250\314\303?\214~\026\314kS\351?\264\215\267\\\013`\350?P\336cU\037@\325?J\177V\254\030\037\352?\200\242K\234\205\236\242?X_\332xg\356\335?\344\337\303\325\313\330\331?\314-88\330\314\312?]0\'W\'1\356?w\266Y\246h\201\341?\370\305\3753\026@\347?\204\302q4\343P\343?\254\303\033\272`\321\324?&f\261c\244\004\321?3\360\037\214\362^\340?g\312\2764\030\350\343?\3518\334\251\363C\341?\376\320\2002\352#\340?\210\211\232\213\331\351\322?\200\241\371}in\220?wN\222\353\234z\355?p\205/^qY\277?\374\245\016\326\357\313\300?\374\332+6gr\322?\334\363/\370\320w\315?\"\272\274\014v\216\334?\010\230\241g\276\313\320?xY\221\326\311J\346? \202\017!\227\227\263?\364}Cc\300\276\336?\334l\2327\2617\324?\013>\212\363v]\343?&\r\017j\301\275\344?\354Ev\014\331\337\302?k\225\221\242 \t\347?\230`\371\351\'\316\263?W\346\302\232\325\233\341?\360\235\272x\256\300\260?\224\274\006\303w0\331?4\237\216E\252\003\343?`E\353\230}\177\346?z\034\207\335\270\273\343?yeL\334[{\341?=KwWh_\357?8\007\204#C/\343?d\244\037m`\261\321?p\014\257I\202\005\356?h\341T\\0$\326?\224\316K=\177\303\312?:\014\035\356\000T\357?\264\240\361\323\340\201\324?\0179\312\222O\344\351?\336)\307]\235a\333?\220=\"\350\233\224\253?c\257\236\363C1\350?\022\374;\3201\346\351?\344C\361\263\270s\342?\374\275\333\340\327]\307?\0340\243g1\035\323?\360d\360\267\373\317\327?\321\236\266\374\353\376\343? \372\321\202Hq\267?\275\003(?*\220\356?L\r9\217&\215\330?\313G\220\233\320\316\346?&u_|YK\352?P\004b\331\325\243\342?\242\322\336\r\360\250\332?\010\261\026\222D\031\273?_#\226\205!0\354?\220\211\257\367m\026\312?Ow\225\253w\216\342?\335k\3301XI\354?~\2040O&\227\346?\260M\327\356\027T\251?\370j\nX\310\235\265?\336\317m\370:\217\326?U/\216\007\302\354\344?\334Gq\361\337\270\312?\3142\247\351\314v\354?\025\321\030\270\221\022\353?V\344\3402\321X\347?0\336N4\234z\313?\022ZB\241q \325?\373\035\347\\\244\224\342?\000X5\351\353N\333?<\355\300\205 8\350?\216\014\200\023\230\033\324?dY\307\273\030\010\303?%\231\n\271k\351\346?x\220\327R\313\221\346?\2163f\321\341z\351?\332 \303\306\375\332\333?\374U$\270\357\234\333?\341\234\2669\276i\343?\317\003\243$\323\013\355?\033\217\226\3707\225\341?1;\214\0164\220\355?\370\217\210\352f\376\307?\300\374s\207E.\205?\212\317\0146\334d\344?2\3414\rC-\320?nX\256l\013\310\326?Q7\245\031\256\224\341?\220\0236\300\310Y\343?h\300oD\016\204\267?i\210Ux\224E\352?7\307u\252H~\343?\032\323\033)\210+\333?\230\177jy\305\345\333?\347\255h\206Q\033\344?\342`;\313;\247\321?\022\354:j\243\243\354?\234\027\"\307\0310\325? \r\355\034@\332\326?\000#\352y\257\333\264?\000\336\327L\357\253y?\366\030\002\326\253\265\353?\300\030\277\260(\020\243?s\022:\031\025\336\353?\243\336\331(9t\344?\330\014r\216k\270\357?\355G\240\327\371\376\351?\333%\354L\003:\355?\313\357\222\003\265\352\346?\000\353r\203+\301\250?\371\305\253E:\244\353?\260\233\026\306\327h\324?\3504Nm9\310\320?\t*\014\316^\272\357?\232SU\201@\t\352?\356\361\364\273K\343\353?(\213\211Cx\321\266?a@lYh\316\357?\000\010\266\010\273{\335?\372\340&\271\316k\325?\000]\232z\177\353\325?\364\204\033^\246\217\340?\345\211\370\252?\013\346?\350L\360>\270\t\344?\255\355\264\342\327\362\344?e!\003\366\341\222\340?\203\026\307\376\030(\357?v\tD\240\303\327\334?G\026\246\315\035_\342?\351:\363\341H\212\356?\335W\243\033\376\247\357?\357\241\346\250K\347\351?\004\207@~}g\340?8\363W6\352\314\321?,X\2062\335\363\312?\201\262\263.i\251\352?7T\273\241\255\217\357?\034\004\307\254\344\225\301?q\000\3514cP\356?\000\310\215\363\002\n\303?\002\245\326\220\016]\325?\216\374\260\234\225\016\352?\000\024\234&\263\237\342?\337\036\016a?\025\351?\270\036\363\343P\020\327?\351\206\035\322\322O\347?(\007\373\333\3235\350?\365\222C\n\230X\342? \245i\310\217q\331?|\371\313\341\335\264\354?\260\322T~\323\217\241?\004\270s\225\310\350\305?\023@\315\244 \237\356?J_TJ\306\221\330?Z+\351\361u\005\322?\270\214\356\340\377\206\305?\324\321\341\321\014\327\325?\220\246\224\313\352\252\240?;^Y\351\356\272\343?\r\317\303\322e\351\354?)^;2\213f\357?\207\230\350;\325\034\355?\206\266\034F[V\353?\372}\264I]\223\325?\261\270\355~\300\247\355?\210\2532\314\302\\\341?\362\320Y3`\256\336?\010\236\220\t\221V\310?\000\351\204?\247\275o?\362Z\250\217\030\364\350?PP\224\216\r\274\252?\275\003\270g\0066\343?\332\315c\317\235\364\342?p\026\322\272@\257\356?\000\006\240\241\321\257Q?\222s#\020\324\373\345?\364a\334\214\253\034\330?v*\231\270|8\343?P\274* \371l\345?2(\025\r\221\245\356?\355\325\343\245]Y\341?0B\037N\345U\350?\246\253\207\274FT\343?\264\033cm\232\215\303?%\245\333\224\300f\354?h;V\013{\222\300?\205h\023\324\004\332\347?\317V\220\3130\226\353?\360\016w\225i\377\310?\000u\371U0\360\201?1\267g\026BQ\352?\215Z\275\231\255]\343?=vAv\223\224\340?\317\361\377\305\250\305\351?\367\035\021\025\264\203\347?\0356\277\3331\232\353?T\327 \034\232\262\312?.h\r\270\034\202\352?\212\'\260/\004\372\321?\352\332\374\324\260\371\355?\251\272\022_\031|\355?\346\366:\222\014\214\334?f\310P\200u+\334? \360\013o\252\376\264?\024\212r4\201\261\347?\365\241-X\323(\347?\220N\360n\033V\320?H\244\363{\333\212\260?@\345_6*\270\324?\001r\337!\002\231\340?\320\225E\031\370+\301?\240\231\222J\010O\316?DB\251jJT\306?\336\303\213\314\022\035\350?\246\324\232}b-\341?\300\002F\240<-\334?\270\254\214\227\033\226\311?\323qE}\346\005\353?\034\226,\211k\377\352?\233\261\352\000\004\210\350?2xIR\334\037\352?\376\306\232\237\307j\334?\340[\227\264\273)\240?|~\277\016I\027\303?\240#\t\276\364\020\351?0>\215\225\322k\256?@g\231%\367\336\205?*\322STU\331\336?f%p\362U\246\343?$3}O\223*\312?\344|\271%U\266\350?l!\371\020#`\337?x\315L\242>\036\303?I\216\211\036\363k\340?i10\273|Y\350?\212\"wqC\245\333?p\202\303?\240\337]\345\220\267\223?@b:\267\222k\273?\212@\007\265\216\014\356?\360\277owi\360\347?Y\246\357\006E\220\343?\0032r\203\200\377\356?\240tC\000J4\314?6\323\331\211\022m\346?\234\312\366\317h\312\304?\n2\302\357\236Q\330?\032w\250\365\322Y\352? \003Q,\312\005\307?\020\260\t\002X1\333?\226lz*\242\017\332?`\273\204ni^\254?H\367\"Z\314\231\311?\240\244\344{\003#\226?\2619\251\201c\260\343?\302\255\312g\375\265\336?\364\275\231a\302\346\355?-\362\366C\254\247\355?UV\3620:\345\356?m!\n\177H\223\347?z\363\250\004f\246\332?dPv\344qU\335?\010R4\214\327v\327?HTow|\351\353?\032\357\270YZ\271\347?\317\354G\207C\036\356?(\t\302: 9\265?\036\025\366\224Y;\335?J\315+\373\341\r\343?\3046\316d\377H\354?\375\252\314\320\237C\356?\321\327\010\014V\213\340?\250z\377\307\024\021\341?\216\255E\227X\362\325?`z\016\224\374\217\324?\331s\260\230\235\356\357?\3667\274M\323\316\325?\222\333\n+\272L\322?2\371:y\316\r\352?b\346\314kT\315\356?\266M\212\323\304\242\330?\322#\010[\357\253\331?^\230|B\017\026\323?\014\343\023\000a\316\305?#j\0029e\303\352?\213\256\204\336\000W\345?\020\\\374\376\357\267\316?\2047Tb\241\237\354?\300FW9\347(\213?\300\036\261+B\343\223?t\324tc\367M\323?8\276\244\366\343\260\276?\260/p\376\315\271\342?\304\001\326\007\320\'\301?Y?\372^\004\210\352?\334\313\014\206\014\001\336?\000\260.P\221\022@?\322zB\237x\202\344?1i\355\307\265\245\340?`N\201\323\303\345\273?i\367a\033\330\223\347?\340\322\2334[\326\354?\367/\247%\221\323\342?~\t\266\3543\305\326?\241C\022 U\355\355? $\027\360\206D\241?@\321P!\327X\253?\201\332(\2246\304\350?p\026\377\006\017\342\255?\241\264\215\037\274C\344?\010\035}\321\265e\357?\"\224%\2044<\344?B\300Sj\001b\325?O(\212\374M\200\340?6?\252(\244\245\346?},\372\246\367\310\347?FB\310M\030\341\353?\370Z\tX\264\337\354?\236<\304\256\330\270\342?\002\265\246=\037\004\325?\260\214\227e\217Q\267?\205\n\250\r]/\340?P\24086\032\231\350?\024r\205\272\0173\315?\000.\303\206\264\253\257?v\203Z\n\366\001\320?\344+\307\351D\314\310?4\362\260\352\227:\353?\354\014\254\333u\347\321?\200^k\300\350\020y?\007\\\316\325\275\005\354?oZ5/\237\213\357?\025\262u\274\305\245\344?0\n=Y\321\274\355?@\270\023\240\325Q\266?\222jv\206B\035\342?\021\225\020\003\007\363\355?\260jj\341#,\304?\214t\035\t\222\366\300?\020\361\332\252\306\247\252?`\035\t\252O\243\221?\311\021\306\252@.\356?\320\235\272\317\304\351\345?zd\361\261\215\\\330?\264cP}\026\257\322?W`\326\302\233\021\353?\343@K#fl\352?\t\345!\357Q\212\356?\234\024\206{\272\n\322?F\262E6H\003\320?\202\2774N\243p\336?\326\020Z\312;J\327?\250h\022\345\224\234\351?\340\003\345\346\256\370\265?\342\230\001\032\231\307\342?*\222o\354y\264\331?\376\243\254M\243{\347?(\256F \231)\270?\023x\3366\370H\352?\226\347\350\003;\366\356?@\342\372\347\372\333\276?\005\371\364\355d\261\344?\333\361\032M\r\033\353?\022\307\327|\207\233\344?\031\201\023\341\032.\344?\037qx\232\017#\347?$\323\310\266\240\234\304?Q\322\177\'\022\337\357?U\234\024A\035;\344?\000\254\251\246\003~\213?\220\237\362T``\334?k\341\024\035\377\376\352?07+\212\010\020\355?\241c\013\333\214\362\341?\000 M6\250\244k?\333\351\376\nX\252\354?\216\004z\342\370.\326?\177\037v\231\355\324\340?XW\236:\013\332\307?0\036\254a\033F\305?P\205\315\212z,\353?\240\264\355\035\224\261\251?\2102\200\342l9\350?\330!\306\354\025K\275?\360A\242\230|\342\272?\360\235\275\267\236\n\250?a\337\230A\246\002\357?,\2436\301\201\211\335?)Jt8E\000\351?\240\005#&\033\001\323?XCq\266\325`\316?(\356\3261/\241\350?\210T\241\227\372\034\306?/0_\305p\322\345?\246\352\305J\243\272\320?\024~\177\211\254>\303?O\0044\356\203\315\346?\356\337\030u#}\352?\324,\203\027\221\234\341?|7\177\357\235\326\323?\026\331\254\355\350\233\355?\341F\215<\347\304\357?\200\316*f\367S\350?b-C\256\364\215\342?\312H\024_\207K\334?F\3505\315\331\241\342?R8}\214_^\323?\036\320\007M-}\337?KB\200k\237c\346?\033sS@\0339\343?\274\353\336\n\310\325\302?\024\010(\217\276\246\353?\270\266\034\232\337\301\352?\0149B^\271z\316?\020k\020\312x\266\313?\322\336\310D\204s\323?b\211d$\374\346\326?z\312_\342I\265\320?\300Hl\342L\253\356?\314>e\t\233\322\311?>\365@q\376\024\341?\210\010\211<\272E\266?\014\312\342\000\2676\325?\243\334\021\225\005A\346?\240f\345D\216\315\320?\304\367\332x\260\341\310?\246\203\0313\004\210\336?0\tq\313[V\357?7\303\345\3454k\344?\275\246e\233S1\353?\241#\340\301{\230\342?\333[\372\\Og\356?\265+`\343\226\231\356?\221\243\025\016\177\032\351?,LJ~\200\371\306?\330X/\227\377.\323?\234):\254{K\341?\244\273\374\017\324Z\357?\033=\324\201Ep\354?\250\332\2745\306\007\352?\002\037\365\205Y\226\341?\335}\332\376hq\354?\222\344\273\021\263q\340?\006\257\207\317)\266\340?b\327\303\221\234\236\356?\216\232\314\331\017\365\342?\340g\341\301Y\254\327?\266\261(\255\354\307\334?P\313\237\217[.\322?\240% \244\021\016\220?\266\021\023\372\372V\340?\360\317\327\376\024\031\305?\003O\356\276\270\367\354?\030\235\226\346\260b\313?(\212S\306\300_\311?\256\205\276\325(\353\346?\303)\342\026?\177\342?\210M<\333:\007\325?D\274\367\302\001!\301?%%E\320\'\245\352?k?R\347G\346\356?\354\311\203G\245\241\325?~\004\311\300t\351\347?\315\313\374\366\331\325\341?\257\0213\226.\274\346?8K\341nc\227\353?q\366\206\034\377\321\355?\205\000\340\037\310\330\351?\370bs\207hU\267? \313\217\177\263\376\310?pM2\374\247C\305?\351\233\206\020\306\366\350?\000\331\3309\255\361s?P_<\004\344Z\241?H\226Q\200\364\\\321?\251$\363,3)\356?\236\224/\000\312\347\336?\'\321\213g\227\226\340?\321\013\351<\262\n\350?qI\020o\321\247\345?\302upEG\346\333?\016\023t$D\264\352?\260p\004O?0\313?\002\311\323N\212\264\325?\020\352\324GYR\271?\261\206\237\253;\355\340?\3601\264\2771\350\333?B6\033\225%\033\333?4r\272:\335X\306?|:\325\000F\310\340?\371\252]7-\027\343?O\376\2105\007\256\355?@F\210\347e\246\262?\247\341Z\372\264j\343?`%\317\271_\241\255?1\346IT_B\357?\234\014\302\313V\020\342?\023X#\341\324\365\347?\230\374\303\372\276p\356?\330\232d\230V\312\311?\214\037*Rh\177\312?\310XdM\332\233\325?\250\026\206.\216\024\322?\345q\325\304%\032\355?@n<%\226v\217?\314\344\203Z%n\347?\2004\345\017$\032\247?\030\367.\337\277\262\307?\177X93V6\350?\2474\263M\212\314\345?\372G\360QC0\320?\352\201J\230w\351\350?\263D\355RC\026\341?\002H\245-\010\355\340?\030\237\363\037\220\371\275?\240\250|\241\227u\230?.E\360\252V\361\340?\266\003\300\262\230\014\320?E\215\247\367M\350\357?&+\n\351f\363\355?,\235\r\002\032\316\314?w\020\375J2\276\355?w\322\2366\234\376\344?<\233\013\017\254\021\300?\214\356Q\376b\013\335?\'~#\007\315\326\354?\355\314\273+\367\370\344?4\361O\027\234\233\353?\265\267\260\034\247\361\352?\370\304\016\313c7\261?\226\333\247\313y\353\330?\205\022\203\264\030\222\354?\177\306\2626<\036\354?xy\021\235\376\345\274?\021&\226\315\002\t\353?\320\252\000\305\035\252\306?\376t\331\252\236q\357?PN\236\211\330\322\276?>\304\251~\004!\320?\244\2620[Pk\324?:a\014,\220\361\341?\316_\222\024:\377\334?\026\323<\243Z\273\351?\232\315\010\325!I\341?o\023,\035\033U\347?BJ\1771\357\227\333?\242\304\013q\377\210\320?vu^E\216\300\320?#\232e\\\005\003\352?\200=$\276\016\246z?\374\266\254\005\302\013\336?\317aA\003\332\336\343?\300C\3573\022z\336?\274\232@\270T\t\350?\tKT\210=I\354?P\241m_\036\363\336?D\365\375Jh\324\352?\035\301!\356\350\265\346?\273\311\030\2340 \345?\242\244.%\352\201\342?}\222R\377\263}\356?\267\243G\252\205)\341?\317~@\351\330-\354?\2737\250s\327g\353?\200q\336\221\351\\u?P<\230\234\253\342\243?\300\357\n]\200\264\330?JQ\265H\275\'\357?O\t\034\241\275E\356?\025\030\354\265\341\243\343?\214\350\263\033\356n\332?\313;\014\263\370\025\344?hg\307V]f\346?\023\034+\260D\334\355?\264wL\233\227\370\303?\230\371\237\246p<\343?\t\017\241\375\210F\341?k\371\213G\341\346\356?\264\250\352\005i\232\327?\212\\\rj\211I\342?\265\203-\221{\370\346?\320\034^\034\245\013\304?8{>\247z\204\303?\334 \3122\260\377\316?\240\344\273\364\000\222\345?@\270\203\360+\030\324?\270\215]\027\203V\330?\200u\346b\246\204\323?+@\242]\031\351\344?\206.S\307\210.\347?\360;\300\373\241\035\345?\310k\264\001\275\001\277?PW\352\304m\227\311?\356\363\275\211\320\000\342?\353\2754\337\177\363\356?\000\320\377+\331\255s?\\\251\014\236\252\005\335?:\202M\215\372\'\343?\330\035\274\213\360c\353?\320\347e1(\273\255?\240\210\3051\324\313\317? uL\237\357P\270?\350\330.\302\004#\315?\251\261\006<\220\353\342?\002\221`\200\373P\324?\354\217SC\375\246\333?JDbO\352q\320?\350=\021\254\237W\274?\243e+\362$\033\350?L\006\005\232\367<\333?\352\337]\tzC\333?l\016\204h\347d\304?\"\370\207jX?\343?\264u?{U\270\321?\314\265\010\3146@\305?)\315\320\231\2214\347?\255^\256kFE\343?\273o\021\247;u\344?\310ix\370\033\337\270?\300D\345\224\321Z\335?P\357\020s\273<\266?&k\2007~\233\321?M\316\223#\232J\343?X\273/\302\221#\344?\254e\002\177\215h\305?\277\316I{\030\236\352?\312~~\317\363t\353?\020-\351@^\316\327?\345<\302\373\207\005\343?\014\334(\240\365m\311?\030\237\204\025\236\273\331?\025\262\303\221\211\260\357?\214\224\240\300f\326\323?D\251\227S\203\317\314?,\337}\211I\014\354?\337>\242\212\306%\345?>\013P\331\316\272\321?*S\003\311\250S\327?O\375\364\210{t\347?(\365;7\345\341\267?0\326\233\361\247u\240?\302A\350,\006\320\336?-\313j\004\005N\346?L\267\036;\235\204\334?\310\2237\260\202\232\306?\224mC\335Z\323\305?3\010\322\363o8\347?\n\232\275\230%\311\354?!\002\352\327\353\330\346?\355\305\375\023! \357?\370r=)\326\246\303?\032\341b\001\351J\347?\342HRK\377\330\341?\366(\001\"\315\272\345?\270|=\341\211\254\267?\262%\240^KZ\345?\320\037\231\216\347\363\242?\200?\003\340T\361\321?w \3756\376\314\347?t\342\327\337\231\251\310?\020;]f\343\264\336?f\036~\333\344\237\322?\246\277$\234Hm\324?\370#Kh\370<\324?\332\030\003\272\225\351\347?\351J\350A\256\265\342?\210\242\370o\036\335\337?\276\332\254WT\247\331?\014\356.\231\r8\304?~\210\252\260\246\353\356?\320\377\330\373\021R\256?@w*=\244;\220?\312c\024\262sg\346?aN_\310~n\353?\000d\266.\202\274z?\004\203R=K)\346?z*\300\033\004\352\331?X\036\200\265\374-\313?\300\027\001\264\370\237\343?*\205.\272M\300\323?\2178\270w\314\350\350?\034fE ?\240\341?]U\252\212\243\014\340?\230\345}\022\020\276\336?(iw=\013\317\262?[\013\34743\007\347?\024.\267\014m.\312?\244\346\215g.\032\354?\004\3653\272\342\310\315?\322`\263\323\212\006\346?h\216\216\261l\004\317?\235\332\022O\207b\355?\320v/H\tJ\341?2\230\342\356:\274\357?([\000\\\224\\\272?0zU\014$\364\273?\"\327\007\013$\245\320?\207)\260\215\006\352\341?\316L\356Y\034|\322?\340Ty~c\216\250?}\320\343\207\r\311\357?\324\211J\307\345\363\323?HT\270q\375\021\344?\271\177\217iG7\350?\003l\326@\324\333\354?il9\237\023\370\351?\346PQ\305\203\001\352?\236xL+\277\370\341?\310@\356\210\017\202\332?0T\254\303\350\201\265?\234K\375\235_\317\340?\017\234\256]\177g\355?L\034\351e\345\003\354?\364\034\254\327\221\365\344?\247c\3218P)\356?6\2138\345M,\333?\020\322\215\003r\t\343?\004\220O)-\\\355?V\035*@N\277\322?N\254\360\033\357\351\327?\252#[q\003\023\356?\'\356\372~\331u\342?\227u\330$a\027\343?x\240\002\324;C\301?\310\001L\230i\276\332?\010\226\213W\312\004\310?aF\r\205\230\301\344?\226\223i\tZ1\335?\336vt\335J#\352?\364{G\026\223v\344?\357!\254\276\3666\357?_b\350\372\210\021\344?\356~\252z\253\276\323?)\245\201\022\"K\356?\317\342?\250\365\324\372\350\240\306?;\\z\350\364\250\340?\235\233\036\030\037\341\344?\007\022\033\202\306\265\345?\0002m\222)X\344?F\016\213\303.\335\345?.&?J\200\215\325?\244\255w\035!`\343?T\035\210\262$\212\320?\300\277\364\251S3\260?\331E`>\324\323\350?\026MqA\210\377\335?\025m\3147-\340\346?\204p\241Mzy\355?\347\363.\214\2019\357?\006\001\200\222\201]\324?\347Y\250\266\201J\356?\000\304\374*\217.Q?\006\304E\004%\213\343?\312\r\233#\365&\327?\374WRX\304\372\310?\204n\354_\036\215\305?8\216\266\326t\027\303?G\3252\201\325\221\342?K\032\022&\356\276\351?T\353l\345\334\245\317?Vxt#\224\325\342?\365Y\267\020B\327\350?\240qs\n~n\314?2\325k\341\205[\347?\000\261\356}\362q\256?\370g\362\324\210\310\260?\027u\215\3702H\347?~|\334\231\244\311\321?\2016\255\025\271\245\354?\344\2363\255\233\232\315?\022il\217\306\377\347?T\006\toD\212\317?\201\324\354\362\202\324\340?\202\323\237GEc\350?\002P8\307\177~\352?\2276\266\255\376\247\343?B\204;a\021\301\354?\274\r\033\000\324*\347?\254\336\335\001 {\312?L\375\016\323\250s\312?\340G\365\352\371\261\304?\0300_\323\221\\\313?Lh\304pO\366\322?\300\"CI\240\353\276?\265\256\250\034\253\254\351?\253K?9\221\322\350?@\273\276\276\304\340\350?\032\340@\267\214\006\351?\030\274\244\\\257\372\337?\242\246\272a\363\316\351?Y\243\377>\333\346\350?\034?\236\344k;\307?^9\373\300uP\350?\032C\247\340e\220\334?T\266\250\276\002M\331?\234\256Kr\350\017\340?\230\000\374G\371\357\343?p\357s\013\260\n\355?\037{\035\205_\272\346?\177\267[05\202\355?\200\001!H\204\020\355?\240s&\254\272\301\227?6;i\321\354\233\354?\260\326\341\003H\023\262? \237o\2044\250\325?\010]c\242Y\010\307?k\323\376\334K\302\353?,\234\201\337\032]\323?H\257\304\265t\332\340?*\315\324:\345P\331?\233W\271iE\016\350?\330\304\271`\t\275\335?\366-\332nXf\345?p/p\240\037\003\304?y\342\314O_:\340?1\367\252S\303Z\343?\245^C\357\255\250\344?\304\344C\373\336\343\347?V\2221\033\235c\337?\300\005\252J\262\350\343?\230F\305\".s\272?<\312ui\201\'\330?p\341\027\014\210\243\342?\023\330i\002HB\351?H\3009\320M\325\302?\000\302S\252^F\226?\244\205&\301|E\324?\034\t\343\333\211=\337?\032\375Xn\374\017\336?\300\370c\200f,\352?\300\213\354>\336\220\262?T\313\331\214\304\320\353?\274\352 F\021k\347?}\353\r~\244\335\357?\300\367\252\371\014\362\277?@\2343H\303\270\300?\317\377jJ\231\022\344?.\023:\300qP\357?\310\326\214k\327\351\271?\206\223\302\330\350\024\351?\206\033c\324U\254\347?:\323\351~A~\331?\000`\205\352>\372\300?b\275\347L\337s\342?\264Z\3270\213\002\337?\372\342\037C\236\335\336?\360\303k]d2\332?\220\254\240\252M(\322?z\006\324bb\362\322?\265\277\277\213\213\000\345?\327`\206UXB\351?\337\331\276\341.\230\347?\\\374\270Q\016/\320?<\352\300x\014&\305?\336\t\026\231\377\217\325?h\267T\212\265\345\313?\004\213#\342-\337\317?\3521x\224U\353\344?\334/\345\262\312\205\304?\300p\262\242}<\230?\270\366i5\202\260\340?$lV\306\344\265\342?F\337\252\251V\255\333?\032\251\375\3625\375\322?\223\022/\233%\321\343?\365\265\023\336\264~\340?\202\224\030\236\205s\320?\330J\370\373\022h\335??\276\344b\273\260\343?a\3550g\000\321\357?0\250V15t\244? \355\201\007DO\332?\177\007F\261\036\213\344?\210\001x\226|\032\346?\016\273\027\234\253\025\344?\n\016f\341\235\025\344?P}\343\351\t\t\331?\362\276\376\341\302\235\327?\350\ni\257R\357\304?\024\276\233.\035n\310?\314\373\261|\014\351\344?\240p\266\224\001 \271?\357\032\305\"r\210\342?\017\352,\367M\317\352?\255Qv\227\277\265\350?\316~\336\020\375\242\322?\244\2456\177u%\330?\260\344;\031,(\305?h\270\234\226\322-\320?\262\000yJ\332\227\323?\000\243\"\024\222\343\344?6F\202\037\250\373\321?J\201\355\244\266\031\356?\004\274\312A\350\322\322?\314z\245S\020\262\317?\350\3449\2070N\330?\260\343y\034\'\256\357?Cw\254\242+\257\343?\212\371\245\334\216\027\352?\206\335\001\232\353T\325?\342\n\004\261\315n\343?\2765u\207\313\245\326?\342\213YV\306`\347?\334\032=\014f\310\335?\312\316\005@\241\203\323?\016\322\216\372\264\302\355?^-\230)E\303\345?\316\017\344J\025\201\342?\326\322\303\342\"3\337?\213\3705\0076&\346?@\243\232\307\344\200\261?8\313\032\243wQ\346? ]K\231\226\205\235?\304\330q\376H\330\344?Au/I\211\242\354?$\257\323\350_&\312?\325\227\273@\261*\341?\3203}\375\236\241\266?q\264J\341\373z\354?|\234\367Y\250!\336?t\300\016Z\342_\326?\\\301\266\362\237\261\346?\300Dv;\232\334\273?\020\372\262\026\005C\322?C\343\255\004B\230\354?\020\353\273Y\223\370\352?l\274&\013/\235\330?\373\336\303ZP\325\352?%6\364\256\202\217\354?\300\225\231\204\003\215\231?\220\227\003\223\002\372\245?\366,x\210\314J\325?\020\225-{>\035\335?@\3258\003\023\313\324?\343\351\032\210\263`\354?q\347\306j\233\234\352?z\323\222x\276\271\323?\024A6\212\374\020\336?\300\273K\301U#\256?x?ji\327\235\314?\024\317\030fc\230\311?Z\234\363f\315m\333?\000Y\240\304\tb\257?\010\267\035\371\211\311\331?\251\374\306Q\036o\346?\367J9\033y\327\350?\013#\275\203%e\341?!L\3722\177t\357?I\017+\022\277\t\350?\360@\225\373\027\254\311?\3109Q\314\026\r\345? \301\362\324\336D\355?\230\002\267e6\307\325?\\\342x\235\303\224\326?\274\316d\373O\005\340?l\032=\354#\221\347?2c\352\247\037\005\337?\230\027\031\317x\206\325?\367\241\241g\246\203\356?09d\177/Y\343?\255M*o\315X\345?658\"\321\231\332?\364&\014\223\323\316\327?\276K\240\022\326z\331?,\211\211\2152\357\333?\324\261\313\225\010\260\351?m\005\235 \265\243\350?\000,HX\217W\316?$\320\356Gd\025\326?\260\337\024\242\017\326\353?\240\035L\003\322<\250?D\026\322\201\2136\351?^\320:\255\365\307\324?\325\307P\206D4\353?\210\237H\250\357:\270?\200l\231\006\353\364\220?~\337\201f\253~\340?\340\247\346Y\261\350\341?5\306\266\233\250X\353?\017\337?\227\265\254\353?\016z.\021YM\323?\364z\205S\273y\322?\230\233J\re\240\271?\262\t\327\0212\221\320?\357\267\210\246E\"\357?\305q\320\245\270\322\354?>\376\303\205\305\325\354?\342n\251\226\313\002\343?\320o\315\367\204\225\327?\037\205\"\002dP\352?\032\010Z>\370r\345?\213\005\213\336\020\211\353?\332\310^\2768\233\336?\272*G\206\377\'\335?}[\003\2552\310\352?g\363\007b\252\356\353?H\324\227\301RJ\352?=&\327\370}\374\355?L\363\016^\232:\331?|\335\0074\237\006\353?\300k\313\337\370\031\322?zs\263\333\007\256\326?\363\027\370\353\014+\345?\374}n\255)g\330?\226\227\0319\223\317\331?\370\333\221\300\366\350\317?\276~_%\371Y\347?p\177\274\333\345z\356?\3249\036\356\244\255\310?\214Z\010\374HD\335?\270\230\315\004Q\257\332?\265\247d\317\327h\345?\352TV\023J\354\322?L\225Q\357Z\267\347?A\006\255GJh\352?Q\367\266\344\357\273\347?;Q<\227\250}\351?R\025P\031\010r\356?`P\023k\316\243\226?1\030H\231\226\002\354?\257\204\374L\230\n\351?\2304`\364B\242\270?\032\341\267\370\207\366\322?\374\345\306@]I\331?\341~\311\262\323\214\340?\000,\262x[\275P?\240\216x6\237+\251?\225\252g`\311\336\356?)\237\263\322\257\205\341?WK\356\013\014\036\353?\356M\277U\374\262\324?\203\2202\264Fk\350?*\002\325\263\337\331\345?\374\221\250E\314x\333?\344Cf\202\212\373\314?\022D\200W\220\305\353?\340|\330FIW\307?\3008\237\005\206\376\324?\300\001#\315\200\031\224?\237L\037\336L\002\355?|\343}Q\253X\314?\200W\t\354\202\362\231?\220\037\026\370\r\021\323?X\013\005\313\027\241\350?\372\356\371\351\301\253\356?\250\301YF\374\256\274?S\365\366\305(\236\350?J\354?\230h\302\337? \232\231D\030h\257?\005-O\337\013\205\353?h\276,\026~\025\300?\204\206\322|\255\344\313??\213\031:07\344?@+A\3634>\221?\350\277(\346x\306\331?\214m\230\363\277\214\354?\000\303\232|\303\"\357?q\221\303C\323\313\345?\000\333\250\222\024Z\273?D\321FCO\375\337?\355\023\037ufo\345?xU\333\347\351\034\314?\244Q\261\261\224j\344?\304\n\310\362\301\225\305?\276\254\3318\3463\334?\216\347\373\247:\327\333?\211g\313\300\347{\351?p\331\351\316\2175\263?\222\336m\037\327\214\335?9\240\262\030\243L\354?\'4Pt\316\021\341?P2b\251\223\027\334?\375_\263@\317\350\352?\000+\001:\252$\350?\230\376\333z\311p\306?\353\224k^\033|\352?\214$\213\273^z\357?@\323?\326\301\"\337?\rr\211\360\273q\350?\024\317H\345\374\242\322?\3327\362r%\352\327?\263(oa!#\345?\307i\205\327g8\340?\200\272\222\n\215\311w?M9\326\360\276[\340?\320\210\013}\226p\324?\037\003\032\267\252\212\345?J\324\253\017L\251\337?x\322\226G\014i\263?\003K\023\276\250]\342?\304?\320i\327\240\324?|?*\317\216\242\315?\366\021i\n\277\\\321?\324\251\031M\313\027\303?\345\035\221\256\363\027\353?\307\251\362}\203 \347?PIG\232X\245\313?\363\177\223)f\026\356? \213\316;\3339\260?(\022\264\231(\014\350?\340\341(mn\265\251?\326M\315Zzy\322?\26491L\330\311\353?\272\004t\222\272\333\342?\354\214\210\302{x\313?\352\242\000\232Vo\333?`\222\006\246\273d\222?\360\027\235Hw\200\330?\300\311\230\375\334\206\231?[Z\221\026\315\307\340?\370\262\256\\\335\032\316?q\306\224\306+\330\346?\330\312\243\t\021\'\356?\273\363\311 \217\226\341?\226A\273e\210~\346?)Q\214.\201\200\357?\004\201\260\355\323\355\304?\024\310\355{|\336\335?`\313nK\266!\251?QgJ\266\232\013\346?\373\355u\355w\254\341?1\253LNK\213\356?C\213\276\270\271f\351?0\246\275\263\371\201\355?\000\277\030\263\220k\263?\014:u\"\227\257\330?\360\241\264\374M6\334?\214\334\036\010@3\301?)\363%\302\201Y\357?m\233\305\016\237\320?\220t\3607\236\024\260?\370\214\\\372\222\340\324?\244{\0064Q\032\320?6r\001\220\213\035\326?\200\'S\2639j\264?@\351\273\254\331\355\334?\026!e)\301\331\354?\342\211\300\271%+\323?\362/\327\213c\320\322?\314\260z\212\320.\326?\020\252N\"\373\344\251?d9&\322\246\346\304?:\255\217\026\200 \335?p\303\372\251\331\306\267?\022u2\324D\205\322?-\265\356N\252\201\345?\370\016\273\363\276<\312?\334\246\255\214\374\002\313?\022\271\300\315C\377\350?\320\027\177\216\242\035\345?\200.+\221\326\206\212?`\367$\267b\007\353?\243\261\002X\370Y\354?@\004t\254\360|\244?>\332>\352\321\224\337?L\245m\356\000\000\317?V\250\037\253\332\333\321?`\324\211\374\235\362\230?\370,\001\316\264\001\270?\033P\363\313\333\n\354?v*\347\344 F\320?\212\346\r\027\201\224\346?\226\213WGF\024\352?\2121@\376\205\326\330?K_\271\3747\263\341?@!\263\262 \355\202?\356\271\307\310\270Z\355?\000^\222\024E\276\312?\306\311uD\317\206\320?\376!\003\3070~\322?4\314Gd\313a\353?I\260\005#\013<\344?\344\033\334_\221\340\310?Q\241\372[fQ\343?4\331\341\301\332\254\353?A\370\251{a\367\353?\256\313Iv\211\220\351?0\325\270D\304\200\306?\320\235C<\276\307\326?h\332\356]e\343\307?H\241G\312\200\001\332?v\246\373\362\324\330\343?\202\263\024D\343\007\350?\3446zT\315\214\351?\370\002\033e\342\201\346?\226jK\303\210N\351?\240F\006\372\315\212\256?\n\304\322\363\331D\340?n\251ucr\240\321?\360\352\370\204\272\262\347?\320\352\337\023A\016\347?D\342\013h\343\276\303?\320U\303\261*\273\262?\002=\371D\226\344\326?\340\214\3633.\226\266?d-F\225\356\320\346?\220v\242t\365\200\254?\007\335\346^{g\354?\326\266\362xFR\326?\374\tcr\373\362\345?\340/\254\215\301\360\305?\320\2337\306u\270\333?\233\005QC\032 \353?\277l\330\225\034\376\340?V\315\3373\302q\322?\320KqB\\H\324?\t\215\252\347\323\t\346?$\255-\257\336/\326?\234\276\277\305?L\362\347\370_\247\322?\252r\372\203\tW\356?\300ey&\353\234\353?+\036\327.M\300\353?\350\036\302G\213\310\260? y\214\016{\277\270?\332\032/\201\241\003\320?C\022\265\007N\273\347?\026IrX\227\356\325?t\031)?\031*\352?\234\327\267\360\217\207\315?l|\005\257\236G\343?|\243\302\237\234|\302?\202\331\024c\376\311\330?\027G\365\300\204V\350?P\022a\377\025\221\266?x\204\326I\334\307\305?Xb\177\262\263\325\345?(\034\223XI\n\342?\000\005\356\325\342w\323?\244\034\211\221bR\347?\250!|\272\341\324\275?\216+\235x\\\204\325?\262\266T\233\323i\356?\316\005R\270\2516\332?`\356\r\tn%\260?\335mb\320o\363\357?p\235#t\276q\306? U\320\003\344[\332?~O\266\340\341\026\325? \0219\346\311\003\242?\3041\r\201\207\314\342?\300\321\234:\222\341\311?\264\3607uH\277\331?\374\270.\374\200\223\330?\3748\005t\373\235\314?\256\212;\021\262\035\340?\237\\j\253\0174\354?P\345\035\031\376#\356?\340\341\030\317\335B\257?\210\373!\022\377]\337? \203.\371{\343\304?\303\226\023\325\n\315\341?\n2\360\371\352\204\323?\242\037\'\333\362\001\323?Y\246_\326\025h\357?\200\334E%\330\227\201?\020x\024|\221\344\246?\364\021\373\2161\237\331?\3072F7\333\002\357?\274\206\255\343\350\302\317?\034\363\250H\360\270\336?\312k\253\037\005?\325?\210\315\200Y\237\220\266?n\362\273\361X\334\320?w\360\267%\372*\340?\3206\033\035C\t\314?9C\225r\216\230\343?\217\314k\340OW\343?\352Hn\211;\242\346?\313\332\246\253\260(\347?\355\227\266\344\237t\355?\007\"\341\330d&\351?h\331\3565{\375\321?\t\353W`\200\366\352?\356\306\336\236\222\316\353?\203\332\371le\203\355?\3014\270\327b%\356?O\360,\250$\233\354? \201\256\2203\252\355?t\377\357\004\tp\317?\212\207\033k\317\227\327?\266r\226\311\265\362\320?\320o\375\317,\315\267?U\t\236r\016\r\350?\234\\\364k\204\026\310?\374Y\305\212i\005\331?tO&GqK\354?@\203sV\310\264\342?\017\273\366e\224 \354?\350Z\210\272\364\232\310?b\007\225\363\r\335\342?\362\370O\321} \325?\354\231wB\324\246\343?\\\327u\204\\\364\353?R\337 x$\022\343?P\271\2622\265\257\351?S2\315\272SV\340?\324\206\230\341\276\030\314?\262\210f\2638\332\354?\252\364\301\330*w\345?\220\247b\335\351\032\246?\272\211\270\205\255H\347?\310\243G\2215<\274?\244\360\362\353\2117\314?\344dX\311\237\020\352?p\303e\323\344|\315?H\t\314P\026\315\274?<[\200Z/\315\341?\3307\223\2477\221\275?4\276\267\3128?\304?\367\324\023\3514\000\351?b\035H9h\223\320?@\000\036\244K\005\331?\365\033\301\032,R\357?\265\335O?\234\330\351?\326\326\337\037\276\313\346?\270=_Y\335\336\341?\322\013%#]\342\351?0\356+\257+\344\307?\260\314:m\236\341\325?\210D\347j%>\260?\300M;#<\341\232?\001|aL\234\r\347?&\\\227\026\263\020\352?\340[\240L\222\277\257?\376\254\013\263#\275\345?\020H\254s\374\247\255?\370\253\373,\244\224\307?\000\365\026V\313l\206?\334\'\217\230H\266\313?\0233\312I%\345?,\020\303\202\201\016\346?\2448\221\242\300\267\330?\300\333\325Z6\221\351?`\026\265\207\014\251\355?I48(\003\210\351?y\254\362Sd\270\354?Xk\300\202\376\355\277?\010;\262\0163\316\355?\205*B\372\232\343\354?H\020\343XN\010\303?\220\020r\202\000\352\267?\351\\\237Q\373\212\350?\333\376C]V;\345?V_\202\363\377\330\324?@\032\356\335\304\357\254?`\255\330{:\252\304?\206\321@Wd\010\337?\3560\314\3509K\344?\225\177^\312\233\375\351?\347Mf\0146\367\344?y\342\327\352\265\r\357?X\031\327i\355\312\334?8\025\372|\232X\263?\321\374M%\200\274\350?\'\034\rw\020w\347?N\313\3303\370\355\333?\003\236\366\031\314\216\353?\236+\311\247\306L\351?-\017Q\273I\005\357?\342\277\233\002Y\314\357?\277\324xV)8\351?\366Y|KZ\213\322?\317\020\374l\336\003\341?n\356\372\014\300\224\342?M\210_\317S\352\354?\376*m\362\314\355\354?\352\257a<\250\221\333?\021\037V=\266\357\356?\000g/[5d\335? \212\313\256\246\357\276?\020O\301\316gs\301?\220k4/\363\211\355?\032\001|u\260F\341?\254\264C\211!a\333?\220\261I\010C\201\340?\"C\266\244\021q\354?\243\277/\355\022\266\340?\006\220:\264\3779\355?\330\3754\263\204\314\320?\324K\265mE\344\307?m\020\346\"~\325\346?\030\362\262\202q\302\336?\244Z\341@\320\341\354?\323\201\013`\302i\342?J\250Vc\263F\337?\242\347\2139d\330\346?@af\353?T\253?:\014\203S\3474\326?X\200\204\255B\033\266?\340\366\020\206(E\303?\000\335Q~\003\275p?d\014\352\324\215\343\316?\270\006\305\250\242\333\354?\204@1\267\221\325\346?W\367\025Y\273N\341?G\307r\341`\206\342?`x\264\244\tB\242?\272\220\206r\362\246\335?P\237\302\220\336\211\264?\211\334\001\034\363\034\341?pF\225\016\004\355\244?\326\377\374\223\',\335?\222m\326\374\016-\357?P\362\237\316\207\343\311?\260QU\263\307$\274?9\224/K\'\331\357?\214P\034ZLl\353?\324\'$\340\340)\334?\360\211\3775\022v\313?h@t\325\314\237\325?\000L\006\320#\177n?w3p\026w\022\341?\210\363<\'y\243\310?\024\221\177\345\375\255\323?\324D\327\360\022\355\311?@\237{\374r\370\242?\036v\376\260\201Y\336?\020e!^\267\223\262?6\243\3200\242j\322? \212F\373\321\031\256?R\261\007-+<\335?<\360\320u$\273\356?\374\300LP\263J\355?\242\306\207E\300\274\325?\320\214\311\257\037\365\243?h\016\224\222:\323\354?\356\205\3309\306.\327?Xo\"*\te\336?|p.ht\306\304?6\n\"N\320D\322?\023\251\031\004\242l\353?\340>\2345<\325\246?\000\030%!\356\014\331?\365L\210Gl\340\355?\250/\013$\263\013\271?\024,\3741)\000\323?0\367N\026`\333\317?s1\\M\334\005\345?.`\"\256-\257\345?\035\255\027\345D\023\350?\240\325\251s\021r\347?`&u\224\373\260\265?Z\347\257{t\016\335?\362\325\200\031\322l\320?\3004\300\364\351\234\201?N\362|\311N\013\332?Y6\354\240\373W\344?\240i\373ed\340\334?\031\265\363}\306\204\342?\000\027f`\232\213\233?\326\275\321`?\214\331?\346\020Ndo\240\341?\200\324\351T1\233\352?7^p\0141\300\343?S\243\260\3635\356\351?>\211\022\241\240e\355?\240\373;\364\276_\254?\250\321\217*\206\370\326?\202\350o\2352\232\344?V\034\355\317\034$\343?\020\364\212\264K\311\243?#\272\335P\240\257\340?\n\332\236\323\2239\324?>Y\214\260\024\370\347?r\372\241\316\234\341\330?\236e\220\233A\033\354?\327\356L#)\232\357?\360\271FI\257A\253?\000^SP\236\036y?\031\253\017u\325\323\341?\".\031\337\036T\342?\364\317\241\201\341\331\347? \331\363E\017,\327?]MV\201\275W\344?r\335{Pm\377\354?\255\263d\222\360\362\346?B\266\323\212L4\322?\210\335\250\001tg\265?\2002\336\035\253\206r?[\367\303\335|\346\344?\306\320*N$`\330?Q\030\233\300l\231\352?X\024h\215a`\324?\327\236\224\035\030\351\346?x&\272\220e\033\355?\357\271\323\304\231\211\355?\347c\004`(n\347?\266\347\207\210\014\276\324?\206\3473\224,$\343?\360\311)l\236x\265?\230\035\275E\353O\274?l\251\237J\320\245\343?\243z\030\013\272\327\357?,x2\022\202>\335?\340\212z!\213%\245?\266\332P\034\372\350\336?lA3\205\322\024\326?,\352\215\307\265\023\321?\210J\'\3556\313\327?\252\325gt\267\371\354?\022;\211\346\307%\323?V\316\222\211\036\246\325?p!\246\277=\353\270?V\030\036\265oU\335?\200\300z\252/\rz?A\242\021l )\356?\300\373U\355\207\037\273?\205\320\250\265\t+\355?\350\021b\022\236\202\356?\026\262\002/\"\003\352?\004g[\346\247\311\333?\000v\212\356\317\237p?\032\214\325\326Y\005\355?\350\354\235u\227W\305?\324M\005\356w\224\311?Rw\266\323\260\310\343?\237q2bl\246\355?J\025\262D \3220\037\346?\212\014\364\006]\345\357?W\363\206b\237}\343?\244\226}{\200g\320?\320U\241\3264\212\356?\206\357\010I\342N\342?\025\014/\242\2763\351?\310\177\212\230\344\250\351?\320S{\n\375\347\276?;-\377,<\272\346?\260\332\306K\210\333\276?\362a\247\rm8\340?\356\314\315\227\263\377\353?_)L\204\261\'\357?\370\022\035\315\205\224\330?@\253\242Q\354_\344?VV\\@\032\031\353?\005\0043q}\371\345?\342(\327v&\256\345?\017\251\201\314\0357\353?2\037\3101\267\022\346?\320\277\345:\312A\304?\320\254\240\347\276_\241?\302\216(X\005@\321?\005\014\002.\t\244\356?\302\227\266\006\223\356\350?%]\374\301\273\351\340?<\366x\274\021\253\330?2\225 Q(E\344?\325\327\353\275l\021\354?\026ch\366\211F\333?\314\244\252\370\246\316\330?\311\301\005:\314=\341?\262v\245?\352B\331?\260\316\266\367\r\252\264?\350N\352^g\271\317?\206\\\264\340\362f\353?\220\027\243\320\375@\301?\202~\360\316>G\342?&\276p\224\024~\325?\023\255Q\024\355o\354?\224d\2305\221\334\307?`Aq\375\264\031\224?\236\313\030\337\260\355\350?\n\220\243q\221x\327?`\301\2232\003\026\273?\304\230\343\330v\"\330?\260\234\210\337\275j\325?K>\241\337F\265\355?\005\200j\217\370u\353?\330c!\276N|\342?zo\235\"\213#\354?\306\207QJ\036\311\355?@\322\265\211\320\255\241?\320CL\313\244\364\324?\204\315\n\314Y\361\341?.\371q\323\245D\356?\350[Bx\256\272\301?\031\r\362\204\t\034\355?\325\204\355\256\232\004\345?\032\247i\035\300,\352?\3332\274\020oW\355?\374\0335d\356\267\344?\312\205\360\213;R\351?\260}v\363#o\300?\354:\035\2455\364\336?\346>@l]\262\346?(\261\213\273\033\'\276?O\020!K\357\r\351?\"\331\234-I|\354?\341\010\252;\222\265\354?\360\000%Q@e\343?*\034r\365oO\321?{\321\371h\242E\343?\350\021\361\246\314o\352?&\210c\370\237\360\321?\340V\254\344G\355\305?t\322}\2323\030\313?\220!hD\307\367\314?FY\350\330\367\300\341?xF\t\023\006 \313?(z%QHb\322?\000h\372b\364\225\311?\021\3714\265,\307\341?!NT\226\242-\340?d3\216v6_\337?Q\347\263\224o\372\345?,\366\223\023h\215\347?\r\224\237\372}\206\342?(\303\370o;\271\300?\221\261\315\335\211\223\355?*0\237q\"\013\343?\200!\323a~9\275?\273\346\322`\365\306\341?\236Y\026_3P\356?\335\344\351\3509\031\347?\020\023\374!8\177\320?\020y\203un,\267?\020\203,\327S\027\251?p\201\016i\341\375\335?@\237\324\200\351?\220?(\275jc1\271\324?g\357#\250\037(\345?V~Eq\325\024\324?\034YU\210\372\207\304?pO\025\027\332d\340?\214\356\"\263w\217\344?4%i\332\004(\355?Z\230\207\235LK\345?\254\002\355M\300R\327?`\256O,\235\304\265?\300\233K\n\307\323\313?\240C\035\265\"[\272?\246\340\206e\006\225\342?B)d\243\335\224\334?\231E\373;\021u\341?\330\2731\354\300\007\274?\376\t\013\'\260)\355?\316\350uW\277\007\343? \276e\216\231%\330?\366\232\262:PF\342?\334\235\033Iw\360\356?\251\201O\356\214\003\343?;\311Q\233\305\233\350?\206s(#\374\'\327?\234\201\273_\223\263\322?\373\177V\005=\337\350?\204\302\324\034\355\324\347?\036\245J\313#\340\337?@Z\366WS\330\237?\324\365g\253\037=\304?EF\n\203\246C\345?\323\306\224\016F\365\354?\017\263\343R\324\231\341?\000\345\327\371\343\307\257?}+\004t\277.\340?\321\251s\304LV\351?PT\200\326\\U\321?\264\354\222Z\234\222\341?\022r\n\270\305\363\323?\215\000\014n\225\212\354?\2318\246P\033\303\350?\014\225V\371\034\\\305?\240z.=\\\272\274?\344XJ\261\230x\332?k\300yT*0\352?\232\023\207\3538F\351?\210\2108\267f8\304?\330\001\022\212\224x\332?\217\204\252\026\260\367\344?\266\326\263\230\035\344\325?\364+\277\021\022\316\352?.\254\351h\362Q\345?\244:\320_\216:\324?\241O\274\034\024|\341?f9\037w\322\311\342?h\224\243\230\236\346\331?(\000\314\275\211)\353?\332\255N\001\215\356\330?\365k][\3058\347?!\377;\303\223\263\350?\034\332]\214\315\227\307?\236D\3665\202#\350?\220\304\335\271}\232\342?\347zh\266\2400\354?l\342\236\360#\240\344?:\234\331i\360\271\321?\024E\200\330 b\303?PdNV\0141\347?\202\310\261\316\234\005\322?\243\376B>!>\345?\324\334\253.\036\331\320?\366\253\337C\033\026\327?\3008?\362\247K\233?IS\245\200\341\346\340?\310r\343\311`Q\262?\n\272\360fB\257\327?c1\222\374\177\346\347?\262\376[C\354\022\337?@\016WW[\347\223?\336\373\227\275\033Z\322?6\233\345\2744\312\342?0l\241T\237/\316?\334E~\346G\300\307?\215\356b\005\341\327\354?\234\321\030]\351\242\304?\236H\014q=\231\357?\250B\376x\335F\261?h\213\305\257\023\310\322?\321\330C\036Qr\343?\270Px\243z}\267?\334x\177@\251\270\331?\005\376F\331?\\\357?\t=\344\216W\355\356?\352cg\273\020_\345?f\377\360\360L\376\335?.\346\007\034\3361\323?\202\225%\006(n\326?\340\211\223\202\374p\310?\210eS\037@\357\271?\264\232\273\244\216!\327?\004=\tb9\376\333?\216fr2\253\"\346?\264\347\314\217\355\227\352?\314\256q*Sh\351?0<\305\010(\036\322?\226\330\313r\220b\357?c$\355\322\177\"\355?P\342UOx\214\272?\034\361b3_\005\335?\035\224mp\361\234\341?\236\213,M\203\242\332?s\370e\366\t\312\354?I\350r\221\355\362\347?\020e8\273Ca\323?T\036b\354\343^\303?\334\235\363%\375\253\347?Dy,&\301\277\350?\345\327\2124\261\347\341?\260\340/mB\"\352?\374\323\021/\n`\327?20|\032ga\356?8/\001<\200\351\333?T\003\256K\324\246\310?\025\252\202tl3\355?t\235L\325l\305\307?\200\217\243{\217G\267?\344W\037\033\346n\320?&\250\303\341\314\227\342?\022\237\237\010\034u\353?\322_\212tm3\340?\217\267R\005\263\345\345?\200e(\025\252&\266?y\364\227\245,\022\345?\260ke%\346\322\317?|\277\016C\253\035\336?\300:W\201\223l\330?\214\2031\235,\204\353?2\214\314\026\240Z\335?Pv\273\373\371\323\357?\223\270ud\304g\343?\205\270\343>\276<\353?\354\275\r\247#\000\357?\355\263J\215\210\257\347?\250fN\272\200\343\276?\300\313\245n\320\245\344?\345\362\347\212w\353\354?\265.\034\366\311\211\353?H\311Kd\215>\277?\234)\211\261\365\362\312?D\004\201]L\225\310?b\375\306\tI\321\340?Vbx\252\201\213\324?\204\230w>\264\346\324?\376\030R,w\344\326?\022\255\374\220\320\010\330?\004`si\206\353\301?\213\214\363\343N\237\351?\232\320A\226v\'\352?\n\2736\273\254\343\325?\022\016\013\t\226Q\337?\350t\223+G\226\333?T\262\215\271@\324\320?\340\317d\336\240\264\260?\3334\260C\237\252\350?\022\256{g\376\014\324?)g\366\252\313\360\341?t~<=\316k\342?o\377*\255\3501\345?8\034=\254^X\326?\325\271\251;`\367\344?\030|q\214Z\205\311?\212{\302\376\233~\347?\205\004\334\336\212\251\t\274\330?@W\316\336\377\221\273?H4y\304\tM\305?hC\36661\014\350?\205hP\000\275\317\346?\243g\343G(\032\355?\310\311Xd(\235\273?\344\362\265\000\206\347\314?\032\0073\261\270\204\325?a\t\262a\304\256\341?T\335*5\037\'\340?Jon\252\210\352\320?\300\310V\206k8\330?\340|\265\340V\234\302?\340\253\230=\263\327\265?\350@\272\211LB\313? {\237\357\321\370\332?\224f\217\231\247/\303?\200%\035\3137=\267?\354 \244qEA\314?\310\014\034\226d\362\347?\026\376\256]!\222\343?>\351\0251n\001\335?o\312\303]\025\231\347?j\256\267\'\020\r\346?\240\257\336\0228\273\311?)\206\304\262\232\251\343?`j\016&vx\244?\316\314\346\255\223\232\340?\221\304\360\315\237\204\340?\330\347\037V`E\356?\337\307~\215\350~\345?@\324y4}G\247?hca\270i\326\345?\006\034\3474\257\346\335?\274\317\347\374W\340\305?\010|\362\026\267u\325?PO\013\343\315X\323?X!\023\367\001\365\343?\361\0350\007\241Y\357?\250}yk\376\320\270?\313401\014\321\354?\227\010Nz\203n\343?\277\005\220\000\265\024\356?x;\256\217\373D\311?\364\216\000\305\021\017\336?\330\177\324\220LM\322?\270\375\036\213C&\332?`+\254H\2464\355?\264s\036M\r\341\342?\216x\334m\347\250\347?\310\377\364\357\330\363\310?\337\001i]\221/\353?\244t\323\332\220\356?\370\263|e\210\005\262?F\332\031\034\222\275\354?A|\344T2{\351?E9\341\300`\255\355?\210ZF\014\307\344\317?\336\361\332u\354T\332? \362P\223\302\300\333?_\002\251\273\313e\354?%\013\300\205Mm\344?p\340\037\326\037\367\267?\226\367\212\215P\321\323?|\220\351\033#\263\333?Xu\001\332\037\307\357?\306d\020y\354W\322?b\2734\364y&\346?\020\305\225F\332*\265?2\211\335a\306\341\355?\034\276&\312\233\331\327?\340\264\271-l\211\266?\246\033s|\364P\323?\222$\0031s6\345?\314\\#\253-\006\355?\235\267\312[\332\033\346?\204,\000\236-\013\350?0~\030\276\300B\245?M^\325\215A\300\341?\252\016\315\t\263r\326?]P\357\n\347J\351?\310\327\013?S#\301?\305S\246\023\255\273\355?t \032\366h\254\321?\256\361]\300\256\324\324?-\006\2629\336\234\343?0\0060&;\310\343?\367m\225\373\212\277\341?\240\"\337\370\265$\225?\013K\277\343\027\230\355?\310\310\"\003\"\005\263?D\216\213y_\243\324?\3001\225ru\331\226?,\363\350\351\324)\311?\315p\025\345au\351?X\365\276\246\272 \345?<\245)9\241=\301?\006\277\225\0068\242\341?\000\323\211\206\002\374j?\342h\t\301\307\377\320?\226/\030\225\350\203\342?\355\n\224\3262M\345?\360\010f\t\314\001\332?\270\230\302\364;\374\276?\350\311\366\034\303v\271?nwa*f\341\331?\300\270B\267\246*\260?\342{p!\266\014\357?`\370\367by\000\337?\212A^\3202\r\333?\024\001\302\021\006\352\307?d\204?)\247\004\305?\031\346a\2308Y\356?\\u\241?T#\341\370\205Y\316?/\354\324\324\233\347\342?\332\303R#\313\372\334?\260\361>JI\375\354?\230a\222\"Gc\306?\271\010\000\211\316\211\341?L\207l\202\350~\313?\305Yf\211V\330\355?\020\333a\365\210e\341?\306<\225\246\376\341\352?\261\033\337\263\314x\344?{x\233{\022\375\352? !?\326\336/\266?\304\234\026\223y\312\310?\313\211]\316_e\350?z\315*\266\032J\345?\216\021\364\261\0069\333?\030\020\013\333\200@\342?v|\332ao\377\321?\320\203\017U\372`\343?*\330[R@\207\337?\374\327\304r\025\223\347?\271\351n\330=i\357?\362\365\342[\355\232\332?\274\272\355\306<\035\306?A\330\262\007D\226\357?:\313!\001\222\024\353?\226\272\302\242\336\203\334?\225Z\375\275\004\277\350?t\330re\371\266\350?\034,\271\224\000\341\311?d\212\325\322{\r\331?h\260\374\245N\262\357?\036o$\226\007>\321?$\006$\004\216\t\305?9l>mzh\355?\236\326\0250\220\306\337?\327\232\342\232\304\212\347?\002\022Z0\346\277\357?\323\332\036\307C\002\350?\372\230*\336HJ\342?\036e\252\3427\230\342?h\312j\301\265\343\261?xMa\374%w\311?\3303\361\242\211x\354?\226\207N\234\203R\322?\202d\223\025m\261\326?HsM\343\241\177\351?\302v\256\362I\246\327?\220U\230\312\206=\354?\247\3703e\241\236\356?\"\241\327N\354\347\352?\000\255\3556\325\002\253?\014\006T`\034\344\304?p\017\215^\016\020\354?iZ\241\005N\366\345?\344`1 \324,\313?\310p\027\016\224\305\336?Xmzr\305A\351?\362\330\376c\335:\323?\000\237\230\360<^\264?X\327\270\256T\312\351?`\244\341\005\363g\355?\372\344\362C\t\330\\\315?z\n\207\232G\357\357?5\n\315\262+\330\354? \200\027\266}~\226?qr\336\201\033I\357?\207#Zs\357\215\350?\310\234r:p1\277?\300\014\rz\250\276\236?\257FW\331\310\227\357?(\206z\027-\356\262?\036\"\233\350\\\036\352?\250Nu_\353\016\263?\354\024N\304\252S\345?\254\336\372\275\253\266\336?\300\r\354\222J\327\357? \204n|\000W\236?\374x C\257\306\310?\326,\241\310W\353\355?\'\254*J=*\344?\250\347\243\204\206\032\345?\362\220\347\377\220\037\327?\'<\363\223\315\221\345?\340\255\365\264\005&\344?\342#\376i\241\210\344?\244\254 \265mK\335?\014@\027\250\332\366\325?\020F\200\214\324\351\336?\216\013\277k\310\022\344?\356\262X.\257V\330?t\373^\277\314%\301?|\313\353\205\2352\323?S\032\226\355\305\327\342?\320\252A\357B\357\242?29\327\236fp\332?\240jV~\235!\330?:x\200\227\202\372\321?\362\023J\213\304\257\325?If#d%\357\352?J\237\330O[\373\346?\242\350\277|\034e\324?B\345\"x\332\'\330?\000\363\010\212\340\232\340?\230\355$w\232\344\263?\177s\330\031\312H\352?\362\220~\245\271\265\323?\327\'9\332k\000\355?\344\257ILfQ\336?H{e\035&\005\337?JQ\375\026\234\345\353?\267?\316M\007\317\355?\034\310\354\0338-\346?\222x\331}\333\230\352?\230H*\262Xa\271?\337\232\310\273\247\236\343?\316\017Q\3216\320\327?\214\306\362\003B\246\350?j\306\030\264U\024\320?T\036\244\n_N\312?\311#woA\322\344?\312b\322o\026\344\333?\370z\234\357\254$\357?\373\213m\221\021\276\343?\264\327A\361\356\330\352?\300jd\301\322\356\252?\220\016*\020m\026\304?\027D\313\324##\343?\330kn\220\201\325\346?:\245F_\342\002\326?\332\242L\225\371\t\327?dq\335\'\305\320\314?\230\034\304\334\232\351\304?\000\201\001\274\0109\226?\005i\336)#\334\341?\300r\021)b\314\323?\333U%l?Y\344?\275\245\224\360?\367\357?\354+S\220\263\360\350?\024\277\221\273#\364\342?\004=i\315h\200\350?@\305\025\243\210\355\321?\016\375\331%n\270\324?\365m\270,Tl\342?-\004y\265\301\014\354?\256\311*\324\013\264\342?@@L~c\225\223?\346\335\354\372\212s\330?\260{y\010\244\216\341?@\344\t\337hx\274?\211\301\224\r\341\372\352?\370\305hX(\375\271?E;2\244\001\242\344?8@\220\2463\215\304?\013X\364\027n\307\356?\323\247\251\300\261\003\355?p\334\032(s\260\266?\344\244{\322\226\205\353?\371\211\371\360\345\026\343?\270\201\313\240\021\330\270?n\347J\260X\311\330?\374g&\332lT\310?r)k\203\220\340\344?\216\306\340\326\026\272\354?6G\206\3653\007\320?\\0\334\226\264\212\305?\270`\256\2264W\357?x\013\264\236^\351\260?\270\010.\360E\361\306?\244D\026\273\256\243\344?\n\207\231W\345\214\354?Q\353~\251\203\366\346?&\365\344\242\216\007\350?]\312h\244\227|\344?0\320Z\016\370\207\277?\240\256\312\334\001\365\234?\343\305\334\004\014\341\344?H\367\302\372NK\276?\236\243\200\020(\317\341?,,\207\363\376w\353?\310\235\036\247\302\t\320?\004\342\3308t\254\334?\200\242\223AJ\376\274?\002\376\027r\245n\325?\'p\253J\315\211\350?\370\202\241`Gn\277?\353\004_x\034\037\356?U\305\3300\027)\351?R\323]J8A\335?\236\0109\3611\313\332?\226\237\220\037\376E\324?\336\310\361\255\r\270\324?\010\004IW\345C\354?p\256\177\347\340\007\247?X\271d\306\330\223\355?\234\200\365t\242\227\350?H\330\251\250\240\001\264?\320Q\246\315\323\016\254?b~|\247pP\321?@5(\212^\230\223?\304\312\235K1w\352?\250\240@\201.\034\325?X\230\036\227\230\022\305?\200\275\311:<\203\205?\207\264\343\266\246s\342?&\230F\211\263\355\334?T\367i\023\276\265\303?[\335P\3531:\347?\220\033\275\037.\371\315?\035d\317sx\303\353?\0017[\247\255Y\342?HU/\336z.\346?\344\224\003F/\274\301? %+\276\023\036\251?P\020&\211\341\221\253?\302\036\277]B\220\340?~\234w7\035!\355?\024\326\211\207\007\320\347?\032\376N\333;s\324?\250}*I\351W\307?RA!\232\311\326\355?V\020\005\223\003\270\323?}|\346\263\303\347\324?\362jy\031\374]\352?$\252C\274\217\024\335?l\304\1770\253Q\353?y\336\356\301WJ\353?\342\215\227End\330?\246\037\312\026P\323\320?\360\247\374\025\332\347\273?@n\177\017ze\345?\'\351\"\255\326d\342?/\251z\304\355\375\340?\211\325\240\3159\334\347?v\311\320\330a\201\341?P\247\240\267\352c\303?\250\201\017\263\227N\324?\254v\364#`\375\303?\322m\344\327\230\024\355?v\376BE\306\261\332?T\272\013\211\245[\330?\251{\335\252\r\305\352?\373\253\333d\013\340\354?`\202\313\306\267z\274?&F\304_\334\200\335?\252\014`\210\373t\354?\220\244^\313I\351\244? \243\"T\020\355\241?p\272\016\226\027\017\240?1~\377\003-P\351?\362\373\035za\227\331?\035\360}\354 \027\342?N\031\337h\030\030\347?\365\372\213y\205&\357?\272\025\377c\255\224\325?Z\025\216\346\354\r\344?\240\336\026\242\300x\222?\360\306\344A\352\\\272?e\256>}g\370\346?t\316D\nR\005\316?\000\257\276NFy\225?\002\347\033,\001-\322?\260V)\263HW\245?8\273\315nT\222\312?\330{\215O\313\007\351?\240\235\346?\200\324Z\361,\200\223?^\353g\217\261\211\321?\000\204\244\000\271\256\266?\322\261\233\214\262@\336?|\303\272*\275\204\330?6\3631\027\035N\357?\242\325\\\204\366@\356?,\274\356\276\014H\300?\244\214\315\342G\216\356?\021W\245\022\037\317\343?\324\263\326\203\247\343\322?\240Y\021;7)\331?\200!{s!]\340?<\205\367x\307\363\302?\227p\247\336\346b\345?\\\317\255\014\316\r\333?\252>VnH\206\333?@L\301\033ss\315?\312a\341\230S\375\355?\216\030\214\255\035\355\320?\352\266\346\252\354?\251O1\\\3118\353?J\320\032\241@e\330?_\017)\352\375\304\353? \354\202}\312\247\354?\266G\227\200\020\037\334?\230\004\243\253w\362\356?\370\320-aW\"\346?<\210\323\036\215\177\300?\030,\231g\264\213\315?\026\030y\345\246.\340?\010\226>\274\0133\274?g\263BQ\303g\356?O(\207\310y\322\344?\n\250\232jr\354\320?@.\277\317\006\202\242?\205\305J_\232\200\343?k\205,\376Lw\352?\342\204WbW6\325?\000\252u\032\r\356_?8\250\340a\357\311\276?@|\321c@\244\217?\260\256)&\315\237\337?\330\316\253$\371\224\264?\t\204\261B\016\273\341?:0/\002\343\305\346?4\317\221\334\265\002\300?\3701\325\377\022\001\304?\214\313\302\276Z\245\344?\006]\034y\375\224\351?\233@$\237\300\274\341?@4D\251\320\376\347?\316\323,\310\263\356\333?2\333c~\274\264\330?\034\021\350\031\260\316\301?\240s\255$\'\n\252?\355u\030B\035\235\340?\\\233\377\220\353\357\326?\344o\276\321\246\353\315?\230\370\027\246\244\\\354?f^\\q\003\255\334?\231\305P\350P\236\353?\006\331G\225o)\320?\306\214\275\223\014L\345?\363y\256\361h?\347?{S\240\247\305\034\355?Hd]\020\371\023\277?\230d\346g\361Y\310?h|\246\351r\243\277?\230MS\315\243H\305?.\216V\233\215\361\331?\323\276\335F\355\232\356?\350\232\241\360=8\336?\262\371\032We\201\355?`\347\030><\320\352?\333\031\302\0348m\355?\036\n.\032\320\237\322?3\353\200U\000\210\351?@,\266\301H\363\252?\210\2757\252/\204\334?w\245\305\330\256/\350?c\350@\312qh\341?\323[\232\306\016\255\344?&x\374\263\361\'\322?\234N6z]]\344?x\217\237dJ\220\323?n\275\002\311\341\345\323?b\373%\321\233s\320?\026\213tM\020H\334? =\302\271$\r\317?\226\030,\372\270.\344?1\247\204\316\340E\347?\364\250\036\035\320\357\315?\334\312\267cf\257\337?\234\202M1\341\364\341?\232\177YqC,\341?\177\311\177rBT\353?d\233\221\343\206=\336?\310\335C\211\035\316\335?\300\275rM\312\360\323?~J\016\204\003\362\344?\000\016\200l\371\371\337?\026\304r\275\344\016\341?\354H\251\315\220\262\335?\222\217\017\342aw\321?d*IH\264;\333?\322\0316\266\035\302\345?&\301w\342\274\352\350?\265\316\341wSq\357?pN\271\377D\227\341?\332c1\313\220\027\335?Z@\265@a\233\344?\033\245\260Y\002G\352?BM\216\262B\000\347?\252\300\307\223\020\311\325?\260\346\230\377\237\247\253?>\322\307[.\367\343?@\335\333\275\312\233\271?\010\224\251\210\222C\335?\270S\314\002\375\315\261?\254Q\235\000aC\303?\244-\233\2344\231\356?P\"\013Ae\r\334?\252\t2\232\2432\344?iH\254\256V\034\344?\014\325\215*\215\003\313?l!\255r\313\323\345?\210\364-x\346\363\346? \242\026\316$|\271?\222\204\035Hm\031\330?p1\\\351\036g\342?\356\241\017w\301\222\323?\240\201s(D\354\314?\302\261\341\246\036L\323?\344\343s\311\031\365\356?\330\202|\300\252\374\322?\301\342\213\034\344\303\356?\000\013\201[\336<\353?>@\213\313\026\205\337?=C\330\246\231\201\342?\347\335\3434;\361\347?\351|\215~\365\201\354?\246\002\271\3741\026\354?\300\216\230\035\337,\254?\215\017%\001\245\310\342?\214\330\375\t\005\036\317?\371\302H]\367D\355?`U\321\'y>\300?\034\350o\211\001\245\317?F\213\226\016\334$\327?\301\023Q=^\014\354?\3009\364\267\024+\201?\3249k\360zd\344?\350\030lgl:\324?\036\217\356\203\216\027\351?\021\246\355\302\242\371\344?\003\322\233\357x\343\347?\321\347\236o\017\267\346?\230)R\201\3518\335?T\317O.0A\317?\301\"\277\366\255\311\346?\361\343-\315\334\023\350?\010l\017F\364\263\262?2\357\335\003:s\325?\206JO\336\222\017\346?\200\276up\026\355\205?\211\262%\240\346B\343?\326]\224\t\204\275\346?\220\314F\031z\327\331?\240w\002\302K\271\232?\363\370e\r\324\314\351?`A\025+\245\032\351?\\\245\200e/\336\324?^\304\225\260\315r\342?^\356\275\375i\336\335?\026\333{\220\375\300\331?\206t\232hg\001\324?ey\352\3108\305\347?\352\300\323\212\364\336\341?&\306\311\334\336\002\331?p\250\352\341\244\231\304?\366\351Ud\226\222\331?\304\324B\371\266\363\335?Yw\343\032Pp\355?@\373+\221\017\366\313?z\243\313\276\001\215\353?\004\314\2318\203\350\352?\342n\271j\216\017\322?\343\242.\2759r\341?\204\220\322\364k9\337?\211\255\200Fl\333\356?/\362\221\003\366\314\346?\360T\2075xE\330?Xa\003\235:\034\310?\225\251 \344\255\300\346?\372\22306\340\365\344?`\277\215\004\362G\333?&\022a\225\233\303\353?b\353\252}5\223\342?w\251\236\225)\204\350?6\210\260\343\357\231\354?\330eaww\r\271?t\235\202\357\304l\341?\224r\002\343c\207\346?\320\371\ru\275)\345?^>\263\213\312w\327?\200\240\371\310@\254\240?\240\376hn\272\334\271?\310\000.n\236\343\331?\374\023\212\243\036J\320?\327\014\360;\253\321\356?\226f\000\035\276\237\344?\256P\377c\343\327\334?VN]b9\021\341?\227\303(\031\342\271\354?\014S\340\322O\337\344?@\355\275\217\233\271\215?\\\000-\243\035I\315?\234\226\007\r\006\206\347?\340\336%\345\240X\344?D\226$\376\303\004\333?\272\036\322\010\272E\352?|\237\2243B1\347?2\3037gJ\360\331?0(\326\311\0028\254?\032\357\327\233L9\356?\260\254\210q\204|\274?Y\355;\347D:\350?\270-\255\001\311\023\353?\200VzWd\215\246?\346\214\244\344\330\220\353?@\351\321\030\007.\315?\304x\255j\006\026\323?\000\304\036\202\205\025\350?\324\357C>\356b\343?)f\"\361\030\241\347?\332\336\242G2\225\352?X\312\344\003!\t\356?:e\025\371T\342\337?\356m\256\030?8\335?\352\232E\276\301\323\340?\274\262\365Jz3\327?h\304S\321yg\342?\317~\307]\353\360\342?\334\250`B\345\300\314?\360u\337c\241m\337?\000=\027\361`i`?\t\\\0318\360\334\352?\323\035\305\024\352\004\352?\370\263\205\344\220=\337?\200n\t{y\240\255?\2778\272GHa\357?\217\\\322\355\2759\340?\335F\234\346\3703\354?\224=\310\342d\315\353?\020\277\325\360D\340\301?\333\313neI\002\340?%>\202\234\202\020\353?\302\356z\355wk\322?X\3522K:\336\355?J\026\323>\272c\330?\360^(4\302x\335?j\217\336\263L\035\347?\274`ZPb\000\351?\222M\320\254bE\346?\314\216\261\204&\253\303??%5\373\000/\353?\202\204\326\333\273\375\343?\355{\324\216t\350\356?Y\nU\252\025\310\350?p2\236\310\3169\333?4e\363b\241#\356?\256nM\304\264\324\333?\027\244\300\277\037_\357?Y@\363\003\003)\355?\007\331\375\306M\370\353?\214X\005|\315W\307?\241\240>*\324\204\350?\027ct\365k\320\344?d\266\200\301\203\321\330?\372\323b\347!k\354?\320\276\244m!\352\307?T{\307W\317\214\340?0\31136\375\351\324?\230D\036\260\213\317\266?\354|\313_\214\t\317?(\303]\271h<\265?\325\037\022\225G-\344?v\347\225>z\r\336?b1\326\217\272<\355?P{=\257\267m\350?\320\242\216\224\006\314\345?\220\243\206OP\026\242?8d\013\367pW\325?p;\242\"\0020\254?\362\363L\337\247}\337?\340\326\235\321{!\244?<_\212\014\271\305\315?\374\325\332\224\225\326\316?\2109\213L\246Z\307?\340\352p\372\255\300\224\350?\221\303\376\235YW\346?o\335\352^\026\020\350?\202`[S\246\312\356?\020P\035\360\363o\260?\274\3131\372(@\341?\342\303\360\006C\240\352?(\245T\373o\376\266?\264\\\177\321{\314\352?\316l\3115s\355\336?\200c\023[\375f\266?\244M\014\322\211!\330?m\375s\216\010\177\353?U\233\255/,\266\352?\307xu\333\006\017\357?<\266\024?\021\235\305?@}aj\026\230\204?Hl7YE>\354?\016\016+Tn\033\342?\3446\205;>k\345?\010\300v\205\tB\260?\300LQz\2758\350?\353\345dB\216\363\346?B\272\004\320\307\036\334?l\222\356\333\224\200\316?\345\215\331\302}`\347?\352u\010\333\302}\346?\236\362_\360\366\341\334?\336\323\365\267\323\317\354?\232\'q\365\355n\351?\000\271Dk7\300n?\377\331\365\031 &\355?Da\373\357\224q\320?\303\']H\3535\340?\220DI3b\274\313?sNI\037\367H\344?\252\261E\302}G\344?\354\303\371\203\350\367\337?\363\277\027\023Qt\351?\360\371\220\306\351\261\317?%\351\"\310\320_\347?\002\\\3355\010\003\346?\307\341\231!\273\025\343?\235\331\256\335\245\354\355?\014o\371Y\325\357\334?@\340F\237\020\353\215?\030i\345m\222\367\266?\300tdN\337\360\274?\267\263\364u\222\310\343?\257=\004E}\373\357?\257\334N\265\370u\356? \r2\324D0\312?.$C\253\017\346\343?>\246\210\216\325\370\352?)\033|\307\276\222\350?*\220CX7E\322?+\203\213H!\301\345?\035\334\001\177I\237\344?\311JG\204\377\311\352?\232S\271\323\220C\341?\371\250Uu:\207\357?X\376p+\234\367\310?Ha\332e\005\310\276?\354\227\255S$\361\356?N\344\023\347\300\336\325?\214#\364W\3157\323?\343\335\352\207\0362\357? \251\304\326.\363\234?\360#VuP\306\323?\304\347\022GS\002\312?^\031\275zA\014\327?\237L\303\024\347\014\353?\301e[\213\262\324\351?,T\354\366\346\276\336?\216\n\222\221\266l\347?\322KH\232\033\300\343?\254!\303%\241\333\353?\340\025:q\r\217\351?H-9\271\331h\270?\2005^9\360\325\243?\356A{\'\0056\335?x\236\031c\0161\260?<<.Kc\266\320?z\354\235Df\304\327?\220\031\331\267\321\360\326?\224(\242y\270\234\353?S\374y\307!n\347?\035\021\272\206\352,\340?\307R\312\223\327\254\340?\300\265\\\225t\277\335?\350}\330\337\347j\355??\204\360\232\263\364\341?\230\270\027\213\365o\334?\210.\362o |\266??\177\305IAw\344?\032\312$B\364\307\352?D\026d\255\234\340\351?\354\371;\033\3363\321?\206W\264g]\331\325?\030\200}D\3112\310?S?\013G\2748\350?Pa\264\364\311\036\320?\303U;\344\244\342\356?\272\206\3767>K\324?\031\r\372\016\215=\355?D5\245q?\034\326?n\035P\r\3517\352?P>\033,\300,\245?8\201\000\304l#\345?|\341?\3305\235\347?\305?\317vc@\353?\300\223\356\333\377\021\224?0\267D\212\211\302\240?|A\033\036\022g\357?.}\373\235d&\325?N\300\375\277\\\245\335?\214T\223\371\031\267\350?\270\347\210`\"\025\323?i}\214\023n\261\356?,(?\2258c\303?\254W\354~\254\266\317?\335\032\021\020\262v\357?j\0172\315m\330\350?\240\333\244\300}\263\260?1Z\213#\225y\341?\027\301\001\356?:\341?B\244\357\333)\356\324?\2026\266\211\024U\343?ZXjy]V\325?@X\365\006\347\345\311?\032\3731\003\023\034\345?\030i\004\034\314O\304?\332\267\037\313\364\236\324?rDH>\345\300\333?F\322\005\273\2604\326?a\307\316\017\017\032\354?\220\337\3375s\\\332?6FV4\213Z\334?\232\020\240\237\002\223\352?\352\234\3506\317f\346?[\r\331?F\017\364\016\223\204\332?\270w\274K\023+\311?v\033\252\315I\367\356?\265\310\004\000-\372\356?R\034i\204\333[\331?p<\212K\262\367\267?p\220\262\372\242\312\307?\334S&n\244\333\321?\3351\0371Ox\354?\207\200\302\377,I\344?\220\205\020\222\026P\323?\010\305M\244\320<\340?\022\274;&o\370\327?\001\211\371\216n5\357?\340\307\243\242\025w\355?\310aBY\242\313\312?\240\224\236o\364\365\313?X\362^\360\307\203\312?\266\257\267 I\332\353?\200\216\027?r=\341?\262\360m-c\223\350?\200\303%9\3348\307? \025\234\031M\353\342?v\300\253\245\240\366\332?\300\024\2062\036\305\250?\210\033.\r\356\\\316?\256P\034\200I\005\347?N\2425\353\275R\336?\343\315\205\257\023)\357?,9v\311N\020\313?\276\370\376\026\333\377\346?\037e\t\025\351\232\357?\267\310\240\177k\342\347?Fh\207C\033\220\354?\230\272\261\376y\013\306?\373\260\314\037\212\002\347?\352\226\027\244|\210\347?\246\274\0067V\243\320?\230\347\0202\333f\326?\002o\227\261\341\002\345?\320\023\242\212J\342\330?Bg\327\231\327\256\352?\336v\213M\'%\333?\332\302\225\341T\226\340?s\346Y\023#\006\353?\260\320\302=\225\252\241?p\307F[P\356\336?\240\265%\260\231\325\251?]t\234e\343m\353?\300o\377\035\367\364\301?\236k\243f\347d\357?|\347K\0163\304\347?\351\250\301f\3240\355?U\016\213]i\257\345?`_0~Ck\262?\242s\2058\201\314\345?\3542\"d\232\215\342?0\262\'\362#\362\321?\230\251&:G1\323?\216\341\237\242E\031\346?(\330\242\207\341\325\305?\324+w\335\326\\\336?*s\302\323\002\307\353?\343a\356\202Gx\347?\334\037k\324\366\375\304?\207\303E\217|I\341?\022\014\026\033\252\260\332?\360\333ED)\306\270?G\r\020\204\372\200\357?\206\257\207\0211p\345?B\237\032\330Fw\345?\200\250\025\276\361%\334? \205`\031|R\320?0\323\252\256\217\025\272?\235EM\232\244\203\356?\000\240\200\221\210\200W?%\313Dpa\215\340?\245V*M\344\367\354?d\352\306\000\336\373\312?P\\\353At;\340?\"}\243\371\232=\346?$=\311\307sq\334?\016cJ\314K\235\357?y\177v\261%\313\343?\373\336\305\372\223|\356?\"3\256\034MD\336?B\351\366\202)\257\357?\020GK\324&\021\276?\250\355<\313\2239\307?(\235\350y\347>\301?\010SGG\223\306\334?XK-\024\234\276\344?v(\203\263\311I\334?\310;\301\314\303\226\322?\210c\344\302\356L\333?\205\n@:\216G\345?\034\274\230\003\025\302\354?\327 q\214@\006\350?\315\311\255\246\3342\344?\030\364\324\363W\231\355?\304\225\352\236v\270\305?pj\006J\360\005\246?V0p\242(\'\324?\032]k\360\211p\323?>\001\373\242\201\357\325?\255R\257\317\377\260\340?\272s\272\326\353H\344?\244\213\365\253q\242\340?\304%\336_\017\303\315?\320R\034\362\376=\250?e\035\001M\244j\343?\250\017\007\017\347\326\310?\362\307\303\005\231m\347?p\016dA\035\367\266?\364\375_FT|\340?Tgq;?\"\334?\270zV1\220\177\356?\214\026\245X\214\216\320?\250\026\227\215 l\317?f\312\006\345O\342\343?\220\306\032\245r\221\351?\240\207\267\004{|\353?EbJ\367\022\232\352?c\223\346\317\037\'\343?\270\356\240\352\221e\315?\372\367\366\356~\227\324?\341\240\253 \277@\343?\362$B;\006\361\322?H\300\263Y$\200\323?\n\324<\301\006!\325?T\367a\005mD\306?\233\351\031\327\237\276\354?\000\216\000\377\341bs?\nv\250V\246\216\344?$\323\273\377\027m\334?\270\275\317:\315H\333?\004\342\002|?\256\305?\004P%\337\251S\302?y{+\030-\253\351?\256ax\321j\005\342?\340y\367\n\351:\333?\034?\245\006\221!\332?\370\331}q&\007\325?\240vW\316\214g\304?l\032\323\340\315[\304?[+\342\343k\226\356?\200\373\303X\345\010\261?(\361t\312\242S\353?(\236\376\033g\363\336?U\270\324IY\375\343?X#6\260\323\325\270?xy\350\302k9\307?\276\036wi\212\342\347?\214\233O\004\373;\343?f\301\255B\243\311\321?\034\3333\240\027\364\326?\212d\203\367v\006\345?\352\2416c\346\234\340?T\036\016\207\325&\353?(\216\023\016L\031\320?\362\330\001\016\213\236\326?\250\255\350\036c\304\305?\007\225\271\313\373 \355?u\203\326@E\035\343?|\027\343\324t\n\322?hB\265\205\253\226\302?\214\364\032xN\257\340?f\310\320;\255i\330?\3305c4>\026\346?H \330\242\223)\262?\330 m\351%\000\264?fww\352\225p\344?\347\334\336\376\314\221\347?\030\371\265\001\0349\306?G\357\260\351\374\003\353?\312cI\326\241\311\341?PU\321\264\235w\271?\340\215\002\030\301\336\274?\306:N\307G\233\326?\366\330\276/R\324\333?\226\2377YZ\260\347?P\"rf\t\277\267?8\021\364\004\350\341\340?0_T\322hH\251?\\\345\207J\006\271\324?j,PH\024r\342?.\214\371\033\300n\356?J\335\030\351\2741\351?\344\212\314\376j\377\355?\'\003\013\330\035z\343?\324;\356\202\366\240\341?\006P\314\t\210\344\350?\202\340\220m\260\241\346?%\241%{&\243\340?0\206\305- +\317? \260 ~\313Y\253?\236l\223\272\276\017\331?\237\007\262\312\235K\344?\254\260N\233\213\030\300?\343Z;\023\353\201\355?\324#\275\001\377\361\321?^S\034\020\032\272\343?\260\001\204\301\004U\242?\342{\236\035o6\320?\000j%\375\257\310`?\264\032\003\333\017x\344?8Q>\352\366m\327?0KD\273\023 \252?\000\014DE\265\251\274?\030\013\320\210?m\307?\330j\300\307O\243\342?h\306B\321\363d\301?\334\223,&\022\277\357?\260)).wV\352?%m\372\001-\235\347?\353\006\303\323\321C\345?\250\352\373|#A\355?\360r\345\307\356\247\262?\250J\271\350\246,\353?x4u\246\0162\357?\255\022\364b\311\n\355?J\233\371\217\316o\345?\235fZ\003\270\323\353?\365GVR\335\223\354?w0\320\204+T\356?xE~\330\312 \305?[\345#<\2407\353?\270\204\002\225/\230\345?\036\215\317r\320\013\342?U\262\367B\200\210\342?EUF0\031\322\353?\335\225\350`\252\270\356?\256\267>\360\317x\321?@\236\251\273_W\326?\222\210p\262i4\325?\013\353\265\226@\003\340?\340\262\260\350\363\222\272?\250\272XO\0322\343?\035\240\320W\021\236\353?p\036\335\2329\320\356?\200\374\033<^\004\223?\354\360\256*\204\213\330?\"\223\250\271W\312\330?\374\350Cu\377\356\327?\212[hv\313\375\322?\3450\213\366\321\312\357?\376\320\315<\254.\340?\320\332M\030\2357\322?\216Q\304h\234\360\346?\240}p\3433\316\314?%J\262\202\251\363\344?^\232\251\326\372m\322?Z\341\345^w\221\333?P\235e\007\nn\242?\344#\206\333\243\227\345?\213\317\177\366\302\356\357?4\236\305\022\207\276\311?\320\352Vi\346\215\242?\270?\236\002b/\313?\031\272\202\231\366\333\355?\223\346{m\342\261\354?\245\021\322nn\225\345?\260\3557\006\0008\356?\332\362\363\")\320\344?\236O\037\361:Q\350?\260\257\321\276!\260\262?@\020B\216\351c\302?P\233c^k\212\251?h\270t\254\324V\317?\272\\\342\257|\375\321?\351\'W,\323\204\340?B\253\005\305\236\023\340?PC|\264!\242\252?\271\262\336\247]~\351?\036b\350\262\244\353\320?\332\254\244P\315w\345?\260a\032\n.\234\256?H\236tv\340+\305?Fx\t\213\031\233\352?\2637Y\241\303\214\352?fhT!.\360\325?\177C3\345\003\277\350?\253z8\024\226\301\355?\325$\234\002[\274\342?\211\274\005\005\014\354\351?\033/\210\014\300\355\340?i\004\024\206\265\314\355?V\351\335\335Y\000\325?\363=\212\030/\246\350?$\323;\230q]\343?\203\345i\345\220k\346?\332\256\220eF\365\344?\232\312\361\215\372\205\321?b\374\301\235\242\256\333?~\010J\206\006\361\336?\272\354\353\260\311:\340?r<\242B\261\260\356?L\2257\250\223\005\305?\254\322\261\200\244\233\323?\014\307\321\365\014\370\306?\346\316\025\337-\246\353?9E*\"{\304\345?\222,\033I\006\254\320?\r)\225\327FS\350?\035\006M\332:\032\355?\036\223)IV\342\346?\r\232|XA`\350?7\177\212E\307\177\347?\242M\020!\'\220\345?\242\004wu\303\027\345?\037\3232Q,X\355?4\020[g\315[\354?%\251\276\353m\241\343?\242t\227\030\357^\321?\234Tp\374\263\233\351?\010z\216S\320\203\345?\310F\367tG\370\265?\265\020\376\005\n\244\347?\226\017\022\275\032[\330?H\373m\010\332\010\310?`W\263\353\014\001\336?\021\177a\271\024\226\352?\300A\177*\\u\345?RX\240\275Fg\326?\217B\330\003\240\257\357?\202$\261e9\213\320?*\362\034k\374\256\330?\300m\330\004\034\227\225?\020\204A\032\023P\316?\016\331\304\016\3161\335?\034D\020T\335\240\337?\356\312\347E\237\\\343?t\327#K\017l\345?\t\224\035\215}\231\345?@\274\217\377>\377\303?0Sa\014,W\253?\270\302\301N\305\205\354?\210\205f\243\014\343\304?`T>\227\206\247\252?\0207?\333*\231\325?\014\344\302\265\302\304\352?\350:\214\230\330\211\322?\212$\255>W\017\353?>\025z\240\t1\346?\036\234\023\314\273;\331?i\271@\203\234\275\342?\214\364\357Y`M\330?f=,\224a,\346?\020\255\272\020y#\267?\220R\316\340,\326\273?\226\307\"\273p[\351?Z\250\312\'M\004\327?p\311\014\241n\241\250?\'\211\337e\005\377\346?\305Fe\244eq\344?* \256$\333<\347?W\2156=\3050\352?\246F\262\234\003\256\324?h\362\n\351\030\031\341?`\346jwOn\313?\270\257\274\366\220\310\263?\214\361\037\343H*\300?v\'\234\300jX\332?\000\304u!\332\033[?\220\'\353\024\035\377\324?\206\377\010\273>f\322?P\333\322F-\310\276?h2\254b=8\274?\365\345;\365\233\037\341?\376\306\361\tgS\327?\016\224-\263nz\333?\304\035\274\013k\274\337?\334\246:\213\360 \305?\300\345\360\256Q\246\325?\261\322\034\377}\324\345?\350\362J\370W\022\340?\240v:\030AB\314?\261\3553\3017)\344?P\274\226\302\260\334\321?Pkd\014o\"\262?\240;\014sq\022\261?\216+\231\360\214[\321? \276\2562\300O\246?(\014\201\252\367\354\336?\204\261\276\201 \'\352?\311\270\312&N[\351?k,\207}\033F\345?\203\3645\2576\362\345?\341ET\276\220\036\343?\224r\346\322\241\\\337?h8\320R/\204\344?\360a\373H@\205\301?\304-\r\360\320\205\327?9\327<.\205\250\346?\240j~\270\201\236\342?\360\n\257\200\257\027\352?a\305\234y\361\277\355?p\221\324\353\274\304\240?\225\264W_+\003\347?\200C\333\254\0340\272?\323Q\307\342\352(\343?\216\000\230<)\342\323?\236>\335\213\007\204\346?\340\006N\362\231\320\260?(.)\013\364\215\333?\200;U\251f\272\273?`\201\202\353\333t\310?\020(\374\361\350\266\326?\\=\204\3335\234\344?\212\206j\320\030%\351?\024\351k\300\322H\356?\000\264N\271 <\326?v\315\247\354q\352\350? g\372\353Lx\222?\366\211\347IE\031\334?\300\246N\3745\255\307?\272\321%)\354\201\355?<\263\221#\256L\322?\304i*\000B\037\310?\262\037\253\217\027\307\357?ty\240x}\002\351?\210\334:\335\364\034\326?\034\216B\321\346\371\323?\250\347\375\"\032\254\335?\375?\205\027\373\300\342?\220g$\"\243\212\276?\245\204\006m\"\244\340?h&\r{\222\221\353?\253\273\356P\331\364\355?\010\004\001\363\263\026\332?\371\026\026jn\367\345?\000\315@>\365H\271?H\302\323\373/\313\357?\200\352\206|8\330\341?\216+0\253z\244\345?\000$\021`\2633\355?--\204\\7\225\350?\000\346@\'\020.\347?\0301p\004\227\244\327?\032\365\315R\265\316\326?\317\363\265\261\350\263\357?\250dk>\216\255\356?]\373\n\r m\356?\030\261 \230H&\262?\245\317\\x\257\205\340?@0\2447\320q\330?\025\220\237\373\210c\344?H\364\3228\260#\271?\326w\213\305\271\305\342?0\306\r\323\222\223\253?\370\004\n\325\377B\307?\251%\007\247\260\275\357?\004\r\245\273\236y\331?\366\222\021\031\375;\334?|\314\345\313\335\201\323?h\320zg\272\273\313?X\030\276G\224\262\352?\237\177E\356\230\230\353?H,\272\266\211\357\323?\300\006\214\247?\017\342?\354R`\374\315\007\300?(3\010]\277\003\311?,\204\321\320\004X\300?j\376p\367ip\352?8xD/\274\027\327? \036\317\007\374\270\234?^1\262\212\244\207\333?@\256E\035\322t\237?\314G\360z\230\257\304?\256\222$\273\354/\354?f\365^\001\310\033\331?\370K\265$\t\006\341?r\226h\t\314\315\331?\220\245\253\357\244\257\315?`TQ\374\303\'\227?\330/6\244\224S\340?<\357E\252\0338\335?\274\024\353\221f6\336?\334\345\241-[N\314?OEe4\353\002\342?\206\"\304q/\267\356?<|\260\210\213\017\356?P\\\316\001\311\265\320?\340\t\360\313\235\304\302?\277cYN{l\350?Z\302\020\253[\342\357?\314\266w\265\262y\313?\337}2\346@\272\355?l\337\261HLL\312?\3400\312Z\307\367\223?\244\241\373\312\343m\352?\306\216\213\223\031\302\334?\324\037j\036\020\250\331?\256\014\211Yx~\334?L\0273VD[\315?\000\324+\004\267,\333?^\336\215I\177x\341?\260[\361\241\212\\\263? \270\222\245\341\320\356?cB\242\212e\334\343?\014B\3709\367\250\316?\253W \226\036U\343?\315H4\232\rB\344?\024b\307N>\032\356?\211\2230eSu\340?\233#\031/\016K\350?\000\222\311\241N\n\264?\207\350\031\342\r\246\340?\361l{j&3\341?\000\243\317.]\262\335?\266\362By/-\327?\010\214\222&;\355\310?\364\226\352\366&k\321?;\314`+(m\357?m\t\356~\332j\344?\010\223mP\0273\331?@g\253\347)\321\305?h\210=0\247V\340?w>\336~\233o\340?\374j\014\017y\371\356?\035jD\302\2743\353?\034\026\267BP\300\331?\360f3\253{\233\302?\376\347\314RN\302\320?\024\3349ju\276\341?\201\202\013\336\262\200\357?\310\3347-B\362\331?q\263\215\301\372\250\343?\303\256\323}[s\346?f\231 \321n\350\345?h\207G\202\351\023\322?:c\266\331\255k\346?\273\346\351|A\220\344?%51\202\'\221\340?\267\210\331\243\324\302\343?\\<\372.\365\346\340?\232\353K\311^o\345?\313\324\007\375\373\340\353?\303-\353\321\317*\347?\306\370\026\033%\331\336?t\221S\215iA\316?\250&$\236\006\215\317?;\251B\213Y6\357?4\016\025@\357\344\331?\177nY*\243\211\346?\333\274?\332yd\346?\275\227HYt+\342?\214\237\224\264 $\323?X]\2663\001\320\302?\020\324{8!\215\271?6\017?\207\t\241\343?\326\020\234\242\232\271\352?\340F\337\304cY\257?\2008\220a\302\202\211?xE\376cr\346\341?\322S\261E[\273\344?\204\364\326\371\350\334\346?\025\3054\376_\002\357?vk\001\036\230\225\320?~p\377\207Nc\342?\360h\200\320\2450\330?\341L\227\317|\310\351?9\365Z}M\366\351?\320P\rf\237 \325?\266\205q\205\341\333\322?Ur\017\223\265C\355?N\177f\027K\022\340?\\O\336\266\262Y\323?\363\325H|H4\355?\237\275\013\222\304G\341?:\021D\336\3207\341?P(lyM5\262?\\\211f\241\333\330\356?\370-_\\U\266\327?\"\237K-Ax\325?$%\304\013\205>\333?)c)q\350\273\352?>7\331\313Y5\351?`\222\245o\311\322\322?,\034\306F\026\243\322?\307\n\r0\366B\340?\214\223o\275`\301\325?\250\022\000\'\n\214\310?\020\360-Pd\266\251?\00445\205<4\316?T\025\004\317G\010\353?>\352\252?\300\313\336?\262\251\222)\210G\320?XJ\360\376\253\005\327?\345\375\366\272\217\312\343?\317\24292v\340\352?z^\225F\034?\341?\030\261\233C\307:\357?\330%\222\341\340E\344?~\245\027\033`l\325?r2\206\200\221\374\345?l\033F\274\364K\336?\354\013\210\346\334\273\335?r\"\362\302\037\200\350?\030\010\000\233\354\024\302?\262\031v\016\317\010\332?P\345K_\340\276\302?@\375\032\261\250\252\273?\316\241\215\327\247\362\350?\330V\353P7]\272?\345\346\266\352\307y\355?\034\222\003\315\005\030\350?$d\231\014!X\346?$+\327m\244\342\342?R\274\337Q\035\201\346?tY}\334T\335\314?\340\217\367\223\254\373\312?pF9\036\214\271\245?P\321n\305\314u\265?\303\234\026!\275\007\341?\024\201\367\265\373\244\346?D\t\002W\020\213\324?\031\375L\337Q\361\346?\225b\267\312\362\317\342?\362\177(\316\337\357\322?:F\222\234\227\251\320?\027\205\004 \276\366\351?\340x\361r7\'\242?\376\310\002}\360G\350?\352\007\243\341\027\367\345?\374\351\360\311\320G\307?\353\305B\257\324\023\344?\242?e!l\252\320?\223H5\272\205\025\353?\342\365\025S\325\373\342?\340\245!=ag\274?\000\246-\002\246\223\\?\364\253\3625\356\017\335?8WDO\357\204\327?\034\3311\225G\006\327?\210:a\312\345\351\261? \377\320\262\360\024\230? m\270\177aJ\313?\224\330\361\340Ll\302?\2522\251{\327\256\345?,\373\210\317\261^\345?\333\356\313[\0058\356?\244\361fK$\240\303?\362\365t\240b\361\333?\347\\\320\211\216>\350?B\276\033\010v$\325?\001\306\257\254\022\007\356?hO\014\010\036\234\306?\362p\364\370\330\331\337?\022\261s\177&\217\336?\014\376\300\206\226\355\352?\200\264\220\275\366\300y?\2113\235\017\267\361\354?\331\314\036\014\336r\342?\360\353x:\261\327\276?0\206\025\322!\360\340?\335)\356\207\"\334\353?\\\235p\027\340>\350?\020\231\2550Zd\350?\252\264|z\002\214\341?\276\245\325\276\372\017\325?LJc3y+\302?\322\227/\344\253\255\347?\030\300@h\243\253\337?\222\007(\324c\254\332?\300\306l\342\361\371\303?\265\251\361\217\260}\344?\350\357$\t\231\254\317?\372\317\237\006w\374\327?&R\277\244~\327\320?\227\014\206\342pF\347?\3303\201q\334a\344?\213\331\250B\242O\356?\216\017\3744\033\306\327?\006\272\255\373\255\024\340?\300\221\357\202\017\306\204?$\321\344\373\"\267\312?\230\251\3478\332u\347?p\274(\327\353 \321?\314\373\305\177\024\027\327?\032\271b\224\254;\345?p?\314i\220\234\343?\316E\227\313\231?\341?\344\024\222\246\242\341\344?\000I\246s\2747\216?\242\250\371\213\361\376\345?\211&o\316\267=\346?\014\343vJ?:\326? \213\214kN\000\344?\212\031\177cl)\320?b\322\206R\277\037\327?Bv\007\242\215\t\343?a\022q\225\234^\342?\324\234\346\323\026R\333?\324\276:\372\307F\351?r-\230G^8\344?MMa\362#>\340?\322=>\333\375\304\321?\210\206\304&\372\320\330?tI\252\232\032R\337?\363\016K?\216K\356?\003\306J\026\264\302\342?y\254\005\213\301\260\345?\216P\227\211\266\302\352?\255+\025\324\210\034\353?\200Q\236\277\214\361\322?\206\274\343>l\214\330?\002VDY\333N\341?0*\204\031S\232\351?2rr\035q\227\347?+\276\031\317;\215\356?\222O\265\320\014\345\320?\227#\267n\024k\356?5\331l\177\0318\352?\266`\367D\3361\345?\360\216\265\351\206\374\326?P\370kL\344\251\272?\200\245\302\261\324xs?^\335\202\271m\007\324?\206r;\300\036\321\322?h\025Ma\323$\357?\221\376l\270@}\343?\366\222E\024\213[\356?\200M\2155\202\246\234?~\324U\354Zr\333?\264\204\n\336\037\033\355?\235W0+\304\245\357?\353\365\223\202\033\362\355?|S\327a;\271\323?^\024+>\006\257\357?\214\326\321h2B\314?\300\021\253\375\\^\301?\260\002y\r\325\225\240?,~\237\320\007\214\352?\013\326\"\267\0217\351?l\363\000\356\341\264\345?R\300q\223mC\331?1\204\201\2505\032\343?\002n\234|\300\021\331?l\331}\330\253<\347?O)\322=\342\033\342?\315A\324\315B(\350?\206J\361*\257L\342?\024\t\214D\235c\336?$\253\355\223{\'\324?~\n\033\375\246|\345?\003\320?\006\177@\354?Xe\353\215\236\217\354?,\367\301\221\277\312\321?\365\266\310W&\007\346?~\221\000\326\276\320\334?\310\357\003\306x\330\314?\351\332\006\334\030\376\354?\"Q\341Z|`\326?Z\310`\303u=\333?v\314K\027\350L\351?\262\340v\274\201\354\323?O\373C#\232\342\353?\236\213\221\365\234)\346?\320\240A\214s&\354?\312\200M\361\336D\351?@i2^\336\024\344?\2445[1\023n\324?\241\324\206D\2209\351?(h\nm\331\337\341?J\270y\336\031\223\324?\376\257\3714B\326\322?\027\331\236\001\242>\345?\261)\256\270d\266\342?@\212|/\344f\204?\250Z\233\316\276p\307?P\236E\251\026\026\346?@v\266\271K\360\272?\\\n\334\026\311i\353?\304\022\225w\365\002\314?\\e\234.\217\370\302?\264UNQ\005\\\325?\270\014sM\244\370\322?nh\262(\315\210\320?5\233\r\354\036c\355?\276\330$^\364\016\331?\266A\262\254AR\335?\2361\023\312\361\235\353?UB\013\302\351\375\344? \027\260\316\236\314\223?\354ea\007HP\311?@\"\264\004#L\253?\242\373\360X\t>\355?\240\013q\032\'L\354?\245gJ\357\345\213\355?V\021\300\360\346\321\326?\265\254\301\336\323d\350?\025\266f\033\310 \342?c/\035\036C\376\342?\340\200is\274\273\270?[g4\227 (\354?\216n\277\205\035i\343?P\024\264\224\202j\314?\305i\305NM\001\351?f\211\200ol\"\351?JC\334\254\352\247\351?:[\034\374\325\237\345?\030T\200N\013\361\267?8\032$t\tj\316?\236+\"E\2734\351?@-\220\316\350\334\267?\005\201\267\222\345Z\356?\227,\202\223\r\353\356?\3165\200\357:\034\336?@M\253\363<\253\236?$\033a\270|\200\346?\31468\021\365G\301?D\244\211\234\323+\356?\177\036\033\245\232\200\353?Vr:\337\033\021\333?\332\033\234\014\254\027\350?\376\356J\320\004B\324?\334.L\022\263T\325?\270\373\223-1\273\275?(9\263\223\0333\305?\370i\312\276\346\216\321?\352\311X\260\246{\334?lf a\326i\304?\353\246\007\033\"\221\343?0\341\304\001\"E\327?BEx\350n\263\325?Z\036g\365k\355\327?\366\244n\240o\257\334?\234\021wP\345U\312?\350?LV\215\246\271?\333H\030\346\'&\356?\364\365J\302\304\034\316?\224\377~S\242\332\306?\235[\304\231d\036\355?\037S\001\255=\235\343?\270R8\034\275\367\275?\360\301Q\376:[\330?\274\246\237u\370\271\300?%]\036\334\261\263\345?h\363\341U#\213\300?\010\243\255\266\223\344\355?\214\260\'\023\r\004\301?\030Z\237(\322\255\351?\270\216\304o\017\025\263?\t\235\006y\022\203\350?jM\2368&\361\326?\214\223\035\024]9\345?\341d\330\265\276\024\353?\027$\210\314\233\244\355?\006j\3240\326\277\333?\004Y\337\212\251\362\325?\356\272Y\315\363\263\341?\243*\034\223\250\313\356?d\322\237\to\004\325?\222/\201H\362\376\344?\010\226\340>\'\336\331?\250\326\026\000#h\301?\370\204k\322&\314\343?|waB\205\233\322?F\006F7\275\227\350?\236\323\324\005\265\316\357?K\231\001<\340X\350?(\306\204F\246\005\353?\2445s\324^^\306?\200\377\204\343\346w\310?T-\362T\013,\314?\270\3600\205\327d\301?\264q\300\304\032;\345?0mW\263\277\220\260?\004\366\017\014\343\217\322?y\214\2033\355\r\353?T\024\035\327b\021\340?\354\361\354\031\220\237\333? \334\036\343\343\272\355?4g\"D\3047\333?\314U\025\006\016\363\345?o\246\020\255\265!\342?\336\361\314/\300R\332?\023\216\362\002O\211\356?6\242\266K\266<\351?\230_Z\315\323\361\316?U\205&8\267A\351?\346\301\027g#h\326?\2428%\204\331\351\356?\232\227\371\301>\354\337?\240&]\021\276w\235?\354\320w\274\213\237\353?\230\t\020d^\003\352?lE\n#()\337?L\027\r_\356\300\354?\214.\2348g\001\306?\202e\034\013\271\327\325?\033\355\227\014|\373\342??\243\356\252\347\364\356?\226\3073\300\217.\353?|E\330\350\005\205\352?@\032\266\325\036\270\357?\373\332\372Co(\342?\n\025\340\202\315\017\340?>\315\320\343\002\262\342?\200+@q\ty\312?&<\357j\320~\351?\304r&k\025\360\336?`\314\333\207\333\270\331?L-\247 \263\361\355?\262\240\013Ap\211\337?\371@\210&\253\244\354?p\210\346Q\255!\312?b\275\255\332#B\342?\254\2378\252\330B\200\330?,\245\215<\n5\302?\000S\232\240\206\217\234?\3202\220:w\363\354?\276N\236\014\206\023\320?\302\235f\303\304\250\324?\340\273\272\314\3169\320?\020T\037\337\376\375\310?\210\345\233\305I\260\352?\005HFC\264\213\352?\262r\345\234\277;\321?-k&A\027\332\346?\244 L4\365\250\342?u\223\3160\356\004\355?\243\033\272vBr\341?\230F\0068\203\027\326?\200#[\330\233\307\333?\260\233\241\355\374\344\256?\000\224\t\247\271\345F?\000\026\231\200d\231\237?\340\254\236\246\262\023\262?@\213\216\255Q\024\346?\354\004\347/t6\316?\206\263\310%\305\323\356?|(\237\200K\004\342?J\325%\2745Z\323?\352\272m\277\203\t\330?\n\315\332\3045\271\342?\213G[,X\327\342?%t\354B\005\257\356?\244\243+\340S\236\350?\2534)\233\260\035\351?\365G\"\030\202D\342?\200>\310\014g-\230?\340\206\334\353\251\261\231?H\t\000\262\307\271\260?P\370\000@\033\277\301?A\214J0B\271\347?\024Wa\245\027F\346?\200k\350\212\240?\204?\324\312\343,A~\301?\016 q\320\2708\355?\224\2510N\200\254\326?\014\302\306\376E\004\311?W\2642o3\216\344?\200\207|\3054\302\226?\262\221\031\314\002-\354?\030\232\333\237\026\312\351?R\307m\002\2259\320?\200\251\374s=\203\323?\242\347]\036\354E\335?6t\256\235P\256\332?\000|\203A\304\340j?\246\330-\024\364\376\340?\342r\231\307\346\326\350?\262\333\223\351\374\361\353?\212\374\342LF^\337?_\325\267jH\275\351?\274\225f8\211\230\313?\321\330\013~\177[\343?\310\2020\211\033\244\261?K\263\365:\214\307\340?\340\'VOT\\\334?F\312!\200\351 \323?\034\352}\020\026\360\321?\313\300\177\375\336>\342?z\305\210\3712\242\334?P\361\"\334\304\036\267?\3244\372\321.\374\313?\214\230\376\017{\022\300?\3269\2071\334\336\342?ed\215\277\362\336\344?`\310\206F\240\342\235?EQI\'[\177\347?@\021\246\362\235\350\277?\000\021\225\3565\016h?h\354N\017\311q\312?\016\000N\032\241\034\350?4\020*|\322\264\320?\302\267:R\244N\326?\236J\023HK\377\341?\210}\256\257\227Z\305?\360Q\367\177&\235\307?\200I\177h$\320\214?e\014\267\303\220\035\350?|&\377\332\253\337\344?^\314\017\331e\265\354?\320G\201\327\n\363\245?l\200\020\027K\346\307?st(\0060J\354?\200Q\243\332\231o\355?\312\213\263\317\223l\335?R\3378L\347\263\355?\036i\305\334\310~\355?\222\274\273Wy\301\346?\210\217\303b\275\324\304?\365\347\020\371=\306\355?\0201\262\340\356\353\246?\201\3300\350\001\313\340?\316\350dap\024\340?\273\270\037X\263}\342?\030R\"\354u\n\336?\262\365\331\223I\215\341?\234@\226 bf\334?\342\362\031\367\"\232\352?P[\236D&Z\354?\347Cy\340\0347\344?\250\3774/m\335\263?\234\322\232\224\237\321\312?Z\\\034J\372\240\322?\034\306S\217y\346\347?-\251\260\364*\353\352?\273j\304^\214\351\353?\211Lng\277%\344?dH\355}\371&\320?\300\330\321z\363\236\316?\252mK\2049\365\351?`\226\277U\016M\273?5\375WX\301\357\340?\360\030\337*\214\003\255?\3127\202\356\357\212\334?A\361@\360\032\262\344?\301\312\037x\016\323\356?\200\324\317+\342\271\265?-\221\203\306\030\266\353?\324\035]N7&\312?\206Y\316\273\250\247\323?\275n\254\236X\213\350?8\263\326%z\356\266?\353\273\276\374X\026\344?{\000\257-\366\276\352?\337-\257\240\026.\352?\030\177z\273jT\266?\270\0133\256}v\306?:\010\225?\t\206\347?\242\206\260\272\220&\333?\020\234\250\256\016\227\311?\235.\333\003\254\263\342?\330]\\\177\263K\307?\230\003J\247\340\343\355?\007O\033\226\261P\356?\005\314\346=!*\343?,St\002\344\374\301?<\321@?M\347\356?\010\364]O\311C\340?\210Q\353\344\361\300\301?!\017\'\276\303\345\352?\242\016\217=6\272\327?\022\0373A\022\235\345?6O\253\017bj\321?V]\2574\226A\322?\257\027\211o}\336\340?\376\014\n{i$\330?\220r\001\253LJ\307?`\203\247r\237W\267?\216\005\253\366u\327\353?\320\023\204\300O\305\267?\324\347{?$M\337?V;E\256\025N\350?\210\014\300\330\361{\307?\355J\336fmm\344?\201\332V\034c\352\356?\306)\3256\265K\342?\010M\377b\340\372\300?\220\264vH\300\003\341?\\\367Y+A\n\315?J{\232u^\034\322?\375\247g\337\031M\347?x+\337\274;\240\302?\202\202l\3517\225\337?\307i3\361\320\314\355?0\265\2227\264\275\306?\320k\323xNg\340?\241\377 \334\276\335\352?\264\017\274<\266\000\301?`\314\321\001\177|\240?\250\227\225\207\024\r\264?\204 2\272\200V\326?h\364.{\272\027\263?x%K0s\303\266?8$\024\357\264\307\320?\376~\247%\351Q\330?\352:F\263k\262\323?`ok\262\227\316\264?\2579\305\2036\217\343?\034\226\205\303\205c\314?Z+\351U5\016\342?\032\013mUR\302\321?#\226\200\020\365\335\355?Yq\026\331\310V\342?\254\215\026d\'\330\355?\340\300\234\210\250\312\235?\213\202\315\010.\276\353?\212\002\351\363\025\327\324?2.\212\026\375v\335?\254\254\270\364\241:\357?\000\317y\020\033\010\317?E\301vz\360\223\351?^2J\240\203\325\346?g\007\266\207\367]\341?\214\263G\252E?\304?\2022i\274%Z\326?`O\031zy\022\244?\202\331\216\356m^\335?T\337\261\017\305~\324?\030AO\377\255\244\334?\300\303\237s \350\276?\316k\250\275<\224\321?\232\033|c\312\321\346?\265\351\231]\014H\340?\010\250\320\257\361\366\334?\264\0226\032\216k\316?\3109\335\2642W\303?\357p\373[u*\344?\350P\341~<\277\313?\000\005\317\333*\341\225?\332\224&_\322\321\343?\232\376t\345\374\234\326?\232\006\352\21093\335?\353\177\215\206k\342\346?\3600\005\221\302\302\304?B\344\240u\257\367\355?\314*\275()\233\324?\020\024\2560_k\347?Y;\370\263\362\355\351?\034]\271\252Ek\315?\320\322$2\250;\322?\3201\2477\211;\242?\365\016\363\245\362e\344?\311\352`r/I\345?\005WB\340\004Y\356?\370\000\223\354\322\364\305?\322\001\210\nD`\321?lo\222\\=\266\322?\374\372\216\024\343\305\322?(~\247\214\223U\326?\rA\321\222\315;\352?\325Z\235\305\223*\345?t\260G\316\372%\305?H[=\2449b\344? \341\276w4E\240?\226\251oP\376\242\347?\234;\035\333\277\314\341?\215\034b\255\301G\341?:\325\371\340GQ\344?\340\274\035\2558\275\226?\005\275\242\376!\243\354?v%\345\315N)\322?f\232\264niY\330?x\207\222\224\313\231\334?\362j\261\257\025\243\342?e^\214\332\217\266\353?\300\210\006O\027<\343?\202#\203vV\314\336?\260\262\256\274h\205\262?D6\305\032\345\346\344?\000\006yS\302\200l?\252Q\231K\260\211\341?\036\2679\257\336s\334?\260y\242\314\003X\342?\344_\304\005\201Y\303?\202\206\230y\212X\354?G\324*6\310\214\347?Q\276\343-Z\033\352?\300\273\001\022\205\224\210?\030\210 Ud\000\277?\310\223Hl\035x\326?h-\264\224eU\346?p\301\363\273\266\203\333?4\r\037\277s\354\353?\200\223.\003\031\341\351?\025\\\221\n08\350?q\242\3409\343;\342?\271!\366\367\235$\347?\315\227\027I\347\375\341?\316N\030\245`\016\342?\004\000\334\331\376\234\355?\363\"\273\212\232\000\353?\320u\003T\302\356\336?\243\213tM\034n\357?\2405\304p\215\260\260?\201\036\334.\346(\350?\235\223\302\333\346\\\351?x\336\226\256z\234\320?,\240\024\006\006O\314?\367!\254KL\210\350?\010\351\323\227x\026\302?\303\304\361<\240H\352?6\\\323GH2\337?h\227&u:r\272?%\301\346}\230#\354?P3\243\235\251F\321?\355\"\355/\240\342\342?\217\255\373\233\245\366\341?\240X\241\357\032\344\224?\005\373\230yi\257\356?`,\177\210P\320\325?\340\337~\014\252l\347?j\201\307\255\262}\337?O\222i%\267`\352?\200\200\234\361S\305\224?\315\201\304n\362\331\355?D\240\251\023\214`\321?\363\034\335\360\310\025\346?\026\217\346C\361\010\356?D7\263\371!k\302?L\177\267\371\2755\325?\346\014\303C&\242\353?-\273A\343)\306\345?\364\020\221\002\216\002\321?\346\323\223\224\364X\347?\327,\025j\224\202\344?EG\214-\376\334\346?\ti\350\252\273\001\353?4\032J\217S\201\357?\342\244\363To\"\347?;\021\263\312\312\315\356?\244\004\327H\257\031\345?\340\214\002\246\352\324\317?-\215\037^8\000\356?\327&\216\020?\226\352?\351\002.\031Zx\341?\377\345\3641\020\270\354?\217\3230\001+K\346?`\3467\266e\211\222?\240y\017\022\315\370\327?\010J\022\001o\374\353?P\034\264\307G7\327?\037\334\340w\3029\347?\320\334\223\211\363\031\355?\310\177\007\030\007\351\301?\350Yt)*\027\266?5\242\310\370\250\254\341?\272\212j5f\023\327?\300\227\305\321\231\314\240?\215\325\361{\004\t\355?\264\200I\230\n\266\351?\340f\207\200\027\\\353?`\236\276\215\022\320\325?\254v\033H\212\320\321? H\316&1\\\267?9\314\220\201C\246\357?\3407vX]\222\232?\304G\020\241e#\344?\302\227&\237\261f\353?PU0\375\007\343\246?la\237YU\020\356?3\201\337N\200\320\356?\206\224\306gy\334\335?\020R\334\313\233\006\321?\327\006x8X*\347?\375\372\315\205\235N\340?\330\260\006\357\356f\353?sT\265\367\332\010\340?\202\366s\"Z=\326?\322AyH4M\337?\313J\374\271\256\316\346?\374\200az\303\207\322?\231\355n\365\000\303\352?x\\\226\003C\200\310?\353f\005\254u2\341?\006\1772;V\262\341?\214X\221\370\r\010\310?o\254\002\227@\326\343?\257\344`f\021\303\340?b]0\030\331,\352?J\236I\303\376\306\347?\332\006\241\311.\024\330?2\'\317B\221\315\350?\3347X\253\371\300\321?\350\237C\n\305\260\260?8\346,p\230Q\266?\222\007\200)1\276\320?h\356\222ol-\334?`\224u\313M\265\237?X\337P2\227\263\346?\010G\344\222\2037\261?\220\265\357\333X\241\355?\260\266W\343]c\276?\232z\333\236Qi\341?\354\211\005\274\020\321\315?t\211]\006\366e\305?\032\335s\376\322\275\350?\n\336e7X\364\330?\006\245e[i+\323?\364*\312Kuz\350?g\226\035h\216a\345?\330\362\200\302\340\277\347?\200\234/\366\362\016\210?\010\247Jd\034\031\265?\260\036\357\252{Q\260?\262\312\017^\222\236\323?\200{ \311\326\255\255?\204\374<\177\242\326\355?\000\2142\224|\350B?n\241\323\312p\313\351?\220\034\334\311W\343\344?v^\3571\301\230\323?\270\217\261\300z\245\357?\244g\306J;\336\326?\252(\225H\221\023\333?yF\352\334\227\255\347?\002\367\026\313w\203\345?\312\\\331\020\207\255\351?Z\352\351\265\247\223\356?K\202\327(^6\355?\305}\256\237\215\366\347?\267X[~\265\014\357?\310\3634\373!>\306?^mt!\034\246\322?4x\240I\202V\342?d\210\274=\360\230\350?\000\315L-\013+h?\313\367\025.7\271\342?\200F\031e\325\020x?n\375\004o\367\346\334?X9\365P\347\243\274?\376\215\305\222q2\334?\211\021\202\'7\276\346?\274$\250\301\2037\330?!^V\215\031Q\342?Q\350\237wL\253\353?F\205EDD\004\344?$\017\231\r\212(\327?\361\334#\240o\006\352?\374dG\217J7\316?^zP\031Y\355\356?w\177\0144:\366\351?\034\240Z$\372=\332?0\2149C\267\032\312?Xx\325\305Cd\343?\014NR\337\362\035\347?\242=,\201\303\216\326?F\364<\017\014\265\352?u\231\314\3206u\351?\260*\211\276\357\370\203\366\356?d&\276k\251\373\331?\025\223R`\322\375\354?n\027V\006\216\034\353?\3018\354MH\345\354?_\025\027\223\344\321\353?\352:\'\004s\240\326?\210\025\376m\010\t\332?D*\273\246\337\202\352?\200\322\2023\363\232\300?\"\266\364\366WP\333?\257t\230\367\302[\354?Q\315Z0\213*\346?\024\240\334\r\001Q\321?\272\252\224\352\277\021\341?\3006\3438\257\273\255?\352\247{\375\\|\344?x\303\257\266\212J\354?.\022\376%j\270\322?\364\313\233\310\3235\315?H\371\017\232%\213\340?\300\356!lb\337\340?my\221\312\251X\343?\024B%\356I\257\322?t\243\271\036\341\003\316?\263\342oB\374\244\355?\362B\362Kw \331? \035\276\0357p\251?ia\256\246\324\217\356?\260\367\230\365MX\256?\010#H\255\334\245\343?$\303\312\326\032\006\357?k\006\003\362\034/\355?\000P\265I\356\237\347?\356\034\017\276\347\307\344?\220\306#\227\342\032\306?4N\325\304\025\352\300?F\273k\372\033r\337?Z\235\205\013T\357?q\373\252\351\377\360\341?\"\275\206@w\264\325?\376{|x[\353\337? \032\371\250\353\233\257?\037\232\221\224\246\327\341?n?\363\371\257!\353?:\346\250\252\025\'\344?4\3439\177\243*\341?\3247P\2531\004\324?\230\247B\035\250\362\342?P\0223[dG\351?\021\340\300\202L\313\356?d\241O\332l\232\322?R\020)C\351\327\341?\030\266ci\230g\274?\000zL\316bp\305?\377fO \033\361\347?\366\227E\301\tp\342?DoO\222\360L\346?\034\002I\203\320\007\300?\270,\377&}q\274?\346\266\247\251%D\353?\000\300\nj\375i\022?R\214\277\267\246\353\321?\356\257rD\250t\340?\235\220\270c\216_\343?\000.\273\357\337\340\355?\224\010\215r\301c\337?cq\316\217\221j\344?\374\322\367p\245\004\346?\032\220\251YZ\304\332?@i$\277\302P\347?`b\t\300\262\326\240?\216&\032\261\226\332\336?\255\207\265\272\211!\344?\000 \032\334\026u7?H0\034\305SQ\304?\206\242\023\364\307I\341?z\236\342\246R\336\320?\037\r\344\2051N\353?\272L|\255\315\242\353?\376\0169r\314\303\356?\362;W)$\005\353?JCC\035\214\233\326?\230@\327eI(\335?\000\2105\235\241\032O?\000;\367\230\001\221\352?\204af~h\373\347?\374\240\265\332\261\257\332?\244\245O\370\034\312\335?\350\236\027@t?\317?0\374=\335\327*\262?\032\230p\000\215A\325?\346Lz\021\361\374\356?\300n\231\333M\234\224?\346\324(\017\276\246\357?\213/\312\306\363\327\344?PB\376\376\362S\323?\340\367\317\257\220a\313? \261\\]r\344\353?a\267DS\340v\351?\243X\017kK \347?/\337G\032\010<\347?\354\256\317\303<-\300?\342i\235\362\243\315\327?\356\325\327fK\310\320?\330\377\245\336\363\335\344?\364\264)\203\351\253\351?\030p\257j,\024\301?@yW\342\361\235\326?\n\275\302\365\217\270\355?\314\213\201\244\215!\327?\252&\211\213\303N\322?#x\263Tqx\345?Dfy\035\021\336\326?\3165\311\215V\365\323?\024\027\232W\207t\303?\032\017b\0260\231\340?`\356t\330\\\177\356?\206\264_c\312;\326?h\010G\256\310\321\324?\'A\013\304\321\333\355?l\243o_~\001\326?\364\265D\313\335.\335? \315\264.\331I\260?\320\007\224\246\241{\357?h\212\310\315\244\035\302?P\334\256\254\014M\322?P#\013\366~J\304?\370\253\275ey\335\324?l;I\241\332\340\356?\366\010K\260W\026\352?\322\215\251\002\275\210\337?t\036\244xX\253\305?XEO\217\000\034\274?\326\317\355\203i\004\321?$^\035\226\260\241\343?\310\001\246\237\276t\304?\335vP\347\265f\355?\023\007\254g\331i\345?\200\252\257\225](\342?9\276i\340\023\312\340?\212\306H\244A\232\320?\256\220\314:\330\204\334?\233@\351\365\275\304\351?n<\214~]\236\353?\240\352>\363\021\326\327?\036qR\006\260Q\322?I\327,\030\314P\353?;\221\326n\005\253\352?D8\350\251\rN\302?\300\005\241M\326\221\352?\205\232\003d\221\207\351?^E\026\340_\263\340?H\221\233\313qD\267?\370G\224/\215\240\325?\000\265\342\204\307\224\236?\000\235\261h\223\254\233?\010\353\230\014\340\023\266?\017\371\314\273)\024\341?\346\302\331\\\261s\345?\031V\245=T\256\356?\t.\330\004g<\354?\266\030\233\206\017\225\326?q\225\341\310\254\254\343?_\354G\272\034u\351?0\003%\357\220R\304?{\221\351d\253\017\355?V\2451k\316\027\353?\nM\261\345\316\271\324?\035\343\000\355\3700\346?\217\347\216;\207\266\352?_\t\347\250\024&\340?\347\211K\226`Q\353?C\376\372?M\357\347?\016\3506\217\245|\322?\260\032\347\251f\322\331?\322L*\351\253\203\335?\264\314^z\333\365\341?(v01\260\023\265?\312\262\247\rPp\336?\374\344\371\001)\016\303?g\263\331g?\037\346?b%\215\003\254%\323?*\215\256}\201;\350?\217\002\010\365p_\344?\330\350\324\213J[\304?|\337\024\371\226K\315?\314O\371T\360&\335?\024\236T\3408@\312?\264\277I\344\354\271\315?\342\352fPV9\356?d;6\035\221V\350?\237\203\370]Zt\353?\240?R\213\355\271\267?x5\354a\365\313\353?I\306\031Ud\213\351?\256\205\241`\354!\337?MM\300\221\177y\351?\024Y[\014\327\347\347?\340`\301\342|\210\252?\304^.#\266_\330?\\\330\304~v\014\322?\204p\343\361\010 \346?\337\n\"Hg\366\346? o\267\"qX\251?\373\270\317\323,\316\340?\330\023_\"\261\274\273?\363\345\024-\336\001\347?\022\240C\016p\362\357?L#\\u0]\341?\277\231*\244b8\340?/\227\014\213\302\341\354?,\3701]\255\362\315?\260\2615\323PA\346?4\035\240\256/\210\300?\311\243wG\"%\344?\262\001\360\263\377\220\344?\010/4\322\377\267\305?\251\037T\241)?\355?\027\241klD\034\357?\335\216\226\326\t\250\343?\304\000\235\ri\222\336?\325Zc\300in\344?\216\017rt\270\377\344?\252m\002\267\311\021\320?<\243\234\215\301k\301?e\301\277\230\312\214\355?\376\030\255Q\336\230\323?\267\220\027\216*\226\340?\343\225y\022h\"\351?\002#\366w\234.\335?n \t8\345\374\344?H\367e\307\342k\354?\034\231g\221\362O\315?\346\221\206p\312*\357?t\'\005|\371\216\332?J\275\361\200h\206\351?+\243\014\\\201\307\353?\024\201WX\362/\327?xa\213\007\362,\356?\361hx$\2146\347?\230O{\007b\362\331?9\374K\025\212\253\353?t\001\324\027\227\331\350?\320\252\255p\221\377\315?\362u\332T,l\342?26\364\375\227l\336?\200o\356C{P\302?\350\265\006\317k.\346?\274Ws\002\237r\317?\360s\016\"|\245\277?\200\326z\327\217i\237?\350\021\250\003~\022\261?\250\314\250\017\302%\330?\274\254W\025@v\347?Q\177$\236\376I\340?h\271\235\256\317j\330?^\"u\215\327I\325?I\016\313\226\"\323\341?\342\254\347\370\307\212\353?\370\214\203{\340\310\336?\274\226\021\337\250\260\336?\222\216\033\302g{\336?\026\344\227\317\004\332\331?\'H\303\375\261\274\351?\204\204\t\202\210o\314?\\\330\237ko\003\300?\252E\361\005\372n\331?\000\330\304+\2069\177?\021\364\300\017\032\222\351?\365!/\n\032k\351?\304\200\316\247\257\260\322?\316J\301\201F\024\330?\033#\251\360\327/\350?p\200\2724Q\303\246?\007[\311\332\'9\341?H[\330\255\215\317\260?N6\305 \2330\342?\020rVx\302\"\244?`c/\246c\214\276?\360\301\204U\007{\242?\320\036\013\237\013\332\243?@\260\346\355p\301\266?\221V\263\347:\235\343?J\355\r\325\034\\\341?\301+]\311\013j\345?\327<\214\210D\214\355?\030\272=\246\246\030\340?fW\365\017\304\225\346?\2345a\250jA\343?\300\334X\"\3120\253?\2003\307T\370N\351?b\225\264\265\237\250\330?F\271\261qx\227\321?\225%\362\026#<\353?\354\'\332\236U\331\304?R\303\264\363\305\244\351?w\2174+|\200\350?\326\207\031O\311\216\354?\200Pn\311\362\005\254?=)\355k\261\243\357?\250\353\005\317\020\352\264?\217x/\370\235\364\357?\260\014\204kzw\246?\":\370\224\026j\343?\245\355R\3352\225\343?B\306\271W\255\372\325?\335|\373\tZ#\357?\344.\207v\300\337\301?\360\202\020\021=7\254?|{4@;@\312?n\31366\365#\345?0W\317%\320\337\257?\027\231\254y=\277\354?2\241\234nZz\332?\236\222u\361\013J\350?\337\306\"\003k\216\351?\030*e\202*\024\354?\000\261p\215\377\201\307?5&Z{\220\245\352?\300\320\350\002\007t\211?\306\331wG\224\237\345?\000\231\261\372G\'\270?u\263\357b\214\201\343?^\213\020Q\321\367\327?\032q\227\366\246\017\351?\360}\017\344\024\t\354?\356AV\237.\244\343?Q\261\032\212\177I\355?\210D\021\032\371\240\312?\224\256\256Rs\202\341?R\310\'\272\315\t\321?\317\210dI\274P\343?\255\212q\037\230s\357?5\010\0273M\021\341?\306\355\206\252\253.\357?\220\003\344\177\222y\261?\264\356u\023]I\323?\350\336e\355\232\213\342?^\237\257\316\370^\332?H\343\301\264\365\233\342?\200\252\217-g\243\250?nr\334\361\317\205\355?4\337H\217\252\227\301?H\330\322\300}\006\262?$\235\333\267\211\324\315?\314&|b\276\014\307?\220\013\334\274\313\337\300?U\2773\022\240\343\346?\236f\022\201.\367\353?\320^\2075\350Y\306?\000\013|\014\257\255\350?\300}\354\345\301f\260?-\361>F\032\362\347?\216\304\320\027\361\332\336?g\355\275\024\312\253\351?\000i\262HF\321\216?\346\377\231@j\335\332?\020G\037\177^t\250?\340%j\007zt\303?\207\345-(\231\204\346?,\001\347Gv\376\330?\0172\362M\267m\344?@\262I\344\364\261\223?\030hV\360n\200\333?(\271\271m\324\350\316?\3162\372U\272\301\330?\014\237d\356vW\336?\327h\022\303\210\371\346?|\313\006\231%\276\322?\330/\2728Ip\327?\347\373k\217!\230\340?0\313\014R\340\346\273?\346\n\374d6!\340?\260}|\340\302\027\270?`\360\341\202#\365\344?\007\375Q\272\307\032\352?\n\211&\037\214\204\332?dk;\221\345\032\321?\354\326\235W%1\337?)8\370\264 \177\351?\374\233\345\320\256\224\353?A\363z\277n\227\354?\262\227\305\204\265\017\350?v\3703Sv\022\350?`\335\221\253\023\203\262?U\331]\245\234\261\350?\3541*,\247\247\335?\330\261b\376Mq\343?{\336%\237Y\000\354?\300q$\225\006r\312?X\274\312\352&\265\320?\200\230\027\351\223\251\256?x{/\006\230\207\305?\220J\252\232iy\325?0\330\210w\213\217\323? \275\016\3362(\321??f\225$M\341\354?\0331\332~\277M\351?\340\366\273\271\301$\223?@\017\231P+\362\303?\250\021\n\310\253\321\331?\024(\230\206\022M\335?\226o\032\202;T\340?\036cI\n>/\332?JK\271\257\310\014\350?\000\221\343+\'\271\250?\320\230\272\304z\247\334?\363\3172\002bk\354?\340\023\346\274\356\256\241?#E\220\341\021\021\355?K\247\363FA\362\357?\244\014\324AC\021\341?\222\271p\245\207\304\332?@\310X\377p\373\264?\201\274\361I%\340\346?o\304\257\234\\\370\344?\250U\326Yu\216\311?^\211\275\246\261\033\350?\326\373\217\306qh\347?\344\356\254\034\243\332\331?0\231d\320\256\313\320?8QU\356\340\343\314?\234B\326\273d~\351?\200l\rh\336\271\202?\024\367\246j\274\202\355?\373_\361\263\311\204\357?7X\365\233\312\016\342?\330\347\004`\301\242\333?^\035M\202x\241\321?\"kw8\303\215\342?\232\303\313\332(\006\350?.\024\332?g(5\024V<\346?\240tp\367>1\341?g[\035#\241\207\343?\273\021\370\344:I\355?\007\212\261\002\\\323\342?eF\367\205\326\234\341?\360\336\327\236\263Y\354?\036\2408$\345\275\344?ge*ki\213\357?\210\242\305=\207\240\307?\307\341\213\222\251\200\354?kr\226\0374\311\344?\225d\351\226\373\201\355?N\013\361s\"G\333?X\374)\364\211X\273?t\026\310\363N\243\350?\004v&J\267\365\320?\361\314=\255\020\244\353?\327\331JT`\216\347?\000\314\307\231\277z\332?\240\205}\254\360\005\245?\240C)s%h\311?TB\332eT\343\356?8\025\311\'H\210\264?\226\212W\243\201q\346?\234\264`\033`\277\313?e\007\311\305\234\001\341?\016\000\240\0146\371\335?\020\305\213\206\352>\244?BB$\230<\004\356?\204\271\2740\270^\304?\014\025\272k\301\337\330?\010:\3152\3714\261?J,0mc\032\350?7&\307C\212\242\342?Fm\026\377\037a\327?o\366?\005\246\332\342?\215o\275#\270\'\344?0\3231UA\031\354?d\332\222\363\024\241\347?0\004\206 \217\234\306?\2472\260&\264\305\340?\304\002\370\351\031 \356?\207\360\1774U\364\346?\033\201p\222\310\312\340?\302\340|v\272A\334?*\272\261\204\023k\321?\007~\034\255\317&\354?\014\335\315\275\0337\325?\214\344b\372p\017\311?\016+\301\360\230\325\340?\364\254\225\3777\220\315?\274bA\020\2577\311?\340\275(\232\244p\251?\247\2679\317\020\024\355?\320\2662\346\340\224\272?\340\203\342\317\347\227\333?lo\273\375n\006\335?D\253s,#\207\351?@\034\340n\223\276\274?\231\213U\354\301\345\341?\234\201\373\244\276e\357?\3161\0034\370\352\356?\336m5\221\025\234\337?\346\324\233\367\245\202\343?\340%Kb\363\244\263?\003\321\301\020e:\355?\264Q\337\210\372\346\335?\010\333\247\351G+\341?\330\037\014\244\233U\270?\257qo,;\256\347?\212\312\272\201\307\t\321?\246\014a\203\373(\343?0&=\334W\353\251?\303@%\345\301\360\356?1[i\320d\370\347?\030\373d\032\034s\261?`OXL#{\273?N\021p\010\266\333\327?\240\270U\177\254\343\304?\236>\231r\2261\334?\200\354axo#\217?P\247/\026>[\332?\366B{\273\344\237\321?,\330\031\251\305 \312?\002\373\010\316\227\242\326?\272\r1t\177\322\326?\250Q\366m\231/\343?X\300\007\006\246\276\267?\300\200\252\366\362\223\202?\300\200f@Vo\223?\3308:SN\004\315?\000Q\367b\337\244m?\3703\245>\014`\324?\272\203\340B\260\023\322?\275u\016\310\020\302\355?\036f\354\376\273\270\347?\353}j0*s\345?\324\261\3239\367I\343?\206\036\242\371I\221\320?\256\373\266\230\203\377\344?\300\277\365B\007\"\236?\252\216\225\001fI\350?\320\317\207\203\371#\322?\265\2572n\\\307\343?\000\217\034\325\355<\357?\332\004Q\202P(\330?L\357\031Z\334m\345?\304a\315\223[!\352?-\364/\027<0\352?\322>\343\334P]\357?\340@qR8q\227?u\3566\251\307_\347?\300\244\253\3468\204\217?\251\312\320\371\342;\344?\200,\363\306\347\341\252?Z3M\"\001\001\330?\270X\212.]1\321?\221hzr\021\323\351?\300O\006tK\024\230?\372[ _\001\207\347?\274\254gp\244d\331?\325\010\320W\366\026\353?&?Z\343\217\023\340?V>\341\323R\311\343?\226V\362\032\316\372\326?\360\257\333\371\205\235\343?\004\001-q\352\030\342?Jn\244\370\032!\354?\355@{\211\350\264\345?z/\324\321f\202\331?\022kx)\013\036\347?X.\314\325h\004\275? \210\366\363\222\030\254?\344\n\246\255\302\024\303?\376\375f\242\217`\354?\006\271JB\224\223\321?#\266\322\210!\036\351?0\247%<\345\350\262?\345\350\206\346\'l\353?\2172<\257\327\261\346?VP\307S\317\372\354?@t\233u\306|\273?\256\016}o\010O\347?H\307\305)\340\335\262?\324\362\245\327\"\005\352?@\316\245\013?V\345?\357\361\247\274\346\224\340?\000\316v\030,\022r?\3627\360N\254`\343?<\032\315\026\031\267\333?\036X\364\035iF\334?\330\350\356>\370\237\301?t]\242\246\017\037\307?n\261\325\243s\n\321?\340q(d\276/\337?\204W\345\007\357}\352?s\235\"\302b%\353?\300\204\224\037\244!\317?\rMS$\371f\344?X\275\323\303\356Z\351?j\025\023\254\223\235\342?n\241@\203Bh\327?a\243`\020\352\274\340?\013\313O\276\372\"\346?|5\025\200\263\214\314?\"~81No\352?\254^\277\305\3076\340?\204US\252z/\326?\270\314\235\262\305\247\306?\024\344\373\t_Q\333?\034\345\253qH\235\341?\312\257\220\021\rJ\326?0,\211\276\002\367\242?X\323\203d`\315\333?\014\035<\0376\260\347? \t\214M\021D\223?\252\372\343yg\230\335?hPx\364\276\335\301?V\211\322\261\252\270\354?|I?\374P\036\303?\030I{\207j\317\343?\203\377W\216k%\352?k\013\210\361\203\337\352?\262\227%\266\262\021\331?\362\346\\\323o\337\341?\rr\352\262t-\341?\252\273\010\rCk\333?<@\001%^A\306?\253\225\233\242\346\347\350?F*\203P\006\357\355?\246\237\321\237v\026\324?8z\2253\365e\321?ta\246\263\214\206\342?\036\037\365)\214\307\323?]\234\251]z\377\354?\322\232\305\225\225;\326?\3265\003\035=\335\336?\315\256\330\232\323F\346?Wi\207\351\240\247\354?[\262a\'\257\253\355?M\016\362\030#c\345?/-\342\263\275\241\355?_\356\2111\'k\343?$\233\177\r\024\263\341?\232oY\t\267\031\341?\302#4\343cY\344?\220\024[\220dN\341?\"\3778\213\200*\330?\201\002L\324\260X\347?v\351f\325O\'\350?\000\317\371]\331\202j?\330\341\032\243K\376\303?#\023c^\024\234\340?4\217\204_\343*\310?k\230XDU\377\352?\320H\252\302\242\333\345?hBC\364\022\373\341?\n\336\tS\266\244\346?\254\022\2310xF\347?\301a\0036\320\274\341?H\260^\001^~\314?d\2577>\302e\304?\222$\013D*\341\332?\235\301A\253-\236\351?\274h7\t\306L\315?q^zH8H\354?\000\252|\227\007#{?\210$\006\357\234\315\275?\020/\265\037\223\351\325?\327\234.\313\030\260\345?m\213\362l\204o\340?\242\244\241^\207P\354?\016`\3055?\343\341?b\335RE\264\314\352?i\027S\231\265D\342?\024\rcB\200\364\354?\224\202\212!\010\027\316?\340\357\007\276L\326\272?h\227l\003\351\365\303?Qu\242\032`T\353?X\223\263\334o\216\355?Q{\010m\216\013\342?U\244+\320Q\373\342?\020T\331a\321I\326?&\331\t\345\250J\357?\004\201\350\320\260\003\306?\356U\371&5A\352?\210\216\350W\206\203\312?\322\034\211\316h\371\342?8\317\320\326\201O\321?rY\336TL!\333?\267et\300\344\006\343?\0001\312\267\3177\262?\200>/\250X\265\177?\302\376it\010\220\326?\262#\375\220^\257\343?\300SK[\312\014\302?\224,\277 wh\337?1\235\223DQ\247\346?\356\026\212\373e\261\336?\314\200Q\005\031\006\350?\324\211\344*\271\016\326?\2200\241\366\225]\335?\260X\217\356\370F\342?J\337\017\224YM\357?\017\036\276\377\004\226\352?\323C\375\261\274\326\345?\230\272\227\211\363\001\333?\341z\237}m`\355?x\266\034^\246\003\305?\342\313\307\306F#\343?\331\302\222\372\244\346\351?\261\031\370\324\235k\347?\316;\347w\251\323\353?\216R\213\030Cl\351?4}\221_\203\347\312?e>w\3464I\355?\272[<,]\356\350?.\232Z\2466\342\332?\303S\201\360\372\231\340?\022\244\216\326h\033\323? r\206o\332\360\343?\202\"\242\364T6\324?L\356\324\017\0100\352?0 \240\324\313q\313?G\275\026\374\377E\350?h\203\260-O\207\300?\037;t7\256\307\342?.\247{\332\361\224\344?\3107\241L\365\337\275?\320\357\267?\336\277\262?:\370\334\267`\360\352?2n3\013\376d\350?\"\335n\353\360R\334?\2749\304\313gT\303?n\\\211\324O\340\356?\350G\243\010\250n\274?E\207\340Z\030\275\356?Y\300\031\332\265b\355?|\357\177[\032j\353?\020<\345fn)\265?\233\014\306#\312\242\341?&\363\"\331|\314\320?\231-6\002\370\330\351?\374\254\215-Z\372\313?\252\367\350\206\320k\335?\327\224\014|\365\255\345?&u\037q\301\t\340?]\201\001=\307Y\344?`\217\316\374\274\275\330?\210C\320\303a\375\261?\332\2023\273h\335\341?Z\246x\357R\322\324?\007\212\204,\275\316\356?\212L\375\0028:\352?\007gM\274}\027\342?\246,\222\220\235\372\323?\200o\363\205\242\232\321?DGJ~\203\340\307?\313\253\3725\344I\343?\223x\r\253\352(\353?x\271\341\205\2178\274?D\031+\225\253\340\327?rvgK\310\007\342?\361\361\227\271\241\376\350?\006\333\220\354\007F\322?\035\223\264\247kq\343?\216\013\325\272\003\354\342?\3550\216\332\274O\343?\355\023\306\355\252\343\342?\345\264$YN\360\343?\261TC\212p%\353?@\362g!^\265\302?\206\366\027\264\211r\332?\230c\\Wh\273\271?.\262>\244\216s\357?\026y\304\022\377\\\334?\220\225\016\210\037\352\352?\000\200\355QO\255\244?\300\312K\262\tH\240?\010\002\331\215\342\372\304?\340N\313/;P\233?G\350\242\320\364\\\353?%@\244c\251b\351?\204V\236\2479\271\342?\351\310U\254\263\332\347?\362\326\0363\333H\353?\000\211\206\025\244J\330?\302\360q\223\276&\341?\330\323\347\000J\203\270?9>\275(4\033\343?\321\313)\001\270\361\355?\217s\233\211\300\360\347?78\340\235\272\341\346?A\303\004\276\372@\357?\277)\361\034\256\237\342?p;\207\367f\034\267?\036\257,\370\361\220\323?\366\270\250\013a\027\331? \3656L\327C\304?\277\212J\032\235+\344?\230x\220Y\361_\325?\205\2222K\273\330\341?\206\016\265\214\345\305\346?9nC\215\331[\347?\320&\303\262\300\244\243?\237&\325\265\362\240\341?\340\306\360\322\220o\256?\001\321\355\217\002]\353?U\204t\026\023\233\356?>\014p\224\344&\355?\030q\'\223d\375\335?P\271\361\363\213\353?\277Bh\354\245j\344?J\243\034\364\n\322\333?7h\343\244\317\275\340?\330;\2302\351\304\306?@\312\326\362\271s\324?\230%\334Di\207\262?Xc%z*\217\314?p=\275\030XS\250?_\257./FV\345?\265\264R\234F?\341?6\250\237\223\032\242\345?\340 \224V\343\306\317?\236\315\'\016\302f\324?j{\305u\252\276\332? \025\353GS\240\262?\356Y\275\36677\332?\374N\001F\207\327\317?\220;l\2312\261\241?@\243S\265#\321\273?\257\245\267\203\264\374\342?a\347\360-\327c\357?\002\214\254g b\336?\211R\003\003\236\204\347?|\003$(\202,\312?\210\270qj\275\352\307??T~\213N\254\353?\267\020M\271\357-\350?/\307\021\356\231\024\346?\030\376\355\001\226\274\275?j/\345S\277\270\333?\306\036\004\300\212\241\331?\320\213*\351*{\265?\370\253\260}\310\364\301?_\307\266\220\366%\345?@.\014Vt&\343?\020\033\307\305\227$\243?\301w\300\034\3767\357?j\005\305\264w\215\332?\306\303dqN\265\324?T\216[>\307\007\307?\036\026\031\3350\372\336?\316\373\177\200k\243\327?\312\243|\033\200M\334?`I;\003\271\373\305?Z\345\2316:\204\332?VE\365\314B\377\357?t\225\342\324/u\317?\217\317\304\005\350R\355?XF\222\256\342z\331?\000\027h/\236\005f?NS8L\277\231\331?t:\341@#\255\331?{O\255\234t\010\343?\247R\301\327\226\274\343?\026\360\370\336\256\342\330?\002\007\264\2562\034\322?P\3466\271\343]\312?\326`\352-\320e\340?\340K\344R\345l\271?\311\377\310\321\335\034\344?\353HG\276\337\327\350?\032E9\245n\020\354?\246\274\033\023\007\\\357?\226\035k\035\252\213\351?a\320\302C}\026\340?\234ek\001\3018\322?A\221P\341,k\352?\210X\031\025\317\374\346?L\357L\316\376\203\313? \035\314n\247\320\306?\310V\254n\242\361\270?\000AP\355J\345q?\270\225q\326\3314\332?v\303D\242`\250\344?x]Z\371n\206\343?W\372\032\267\353\345\352?$\353B\021\200\"\352?\260\204\364\032\321\177\242?\260\322l\325y\272\326?\001\311 \267X\t\351?pH)e\010\261\334?\240\243\r;\034\377\251?\341\324f\256<\214\350?\213\007\r\367\270_\347?\232\212\300\177\354(\320?f\272M\274\347\207\342?)y\2773\232\373\351?\243:\336T\020t\342?\244W0,\276\247\311?\325V\273.\216\202\342?~\375\002\233\364\200\347?b\240\316.\2358\321?\322\021\352\342Vj\340?\3523\3412\333_\355?\210\nW\264\004X\357?\360\237\330\242\351E\315?]\342{r\341B\347?\352\007r\367\223\363\341?\200\275+[\246\232\231?h#c\031\362<\266?\352C\262\216\030\272\330?!q%\035\275\020\356?\360\244\363\363\254\247\272?\364\0013\310\216\257\304?\326\006A|\325\257\357?>\213=\325\037\354\326?\250N\306:\312\315\345?\260\311\205\303R\003\322?\010\365e6\t\230\326?pd\343<\227\322\264?\250\331V\221&\355\345?5\365F\245\363N\353?0\372\267\373\022\265\324?v\010\275\017\n\326\357? \242<\233{\367\323?YV ~#L\340?\250\371D\304\177+\266?\331\253&1v$\356?\270Q\"L\013\330\302?\317\035t\260\006\337\351?@\262\310{\261\326\212?\233-\260\341Q\372\356?@\252\242\037\215\247\242?\340@\340a\327\315\272?\000 i\263\227\3508?D/g\216\354}\310?\371\370PKD]\357?\202\237/\2605\202\346?\030\021\n\337>\306\300?Ps\350\252P\270\331?\345\204\252Y)\'\350?\200\351\033\263\363\270\273?\300n\226\016\321\240\333?\000\373)\270\230\343|?P\227{7\314H\264?\350\340q\341\256\353\261?\230\002?Y\370\340\312?\330\310)\253\266\024\312?\246\336\237\205@R\354?\224\024\244\313\372\270\345?;\217\004\314\001\323\351?2\212\325\310\032\037\327?\260\370\362\363\356\370\326?.T\276\200\320\333\347? 9\023\202\301\216\353?,S\250\037?+\325?&\244\326\360\276\r\321?D\004\350\332\003J\320?\260\324\264y\013Q\326?Ruay+\033\332?\347\361?X-\333\342?\200F\256(;\263\303?\3007\013\373\352\313\242?\354+\t\354u\365\332?>=Hv\307\345\340?R\344\312(Zc\352?\374LI\360\r\r\310?Lp\214\202a\001\351?h\032\260\013z\243\270?\226\355\323\2129x\341?p/G\360\357d\313?mDT\023\017\247\341?\260=d\364\3472\244?\344L\037*\301\221\334?\200-\367\263t\343~?\363\256\374\205\004\364\356?6\'\240$7Z\357?\270\211\236\313p\245\340?\246\263\242\377\033\340\331?\270\327\357\317\2026\313?$u\331/ l\347?k\r\304\200\"\032\341?\262\346\333\315\267\n\354?\2440s@\235\014\335?\007\273\351\336?\342\353?\323\223\344\234\037\241\343?\226\203\205\223\002\213\324?~\210\002:\331H\354?\220\025\256a\360\272\351?\313\013-\350\307\020\345?\234\312h\311\035F\311?\006\030\030\014\204\230\320?\2247\237<\241\023\302?@6\223\306\350\224\326?\354\021\232\353\362:\320?ZI_\r\255\323\334?\312@\202#A \357?\010f\344!F \334?aM\205>\334\234\355?\276\016\301\337\361&\347?\200\351n\247\372\314\352?ahh\355\356\323\354?X-\266o\221\250\333?\240\332\036\333\317m\224?F\210\277@A\001\325?\204\234W\272\254\200\347?\336\244o\315\327\030\327?P\246\337\271\353\036\277?\364,\215w\214\221\341?~)}\320\334E\332?\240\000)\010\347\n\226?H\250\3631IJ\272?\366\037\246`\330\337\326?T)\240\303\260\345\302?`s\270\022\324<\237?\005l\264W\270&\353?B\366\214a\342G\340?\274\363[|t\276\316?\277\0102\026\323y\347?\274\333\205t\034\372\333?\300\264\374\216\300\360\241?8\200\253\3626j\264?\032\371d\360\246\237\336?\310\216_z\256m\351?\343E8R`[\354?\276so\240?\037\320?\200\017*\325\325\355\335?v\034{\024\016\245\321?L\007\024\276\273\264\340?\006\0376\264\276\366\330?X\177\222\"]w\332?\'\345\t\304\314]\352?\266j\274\245\327\306\346?\324\334i\001\210/\350?\206\347\177\254&\347\320?@W\307\322=\314\247?\300h\010\2364\324\336?\371\335\314\301\355\270\343?l\324\375\247\325\273\321?\024\264OZ\257S\304?d\310L\317\003]\312?mt\263G\t\273\344?\270\325\223\222&\260\300?\234\375\320XH\024\301?\333\022\203\372\343N\352?\362i\017\034\035/\350?L\267r\\\226l\300?\004\r:\273\366-\316?G\267 \252\363\024\343?\3707\242\224\316\361\344?\310ba\322\353\326\316?`\341Ga\245\234\270?\200\032U\037*xu?X\245,t7z\353?\2411u\37133\354?\260\304s\257\324\376\326?\214\363\213\310\324\242\355?\224 \346\372\243\233\340?\267\333:\336\021\307\354?1r\277\"\n\260\343?\256\223\204c\272+\320?\254\214\023l\373\231\302?\000\373\377=\026(\266?/\333\'2d\237\354?\337\303\3629\231\202\346?\256\3066\367\361\241\343?\250\364\251m\205H\314?m\025}\214\037G\350?\274\226Y\243eY\301?fP\271\203\206\252\343?\312\217h\331\317\353\322?B\277\300\370\377\234\327?KcA\302\235\017\352?b\013T\324\310*\342?\334E\235\343UJ\2450p\333?L\265D\356\376;\340?\306\362\2657\0347\327?\301Fx\000!\365\347?@\030u\377f\374\226?LD\254\335c,\357?X\306\177\211[\007\310?\260r\361j\311A\300?\024\326}Y\271\025\356?*\031\\\240\336M\327?\376\356\260\013t\235\323?\232\027\235\256\017L\353?\327\2063\307\211+\357?\357\367\270\023s\304\346?\360\336\250\366\n\237\324?\014\225\363f%\220\354?\367\006 \212\003h\340?@\262t\303\3213\222?T\'j\263\365/\323?\370\314B\355\337{\267?Z\036=F}\231\327?\2148\200\306\\\010\316?\024\004.\234\205f\335?\305\220\032\\\033\350\340?\200E\216z\rQ\311?\346]\270;`\206\322?`\t\242Lb\246\344?j\364\212&\263\304\326?\361\347\273\262\030\330\340?0\213+\037P\034\311?\004\324\352\315\007r\326?\254\313\206g\356\371\313?\020T\234\021\225\206\261?\216\270\005]\267 \353?\373Z\317\202\031\033\345?\351\241\317\006\320\253\342?e\247\224\351k\254\343?\216\204n\034.v\320?\334\364\250\344\201\205\336? \330\036\227\003b\347?tEb\272\020\036\335?q3(\212\302\264\355?d\357\332\252\265C\317?w\243\245;\343{\352?p\365\004\246T\252\256?\350\351\010(J/\264?\300\206!\265?\217\235?\001\337\344\307I\341\353? #\274Wj\363\337?s\316\204$\022\374\342?\024\2572[c\327\355?\200v\246\360\001\225\221?P\245\032\313\210f\352?\226-)\351K\032\342?\364z\325\304\335\220\340?x\201\304(\345|\300?F\202\244,\241j\336?Pv\360 \'\204\327?\364V\307X\354\036\327?\260\301\217\007L\331\260?\356bj\221\330\365\331?>\332\271w\343l\323?\201\214?Y\223\225\354?8\255f\327\247\377\335?\343\3752\245ZH\357?\224\177@\010I\236\317?\261\260O\263Qu\345?9gVY,\364\355?\257\350\014;.\377\354?\302\313\244\3200f\322?\000\032P}\"l\304?v\340j\302\376\224\330?\321\214\374W\363\005\343?\246\210\366R\375\311\330?\300\000\315\210\233\375\323?\372\0136\342.\374\345?;\275\n\235\004N\356? \265\207Im\304\255?d\035\257P\032\261\317?^\343\251C\241\301\341?\nH\232\263\346\314\336?\027G\t\230=w\343?~w\377r\212\244\350?:\307&\314\2544\331?\250\242\254k\300\006\260?Jz\270\025\356\'\352?\344Y\247\211N\232\334?\314\304*UU\333\335?(=1\204\342\247\274?g%\345\"\013}\342?\220\251\321\364`_\270?\000\216.\362CB\225?\2778\203f\230\356\345?\300\002O\231t\340\272?r\003\345\037\241?\353?8\002\347\005y\243\304?\014\224am\240\377\320?\361\345H\004\035\303\352?\300w\027\360R\026\345?X\223\351l\214p\302?\020\006\301s7h\335?\032\275\335\305o\244\347?\257=A{\355\020\345?\337t\203\"\014\215\354?\230S\342\363\020\214\356?@\250\023><\033\344?\\\233\245\2062O\325?hl\277\350\362 \265?8\361\000Z\2639\306?3,\364\017I$\345?\007\003+\2019\026\357?V\021\353\346\037\030\354?e\306.3\017B\341?Txd\026\267`\316?K9TP\356\270\351?.\374\312\356\321\344\353?\\\226\222\330\311\024\350?\310zHH\003\245\332?_\222\361\006\333\t\356?\372]j\250\352\256\346?-x\004:\027\032\341?~\362\271G\335\324\354?Q2/k\311\212\352?\001\244\271rb\361\350?\220]\274zC\231\317?\254\326V.\301%\331?\014\265\023\332\212\203\351?\203\025N\213\223\253\354?\234\206\361\033\312m\303?\24400\322u\223\310?_?\253\035\227/\345?\207\260\372\321\027N\351?0\30356\255a\247?P\256\317d\302\324\325?-\360\343\356\267I\357?\2024e.\316\216\322?\242\250\271\304\236R\347?\031\253\207Y\030\030\342?\343\3557\337\001\304\354?\240\275\376\037\0327\352?\260\'\374\336\357\331\260?\212\376\372~\246\330\356?B,\"\317\247\245\346?\000\201\342\354\365Sf?9\006^]V\254\353?\241\301\212\313:\371\344?\336\357\3409\250Q\335?\355M\262:\353M\340?@z<~\335\355\251?\224E\216\322\220\212\320?5\367j\357\352j\344?5\251\0103\007\336\340?\200\001_\254\362\321\232?dP\342VH\347\344?]\261\202\006\357\334\345?\346\240\336I\022G\321?r+sm\232\340\354?\rl\357\325\312\244\352?\247P\2368\203\250\343?\247\017\257klt\357?\343\367\337\035o\t\346?\000q\246\r\344\271a?pe\342\336?\346\251?2\302\334\301\312\310\352?\334\330\017uPq\304?a5\277\264\215B\353?\275\227\370\303r\030\356?\241\035/D\201\036\347?\2270\327\026\020n\340?\232\216\361A\017\343\324? \006y\315#\021\256?Qe\n7\007\"\353?Z/\033P\310\020\343?\214\311P\\\337\377\354?h\000\222\366]\221\346?qR\376[;\340\344?\234\217M\006\203\371\357?\200K\300?p4\224?\312\307\367\013q|\344?(ha8\005\027\264?\005N?\006G\213\344?@\177N\236\354x\276?\006\022\177\263?\026\355?@\027\335\206d\377\347?\214p_\004\216\037\357?4\332p\230\303\203\340?\2255\355J\023\004\354?\374]V&U\253\307?\272\324\022\035\013\353\347?\304\277\210\206_\277\311?\273\217z\224\231$\351?\236AZH\212\310\342?\272\307\350\247!:\324?@xB-\227\265\350?Ez\202R\001\003\356?T\034\260\033\2258\305?\000\035\255\310\221g\264?\230\346\027\344\366\351\341?\320D\033BX\203\321?`f\340\023\245\201\342? \027\234>r;\247?\tI\033]\302+\350?\3403\274\242\315\312\224?\026|8\010t\321\337?i\314cP\346\245\352?\222\301\371\025\007f\326?\261\nQqN\212\346?\374\207\255\217\216Q\335?\000\344\222$\2054\205?t\367\354\265\263\273\345?\234\234\022\013\004\177\327?\272\224q(\372\242\332?\222\334\333A\252;\351?\237A\310]\246f\346?3\010H\234\261p\341?\020oYK\212\224\336?\200\323\302{\337\341\302?\232a\352p7\331\336?\266\314a\\\234\037\322?\031\210@\210\216\t\344?\220\216\325\235\007\302\310?|\277\214\342i\207\327?\001\356\321M\372\217\345?\"UzZ2\014\320?\2326\311-\0028\351?\340\221[1\305\356\225?\331\204\247\370DS\353?\254\214[z\2325\357?\200\236\031\036\225-\213?\213\241\334\277\240\315\342?\321\0069\036\\\334\352?b\254~\310\214\254\336?\220\330\364\237Y\210\306?\364G\210\237\317\372\326?r\361\207\372\244\220\354?\334\346A04\250\317?A\343\361\332GS\346?\374\033\032\313q \324?\241\025\2739,}\353?\351\302)\341\343}\355?\227\206\20104\200\345?pA&\014;A\254?p\261\246\034;\345\315?\030\006\\z\227\360\354?zwqg\224\214\346?\002\372@\342\374\273\321?\256\003\235\302\360\201\336?H\261\267\256,\350\327?~\260\344\271\214\214\347?\270\346\325g\331j\302?\366\334V,V\331\337?\3605\335X/!\255?bB=\323ku\330?\275\277\304\344e9\351?\214\326\3635\374\223\310?\210\300\002A\332\371\331?\371\023\265x\247T\357?\300K-\214w\265\302?0^\233\251\023\363\302?\354\240\223\352\247\021\343?`|\262ml\317\324?\212j\345\020n\273\344?<\361\3717C&\301?\314m\253|.3\305?\250\007y\343\241\036\342?\354\316\371pm`\300?\363\200\037]~\241\354?1\265;\262\300\312\342?\000\316\0175\254\363\241?\031\0021\037\363:\343?r\235\037\302\365\036\333?J\250\330\256\320r\357?\340l+\031\314u\261?\322(u\3600p\357?\266:\312\324QZ\340?`\267\204\252\274c\223?\222\245\325\341]\234\326?b\016\221J\323e\330?\244\207\262\210\360\027\302?`y\301V\200\272\313?\234\273\'~xz\340?X&%\260O\032\325?\217\305\252dN\"\343?%X\356u\355\200\357?\311&\352\226\375\250\356?\224\253\310q3Q\320?\304]\222\022\201a\347?\214\222:w\270-\330?\234P\"\315`\006\337?K\362\317\027\371\324\342?c\263uWG\322\350?~\nY(\030\335\335?\260\027\232\220\374\323\274?\344\277\014z\310\232\327?\2704\377vE\037\262?hO\224=;\035\300?~Q\353\242O\002\332?p9\001b\222\205\320?\"\037\361M\255S\333?5K\031\377\361:\355?\353|<\242M7\340?\263\235\374\013~\016\343?Z\340^\341\230B\350?\340\261\021\212\304\270\343?\023v\226p\374\001\350?\332!\014\273\351L\335?\217\032/\025%\002\342?\330\345\013\230\336\355\263?\266/\215)\351\231\325?\177F\355\256M\005\341?\372\2070hU>\321?\017\355?Xb\244\347?\322\034\242\203Xp\322?\354d\261\363\020\312\357?\000\016\t\345\344{\246?\302\271\320\376\261\266\337?/\036\327L\343\371\341?4\"Y\246G9\312?\250\000n\013\271#\331?$\322\244\251\033&\322?\252\310\r\310\350\030\322?\362P_\322\2072\341?\336\351\0228@x\321?\240\356\306\014Zh\320?\303\342\217\352\352\207\341?@\0075\252\233\237\271?\367\311\366Y\222\033\345?(\0342>v\017\330?*_\277\001i\204\332?h\036\244\376\367d\335?\256\254\331x:\000\350?\3149\032\317wI\354?\005\037sC\224\025\353?\266\344\270\330\014\010\321?(\025\344\030\370T\315?\334\356\352\323\006\267\316?\370\"\013My\213\265?\322\004,\005\302E\353?\212\207\220\233b\364\334?}M^\235\035\207\351?\314\252\312_\250\217\316?\262 \264\337\241\235\343?D\355J4\251^\343?p?#\275|\034\250?L\257\025\301\354\314\342?\256\330~\340@D\334?\303\353iD\334\212\356?Ru\007\251t\246\336?4\2216Zc\224\320?xm\264\224\202j\323?B#1\314\353<\353?<9\365\300\207\032\334?\234I\362\2544w\302?F\370\275\325\202\377\334?r.\203^\262c\350?\300b\314\352Z7\217?\020N\227\347\252\016\343?`\030\352\344\324?\300?i\t\311\214\246\324\353?\310\353\335\202(\\\352?\204xD\205\364\362\336?[\315H\002#M\355?~l\211[\305\222\337?\272m\"\347\\\265\330?hw!\333\225?\320?-\331\2724{\253\352?\326\210X\365\213\275\333?&\226]\214\241\227\340?9q\247m\275\023\341?\244q\325\255\276\305\334?\"\202R\201\037\233\322?\340\245GB\257\027\270?X\177\254\302\235Z\303?\366\236\\\235(\206\357?6\351\377\342\323S\336??\352\233\313\226\033\344?\210$\263\tKi\320?\243 \206\320\303\212\342?\200\217\224\250L\266\252?4K\247\330#\277\333?\370O\303\351i[\337?\300A\246\023\324w\313?\275\354X\274\210\353\345??\343\020\320\344`\346?R\'\322\326\255\024\334?\340\332\375\031\010\264\326?\347\023;\260\213\032\351?\234\262\223\361_\231\341?R\304\343\3037\347\334?\020\303\321m\006W\252?\300p\235s\231N\277?\335C\237\206\333\244\351?\250\017\253\272\210\254\345?\257\307\224\306\306c\340?.\266\004\204\245\004\342?5\244+~\026]\341?\230\210\203\237\\\343\342?\374\3470\264\2327\305?\210\005\242A)\210\341?\355>{B\204\324\355?\272\2325\217\274\010\323?\252L\216\303sZ\340?:t6\347\037\312\351?H\315\367//\300\320?/.\'\201t~\354?\016\r>\246X\310\331?h\266W\256-\277\267?\214\275\244\237\361|\334?\342\356[Wz\'\347?\352\016\302G\224\234\325?bq\027\304\262\264\347?\210\037VF\016T\353?av\314\246\224\303\342?\"\023\026\254;\374\332?\325C\241\034Xx\343?\247\261?S\010\324\357?P\304\265\247\250\007\321?$\357_\\\201\021\347?\230zO\371}T\313?\320\r\325qcq\243?\370\210\234\362\263\277\343?\230$V\374\266\235\320?\305\262\237(:T\357?G\370\322\225\302\334\346?\364WZ\262h\326\330?0\nC\322\325K\356?Al\303z\272\016\353?\0249\230\224\017P\331?\"\362\377\235\347\221\357?K.\n\\|*\354?\267w{\236]\277\345?\200;\272\302\270\242\241?\274\263\362\274\033v\320?H\203\342\350\267F\330?\372f\276\366\025y\352?u#\001m\"\001\353?\342I\343\365r\227\344?~\033\317\007\272*\352?.&A\262\231a\352?\270\303\237\302t\257\345?\242\313u\3202\\\341?D<\353\267\345\364\317?\376{\372\361o\354\324?\022d\327\n\263\343\332?_\277\243\263S\035\353?\270h\'xs`\351?\300\026\013\365\307N\252?\240R\351\202\240\272\325?\320)\221kZ\205\343?\214\207\362G\331[\334?;\271\356\r\205\022\350?\220C\035\261\326c\271?\220\373\031a\325\273\304?V\031\354w[\357\355?\304\030c#zr\336?\000UO\304H\030\341?H\302\206\t\345U\346?\273\006\357\020d\255\346?\034\327V\357\241\306\343?\320\324\330\256\216\260\347?:\262]x\332\331\321?\210\002x\247\365\025\322?\274\221\375\301\245\240\311?\271\370\223\335 \236\355?\273\341Qn\010\352\347?\344\366\001\035\250\275\355?\3309V\376\031\362\260?.L\005\000Z\271\340?\\\232\330\010M<\317?\031\376\205UW\370\346?\000Y\014\370\356\232\204?b\211\222S&\374\343?\327\010xs\343\n\343?\202J/H\330\310\341?\367\255\362\010\313\037\351?9\332Q#\224(\357?~\266\020\252F\224\335?2\335n\026\021H\357?\370\316\'\235\373S\304?B\304\035$\371\352\335?\260QRG\2245\254?\200U\362\261\221\372\201?\377\275t\031\302J\353?\354V\321\005\017C\354?\326\r\303\336\275\362\341?\263\201\375#\246V\340?\215\023K\036\016\350\345?\271t\217a\356%\355?\343>\327v\263p\355?%,\205[?\345\350?>M\210<\237\251\341?B\222U\376\036v\345?<\275m\230\0008\324?\200\351\321haM\233?\274\356\351rVv\340?\030\257f\332\265\330\267?\034\345\177\034\3765\303?\371\273\010\r0\272\341?\270qn\266\036H\311?n\336\005\250!\002\327?\230\020y\242\371\013\332?P\033f#Q\024\321?4\016\241a\302\"\320?\2607\271\241}%\326?\260^\321D?\234\257?\330\0010\014\242;\356?e~e\267\245>\354?\212/^\025\220/\345?\ny\013\276\217\214\347?\032\n\373g\330\250\321?H\377\202\016\025\204\323?\030\256\374\321U\370\322?\202\3205v\240\036\331?\370\026\351K\231f\314?Pr\006\030\304T\303?\022N/\315c\354\321?\332\275|\247Q\350\330?h\352\225\252\364V\267?>X\260u\313Y\355?\250\036\267\357\3450\351?\'4\210\007\257S\354?p\333\260\342\3646\326?t(\334o\375\371\306?\037,d\0131S\354?N\203\r\257\224\205\340?\204\265\177\255\2248\343?\231\251\017-V\321\356?\204\375\207\363+\033\341?\270\326w\366\234\270\345?\364g=\340\235F\327?\270\351\035E\t\271\320?2U3\210M\372\332?\220M\247\264\307\324\351?\367\333[\001\225#\341?\233\030\236<\267W\354?d\0321>\274;\314?\250\t\213\350\343\033\337?\226\247\014\243C/\327?@OM\330\277\353\205?\230\3702\366^\322\344?\020\352\262\253\260\223\302?\232\\\3403qJ\354?\370\375}\002\243\330\352?\036\0324\332\177+\326?\206\037\256z\2754\353?\342\274D\004k^\350?\230Q\351\354\361I\344?\275\006\031\236W\206\344?t\365\313;\246,\304?(\264\320n|\367\306?\306c\362\354LJ\350?\210c\353\"\247\370\261?`\300R\312\327\361\273?\031\361\355f\303Y\340?\001%\336\363\272N\355?\350Y35/\256\337?\321\177\377x\335/\345?\216p!\2760>\326?~\336\n\343V\355\347?\213\211$\267>\270\343?\320\003`\261\037\202\355?p\357!;\242\262\313?X\035\270D\307b\312? \205ls\367\316\305?) \tx\275p\342?OnN;\026\261\355?\200\354\330\236\244\266t?\314\316E-\351W\301?oc\231\224%\264\352?\022|\356\t\332\307\341?\367\373\\$\302\242\342?\344\"j\311g\373\336?\257NW\314\315`\344?\301,|\031\317\242\345?\306\020qo\003\313\323?9\366\010\335\200w\343?bu\202?\217,\357?\206\210\266h2$\330?\260\'4!$\216\247?\374\036\245\200W\035\334?\327\267\361>Z\267\350?h1\'\006%J\264?\300\326\241\334*{\300?\347n\010]D\217\355?PZ\001\332o\225\351?\320^\361\272\232F\346?\036B\035\241\'\375\345?\260b\200\253\246\255\335?\315F#\032F\364\346?\346.\327!\243)\323?\275\305oh\036}\340?\240 [\036\235\037\250?\260\373\034\352~\014\357?\177\025\007i\305\001\355?\000}D\235\364A\237?\300\261u\035t\276\210?p*\035~\224}\272?,\201F\033\262\034\325?\336^e\237\0161\323?&\376]o6s\356?Id\204\017^\311\341?\210\251\204\200\341\240\337?\220n\201\364P\355\273?\254\233\307\023\207\006\332?\n\225D%\006\014\330?4\376A\023?6\310?\302:\343m\366\225\347?H\334\025-[\240\301?\222\321\344\217!\n\327?\244k\347w{=\337?\343\270Rfe5\341?\3112\314\352k\302\341?0\264*\237\275\274\320?P3\213Qn\237\311?\200\314H\017\246~\257?\276y\363\n\232l\356?\342F\205\265)~\342?\356\332[\373\223\231\325?\354l+I\260\323\356?`\324\243\265\177r\330?0\365\371\245\353\204\241?e\374\316\373\356k\356?d3\310jNd\321?4\261\253\312zU\336?6\207H\232\226@\345?<\213O\221_{\326?\263B\317\021\354\036\345?\nu\213A8K\322?B\354\342\010\231\003\343?\214\034\224\264\016\217\330?\240\205\333\"K\346\234?\337$#\341[\330\342?\034\317\313z\221\200\312?\355ms^\337\004\353?\204\351\014\031\207\365\304?:\356\003Z\021\303\342?\377\000V\333;\343\347?\370yNa\374P\323?\304\013\324-\353 \311?\273z\235\016\224k\346?\252-\216\275e`\320?b\334\220\022&\325\325?\270]\025\365\264d\331?\nf@d\337c\344?\200I\374\203\245\347\252?`\360+\025m1\333?\253\320\373m\356?\220\222YQ\234\223\327?8\206(\036\223\353\320?NG\304C\346|\327?\226\004q\255\3724\330?\224j\030\2032\344\321?\013\324\225\033\260\002\347?2\3571W\314\t\326?\372\272.\230`\\\335? \031.\235\026\241\221?^\201\356\323!\332\352?h\305\343\2133\341\277?X\362h\003\345/\277?\302\036Y\201\022d\320?6\254\000-\351\240\330?1\321\355\363\247\005\357? 3\371\006\314>\344?\255\343^\247\223:\341?\350\366\232\t\307^\350?\000\243\340\354\354\031g?n\"2>\265\367\332?b\346\302\324\377\204\351?\276Si\351\177\250\326?2<\003pRH\326?s\211\303\315\367\213\354?\020M[\230\246\313\340?\350\337\306\332\347\315\341?\214\217`\233\025G\306?)\365=x\2027\342?\375iB(a\322\340?\246\306\327\346\227\267\326?VX\273f\274\216\337?p\345\300T\014P\344?\240\0077\267=\010\333??\2215\310\"g\357?\032\326\3745z\242\334?H\347\343\320\036\345\262?hco\030+\351\306?0\036\255\375\320\031\251?\006\337K:\371\023\332?\242\376p<\256v\325?\246X\021\020\267\217\354?fT\244\225S\204\355?\000\230\252\215\026.\312?\316\262\200\251\013x\320?\032\324l\234Z\002\353?U\311\331\033\2176\340?!\206&T8]\357?\246\205\000\322\304x\344?2\017\347\312\372\242\322?X\334\312S8\335\345?V\214\031\230\033\316\331?\340\274\362\213\213\021\251?\354\313\027n\232L\306?X\336{\222d\276\311?\267\331\257U=\246\340?*\031\371\317\0011\332?\310~\216\212e\205\260?V\326\330\207\034\211\333?\366\033\304\367io\335?\236\264m`\014\007\332?\266]_Gc\350\353?L\373sg\244\230\303?\372\225\307\220\305v\332?\014\364m\017j\313\301?H\313\304Cp\343\326?\321\237d[\215\t\350?\255|\276:\202\204\354?\262V[\246\354\352\332?\371\037~\317u\037\343?x\331\244^e\217\277?\2668\3250\347W\342?<\241#u/\314\354?\333\266\361%y\345\345?>4(\202\364\010\347?\360\223\274\254 \020\271?\354\374\023\216\211h\302?@{\233\264\001\242\341?\325\020\3014\243h\345?\243\025\225\316x\266\341?B\272\312:\224/\320?\274f(\013p\370\317?=]>m\247&\347?\232gwz\367\252\337?\210\225\30133\306\346?\26294\317\347w\347?$\302\331+\030\036\335?\222\213T\312\375\"\330?\301\214\251ns\202\346?\324\207q\221\237\034\332?c\272\020\253>\244\357?\226=G\361\221>\342?\316&\t5SO\356?fh\321\205\333R\334?\200\271Q8\026\242\223?l7\263\2559\200\322?bUp\336Kz\345?\306t\214M\322\036\332?\200(\230\306\251\346\321?\344\3239-\343u\357?\3003\362\341\217)\217?\203\231q\346A\021\356?\350jK\374\307\370\343?FD\240\000L\213\323?jW\271\300\372j\341?l\353\370&\304\347\327?\271}\035x\304\235\343?\253-\001\201\221\372\352?n\035\223\023\261\270\351?\347\360\323\275H\276\345?\372\022;%$\362\343?\031\2466-$\204\346?n:\021\256yl\336?\004j\266\365\301\205\335?\034\3367\215\373\251\300?w \006\327\220?\342?.\377B\365L1\330?\004\336\031T\230@\347?F\370\244\336\216%\353?s\366k\020\317\216\345?\270\325\253\365\3450\341?\214Uv\222\305\267\345?\201\2779M\312K\354?\265\2636\347P\000\343?\000P\334\326\030\252\212?\376=\371u\211\235\335?xP\316\331\201\341\306?9\263\376\017-\002\343?\220\267\002\370\333\016\316?\350\313\250K\006\352\327?\334\017\177\341\004\315\302?\227\244\323\366\353+\343?\327.bD_\232\357?P\353\235\250\201t\277?\243t\243=j2\357?J\205\301\273w\252\333? \235i\2427\027\266?d\312\366\240\030\024\302?\351\211`\316\310S\340?\220\205\021\300\221\253\264?\236\316o\321\365B\332?\351e%$\201\250\342?(\007\333\021\236\251\267?\227\362\247M\247[\350?\324\362\215\361\365\331?\313\036\255DV\022\351?\014\000z\2762\367\326?I\027\240u\252\363\342?\221\344c\303\205\344\343?\314\004n\365\026\307\355?\234\020>\300\314\216\351?(\311\256\344i\017\347?\326\201\210E\0062\327?\227\226\020\\\323u\352?O\334\025\005d\322\352?\020y\311\320#g\330?\267\220d\226!\314\357?\360\260\252\313J\020\335?\340s@\357 \277\336?\030\341.\304D\312\306?\377\241k\326+\254\342?}\r\227\244e\332\340?\224$\373\013#\357\345?\006\326/\326p\031\354?\200\022G\270\"\344\234?\030\210ZDn`\313?\232\'\n\201\001\232\326?\306\201\034\237\255\216\321?\320\3468A\305\026\334?9K\032j\347\033\347?\350(\314\312V\213\340?`P1eA\344\275?\3310\177\2605\225\351?\020-\370\264b?\354?\322Y\032\354\010\270\331?j\312G\226f\364\331?\340\275\006?\017\204\332?\206\344Sj8i\342?x\260\335W\314K\321?f\365\033\232\254\315\330?\300M\327ZF!\341?\030\035&v\336\206\277?\360\274\350\354\261\306\347?\334\\\254\330\004\260\344?\213u\302%KB\346?\332Bq\300U\246\337?X6\032\353;l\351?\340N\214\203\301\301\352?\215\211\222+\020\206\341?\357\227\016H\017\304\355?\\:\2372\006\336\332?\336\231\361D\023\032\333?dB,\221\201=\326?\275\375IZ\274\247\350?t1\310V\212\332\306?\021F._ y\357?e\204\263\372$6\344?\357tE\221\223\033\350?\214\335\032\034\326\247\342?\373-\025\013{\203\351?p\0164a\250g\326?\017O?Aq\320\352?\300\247e\333\026P\262?j\277D\345\000B\325?l\n\206\302\3236\341?\002\253\334{(\222\331?\010f!4\007\341\330?\001\277\264\301\r\217\353?|\006m\006\333_\343?}R\234\347\017j\344?\030\357\177\275f\331\300?D\212\362\334\033\377\341?\007l\256r_\037\346?-\270\177 /\001\340?.\002\253\001\365a\350?r\2251;\213\026\354?\344\336\340u\333v\327?\210\320\361\237\236p\322?:\315\224\314\217J\322?\241\360y\334{}\342?j]\275\251\253\031\340?\260\201\227\341\344\001\241?\304\301+\210~\241\320?\340\212E<\021\207\231?\243\273|\316\336\203\353?\016\000\r\014\322\006\355?^\260\331?,G\360l9\235\345?s\335#\302\210\211\354?\200\n\323\\\\\370\235?\210\033K\337\350\243\313?\206\361He\370\236\354?\206o\232i\255\231\321?\215\326X\tY\031\353?P\314-\366\201(\305?DK\022U\200\230\351?\231\327o/7\254\345?\331e\253@a\"\342?\360`\274{\240\240\351?\344\240\236\201\032B\303?|\003q\364\0330\304?\006\320x\336\257\006\321?P\240\203\276\3065\323?\235\375\275\014\261<\350?\330}\002f\374\301\267?`\273\3008\241\276\356?\344\232/\323,\214\343?2\005\\6L\330\330?\264\302\2019\336\222\323?\017\274\325\276\323\275\357?H\337\273\235/\246\342?\320\236/s\2624\302?`M\204\323z\027\347?O\325\304\250%<\347?\231m\016\001\276\035\346?\254\230\224\310\317\177\322?\027\023~\315y\367\350?X\333=\356\374\017\321?\202\345r\006\307}\357?\372\303\245vd\343\331?\322\025\331\252\013\274\323?\220\224{]\374\023\276?\034\031\272\375\256\324\334?\214\327b]\347\211\321?\322\n\340\367A\317\352?\356\r\323n\325\344\326? =PQU\216\220?aFv1 |\354?\330C\023\000G\336\327?\006\210zs\010(\354?d\346!\237\207m\311?\0240\316\272c\307\315?\2633\231b\025E\340?\304d\364L\246\244\342?\214\332\217\001\037\232\314?\234YG\335{\270\355?\020\312\367\031$.\242?\222\261\"\207\rj\354?]\221`\034(\\\343?,6haG\027\336?\354J\027\327\3707\310?\010\032\250-\347\006\346?l!IBFP\327?\265+&\251F)\350?\010M\263\243\264\035\261?\331\377\354\223{W\351?\000drn\204\213\230?*\232Jr-\232\320?\322\273\007\326\001\334\332?4\027J\034mW\320?\360\245A\225u\327\261?h\001\365\021\024\003\347??\374\205\340\003\242\342?\340\002\210u\266\337\327?\272\250t\352\234\343\337?7c\236\373[b\356?O\036\367Uo\261\346?\000\335i\005\"\035\211?\212\340\322\232\305l\335?\340Z\256 }\007\300?\260\230dD\215\333\261?\030\313\250\251.L\330?\366\020\010\030\234 \346?\3667\010\244\263\234\352?\250]\260\026\035\350\337?\026\367/7df\322?K\274\336-;\352\353?\247\353L\364Y}\351?\200\305\301r\217\205\212?\356\2117$\315\307\342?\000E\377\250\016J\243?\236\225\321^\236\276\333?\024\267\003\300\332\311\306?x\204\340\272\223e\336?\360T\027\177\264\330?F\330\255\251\253)\340?\256*\354Rxp\340?\314\272\362\313F\017\323?\220R\031\263\256\031\314?@\223\305\264\235\371\207?R\333\303\234\003\026\330?\2075#\324\372\273\357?B\274\210\022ob\323?\006\316e\232\317\304\346?\032\216\250\020\371P\357?bl\237\177\366\023\325?\004\252\013c\246F\306?D#\312\363r\365\342?\032\3547\004\n\027\333?4Kyf\227\337\347?\356)\271\021b0\323?r\270pF|\246\357?\026\316\352\375lN\346?\3304p\230(a\305?\\\342QTZ\215\330?\314\367\302D\337\317\305? \010\002\316gx\324?<0S\000\242M\305?\rK\016\347\271.\355?\035\r\261\263\227\310\356?\201\352\234H\002;\350?\0246\346$1\332\313?T\225\265\252\031i\311?}a\276#\246\321\345?0\325\300\242Z\030\335?\330\2360\\\247\312\265?^\\\257Sr\275\332?\222|&\247\310\316\352?^\032\363cS\347\343?\344|K\273\241\213\340?\254\305[\232\377\350\337?\231e\332(\347m\342?\212[\325\214\320i\333?\304\356\311\326/\252\311?\340\263\216+\253\355\316?\370\3531Q3\217\315?\213\331\227\030\276I\352?\013k{\240\340\250\354?\270^{\253\243\003\301?\001\246H`/L\357?x\210\354\344\002\345\313?\036`\271\365\'\330\353?\240\2255@gf\234?\2622\026\341E\230\331?\312\032\021D\356#\345?T\376\027\341\3337\315?\3749*n%U\302?D\373\330\027\270\205\356?ZH\264\005q\351\357?P\013\326\3709\354\357?{\221r \345i\356?8C\237k\017\353\353?\225\025\010o\246q\352?V\274\362\342\014/\334?\270\234\035\245\230j\330?\336\365G~\341C\326?\340\016\215]l1\224?4\277\265W\214P\331?\354B\277,\205\177\334?\300\222V\276\035\236\232?\325p\213bR\033\342?\216s\367\250\272\n\343?W\246\271e\\\272\346?`\234\016\257 \246\343?\354\216E^C\226\323?\224\273V\204\3728\307?\250\376\205\364:K\276?4\224:\236\206\035\302?\")\367\315\323;\332?\351v\327\264\250\035\344?04\316\201\230\270\303?$\0009@\257\317\300? \212\245W\227\206\270?SY\030j\335\021\344?\020\310\374\266s#\354?`\305d\003\231\267\270?\020\021\230\363\341k\347?\222\216@]0\n\324?\241?\313{\222\r\355?\340\347VN6\007\241?\273\371t\216\274_\340?_\222c\303\214\230\354?\010\007]\332\377\244\315?\200\211\257b)\177\347?\30411\375X*\352?\343\002O\032O\t\344?\270\372\027k\377\351\303?\234\003\353E\352W\326?!\304\341\210\326C\351?~\177#\017A\354\351?\270\204W\307\027{\326?\340W*\203\207\302\304?\372\274\253x\327E\323?\026\365U~{^\343?\313y\243^b/\351?hy\274\262e\340\356?o\352\364\304\247w\344?\270S\202)+ \346?(\206\227\005\302L\310?\210\262\371\216\345o\315?\200\323\300\363\265\'\240?\370\244rG\306\251\327?<\026\372\202r\355?\220,\313\000\267R\241?~\362\305\370\2124\331?\205\340\305y]4\343?@Q}\227\365\007\313?\270\004\035\255\270\'\317?\345\023\206%f\347\344?+\023\017\362\232\317\351?Z\244\372\343a,\353?\242\305+\276*z\334?Pf\"\037\302t\344?v\254\2459\251\303\324?\376\273A\n\340\234\353?\270\224?\254F\311\352?:\333\265\312\357\334\347?\270\023\234\343%\342\313?\245\004\373c\215l\357?\230\004\016I]\342\303?0\337\024\372\263#\330?N0\202>r\245\327?Z\2361\252\217(\322?\354\262c#\365\231\316?\322\033U\263\020`\337?\260\376?\273\337V\356?\260\330sd l\335?1\363\333\027\372\365\355?\234Wm\021\333\364\336?\215\222\010\032\024#\343?\363TK\342td\342?\353TZQ\373<\354?\346\230\014\213D\317\341?,\251\213\357\253\340\353?Xz0\225\266\224\354?\237\225-q&\236\346?\213I\367>MY\355?t\341\236\262\240\232\333?\020\004\3626!\235\306?\006Fx\264\355\360\344?\220;\307_7\274\271?\276\257\201E0x\327?\022\254\024\034\371\260\336?\256\302\310M\253\010\352?$D\251\235\317\211\300?\354DX\207\314\000\304?\234QM\304\341*\314?[\365\272\t\265C\356?s\222\006r^\304\341?r\263\363C\326\001\354?\334D\314\213E\256\317?<\253C\017+\267\316?\200tT\024\262\231\341?\233\261\257Z\323\231\353?\324\\\014\336\350c\341?\244\323\243j\340,\303?U\353\024.\237{\356?\332\220@\215d\347\342?3\030\n\356Z\244\340?\363\036>\346?\203\252A)\303\010\355?d\242\257\264\203\251\353?\006RNq\317\332\355?|\326N0\374\251\356?&\"\240\303\305\205\335?\356\352\313\202\370+\324?\240fE:\377\317\336?\310!\236T\355\354\311?\346\016\314R\341f\347?\000f\223\301\230\207\325?\370\246@-\273a\301?\375\017\037\"\373b\347?Y\250(T\031\223\345?C\035\335~E\354\356?\364\310\256\023\244t\334?\035\006EnjM\346?\314\313o\024\276q\342?\333\345B\310\262\237\356?\336\"n\013\311\211\330?\222\221\320\323\212\262\351?B\230)9\321i\352?#\223\305\347\002\303\352?\260\210\336\342O \251?\264\366\036\354O]\327?\306\303\364|\014\\\327?L\345\374\376Dh\330?\022\272,\021\300\030\344?\333\207\021\350E\241\357?l\022Q\351\001\263\353?\256\r\224*\305\253\357?\245\361\376\252\2364\351?\\%\375\277jz\355?X\261#9G}\271?\\\337\225\324\361\374\352?\332\351\275\360\272X\356?\372\354X<\314\213\334?0\r\361\233\377\263\256?\263w\237\177\302\246\353?\336\251\336\214:U\346?\234\373:\013b\320\310?j`AH\221?\326?\340S\307v\347\320\234?\220\204\200\033;\014\272?c\272A\262]\014\347?\212S\241\331\244\024\326?:4\350z>M\345?v\234q\206V\007\347?8\306s\036\276\021\344?\234\237\275\244\234\026\355?@\342o\361\301]\233?)t\326\343\211\251\340?\026\377 `&\326\334?Q\204\214m\327\030\341?\361\303\321\231\233\304\342?\232\001I\005c.\337?\000\230\375`\3251\217?\275\364\344\315=\304\354?\2440\353\346\340\230\336?T%\277 \366\266\337?t\014\353\270\240\017\341?\370\331{N\363\010\262?\310\210\313\212\222=\271?\200\0348\203\230\016\272?jv\276\024\350\215\341?\350:\332\226\004\232\312?^,\177\206\036\376\332?5\247E\376~\365\346?\025\037\341\2045G\346?\202$\331\014{%\351?\370\2775\243-\331\324?\026\204n\306\315S\351?\324\300YZ\250\243\353?4M\355\2641\033\345?\200\370a\362=\244\276?\334J\233\221\021\'\305?\t\272;\204\004\312\343?\300\3701\263\371\243\331?\336\212\353\370\337\260\321?\312E[\320E\312\341?0\314\356\302\2419\311?\004\365\260\204>\024\317?Z\256t\177K\260\336?\250\030\357\342\352\254\272?\300\220\0322\026w\320?`\no0\r\236\265?\237\313X\0315\204\352?\262\342\250\374,+\341?)\026\233JTS\350?88\326\346I\214\311?\3461C;\310\353\347?\347\003\0338\357V\350?B\262\353\344q\371\343?\246\251}\365\337\277\337?{Z\013\265\336\237\355?K\026\026\010\306;\350?|Z\377\321H\261\356?\302^U\027|\303\327?&\261\214\031?\324\347?\247\365\214\017E\t\341?~\316\202\313o.\356?j~V\220b<\354?L.\025k\372\177\325?\232MdS\246\355\352?\301\360\371;\217\026\346?\300\003\252t\265\341\357?\370\303f\226d\245\320?\262\"#\317\320\251\343?;\001\003b9y\355?\300\242\272\027\262\203\316?Z\000e\217=\357\327?\3775\330\266\241\270\340?\322\n\343\367#+\342?\273\344@\000\240\374\344?\000\210\030\213\327!\253?\273\3337t\0245\344?\232\010\371\320\317\003\332?n\202\375\035\222\365\330?\262\252\276\204\033\240\337?\341p\325\347E\024\344?\214N\202?Q\206\334?\306\227\304r\237\370\350?\330\004\330YeE\265?\361\264i;i9\353?\006\253\t\232\302\200\356?\207N{k\241?\340?\336\226E\377\215c\342?\205\034\355?\266a\340?h,\272\240\303\344\352?\300\025)\234w]\274?\235\254\003\013\032J\345?\246\227vR\235\330\344?^-\246\2434\262\333?xi$_\212O\321?Ms\341\200t\036\350?\375\277Y\251\317c\340?\250\256\003\"!\274\354?\261\302yW\367\375\357?\350\035\207\234X\231\352?\340\347\206\202\212\344\354?\tb(r\302>\340?\335\331+D^\325\347?\300\215|\021\t\006\315?x\334\374(|:\325?\034\221+\250:\005\355?R\351\030pdO\335?\032\266\314\350\360a\352?\360o\335\351\016s\245?\323\221\372W\035\242\357?Rh\316\276\216X\345?}X\372\271$\327\350?\231\310\n\022\330\346\357?e>\241\035l\243\356?Xx\260\026`b\266??\205\270\016vq\341?:\331j\2622\214\327?\252\024\033\035\371%\327?\240o+\001\266\247\233?2\200\355\004\325*\352?\273\212\177Q0\356\350?\215\234W\3479\347\345?\272\005\321\",O\336?4\013\315}\017\323\323?\375\264\205\375\215\346\344?\373B\317\324\311,\352?\2309\211Q`?\342?\346\223d\206\350\210\331?q8+*X\300\353?\213S\363\274g\211\347?\262O1\"\026\324\357?\020L?\\s\214\337?$\276\227\340_\001\341?\33087\212\216\240\265?e\264\361^\335\372\354?d\346\270\233\010[\334?\247\222\246(\267\007\345?L\375\210l\001*\306?\234\241\363ld_\324?\030\311\036\363\332\002\317?4\367=\002>\000\331?\214?\002ts\234\336?\256R\330\211>>\352?\236\375j\200\346W\346?`I\032\'\'\346\352?\367ed\372\3404\352?(h\202\231\342N\336?\360\250\303zk\200\354?\252\027\016\277\354\337\345?x\371\256\227\340\333\276?\273v}\226\241\325\352?\302\202\303Y\304\241\353?+\335U+\312J\353?\345\330d\324\204\353\340?\2666c\211\215\016\345?Ot\231\253\345\300\351?>\341\353~\370\037\336?\315M\331\374\355,\347?]G\350\205\244\373\354?\362/\361\223\367\340\352?`\"\203\001\202\354\242?\243e\nS\346R\340?\201G\025I&\346\341?\224_Z\300\232\014\355?\245T\016\331\355\014\352?m\244*\3049\374\352?\262V\341\270\350\276\330?\307\315\373\027\364\333\344?3v\321\353\215b\350?\214\315`#\224b\326?\304+zw\032\372\354?\257\036_B\256q\351?\201]\3102\026\272\353?H_Z%\343\277\261?v[\224\326rp\335?@X\300X\254\320\302?\007\337\207\251m\033\341?\022%^A|S\327?\t2\364@ \014\344?\212>\313 /o\323?\234sK)\352\200\320?r\201\234\320FE\330?(V\327\014\224g\276?\334+\370\377\3476\354?\330\350zi\235+\354?@fB\267\222\000\255?\0348\313\"\335\272\353?\202\325\224*\021\251\355?nq\314\201\237\231\323?\364\371\330\301F\341\350?+\205\022\220\"&\343?\2501\207\264\306\355\320?\026}\305\225\220\361\332?\340\370Q\274\371\356\277?\034\t\370\206\032|\340?\336\335\337\3440\345\323?\031\033ek7c\351?\200\260\237\003\314\322\247?>\262\315EJN\357?\334k\001\255\335\"\312?\320\350{{@\267\327?\260\201\327v\253s\322?\204^Z]C\222\305?vx6\320\004\005\322?\000\322\202\231W=\221?\232\306\370\221#\242\353?\226v\304u\006\027\340?D\017$R\014\275\307?P\370v\017\244-\317?\362\005\275\276\353m\344?\335\337\330\253\262\233\342?\032\361\261\211\322\210\324?\343\252\214t\',\357?b\320y\206\352~\324?d3~\276\316\334\342?8T\210a_\032\277?\330\336\230\263U7\274?r\016\303:\005\225\347?8#)+[\211\350?\356\235\244\223*H\356?\357E\333\251\2329\354?\372zQw\251Z\355?\221:\366^\212\335\343?\211V\327J\233?\341?\3029\323\032\252\004\336?\000\314\244\212\273\322\324?\320\340\235\335\235l\254?H\200\210\210T\017\317?p=\233d@6\250?0d\224\r\303\010\344?hr\350)\016=\347?\362a\026\215\317T\351?\340\317\211 \223\323\310?dM\317\351\"\273\304?\260b\351|m\177\330?z\022\212\360VQ\350?\307\025\203b.\016\352?\023\\\370n\334\245\341?\340\026\201\270\246\016\324?]\2223\233\323\347\346?f~\366\r\223^\324?\254\231xm\017\363\331?_^\217\253\276\235\351?`\220\3758\300\277\220?\nU\325\\\204\335\320?\344F\005\265\364%\316?\3200r(\021\267\260?\200\266\321\366\347\036\340? \352\020F\334\232\260?\330z\023\234\371j\273?\322M\021\242\266\006\334?e\322;;6\210\353?\220Q\211\372t\013\324?>\375\274\373\3074\342?~*\262\003g\310\347?.yx\364>\241\341?\220\311h\214\300\337\265?\306`\177S\261\206\346?\320\304g\242\325\036\324?\024\244\030j\252\266\343?o\265\277\323v\317\347?\037\036J\250h\270\345?\240\025\307\357ay\312?h4T&`u\314?^\325\336yL\335\357?\351H\033\271\310\014\353?\363\004|9h\004\340?\323\201l\211\177\223\340?\342\344\001@\3221\327?\374\221\244\250e\007\351?U\326\236J\014H\353?8\214\371FF\356\311?\\[k\274q\260\330?\250\001Vr\356\367\327?\364\\=\351\362V\307?5s\000\356\212\246\357?\000\006@)e|o?\004{\353+o\013\344? \340)$i\275\261?Lq\326\342\340?x\300Y@\016\220\343?\0204B\"\331\332\335?\032J\235\357X\227\323?$w\007\242\225\351\347?\230y\302\3748|\307?\237k\001\346\274\037\357?\270\032@\325\306;\302?\232\365\246\257@0\355?&\373B\265\313\221\350?>\\&\351\177`\345? q2\265\337\023\326?\356\323\267\211)\035\337?\200\306\232\035\360\331\221?h\370\263~\205\305\316?tTQ\037\371s\311?8\265!7\330\340\273?\334\344*\022P\310\332? \266\336}_;\302?\246\360(\037U\346\357?x3\203~D\232\332?[.\010\235\3647\347?\251L\236\216*U\357?,\257\366\2412\240\310?\374ISmn\306\351?\273Z\201\347/C\340?\000\273\330\362\216\317\333?\220\374\335\374\244\027\254?e\206\022\226i\252\342?K\215\240\271\260P\341?tKb\251+\t\304?\317\302\250F\332f\353?\010\001\334\250\207\244\267?\210\253\314\222r\243\317?\334[\'\002\234\374\345?$\346\266\014\314\003\306?\010\3663\004\026\201\332?~\242G\276E\007\322?x&\221`\276\360\347?\366<\242\352x\247\335?}X\335n\216\347\345?9\025\027)j\273\355?Ck\277\307\277\343\340?\314p\253\007K\340\323?\245\tP\203FG\341?\010\250Pp\020\256\331?\350\035\3520\363Y\342?0\261\021\311#\265\264?\241\210+\266\025\213\357?N!k\274\033\030\341?&:\227\202\330:\343?\237B1aPA\355?\302Q\305\305\014\232\341?\214\314\337@\363\304\324?\032\307`\243ft\342?\274tE;\321\344\352?\355X\261d\0145\356?\231JE\337\010\336\343? \357\014\375\332\027\326?\201A7\033;P\341?\374\245\336\36571\323?\331\363\267\317\323F\344?\320FK\177\304X\344?C\357\3264[\245\357?\340\275\346?\344\372\226E\025C\337?\355\277\0019vc\344?\200_\020r\003\037\265?\251t\323\371\256:\343?\034\226\233 \026J\341?\335\020\375d\266\006\345?\240\262\034\321[9\236?@%\315<\030\246\333?\255\307p\334\205J\343?j\317V\306K\251\353?\200<\232T\221\264\222?\240K\3439\273>\224?\260!\203\216Y\377\317?\257\264\026\027V\202\354?&|\241CI\327\347?\022\365L\234\375\332\333?\202\365\337\216N@\345?B\352\000S\342k\327?Bn\022\305\267\274\343?@\t\253K>\017\325?\033\214\023\253\251\226\350?\220)]h\204u\345?\025DzW\342\217\347?$\032\351\331\245t\353?0aoo8\\\277?\274\257\256\210\320_\324? \370\257\257\341\034\255?\024\001h\324c\037\353?\230\346D\006_\304\355?~\032,\035a\"\342?2\355VU\225A\327?\024\306\253#\346\207\345?\344\020\301T\321\230\337?\256\264\264J\327\000\323?\034\304\227\337[\003\304?O<\362\r%E\343?+\035\206Xe\242\340?\320\271\306\270\277X\271?\036\017 \276T\376\340?\200Tm\276\210p\351?`1:\311!$\343?p\\\n\355\007\360\313?\300f\375/Lc\345?\000\037(L\016\020\240?\312;]$\006\311\355?h*\031\275\271p\275?\002H\036\306C\016\353?\002\223\370\316\233\004\341?\221\377(*\202,\340?\375\211\344\0079$\357?\277\233\320N\033\034\345?.\315I\033T\264\356?p?\315\257\212\372\273?\342\3557\245s\307\340?\316#m8\251\344\327?y\226\273n\035\321\354?XPn\333\336\004\324?\200\037\336\267h%\277?B\253y\336\267\273\347?b\226\237\344j%\321?&\324]8tn\323?\354L\246\252yG\344?\203S\'\342\204v\357?=\227h\360>\320\344?\364-\225\271\254V\344?\334\026\005:\254(\346?\350\301F5\335?\342?&_F\257\216\247\353?\220\354\016.O^\323?$\267\'\354{\214\324?\224\220\225p\036\004\355?(\302~\212\n3\304?xP\225\213\204A\271?F%o*d\277\344?\362(\327z?\207\345?\210\024\"\3528\254\271?\304~\037C%O\340?_\373\3159\314P\346?f\310W+i\017\357?N\356\317\233\247\206\341?Qj7v\026\021\345?(c\301\036z\300\270?\212H?\220s\321\332?\345\335E+\325&\350?L\206\340:\236\177\322?\3248\356\202\347\202\354?R@EN\210h\336?8\236\030\032\232\r\277?\210\332\244H\034\334\302?x+\250\373\021\301\313?\260]Y\250u\330\355?\364f\177\244\232\275\312?HR\002\374\215K\307?\367&?ja\354\352?\\\221\320\371\356\217\335?\352\n\210Q\201@\322?\324\267\372\007v\251\357?\304iV\317\032$\331?V\365\304\033\rK\323?\241\017\023\023o\305\350?\254F\374\302\325h\343?\205AI\210\263\311\340?\203\254W\244\026\241\346?\342\213\357\376cb\353?\244\232L\004I\005\307?\216\322g\350\363\243\330?O\010\244X\241\347\343?S\366\033\t\347\361\353?\027\224\335\244.?\340?\210\252\333A\203L\325?\226\316t\020)\226\341?\300\307\027\261`\n\321?l]%\352\3075\316?Bm\274\\\215^\323?\215\232\203\255\216\344\356?\324Q\254\235\313a\336?\272\217\334\013\347\313\354?\000=\316/\262\377\222?h\215X(<3\357?\252\301\004V\357\355\354?@\245\323L\006/\332?\310*\307E\354\031\275?\004\'\203\241\373\355\305?\232\306gy\346\251\350?\312\276O\226s\260\332?:\276\225\000\251\235\355?|\301\354X\323K\352?D\017\274px\231\346?@\tEs\026f\262?\270\007y\233\204\324\331?Og\236\346b=\344?>\323\201\320b>\352?>\376J\237\020}\341?\246\362\2458\016\020\333?\364s^\"\337\217\343?\010\364\300M\351r\323?\211t\310\3267#\351?\024B,\300\035\375\357? _\233\014&-\301?R\373\\\023X\321\357?\210\221\322j\222,\337?ux\220\245;\\\346?\244\353\273\3038\224\315?\343\013A\037I\263\346?m\013\236\302!\224\351?h\371\\\022\224\276\302?\030\327\2062E.\277?`\0045xg\205\311?\020Y\220g.f\251?\200]GP))\247?h\274\362q\323F\325?\020\254L5/\316\262?\035F\275{\340\274\354?\300\323\341\362\377\236\255?\330x\327\206\331\340\327?\310\327pd\345{\342?\010u\272\361m7\326?\254\273\247\245\245{\300?\300\311?@\016\247\254?h\3355\215\037\274\321?\265\307\003\277%W\343?xR\360c+\274\271?\000\212X\026\361o\222?\\\203\277,\303I\313?\340\325\t\022\361t\221?\300\027\312\330\337\020\256?\261)\356l\036\211\347?\260\021\224\216>\024\323?(\n\337k\317p\327?\304:\361jF\037\326?\260d\247l\034\375\342?y\340S\275\0069\357?f\362\323\235dP\346?\350?\341\347J\301\307?d\332(\211\033\026\357?\234I\363\341\372\220\342?~\333z\371N\334\335?\204\316^9\276\210\330?\307\0241\244#H\341?\225\343\215\007\tq\340?\346\340l\000\035n\344?t\337\366\250\2371\353?8\301\331\025\345\325\333?\306&k\266\027o\340?\251\037\345\010\242\301\356?\036\350\006Rn\361\323?\217\341\274\217\311\\\356?f\3235\376f\240\333?\225\271\t\241k\277\344?d:\271\260\324\251\301?\034\312+\203\223\357\326?n\275\004\321A\303\353?\324\305Z\241\010\242\342?\222\363\350}I\312\347?f\324\247gBm\322?\035nn\027\005\257\346?\260\255G\216\272\003\240?\224W\374\033\260\016\345?\340\024hZ>S\311?z:`S\204g\346?\354%Rh\203H\353?\"\033\365\252\213\272\335?xZ]\n\037\220\276?\270\363\360\034\331\023\347?\250\377\\\321\312\301\347?@\263w\267\311\207\200?@@\317\242\022\201\314?\316\346\263\262\213\224\320?\244\335}\255\241\255\357?\317\'#\376q\003\353?v\235\3069\377\257\333?2\265vQ\336\240\326?e\352\350\221\277Y\344?\322\330FL\260\350\346?\263\270>A\245g\351?\r/\\\000j\264\352?\204\035l\265\261e\330?\272\275\244\205\024\345\330?\200\220\261\017\323\233\244?\202SI>\310 \324?8\3134\037\266\353\306?\260\323\002\353\241\274\261?=\244VG(|\352?Y\345V\317\271a\351?,\004\265\316\304N\351?l\177R\323\nv\317?\230\274\034z\230\304\323?\360\260\255\222\355\236\341?" + } + } + } + } + node { + name: "PartitionedCall" + op: "PartitionedCall" + input: "Const" + attr { + key: "Tin" + value { + list { + type: DT_DOUBLE + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_DOUBLE + } + } + } + attr { + key: "_collective_manager_ids" + value { + list { + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "config_proto" + value { + s: "\202\001\0008\0012\002J\000\n\007\n\003CPU\020\001\n\007\n\003GPU\020\000" + } + } + attr { + key: "f" + value { + func { + name: "__inference_signature_wrapper_13" + } + } + } + } + node { + name: "NoOp" + op: "NoOp" + } + node { + name: "Const_1" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\n\035\n\t\022\005get_c\010\001\n\016\022\nsignatures\010\002*\000\n\017\n\013\022\007trace_0\010\003*\000\n\027\n\023\022\017serving_default\010\004*\000\n\021\n\r\022\tcapture_0\010\005*\000\n\021\n\r\022\tcapture_0\010\005*\000\n\002*\000" + } + } + } + } + node { + name: "saver_filename" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "StatefulPartitionedCall" + op: "StatefulPartitionedCall" + input: "saver_filename" + input: "Const_1" + attr { + key: "Tin" + value { + list { + type: DT_STRING + type: DT_STRING + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "_collective_manager_ids" + value { + list { + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "config_proto" + value { + s: "\202\001\0008\0012\002J\000\n\007\n\003CPU\020\001\n\007\n\003GPU\020\000" + } + } + attr { + key: "f" + value { + func { + name: "__inference__traced_save_40" + } + } + } + } + node { + name: "StatefulPartitionedCall_1" + op: "StatefulPartitionedCall" + input: "saver_filename" + attr { + key: "Tin" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "_collective_manager_ids" + value { + list { + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "config_proto" + value { + s: "\202\001\0008\0012\002J\000\n\007\n\003CPU\020\001\n\007\n\003GPU\020\000" + } + } + attr { + key: "f" + value { + func { + name: "__inference__traced_restore_49" + } + } + } + } + library { + function { + signature { + name: "__inference_signature_wrapper_13" + input_arg { + name: "unknown" + type: DT_DOUBLE + } + output_arg { + name: "identity" + type: DT_DOUBLE + } + } + node_def { + name: "PartitionedCall" + op: "PartitionedCall" + input: "unknown" + attr { + key: "Tin" + value { + list { + type: DT_DOUBLE + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_DOUBLE + } + } + } + attr { + key: "_collective_manager_ids" + value { + list { + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "config_proto" + value { + s: "\202\001\0008\0012\002J\000\n\007\n\003CPU\020\001\n\007\n\003GPU\020\000" + } + } + attr { + key: "f" + value { + func { + name: "__inference__6" + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "PartitionedCall:output:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "9" + } + } + } + } + } + function { + signature { + name: "__inference__traced_save_40" + input_arg { + name: "file_prefix" + type: DT_STRING + } + input_arg { + name: "savev2_const_1" + type: DT_STRING + } + output_arg { + name: "identity_1" + type: DT_STRING + } + is_stateful: true + control_output: "MergeV2Checkpoints" + } + node_def { + name: "StaticRegexFullMatch" + op: "StaticRegexFullMatch" + input: "file_prefix" + device: "/device:CPU:*" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "pattern" + value { + s: "^s3://.*" + } + } + } + node_def { + name: "Const" + op: "Const" + device: "/device:CPU:*" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: ".part" + } + } + } + } + node_def { + name: "Const_1" + op: "Const" + device: "/device:CPU:*" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp/part" + } + } + } + } + node_def { + name: "Select" + op: "Select" + input: "StaticRegexFullMatch:output:0" + input: "Const:output:0" + input: "Const_1:output:0" + device: "/device:CPU:*" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "StringJoin" + op: "StringJoin" + input: "file_prefix" + input: "Select:output:0" + device: "/device:CPU:*" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "ShardedFilename/shard" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "ShardedFilename" + op: "ShardedFilename" + input: "StringJoin:output:0" + input: "ShardedFilename/shard:output:0" + input: "num_shards:output:0" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "SaveV2/tensor_names" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "_CHECKPOINTABLE_OBJECT_GRAPH" + } + } + } + } + node_def { + name: "SaveV2/shape_and_slices" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node_def { + name: "SaveV2" + op: "SaveV2" + input: "ShardedFilename:filename:0" + input: "SaveV2/tensor_names:output:0" + input: "SaveV2/shape_and_slices:output:0" + input: "savev2_const_1" + device: "/device:CPU:0" + attr { + key: "_has_manual_control_dependencies" + value { + b: true + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "ShardedFilename:filename:0" + input: "^SaveV2" + device: "/device:CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node_def { + name: "MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "MergeV2Checkpoints/checkpoint_prefixes:output:0" + input: "file_prefix" + device: "/device:CPU:0" + attr { + key: "_has_manual_control_dependencies" + value { + b: true + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "file_prefix" + input: "^MergeV2Checkpoints" + device: "/device:CPU:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Identity_1" + op: "Identity" + input: "Identity:output:0" + input: "^NoOp" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "NoOp" + op: "NoOp" + input: "^MergeV2Checkpoints" + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + ret { + key: "identity_1" + value: "Identity_1:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + control_ret { + key: "MergeV2Checkpoints" + value: "MergeV2Checkpoints" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "file_prefix" + } + } + } + } + arg_attr { + key: 1 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "Const_1" + } + } + } + } + } + function { + signature { + name: "__inference__18" + input_arg { + name: "unknown" + type: DT_DOUBLE + } + output_arg { + name: "identity" + type: DT_DOUBLE + } + } + node_def { + name: "Identity" + op: "Identity" + input: "unknown" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "15" + } + } + } + } + } + function { + signature { + name: "__inference__traced_restore_49" + input_arg { + name: "file_prefix" + type: DT_STRING + } + output_arg { + name: "identity_1" + type: DT_STRING + } + is_stateful: true + } + node_def { + name: "RestoreV2/tensor_names" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "_CHECKPOINTABLE_OBJECT_GRAPH" + } + } + } + } + node_def { + name: "RestoreV2/shape_and_slices" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node_def { + name: "RestoreV2" + op: "RestoreV2" + input: "file_prefix" + input: "RestoreV2/tensor_names:output:0" + input: "RestoreV2/shape_and_slices:output:0" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "NoOp" + op: "NoOp" + device: "/device:CPU:0" + attr { + key: "_has_manual_control_dependencies" + value { + b: true + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "file_prefix" + input: "^NoOp" + device: "/device:CPU:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Identity_1" + op: "Identity" + input: "Identity:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + ret { + key: "identity_1" + value: "Identity_1:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + } + } + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "file_prefix" + } + } + } + } + } + function { + signature { + name: "__inference__6" + input_arg { + name: "unknown" + type: DT_DOUBLE + } + output_arg { + name: "identity" + type: DT_DOUBLE + } + } + node_def { + name: "Identity" + op: "Identity" + input: "unknown" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "3" + } + } + } + } + } + } + versions { + producer: 1520 + min_consumer: 12 + } + } + saver_def { + filename_tensor_name: "saver_filename:0" + save_tensor_name: "StatefulPartitionedCall:0" + restore_op_name: "StatefulPartitionedCall_1" + version: V2 + } + collection_def { + key: "saved_model_main_op" + value { + node_list { + value: "NoOp" + } + } + } + signature_def { + key: "__saved_model_init_op" + value { + outputs { + key: "__saved_model_init_op" + value { + name: "NoOp" + tensor_shape { + unknown_rank: true + } + } + } + } + } + signature_def { + key: "serving_default" + value { + outputs { + key: "output_0" + value { + name: "PartitionedCall:0" + dtype: DT_DOUBLE + tensor_shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } + object_graph_def { + nodes { + children { + node_id: 1 + local_name: "get_c" + } + children { + node_id: 2 + local_name: "signatures" + } + user_object { + identifier: "_generic_user_object" + version { + producer: 1 + min_consumer: 1 + } + } + } + nodes { + children { + node_id: 3 + local_name: "trace_0" + } + function { + concrete_functions: "__inference__18" + function_spec { + fullargspec { + named_tuple_value { + name: "FullArgSpec" + values { + key: "args" + value { + list_value { + } + } + } + values { + key: "varargs" + value { + none_value { + } + } + } + values { + key: "varkw" + value { + none_value { + } + } + } + values { + key: "defaults" + value { + none_value { + } + } + } + values { + key: "kwonlyargs" + value { + list_value { + } + } + } + values { + key: "kwonlydefaults" + value { + none_value { + } + } + } + values { + key: "annotations" + value { + dict_value { + } + } + } + } + } + input_signature { + tuple_value { + } + } + } + } + dependencies { + node_id: 3 + local_name: "trace_0" + } + } + nodes { + children { + node_id: 4 + local_name: "serving_default" + } + user_object { + identifier: "signature_map" + version { + producer: 1 + min_consumer: 1 + } + } + } + nodes { + children { + node_id: 5 + local_name: "capture_0" + } + bare_concrete_function { + concrete_function_name: "__inference__18" + function_spec { + fullargspec { + named_tuple_value { + name: "FullArgSpec" + values { + key: "args" + value { + list_value { + } + } + } + values { + key: "varargs" + value { + none_value { + } + } + } + values { + key: "varkw" + value { + none_value { + } + } + } + values { + key: "defaults" + value { + none_value { + } + } + } + values { + key: "kwonlyargs" + value { + list_value { + } + } + } + values { + key: "kwonlydefaults" + value { + none_value { + } + } + } + values { + key: "annotations" + value { + dict_value { + } + } + } + } + } + input_signature { + tuple_value { + } + } + } + } + dependencies { + node_id: 5 + local_name: "capture_0" + } + } + nodes { + children { + node_id: 5 + local_name: "capture_0" + } + bare_concrete_function { + concrete_function_name: "__inference_signature_wrapper_13" + function_spec { + fullargspec { + named_tuple_value { + name: "FullArgSpec" + values { + key: "args" + value { + list_value { + } + } + } + values { + key: "varargs" + value { + none_value { + } + } + } + values { + key: "varkw" + value { + none_value { + } + } + } + values { + key: "defaults" + value { + none_value { + } + } + } + values { + key: "kwonlyargs" + value { + list_value { + } + } + } + values { + key: "kwonlydefaults" + value { + none_value { + } + } + } + values { + key: "annotations" + value { + dict_value { + } + } + } + } + } + input_signature { + tuple_value { + } + } + } + } + dependencies { + node_id: 5 + local_name: "capture_0" + } + } + nodes { + constant { + operation: "Const" + } + registered_name: "tf.TrackableConstant" + } + concrete_functions { + key: "__inference__18" + value { + bound_inputs: 5 + canonicalized_input_signature { + tuple_value { + values { + tuple_value { + } + } + values { + dict_value { + } + } + } + } + output_signature { + tensor_spec_value { + name: "unknown" + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + dtype: DT_DOUBLE + } + } + } + } + concrete_functions { + key: "__inference_signature_wrapper_13" + value { + bound_inputs: 5 + canonicalized_input_signature { + tuple_value { + values { + tuple_value { + } + } + values { + dict_value { + } + } + } + } + output_signature { + dict_value { + fields { + key: "output_0" + value { + tensor_spec_value { + name: "output_0" + shape { + dim { + size: 150 + } + dim { + size: 150 + } + } + dtype: DT_DOUBLE + } + } + } + } + } + } + } + } +} diff --git a/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/fingerprint.pb b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/fingerprint.pb new file mode 100644 index 00000000000..a033e7c08e9 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/fingerprint.pb @@ -0,0 +1 @@ +2(Åžçì½…âÀ Œ‚¦þ¡žõ󼎢¶â­ÚâŽßÅ«ƒÏ¾œîÓÀ°³éîâ®Ù \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/saved_model.pb b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/saved_model.pb new file mode 100644 index 00000000000..46fc1c51987 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..3e08df4e8f9 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.index b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.index new file mode 100644 index 00000000000..2b377c2506a Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/chunked_saved_model/non_chunked_model/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/generate_chunked_models.py b/tensorflow/cc/saved_model/testdata/generate_chunked_models.py new file mode 100644 index 00000000000..ef9f968a68b --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/generate_chunked_models.py @@ -0,0 +1,76 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generates GraphDef test data for Merger. + +Constructs chunked protos test data containing GraphDefs with lots of nodes and +large nodes for Merger::Read and Merger::Merge. +""" + +from collections.abc import Sequence + +import os + +from absl import app +from absl import flags +import numpy as np + +from tensorflow.python.compat import v2_compat +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.lib.io import file_io +from tensorflow.python.module import module +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import save +from tensorflow.python.saved_model import save_options +from tensorflow.python.util import compat +from tensorflow.tools.proto_splitter import constants +from tensorflow.tools.proto_splitter.python import saved_model as proto_splitter + +SPLITTER_TESTDATA_PATH = flags.DEFINE_string( + "path", None, help="Path to testdata directory.") + + +def generate_non_chunked_model(non_chunked_dir: str): + root = module.Module() + root.c = constant_op.constant(np.random.random_sample([150, 150])) + constants.debug_set_max_size(80000) + root.get_c = def_function.function(lambda: root.c) + signatures = root.get_c.get_concrete_function() + save.save(root, non_chunked_dir, signatures=signatures, + options=save_options.SaveOptions(experimental_image_format=False)) + + +def generate_chunked_model(non_chunked_dir: str, chunked_dir: str): + saved_model = loader_impl.parse_saved_model(non_chunked_dir) + prefix = file_io.join(compat.as_str(chunked_dir), "saved_model") + file_io.write_string_to_file(f"{prefix}.pbtxt", str(saved_model)) + proto_splitter.SavedModelSplitter(saved_model).write(prefix) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + main_dir = os.path.join(SPLITTER_TESTDATA_PATH.value, "chunked_saved_model") + non_chunked_dir = os.path.join(main_dir, "non_chunked_model") + generate_non_chunked_model(non_chunked_dir) + chunked_dir = os.path.join(main_dir, "chunked_model") + generate_chunked_model(non_chunked_dir, chunked_dir) + + +if __name__ == "__main__": + v2_compat.enable_v2_behavior() + app.run(main) diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 510e7f589fd..bb5daa99742 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -1,4 +1,5 @@ -# Description: +#include "third_party/absl/strings/str_cat.h" +#Description: # TensorFlow cc tools. load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -22,6 +23,8 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 480c048e94f..5dcf5e64964 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -193,7 +195,7 @@ StatusOr GetVarHandleName( if (node->op() == "VarHandleOp") { return node->name(); } - return errors::NotFound("No VarHandleOp ancestor found"); + return absl::NotFoundError("No VarHandleOp ancestor found"); } // Looks up the variable handle that provides input to node with node_name, @@ -209,7 +211,7 @@ StatusOr GetHandleNameIfNeedsToFreeze( if (var_handle_name.ok() && variable_node_names.count(*var_handle_name)) { return var_handle_name; } - return errors::NotFound("No VarHandleOp ancestor found"); + return absl::NotFoundError("No VarHandleOp ancestor found"); } // Freezes the subgraph of all nodes needed by `outputs`. diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 7a8d5273b03..d8d2ea82e76 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/aot/codegen.h" +#include #include #include #include @@ -24,6 +25,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" +#include "absl/strings/substitute.h" #include "absl/types/span.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" @@ -312,6 +314,74 @@ Status GenVariableMethods(const tf2xla::Config& config, return OkStatus(); } +// Generate shape infos for args (inputs). +Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { + for (int i = 0; i < ps.parameters_size(); ++i) { + const xla::ShapeProto& shape = ps.parameters(i); + if (shape.element_type() == xla::TUPLE) { + // ShapeInfo cannot represent tuple args. + return absl::InternalError( + absl::StrCat("parameter ", i, + ": codegen requires XLA parameters to " + "be non-tuples.")); + } + // Please some compilers (e.g. MSVC) by avoiding the initialization of an + // array of unknown size an empty initializer. Use "-1" for this; note that + // this value is never used (the size attribute is set to 0 in ShapeInfo). + *infos += absl::Substitute(R"( static constexpr int32_t kArg$0Shapes[] = { +$1 + }; +)", + i, + shape.dimensions_size() > 0 + ? absl::StrJoin(shape.dimensions(), ", ") + : "-1"); + } + *infos += R"( static const ShapeInfo* ArgShapeInfos() { + static constexpr ShapeInfo kArgShapeInfoTable[kNumArgs] = { +)"; + for (int i = 0; i < ps.parameters_size(); ++i) { + const xla::ShapeProto& shape = ps.parameters(i); + *infos += + absl::Substitute("{ kArg$0Shapes, $1 },\n", i, shape.dimensions_size()); + } + *infos += R"( }; + return kArgShapeInfoTable; + })"; + return OkStatus(); +} + +// Generate shape infos for results. +Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { + if (ps.result().element_type() != xla::TUPLE) { + return absl::InternalError("codegen requires the XLA result to be a tuple"); + } + for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { + const xla::ShapeProto& shape = ps.result().tuple_shapes(i); + // See above comment about the use here of "-1". + *infos += absl::Substitute( + R"( static constexpr int32_t kResult$0Shapes[] = { +$1 + }; +)", + i, + shape.dimensions_size() > 0 ? absl::StrJoin(shape.dimensions(), ", ") + : "-1"); + } + *infos += R"( static const ShapeInfo* ResultShapeInfos() { + static constexpr ShapeInfo kResultShapeInfoTable[kNumResults] = { +)"; + for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { + const xla::ShapeProto& shape = ps.result().tuple_shapes(i); + *infos += absl::Substitute("{ kResult$0Shapes, $1 },\n", i, + shape.dimensions_size()); + } + *infos += R"( }; + return kResultShapeInfoTable; + })"; + return OkStatus(); +} + // Generates code implementing {Arg,Result}Names(), where T is one of // tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style // string literal in the array, with nullptr terminating the array. @@ -377,17 +447,27 @@ std::vector BufferInfosToCppExpression( std::transform(buffer_infos.begin(), buffer_infos.end(), std::back_inserter(buffer_infos_as_strings), [](const BufferInfo& buffer_info) { - std::pair encoded = buffer_info.Encode(); - string encoded_second_as_str = - encoded.second == ~0ULL - ? "~0ULL" - : absl::StrCat(encoded.second, "ULL"); + xla::cpu_function_runtime::EncodedBufferInfo encoded = + buffer_info.Encode(); + auto param_to_str = [](uint32_t param) -> std::string { + return param == ~0U ? "~0U" : absl::StrCat(param, "U"); + }; return absl::StrCat( - "::xla::cpu_function_runtime::BufferInfo({", - encoded.first, "ULL, ", encoded_second_as_str, "})"); + "::xla::cpu_function_runtime::BufferInfo(", + encoded.packed_kind_and_size, "ULL, ", + param_to_str(encoded.entry_param_number), ", ", + param_to_str(encoded.result_param_number), ")"); }); return buffer_infos_as_strings; } + +Status CheckEqual(size_t a, size_t b, absl::string_view error_msg) { + if (a != b) { + return absl::InternalError( + absl::StrCat(error_msg, ". Expected ", a, ", got ", b, ".")); + } + return OkStatus(); +} } // namespace Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, @@ -400,6 +480,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, compile_result.aot->buffer_infos(); const std::vector arg_index_table = ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); + const std::vector result_index_table = + ::xla::cpu::CreateResultIndexTableFromBufferInfos(buffer_infos); std::vector buffer_infos_as_strings = BufferInfosToCppExpression(buffer_infos); const int64_t buffer_infos_size = buffer_infos.size(); @@ -419,6 +501,15 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable)); + string arg_shape_infos, result_shape_infos; + TF_RETURN_IF_ERROR(GenArgShapeInfos(ps, &arg_shape_infos)); + TF_RETURN_IF_ERROR( + CheckEqual(ps.parameters_size(), arg_index_table.size(), + "Arg number mismatch, proto vs. arg_index_table")); + TF_RETURN_IF_ERROR(GenResultShapeInfos(ps, &result_shape_infos)); + TF_RETURN_IF_ERROR( + CheckEqual(ps.result().tuple_shapes_size(), result_index_table.size(), + "Result number mismatch, proto vs. result_index_table")); const size_t arg_bytes_aligned = xla::cpu_function_runtime::AlignedBufferBytes( buffer_infos_for_args.data(), buffer_infos_for_args.size(), @@ -544,6 +635,8 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; + static constexpr size_t kNumResults = {{RESULT_NUM}}; + // Number of variables for the compiled computation. static constexpr size_t kNumVariables = {{VARIABLE_NUM}}; @@ -560,16 +653,21 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { set_static_data_raw_function(data, {{ENTRY}}); set_static_data_buffer_infos(data, BufferInfos()); set_static_data_num_buffers(data, kNumBuffers); + set_static_data_result_index_table(data, ResultIndexToBufferIndex()); + set_static_data_num_results(data, kNumResults); set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); set_static_data_num_args(data, kNumArgs); set_static_data_num_variables(data, kNumVariables); set_static_data_result_index(data, kResultIndex); + set_static_data_arg_shape_infos(data, ArgShapeInfos()); + set_static_data_result_shape_infos(data, ResultShapeInfos()); set_static_data_arg_names(data, StaticArgNames()); set_static_data_variable_names(data, StaticVariableNames()); set_static_data_result_names(data, StaticResultNames()); set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( data, StaticHloProfilePrinterData()); + set_static_data_use_xla_runtime(data, {{USE_XLA_RUNTIME}}); {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); @@ -589,7 +687,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in - // any AllocMode. Must be called before Run to have an affect. Must be + // any AllocMode. Must be called before Run to have an effect. Must be // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional // argument, to set the argument buffers. // @@ -655,6 +753,13 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { return kBufferInfos; } + static const ::tensorflow::int32* ResultIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kResultIndexToBufferIndex[kNumResults] = { +{{RESULT_INDEX_TABLE}} + }; + return kResultIndexToBufferIndex; + } + static const ::tensorflow::int32* ArgIndexToBufferIndex() { static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { {{ARG_INDEX_TABLE}} @@ -665,6 +770,12 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + // Shapes of the input arguments. +{{ARG_SHAPE_INFOS}}; + + // Shapes of the results. +{{RESULT_SHAPE_INFOS}}; + // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() {{ARG_NAMES_CODE}} @@ -699,13 +810,18 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())}, + {"{{ARG_SHAPE_INFOS}}", arg_shape_infos}, {"{{VARIABLE_NUM}}", absl::StrCat(config.variable_size())}, {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, + {"{{RESULT_NUM}}", absl::StrCat(result_index_table.size())}, + {"{{RESULT_INDEX_TABLE}}", absl::StrJoin(result_index_table, ", ")}, + {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, + {"{{USE_XLA_RUNTIME}}", opts.use_xla_runtime ? "true" : "false"}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, @@ -722,6 +838,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{VARIABLE_NAMES_CODE}}", variable_names_code}, {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, + {"{{RESULT_SHAPE_INFOS}}", result_shape_infos}, {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)}, {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())}, @@ -749,7 +866,7 @@ Status GenerateMetadata(const CodegenOpts& opts, if (opts.gen_program_shape) { program_shape = - absl::make_unique(compile_result.program_shape); + std::make_unique(compile_result.program_shape); // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 9485e86b10e..a0caceaf4c6 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -48,6 +48,9 @@ struct CodegenOpts { // If true, emit a serialized HloProfilePrinterData protobuf that can be used // to pretty print HLO profile counters. bool gen_hlo_profile_printer_data = false; + + // If true, sets this executable as an XLA Runtime one. + bool use_xla_runtime = false; }; // Describes a generated metadata object file. diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 18e3182e686..dc02f88e6a9 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -215,18 +215,22 @@ TEST(CodegenTest, Golden) { CompileResult compile_result; compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( {}, - {BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0), + {BufferInfo::MakeTempBuffer(3 * 8), + BufferInfo::MakeEntryParameter(/*size=*/8, /*entry_param_number=*/0), BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), + BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/1), BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2), + BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/2), BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3), - BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4), - BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)}, - 11, {})); + BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/3), + BufferInfo::MakeResultParameter(/*size=*/5 * 6 * 4, + /*result_param_number=*/0), + BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/4), + BufferInfo::MakeResultParameter(/*size=*/1 * 4, + /*result_param_number=*/1), + BufferInfo::MakeResultParameter(/*size=*/5 * 4, + /*result_param_number=*/2)}, + 0, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index b4af9ef633d..88aefb744ad 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -58,13 +58,15 @@ namespace bar { // Memory stats: // arg bytes total: 392 // arg bytes aligned: 576 -// temp bytes total: 126 +// temp bytes total: 171 // temp bytes aligned: 512 class MyClass final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 5; + static constexpr size_t kNumResults = 3; + // Number of variables for the compiled computation. static constexpr size_t kNumVariables = 3; @@ -81,16 +83,21 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { set_static_data_raw_function(data, entry_point); set_static_data_buffer_infos(data, BufferInfos()); set_static_data_num_buffers(data, kNumBuffers); + set_static_data_result_index_table(data, ResultIndexToBufferIndex()); + set_static_data_num_results(data, kNumResults); set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); set_static_data_num_args(data, kNumArgs); set_static_data_num_variables(data, kNumVariables); set_static_data_result_index(data, kResultIndex); + set_static_data_arg_shape_infos(data, ArgShapeInfos()); + set_static_data_result_shape_infos(data, ResultShapeInfos()); set_static_data_arg_names(data, StaticArgNames()); set_static_data_variable_names(data, StaticVariableNames()); set_static_data_result_names(data, StaticResultNames()); set_static_data_program_shape(data, StaticProgramShape()); set_static_data_hlo_profile_printer_data( data, StaticHloProfilePrinterData()); + set_static_data_use_xla_runtime(data, false); return data; }(); @@ -110,7 +117,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { // // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in - // any AllocMode. Must be called before Run to have an affect. Must be + // any AllocMode. Must be called before Run to have an effect. Must be // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional // argument, to set the argument buffers. // @@ -354,22 +361,29 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() { static const ::xla::cpu_function_runtime::BufferInfo kBufferInfos[kNumBuffers] = { -::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), -::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), -::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), -::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), -::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), -::xla::cpu_function_runtime::BufferInfo({386ULL, 2ULL}), -::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), -::xla::cpu_function_runtime::BufferInfo({386ULL, 3ULL}), -::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), -::xla::cpu_function_runtime::BufferInfo({386ULL, 4ULL}), -::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), -::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) +::xla::cpu_function_runtime::BufferInfo(97ULL, ~0U, ~0U), +::xla::cpu_function_runtime::BufferInfo(34ULL, 0U, ~0U), +::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U), +::xla::cpu_function_runtime::BufferInfo(386ULL, 1U, ~0U), +::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U), +::xla::cpu_function_runtime::BufferInfo(386ULL, 2U, ~0U), +::xla::cpu_function_runtime::BufferInfo(5ULL, ~0U, ~0U), +::xla::cpu_function_runtime::BufferInfo(386ULL, 3U, ~0U), +::xla::cpu_function_runtime::BufferInfo(481ULL, ~0U, 0U), +::xla::cpu_function_runtime::BufferInfo(386ULL, 4U, ~0U), +::xla::cpu_function_runtime::BufferInfo(17ULL, ~0U, 1U), +::xla::cpu_function_runtime::BufferInfo(81ULL, ~0U, 2U) }; return kBufferInfos; } + static const ::tensorflow::int32* ResultIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kResultIndexToBufferIndex[kNumResults] = { +8, 10, 11 + }; + return kResultIndexToBufferIndex; + } + static const ::tensorflow::int32* ArgIndexToBufferIndex() { static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { 1, 3, 5, 7, 9 @@ -378,7 +392,53 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { } // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = 11; + static constexpr size_t kResultIndex = 0; + + // Shapes of the input arguments. + static constexpr int32_t kArg0Shapes[] = { +1, 2 + }; + static constexpr int32_t kArg1Shapes[] = { +3, 4 + }; + static constexpr int32_t kArg2Shapes[] = { +1 + }; + static constexpr int32_t kArg3Shapes[] = { +1 + }; + static constexpr int32_t kArg4Shapes[] = { +5 + }; + static const ShapeInfo* ArgShapeInfos() { + static constexpr ShapeInfo kArgShapeInfoTable[kNumArgs] = { +{ kArg0Shapes, 2 }, +{ kArg1Shapes, 2 }, +{ kArg2Shapes, 1 }, +{ kArg3Shapes, 1 }, +{ kArg4Shapes, 1 }, + }; + return kArgShapeInfoTable; + }; + + // Shapes of the results. + static constexpr int32_t kResult0Shapes[] = { +5, 6 + }; + static constexpr int32_t kResult1Shapes[] = { +1 + }; + static constexpr int32_t kResult2Shapes[] = { +5 + }; + static const ShapeInfo* ResultShapeInfos() { + static constexpr ShapeInfo kResultShapeInfoTable[kNumResults] = { +{ kResult0Shapes, 2 }, +{ kResult1Shapes, 1 }, +{ kResult2Shapes, 1 }, + }; + return kResultShapeInfoTable; + }; // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() { diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index fd3bf0bb7e9..290a6bb4ab4 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -273,6 +273,15 @@ Status Main(const MainFlags& flags) { codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; codegen_opts.target_triple = flags.target_triple; + // Set the XLA Runtime bit if this is an HloLowering. + if (!flags.mlir_components.empty() && flags.mlir_components != "None") { + for (auto component : absl::StrSplit(flags.mlir_components, ',')) { + if (component == "HloLowering") { + codegen_opts.use_xla_runtime = true; + } + } + } + if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index f04aa37c887..191188b674d 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -69,18 +69,18 @@ py_binary( srcs_version = "PY3", deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:cond", - "//tensorflow/python:control_flow_assert", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_v1", - "//tensorflow/python:variables", + "//tensorflow/python/client", + "//tensorflow/python/client:session", + "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:cond", + "//tensorflow/python/ops:control_flow_assert", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:variable_v1", + "//tensorflow/python/ops:variables", + "//tensorflow/python/training", "@absl_py//absl:app", "@six_archive//:six", ], @@ -437,6 +437,7 @@ tf_cc_test( tags = [ "manual", "no_mac", # TODO(b/228273415) + "not_run:arm", ], deps = [ ":test_graph_tfadd", @@ -510,6 +511,7 @@ tf_cc_test( tags = [ "manual", "no_mac", # TODO(b/228273415) + "not_run:arm", ], deps = [ ":test_graph_tfadd_mlir_bridge", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 872ce4160c3..64138e47c98 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #define EIGEN_USE_THREADS #define EIGEN_USE_CUSTOM_THREAD_POOL diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index c1f8fdc089a..c965760785a 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -319,6 +319,8 @@ def _tf_library( ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually # needed. + "//tensorflow/compiler/xla/service/cpu/runtime:convolution_ffi", + "//tensorflow/compiler/xla/service/cpu/runtime:rng_ffi", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_custom_call_status", "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort", @@ -329,6 +331,7 @@ def _tf_library( "//third_party/eigen3", ] or []) + ( mlir_components.count("HloLowering") > 0 and [ + "//tensorflow/compiler/xla/runtime:aot_ffi_c_symbols", "//tensorflow/compiler/xla/service/cpu:runtime_mlir_utils", ] or [] ) + ( diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index b3fd29ff259..9bc3348b38a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_cuda_only_cc_test") load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") @@ -352,8 +352,10 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tf_pjrt_client", "//tensorflow/compiler/xla/service:executable", "//tensorflow/core/tfrt/common:create_pjrt_client_util", + "//tensorflow/core/tfrt/common:global_state", "//tensorflow/core/tfrt/common:pjrt_util", "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/log", @@ -595,7 +597,8 @@ tf_cc_test( "//tensorflow/core/framework:fake_input", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:ops_testutil", - "@com_google_googletest//:gtest_main", + "//tensorflow/core/tpu:tpu_defs", + "@com_google_googletest//:gtest", ], ) @@ -782,6 +785,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", ], ) @@ -1509,6 +1513,70 @@ cc_library( ], ) +cc_library( + name = "xla_host_recv_device_context", + srcs = [ + "xla_host_recv_device_context.cc", + ], + hdrs = [ + "xla_host_recv_device_context.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla/stream_executor:device_memory", + "//tensorflow/core:framework", + "@tf_runtime//:async_value", + ], +) + +cc_library( + name = "xla_host_send_device_context", + srcs = [ + "xla_host_send_device_context.cc", + ], + hdrs = [ + "xla_host_send_device_context.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla/stream_executor:device_memory", + "//tensorflow/core:framework", + "@tf_runtime//:async_value", + ], +) + +tf_cuda_only_cc_test( + name = "xla_host_send_recv_device_context_test", + srcs = ["xla_host_send_recv_device_context_test.cc"], + tags = tf_cuda_tests_tags() + [ + "config-cuda-only", + "no_oss", # Temporarily disable OSS. + ], + deps = [ + ":flags", + ":xla_device", + ":xla_gpu_device", + ":xla_host_recv_device_context", + ":xla_host_send_device_context", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla/stream_executor:device_memory", + "//tensorflow/compiler/xla/stream_executor:multi_platform_manager", + "//tensorflow/core:framework_internal", + "//tensorflow/core:test", + "//tensorflow/core/framework:tensor_testutil", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_test( name = "device_compilation_cluster_signature_test", srcs = [ @@ -1527,6 +1595,9 @@ tf_cc_test( tf_cc_test( name = "device_compilation_profiler_test", srcs = ["device_compilation_profiler_test.cc"], + tags = [ + "nomsan", # TODO(b/284492454) + ], deps = [ ":device_compilation_profiler", ":xla_activity_proto_cc", @@ -1641,6 +1712,7 @@ tf_cuda_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":flags", + ":pjrt_device_compiler_client", ":test_util", ":xla_device_no_jit_rewrite_registration", ":xla_gpu_device", @@ -1649,6 +1721,7 @@ tf_cuda_cc_test( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core:framework", "//tensorflow/core:framework_types_hdr", "//tensorflow/core/tpu:tpu_defs", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f0fcd17ba23..e426c9d40d9 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include #include -#include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -188,7 +188,7 @@ class Encapsulator { // Adds the function call node to graph_out. Status AddFunctionCallNode( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, Graph* graph_out); // Returns the Node that the inputs and outputs of the function should be @@ -206,7 +206,7 @@ class Encapsulator { // and adds the edge within the subgraph from the _Arg node to the image of // the dst node. Status RecordArg(const Edge* edge, - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs); // Records the src of the given edge as a control result of the graph. @@ -214,14 +214,14 @@ class Encapsulator { // the function signature. Status RecordControlResult( const Edge* edge, - const std::unordered_map& node_images); + const absl::flat_hash_map& node_images); // Creates a _Retval node for the src node of edge, and add it to results_, // if none exists yet. If a new _Retval node is created, also adds the edge // within the subgraph from the src to the _Retval node. Status RecordResult( const Edge* edge, - const std::unordered_map& node_images); + const absl::flat_hash_map& node_images); // Creates the sequencer node if it doesn't exist, adding it to graph_out. Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); @@ -260,14 +260,14 @@ class Encapsulator { // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src_; - std::unordered_map args_by_dst_; + absl::flat_hash_map args_by_src_; + absl::flat_hash_map args_by_dst_; // The arguments to the subgraph, in order. std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results_; + absl::flat_hash_map results_; // Set of node names that are the source of a control output of the // subgraph. We store strings here so that we can tolerate nodes being @@ -285,19 +285,20 @@ class Encapsulator { // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to // subgraphs for data edges that cross subgraph boundaries. Status CopySubgraphEdges( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs); // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes. - Status CopySubgraphNodes(std::unordered_map* node_images); + Status CopySubgraphNodes( + absl::flat_hash_map* node_images); // Copies all nodes that aren't in a compiled subgraph to the output graph. Status CopyNodesToOutputGraph( - Graph* graph_out, std::unordered_map* node_images); + Graph* graph_out, absl::flat_hash_map* node_images); // Adds function call nodes for each compiled subgraph. Status AddFunctionCallNodes( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, Graph* graph_out); // Finds the image of an edge source in the output graph. If the edge crosses @@ -305,7 +306,7 @@ class Encapsulator { // in the output graph. Status FindOutputImageOfEdgeSrc( const string& src_func_id, const string& dst_func_id, - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image); // Finds an edge source slot in the output graph. If the edge crosses a @@ -320,7 +321,7 @@ class Encapsulator { // a node in the output graph. Status FindOutputImageOfEdgeDst( const string& src_func_id, const string& dst_func_id, - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image); // Finds an edge destination slot in the output graph. If the edge crosses a @@ -334,14 +335,14 @@ class Encapsulator { // within the output graph, or crosses into or out of a compiled subgraph. Status CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& dst_func_id, - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, Graph* graph_out, - std::unordered_set, - OutputInputTensorPairHasher>* edges_added); + absl::flat_hash_set, + OutputInputTensorPairHasher>* edges_added); // Adds all edges to the output graph. Status AddEdgesToOutputGraph( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, Graph* graph_out); // Makes a copy of graph containing only nodes that are ancestors of at least @@ -351,13 +352,13 @@ class Encapsulator { Status MakePrunedGraphCopyAndInline( const Graph& graph, const std::vector& sink_nodes, std::unique_ptr* pruned_graph, - std::unordered_map* node_images, + absl::flat_hash_map* node_images, FunctionLibraryDefinition* library); const string group_attribute_; const Graph* graph_in_; - std::unordered_map subgraphs_; + absl::flat_hash_map subgraphs_; TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); }; @@ -369,9 +370,9 @@ namespace { // including clusters that are not present in the ancestors map. has_successors // is the set of clusters that are ancestors of some other cluster. void TopologicalClusterSort( - const std::unordered_set& clusters, - const std::unordered_set& has_successors, - const std::unordered_map>& ancestors, + const absl::flat_hash_set& clusters, + const absl::flat_hash_set& has_successors, + const absl::flat_hash_map>& ancestors, std::vector* sorted) { // The nodes are placed in 'sorted' in topological order. sorted->clear(); @@ -447,11 +448,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } Status Encapsulator::Subgraph::RecordArg( - const Edge* edge, const std::unordered_map& node_images, + const Edge* edge, + const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs) { Node* src_node = edge->src(); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + absl::flat_hash_map::iterator iter; bool inserted; std::tie(iter, inserted) = args_by_src_.emplace( OutputTensor(src_node, src_slot), args_by_src_.size()); @@ -481,7 +483,7 @@ Status Encapsulator::Subgraph::RecordArg( Status Encapsulator::Subgraph::RecordControlResult( const Edge* edge, - const std::unordered_map& node_images) { + const absl::flat_hash_map& node_images) { Node* src_node = edge->src(); Node* src_image = node_images.at(src_node); control_output_nodes_.insert(src_image->name()); @@ -490,11 +492,11 @@ Status Encapsulator::Subgraph::RecordControlResult( Status Encapsulator::Subgraph::RecordResult( const Edge* edge, - const std::unordered_map& node_images) { + const absl::flat_hash_map& node_images) { Node* src_node = edge->src(); Node* src_image = node_images.at(src_node); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + absl::flat_hash_map::iterator iter; bool inserted; std::tie(iter, inserted) = results_.emplace(OutputTensor(src_node, src_slot), results_.size()); @@ -592,7 +594,7 @@ Status Encapsulator::Subgraph::BuildFunctionDef( FunctionDef fdef; auto lookup = [this](const Node* node) -> std::optional { if (control_output_nodes_.contains(node->name())) { - return absl::make_optional(node->name()); + return std::make_optional(node->name()); } return std::nullopt; }; @@ -637,7 +639,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( } Status Encapsulator::Subgraph::AddFunctionCallNode( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, Graph* graph_out) { TF_ASSIGN_OR_RETURN(call_node_, graph_out->AddNode(call_node_def_)); @@ -663,7 +665,7 @@ Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } Status Encapsulator::CopySubgraphNodes( - std::unordered_map* node_images) { + absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); @@ -678,7 +680,7 @@ Status Encapsulator::CopySubgraphNodes( } Status Encapsulator::CopySubgraphEdges( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs) { for (const Edge* edge : graph_in_->edges()) { string src_func_id; @@ -752,7 +754,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { Status s; // Map from input graph nodes to subgraph nodes. - std::unordered_map node_images; + absl::flat_hash_map node_images; // Each entry of src_arg_pairs is a pair whose first element is a node in the // original graph that has an output edge in the subgraph, and whose second @@ -794,7 +796,7 @@ Status Encapsulator::BuildFunctionDefs( } Status Encapsulator::CopyNodesToOutputGraph( - Graph* graph_out, std::unordered_map* node_images) { + Graph* graph_out, absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); @@ -811,7 +813,7 @@ Status Encapsulator::CopyNodesToOutputGraph( } Status Encapsulator::AddFunctionCallNodes( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { TF_RETURN_IF_ERROR( @@ -822,7 +824,7 @@ Status Encapsulator::AddFunctionCallNodes( Status Encapsulator::FindOutputImageOfEdgeSrc( const string& src_func_id, const string& dst_func_id, - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image) { if (IsInSubgraph(src_func_id)) { // The edge is from a subgraph to a regular node in the output graph so @@ -853,7 +855,7 @@ int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, Status Encapsulator::FindOutputImageOfEdgeDst( const string& src_func_id, const string& dst_func_id, - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image) { if (IsInSubgraph(dst_func_id)) { // The edge is to a subgraph from a regular node in the output graph so @@ -884,9 +886,10 @@ int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, Status Encapsulator::CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& dst_func_id, - const std::unordered_map& node_images, Graph* graph_out, - std::unordered_set, - OutputInputTensorPairHasher>* edges_added) { + const absl::flat_hash_map& node_images, + Graph* graph_out, + absl::flat_hash_set, + OutputInputTensorPairHasher>* edges_added) { Node* src_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( src_func_id, dst_func_id, node_images, edge->src(), &src_image)); @@ -924,13 +927,13 @@ Status Encapsulator::CopyEdgeToOutputGraph( } Status Encapsulator::AddEdgesToOutputGraph( - const std::unordered_map& node_images, + const absl::flat_hash_map& node_images, Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. - std::unordered_set, - OutputInputTensorPairHasher> + absl::flat_hash_set, + OutputInputTensorPairHasher> edges_added; for (const Edge* edge : graph_in_->edges()) { @@ -1010,7 +1013,7 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, Status Encapsulator::MakePrunedGraphCopyAndInline( const Graph& graph, const std::vector& sink_nodes, std::unique_ptr* pruned_graph, - std::unordered_map* node_images, + absl::flat_hash_map* node_images, FunctionLibraryDefinition* library) { // First copy all ancestor nodes of sink_nodes into a new graph. pruned_graph->reset(new Graph(library)); @@ -1070,7 +1073,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( Status Encapsulator::BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. - std::unordered_map node_images; + absl::flat_hash_map node_images; TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images)); TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 6e8fceaf47d..54c79d77ca8 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -34,6 +34,7 @@ BuildXlaOpsPassFlags* build_ops_flags; MarkForCompilationPassFlags* mark_for_compilation_flags; XlaDeviceFlags* device_flags; XlaOpsCommonFlags* ops_flags; +XlaCallModuleFlags* call_module_flags; MlirCommonFlags* mlir_flags; JitRtFlags* jitrt_flags; std::vector* jitrt_flag_list; @@ -76,6 +77,13 @@ bool SetterForXlaAutoJitFlag(const string& value) { return true; } +bool SetterForXlaCallModuleDisabledChecks(const string& value) { + auto directives = absl::StrSplit(value, ',', absl::SkipEmpty()); + call_module_flags->disabled_checks.insert(directives.begin(), + directives.end()); + return true; +} + void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { std::vector new_flags = { Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0", @@ -184,6 +192,7 @@ void AllocateAndParseFlags() { build_ops_flags->tf_xla_check_cluster_input_numerics = false; build_ops_flags->tf_xla_check_cluster_output_numerics = false; build_ops_flags->tf_xla_disable_constant_folding = false; + build_ops_flags->tf_xla_disable_full_embedding_pipelining = false; mark_for_compilation_flags = new MarkForCompilationPassFlags; mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu = @@ -213,9 +222,12 @@ void AllocateAndParseFlags() { ops_flags = new XlaOpsCommonFlags; ops_flags->tf_xla_always_defer_compilation = false; ops_flags->tf_xla_async_compilation = false; - ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_ = false; - ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_ = false; + ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_ = true; + ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_ = true; + ops_flags->tf_xla_use_device_api.enabled_for_compile_and_run_ = false; + ops_flags->tf_xla_use_device_api.enabled_for_all_ = false; + call_module_flags = new XlaCallModuleFlags; // The `enable_mlir_bridge` flag allows the user to explicitly request that // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge. // @@ -251,6 +263,10 @@ void AllocateAndParseFlags() { &build_ops_flags->tf_xla_disable_constant_folding, "If true then disables constant folding on TF graph before XLA " "compilation."), + Flag("tf_xla_disable_full_embedding_pipelining", + &build_ops_flags->tf_xla_disable_full_embedding_pipelining, + "If true then disables full embedding pipelining and instead use " + "strict SparseCore / TensorCore sequencing."), Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, "Switch a device into 'on-demand' mode, where instead of " @@ -277,6 +293,22 @@ void AllocateAndParseFlags() { &ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_, "If true, uses Device API (PjRt) for compiling and executing ops " "one by one in 'on-demand' mode. Defaults to false."), + Flag("tf_xla_use_device_api_for_auto_jit", + &ops_flags->tf_xla_use_device_api.enabled_for_compile_and_run_, + "If true, uses Device API (PjRt) for compilation and execution " + "when auto-clustering is enabled. Defaults to false."), + Flag("tf_xla_use_device_api", + &ops_flags->tf_xla_use_device_api.enabled_for_all_, + "If true, uses Device API (PjRt) for compilation and execution " + "of ops one-by-one in 'on-demand' mode, for functions marked for " + "JIT compilation, or when auto-clustering is enabled. Defaults to " + "false."), + + Flag("tf_xla_call_module_disabled_checks", + SetterForXlaCallModuleDisabledChecks, "", + "A comma-sepated list of directives specifying the safety checks " + "to be skipped when compiling XlaCallModuleOp. See the op " + "documentation for the recognized values."), Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, "Enables experimental MLIR-Based TensorFlow Compiler Bridge.", @@ -365,6 +397,11 @@ XlaOpsCommonFlags* GetXlaOpsCommonFlags() { return ops_flags; } +XlaCallModuleFlags* GetXlaCallModuleFlags() { + absl::call_once(flags_init, &AllocateAndParseFlags); + return call_module_flags; +} + MlirCommonFlags* GetMlirCommonFlags() { absl::call_once(flags_init, &AllocateAndParseFlags); return mlir_flags; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 9f151b89eb7..042b3688fba 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -137,8 +137,9 @@ struct XlaOpsCommonFlags { } bool IsEnabledInXlaLaunchForDevice(const DeviceType& device_type) const { - return enabled_for_xla_launch_ && - xla_launch_allowed_devices_.contains(device_type.type_string()); + return enabled_for_all_ || + (enabled_for_xla_launch_ && + xla_launch_allowed_devices_.contains(device_type.type_string())); } // Allow using Device API (PjRt) for `device_type` in the XlaCompileOnDemand @@ -152,9 +153,26 @@ struct XlaOpsCommonFlags { bool IsEnabledInXlaCompileOnDemandForDevice( const DeviceType& device_type) const { - return enabled_for_compile_on_demand_ && - xla_compile_on_demand_allowed_devices_.contains( - device_type.type_string()); + return enabled_for_all_ || + (enabled_for_compile_on_demand_ && + xla_compile_on_demand_allowed_devices_.contains( + device_type.type_string())); + } + + // Allow using Device API (PjRt) for `device_type` in the XlaCompile and + // XlaRun ops. Please note that `enabled_for_compile_and_run_` needs to be + // true in addition to the `device_type` being allowed in order to use the + // Device API for single device compilation and execution in the XlaCompile + // and XlaRun ops. + void AllowForDeviceInXlaCompileAndRun(const DeviceType& device_type) { + xla_compile_and_run_allowed_devices_.insert(device_type.type_string()); + } + + bool IsEnabledInXlaCompileAndRunForDevice( + const DeviceType& device_type) const { + return enabled_for_all_ || (enabled_for_compile_and_run_ && + xla_compile_and_run_allowed_devices_.contains( + device_type.type_string())); } // If true, uses Device API (PjRt) for single device compilation and @@ -166,6 +184,16 @@ struct XlaOpsCommonFlags { // one in "on-demand" mode. Defaults to false. bool enabled_for_compile_on_demand_; + // If true, uses Device API (PjRt) for compilation and execution when + // auto-clustering is enabled. Defaults to false. + bool enabled_for_compile_and_run_; + + // If true, uses Device API (PjRt) for compilation and execution everywhere + // i.e. for functions marked for JIT compilation, for ops in "on-demand" + // mode and autoclustering, no matter whether other flags are enabled or + // not, and whether devices have been allowed or not. Defaults to false. + bool enabled_for_all_; + private: // Devices for which using Device API (PjRt) is allowed in the XlaLaunch op. // This can only be modified programmatically. @@ -173,9 +201,18 @@ struct XlaOpsCommonFlags { // Devices for which using Device API (PjRt) is allowed in the // XlaCompileOnDemand op. This can only be modified programmatically. absl::flat_hash_set xla_compile_on_demand_allowed_devices_; + // Devices for which using Device API (PjRt) is allowed in the + // XlaCompile and XlaRun ops. This can only be modified programmatically. + absl::flat_hash_set xla_compile_and_run_allowed_devices_; } tf_xla_use_device_api; }; +// Flags for the XlaCallModule kernel. +struct XlaCallModuleFlags { + // Used by XlaCallModuleOp to specify safety checks to disable. + absl::flat_hash_set disabled_checks; +}; + // Flags for the build_xla_ops pass. struct BuildXlaOpsPassFlags { // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. @@ -197,6 +234,10 @@ struct BuildXlaOpsPassFlags { // Disables all constant folding. The primary use for this is for testing to // guarantee that tests are run on XLA and not on TF's CPU implementation. bool tf_xla_disable_constant_folding; + + // Disables full embedding pipelining when true. Instead, strict SparseCore + // TensorCore sequencing will be used. + bool tf_xla_disable_full_embedding_pipelining; }; // Flags for common MLIR configurations. @@ -235,6 +276,7 @@ MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags(); XlaDeviceFlags* GetXlaDeviceFlags(); XlaOpsCommonFlags* GetXlaOpsCommonFlags(); +XlaCallModuleFlags* GetXlaCallModuleFlags(); MlirCommonFlags* GetMlirCommonFlags(); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 8046207ed54..d651933a5d2 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -23,6 +23,8 @@ XLA_OPS_DEPS = [ "//tensorflow/compiler/jit:variable_info_util", "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/jit:xla_host_recv_device_context", + "//tensorflow/compiler/jit:xla_host_send_device_context", "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -59,6 +61,8 @@ cc_library( "//tensorflow/compiler/jit:xla_compile_util", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core/platform:refcount", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 913cca35be3..ff134d49c50 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -20,11 +20,15 @@ limitations under the License. #include #include #include +#include +#include #include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/device_compiler.h" @@ -35,6 +39,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/jit/xla_compiler_options_util.h" +#include "tensorflow/compiler/jit/xla_host_recv_device_context.h" +#include "tensorflow/compiler/jit/xla_host_send_device_context.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -92,10 +98,11 @@ auto* xla_launch_counter = monitoring::Counter<1>::New( // the initial values for the resource variables (and cannot snapshot them again // during execution) because otherwise we risk observing a different snapshot // with shapes different from what we compiled for. -class XlaExecutableClosure { +template +class ExecutableClosure { public: - explicit XlaExecutableClosure( - xla::LocalClient* client, xla::LocalExecutable* executable, + explicit ExecutableClosure( + ClientType* client, ExecutableType* executable, const XlaCompiler::CompilationResult* compilation_result, ResourceVarsSnapshot resource_var_snapshots, int num_constant_args) : client_(client), @@ -104,11 +111,11 @@ class XlaExecutableClosure { resource_var_snapshots_(std::move(resource_var_snapshots)), num_constant_args_(num_constant_args) {} - XlaExecutableClosure(XlaExecutableClosure&&) = default; - XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; + ExecutableClosure(ExecutableClosure&&) = default; + ExecutableClosure& operator=(ExecutableClosure&&) = default; - xla::LocalClient* client() const { return client_; } - xla::LocalExecutable* executable() const { return executable_; } + ClientType* client() const { return client_; } + ExecutableType* executable() const { return executable_; } const XlaCompiler::CompilationResult* compilation_result() const { return compilation_result_; } @@ -118,24 +125,25 @@ class XlaExecutableClosure { int num_constant_args() const { return num_constant_args_; } private: - xla::LocalClient* client_; - xla::LocalExecutable* executable_; + ClientType* client_; + ExecutableType* executable_; const XlaCompiler::CompilationResult* compilation_result_; ResourceVarsSnapshot resource_var_snapshots_; int num_constant_args_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); + TF_DISALLOW_COPY_AND_ASSIGN(ExecutableClosure); }; -// This maintains a mapping from a globally unique ID to XlaExecutableClosure +// This maintains a mapping from a globally unique ID to ExecutableClosure // instances. -class XlaExecutableClosureStore { +template +class ExecutableClosureStore { public: - XlaExecutableClosureStore() : key_counter_(0) {} + ExecutableClosureStore() : key_counter_(0) {} using KeyT = string; - KeyT Produce(XlaExecutableClosure result) { + KeyT Produce(ExecutableClosure result) { mutex_lock l(mutex_); KeyT key = absl::StrCat(key_counter_++); bool insert_successful = closures_.emplace(key, std::move(result)).second; @@ -144,29 +152,38 @@ class XlaExecutableClosureStore { return key; } - XlaExecutableClosure Consume(const KeyT& key) { + ExecutableClosure Consume(const KeyT& key) { mutex_lock l(mutex_); auto it = closures_.find(key); DCHECK(it != closures_.end()); - XlaExecutableClosure value = std::move(it->second); + ExecutableClosure value = std::move(it->second); closures_.erase(it); return value; } - static XlaExecutableClosureStore* Global() { - static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore; + static ExecutableClosureStore* Global() { + static ExecutableClosureStore* instance = new ExecutableClosureStore; return instance; } private: mutex mutex_; int64_t key_counter_ TF_GUARDED_BY(mutex_); - absl::flat_hash_map closures_ - TF_GUARDED_BY(mutex_); + absl::flat_hash_map> + closures_ TF_GUARDED_BY(mutex_); - TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); + TF_DISALLOW_COPY_AND_ASSIGN(ExecutableClosureStore); }; +using XlaExecutableClosure = + ExecutableClosure; +using XlaExecutableClosureStore = + ExecutableClosureStore; +using PjRtExecutableClosure = + ExecutableClosure; +using PjRtExecutableClosureStore = + ExecutableClosureStore; + se::Stream* GetStream(OpKernelContext* ctx) { return ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -185,6 +202,111 @@ XlaComputationLaunchContext GetLaunchContext( return launch_context; } +Status GetTaskName(const std::string_view device_name, std::string* task_name) { + string ignored; + if (!DeviceNameUtils::SplitDeviceName(device_name, task_name, &ignored)) { + return errors::InvalidArgument("Unable to parse device name: ", + device_name); + } + + return OkStatus(); +} + +// Provide SendDeviceMemoryFunction for XLA host callbacks. This callback +// handles transferring from device to host. +xla::SendDeviceMemoryFunction GetSendDeviceMemoryFunction( + OpKernelContext* ctx) { + return + [ctx](int64_t channel_id, se::Stream* stream, const xla::Shape& shape, + const se::DeviceMemoryBase& device_memory_base, + const absl::flat_hash_map& frontend_attrs) + -> StatusOr> { + auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous"); + + // Generate the Rendezvous key. + const std::string& rendezvous_key_base = iter->second; + const std::string& src_device = ctx->device()->name(); + + std::string task_prefix; + TF_RETURN_IF_ERROR(GetTaskName(src_device, &task_prefix)); + const std::string dst_device = + absl::StrCat(task_prefix, "/device:CPU:0"); + const std::string& rendezvous_key = + Rendezvous::CreateKey(src_device, /*src_incarnation=*/1, dst_device, + rendezvous_key_base, FrameAndIter(0, 0)); + VLOG(2) << "Rendezvous Key for receiving at host: " << rendezvous_key; + + RendezvousInterface::ParsedKey parsed_key; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key)); + + tsl::AsyncValueRef done_event = + tsl::MakeConstructedAsyncValueRef(stream->parent()); + if (!done_event->Init()) { + return errors::Internal( + "Failed to initialize done event (channel_id=%d)", channel_id); + } + + Rendezvous::Args args; + // Rendezvous::Args owns the device context pointer. + args.device_context = new XlaHostRecvDeviceContext( + stream, device_memory_base, shape, done_event); + + Tensor host_tensor; + TF_RETURN_IF_ERROR( + ctx->rendezvous()->Send(parsed_key, args, host_tensor, false)); + + return std::move(done_event); + }; +} + +// Provide RecvDeviceMemoryFunction for XLA host callbacks. This callback +// handles transferring from host to device. +xla::RecvDeviceMemoryFunction GetRecvDeviceMemoryFunction( + OpKernelContext* ctx) { + return + [ctx](int64_t channel_id, se::Stream* stream, const xla::Shape& shape, + se::DeviceMemoryBase* device_memory_base, + const absl::flat_hash_map& frontend_attrs) + -> StatusOr> { + auto iter = frontend_attrs.find("_xla_host_transfer_rendezvous"); + + // Generate the Rendezvous key. + const std::string& rendezvous_key_base = iter->second; + const std::string& dst_device = ctx->device()->name(); + + std::string task_prefix; + TF_RETURN_IF_ERROR(GetTaskName(dst_device, &task_prefix)); + const std::string src_device = + absl::StrCat(task_prefix, "/device:CPU:0"); + const std::string& rendezvous_key = + Rendezvous::CreateKey(src_device, /*src_incarnation=*/1, dst_device, + rendezvous_key_base, FrameAndIter(0, 0)); + VLOG(2) << "Rendezvous Key for sending from host: " << rendezvous_key; + + RendezvousInterface::ParsedKey parsed_key; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(rendezvous_key, &parsed_key)); + + tsl::AsyncValueRef done_event = + tsl::MakeConstructedAsyncValueRef(stream->parent()); + if (!done_event->Init()) { + return errors::Internal( + "Failed to initialize done event (channel_id=%d)", channel_id); + } + + Rendezvous::Args args; + // Rendezvous::Args owns the device context pointer. + args.device_context = new XlaHostSendDeviceContext( + stream, device_memory_base, shape, done_event); + + Tensor device_tensor; + bool is_dead; + TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv( + parsed_key, args, &device_tensor, /*is_dead=*/&is_dead)); + + return std::move(done_event); + }; +} + StatusOr RunExecutable( const XlaPlatformInfo& platform_info, const XlaComputationLaunchContext& launch_context, @@ -200,6 +322,15 @@ StatusOr RunExecutable( run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); + + // Host callbacks used for HLO send/recv. + xla::SendDeviceMemoryFunction send_function = + GetSendDeviceMemoryFunction(ctx); + run_options.set_send_device_memory_function(&send_function); + xla::RecvDeviceMemoryFunction recv_function = + GetRecvDeviceMemoryFunction(ctx); + run_options.set_recv_device_memory_function(&recv_function); + StatusOr execution_output; bool run_synchronous = !stream || platform_info.platform_id() == se::host::kHostPlatformId; @@ -263,7 +394,7 @@ Status CompileToLocalExecutable( // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); if (!rm) { - return errors::Internal("No resource manager."); + return absl::InternalError("No resource manager."); } XlaDeviceCompiler* xla_device_compiler; @@ -312,23 +443,13 @@ Status CompileToPjRtLoadedExecutable( // in the ResourceMgr. ResourceMgr* rm = ctx.resource_manager(); if (!rm) { - return errors::Internal("No resource manager."); + return absl::InternalError("No resource manager."); } PjRtDeviceCompiler* pjrt_device_compiler; - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "pjrt_device_compiler", &pjrt_device_compiler, - [&](PjRtDeviceCompiler** pjrt_device_compiler) { - return BuildPjRtDeviceCompiler(platform_info, ctx.function_library(), - pjrt_device_compiler); - })); DeviceCompilationProfiler* profiler; - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "pjrt_device_compilation_profiler", &profiler, - [](DeviceCompilationProfiler** profiler) { - *profiler = new DeviceCompilationProfiler(); - return OkStatus(); - })); + TF_RETURN_IF_ERROR(GetOrCreatePjRtDeviceCompilerAndProfiler( + platform_info, ctx.function_library(), &pjrt_device_compiler, &profiler)); // Hold the reference to the PJRT device compiler and profiler during // evaluation. (We could probably free them sooner because the ResourceMgr // will retain references, but this is more obviously correct.) @@ -337,8 +458,9 @@ Status CompileToPjRtLoadedExecutable( *client = pjrt_device_compiler->client(); - XlaCompiler::Options options = GenerateCompilerOptionsForPjRt( - *ctx.function_library(), ctx.device(), platform_info); + XlaCompiler::Options options = + GenerateCompilerOptionsForPjRt(*ctx.function_library(), ctx.device(), + platform_info, pjrt_device_compiler); XlaCompiler::CompileOptions compile_options = GenerateCompileOptions(has_ref_vars, may_alias_resource_update); @@ -474,19 +596,23 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { compilation_result, done, inputs, resources = resources_]() { auto platform_info = XlaPlatformInfoFromDevice(ctx->device()); - std::vector variable_infos; - OP_REQUIRES_OK_ASYNC( - ctx, - GetUpdatedVariables(ctx, inputs, resources, *compilation_result, - &variable_infos), - done); - OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)), - done); - OP_REQUIRES_OK_ASYNC( - ctx, - RunPjRtExecutable(*pjrt_client, inputs, variable_infos, - *compilation_result, pjrt_executable, ctx), - done); + // Separate scope so that VariableInfo locks are released before done() is + // called. + { + std::vector variable_infos; + OP_REQUIRES_OK_ASYNC( + ctx, + GetUpdatedVariables(ctx, inputs, resources, *compilation_result, + &variable_infos), + done); + OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)), + done); + OP_REQUIRES_OK_ASYNC( + ctx, + RunPjRtExecutable(*pjrt_client, inputs, variable_infos, + *compilation_result, pjrt_executable, ctx), + done); + } VLOG(2) << "Done executing with PJRT."; done(); }; @@ -505,65 +631,69 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { // Continuation of the execution, may be run in a different thread. auto run_xla_cluster = [ctx, client, executable, compilation_result, done, inputs, resources = resources_]() { - auto platform_info = XlaPlatformInfoFromDevice(ctx->device()); - std::vector variable_infos; - OP_REQUIRES_OK_ASYNC( - ctx, - GetUpdatedVariables(ctx, inputs, resources, *compilation_result, - &variable_infos), - done); - OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)), - done); - std::map resource_var_ptrs; - for (int i = 0; i < resources.size(); i++) { - resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor(); - } - - std::shared_ptr allocator = - GetAllocator(ctx->device(), GetStream(ctx), platform_info); - XlaComputationLaunchContext launch_context = - GetLaunchContext(platform_info, ctx, client, allocator.get()); - - const xla::HloInputOutputAliasConfig& input_output_alias = - executable->executable()->module().input_output_alias_config(); - StatusOr> execution_inputs = - launch_context.PopulateInputs( - ctx, compilation_result, resource_var_ptrs, - /*missing_ctx_input_prefix=*/0, input_output_alias); - OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done); - - xla::gpu::GpuExecutableRunOptions gpu_options; - xla::DeviceAssignment device_assignment; - xla::ExecutableRunOptions run_options; - if (compilation_result->collective_info.has_value()) { + // Separate scope so that VariableInfo locks are released before done is + // called. + { + auto platform_info = XlaPlatformInfoFromDevice(ctx->device()); + std::vector variable_infos; OP_REQUIRES_OK_ASYNC( ctx, - ResolveDeviceAssignment(ctx, *compilation_result->collective_info, - run_options, device_assignment, gpu_options), + GetUpdatedVariables(ctx, inputs, resources, *compilation_result, + &variable_infos), done); + OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)), + done); + std::map resource_var_ptrs; + for (int i = 0; i < resources.size(); i++) { + resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor(); + } + + std::shared_ptr allocator = + GetAllocator(ctx->device(), GetStream(ctx), platform_info); + XlaComputationLaunchContext launch_context = + GetLaunchContext(platform_info, ctx, client, allocator.get()); + + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + StatusOr> execution_inputs = + launch_context.PopulateInputs( + ctx, compilation_result, resource_var_ptrs, + /*missing_ctx_input_prefix=*/0, input_output_alias); + OP_REQUIRES_OK_ASYNC(ctx, execution_inputs.status(), done); + + xla::gpu::GpuExecutableRunOptions gpu_options; + xla::DeviceAssignment device_assignment; + xla::ExecutableRunOptions run_options; + if (compilation_result->collective_info.has_value()) { + OP_REQUIRES_OK_ASYNC(ctx, + ResolveDeviceAssignment( + ctx, *compilation_result->collective_info, + run_options, device_assignment, gpu_options), + done); + } + + // Hardcode run id to always be zero: TF distributed strategy + // differentiates between subsequent runs using dependency edges. This + // is safe, as only TF dist-strat can produce distributed ops, and we + // can rely on TF dist-strat invariants. + xla::RunId run_id(0); + run_options.set_run_id(run_id); + + StatusOr execution_output = RunExecutable( + platform_info, launch_context, std::move(*execution_inputs), + run_options, executable, ctx, allocator.get()); + OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(), + done); + + OP_REQUIRES_OK_ASYNC( + ctx, + launch_context.PopulateOutputs( + ctx, compilation_result, execution_output->ConsumeResult(), + /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos), + input_output_alias, resource_var_ptrs), + done); + VLOG(1) << "Done"; } - - // Hardcode run id to always be zero: TF distributed strategy - // differentiates between subsequent runs using dependency edges. This - // is safe, as only TF dist-strat can produce distributed ops, and we - // can rely on TF dist-strat invariants. - xla::RunId run_id(0); - run_options.set_run_id(run_id); - - StatusOr execution_output = RunExecutable( - platform_info, launch_context, std::move(*execution_inputs), - run_options, executable, ctx, allocator.get()); - OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(), - done); - - OP_REQUIRES_OK_ASYNC( - ctx, - launch_context.PopulateOutputs( - ctx, compilation_result, execution_output->ConsumeResult(), - /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos), - input_output_alias, resource_var_ptrs), - done); - VLOG(1) << "Done"; done(); }; @@ -658,9 +788,11 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) void XlaCompileOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaCompileOp " << def().name() << (must_compile_ ? "(must-compile)" : ""); - xla::LocalClient* client; - const XlaCompiler::CompilationResult* kernel; - xla::LocalExecutable* executable; + const XlaCompiler::CompilationResult* kernel = nullptr; + xla::LocalClient* client = nullptr; + xla::LocalExecutable* executable = nullptr; + xla::PjRtClient* pjrt_client = nullptr; + xla::PjRtLoadedExecutable* pjrt_executable = nullptr; ResourceVarsSnapshot variables_snapshot; std::vector inputs = InputsFromContext(ctx); @@ -678,6 +810,11 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { : DeviceCompileMode::kLazy; }(); + bool use_pjrt = + GetXlaOpsCommonFlags() + ->tf_xla_use_device_api.IsEnabledInXlaCompileAndRunForDevice( + platform_info_.device_type()); + if (GetXlaOpsCommonFlags()->tf_xla_always_defer_compilation || cannot_compile_cluster) { executable = nullptr; @@ -691,22 +828,33 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { // Do not alias resource updates as locking variables in XlaCompile and // unlocking them in XlaRun may lead to deadlocks. - const Status status = CompileToLocalExecutable( - ctx, function_, has_ref_vars_, platform_info_, args, compile_mode, - /*may_alias_resource_update=*/false, &client, &kernel, &executable); + Status status; + if (use_pjrt) { + VLOG(2) << "Using PJRT for compilation. Function name: " + << function_.name(); + status = CompileToPjRtLoadedExecutable( + *ctx, platform_info_, function_, args, compile_mode, has_ref_vars_, + /*may_alias_resource_update=*/false, &kernel, &pjrt_client, + &pjrt_executable); + } else { + status = CompileToLocalExecutable( + ctx, function_, has_ref_vars_, platform_info_, args, compile_mode, + /*may_alias_resource_update=*/false, &client, &kernel, &executable); + } if (compile_mode != DeviceCompileMode::kLazy || status.code() != error::UNIMPLEMENTED) { OP_REQUIRES_OK(ctx, status); } if (status.code() == error::UNIMPLEMENTED) { - LOG(WARNING) << "Compilation failed:" << status.ToString() + LOG(WARNING) << "Compilation failed:" << status << ". Falling back to TF function call."; BroadcastOptimizationRemark( XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString()) .IgnoreError(); executable = nullptr; + pjrt_executable = nullptr; mutex_lock guard(cannot_compile_cluster_mu_); cannot_compile_cluster_ = true; } @@ -718,28 +866,36 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs); // Async compilation returns nullptr executable without an error. - if (!executable) { + if (!executable && !pjrt_executable) { DCHECK(!must_compile_); Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); - Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); compilation_successful.scalar()() = false; - ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({}))); + ctx->set_output(0, compilation_key); ctx->set_output(1, compilation_successful); return; } - // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even + // Each execution of an XlaCompile op creates a new ExecutableClosure, even // if it didn't have to compile the cluster because of a compilation-cache // hit. This is because we at least need new snapshots of the resource // variables. - XlaExecutableClosureStore::KeyT key = - XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( - client, executable, kernel, std::move(variables_snapshot), - constants_.size())); - Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); - compilation_key.flat()(0) = key; + if (use_pjrt) { + PjRtExecutableClosureStore::KeyT key = + PjRtExecutableClosureStore::Global()->Produce(PjRtExecutableClosure( + pjrt_client, pjrt_executable, kernel, std::move(variables_snapshot), + constants_.size())); + compilation_key.flat()(0) = key; + VLOG(2) << "Compiled with PJRT. compilation_key: " << key; + } else { + XlaExecutableClosureStore::KeyT key = + XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( + client, executable, kernel, std::move(variables_snapshot), + constants_.size())); + compilation_key.flat()(0) = key; + VLOG(2) << "Compiled with XLA. compilation_key: " << key; + } Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); compilation_successful.flat()(0) = true; diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index e70b5c2525d..1059a263d57 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -24,5 +24,5 @@ py_library( name = "xla_ops_grad", srcs = ["xla_ops_grad.py"], srcs_version = "PY3", - deps = ["//tensorflow/python:framework_ops"], + deps = ["//tensorflow/python/framework:ops"], ) diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index 17e47dd9a81..c74ea677fcd 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -77,8 +77,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:test", - "//tensorflow/core:test_main", "//tensorflow/core/platform:path", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h index edb6be6a0ff..e8ae70928d1 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -51,7 +52,7 @@ class JitCompilationListener : public XlaActivityListener { bool expect_persistent_cache_use) { for (const auto& activity : activity_history_) { if (activity.used_persistent_cache() != expect_persistent_cache_use) { - return errors::FailedPrecondition("Unexpected listener history."); + return absl::FailedPreconditionError("Unexpected listener history."); } } return OkStatus(); diff --git a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc index 3f96a0f2aa9..052ed6b6f38 100644 --- a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc +++ b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc @@ -23,14 +23,15 @@ Status TfGraphToHloCompiler::Compile(const XlaCompiler::CompileOptions& options, const NameAttrList& function, absl::Span args, XlaCompilationResult* result) { - return xla_compiler_.CompileFunction(options, function, args, result); + return ADD_SOURCE_LOCATION( + xla_compiler_.CompileFunction(options, function, args, result)); } Status TfGraphToHloCompiler::CompileSingleOp( const XlaCompiler::CompileOptions& options, const OpKernelContext* ctx, absl::Span args, XlaCompilationResult* result) { - return xla_compiler_.CompileSingleOp( - options, XlaCompiler::SingleOpCompileArgument(*ctx), args, result); + return ADD_SOURCE_LOCATION(xla_compiler_.CompileSingleOp( + options, XlaCompiler::SingleOpCompileArgument(*ctx), args, result)); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index f6bdaf4e0bc..010ce8bd7c2 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include @@ -43,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tf_pjrt_client.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/core/framework/function.h" @@ -53,6 +55,7 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tfrt/common/pjrt_util.h" #include "tensorflow/tsl/platform/errors.h" namespace tensorflow { @@ -169,28 +172,12 @@ Status XlaCompileOnDemandOp::Compile( DeviceCompilationProfiler** profiler, const XlaCompiler::CompilationResult** result, xla::PjRtLoadedExecutable** executable) { - // We store information about the JIT-compiled XLA computation - // in the ResourceMgr. - ResourceMgr* rm = ctx->resource_manager(); - if (!rm) { - return errors::Internal("No resource manager."); - } + TF_RETURN_IF_ERROR(GetOrCreatePjRtDeviceCompilerAndProfiler( + platform_info_, ctx->function_library(), pjrt_device_compiler, profiler)); - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "pjrt_device_compiler", pjrt_device_compiler, - [&](PjRtDeviceCompiler** pjrt_device_compiler) { - return BuildPjRtDeviceCompiler(platform_info_, ctx->function_library(), - pjrt_device_compiler); - })); - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "pjrt_device_compilation_profiler", profiler, - [](DeviceCompilationProfiler** profiler) { - *profiler = new DeviceCompilationProfiler(); - return OkStatus(); - })); - - XlaCompiler::Options options = GenerateCompilerOptionsForPjRt( - *(ctx->function_library()), ctx->device(), platform_info_); + XlaCompiler::Options options = + GenerateCompilerOptionsForPjRt(*(ctx->function_library()), ctx->device(), + platform_info_, *pjrt_device_compiler); // No detailed logging for on demand op. options.detailed_logging = false; XlaCompiler::CompileOptions compile_options = GetCompileOptions(true); diff --git a/tensorflow/compiler/jit/xla_compile_util.cc b/tensorflow/compiler/jit/xla_compile_util.cc index e5256a8b2c9..6a3e43f4a94 100644 --- a/tensorflow/compiler/jit/xla_compile_util.cc +++ b/tensorflow/compiler/jit/xla_compile_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_util.h" #include +#include #include #include "tensorflow/compiler/jit/flags.h" @@ -24,6 +25,11 @@ limitations under the License. #include "tensorflow/core/util/determinism.h" namespace tensorflow { +namespace { +constexpr const char* kPjRtDeviceCompilerResourceName = "pjrt_device_compiler"; +constexpr const char* kPjRtDeviceCompilationProfilerResourceName = + "pjrt_device_compilation_profiler"; +} // namespace StatusOr> CreateSingleOpGraph( const NodeDef& node_def, absl::Span args, @@ -69,7 +75,18 @@ StatusOr> CreateSingleOpGraph( bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type) { const auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; return rollout_config.IsEnabledInXlaLaunchForDevice(device_type) || - rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type); + rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type) || + rollout_config.IsEnabledInXlaCompileAndRunForDevice(device_type); } +std::string GetPjRtDeviceCompilerResourceName(const DeviceType& device_type) { + return absl::StrCat(kPjRtDeviceCompilerResourceName, "_", + device_type.type_string()); +} + +std::string GetPjRtDeviceCompilationProfilerResourceName( + const DeviceType& device_type) { + return absl::StrCat(kPjRtDeviceCompilationProfilerResourceName, "_", + device_type.type_string()); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_util.h b/tensorflow/compiler/jit/xla_compile_util.h index 345c55a86e5..d555738d4c3 100644 --- a/tensorflow/compiler/jit/xla_compile_util.h +++ b/tensorflow/compiler/jit/xla_compile_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ #include +#include #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/core/graph/graph.h" @@ -47,6 +48,14 @@ StatusOr> CreateSingleOpGraph( // Checks if single device compilation and execution with PJRT is enabled for // `device_type` in either the XlaLaunch op or the XlaCompileOnDemand op. bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type); + +// Gets the resource name of the PjRt DeviceCompiler for `device_type`. +std::string GetPjRtDeviceCompilerResourceName(const DeviceType& device_type); + +// Gets the resource name of the DeviceCompilationProfiler for `device_type` +// when PjRt is used for compilation and execution. +std::string GetPjRtDeviceCompilationProfilerResourceName( + const DeviceType& device_type); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compile_util_test.cc b/tensorflow/compiler/jit/xla_compile_util_test.cc index 9fc706fb649..7e55498ec42 100644 --- a/tensorflow/compiler/jit/xla_compile_util_test.cc +++ b/tensorflow/compiler/jit/xla_compile_util_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/tpu/tpu_defs.h" namespace tensorflow { namespace { @@ -118,5 +119,31 @@ TEST(XlaCompileUtilTest, PjRtXlaCompileOnDemandFlagTest) { EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); } +TEST(XlaCompileUtilTest, PjRtDeviceCompilerResourceName) { + EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_TPU)), + "pjrt_device_compiler_TPU"); + EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_TPU_NODE)), + "pjrt_device_compiler_TPU"); + EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_CPU)), + "pjrt_device_compiler_CPU"); + EXPECT_EQ(GetPjRtDeviceCompilerResourceName(DeviceType(DEVICE_GPU)), + "pjrt_device_compiler_GPU"); +} + +TEST(XlaCompileUtilTest, PjRtDeviceCompilationProfilerResourceName) { + EXPECT_EQ( + GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_TPU)), + "pjrt_device_compilation_profiler_TPU"); + EXPECT_EQ( + GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_TPU_NODE)), + "pjrt_device_compilation_profiler_TPU"); + EXPECT_EQ( + GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_CPU)), + "pjrt_device_compilation_profiler_CPU"); + EXPECT_EQ( + GetPjRtDeviceCompilationProfilerResourceName(DeviceType(DEVICE_GPU)), + "pjrt_device_compilation_profiler_GPU"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compiler_options_util.cc b/tensorflow/compiler/jit/xla_compiler_options_util.cc index 8580bcfbeef..1ba962380d8 100644 --- a/tensorflow/compiler/jit/xla_compiler_options_util.cc +++ b/tensorflow/compiler/jit/xla_compiler_options_util.cc @@ -15,10 +15,14 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compiler_options_util.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" + namespace tensorflow { namespace { using XlaDeviceCompiler = DeviceCompiler; +using PjRtDeviceCompiler = + DeviceCompiler; inline void LogOptions(const XlaCompiler::Options& options) { VLOG(2) << "XlaCompiler::Options[device_type=" << options.device_type @@ -81,7 +85,8 @@ XlaCompiler::Options GenerateCompilerOptionsForTfrtTpu( XlaCompiler::Options GenerateCompilerOptionsForPjRt( const FunctionLibraryRuntime& function_library, - const DeviceBase* device_base, const XlaPlatformInfo& platform_info) { + const DeviceBase* device_base, const XlaPlatformInfo& platform_info, + const PjRtDeviceCompiler* pjrt_device_compiler) { XlaCompiler::Options options; options.device_ordinal = device_base->parsed_name().id; options.flib_def = function_library.GetFunctionLibraryDefinition(); @@ -96,8 +101,9 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt( options.device_type = metadata->jit_device_type(); options.shape_determination_fns = metadata->default_shape_determination_fns(); + } else if (pjrt_device_compiler != nullptr) { + options.device_type = pjrt_device_compiler->device_type(); } - // TODO(b/255826209): Set options for non-XLA devices once PjRt supports them. // TODO(b/255826209): Confirm below options are correctly set after testing. options.allow_cpu_custom_calls = false; options.alias_passthrough_params = false; diff --git a/tensorflow/compiler/jit/xla_compiler_options_util.h b/tensorflow/compiler/jit/xla_compiler_options_util.h index 1be63a6dc8b..1c70b91c8f5 100644 --- a/tensorflow/compiler/jit/xla_compiler_options_util.h +++ b/tensorflow/compiler/jit/xla_compiler_options_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" namespace tensorflow { @@ -39,9 +40,12 @@ XlaCompiler::Options GenerateCompilerOptionsForTfrtTpu( // Returns created options for XLA compiler when PjRt (Device API) is used for // compilation and execution. +// TODO(b/255826209): Remove default arg once PjRtCompileOnDemand op is deleted. XlaCompiler::Options GenerateCompilerOptionsForPjRt( const FunctionLibraryRuntime& function_library, - const DeviceBase* device_base, const XlaPlatformInfo& platform_info); + const DeviceBase* device_base, const XlaPlatformInfo& platform_info, + const DeviceCompiler* + pjrt_device_compiler = nullptr); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compiler_options_util_test.cc b/tensorflow/compiler/jit/xla_compiler_options_util_test.cc index 2a4742567e4..1ab03bc7444 100644 --- a/tensorflow/compiler/jit/xla_compiler_options_util_test.cc +++ b/tensorflow/compiler/jit/xla_compiler_options_util_test.cc @@ -23,12 +23,14 @@ limitations under the License. #include #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h" #include "tensorflow/compiler/jit/test_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_base.h" @@ -42,18 +44,31 @@ using XlaDeviceCompiler = DeviceCompiler; using XlaDeviceExecutablePersistor = DeviceExecutablePersistor; +using PjRtDeviceCompiler = + DeviceCompiler; +using PjRtDeviceExecutablePersistor = + DeviceExecutablePersistor; -XlaDeviceCompiler* CreateXlaDeviceCompiler( - const XlaDeviceExecutablePersistor::Config& persistor_config, - DeviceType device_type, xla::LocalClient* local_client) { +XlaDeviceCompiler* CreateXlaDeviceCompiler(DeviceType device_type, + xla::LocalClient* local_client) { auto persistor = std::make_unique( - std::move(persistor_config), device_type); + XlaDeviceExecutablePersistor::Config(), device_type); auto compiler_client = std::make_unique(local_client); return new XlaDeviceCompiler(std::move(persistor), std::move(compiler_client)); } +PjRtDeviceCompiler* CreatePjRtDeviceCompiler(DeviceType device_type, + xla::PjRtClient* pjrt_client) { + auto persistor = std::make_unique( + PjRtDeviceExecutablePersistor::Config(), device_type); + auto compiler_client = + std::make_unique(pjrt_client); + return new PjRtDeviceCompiler(std::move(persistor), + std::move(compiler_client)); +} + std::vector GetShapeDeterminationFns() { XlaHelpers::ShapeRepresentationFn shape_representation_fn = @@ -160,6 +175,45 @@ TEST_F(XlaCompilerOptionsTest, PjRtOptionsPjRtBaseDevice) { tensorflow::XlaLayoutPreference::kTpuPreferLinearLayout); } +TEST_F(XlaCompilerOptionsTest, PjRtOptionsNonXlaDevice) { + device_setup_.AddDevicesAndSetUp({DEVICE_CPU}); + Device* device = device_setup_.GetDevice(DEVICE_CPU); + DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT); + + XlaPlatformInfo platform_info(compilation_device_type, + /*platform_id=*/nullptr, + /*xla_device_metadata=*/nullptr, + /*pjrt_device_metadata=*/nullptr, + /*device_allocator=*/nullptr); + + auto pjrt_device_compiler = + CreatePjRtDeviceCompiler(compilation_device_type, nullptr); + core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + + XlaCompiler::Options options = GenerateCompilerOptionsForPjRt( + *device_setup_.flr(), device, platform_info, pjrt_device_compiler); + + EXPECT_EQ(options.device_type, compilation_device_type); + EXPECT_EQ(options.device_ordinal, 0); + EXPECT_NE(options.flib_def, nullptr); + EXPECT_EQ(options.graph_def_version, TF_GRAPH_DEF_VERSION); + EXPECT_FALSE(options.allow_cpu_custom_calls); + EXPECT_FALSE(options.alias_passthrough_params); + EXPECT_FALSE(options.detailed_logging); + // Check whether options have default shape determination functions set. + TF_ASSERT_OK_AND_ASSIGN( + auto shape, options.shape_determination_fns.shape_representation_fn( + TensorShape(), DT_FLOAT, false, + tensorflow::XlaLayoutPreference::kNoPreference)); + xla::ShapeProto shape_proto; + shape_proto.set_element_type(xla::PrimitiveType::F32); + shape_proto.mutable_layout(); + EXPECT_EQ(shape, xla::Shape(shape_proto)); + EXPECT_EQ(options.shape_determination_fns.layout_preference_fn( + TensorShape(), DT_FLOAT, std::nullopt), + tensorflow::XlaLayoutPreference::kNoPreference); +} + TEST_F(XlaCompilerOptionsTest, XlaOptions) { device_setup_.AddDevicesAndSetUp({DEVICE_XLA_GPU}); Device* device = device_setup_.GetDevice(DEVICE_XLA_GPU); @@ -168,8 +222,8 @@ TEST_F(XlaCompilerOptionsTest, XlaOptions) { DeviceType device_type = DeviceType(DEVICE_XLA_GPU); DeviceType compilation_device_type = DeviceType(DEVICE_GPU_XLA_JIT); - auto xla_device_compiler = CreateXlaDeviceCompiler( - XlaDeviceExecutablePersistor::Config(), compilation_device_type, client); + auto xla_device_compiler = + CreateXlaDeviceCompiler(compilation_device_type, client); core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); se::Platform::Id platform_id = se::host::kHostPlatformId; @@ -208,8 +262,8 @@ TEST_F(XlaCompilerOptionsTest, XlaOptionsHasRefVarsNoXlaDeviceMetadata) { DeviceType device_type = DeviceType(DEVICE_CPU); DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT); - auto xla_device_compiler = CreateXlaDeviceCompiler( - XlaDeviceExecutablePersistor::Config(), compilation_device_type, client); + auto xla_device_compiler = + CreateXlaDeviceCompiler(compilation_device_type, client); core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); se::Platform::Id platform_id = se::host::kHostPlatformId; @@ -249,8 +303,8 @@ TEST_F(XlaCompilerOptionsTest, TfRtTpuOptions) { xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT); - auto xla_device_compiler = CreateXlaDeviceCompiler( - XlaDeviceExecutablePersistor::Config(), compilation_device_type, client); + auto xla_device_compiler = + CreateXlaDeviceCompiler(compilation_device_type, client); core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); XlaCompiler::Options options = GenerateCompilerOptionsForTfrtTpu( diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.cc b/tensorflow/compiler/jit/xla_host_recv_device_context.cc new file mode 100644 index 00000000000..b634ac88739 --- /dev/null +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.cc @@ -0,0 +1,49 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/jit/xla_host_recv_device_context.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" + +namespace tensorflow { + +void XlaHostRecvDeviceContext::CopyDeviceTensorToCPU( + const Tensor* device_tensor, StringPiece tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) { + DataType dtype = EncodePrimitiveTypeAsDataType(shape_.element_type()).value(); + TensorShape tensor_shape; + Status status = XLAShapeToTensorShape(shape_, &tensor_shape); + if (!status.ok()) { + done(status); + return; + } + + *cpu_tensor = Tensor(dtype, tensor_shape); + + stream_->ThenMemcpy(cpu_tensor->data(), device_memory_base_, + device_memory_base_.size()); + stream_->ThenRecordEvent(&done_event_.get()); + if (auto st = stream_->BlockHostUntilDone(); !st.ok()) { + done_event_.SetError(absl::InternalError(absl::StrFormat( + "failed to synchronize send operation with a stream: %s", + st.ToString()))); + return; + } + + done_event_.SetStateConcrete(); + done(OkStatus()); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.h b/tensorflow/compiler/jit/xla_host_recv_device_context.h new file mode 100644 index 00000000000..e2c5d1767d1 --- /dev/null +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.h @@ -0,0 +1,92 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_ + +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/core/framework/device_base.h" +#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime + +namespace tensorflow { + +// XlaHostRecvDeviceContext is a DeviceContext that is intended to be +// used to transfer from device->host using Rendezvous. It transfers the +// content of `device_memory_base` with `shape` using `stream`. Only +// `CopyDeviceTensorToCPU` method is implemented. The `done_event` is marked as +// Concrete once transfer is completed. +// +// Example usage: +// +// Device device; +// stream_executor::Stream stream(executor); +// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2})); +// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; +// xla::Shape shape(xla::F32, {2, 2}, {}, {}) +// tsl::AsyncValueRef done_event = +// tsl::MakeConstructedAsyncValueRef(stream.parent()); +// done_event->Init(); +// Tensor dest_cpu_tensor; +// +// XlaHostRecvDeviceContext device_context(&stream, gpu_dst, +// shape, done_event); +// device_context.CopyDeviceTensorToCPUSync( +// &device_tensor, "", &device, &dest_cpu_tensor); + +class XlaHostRecvDeviceContext : public DeviceContext { + public: + XlaHostRecvDeviceContext(se::Stream* stream, + const se::DeviceMemoryBase& device_memory_base, + const xla::Shape& shape, + tsl::AsyncValueRef& done_event) + : stream_(stream), + device_memory_base_(device_memory_base), + shape_(shape), + done_event_(done_event) {} + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override { + done(errors::Internal("host->device copy not implemented.")); + } + + // Copies `device_memory_base_` with `shape_` into `cpu_tensor`. + // `device_tensor` is unused. + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + StringPiece tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override { + done(errors::Internal("device->device copy not implemented.")); + } + + private: + se::Stream* stream_; // Not owned. + // This is copied rather than a reference or pointer since its lifetime + // is not guaranteed to outlast the original object. Object slicing is + // not an issue here since only DeviceMemoryBase methods/members are used. + const se::DeviceMemoryBase device_memory_base_; + const xla::Shape shape_; + tsl::AsyncValueRef done_event_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaHostRecvDeviceContext); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_ diff --git a/tensorflow/compiler/jit/xla_host_send_device_context.cc b/tensorflow/compiler/jit/xla_host_send_device_context.cc new file mode 100644 index 00000000000..1c30ef022a8 --- /dev/null +++ b/tensorflow/compiler/jit/xla_host_send_device_context.cc @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/jit/xla_host_send_device_context.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" + +namespace tensorflow { + +void XlaHostSendDeviceContext::CopyCPUTensorToDevice( + const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, + StatusCallback done, bool sync_dst_compute) const { + stream_->ThenMemcpy(device_memory_base_, cpu_tensor->data(), + device_memory_base_->size()); + stream_->ThenRecordEvent(&done_event_.get()); + if (auto st = stream_->BlockHostUntilDone(); !st.ok()) { + done_event_.SetError(absl::InternalError(absl::StrFormat( + "failed to synchronize send operation with a stream: %s", + st.ToString()))); + return; + } + + done_event_.SetStateConcrete(); + done(OkStatus()); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_host_send_device_context.h b/tensorflow/compiler/jit/xla_host_send_device_context.h new file mode 100644 index 00000000000..ce292fa61d1 --- /dev/null +++ b/tensorflow/compiler/jit/xla_host_send_device_context.h @@ -0,0 +1,89 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_ + +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/core/framework/device_base.h" +#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime + +namespace tensorflow { + +// XlaHostSendDeviceContext is a DeviceContext that is intended to be +// used to transfer from host->device using Rendezvous. It transfers the +// content of `device_memory_base` with `shape` using `stream`. Only +// `CopyCPUTensorToDevice` method is implemented. The `done_event` is marked as +// Concrete once transfer is completed. +// +// Example usage: +// +// Device device; +// stream_executor::Stream stream(executor); +// Tensor cpu_tensor(host_allocator, DT_FLOAT, TensorShape({2, 2})); +// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2})); +// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; +// xla::Shape shape(xla::F32, {2, 2}, {}, {}) +// tsl::AsyncValueRef done_event = +// tsl::MakeConstructedAsyncValueRef(stream.parent()); +// done_event->Init(); +// +// XlaHostSendDeviceContext device_context(&stream, &gpu_dst, +// shape, done_event); +// device_context.CopyCPUTensorToDeviceSync( +// &cpu_tensor, &device, &device_tensor); + +class XlaHostSendDeviceContext : public DeviceContext { + public: + XlaHostSendDeviceContext(se::Stream* stream, + se::DeviceMemoryBase* device_memory_base, + const xla::Shape& shape, + tsl::AsyncValueRef& done_event) + : stream_(stream), + device_memory_base_(device_memory_base), + shape_(shape), + done_event_(done_event) {} + + // Copies 'cpu_tensor' to `device_memory_base_` with `shape_`. + // `device_tensor` is unused. + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + StringPiece tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override { + done(errors::Internal("host->device copy not implemented.")); + } + + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override { + done(errors::Internal("device->device copy not implemented.")); + } + + private: + se::Stream* stream_; // Not owned. + se::DeviceMemoryBase* device_memory_base_; // Not owned. + const xla::Shape shape_; + tsl::AsyncValueRef done_event_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaHostSendDeviceContext); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_ diff --git a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc new file mode 100644 index 00000000000..90d7b3b7b8f --- /dev/null +++ b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include +#include "tensorflow/compiler/jit/xla_host_recv_device_context.h" +#include "tensorflow/compiler/jit/xla_host_send_device_context.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +class XlaHostSendRecvDeviceContextTest : public ::testing::Test { + public: + void SetDevice(const string& device_type) { + auto device_factory = DeviceFactory::GetFactory(device_type); + SessionOptions options; + std::vector> devices; + Status s = device_factory->CreateDevices( + options, "/job:worker/replica:0/task:0", &devices); + device_ = std::move(devices[0]); + + AllocatorAttributes host_alloc_attr; + host_alloc_attr.set_on_host(true); + host_allocator_ = device_->GetAllocator(host_alloc_attr); + + AllocatorAttributes device_alloc_attr; + device_alloc_attr.set_on_host(false); + device_allocator_ = device_->GetAllocator(device_alloc_attr); + } + + protected: + std::unique_ptr device_; + Allocator* host_allocator_; + Allocator* device_allocator_; +}; + +TEST_F(XlaHostSendRecvDeviceContextTest, CopyDeviceTensorToCPU) { + SetDevice("GPU"); + Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5}); + Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2})); + Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); + + stream_executor::Platform* platform = + stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value(); + stream_executor::StreamExecutor* executor = + platform->ExecutorForDevice(0).value(); + stream_executor::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; + xla::Shape shape; + TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape)); + + // Copy the cpu_tensor to the GPU first before trying to copy it back. + stream.ThenMemcpy(&gpu_dst, origin_cpu_tensor.data(), gpu_dst.size()); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + tsl::AsyncValueRef done_event = + tsl::MakeConstructedAsyncValueRef(stream.parent()); + done_event->Init(); + XlaHostRecvDeviceContext* device_context = + new XlaHostRecvDeviceContext(&stream, gpu_dst, shape, done_event); + TF_ASSERT_OK(device_context->CopyDeviceTensorToCPUSync( + &device_tensor, "", device_.get(), &dest_cpu_tensor)); + + tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor); + device_context->Unref(); +} + +TEST_F(XlaHostSendRecvDeviceContextTest, CopyCPUTensorToDevice) { + SetDevice("GPU"); + Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5}); + Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2})); + Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); + + stream_executor::Platform* platform = + stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value(); + stream_executor::StreamExecutor* executor = + platform->ExecutorForDevice(0).value(); + stream_executor::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; + xla::Shape shape; + TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape)); + + tsl::AsyncValueRef done_event = + tsl::MakeConstructedAsyncValueRef(stream.parent()); + done_event->Init(); + XlaHostSendDeviceContext* device_context = + new XlaHostSendDeviceContext(&stream, &gpu_dst, shape, done_event); + TF_ASSERT_OK(device_context->CopyCPUTensorToDeviceSync( + &origin_cpu_tensor, device_.get(), &device_tensor)); + + // Copy the GPU tensor back to CPU to check that copy worked. + stream.ThenMemcpy(dest_cpu_tensor.data(), gpu_dst, gpu_dst.size()); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor); + device_context->Unref(); +} + +TEST_F(XlaHostSendRecvDeviceContextTest, RoundTrip) { + SetDevice("GPU"); + Tensor origin_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); + test::FillValues(&origin_cpu_tensor, {1.2, 2.3, 3.4, 4.5}); + Tensor device_tensor(device_allocator_, DT_FLOAT, TensorShape({2, 2})); + Tensor dest_cpu_tensor(host_allocator_, DT_FLOAT, TensorShape({2, 2})); + + stream_executor::Platform* platform = + stream_executor::MultiPlatformManager::PlatformWithName("CUDA").value(); + stream_executor::StreamExecutor* executor = + platform->ExecutorForDevice(0).value(); + stream_executor::Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + + se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; + xla::Shape shape; + TF_ASSERT_OK(TensorShapeToXLAShape(DT_FLOAT, TensorShape({2, 2}), &shape)); + + tsl::AsyncValueRef send_done_event = + tsl::MakeConstructedAsyncValueRef(stream.parent()); + send_done_event->Init(); + XlaHostSendDeviceContext* send_device_context = + new XlaHostSendDeviceContext(&stream, &gpu_dst, shape, send_done_event); + TF_ASSERT_OK(send_device_context->CopyCPUTensorToDeviceSync( + &origin_cpu_tensor, device_.get(), &device_tensor)); + + tsl::AsyncValueRef recv_done_event = + tsl::MakeConstructedAsyncValueRef(stream.parent()); + recv_done_event->Init(); + XlaHostRecvDeviceContext* recv_device_context = + new XlaHostRecvDeviceContext(&stream, gpu_dst, shape, recv_done_event); + TF_ASSERT_OK(recv_device_context->CopyDeviceTensorToCPUSync( + &device_tensor, "", device_.get(), &dest_cpu_tensor)); + + tensorflow::test::ExpectClose(origin_cpu_tensor, dest_cpu_tensor); + send_device_context->Unref(); + recv_device_context->Unref(); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index e683caa1aac..b66e4270d3a 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_kernel_creator.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" @@ -136,7 +137,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) { input: 'b' )proto"), &kernel_); - EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); + EXPECT_TRUE(absl::IsInternal(status)) << status; } TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { @@ -153,7 +154,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { input: 'b' )proto"), &kernel_); - EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); + EXPECT_TRUE(absl::IsInternal(status)) << status; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index b311faa13ab..6d6c5e5b492 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -18,17 +18,21 @@ limitations under the License. #include #include #include +#include #include #include #include "tensorflow/compiler/jit/device_executable_persistor.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/pjrt_device_compiler_client.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/jit/xla_device_compiler_client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" +#include "tensorflow/core/tfrt/common/global_state.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" #include "tensorflow/core/tpu/tpu_defs.h" @@ -45,19 +49,23 @@ using PjRtDeviceExecutablePersistor = XlaDeviceCompiler* CreateXlaDeviceCompiler( const XlaDeviceExecutablePersistor::Config& persistor_config, - DeviceType device_type, xla::LocalClient* local_client) { + DeviceType compilation_device_type, xla::LocalClient* local_client) { return new XlaDeviceCompiler( std::make_unique( - std::move(persistor_config), device_type), + std::move(persistor_config), compilation_device_type), std::make_unique(local_client)); } -PjRtDeviceCompiler* CreatePjRtDeviceCompiler( - const PjRtDeviceExecutablePersistor::Config& persistor_config, - DeviceType device_type, xla::PjRtClient* pjrt_client) { +PjRtDeviceCompiler* CreatePjRtDeviceCompiler(DeviceType compilation_device_type, + xla::PjRtClient* pjrt_client) { + PjRtDeviceExecutablePersistor::Config persistor_config( + GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory, + GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks, + GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix); + return new PjRtDeviceCompiler( std::make_unique( - std::move(persistor_config), device_type), + std::move(persistor_config), compilation_device_type), std::make_unique(pjrt_client)); } @@ -73,6 +81,60 @@ StatusOr>> GetAllowedGpus( return gpu_ids; } + +Status GetCompilationDeviceTypeAndPjRtClient( + const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr, + DeviceType* compilation_device_type, xla::PjRtClient** pjrt_client) { + DeviceType device_type = platform_info.device_type(); + + if (platform_info.xla_device_metadata()) { + VLOG(2) << "Building PjRtDeviceCompiler using " + "platform_info.xla_device_metadata()."; + + *compilation_device_type = + platform_info.xla_device_metadata()->jit_device_type(); + TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type)); + return OkStatus(); + } + + if (platform_info.pjrt_device_metadata()) { + VLOG(2) << "Building PjRtDeviceCompiler using " + "platform_info.pjrt_device_metadata()."; + + *compilation_device_type = + platform_info.pjrt_device_metadata()->jit_device_type(); + TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type)); + return OkStatus(); + } + + // TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not + // have `xla_device_metadata`. + if (device_type == DEVICE_TPU) { + *compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT); + TF_ASSIGN_OR_RETURN(*pjrt_client, GetOrCreatePjRtClient(device_type)); + return OkStatus(); + } + + VLOG(2) << "platform_info.xla_device_metadata not found and " + "platform_info.device_type() != DEVICE_TPU. Building " + "PjRtDeviceCompiler for non-XLA device."; + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + device_type.type()); + } + *compilation_device_type = DeviceType(registration->compilation_device_name); + + TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr)); + // TODO(b/255826209): Set platform, intra op parallelism threads if required + // and when supported by GetOrCreatePjRtClient(). + // The `allowed_gpus` argument is used only if the `device_type` is GPU. + TF_ASSIGN_OR_RETURN(*pjrt_client, + GetOrCreatePjRtClient(device_type, allowed_gpus)); + + return OkStatus(); +} } // namespace xla::StatusOr>> ParseVisibleDeviceList( @@ -175,71 +237,45 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, return OkStatus(); } -Status BuildPjRtDeviceCompiler(const XlaPlatformInfo& platform_info, - FunctionLibraryRuntime* flr, - PjRtDeviceCompiler** pjrt_device_compiler) { - PjRtDeviceExecutablePersistor::Config persistor_config( - GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory, - GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks, - GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix); +Status GetOrCreatePjRtDeviceCompilerAndProfiler( + const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr, + PjRtDeviceCompiler** pjrt_device_compiler, + DeviceCompilationProfiler** profiler) { + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = tfrt_global::GetTFGlobalResourceMgr(); - DeviceType device_type = platform_info.device_type(); + const auto& device_type = platform_info.device_type(); + const std::string& compiler_name = + GetPjRtDeviceCompilerResourceName(device_type); - if (platform_info.xla_device_metadata()) { - VLOG(2) << "Building PjRtDeviceCompiler using " - "platform_info.xla_device_metadata()."; + // Lookup the DeviceCompiler, create one if not found. + Status s = rm->Lookup( + rm->default_container(), compiler_name, pjrt_device_compiler); + if (!s.ok()) { + DeviceType compilation_device_type(""); + xla::PjRtClient* pjrt_client = nullptr; + TF_RETURN_IF_ERROR(GetCompilationDeviceTypeAndPjRtClient( + platform_info, flr, &compilation_device_type, &pjrt_client)); - DeviceType compilation_device_type = - platform_info.xla_device_metadata()->jit_device_type(); - TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); - - *pjrt_device_compiler = CreatePjRtDeviceCompiler( - persistor_config, compilation_device_type, pjrt_client); - return OkStatus(); - } - if (platform_info.pjrt_device_metadata()) { - VLOG(2) << "Building PjRtDeviceCompiler using " - "platform_info.pjrt_device_metadata()."; - - DeviceType compilation_device_type = - platform_info.pjrt_device_metadata()->jit_device_type(); - TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); - - *pjrt_device_compiler = CreatePjRtDeviceCompiler( - persistor_config, compilation_device_type, pjrt_client); - return OkStatus(); + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), compiler_name, pjrt_device_compiler, + [&](PjRtDeviceCompiler** pjrt_device_compiler) { + *pjrt_device_compiler = + CreatePjRtDeviceCompiler(compilation_device_type, pjrt_client); + return OkStatus(); + })); } - // TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not - // have `xla_device_metadata`. - if (device_type == DEVICE_TPU) { - TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); - *pjrt_device_compiler = CreatePjRtDeviceCompiler( - persistor_config, DeviceType(DEVICE_TPU_XLA_JIT), pjrt_client); - return OkStatus(); - } + const std::string& profiler_name = + GetPjRtDeviceCompilationProfilerResourceName(device_type); + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), profiler_name, profiler, + [](DeviceCompilationProfiler** profiler) { + *profiler = new DeviceCompilationProfiler(); + return OkStatus(); + })); - VLOG(2) << "platform_info.xla_device_metadata not found and " - "platform_info.device_type() != DEVICE_TPU. Building " - "PjRtDeviceCompiler for non-XLA device."; - - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - return errors::InvalidArgument("No JIT device registered for ", - device_type.type()); - } - auto compilation_device_type = - DeviceType(registration->compilation_device_name); - - TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr)); - // TODO(b/255826209): Set platform, intra op parallelism threads if required - // and when supported by GetOrCreatePjRtClient(). - // The `allowed_gpus` argument is used only if the `device_type` is GPU. - TF_ASSIGN_OR_RETURN(auto pjrt_client, - GetOrCreatePjRtClient(device_type, allowed_gpus)); - - *pjrt_device_compiler = CreatePjRtDeviceCompiler( - persistor_config, compilation_device_type, pjrt_client); return OkStatus(); } diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index 725a876904d..4a8bc27a045 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/pjrt_base_device.h" @@ -113,17 +114,21 @@ Status BuildXlaDeviceCompiler( DeviceCompiler** xla_device_compiler); -// Builds a DeviceCompiler that uses xla::PjRtClient using an appropriate +// Fetches a DeviceCompiler from the tfrt_global resource manager (or creates +// one there if not found) that uses xla::PjRtClient using an appropriate // PjRtClient for `platform_info.device_type()` and sets *pjrt_device_compiler -// to point to it. Uses flags from `MarkForCompilationPassFlags` for configuring -// the persistor used in the DeviceCompiler. Please note that non-XLA devices -// aren't supported yet. This is because: +// to point to it. Also fetches/creates a DeviceCompilationProfiler from/in the +// tfrt_global resource manager for `platform_info.device_type()` and sets +// *profiler to point to it. Uses flags from `MarkForCompilationPassFlags` for +// configuring the persistor used in the DeviceCompiler. Please note that +// non-XLA devices aren't supported yet. This is because: // 1. PjRtClient doesn't support data transfer for non-XLA devices yet // 2. Fetching the PjRtClient for non-XLA devices is also not supported yet -Status BuildPjRtDeviceCompiler( +Status GetOrCreatePjRtDeviceCompilerAndProfiler( const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr, DeviceCompiler** - pjrt_device_compiler); + pjrt_device_compiler, + DeviceCompilationProfiler** profiler); // Returns information about the platform from kernel context. XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); diff --git a/tensorflow/compiler/jit/xla_platform_info_test.cc b/tensorflow/compiler/jit/xla_platform_info_test.cc index 0dedbb39bb9..e12a9366c04 100644 --- a/tensorflow/compiler/jit/xla_platform_info_test.cc +++ b/tensorflow/compiler/jit/xla_platform_info_test.cc @@ -81,7 +81,7 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerNonXlaDevice) { EXPECT_TRUE(xla_device_compiler->client() != nullptr); } -TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestXlaDevice) { +TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerXlaDevice) { DeviceType device_type = DeviceType(DEVICE_XLA_GPU); device_setup_.AddDevicesAndSetUp({device_type.type()}); @@ -91,23 +91,27 @@ TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestXlaDevice) { XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device); PjRtDeviceCompiler* pjrt_device_compiler = nullptr; - TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(), - &pjrt_device_compiler)); + DeviceCompilationProfiler* profiler = nullptr; + TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler( + platform_info, device_setup_.flr(), &pjrt_device_compiler, &profiler)); core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + core::ScopedUnref profiler_ref(profiler); TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); EXPECT_EQ(pjrt_device_compiler->device_type(), metadata->jit_device_type()); EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client); } -TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestGpuDevice) { +TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerGpuDevice) { device_setup_.AddDevicesAndSetUp({DEVICE_GPU}); Device* device = device_setup_.GetDevice(DEVICE_GPU); XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device); PjRtDeviceCompiler* pjrt_device_compiler = nullptr; - TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(), - &pjrt_device_compiler)); + DeviceCompilationProfiler* profiler = nullptr; + TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler( + platform_info, device_setup_.flr(), &pjrt_device_compiler, &profiler)); core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + core::ScopedUnref profiler_ref(profiler); } #endif @@ -138,7 +142,7 @@ TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerTpuDevice) { // TODO(b/255826209): Look into using an actual TPU device for the unit test, // and move this out of OSS. -TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTpuDevice) { +TEST_F(XlaPlatformInfoTest, GetOrCreatePjRtDeviceCompilerAndProfilerTpuDevice) { DeviceType device_type = DeviceType(DEVICE_TPU); DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT); // Use a CPU PjRtClient instead of a TPU one just for testing whether @@ -158,9 +162,11 @@ TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTpuDevice) { /*device_allocator=*/nullptr); PjRtDeviceCompiler* pjrt_device_compiler = nullptr; - TF_EXPECT_OK( - BuildPjRtDeviceCompiler(platform_info, nullptr, &pjrt_device_compiler)); + DeviceCompilationProfiler* profiler = nullptr; + TF_EXPECT_OK(GetOrCreatePjRtDeviceCompilerAndProfiler( + platform_info, nullptr, &pjrt_device_compiler, &profiler)); core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + core::ScopedUnref profiler_ref(profiler); EXPECT_EQ(pjrt_device_compiler->device_type(), compilation_device_type); EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client); diff --git a/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md index 9405aa417df..342b5e0d23b 100644 --- a/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md +++ b/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md @@ -272,6 +272,10 @@ compiling to XLA. ### `-tf-embedding-pipelining`: Rewrite graph for embedding pipelining For architectures that support accelerated embedding lookups, this pass will rewrite the graph to use pipelining for better device utilization. +### `-tf-embedding-sequencing`: Rewrite graph for sequential execution of embeddings +This is a strictly sequential and formally correct fallback option for the +embedding pipelining pass intended for debugging during pipelining +development. ### `-tf-executor-break-up-islands`: Transform from TF control dialect to TF executor dialect. ### `-tf-executor-check-control-dependencies`: Checks control dependencies This pass analyzes control dependencies between islands and warns about @@ -1993,4 +1997,21 @@ This pass will transform it into ### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph Verifies whether all functions in module are of single tf_executor.graph and each tf_executor.island in tf_executor.graph only has a single op. +### `-tf-xla-call-module-deserialization`: Deserializes StableHLO functions embedded in `tf.XlaCallModule` to top level module +This pass deserializes the StableHLO bytecodes embedded in tf.XlaCallModule, +then outlines the functions in the deserialized StableHLO module to the top +level MLIR module, with function renamings to avoid naming conflicts. + +After the outlining, it updates tf.XlaCallModule's module attribute to be +empty, adds an `_entry_function` attribute referring to the entry function. +It also adds a `_from_xla_call_module: true` attribute to each lifted +StableHLO function. +### `-tf-xla-call-module-serialization`: Serializes StableHLO functions from top-level module into `tf.XlaCallModule`'s `module` attribute +This pass collects StableHLO functions referenced from `tf.XlaCallModule`'s +`_entry_function` attribute into a module, serializes the module into MLIR +bytecode, and embed the bytecode to `tf.XlaCallModule`'s `module` attribute. + +After serialization, this pass removes the `_entry_function` attribute from +`tf.XlaCallModule`, and removes all the serialized stablehlo functions +from the top-level module. ### `-tfe-legalize-tfg`: Legalize from TFG to the TFE dialect diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index 35491ed3d55..f65c86b727b 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -64,6 +64,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties): ) def glob_lit_tests( + name = None, exclude = [], test_file_exts = _default_test_file_exts, default_size = _default_size, @@ -78,6 +79,7 @@ def glob_lit_tests( """Creates all plausible Lit tests (and their inputs) under this directory. Args: + name: str, name of the test_suite rule to generate for running all tests. exclude: [str], paths to exclude (for tests and inputs). test_file_exts: [str], extensions for files that are tests. default_size: str, the test size for targets not in "size_override". @@ -103,7 +105,10 @@ def glob_lit_tests( # Run tests individually such that errors can be attributed to a specific # failure. + all_tests = [] for curr_test in tests: + all_tests.append(curr_test + ".test") + # Instantiate this test with updated parameters. _run_lit_test( name = curr_test + ".test", @@ -114,3 +119,11 @@ def glob_lit_tests( features = features, exec_properties = exec_properties, ) + + # TODO: remove this check after making it a required param. + if name: + native.test_suite( + name = name, + tests = all_tests, + tags = ["manual"], + ) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 27c2706622a..5bd1e41c068 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -342,8 +342,6 @@ cc_library( "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", ], @@ -1235,6 +1233,8 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", + "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", + "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 1ea913ae5c4..642a3349528 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -83,7 +83,9 @@ struct PassConfig { // Whether to run the `GuaranteeAllFuncsOneUsePass` to ensure each function // has a single use. bool guarantee_all_funcs_one_use; - // Whether to enable the hlo to tf conversion. + // Whether to enable the hlo/stablehlo to tf conversion. This also supports + // the case where a saved model contains both TF module and serialized + // StableHLO module. bool enable_hlo_to_tf_conversion; // Whether to enable to use DynamicUpdateSlice op. bool enable_dynamic_update_slice; diff --git a/tensorflow/compiler/mlir/lite/emit_error_reporter.cc b/tensorflow/compiler/mlir/lite/emit_error_reporter.cc index d280bec85f5..f9c4760326b 100644 --- a/tensorflow/compiler/mlir/lite/emit_error_reporter.cc +++ b/tensorflow/compiler/mlir/lite/emit_error_reporter.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/emit_error_reporter.h" +#include +#include + namespace tflite { int EmitErrorReporter::Report(const char* format, va_list args) { diff --git a/tensorflow/compiler/mlir/lite/experimental/common/BUILD b/tensorflow/compiler/mlir/lite/experimental/common/BUILD index 4b7a41ab347..02fab009fda 100644 --- a/tensorflow/compiler/mlir/lite/experimental/common/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/common/BUILD @@ -1,5 +1,7 @@ load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + cc_library( name = "outline_operations", srcs = ["outline_operations.cc"], diff --git a/tensorflow/compiler/mlir/lite/experimental/remat/BUILD b/tensorflow/compiler/mlir/lite/experimental/remat/BUILD index b9b69bf852d..f0d059ca919 100644 --- a/tensorflow/compiler/mlir/lite/experimental/remat/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/remat/BUILD @@ -1,5 +1,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + cc_library( name = "rematerializer", srcs = ["rematerializer.cc"], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index 0903cda3e43..c5c4c422bf8 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -232,6 +232,7 @@ cc_library( "transforms/get_alternative_subgraph.cc", "transforms/pick_subgraphs.cc", "transforms/raise_target_subgraphs.cc", + "transforms/tac_filter.cc", "transforms/target_annotation.cc", ], hdrs = [ @@ -243,6 +244,7 @@ cc_library( ":common", ":cost_model", ":device_transform", + ":tac_filter_cc_proto", ":tac_importer_exporter", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", @@ -253,6 +255,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf_headers", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", @@ -380,3 +383,15 @@ py_library( "//tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper:_pywrap_tac_wrapper", ], ) + +proto_library( + name = "tac_filter_proto", + srcs = ["tac_filter.proto"], + compatible_with = get_compatible_with_cloud(), +) + +cc_proto_library( + name = "tac_filter_cc_proto", + compatible_with = get_compatible_with_cloud(), + deps = [":tac_filter_proto"], +) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h b/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h index 2bd79ddf848..2f2992871a1 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_ #include +#include #include #include #include @@ -36,6 +37,9 @@ constexpr char kDevice[] = "tac.device"; // Inference type. constexpr char kInferenceType[] = "tac.inference_type"; +// Inference type. +constexpr char kSkipTargetAnnotation[] = "tac.skip_target_annotation"; + // TODO(renjieliu): Add more inference types. enum InferenceType { UNKNOWN = 0, diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.cc index e300f0686ac..6a580db3185 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.h" +#include + #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc index 6d17c7f6ff6..932e047f7a4 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc @@ -15,6 +15,7 @@ #include "tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h" #include +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc index a4f09f98bc7..09876d9373f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/cpu_hardware.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc index ed3785f6898..6af9f6211d3 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h" +#include + #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc index ad09d8d2762..ab2de0b75d2 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.h" +#include + #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.cc index f1b883fda58..62874d7cc51 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/DenseMap.h" #include "llvm/Support/raw_ostream.h" @@ -33,22 +34,16 @@ namespace mlir { namespace TFL { namespace tac { namespace { -struct RegisteredTargetHardware { - // TODO(b/177376459): Remove this constructor. - RegisteredTargetHardware(const std::string& name, - const std::string& description, mlir::TypeID type_id, - std::unique_ptr target_hardware) - : unique_name(GetCanonicalHardwareName(name)), - description(description), - type_id(type_id), - target_hardware(std::move(target_hardware)) {} +struct RegisteredTargetHardware { RegisteredTargetHardware( const std::string& name, const std::string& description, mlir::TypeID type_id, std::function()> target_hardware_factory) : unique_name(GetCanonicalHardwareName(name)), description(description), + type_id(type_id), + target_hardware(target_hardware_factory()), target_hardware_factory(target_hardware_factory) {} std::string unique_name; @@ -185,22 +180,6 @@ std::function()> GetTargetHardwareFactory( namespace internal { -void RegisterTargetHardware( - const std::string& unique_name, const std::string& description, - mlir::TypeID type_id, - std::function()> target_hardware_factory) { - auto* registered_hardwares = GetRegisteredHardwares(); - for (const auto& hardware : *registered_hardwares) { - if (hardware.unique_name == unique_name) { - llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name - << " already registered\n"; - return; - } - } - registered_hardwares->push_back(RegisteredTargetHardware( - unique_name, description, type_id, target_hardware_factory())); -} - void RegisterTargetHardwareFactory( const std::string& unique_name, const std::string& description, mlir::TypeID type_id, diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h index 38286ed3cfe..9a1e21dcc19 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h @@ -44,7 +44,7 @@ constexpr static float kCrossHardwareTransferFixedCost = 10.f; // for registering the operation. class TargetHardwareOperation { public: - virtual ~TargetHardwareOperation() {} + virtual ~TargetHardwareOperation() = default; virtual double GetOpCost(mlir::Operation* op) const = 0; @@ -64,7 +64,7 @@ class TargetHardwareOperation { // }; class TargetHardware { public: - virtual ~TargetHardware() {} + virtual ~TargetHardware() = default; // Initializes all TargetHardwareOperation registered for this hardware. // Users overriding this function, should call the base class method to @@ -111,20 +111,6 @@ std::function()> GetTargetHardwareFactory( const std::string& hardware_name); namespace internal { -// DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead. -void RegisterTargetHardware( - const std::string& unique_name, const std::string& description, - mlir::TypeID type_id, - std::function()> target_hardware_factory); - -// DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead. -template -void RegisterTargetHardware( - const std::string& description, - std::function()> target_hardware_factory) { - RegisterTargetHardware(T::kId, description, mlir::TypeID::get(), - target_hardware_factory); -} void RegisterTargetHardwareFactory( const std::string& unique_name, const std::string& description, @@ -158,9 +144,6 @@ struct TargetHardwareRegistration { TargetHardwareRegistration(const std::string& description, std::function()> target_hardware_factory) { - // TODO(b/177376459): remove this. - internal::RegisterTargetHardware(description, - target_hardware_factory); internal::RegisterTargetHardwareFactory(description, target_hardware_factory); } @@ -185,7 +168,7 @@ struct TargetHardwareOpRegistration { //======== util functions ========== // Process user specified device specs, will always add CPU if it's not there. -// specified_deivce_specs: ',' separated, like "GPU,DSP,CPU". +// specified_device_specs: ',' separated, like "GPU,DSP,CPU". // device_specs: processed device specs enum. bool ProcessTargetDevices(llvm::ArrayRef specified_device_specs, std::vector* device_specs); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD index d363334fb5f..57ee70321ee 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD @@ -103,7 +103,7 @@ pybind_extension( ] + if_onednn_v3(["@onednn_v3//:__subpackages__"]), deps = [ ":tac_wrapper_lib", - "//tensorflow/python:pybind11_lib", + "//tensorflow/python/lib/core:pybind11_lib", "//third_party/python_runtime:headers", "@pybind11", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.cc b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.cc index b9b4d3d465b..0ae1e3db3fe 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper_pybind11.cc b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper_pybind11.cc index 5d0366515cf..18616733118 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper_pybind11.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper_pybind11.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include "pybind11/pybind11.h" // from @pybind11 #include "tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tac_filter.proto b/tensorflow/compiler/mlir/lite/experimental/tac/tac_filter.proto new file mode 100644 index 00000000000..d26e0996dbe --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tac_filter.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package third_party.tensorflow.compiler.mlir.lite.experimental.tac; + +// A list of filters for TAC users to run ops/functions on ML hardwares. The +// intuition is that, for ops/functions that can be run on ML hardware (e.g. +// EdgeTPU) and TFLite CPU, TAC users give a hint that they're more performant +// to run on TFLite CPU. These filters give the TAC users freedom to specify the +// parts that they want to use other hardware to accelerate. +message TacFilters { + // A list of filters/rules to specify the parts that user wants to run on + // other hardware. + repeated TacFilter tac_filters = 1; +} + +// A filter can be used for an op or function. +message TacFilter { + oneof filter { + OpFilter op_filter = 1; + FunctionFilter function_filter = 2; + } +} + +// Function filter is to include/exclude a function in the target annotation +// pass in the TAC tool pipeline. +message FunctionFilter { + // Function filter types that are supported. If one function is matched for + // two rules with conflict, INCLUDE_TARGET_ANNOTATION has higher priority. + enum FunctionFilterType { + // To skip this function in the target annotation pass. This means all ops + // in this function run on TFLite CPU. + SKIP_TARGET_ANNOTATION = 0; + // To include this function in the target annotation pass. This has higher + // priority than `SKIP_TARGET_ANNOTATION`. + INCLUDE_TARGET_ANNOTATION = 1; + } + // This name corresponds to the TFLite subgraph name in the flatbuffer. + // `function_name_pattern` supports regex matching. + string function_name_pattern = 1; + FunctionFilterType filter_type = 2; +} + +// Op filter is to filter out ops that user wants to run. Ops with this filter +// run on TFLite CPU. +message OpFilter { + // This name corresponds to the mlir::Location of the tensor. + // `op_name_pattern` supports regex matching. + string op_name_pattern = 1; +} diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h b/tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h index 3711b8874a3..a40a3b94b52 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h @@ -28,7 +28,7 @@ namespace tac { // See TacModule in how to register it with the module and use it. class TacImporter { public: - virtual ~TacImporter() {} + virtual ~TacImporter() = default; // Imports and returns the Module for the imported program. virtual absl::StatusOr> Import() = 0; @@ -40,7 +40,7 @@ class TacImporter { // See TacModule in how to register it with the module and use it. class TacExporter { public: - virtual ~TacExporter() {} + virtual ~TacExporter() = default; // Imports and returns the Module for the imported program. virtual absl::Status Export(mlir::ModuleOp module) = 0; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc index 805b7802517..8313bf2c10e 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h index 883b1ba84e2..7733a9bda80 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ +#include #include #include #include @@ -57,7 +58,7 @@ class TacModule { bool legalize_to_tflite_ops = false; }; - virtual ~TacModule() {} + virtual ~TacModule() = default; explicit TacModule(const Options& options) : options_(options) {} diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD index 1ae5f737d37..f3a0574e882 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD index 8fef794a866..58beccdb043 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD @@ -10,6 +10,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir index 18b9e0fd605..3018221fdac 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir @@ -314,7 +314,7 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor : tensor<1xi32> %cst_3 = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.shape"(%arg2) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> - %1 = "tfl.strided_slice"(%0, %cst_3, %cst_2, %cst_2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %1 = "tfl.strided_slice"(%0, %cst_3, %cst_2, %cst_2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, offset = false, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %2 = "tfl.custom"(%cst_1, %1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> %3 = "tfl.custom"(%cst_1, %1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> %4:8 = "tfl.while"(%cst_0, %cst_0, %arg5, %arg6, %2, %2, %3, %3) ({ @@ -337,23 +337,23 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor : tensor<1xi32> %cst_14 = arith.constant dense<0> : tensor<1xi32> %9 = "tfl.shape"(%arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> - %10 = "tfl.strided_slice"(%9, %cst_14, %cst_13, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %10 = "tfl.strided_slice"(%9, %cst_14, %cst_13, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, offset = false, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %11 = "tfl.range"(%cst_12, %10, %cst_11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor %12 = "tfl.pack"(%10, %cst_11) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> - %13 = "tfl.strided_slice"(%9, %cst_13, %cst_10, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %13 = "tfl.strided_slice"(%9, %cst_13, %cst_10, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, offset = false, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %14 = tfl.mul(%11, %13) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %15 = "tfl.reshape"(%14, %12) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor - %16 = "tfl.strided_slice"(%9, %cst_14, %cst_10, %cst_13) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %16 = "tfl.strided_slice"(%9, %cst_14, %cst_10, %cst_13) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %17 = "tfl.reduce_prod"(%16, %cst_14) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> %18 = "tfl.reshape"(%arg1, %17) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor %19 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> - %20 = "tfl.strided_slice"(%19, %cst_14, %cst_13, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %20 = "tfl.strided_slice"(%19, %cst_14, %cst_13, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, offset = false, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %21 = "tfl.range"(%cst_12, %20, %cst_11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor %22 = "tfl.pack"(%20, %cst_11) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> - %23 = "tfl.strided_slice"(%19, %cst_13, %cst_10, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %23 = "tfl.strided_slice"(%19, %cst_13, %cst_10, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, offset = false, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %24 = tfl.mul(%21, %23) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor %25 = "tfl.reshape"(%24, %22) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor - %26 = "tfl.strided_slice"(%19, %cst_14, %cst_10, %cst_13) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %26 = "tfl.strided_slice"(%19, %cst_14, %cst_10, %cst_13) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %27 = "tfl.reduce_prod"(%26, %cst_14) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> %28 = "tfl.reshape"(%arg0, %27) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor %29 = tfl.add %arg8, %cst_11 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor @@ -406,7 +406,7 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { // CHECK: %0 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> -// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: return %1 : tensor // CHECK: } // CHECK: func.func private @func_1_CPU_FLOAT(%arg0: tensor<1xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor<2xi32>) -> (tensor, tensor) attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { @@ -458,25 +458,25 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<1xi32>) -> (tensor, tensor<2xi32>) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_3"} { // CHECK: %0 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> -// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: %2 = "tfl.range"(%arg3, %1, %arg4) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor // CHECK: %3 = "tfl.pack"(%1, %arg4) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> -// CHECK: %4 = "tfl.strided_slice"(%0, %arg2, %arg5, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.strided_slice"(%0, %arg2, %arg5, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: %5 = tfl.mul(%2, %4) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %6 = "tfl.reshape"(%5, %3) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor -// CHECK: %7 = "tfl.strided_slice"(%0, %arg1, %arg5, %arg2) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %7 = "tfl.strided_slice"(%0, %arg1, %arg5, %arg2) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK: return %6, %7 : tensor, tensor<2xi32> // CHECK: } // CHECK: func.func private @func_4_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor<1xi32>, %arg4: tensor<1xi32>, %arg5: tensor, %arg6: tensor, %arg7: tensor<1xi32>) -> (tensor, tensor, tensor<2xi32>) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_4"} { // CHECK: %0 = "tfl.reshape"(%arg0, %arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor // CHECK: %1 = "tfl.shape"(%arg2) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> -// CHECK: %2 = "tfl.strided_slice"(%1, %arg3, %arg4, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %2 = "tfl.strided_slice"(%1, %arg3, %arg4, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: %3 = "tfl.range"(%arg5, %2, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor // CHECK: %4 = "tfl.pack"(%2, %arg6) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> -// CHECK: %5 = "tfl.strided_slice"(%1, %arg4, %arg7, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %5 = "tfl.strided_slice"(%1, %arg4, %arg7, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor // CHECK: %6 = tfl.mul(%3, %5) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor // CHECK: %7 = "tfl.reshape"(%6, %4) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor -// CHECK: %8 = "tfl.strided_slice"(%1, %arg3, %arg7, %arg4) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %8 = "tfl.strided_slice"(%1, %arg3, %arg7, %arg4) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK: return %0, %7, %8 : tensor, tensor, tensor<2xi32> // CHECK: } // CHECK: func.func private @func_5_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor<1xi32>, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor<5xi32>, %arg14: tensor, %arg15: tensor<5xi32>, %arg16: tensor, %arg17: tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_5"} { diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/tac-filter.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/tac-filter.mlir new file mode 100644 index 00000000000..9b6d68c49f5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/tac-filter.mlir @@ -0,0 +1,64 @@ +// RUN: tac-opt-all-backends -tfl-tac-filter='use-test-setting=true' %s -split-input-file -verify-diagnostics | FileCheck %s + +// expected-remark@below {{Tac filter (0): filter type: function filter SKIP_TARGET_ANNOTATION, filter_pattern: "^testFunction"}} +// expected-remark@below {{Tac filter (1): filter type: function filter INCLUDE_TARGET_ANNOTATION, filter_pattern: "testFunctionInclude"}} +// expected-remark@below {{Tac filter (1) specified but not applied to any op}} +// expected-remark@below {{Tac filter (2): filter type: op filter, filter_pattern: "^test_op"}} +// expected-remark@below {{Tac filter (2) specified but not applied to any op}} +module { + // CHECK-LABEL: testFunctionSkiped + // expected-remark@+1 {{filtered by tac filter (0)}} + func.func @testFunctionSkiped(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) { + // CHECK: tfl.add + // CHECK-SAME: tac.skip_target_annotation + %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: tfl.add + // CHECK-SAME: tac.skip_target_annotation + %1 = "tfl.add"(%arg0, %0) {fused_activation_function = "RELU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: tfl.relu + // CHECK-SAME: tac.skip_target_annotation + %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + func.return + } +} + +// ----- + +// expected-remark@below {{Tac filter (0): filter type: function filter SKIP_TARGET_ANNOTATION, filter_pattern: "^testFunction"}} +// expected-remark@below {{Tac filter (1): filter type: function filter INCLUDE_TARGET_ANNOTATION, filter_pattern: "testFunctionInclude"}} +// expected-remark@below {{Tac filter (2): filter type: op filter, filter_pattern: "^test_op"}} +// expected-remark@below {{Tac filter (2) specified but not applied to any op}} +module { + // CHECK-LABEL: testFunctionInclude + // CHECK-NOT: tac.skip_target_annotation + // expected-remark@+2 {{filtered by tac filter (0)}} + // expected-remark@+1 {{filtered by tac filter (1)}} + func.func @testFunctionInclude(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) { + %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return + } +} + +// ----- + +// expected-remark@below {{Tac filter (0): filter type: function filter SKIP_TARGET_ANNOTATION, filter_pattern: "^testFunction"}} +// expected-remark@below {{Tac filter (0) specified but not applied to any op}} +// expected-remark@below {{Tac filter (1): filter type: function filter INCLUDE_TARGET_ANNOTATION, filter_pattern: "testFunctionInclude"}} +// expected-remark@below {{Tac filter (1) specified but not applied to any op}} +// expected-remark@below {{Tac filter (2): filter type: op filter, filter_pattern: "^test_op"}} +module { + // CHECK-LABEL: testOpFilter + // expected-remark@+1 {{all ops filtered by tac filter (2): "tfl.add", "tfl.relu"}} + func.func @testOpFilter(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) { + // CHECK: tfl.add + // CHECK-SAME: tac.skip_target_annotation + %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> loc("test_op_0") + // CHECK: tfl.add + // CHECK-NOT: tac.skip_target_annotation + %1 = "tfl.add"(%arg0, %0) {fused_activation_function = "RELU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> loc("non_test_op") + // CHECK: tfl.relu + // CHECK-SAME: tac.skip_target_annotation + %2 = "tfl.relu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> loc("test_op_1") + func.return + } +} diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/target-annotation.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/target-annotation.mlir index 22faae6016c..8197ca323c4 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/target-annotation.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/target-annotation.mlir @@ -80,3 +80,12 @@ func.func @annotateInferenceType(%arg0: tensor<1x1x384x!quant.uniform>, tensor<1x384x1x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> func.return %1 : tensor<1x384x384x!quant.uniform> } + +// ----- + +// CHECK-LABEL: testSkipAnnotation +func.func @testSkipAnnotation(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> { + // CHECK-NOT: tac.device + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32, tac.skip_target_annotation } : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + func.return %0 : tensor<256x30x30x16xf32> +} diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc b/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc index e579b60869b..bf3481b79a4 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h" +#include +#include #include #include "absl/status/status.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h b/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h index 2dcba4ab868..ed59787f946 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TFLITE_IMPORT_EXPORT_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TFLITE_IMPORT_EXPORT_H_ +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc index 330d3096b51..9760bad9998 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h" +#include #include #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc index 42852425741..7ccf26d3bac 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h index f738b2e7a60..a16b0f772c0 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/tac_filter.pb.h" namespace mlir { namespace TFL { @@ -64,6 +65,11 @@ std::unique_ptr> CreateGetOpCostPass(); std::unique_ptr> CreateFoldConstantsToSubgraphPass( bool fold_all_constants); +// Create an instance of TacFilterPass. +std::unique_ptr> CreateTacFilterPass( + ::third_party::tensorflow::compiler::mlir::lite::experimental::tac:: + TacFilters* tac_filters); + } // namespace tac } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc new file mode 100644 index 00000000000..a2f7441cc17 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc @@ -0,0 +1,259 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "google/protobuf/text_format.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Regex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" +#include "tensorflow/compiler/mlir/lite/experimental/tac/tac_filter.pb.h" +#include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h" + +namespace mlir { +namespace TFL { +namespace tac { +namespace { + +using ::third_party::tensorflow::compiler::mlir::lite::experimental::tac:: + FunctionFilter; +using ::third_party::tensorflow::compiler::mlir::lite::experimental::tac:: + TacFilter; +using ::third_party::tensorflow::compiler::mlir::lite::experimental::tac:: + TacFilters; + +class TacFilterPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TacFilterPass) + + TacFilterPass() = default; + TacFilterPass(const TacFilterPass& other) { + this->tac_filters_ = other.tac_filters_; + } + explicit TacFilterPass(TacFilters* tac_filters) { + tac_filters_ = tac_filters; + } + + private: + TacFilters* tac_filters_ = nullptr; + + llvm::StringRef getArgument() const final { return "tfl-tac-filter"; } + llvm::StringRef getDescription() const final { + return "This pass marks the ops to skip target annotation by inserting " + "`tac.skip_target_annotation` attribute to them based on user " + "provided config."; + } + + Option use_test_setting_{ + *this, "use-test-setting", + llvm::cl::desc( + "Whether to use the test config for the tac filter protobuf."), + llvm::cl::init(false)}; + + void runOnOperation() override; +}; + +void ApplyFunctionTacFilter(func::FuncOp func, + FunctionFilter::FunctionFilterType type, + OpBuilder& builder) { + for (Operation& op : func.front()) { + if (type == FunctionFilter::SKIP_TARGET_ANNOTATION) { + op.setAttr(kSkipTargetAnnotation, builder.getUnitAttr()); + } else if (type == FunctionFilter::INCLUDE_TARGET_ANNOTATION) { + op.removeAttr(kSkipTargetAnnotation); + } + } +} + +void ApplyTacFilter(ModuleOp module, const TacFilter& tac_filter, + SmallVector& filtered_ops, OpBuilder& builder) { + if (tac_filter.has_function_filter()) { + llvm::Regex func_regex( + tac_filter.function_filter().function_name_pattern()); + for (auto func : module.getOps()) { + if (!func_regex.match(func.getName())) { + continue; + } + + ApplyFunctionTacFilter(func, tac_filter.function_filter().filter_type(), + builder); + filtered_ops.push_back(func); + } + return; + } + + llvm::Regex op_regex(tac_filter.op_filter().op_name_pattern()); + module.walk([&](Operation* op) { + auto named_loc = op->getLoc().dyn_cast(); + if (!named_loc) { + return; + } + if (!op_regex.match(named_loc.getName())) { + return; + } + + op->setAttr(kSkipTargetAnnotation, builder.getUnitAttr()); + filtered_ops.push_back(op); + }); +} + +// A custom string for tac filter. +std::string TacFilterToString(const TacFilter& tac_filter) { + std::string tac_filter_type_str; + std::string tac_filter_name_pattern_str; + if (tac_filter.has_function_filter()) { + tac_filter_type_str = (llvm::Twine("function filter ") + + FunctionFilter::FunctionFilterType_Name( + tac_filter.function_filter().filter_type())) + .str(); + tac_filter_name_pattern_str = + tac_filter.function_filter().function_name_pattern(); + } else { + tac_filter_type_str = "op filter"; + tac_filter_name_pattern_str = tac_filter.op_filter().op_name_pattern(); + } + return (llvm::Twine("filter type: ") + tac_filter_type_str + + ", filter_pattern: \"" + tac_filter_name_pattern_str + "\"") + .str(); +} + +void PrintTacFilterResult(Location module_loc, const TacFilter& tac_filter, + int count, + const SmallVector& filtered_ops) { + emitRemark(module_loc) << llvm::formatv("Tac filter ({0}): {1}", count, + TacFilterToString(tac_filter)); + if (filtered_ops.empty()) { + emitRemark(module_loc) << llvm::formatv( + "Tac filter ({0}) specified but not applied to any op", count); + return; + } + + if (tac_filter.has_function_filter()) { + for (Operation* op : filtered_ops) { + auto func = cast(op); + func.emitRemark() << llvm::formatv("filtered by tac filter ({0})", count); + } + return; + } + + DenseMap> func_to_filtered_ops_map; + for (Operation* op : filtered_ops) { + auto func = op->getParentOfType(); + func_to_filtered_ops_map[func].push_back(op); + } + for (auto& [func, ops] : func_to_filtered_ops_map) { + std::string interleaved_op_name; + llvm::raw_string_ostream os(interleaved_op_name); + llvm::interleaveComma( + ops, os, [&](Operation* op) { os << "\"" << op->getName() << "\""; }); + os.flush(); + func.emitRemark() << llvm::formatv( + "all ops filtered by tac filter ({0}): {1}", count, + interleaved_op_name); + } +} + +void TacFilterPass::runOnOperation() { + TacFilters test_tac_filters; + if (use_test_setting_) { + // Sets up the test config used in the mlir LIT test. + google::protobuf::TextFormat::ParseFromString(R"( + tac_filters { + function_filter { + function_name_pattern: "^testFunction" + } + } + tac_filters { + function_filter { + function_name_pattern: "testFunctionInclude" + filter_type: INCLUDE_TARGET_ANNOTATION + } + } + tac_filters { + op_filter { + op_name_pattern: "^test_op" + } + } + )", + &test_tac_filters); + tac_filters_ = &test_tac_filters; + } + + if (!tac_filters_) { + return; + } + + ModuleOp module = getOperation(); + OpBuilder builder(module); + std::sort(tac_filters_->mutable_tac_filters()->pointer_begin(), + tac_filters_->mutable_tac_filters()->pointer_end(), + [](const TacFilter* a, const TacFilter* b) { + const bool a_is_function_filter = a->has_function_filter(); + const bool b_is_function_filter = b->has_function_filter(); + if (a_is_function_filter != b_is_function_filter) { + // Function filter is applied before op filter. + return a_is_function_filter > b_is_function_filter; + } + + if (!a_is_function_filter && !b_is_function_filter) { + // The order of 2 op filters don't matter. + return false; + } + + const bool a_is_function_exclude = + (a->function_filter().filter_type() == + FunctionFilter::SKIP_TARGET_ANNOTATION); + const bool b_is_function_exclude = + (b->function_filter().filter_type() == + FunctionFilter::SKIP_TARGET_ANNOTATION); + // Function exclude filter is applied before function include + // filter. + return a_is_function_exclude > b_is_function_exclude; + }); + + for (const auto& tac_filter : llvm::enumerate(tac_filters_->tac_filters())) { + SmallVector filtered_ops; + ApplyTacFilter(module, tac_filter.value(), filtered_ops, builder); + PrintTacFilterResult(module.getLoc(), tac_filter.value(), + tac_filter.index(), filtered_ops); + } +} + +} // namespace + +std::unique_ptr> CreateTacFilterPass( + TacFilters* tac_filters) { + return std::make_unique(tac_filters); +} + +static PassRegistration pass; + +} // namespace tac +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h index 392a2713e95..6e61dbe99bb 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_TAC_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_TAC_PASS_H_ +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -39,7 +40,7 @@ class TacPass : public OperationPass { : OperationPass::OperationPass(mlir::TypeID::get()), module_(module) {} - ~TacPass() override {} + ~TacPass() override = default; const TargetHardware* GetTargetHardware( const std::string& hardware_name) const { @@ -62,7 +63,7 @@ class TacFunctionPass : public TacPass { public: using TacPass::TacPass; - ~TacFunctionPass() override {} + ~TacFunctionPass() override = default; mlir::func::FuncOp getFunction() { return getOperation(); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc index 2dddad4e9a8..6d1bf7ab934 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc @@ -92,6 +92,9 @@ void SetAnnotation(Operation* op, std::string attribute, std::string annotation, void TargetAnnotationPass::SetTargetAnnotation( Operation* op, llvm::ArrayRef device_specs, OpBuilder* builder) { + if (op->hasAttr(kSkipTargetAnnotation)) { + return; + } const InferenceType inference_type = GetInferenceType(op); const std::string inference_type_str = GetInferenceString(inference_type); SetAnnotation(op, kInferenceType, inference_type_str, builder); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc index 20c81962e5a..aef77e208d2 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h" +#include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 1e4475bd4b3..9bb1c172116 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -649,7 +649,7 @@ class Translator { // to a function's body or while op. Modifies *region by calling // ExtractControlEdges. std::optional> BuildSubGraph( - const std::string& name, Region* region, const int index); + const std::string& name, Region* region, int index); // Modifies *block by unwrapping all ControlNodeOps. The DAG of the control // dependencies is returned as a vector of its edges, with node indices into @@ -674,8 +674,7 @@ class Translator { // 'items' is a map from tensor name in signatureDef to tensor name in // the subgraph, specified by the 'subgraph_index' argument. std::vector> GetList( - const int subgraph_index, - const std::map& items); + int subgraph_index, const std::map& items); // Uses the tf.entry_function attribute (if set) to initialize the op to name // mapping. diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 487b3edd60a..143b67acf96 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include #include +#include #include +#include #include #include #include @@ -973,12 +975,19 @@ StatusOr> GetTensorIndices( return indices; } +// Given a list of tensor indices, returns true if any of the tensors have +// non-empty name strings. +bool HasNonEmptyNames(const tflite::SubGraphT& subgraph, + ArrayRef indices) { + return llvm::any_of( + indices, [&](int i) { return !subgraph.tensors.at(i)->name.empty(); }); +} + // Given a list of tensor indices, returns a string of concatenated tensor names // wrapped in a NamedAttribute. -template mlir::NamedAttribute BuildTFEntryFunctionAttribute( - const tflite::SubGraphT& subgraph, Builder* builder, const std::string name, - const ContainerType indices) { + const tflite::SubGraphT& subgraph, Builder* builder, + const std::string& name, ArrayRef indices) { auto tensor_names = llvm::map_range( indices, [&](int i) { return subgraph.tensors.at(i)->name; }); return builder->getNamedAttr( @@ -1351,15 +1360,17 @@ StatusOr ConvertSubgraph( // Set tf.entry_function attribute if (is_entry_point) { llvm::SmallVector attributes; - if (!func_inputs.empty()) { + if (HasNonEmptyNames(subgraph, func_inputs)) { attributes.push_back(BuildTFEntryFunctionAttribute( subgraph, &builder, "inputs", func_inputs)); } - if (!func_outputs.empty()) { + if (HasNonEmptyNames(subgraph, func_outputs)) { attributes.push_back(BuildTFEntryFunctionAttribute( subgraph, &builder, "outputs", func_outputs)); } - func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); + if (!attributes.empty()) { + func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); + } } else { func.setPrivate(); } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.h b/tensorflow/compiler/mlir/lite/flatbuffer_import.h index 8707be2894e..76edd13afd4 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ #include +#include #include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 35475091aa8..2f1779b97d0 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 88e3029188b..62cb3447313 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MemoryBuffer.h" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index f23dfd96e88..58820f0edee 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include +#include #include #include #include +#include #include #include #include @@ -3330,33 +3332,36 @@ OpFoldResult StridedSliceOp::fold(FoldAdaptor) { namespace { -// Computes the permutation of a constant `input_tensor` according to `perm`. // The function recursively traverses the dimensions of the output tensor in -// a row-major order and writes the value in the output tensor into -// `new_values`. -void ComputePermutation(mlir::detail::ElementsAttrRange< - mlir::detail::ElementsAttrIterator> - input_tensor_values, - ArrayRef perm, ArrayRef output_shape, - const int num_dimensions, const int output_axis, - std::vector* input_indices, - std::vector* new_values) { - // Refer to the implementation of `Transpose` function in - // tensorflow/lite/kernels/internal/reference/reference_ops.h - assert(output_axis < num_dimensions); - const int input_axis = perm[output_axis]; - for (int i = 0; i < output_shape[output_axis]; ++i) { +// a row-major order and writes the value of the output tensor into +// `output_element_addr`. +// TODO(@lukeboyer) make element byte size a template param. +void ComputePermutation(ArrayRef perms, ArrayRef output_shape, + const char* raw_input, const int element_byte_size, + const int64_t current_axis, char*& output_element_addr, + MutableArrayRef current_input_index, + ShapedType input_shape_type) { + const int64_t input_axis = perms[current_axis]; + const bool is_last_axis = current_axis == output_shape.size() - 1; + for (int i = 0; i < output_shape[current_axis]; ++i) { // Update the input indices on `input_axis`. - input_indices->at(input_axis) = i; + current_input_index[input_axis] = i; // Write the value from `input_tensor` if it is the last axis or // recurse into the next axis. - const bool is_last_axis = output_axis == num_dimensions - 1; if (is_last_axis) { - new_values->push_back(input_tensor_values[*input_indices]); + int64_t input_flat_index = ElementsAttr::getFlattenedIndex( + input_shape_type, current_input_index); + // Address of input element to write raw data. + const char* input_element_addr = + raw_input + (input_flat_index * element_byte_size); + std::memcpy(output_element_addr, input_element_addr, element_byte_size); + // Increment the next output address to write to by bytes equal to + // width of constiuent elements. + output_element_addr += element_byte_size; } else { - ComputePermutation(input_tensor_values, perm, output_shape, - num_dimensions, output_axis + 1, input_indices, - new_values); + ComputePermutation(perms, output_shape, raw_input, element_byte_size, + current_axis + 1, output_element_addr, + current_input_index, input_shape_type); } } } @@ -3365,8 +3370,8 @@ void ComputePermutation(mlir::detail::ElementsAttrRange< OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - assert(operands.size() == 2); - auto input_tensor = operands[0].dyn_cast_or_null(); + + auto input_tensor = operands[0].dyn_cast_or_null(); auto perm_tensor = operands[1].dyn_cast_or_null(); if (!input_tensor || !perm_tensor) return nullptr; @@ -3375,33 +3380,56 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { if (!getType().cast().getElementType().isSignlessIntOrFloat()) return nullptr; - assert(perm_tensor.getShapedType().getRank() == 1); - const int num_dimensions = input_tensor.getShapedType().getRank(); - assert(perm_tensor.getShapedType().getNumElements() == num_dimensions); - - ArrayRef input_shape = input_tensor.getShapedType().getShape(); - auto output_type = getType().cast(); - - SmallVector perm; - SmallVector output_shape; + // TODO(b/280099953) This algorithm only works for fixed width element types. + // This is the usual case, but consider falling back to old approach + // if transposing string tensors becomes needed while folding. + if (!input_tensor.getElementType().isIntOrIndexOrFloat()) return nullptr; + SmallVector perms; + SmallVector output_shape; + ArrayRef input_shape = input_tensor.getType().getShape(); + auto attr_iter = perm_tensor.getValues(); + const int num_dimensions = input_tensor.getType().getRank(); for (int i = 0; i < num_dimensions; ++i) { - perm.push_back(perm_tensor.getValues()[i].getInt()); - output_shape.push_back(input_shape[perm[i]]); - - // Check that the derived output shape matches the static shape. - assert(!output_type.hasStaticShape() || - output_type.getShape()[i] == output_shape[i]); + perms.push_back(attr_iter[i].getInt()); + output_shape.push_back(input_shape[perms[i]]); } - std::vector new_values; - new_values.reserve(input_tensor.getShapedType().getNumElements()); - std::vector input_indices(num_dimensions); - auto input_tensor_values = input_tensor.getValues(); - ComputePermutation(input_tensor_values, perm, output_shape, num_dimensions, - /*output_axis=*/0, &input_indices, &new_values); - auto result_type = tensorflow::GetTypeFromTFTensorShape( - output_shape, output_type.getElementType()); - return DenseElementsAttr::get(result_type, new_values); + // If the input tensor values are splat, then it has exactly one value. + // It is sufficient then to just reshape the input data. + if (input_tensor.isSplat()) { + return input_tensor.reshape(input_tensor.getType().cloneWith( + output_shape, input_tensor.getElementType())); + } + + // MLIR implementation pads elements < 8 bits to 8 bits and pads non byte + // aligned to the nearest byte. So this is allowed. + const char* raw_input = input_tensor.getRawData().data(); + const int element_byte_size = + input_tensor.getElementType().getIntOrFloatBitWidth() / 8; + + // Hold current ND index in input tensor when computing + // permutation. + llvm::OwningArrayRef current_input_index( + input_tensor.getType().getRank()); + + // Allocate raw data and retrieve address of the first char in its raw + // buffer. + llvm::OwningArrayRef raw_output_arr(input_tensor.getRawData()); + char* raw_output = (char*)raw_output_arr.data(); + + // Compute the result and write to `raw_output`. + ComputePermutation(perms, output_shape, raw_input, element_byte_size, + /*current_axis=*/0, raw_output, current_input_index, + input_tensor.getType()); + + bool detected_splat = false; + const bool valid_output_buffer = DenseElementsAttr::isValidRawBuffer( + input_tensor.getType(), raw_output_arr, detected_splat); + if (!valid_output_buffer || detected_splat) return nullptr; + + auto result_type = + RankedTensorType::get(output_shape, input_tensor.getElementType()); + return DenseElementsAttr::getFromRawBuffer(result_type, raw_output_arr); } mlir::LogicalResult TransposeOp::verify() { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 8266fc605c0..01c77e4f21c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -803,14 +803,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", let arguments = ( ins TFL_VariadicTensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$values, + [F32, I64, I32, I16, I8, QI8, QUI8, UI8, UI32, I1]>:$values, I32Attr:$axis, TFL_AFAttr:$fused_activation_function ); let results = (outs TFL_TensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, UI8, I1]>:$output + [F32, I64, I32, I16, I8, QI8, QUI8, UI8, UI32, I1]>:$output ); let hasOptions = 1; @@ -3063,11 +3063,11 @@ def TFL_RangeOp: TFL_Op<"range", [ }]; let arguments = (ins - TFL_TensorOf<[I32, F32]>:$start, - TFL_TensorOf<[I32, F32]>:$limit, - TFL_TensorOf<[I32, F32]>:$delta); + TFL_TensorOf<[I32, F32, I64]>:$start, + TFL_TensorOf<[I32, F32, I64]>:$limit, + TFL_TensorOf<[I32, F32, I64]>:$delta); - let results = (outs TFL_TensorOf<[I32, F32]>:$result); + let results = (outs TFL_TensorOf<[I32, F32, I64]>:$result); let hasFolder = 1; } @@ -3873,7 +3873,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ I32Attr:$end_mask, I32Attr:$ellipsis_mask, I32Attr:$new_axis_mask, - I32Attr:$shrink_axis_mask + I32Attr:$shrink_axis_mask, + BoolAttr:$offset ); let results = (outs diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h index 738ea1ecd2d..322ec2e852d 100644 --- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h +++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_INST_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_INST_H_ +#include #include +#include #include #include "mlir/IR/Location.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc index 7874f6c5f3c..75a0c3eb3bb 100644 --- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc +++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include +#include #include #include #include @@ -52,7 +53,7 @@ class MockSuccessPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MockSuccessPass) - explicit MockSuccessPass() {} + explicit MockSuccessPass() = default; private: void runOnOperation() override { @@ -73,7 +74,7 @@ class MockFailurePass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MockFailurePass) - explicit MockFailurePass() {} + explicit MockFailurePass() = default; private: void runOnOperation() override { diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index bef202887a8..4f97fb56f86 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc index b71d6c1dbd2..02743c9c65f 100644 --- a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc +++ b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 85c87fd66ad..344c558ba3e 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index 1b0f22c7cd1..473f63812bd 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/str_join.h" #include "absl/types/span.h" diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 74c09b3e9e6..acbfa08e770 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include #include +#include #include +#include #include "absl/types/span.h" #include "llvm/ADT/StringSet.h" diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 5cfbc0c937a..fb5efba769a 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index d5a98612b2d..85d4ddffaa2 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 29216f3be16..7581b5c78cf 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" #include +#include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 353e023d3fb..798e011dec2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -17,9 +17,13 @@ limitations under the License. #include #include #include +#include +#include #include #include #include +#include +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h index d7cb5ab1fe6..6c94e4c2d10 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 9ae572b1c7a..36855cdb744 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" #include +#include #include #include +#include #include #include "llvm/ADT/Twine.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 5bd1b71e631..fe5ca2ca8f1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc index d1fc9318116..759893401e6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc @@ -171,7 +171,7 @@ inline TFL::ConstBytesAttr CustomOptionForFlexOp(OpBuilder *builder, class FallbackToFlexOps : public PassWrapper> { public: - FallbackToFlexOps() {} + FallbackToFlexOps() = default; explicit FallbackToFlexOps(const std::string &mode) { mode_ = mode; } FallbackToFlexOps(const FallbackToFlexOps &other) { mode_ = other.mode_; } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD index 03332e19f6a..796676e1d28 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 6afc81e8ce9..9dfe5166033 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD index 03332e19f6a..796676e1d28 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index 04d0b7675bb..ad4112a05ad 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc index f73f162929f..7cbcb108729 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include #include "absl/strings/match.h" #include "absl/strings/str_replace.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index 525d73c1b79..04e9a070af8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD index 7002dd57dda..dd691a25be1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD @@ -10,6 +10,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir/lite/stablehlo:run_lit.sh", size_override = { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc index b7277ae0415..2a7950cf581 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc @@ -44,6 +44,28 @@ namespace odml { static constexpr std::string_view kStablehloModuleDefaultEntryFuncName = "main"; static constexpr std::string_view kStablehloFuncNamePrefix = "XlaCallModule"; +static constexpr char kShardingAttr[] = "mhlo.sharding"; +static constexpr char kShardingName[] = "Sharding"; + +class RemoveCustomCallWithSharding + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::stablehlo::CustomCallOp op, + PatternRewriter &rewriter) const override { + // Removes the custom call with sharding op if the operand type is the + // same as the result type. + if (op->hasAttr(kShardingAttr) && op.getCallTargetName() == kShardingName && + op.getNumOperands() == 1 && op.getNumResults() == 1 && + op.getOperands().front().getType() == + op.getResults().front().getType()) { + rewriter.replaceOp(op, op.getOperands()); + return mlir::success(); + } + return mlir::failure(); + } +}; class ConvertTFXlaCallModuleOp : public mlir::OpRewritePattern { @@ -90,10 +112,12 @@ class ConvertTFXlaCallModuleOp stablehlo_module_op.get().getOps()) { mlir::func::FuncOp cloned_func_op = func_op.clone(); if (cloned_func_op.getSymName().contains( - kStablehloModuleDefaultEntryFuncName)) { + kStablehloModuleDefaultEntryFuncName) && + cloned_func_op.getSymVisibility() == "public") { main_fn = cloned_func_op; - main_fn.setSymVisibility(stablehlo_builder.getStringAttr("private")); } + cloned_func_op.setSymVisibility( + stablehlo_builder.getStringAttr("private")); parent_module_symbol_table.insert(cloned_func_op); } @@ -159,6 +183,7 @@ class TFXlaCallModuleOpToStablehloPass ModuleOp module_op = getOperation(); RewritePatternSet patterns(&getContext()); patterns.add(&getContext(), module_op); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index f2cfa39ab54..72efe28296c 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", exclude = ["load-quantization-recipe.mlir"], diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 5dcaada7e56..ba9c1e58565 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -450,6 +450,7 @@ func.func @transpose_no_fold(%arg0 : tensor<2xi32>) -> tensor<2x2xi32> { func.return %0 : tensor<2x2xi32> } + // CHECK-LABEL: @transpose_1d // Basic 1D identity func.func @transpose_1d() -> tensor<3xi32> { @@ -484,6 +485,17 @@ func.func @transpose_2d() -> tensor<2x2xi32> { func.return %0 : tensor<2x2xi32> } +// CHECK-LABEL: @transpose_2d_splat +func.func @transpose_2d_splat() -> tensor<3x2xi32> { + %cst = arith.constant dense<0> : tensor<2x3xi32> + %cst_perm = arith.constant dense<[1, 0]> : tensor<2xi32> + + // CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<3x2xi32> + // CHECK: return %[[CST]] + %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32> + func.return %0 : tensor<3x2xi32> +} + // CHECK-LABEL: @transpose_2d_identity func.func @transpose_2d_identity() -> tensor<2x2xi32> { %cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> @@ -837,7 +849,7 @@ func.func @ConstFoldStridedSlice(%arg0 : tensor<15600xf32>) -> tensor<15600xf32> %0 = "tfl.pseudo_const"() {value = dense<15600> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tfl.pseudo_const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %2 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tfl.strided_slice"(%arg0, %1, %0, %2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<15600xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<15600xf32> + %3 = "tfl.strided_slice"(%arg0, %1, %0, %2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<15600xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<15600xf32> func.return %3 : tensor<15600xf32> // CHECK: return %arg0 } @@ -846,7 +858,7 @@ func.func @ConstFoldStridedSliceMultiDims(%arg0 : tensor<10x10x10xf32>) -> tenso %0 = "tfl.pseudo_const"() {value = dense<[10, 10, 10]> : tensor<3xi32>} : () -> tensor<3xi32> %1 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> %2 = "tfl.pseudo_const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> - %3 = "tfl.strided_slice"(%arg0, %1, %0, %2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10x10xf32> + %3 = "tfl.strided_slice"(%arg0, %1, %0, %2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10x10xf32> func.return %3 : tensor<10x10x10xf32> // CHECK: return %arg0 } @@ -855,7 +867,7 @@ func.func @NotFoldStridedSlice(%arg0 : tensor<10x10x10xf32>) -> tensor<9x9x9xf32 %0 = "tfl.pseudo_const"() {value = dense<[9, 9, 9]> : tensor<3xi32>} : () -> tensor<3xi32> %1 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> %2 = "tfl.pseudo_const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> - %3 = "tfl.strided_slice"(%arg0, %1, %0, %2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<9x9x9xf32> + %3 = "tfl.strided_slice"(%arg0, %1, %0, %2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<9x9x9xf32> func.return %3 : tensor<9x9x9xf32> // CHECK: %[[STRIDED_SLICE:.*]] = "tfl.strided_slice" // CHECK: return %[[STRIDED_SLICE]] diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD index 9a0b427f294..bb7412a10f9 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [ ":debug_info_files", ":test_utilities", diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD index b162606d135..b0e8270e4dc 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [ ":quant_stats_files", ":test_utilities", diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index f7dbeaf48af..e1687b22816 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -7,6 +7,7 @@ load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [ ":extra_files", ":test_utilities", diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/empty_input_output_names.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/empty_input_output_names.json new file mode 100644 index 00000000000..87c809fa7cc --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/empty_input_output_names.json @@ -0,0 +1,81 @@ +// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s + +// If input and output tensors don't have names, there shouldn't be an +// `tf.entry_function` attribute created. +// CHECK-NOT: tf.entry_function + +{ + "version": 3, + "operator_codes": [ + { + "builtin_code": "CONV_2D" + } + ], + "subgraphs": [ + { + "tensors": [ + { + "shape": [ + 256, + 32, + 32, + 3 + ], + "quantization": { + } + }, + { + "shape": [ + 16, + 3, + 3, + 3 + ], + "quantization": { + } + }, + { + "shape": [ + 0 + ], + }, + { + "shape": [ + 256, + 32, + 32, + 16 + ], + "quantization": { + } + } + ], + "inputs": [ + 0, + 1 + ], + "outputs": [ + 3 + ], + "operators": [ + { + "inputs": [ + 0, + 1, + -1 + ], + "outputs": [ + 3 + ], + "builtin_options_type": "Conv2DOptions", + "builtin_options": { + "stride_w": 1, + "stride_h": 1 + } + } + ], + "name": "main" + } + ], + "description": "MLIR Converted." +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc index ba0fd474a3a..8fc5a0cb051 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 4f58b7af868..15a0f4b160d 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1307,84 +1307,84 @@ func.func @resize_with_bilinear_with_half_pixel_centers(%arg0: tensor<1x100x100x } func.func @strided_slice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { - %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> // CHECK-LABEL: strided_slice - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> { %cst = arith.constant dense<-1> : tensor<1xi32> %cst_1 = arith.constant dense<0> : tensor<1xi32> %cst_2 = arith.constant dense<1> : tensor<1xi32> - %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64, offset = false} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> func.return %0 : tensor<10x10xf32> // CHECK-LABEL: strided_slice_with_constant_attributes // CHECK-DAG: [[BEGIN:%cst.*]] = arith.constant dense<-1> : tensor<1xi32> // CHECK-DAG: [[END:%cst.*]] = arith.constant dense<0> : tensor<1xi32> // CHECK-DAG: [[STRIDES:%cst.*]] = arith.constant dense<1> : tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> } func.func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf_type.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> { - %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> func.return %0 : tensor<1x2x2x5x!tf_type.string> // CHECK-LABEL: strided_slice_with_string - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> } func.func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> { - %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32> + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32> func.return %0 : tensor<*xf32> // CHECK-LABEL: strided_slice_with_unranked_input_and_i64_parameters // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32> } func.func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> { - %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32> + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> // CHECK-LABEL: strided_slice_with_i64_parameters // CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32> // CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> { %cst = arith.constant dense<-1> : tensor<1xi64> %cst_1 = arith.constant dense<0> : tensor<1xi64> %cst_2 = arith.constant dense<1> : tensor<1xi64> - %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32> + %0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64, offset = false} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32> func.return %0 : tensor<10x10xf32> // CHECK-LABEL: strided_slice_with_i64_constant_attributes // CHECK-DAG: [[BEGIN:%cst.*]] = arith.constant dense<-1> : tensor<1xi32> // CHECK-DAG: [[END:%cst.*]] = arith.constant dense<0> : tensor<1xi32> // CHECK-DAG: [[STRIDES:%cst.*]] = arith.constant dense<1> : tensor<1xi32> - // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> + // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32> } func.func @strided_slice_non_zero_ellipsis_mask(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { - %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> // CHECK-LABEL: strided_slice_non_zero_ellipsis_mask - // CHECK: %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_non_zero_new_axis_mask(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { - %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 2 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 2 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> // CHECK-LABEL: strided_slice_non_zero_new_axis_mask - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 2 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 2 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> } func.func @strided_slice_big_dims(%arg0: tensor<5x6x7xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>) -> tensor<1x1x5x6x7xf32> { - %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 7 : i64, shrink_axis_mask = 0 : i64} : (tensor<5x6x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x5x6x7xf32> + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 7 : i64, shrink_axis_mask = 0 : i64, offset = false} : (tensor<5x6x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x5x6x7xf32> func.return %0 : tensor<1x1x5x6x7xf32> // CHECK-LABEL: strided_slice_big_dims - // CHECK: %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 7 : i64, shrink_axis_mask = 0 : i64} : (tensor<5x6x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x5x6x7xf32> + // CHECK: %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 7 : i64, offset = false, shrink_axis_mask = 0 : i64} : (tensor<5x6x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x5x6x7xf32> } func.func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor { @@ -2627,3 +2627,31 @@ func.func @batchmatmul2fullyconnected(%arg0: tensor<4x128x2xf32>) -> (tensor<4x1 // CHECK: return %2 : tensor<4x128x1xf32> } +func.func @approx_top_k_with_max_k_last_reduction_dimension(%arg0: tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) { + %values, %indices = "tf.ApproxTopK"(%arg0) {aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) + func.return %values, %indices: tensor<1x4xf32>, tensor<1x4xi32> + + // CHECK-LABEL: approx_top_k_with_max_k_last_reduction_dimension + // CHECK-DAG: %cst = arith.constant dense<4> : tensor + // CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %cst) : (tensor<1x4xf32>, tensor) -> (tensor<1x4xf32>, tensor<1x4xi32>) + // CHECK: return %values, %indices : tensor<1x4xf32>, tensor<1x4xi32> +} + +func.func @approx_top_k_with_min_k(%arg0: tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) { + %values, %indices = "tf.ApproxTopK"(%arg0) {aggregate_to_topk = true, is_max_k = false, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) + func.return %values, %indices: tensor<1x4xf32>, tensor<1x4xi32> + + // CHECK-LABEL: approx_top_k_with_min_k + // CHECK: %values, %indices = "tf.ApproxTopK"(%arg0) {aggregate_to_topk = true, is_max_k = false, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) + // CHECK: return %values, %indices : tensor<1x4xf32>, tensor<1x4xi32> +} + +func.func @approx_top_k_reduction_dimension_not_last_dim(%arg0: tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) { + %values, %indices = "tf.ApproxTopK"(%arg0) {aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 0 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) + func.return %values, %indices: tensor<1x4xf32>, tensor<1x4xi32> + + // CHECK-LABEL: approx_top_k_reduction_dimension_not_last_dim + // CHECK: %values, %indices = "tf.ApproxTopK"(%arg0) {aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 0 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) + // CHECK: return %values, %indices : tensor<1x4xf32>, tensor<1x4xi32> +} + diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD b/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD index 930e0f20b05..8bb228b8520 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD @@ -14,6 +14,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD index 7e748ffe18d..3d4e40f9119 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 628e523c488..51fc212a2a7 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1684,32 +1684,32 @@ func.func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, // CHECK-LABEL: testStridedSlice func.func @testStridedSlice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> { - // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %0 : tensor<1x2x2x5xf32> } // CHECK-LABEL: testStridedSliceWithQI8 func.func @testStridedSliceWithQI8(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> { - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> func.return %0 : tensor<1x2x2x5x!quant.uniform> } // CHECK-LABEL: testStridedSliceWithQUI8 func.func @testStridedSliceWithQUI8(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> { - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> func.return %0 : tensor<1x2x2x5x!quant.uniform> } // CHECK-LABEL: testStridedSliceTFType func.func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.quint8> { - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xui8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.quint8> + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xui8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.quint8> func.return %0 : tensor<1x2x2x5x!tf_type.quint8> } // CHECK-LABEL: testStridedSliceWithString func.func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf_type.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> { - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5x!tf_type.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf_type.string> func.return %0 : tensor<1x2x2x5x!tf_type.string> } @@ -1717,7 +1717,7 @@ func.func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf_type.string>, % func.func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> { // expected-error @+1 {{op failed to verify that input and output must have same element type}} - %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xi32> + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xi32> func.return %0 : tensor<1x2x2x5xi32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 8d57178a47f..2515e209396 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -13,7 +13,7 @@ func.func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x32x32x16xf32>) -> tensor<256x32x32x16xf32> func.return %1 : tensor<256x32x32x16xf32> - + // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> // CHECK: return %0 } @@ -568,8 +568,8 @@ func.func @FuseFullyConnectedAddWithScalarRhs(%arg0: tensor<40x37xf32>, %arg1: t // CHECK: return %[[fc]] } -// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs -func.func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: tensor<4x37xf32>) -> tensor<4x4xf32> { +// CHECK-LABEL: @FuseFullyConnectedAddNoBiasWithUnfusableRhs +func.func @FuseFullyConnectedAddNoBiasWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: tensor<4x37xf32>) -> tensor<4x4xf32> { %cst = "tfl.no_value"() {value} : () -> none %cst2 = arith.constant dense<[[2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3]]> : tensor<4x4xf32> @@ -585,6 +585,23 @@ func.func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: // CHECK: return %[[add_result]] } +// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs +func.func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: tensor<4x37xf32>) -> tensor<4x4xf32> { + %cst = arith.constant dense<[2.0, 2.1, 2.2, 2.3]> : tensor<4xf32> + %cst2 = arith.constant dense<[[2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3]]> : tensor<4x4xf32> + + %0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x37xf32>, tensor<4x37xf32>, tensor<4xf32>) -> (tensor<4x4xf32>) + %1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + + func.return %1 : tensor<4x4xf32> + + // CHECK-DAG: %[[bias:.*]] = arith.constant dense<{{.*}}> : tensor<4xf32> + // CHECK-DAG: %[[filter:.*]] = arith.constant dense<{{.*}}> : tensor<4x4xf32> + // CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[bias]]) + // CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]] + // CHECK: return %[[add_result]] +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAddConst // FOLD-LABEL: @FuseFullyConnectedReshapeAddConst func.func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { @@ -643,6 +660,46 @@ func.func @RetainRedundantReshapeUseInNonBinaryOp(%arg0: tensor<128xf32>, %arg1: // CHECK: return %1, %2 } +// CHECK-LABEL: @FuseTransposeReshapeTranspose +func.func @FuseTransposeReshapeTranspose(%arg0: tensor<1x16x256xf32>) -> tensor<16x256xf32> { + %cst_10 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> + %cst_3 = arith.constant dense<[256, 16]> : tensor<2xi32> + %cst_6 = arith.constant dense<[1, 0]> : tensor<2xi32> + %2057 = "tfl.transpose"(%arg0, %cst_10) : (tensor<1x16x256xf32>, tensor<3xi32>) -> tensor<1x256x16xf32> + %2058 = "tfl.reshape"(%2057, %cst_3) : (tensor<1x256x16xf32>, tensor<2xi32>) -> tensor<256x16xf32> + %2059 = "tfl.transpose"(%2058, %cst_6) : (tensor<256x16xf32>, tensor<2xi32>) -> tensor<16x256xf32> + return %2059: tensor<16x256xf32> + // CHECK-DAG: %cst = arith.constant dense<[16, 256]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<1x16x256xf32>, tensor<2xi32>) -> tensor<16x256xf32> + // CHECK: return %0 +} + +// CHECK-LABEL: @FoldDoubleTranspose +func.func @FoldDoubleTranspose(%arg0: tensor<1x4x1440x256xf32>) -> tensor<1x1440x256x4xf32> { + %cst_12 = arith.constant dense<[0, 1, 3, 2]> : tensor<4xi32> + %cst_18 = arith.constant dense<[0, 2, 1, 3]> : tensor<4xi32> + %2112 = "tfl.transpose"(%arg0, %cst_18) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x1440x4x256xf32> + %2114 = "tfl.transpose"(%2112, %cst_12) : (tensor<1x1440x4x256xf32>, tensor<4xi32>) -> tensor<1x1440x256x4xf32> + return %2114 : tensor<1x1440x256x4xf32> + // CHECK-DAG: %cst = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> + // CHECK: %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x1440x256x4xf32> + // CHECK: return %0 +} + +// CHECK-LABEL: @FoldMultpleTranspose +func.func @FoldMultpleTranspose(%arg0: tensor<1x4x1440x256xf32>) -> tensor<1x256x4x1440xf32> { + %cst_11 = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> + %cst_12 = arith.constant dense<[0, 1, 3, 2]> : tensor<4xi32> + %cst_18 = arith.constant dense<[0, 2, 1, 3]> : tensor<4xi32> + %2112 = "tfl.transpose"(%arg0, %cst_11) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x1440x256x4xf32> + %2113 = "tfl.transpose"(%2112, %cst_18) : (tensor<1x1440x256x4xf32>, tensor<4xi32>) -> tensor<1x256x1440x4xf32> + %2114 = "tfl.transpose"(%2113, %cst_12) : (tensor<1x256x1440x4xf32>, tensor<4xi32>) -> tensor<1x256x4x1440xf32> + return %2114 : tensor<1x256x4x1440xf32> + // CHECK-DAG: %cst = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> + // CHECK: %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x4x1440x256xf32>, tensor<4xi32>) -> tensor<1x256x4x1440xf32> + // CHECK: return %0 +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAddConstWithOptionalAttribute // FOLD-LABEL: @FuseFullyConnectedReshapeAddConstWithOptionalAttribute func.func @FuseFullyConnectedReshapeAddConstWithOptionalAttribute(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { @@ -2613,6 +2670,63 @@ func.func @noReplaceReshapeEqualWithOneHotBadIndex(%arg: tensor<2xi32>) -> tenso // CHECK: %[[RES:.*]] = "tfl.equal"(%[[TMP]], %[[CST2]]) : (tensor<2x1xi32>, tensor<3xi32>) -> tensor<2x3xi1> } +// CHECK-LABEL: ReplaceReshapeEqualOneHotDynamicBatch +func.func @ReplaceReshapeEqualOneHotDynamicBatch(%arg0: tensor) -> (tensor) { + %cst = arith.constant dense<-1> : tensor + %cst_0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32> + %0 = "tfl.expand_dims"(%arg0, %cst) : (tensor, tensor) -> tensor + %1 = "tfl.equal"(%0, %cst_0) : (tensor, tensor<10xi32>) -> tensor + %2 = "tfl.cast"(%1) : (tensor) -> tensor + func.return %2 : tensor + + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<-1> : tensor<1xi32> + // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<10> : tensor + // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<-1> : tensor + // CHECK: %[[EXPAND_DIMS:.*]] = "tfl.expand_dims"(%arg0, %[[CST_3]]) : (tensor, tensor) -> tensor + // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%0, %[[CST]]) : (tensor, tensor<1xi32>) -> tensor + // CHECK: %[[ONE_HOT:.*]] = "tfl.one_hot"(%1, %[[CST_0]], %[[CST_1]], %[[CST_2]]) {axis = -1 : i32} : (tensor, tensor, tensor, tensor) -> tensor + // CHECK-NEXT: return %[[ONE_HOT]] +} + +// CHECK-LABEL: noReplaceReshapeEqualWithOneHotDynamicNonBatch +func.func @noReplaceReshapeEqualWithOneHotDynamicNonBatch(%arg0: tensor<1x?xi32>) -> tensor<1x?x10xf32> { + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32> + %1 = "tfl.equal"(%arg0, %cst) : (tensor<1x?xi32>, tensor<10xi32>) -> tensor<1x?x10xi1> + %2 = "tfl.cast"(%1) : (tensor<1x?x10xi1>) -> tensor<1x?x10xf32> + func.return %2 : tensor<1x?x10xf32> + + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32> + // CHECK: %[[EQUAL:.*]] = "tfl.equal"(%arg0, %[[CST]]) : (tensor<1x?xi32>, tensor<10xi32>) -> tensor<1x?x10xi1> + // CHECK: %[[CAST:.*]] = "tfl.cast"(%[[EQUAL]]) : (tensor<1x?x10xi1>) -> tensor<1x?x10xf32> + // CHECK-NEXT: return %[[CAST]] +} + +// CHECK-LABEL: noReplaceReshapeEqualWithOneHotUnranked +func.func @noReplaceReshapeEqualWithOneHotUnranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { + %cst = arith.constant dense : tensor + %1 = "tfl.equal"(%arg0, %cst) : (tensor<*xi1>, tensor) -> tensor<*xi1> + func.return %1 : tensor<*xi1> + + // CHECK-DAG: %[[CST:.*]] = arith.constant dense : tensor + // CHECK: %[[EQUAL:.*]] = "tfl.equal"(%arg0, %cst) : (tensor<*xi1>, tensor) -> tensor<*xi1> + // CHECK-NEXT: return %[[EQUAL]] +} + +// CHECK-LABEL: noReplaceReshapeEqualWithOneHotDynamicNonBatchRank1 +func.func @noReplaceReshapeEqualWithOneHotDynamicNonBatchRank1(%arg0: tensor) -> tensor { + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32> + %1 = "tfl.equal"(%arg0, %cst) : (tensor, tensor<10xi32>) -> tensor + %2 = "tfl.cast"(%1) : (tensor) -> tensor + func.return %2 : tensor + + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32> + // CHECK: %[[EQUAL:.*]] = "tfl.equal"(%arg0, %[[CST]]) : (tensor, tensor<10xi32>) -> tensor + // CHECK: %[[CAST:.*]] = "tfl.cast"(%[[EQUAL]]) : (tensor) -> tensor + // CHECK-NEXT: return %[[CAST]] +} + // CHECK-LABEL: fuseOneHotCast func.func @fuseOneHotCast(%arg: tensor<2xi32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) { %depth = arith.constant dense<3> : tensor @@ -3118,3 +3232,160 @@ func.func @DontEliminateExtraSelect(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi1 // CHECK-NEXT: %[[SELECT_1:.*]] = "tfl.select_v2" // CHECK-NEXT: return %[[SELECT_1]] } + +// CHECK-LABEL: func @fuseReluToMin1_StaticShapeWithBroadcastedCst_Float1 +func.func @fuseReluToMin1_StaticShapeWithBroadcastedCst_Float1(%arg0: tensor<2x2xf32>) -> (tensor<2x2xf32>) { + %cst0 = arith.constant dense<0.0> : tensor + %0 = "tfl.maximum"(%arg0, %cst0) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %cst1 = arith.constant dense<1.0> : tensor + %1 = "tfl.minimum"(%0, %cst1) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + + func.return %1 : tensor<2x2xf32> + // CHECK-NOT: "tfl.relu" + // CHECK-NOT: "tfl.minimum" + // CHECK-NOT: "tfl.pseudo_const" + // CHECK: "tfl.relu_0_to_1"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> +} + +// CHECK-LABEL: func @fuseReluToMin1_StaticShapeWithBroadcastedCst_Float2 +func.func @fuseReluToMin1_StaticShapeWithBroadcastedCst_Float2(%arg0: tensor<2x2xf32>) -> (tensor<2x2xf32>) { + %cst0 = arith.constant dense<1.0> : tensor + %0 = "tfl.minimum"(%arg0, %cst0) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %cst1 = arith.constant dense<0.0> : tensor + %1 = "tfl.maximum"(%0, %cst1) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + + func.return %1 : tensor<2x2xf32> + // CHECK-NOT: "tfl.relu" + // CHECK-NOT: "tfl.minimum" + // CHECK-NOT: "tfl.pseudo_const" + // CHECK: "tfl.relu_0_to_1"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> +} + +// CHECK-LABEL: func @fuseReluToMin1_StaticShapeWithSameShapeCst_Float +func.func @fuseReluToMin1_StaticShapeWithSameShapeCst_Float2(%arg0: tensor<2x2xf32>) -> (tensor<2x2xf32>) { + %cst0 = arith.constant dense<1.0> : tensor<2x2xf32> + %0 = "tfl.minimum"(%arg0, %cst0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %cst1 = arith.constant dense<0.0> : tensor<2x2xf32> + %1 = "tfl.maximum"(%0, %cst1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + + func.return %1 : tensor<2x2xf32> + // CHECK-NOT: "tfl.relu" + // CHECK-NOT: "tfl.minimum" + // CHECK-NOT: "tfl.pseudo_const" + // CHECK: "tfl.relu_0_to_1"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> +} + + + +// CHECK-LABEL: func @fuseAddAndStridedSlice +func.func @fuseAddAndStridedSlice(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { + // CHECK: %cst = arith.constant dense<1> : tensor<1xi32> + // CHECK: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %1 = "tfl.strided_slice"(%arg0, %arg1, %cst, %0) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = true, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tfl.add"(%arg1, %cst_0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %cst_1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fuseSubAndStridedSlice +func.func @fuseSubAndStridedSlice(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { + // CHECK: %cst = arith.constant dense<1> : tensor<1xi32> + // CHECK: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %1 = "tfl.strided_slice"(%arg0, %arg1, %cst, %0) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = true, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tfl.sub"(%arg1, %cst_0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %cst_1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @dontFuseAddAndStridedSliceNonConstantStride +func.func @dontFuseAddAndStridedSliceNonConstantStrides(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor<4xi32> { + // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %1 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %2 = "tfl.strided_slice"(%arg0, %arg1, %1, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %cst = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tfl.add"(%arg1, %cst) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @dontFuseAddAndStridedSliceOffset +func.func @dontFuseAddAndStridedSliceOffset(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<4xi32> { + // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %1 = tfl.add(%arg2, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %2 = "tfl.strided_slice"(%arg0, %arg1, %1, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %cst = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tfl.add"(%arg2, %cst) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @dontFuseAddAndStridedSliceNonConstantOffset +func.func @dontFuseAddAndStridedSliceNonConstantOffset(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor<4xi32> { + // CHECK: %0 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<1xi32> + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %0, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %0 = "tfl.add"(%arg1, %arg1) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @dontFuseAddAndStridedSliceBeginMask +func.func @dontFuseAddAndStridedSliceBeginMask(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { + // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + // CHECK-DAG: %1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %2 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tfl.add"(%arg1, %cst_0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %cst_1) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @dontFuseAddAndStridedSliceEndMask +func.func @dontFuseAddAndStridedSliceEndMask(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { + // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + // CHECK-DAG: %1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %2 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tfl.add"(%arg1, %cst_0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %cst_1) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @dontFuseAddAndStridedSliceEllipsisMask +func.func @dontFuseAddAndStridedSliceEllipsisMask(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<4xi32> { + // CHECK-DAG: %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + // CHECK-DAG: %1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %2 = tfl.add(%arg1, %0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + // CHECK: %3 = "tfl.strided_slice"(%arg0, %arg1, %2, %1) {begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + + %cst_0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + %cst_1 = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tfl.add"(%arg1, %cst_0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %1 = "tfl.strided_slice"(%arg0, %arg1, %0, %cst_1) {begin_mask = 0 : i32, ellipsis_mask = 1 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fuseSigmoid +func.func @fuseSigmoid(%arg0: tensor<10xf32>) -> tensor<10xf32> { + // CHECK: "tfl.logistic" + %cst = arith.constant dense<1.000000e+00> : tensor<10xf32> + %0 = "tfl.neg"(%arg0) : (tensor<10xf32>) -> tensor<10xf32> + %1 = "tfl.exp"(%0) : (tensor<10xf32>) -> tensor<10xf32> + %2 = tfl.add %1, %cst {fused_activation_function = "NONE"} : tensor<10xf32> + %3 = tfl.div %cst, %2 {fused_activation_function = "NONE"} : tensor<10xf32> + return %3 : tensor<10xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir index b549b564515..01ed79e5a63 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir @@ -3,6 +3,8 @@ // RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="enable-float16-quantization" | FileCheck --check-prefix=Float16 %s // RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="enable-custom-op-quantization=CustomTestOp=1-3,CustomTestOp3=3" | FileCheck --check-prefix=CustomOp %s // RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=4000 enable-custom-op-quantization=CustomTestOp=1-3,CustomTestOp3=3" | FileCheck --check-prefix=MinElement %s +// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=19" | FileCheck --check-prefix=LSTMOpQuantized %s +// RUN: tf-opt %s -tfl-prepare-quantize-dynamic-range="min-elements-for-weights=21" | FileCheck --check-prefix=LSTMOpNotQuantized %s // CHECK-LABEL: QuantizeConv2D // PerTensor-LABEL: QuantizeConv2D @@ -409,3 +411,41 @@ func.func @LargeFloat16Constants(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112 // Float16-DAG: %[[w:.*]] = arith.constant dense<6.550400e+04> : tensor<64x3x3x3xf16> // Float16-DAG: %[[b:.*]] = arith.constant dense<-6.550400e+04> : tensor<64xf16> } + +// LSTMOpQuantized-LABEL: LSTMOpNotPartiallyQuantized +// LSTMOpNotQuantized-LABEL: LSTMOpNotPartiallyQuantized +func.func @LSTMOpNotPartiallyQuantized(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> { + %cst_2 = "tfl.no_value"() {value = unit} : () -> none + %cst_3 = arith.constant dense<1.0> : tensor<20x20xf32> + %cst_7 = arith.constant dense<1.0> : tensor<20xf32> + %recurrent_input = arith.constant dense<1.0> : tensor<1x20xf32> + %recurrent_stats = "quantfork.stats"(%recurrent_input) {layerStats = dense<[-2.0, 1.0]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32> + %cell_input = arith.constant dense<1.0> : tensor<1x20xf32> + %cell_stats = "quantfork.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32> + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, + %cst_3, %cst_3, %cst_3, %cst_3, + %cst_3, %cst_3, %cst_3, %cst_3, + %cst_7, %cst_7, %cst_7, + %cst_7, %cst_7, %cst_7, %cst_7, + %cst_3, %cst_2, + %recurrent_stats, %cell_stats, + %cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} + : ( tensor<1x28x28xf32>, + tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, + tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, + tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, + tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, + tensor<20x20xf32>, none, + tensor<1x20xf32>, tensor<1x20xf32>, + none, none, none, none) -> tensor<1x28x20xf32> + %1 = "quantfork.stats"(%0) {layerStats = dense<[-1.0, 2.0]> : tensor<2xf32>} : (tensor<1x28x20xf32>) -> tensor<1x28x20xf32> + func.return %1 : tensor<1x28x20xf32> + +// LSTMOpQuantized-DAG: %[[dq1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<20x20x!quant.uniform:f32, 0.0078740157480314959>>) -> tensor<20x20xf32> +// LSTMOpQuantized-DAG: %[[dq3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<20x!quant.uniform:f32, 0.0078740157480314959>>) -> tensor<20xf32> +// LSTMOpQuantized: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq1]], %[[dq3]], %[[dq3]], %[[dq3]], %cst_0, %cst_0, %cst_0, %cst_0, %[[dq1]], %0, %cst_1, %cst_1, %0, %0, %0, %0) + +// LSTMOpNotQuantized-DAG: %[[cst_1:.*]] = arith.constant dense<1.000000e+00> : tensor<20x20xf32> +// LSTMOpNotQuantized-DAG: %[[cst_3:.*]] = arith.constant dense<1.000000e+00> : tensor<20xf32> +// LSTMOpNotQuantized: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_1]], %[[cst_3]], %[[cst_3]], %[[cst_3]], %cst_0, %cst_0, %cst_0, %cst_0, %[[cst_1]], %0, %cst_1, %cst_1, %0, %0, %0, %0) +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index ea12951a97d..2a4b2af88f5 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -302,7 +302,7 @@ func.func @QuantizeSlice(tensor<2x3x5x!quant.uniform>, tensor<3xi32 func.func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> { ^bb0(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>): %0 = "tfl.dequantize"(%arg0) : (tensor<12x2x2x5x!quant.uniform>) -> tensor<12x2x2x5xf32> - %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> func.return %1 : tensor<1x2x2x5xf32> // CHECK: %0 = "tfl.dequantize"(%arg0) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index a668475a9e2..4f3914265b4 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -206,6 +206,19 @@ func.func @sharding(%arg0: tensor<10x10xi32>) -> (tensor<10x10xi32>) { // CHECK-NOT: %3 = "tf.XlaSharding"(%1) {_XlaSharding = "\08\03\1A\02\01\01\22\01\00", device = "", sharding = "\08\03\1A\02\01\01\22\01\00", unspecified_dims = []} : (tensor<10x10xi32>) -> tensor<10x10xi32> } +func.func @preventGradient(%arg0: tensor<10x10xi32>) -> (tensor<10x10xi32>) { + %0 = "tf.MatMul"(%arg0, %arg0) {device = "", transpose_a = false, transpose_b = false} : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<10x10xi32> + %1 = "tf.MatMul"(%arg0, %arg0) {device = "", transpose_a = false, transpose_b = false} : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<10x10xi32> + %2 = "tf.PreventGradient"(%0) : (tensor<10x10xi32>) -> tensor<10x10xi32> + %3 = "tf.PreventGradient"(%1) : (tensor<10x10xi32>) -> tensor<10x10xi32> + %4 = "tf.AddV2"(%2, %3) {device = ""} : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<10x10xi32> + func.return %4 : tensor<10x10xi32> + +// CHECK-LABEL: preventGradient +// CHECK-NOT: %2 = "tf.PreventGradient"(%0) : (tensor<10x10xi32>) -> tensor<10x10xi32> +// CHECK-NOT: %3 = "tf.PreventGradient"(%1) : (tensor<10x10xi32>) -> tensor<10x10xi32> +} + func.func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> { %166 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", _output_shapes = ["tfshape$dim { size = 1} dim { size = 1000}"], device = "", name = "matmul", transpose_a = false, transpose_b = false} : (tensor<1x1280xf32>, tensor<1280x1000xf32>) -> tensor<1x1000xf32> func.return %166 : tensor<1x1000xf32> @@ -737,4 +750,12 @@ func.func @UnsupportedGroupConv_DynamicDimAtInputDimThree(%arg0: tensor, %fill: tensor) -> (tensor) { + %0 = "tf.Fill"(%shape, %fill) : (tensor, tensor) -> (tensor<*xf32>) + %1 = "tf.Shape"(%0) : (tensor<*xf32>) -> (tensor) + func.return %1 : tensor + + // CHECK-LABEL: RedundantShapeOp + // CHECK-NOT: "tf.Shape" +} } diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir index 4431444c1ba..58dfed58a69 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir @@ -83,7 +83,7 @@ func.func @QuantizeReadAssign(%arg0: tensor<1x32x1x3xf32>) -> (tensor<1x34x1x3xf %4 = "tfl.concatenation"(%3, %1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x1x3xf32>, tensor<1x32x1x3xf32>) -> tensor<1x34x1x3xf32> %5 = "tfl.quantize"(%4) {qtype = tensor<1x34x1x3x!quant.uniform>, volatile} : (tensor<1x34x1x3xf32>) -> tensor<1x34x1x3x!quant.uniform> %6 = "tfl.dequantize"(%5) : (tensor<1x34x1x3x!quant.uniform>) -> tensor<1x34x1x3xf32> - %7 = "tfl.strided_slice"(%6, %cst_1, %cst_0, %cst) {begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> + %7 = "tfl.strided_slice"(%6, %cst_1, %cst_0, %cst) {begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> %8 = "tfl.quantize"(%7) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> %9 = "tfl.dequantize"(%8) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> "tfl.assign_variable"(%2, %9) : (tensor, tensor<1x2x1x3xf32>) -> () @@ -100,7 +100,7 @@ func.func @QuantizeReadAssign(%arg0: tensor<1x32x1x3xf32>) -> (tensor<1x34x1x3xf // CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%[[dq2]], %[[dq1]]) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x1x3xf32>, tensor<1x32x1x3xf32>) -> tensor<1x34x1x3xf32> // CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc]]) {qtype = tensor<1x34x1x3x!quant.uniform>, volatile} : (tensor<1x34x1x3xf32>) -> tensor<1x34x1x3x!quant.uniform> // CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x34x1x3x!quant.uniform>) -> tensor<1x34x1x3xf32> -// CHECK-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[dq3]], %[[cst_1]], %[[cst_0]], %[[cst]]) {begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> +// CHECK-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[dq3]], %[[cst_1]], %[[cst_0]], %[[cst]]) {begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32} : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> // CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%[[ss]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> // CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %[[q3]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () // CHECK-NEXT: return %[[dq3]] : tensor<1x34x1x3xf32> @@ -127,7 +127,7 @@ func.func @QuantizeConvVariable(%arg0: tensor<1x3x1x1xf32>) -> (tensor<1x3x1x1xf %11 = "tfl.concatenation"(%7, %10) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x3x1x1xf32>, tensor<1x3x1x1xf32>) -> tensor<1x6x1x1xf32> %12 = "tfl.quantize"(%11) {qtype = tensor<1x6x1x1x!quant.uniform>, volatile} : (tensor<1x6x1x1xf32>) -> tensor<1x6x1x1x!quant.uniform> %13 = "tfl.dequantize"(%12) : (tensor<1x6x1x1x!quant.uniform>) -> tensor<1x6x1x1xf32> - %14 = "tfl.strided_slice"(%13, %cst_1, %cst_0, %cst) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 13 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x6x1x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3x1x1xf32> + %14 = "tfl.strided_slice"(%13, %cst_1, %cst_0, %cst) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 13 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<1x6x1x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3x1x1xf32> %15 = "tfl.quantize"(%14) {qtype = tensor<1x3x1x1x!quant.uniform>, volatile} : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1x!quant.uniform> %16 = "tfl.dequantize"(%15) : (tensor<1x3x1x1x!quant.uniform>) -> tensor<1x3x1x1xf32> "tfl.assign_variable"(%6, %16) : (tensor, tensor<1x3x1x1xf32>) -> () @@ -157,7 +157,7 @@ func.func @QuantizeTwoVariable(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) %41 = "quantfork.stats"(%40) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> %42 = "tfl.concatenation"(%41, %0) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x4x3xf32> %43 = "quantfork.stats"(%42) {layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>} : (tensor<1x4x3xf32>) -> tensor<1x4x3xf32> - %44 = "tfl.strided_slice"(%43, %1, %2, %3) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 5 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x4x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3xf32> + %44 = "tfl.strided_slice"(%43, %1, %2, %3) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 5 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<1x4x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3xf32> %45 = "quantfork.stats"(%44) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> "tfl.assign_variable"(%4, %45) : (tensor, tensor<1x2x3xf32>) -> () @@ -165,7 +165,7 @@ func.func @QuantizeTwoVariable(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) %51 = "quantfork.stats"(%50) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> %52 = "tfl.concatenation"(%51, %0) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x4x3xf32> %53 = "quantfork.stats"(%52) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x4x3xf32>) -> tensor<1x4x3xf32> - %54 = "tfl.strided_slice"(%53, %1, %2, %3) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 5 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x4x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3xf32> + %54 = "tfl.strided_slice"(%53, %1, %2, %3) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 5 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<1x4x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3xf32> %55 = "quantfork.stats"(%54) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> "tfl.assign_variable"(%5, %55) : (tensor, tensor<1x2x3xf32>) -> () diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 84323be0555..8a7625d672b 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" +#include #include +#include #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -27,6 +29,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" @@ -122,6 +126,10 @@ void AddDynamicRangeQuantizationPasses( void AddConvertHloToTfPass(std::string entry_function_name, mlir::OpPassManager* pass_manager) { + pass_manager->addPass(mlir::odml::CreateRenameEntrypointToMainPass()); + pass_manager->addPass( + mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + pass_manager->addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // Legalize jax random to tflite custom op. // The CreateLegalizeJaxRandom Pass has to stay at because we need to replace // the random function body before being inlined. @@ -150,6 +158,14 @@ void AddConvertHloToTfPass(std::string entry_function_name, pass_manager->addNestedPass( mlir::TF::CreateLegalizeHloToTfPass()); + // folds tf.BroadcastTo ops with subsequent ops if they have built in + // broadcasting support. This needs to be run immediately after HLO->TF + // legalization; otherwise other passes like `ConvertTFBroadcastTo` will + // constant fold the newly generated TF broadcast ops and materialize the + // weights. + pass_manager->addNestedPass( + mlir::TF::CreateBroadcastFoldPass()); + // Canonicalization after TF legalization. pass_manager->addNestedPass( mlir::createCanonicalizerPass()); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 2d25ef59a14..8b4057bc625 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -15,7 +15,10 @@ limitations under the License. #include #include +#include #include +#include +#include #include "absl/strings/str_split.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 86dbe9c513e..51fd8dbc23e 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -165,24 +165,27 @@ StatusOr> LoadFromGraphdefOrMlirSource( auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); if (!extra_opdefs_status.ok()) return extra_opdefs_status; + ::tensorflow::GraphdefToMlirOptions graphdef_conversion_options{ + std::string(debug_info_file), + /*xla_compile_device_type=*/"", + /*prune_unused_nodes=*/specs.prune_unused_nodes, + /*convert_legacy_fed_inputs=*/true, + /*graph_as_function=*/false, + specs.upgrade_legacy, + /*enable_shape_inference=*/false, + /*unconditionally_use_set_output_shapes=*/true, + /*enable_soft_placement=*/false}; + if (use_splatted_constant) { return tensorflow::GraphdefToSplattedMlirTranslateFunction( - file->getBuffer(), debug_info_file, /*xla_compile_device_type=*/"", - input_arrays, input_dtypes, input_shapes, output_arrays, - control_output_arrays, specs.prune_unused_nodes, - /*convert_legacy_fed_inputs=*/true, - /*graph_as_function=*/false, specs.upgrade_legacy, - /*enable_shape_inference=*/false, - /*unconditionally_use_set_output_shapes=*/true, context); + file->getBuffer(), input_arrays, input_dtypes, input_shapes, + output_arrays, control_output_arrays, graphdef_conversion_options, + context); } return tensorflow::GraphdefToMlirTranslateFunction( - file->getBuffer(), debug_info_file, /*xla_compile_device_type=*/"", - input_arrays, input_dtypes, input_shapes, output_arrays, - control_output_arrays, specs.prune_unused_nodes, - /*convert_legacy_fed_inputs=*/true, - /*graph_as_function=*/false, specs.upgrade_legacy, - /*enable_shape_inference=*/false, - /*unconditionally_use_set_output_shapes=*/true, context); + file->getBuffer(), input_arrays, input_dtypes, input_shapes, + output_arrays, control_output_arrays, graphdef_conversion_options, + context); } // Applying post-training dynamic range quantization from the old TOCO quantizer @@ -321,8 +324,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::PassManager pass_manager(module.getContext()); mlir::registerPassManagerCLOptions(); if (mlir::failed(mlir::applyPassManagerCLOptions(pass_manager))) { - return tensorflow::FromAbslStatus( - absl::UnknownError("failed to apply MLIR pass manager CL options")); + return absl::UnknownError("failed to apply MLIR pass manager CL options"); } pass_manager.addInstrumentation( std::make_unique( diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 173f16ab488..95c9817f560 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ +#include #include #include #include +#include #include "absl/types/span.h" #include "llvm/Support/SourceMgr.h" @@ -53,7 +55,7 @@ tsl::StatusOr> LoadFromGraphdefOrMlirSource( // Load Saved model (either v1 or v2) into MLIR. // 'saved_model_bundle' will be initialized if V1 model was loaded. tsl::StatusOr> ImportSavedModel( - const std::string& input_filename, const int saved_model_version, + const std::string& input_filename, int saved_model_version, const std::unordered_set& tags, absl::Span extra_tf_opdefs, absl::Span exported_names, const GraphImportConfig& specs, diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index a020a4be43a..34373268527 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -50,6 +50,10 @@ def ConvertToQuantTypeFromAttrs : NativeCodeCall< def convertIntAttrTo32Bit : NativeCodeCall< "$_builder.getI32IntegerAttr($0.cast().getInt())">; +// Builds a constant bool attribute. +class GetBoolAttr : + NativeCodeCall<"$_builder.getBoolAttr(" # value #")">; + // Converts an integer attribute $0 to 64-bit with builder. def convertIntAttrTo64Bit : NativeCodeCall< "$_builder.getI64IntegerAttr($0.cast().getInt())">; @@ -69,6 +73,10 @@ def CreateTFCastToInt32Op : NativeCodeCall< def CreateNoneValue : NativeCodeCall< "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; +// Creates an int32 constant op from an integer attribute $0. +def CreateInt32ConstOpFromIntAttr + : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast($0.cast().getInt())}))">; + //===----------------------------------------------------------------------===// // Nullary ops patterns. //===----------------------------------------------------------------------===// @@ -373,6 +381,16 @@ def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>; +def ReductionDimensionIsLastDim : Constraint().getInt() == " + "$1.getType().cast().getRank() - 1 || $0.cast().getInt() == -1)">>; + +// Legalizes TF_ApproxTopKOp to TFL_TopKV2Op with the following constraints: +// 1. It computes max k +// 2. The reduction dimension is the last dim of the input. +def LegalizeApproxTopK : Pat<(TF_ApproxTopKOp $input, $k, $reduction_dimenstion, $ignored_recall_target, /*is_max_k*/ConstBoolAttrTrue, $ignored_reduction_input_size_override, $ignored_aggregate_to_topk), + (TFL_TopKV2Op $input, (CreateInt32ConstOpFromIntAttr $k)), + [(ReductionDimensionIsLastDim $reduction_dimenstion, $input)]>; + def LegalizeMin : Pat< (TF_MinOp $arg0, $axes, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>; @@ -534,7 +552,8 @@ def LegalizeStridedSlice : Pat< (CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask), (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask), (convertIntAttrTo32Bit $new_axis_mask), - (convertIntAttrTo32Bit $shrink_axis_mask))>; + (convertIntAttrTo32Bit $shrink_axis_mask), + (GetBoolAttr))>; def LegalizeRfft2d : Pat< (TF_RFFT2DOp $input, $fft_length), diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 9b74c6bf606..7b31bcbc1a1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -236,8 +237,8 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, return false; } -// Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value -// of `indices` are from 0 to n-1, the output tensor are identical to the +// Returns true if we can eliminate the GatherNdOp or ScatterNdOp. When the +// value of `indices` are from 0 to n-1, the output tensor are identical to the // `params`. bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, DenseIntElementsAttr indices, @@ -344,15 +345,25 @@ TypeAttr RescaleQtype(Type input, Attribute factor) { // Returns shape of a ranked tensor. // Precondition: output_val's is ranked tensor. -DenseElementsAttr GetShape(Value output_val) { +// Returns a truncated shape when `truncate` is set to true. +DenseElementsAttr GetShape(Value output_val, bool truncate = false) { auto output_type = output_val.getType().cast(); SmallVector shape; shape.reserve(output_type.getRank()); - for (int64_t dim : output_type.getShape()) { + + bool needs_truncation = true; + for (size_t dim_idx = 0; dim_idx < output_type.getRank(); ++dim_idx) { + int64_t dim = output_type.getShape()[dim_idx]; + if (truncate && needs_truncation && dim == 1) { + continue; + } else if (needs_truncation && dim != 1) { + needs_truncation = false; + } shape.push_back(ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); } + return mlir::DenseElementsAttr::get( RankedTensorType::get( {static_cast(shape.size())}, @@ -360,6 +371,34 @@ DenseElementsAttr GetShape(Value output_val) { llvm::ArrayRef(shape)); } +// Utility function to map final permutation to initial permutation +// initial -> permutation1 -> permutation2 -> final +DenseElementsAttr RemapPermutation(Value permutation1, Value permutation2) { + SmallVector initial_permutation; + DenseElementsAttr perm1_const; + DenseElementsAttr perm2_const; + + SmallVector new_permutation; + if (matchPattern(permutation1, m_Constant(&perm1_const)) && + matchPattern(permutation2, m_Constant(&perm2_const))) { + for (int32_t idx = 0; idx < perm1_const.getNumElements(); ++idx) { + initial_permutation.push_back(idx); + } + for (auto perm : perm2_const.getValues()) { + new_permutation.push_back( + initial_permutation[perm1_const + .getValues()[perm.getSExtValue()] + .getSExtValue()]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(new_permutation.size())}, + mlir::IntegerType::get(permutation1.getContext(), 32)), + llvm::ArrayRef(new_permutation)); +} + // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in // the specified `shape` and `false` otherwise. static bool ShapeMatchesReduceWithKeepAxes(Value input, @@ -480,8 +519,12 @@ Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value) { // so we could cast safely here. auto type = value.getType().cast(); SmallVector new_shape; - for (int64_t dim : type.getShape().drop_back()) { - new_shape.push_back(dim); + if (type.hasStaticShape()) { + for (int64_t dim : type.getShape().drop_back()) { + new_shape.push_back(dim); + } + } else { + new_shape.push_back(-1); } return builder.create( value.getLoc(), value, @@ -635,6 +678,78 @@ TypedAttr ConvertSingleElementAttrToFloatAttr(Attribute attr) { #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" +struct FuseAddAndStridedSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::StridedSliceOp strided_slice_op, + PatternRewriter &rewriter) const override { + // Match Add + mlir::TFL::AddOp add_op = + dyn_cast_or_null(strided_slice_op.getEnd().getDefiningOp()); + mlir::TFL::SubOp sub_op = + dyn_cast_or_null(strided_slice_op.getEnd().getDefiningOp()); + if (!(add_op || sub_op)) { + return failure(); + } + + // Check that add rhs is constant. + DenseElementsAttr added_value; + Value constant_val = add_op ? add_op.getRhs() : sub_op.getRhs(); + if (!matchPattern(constant_val, m_Constant(&added_value))) return failure(); + + // Check the add op is applied to begin. + mlir::TypedValue<::mlir::TensorType> begin_tensor = + strided_slice_op.getBegin(); + mlir::TypedValue<::mlir::TensorType> add_source_tensor = + add_op ? add_op.getLhs() : sub_op.getLhs(); + if (begin_tensor != add_source_tensor) { + return failure(); + } + + // Check that strides are constant + DenseElementsAttr strides_value; + Value strides_val = strided_slice_op.getStrides(); + if (!matchPattern(strides_val, m_Constant(&strides_value))) + return failure(); + + mlir::TensorType constant_val_type = + constant_val.getType().cast(); + // If it's not 1D or 0D (which can be broadcasted to 1D), reject the + // matching. + if (constant_val_type.getRank() > 1) { + return failure(); + } + + mlir::RankedTensorType end_type = + strided_slice_op.getEnd().getType().dyn_cast(); + // begin, end and strides are Rank 1 tensors with one element per dimension + // of input. + int64_t num_dims = end_type.getShape()[0]; + DenseElementsAttr new_added_value = + added_value.reshape(RankedTensorType::get( + {num_dims}, + added_value.getType().cast().getElementType())); + ::mlir::arith::ConstantOp new_end = rewriter.create( + strided_slice_op.getEnd().getLoc(), new_added_value); + + if (strided_slice_op.getBeginMask() != 0) return failure(); + if (strided_slice_op.getEndMask() != 0) return failure(); + if (strided_slice_op.getEllipsisMask() != 0) return failure(); + mlir::TFL::StridedSliceOp new_strided_slice_op = + rewriter.create( + strided_slice_op.getLoc(), strided_slice_op.getOutput().getType(), + strided_slice_op.getInput(), strided_slice_op.getBegin(), new_end, + strided_slice_op.getStrides(), strided_slice_op.getBeginMask(), + strided_slice_op.getEndMask(), strided_slice_op.getEllipsisMask(), + strided_slice_op.getNewAxisMask(), + strided_slice_op.getShrinkAxisMask(), + /*offset=*/true); + rewriter.replaceOp(strided_slice_op, new_strided_slice_op.getOutput()); + + return success(); + } +}; + // Fuse Add with proceeding FullyConnected. // TODO(b/136285429): Move to tablegen when variadic is supported struct FuseFullyConnectedAndAdd : public OpRewritePattern { @@ -655,6 +770,9 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { // Check if the constant RHS is either 0D (scalar), or a 1D with // `{num_channels}` shape. auto constant_val_type = constant_val.getType().cast(); + if (constant_val_type.getRank() > 1) { + return failure(); + } // In TFLite FullyConnect definition, bias must be a 1D tensor where // the number of elements is equal to the number of channels. @@ -716,7 +834,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { .getOutput(); } else { // If the RHS is neither a scalar constant nor a 1d constant, look - // if there is opportunity to reduce the dimentionality and allow + // if there is opportunity to reduce the dimensionality and allow // implicit broadcasting auto new_added_value = added_value.reshape(RankedTensorType::get( @@ -1803,7 +1921,7 @@ void OptimizePass::runOnOperation() { FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs, FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp, RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected, - FuseUnpackAndConcatToReshape, OptimizeTopK>(ctx); + FuseUnpackAndConcatToReshape, OptimizeTopK, FuseAddAndStridedSlice>(ctx); if (!this->disable_fuse_mul_and_fc_) { phase_2_patterns.add(ctx); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 216ac15c034..01357c332d5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -679,16 +679,56 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; +// Returns truncated shape of a ranked-tensor. +// Truncated, here, means eliminating any contiguous 1s' in the lower +// dimentions of the tensor +def GetTruncatedShape: NativeCodeCall<"GetShape($0, true)">; + // Returns True if the operand type is RankedTensorType and valid. def HasValidRankedTensor : Constraint() && " "$0.getType().cast().getNumDynamicDims() <= 1">>; +// Check if the truncated shape of the lhs is equal to the shape of rhs +def IsTruncatedShapeEqualTo : Constraint>; + def ConvertSqueezeToReshape : Pat< (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $squeeze_op))), [(HasValidRankedTensor $squeeze_op)]>; +// Pattern to perform the following optimization +// transpose [1xAxB] -> [1xBxA] +// | +// reshape [1xBxA] -> [BxA] +// | +// transpose [BxA] -> [AxB] +// || +// reshape [1xAxB] -> [AxB] +def ConvertTrasposeReshapeTransposeToReshape : Pat< + (TFL_TransposeOp:$second_transpose + (TFL_ReshapeOp:$middle_reshape + (TFL_TransposeOp:$first_transpose $input, $permutation2), + $shape), + $permutation1), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetTruncatedShape $input))), + [(IsTruncatedShapeEqualTo $first_transpose, $middle_reshape), + (IsTruncatedShapeEqualTo $input, $second_transpose)]>; + +// Function to map final permutation to initial permutation +// initial -> permutation1 -> permutation2 -> final +def RemapPermutation: NativeCodeCall<"RemapPermutation($0, $1)">; + +// Pattern to fuse redundant tanspose op +def FoldDoubleTranspose : Pat< + (TFL_TransposeOp + (TFL_TransposeOp:$transpose_out1 $input, (Arith_ConstantOp:$permutation1 $p1)), + (Arith_ConstantOp:$permutation2 $p2)), + (TFL_TransposeOp $input, + (Arith_ConstantOp (RemapPermutation $permutation1, $permutation2))), + [(HasOneUse $transpose_out1)]>; + // Convert expand_dims to reshape if possible. def ConvertExpandDimsToReshape : Pat< (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), @@ -714,6 +754,19 @@ def MinimumOfReluAnd6ToRelu6 : (TFL_Relu6Op $x), [(IsConstantValueOf<6> $y)]>; +// For both relu1 and relu_0_to_1, the min/max operators commute, +// so there are two possible orderings we need to rewrite. +// Concretely, `m < n -> max(m, min(n, x)) = min(m, max(m, x))`. +// Proof: +// case (x <= m) +// max(m, min(n, x)) = max(m, m) = m and +// min(n, max(m, x)) = min(n, m) = m +// case (m < x < n) +// max(m, min(n, x)) = max(m, x) = x and +// min(n, max(m, x)) = min(n, x) = x +// case (n <= x) +// max(m, min(n, x)) = max(m, n) = n and +// min(n, max(m, x)) = min(n, x) = n def MatchRelu1Pattern1 : Pat< (TFL_MinimumOp (TFL_MaximumOp $input, (Arith_ConstantOp $NegOne)), (Arith_ConstantOp $One)), @@ -726,6 +779,18 @@ def MatchRelu1Pattern2 : Pat< (TFL_Relu1Op $input), [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; +def MatchRelu0To1Pattern1: Pat< + (TFL_MinimumOp (TFL_MaximumOp $x, (Arith_ConstantOp $max_cst)), + (Arith_ConstantOp $min_cst)), + (TFL_Relu0To1Op $x), + [(FloatValueEquals<"0"> $max_cst), (FloatValueEquals<"1"> $min_cst)]>; + +def MatchRelu0To1Pattern2: Pat< + (TFL_MaximumOp (TFL_MinimumOp $x, (Arith_ConstantOp $min_cst)), + (Arith_ConstantOp $max_cst)), + (TFL_Relu0To1Op $x), + [(FloatValueEquals<"0"> $max_cst), (FloatValueEquals<"1"> $min_cst)]>; + def MatchLeakyRelu : Pat< (TFL_MaximumOp (TFL_MulOp:$mul_out $x, @@ -1027,9 +1092,18 @@ def ReshapeValueDroppingLastDim : NativeCodeCall< def IsOneHotIndexAttribute : Constraint>; -// Checks if the shape has shape with last dimension equals 1. +// Checks if the shape has static shape with last dimension equals 1. def IsLastDimensionEqualOne : Constraint>; +// As above but if shape is not static and rank 2 with last dim 1. +def IsLastDimensionEqualOneOrDynamicBatchDimRank2 : Constraint< + CPred<"IsLastDimensionEqualOne($0) || " + "(!$0.getType().cast().hasStaticShape() && " + " $0.getType().cast().hasRank() && " + " $0.getType().cast().getRank() == 2 && " + " !$0.getType().cast().getShape().empty() && " + " $0.getType().cast().getShape()[1] == 1)">>; + // Replace // Equal(X, indices) // With @@ -1044,7 +1118,7 @@ def ReshapeEqualOpToOneHotOp : Pat< (Arith_ConstantOp ConstantAttr, "true">), (Arith_ConstantOp ConstantAttr, "false">), ConstantAttr), - [(IsLastDimensionEqualOne $x), + [(IsLastDimensionEqualOneOrDynamicBatchDimRank2 $x), (HasRankAtLeast<2> $x), (IsOneHotIndexAttribute $series)]>; @@ -1258,3 +1332,18 @@ def SimplifyDoubleSelectFCZerosRHS : Pat< (AllValuesAreZero $zeros_1), (AllValuesAreZero $zeros_2) ]>; + +def FuseSigmoid : Pat< + (TFL_DivOp + (Arith_ConstantOp F32ElementsAttr:$ones), + (TFL_AddOp:$add_out + (TFL_ExpOp:$exp_out + (TFL_NegOp:$neg_out $arg)), + (Arith_ConstantOp F32ElementsAttr:$ones_1), TFL_AF_None), TFL_AF_None), + (TFL_LogisticOp $arg), + [(FloatValueEquals<"1"> $ones_1), + (FloatValueEquals<"1"> $ones), + (HasOneUse $neg_out), + (HasOneUse $add_out), + (HasOneUse $exp_out), + ]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 9064d6c7f50..47fc9df2ba5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -76,6 +76,7 @@ def ConvertPlaceholderWithDefault : Pat<(TF_PlaceholderWithDefaultOp $arg), (TF_ //===----------------------------------------------------------------------===// // Op removal patterns. //===----------------------------------------------------------------------===// +def RemovePreventGradient : Pat<(TF_PreventGradientOp $op, $msg), (replaceWithValue $op)>; def RemoveXlaSharding : Pat<(TF_XlaShardingOp $a, $b, $c), (replaceWithValue $a)>; def RemoveIdentityN : Pat<(TF_IdentityNOp $arg), (replaceWithValue $arg)>; @@ -196,3 +197,10 @@ def LowerUInt32AddV2 : Pat< (CreateTFCastOpI32 $rhs, /*truncate=*/ConstBoolAttrFalse)), /*truncate=*/ConstBoolAttrFalse), [(TensorOf<[TF_Uint32]> $lhs), (TensorOf<[TF_Uint32]> $rhs)]>; + +//===----------------------------------------------------------------------===// +// Fill op patterns. +//===----------------------------------------------------------------------===// + +def RemoveRedundantShapeOp : + Pat<(TF_ShapeOp (TF_FillOp $shape, $fill)), (replaceWithValue $shape)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index a19c29a666f..951748b3127 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -84,6 +84,10 @@ class PrepareDynamicRangeQuantizePass void runOnOperation() override; private: + // Keeps track of ops whose inputs cannot be quantized due to not meeting the + // minimum_elements_for_weights threshold. Prevents emitting duplicate + // warnings for the same op, once deemed ineligible for quantization. + llvm::SetVector visited_nonquantizable_ops_; quant::QuantizationSpecs quant_specs_; }; @@ -95,8 +99,10 @@ class PrepareDynamicRangeQuantizableOp : public OpRewritePattern { public: explicit PrepareDynamicRangeQuantizableOp( - MLIRContext* context, const quant::QuantizationSpecs& quant_specs) + MLIRContext* context, const quant::QuantizationSpecs& quant_specs, + llvm::SetVector* const visited_nonquantizable_ops) : OpRewritePattern(context), + visited_nonquantizable_ops_(visited_nonquantizable_ops), quant_specs_(quant_specs) {} LogicalResult matchAndRewrite(arith::ConstantOp op, @@ -129,6 +135,8 @@ class PrepareDynamicRangeQuantizableOp } private: + llvm::SetVector* const visited_nonquantizable_ops_; + // Check if the operand_index is included in the quantizable_indices. bool isQuantizableIndex(const int operand_index, const std::vector& quantizable_indices) const { @@ -142,6 +150,10 @@ class PrepareDynamicRangeQuantizableOp // specification for checking the support. For custom ops, it checks the // provided map. bool hasInt8QuantizableOperandAt(Operation* op, int operand_index) const { + if (visited_nonquantizable_ops_->contains(op)) { + return false; + } + if (auto custom_op = llvm::dyn_cast_or_null(op)) { std::string op_name = custom_op.getCustomCode().str(); auto custom_map_iter = quant_specs_.custom_map.find(op_name); @@ -152,7 +164,53 @@ class PrepareDynamicRangeQuantizableOp llvm::dyn_cast(op)) { const auto& quantizable_indices = quantizable_op.GetQuantizableOperandIndices(); - return isQuantizableIndex(operand_index, quantizable_indices); + + if (!isQuantizableIndex(operand_index, quantizable_indices)) { + return false; + } + + // Special case handling for UnidirectionalSequenceLSTMOp, which doesn't + // support partial quantization of its inputs. + // Below, we check all of the input constants for the + // UnidirectionalSequenceLSTMOp to see if any of them would not be + // quantized due to not meeting the minimum_elements_for_weights + // threshold. Should we find any, we don't quantize any of the ops. + if (!llvm::dyn_cast(op)) { + return true; + } + + for (int qi : quantizable_indices) { + auto const_op = llvm::dyn_cast_or_null( + op->getOperand(qi).getDefiningOp()); + if (!const_op) { + continue; + } + + DenseFPElementsAttr attr; + if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) { + continue; + } + + if (attr.dyn_cast().size() >= + quant_specs_.minimum_elements_for_weights) { + continue; + } + + visited_nonquantizable_ops_->insert(op); + op->emitWarning( + "Skipped quantization for UnidirectionalSequenceLSTMOp. Partial " + "quantization of inputs for UnidirectionalSequenceLSTMOp is not " + "supported. The operand ") + << const_op->getName().getStringRef().str() << " at index " << qi + << " was not quantized because it has " + << attr.dyn_cast().size() + << " elements which is fewer than the " + "`minimum_elements_for_weights` threshold of " + << quant_specs_.minimum_elements_for_weights; + return false; + } + + return true; } return false; } @@ -427,7 +485,8 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() { removeAllStatsOp(func); RewritePatternSet patterns(&getContext()); - patterns.add(ctx, quant_specs_); + patterns.add(ctx, quant_specs_, + &visited_nonquantizable_ops_); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); ConvertMlirQuantOpsToTFLQuantOps(func); diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 9f2301d4803..86d0509ceb7 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include +#include #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc index f6b257128dd..28c6106dcb7 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" #include +#include #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h index 093e53c0869..77b047f68c6 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ #include +#include #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index f1f006e93a2..18320bba3c9 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index 7421fe2faa8..9e01b5dbf75 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -58,7 +58,7 @@ class ConvertLSTMCellSimpleToFusedLSTM { delete; ConvertLSTMCellSimpleToFusedLSTM& operator=( const ConvertLSTMCellSimpleToFusedLSTM&) = delete; - virtual ~ConvertLSTMCellSimpleToFusedLSTM() {} + virtual ~ConvertLSTMCellSimpleToFusedLSTM() = default; virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; } @@ -184,7 +184,7 @@ class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=( const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; - ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {} + ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override = default; llvm::StringRef GetCompositeOpName() override { return kLayerNormalizedLstmCellSimple; diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h index e5d7fd1a639..b78f7c86e45 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index a542267a14a..7ce9c56086e 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index e16feb92652..6ff354fb23b 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -51,7 +51,7 @@ bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, IntegerAttr *y); // Returns true if the attribute is an integer list of the form [1, X, Y, 1], -bool TFIntListIs1XY1(const Attribute attr); +bool TFIntListIs1XY1(Attribute attr); // Returns true if the given `op` // * has an attribute with the given `name`, @@ -62,7 +62,7 @@ bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, // Returns true if every element of the attribute is 1. All elements of `attr` // must be `IntegerAttr`. -bool TFIntListIsAllOnes(const Attribute attr); +bool TFIntListIsAllOnes(Attribute attr); // Returns true iff the given value is a float32 tensor. // is "DT_FLOAT". diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index dc69f3d64bb..c1753dd34fb 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" #include "absl/container/flat_hash_set.h" @@ -137,7 +139,8 @@ static void RegisterDialects(mlir::DialectRegistry& registry) { Status MlirFunctionOptimizationPass::Run( const std::string& function_name, const DeviceSet& device_set, - const ConfigProto& config_proto, absl::string_view xla_compile_device_type, + const ConfigProto& config_proto, + const FunctionOptimizationPass::FunctionOptions& function_options, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, bool* control_rets_updated) { @@ -208,7 +211,9 @@ Status MlirFunctionOptimizationPass::Run( // the shape inference pass is run early in the pass pipeline, shape inference // during import is not necessary. import_config.enable_shape_inference = false; - import_config.xla_compile_device_type = xla_compile_device_type; + import_config.xla_compile_device_type = + function_options.xla_compile_device_type; + import_config.enable_soft_placement = function_options.allow_soft_placement; static const char* kTfMlirCategory = "TfMlir"; tensorflow::metrics::ScopedCounter<2> timings( diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index d3a8420af94..059147d4ea9 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -17,7 +17,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ #include +#include +#include #include +#include +#include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -120,7 +124,7 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass { // Executes all of the underlying registered MlirOptimizationPasses. Status Run(const std::string& function_name, const DeviceSet& device_set, const ConfigProto& config_proto, - absl::string_view xla_compile_device_type, + const FunctionOptimizationPass::FunctionOptions& function_options, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, bool* control_rets_updated) override; diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index 4e7d1449946..95d669ff9bf 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -40,45 +40,43 @@ constexpr char kFailure[] = "kFailure"; class MockMlirOptimizationPass : public MlirOptimizationPass { public: - // MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX - // instead. - MOCK_CONST_METHOD0(name, llvm::StringRef()); - MOCK_CONST_METHOD4(GetPassState, - MlirOptimizationPassState( - const DeviceSet* device_set, - const ConfigProto& config_proto, const Graph& graph, - const FunctionLibraryDefinition& function_library)); - MOCK_METHOD5(Run, Status(const std::string& function_name, - const ConfigProto& config_proto, - mlir::ModuleOp module, const Graph& graph, - const FunctionLibraryDefinition& function_library)); + MOCK_METHOD(llvm::StringRef, name, (), (const, override)); + MOCK_METHOD(MlirOptimizationPassState, GetPassState, + (const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library), + (const, override)); + MOCK_METHOD(Status, Run, + (const std::string& function_name, + const ConfigProto& config_proto, mlir::ModuleOp module, + const Graph& graph, + const FunctionLibraryDefinition& function_library), + (override)); }; class MockMlirV1CompatOptimizationPass : public MlirV1CompatOptimizationPass { public: - // MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX - // instead. - MOCK_CONST_METHOD0(name, llvm::StringRef()); - MOCK_CONST_METHOD4(GetPassState, - MlirOptimizationPassState( - const DeviceSet* device_set, - const ConfigProto& config_proto, const Graph& graph, - const FunctionLibraryDefinition& function_library)); - MOCK_METHOD2(Run, Status(const GraphOptimizationPassOptions& options, - mlir::ModuleOp module)); + MOCK_METHOD(llvm::StringRef, name, (), (const, override)); + MOCK_METHOD(MlirOptimizationPassState, GetPassState, + (const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library), + (const, override)); + MOCK_METHOD(Status, Run, + (const GraphOptimizationPassOptions& options, + mlir::ModuleOp module), + (override)); }; class ModifyMlirModulePass : public MlirOptimizationPass { public: explicit ModifyMlirModulePass(Status run_status) : run_status_(run_status) {} - // MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX - // instead. - MOCK_CONST_METHOD0(name, llvm::StringRef()); - MOCK_CONST_METHOD4(GetPassState, - MlirOptimizationPassState( - const DeviceSet* device_set, - const ConfigProto& config_proto, const Graph& graph, - const FunctionLibraryDefinition& function_library)); + MOCK_METHOD(llvm::StringRef, name, (), (const, override)); + MOCK_METHOD(MlirOptimizationPassState, GetPassState, + (const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library), + (const, override)); // Just modify MLIR module so that we can check whether original TF graph // has changed or not. @@ -187,12 +185,12 @@ class MlirGraphOptimizationPassTest : public Test { } ConfigProto config_proto_; + FunctionOptimizationPass::FunctionOptions function_options_; MlirFunctionOptimizationPass function_optimization_pass_; DeviceSet device_set_; std::unique_ptr graph_; std::unique_ptr flib_; std::vector control_ret_node_names_; - std::string xla_compile_device_type_; bool control_rets_updated_{false}; monitoring::testing::CellReader mlir_function_pass_fallback_count_ = monitoring::testing::CellReader( @@ -219,11 +217,11 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoFallback) { GraphDef original_graph_def; graph_->ToGraphDef(&original_graph_def); - EXPECT_EQ(function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, - xla_compile_device_type_, &graph_, flib_.get(), - &control_ret_node_names_, &control_rets_updated_), - Status(absl::StatusCode::kAborted, "aborted")); + EXPECT_EQ( + function_optimization_pass_.Run( + "test_func", device_set_, config_proto_, function_options_, &graph_, + flib_.get(), &control_ret_node_names_, &control_rets_updated_), + Status(absl::StatusCode::kAborted, "aborted")); verifyGraph(original_graph_def); verifyCounters(); } @@ -246,11 +244,11 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) { AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, Status(absl::StatusCode::kAborted, "aborted")); - EXPECT_EQ(function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, - xla_compile_device_type_, &graph_, flib_.get(), - &control_ret_node_names_, &control_rets_updated_), - OkStatus()); + EXPECT_EQ( + function_optimization_pass_.Run( + "test_func", device_set_, config_proto_, function_options_, &graph_, + flib_.get(), &control_ret_node_names_, &control_rets_updated_), + OkStatus()); verifyGraph(original_graph_def); verifyCounters(); } @@ -263,11 +261,11 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailFallback) { AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, OkStatus()); - EXPECT_EQ(function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, - xla_compile_device_type_, &graph_, flib_.get(), - &control_ret_node_names_, &control_rets_updated_), - OkStatus()); + EXPECT_EQ( + function_optimization_pass_.Run( + "test_func", device_set_, config_proto_, function_options_, &graph_, + flib_.get(), &control_ret_node_names_, &control_rets_updated_), + OkStatus()); verifyGraph(original_graph_def, true); verifyCounters(); @@ -281,11 +279,11 @@ TEST_F(MlirGraphOptimizationPassTest, GraphDoesntConvertUpdatesCounter) { AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, OkStatus()); - EXPECT_EQ(function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, - xla_compile_device_type_, &graph_, flib_.get(), - &control_ret_node_names_, &control_rets_updated_), - OkStatus()); + EXPECT_EQ( + function_optimization_pass_.Run( + "test_func", device_set_, config_proto_, function_options_, &graph_, + flib_.get(), &control_ret_node_names_, &control_rets_updated_), + OkStatus()); EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kOk), 0); EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kInvalidArgument), diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index cbd03639c02..f5912553f10 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -40,7 +40,7 @@ static inline llvm::StringRef StringViewToRef(absl::string_view view) { namespace tensorflow { -OpOrArgNameMapper::~OpOrArgNameMapper() {} +OpOrArgNameMapper::~OpOrArgNameMapper() = default; llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix, int hash_value) { diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 7cc1d25355e..b709ede8956 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/python/mlir.h" +#include #include #include +#include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" @@ -388,6 +391,8 @@ void ExperimentalWriteBytecode(const std::string& filename, } mlir::FallbackAsmResourceMap fallback_resource_map; mlir::BytecodeWriterConfig writer_config(fallback_resource_map); + // TODO(jpienaar): Make this an option to the call. + writer_config.setDesiredBytecodeVersion(1); std::string error; std::unique_ptr outputFile = mlir::openOutputFile(filename, &error); @@ -446,6 +451,8 @@ void ExperimentalTFLiteToTosaBytecode( } mlir::FallbackAsmResourceMap fallback_resource_map; mlir::BytecodeWriterConfig writer_config(fallback_resource_map); + // TODO(jpienaar): Make this an option to the call. + writer_config.setDesiredBytecodeVersion(1); std::string error; std::unique_ptr outputFile = mlir::openOutputFile(tosa_bytecode_file, &error); diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD index 807d0f497df..3e184602595 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -20,8 +20,8 @@ tf_python_pybind_extension( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/python:pybind11_lib", - "//tensorflow/python:pybind11_status", + "//tensorflow/python/lib/core:pybind11_lib", + "//tensorflow/python/lib/core:pybind11_status", "@llvm-project//llvm:FileCheckLib", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -36,8 +36,8 @@ tf_python_pybind_extension( srcs = ["filecheck_wrapper.cc"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:pybind11_lib", - "//tensorflow/python:pybind11_status", + "//tensorflow/python/lib/core:pybind11_lib", + "//tensorflow/python/lib/core:pybind11_status", "@llvm-project//llvm:FileCheckLib", "@llvm-project//llvm:Support", "@pybind11", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index f85f8f13882..50e4037fa1c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -15,6 +15,15 @@ package_group( ] + internal_visibility_allowlist(), ) +package( + # copybara:uncomment default_applicable_licenses = ["@stablehlo//:license"], + default_visibility = [ + ":internal_visibility_allowlist_package", + "//tensorflow:__pkg__", + ], + licenses = ["notice"], +) + # TODO(b/264218457): Add quantize and post_quantize passes. cc_library( name = "passes", @@ -26,19 +35,18 @@ cc_library( ], compatible_with = get_compatible_with_cloud(), deps = [ + ":fill_quantization_options", ":quantization_options_proto_cc", ":stablehlo_passes_inc_gen", - "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/core/platform:path", "//third_party/eigen3", - "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@stablehlo//:stablehlo_ops", @@ -59,11 +67,14 @@ cc_library( compatible_with = get_compatible_with_cloud(), visibility = [":internal_visibility_allowlist_package"], deps = [ + ":fill_quantization_options", ":passes", ":quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/core/platform:path", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", ], @@ -87,6 +98,17 @@ gentbl_cc_library( ], ) +cc_library( + name = "fill_quantization_options", + srcs = ["utils/fill_quantization_options.cc"], + hdrs = ["utils/fill_quantization_options.h"], + compatible_with = get_compatible_with_cloud(), + deps = [ + ":quantization_options_proto_cc", + "@llvm-project//llvm:Support", + ], +) + tf_proto_library( name = "quantization_options_proto", srcs = ["quantization_options.proto"], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h index 788a00f349c..acd3657484e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -31,7 +31,8 @@ namespace stablehlo { // Creates a pass that quantizes weight component of StableHLO graph. std::unique_ptr> CreateQuantizeWeightPass( - ::stablehlo::quantization::QuantizationOptions quantization_options); + ::stablehlo::quantization::QuantizationComponentSpec + quantization_component_spec); } // namespace stablehlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc index 9d5d0cc8e91..d4480dbf170 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc @@ -14,29 +14,31 @@ limitations under the License. ==============================================================================*/ #include -#include #include -#include #include #include #include "third_party/eigen3/Eigen/Core" -#include "llvm/Support/CommandLine.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h" // NOLINTNEXTLINE //===----------------------------------------------------------------------===// @@ -50,6 +52,7 @@ namespace { #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" using QuantizationUnits = llvm::SetVector>; +using ::stablehlo::quantization::QuantizationComponentSpec; // Min/Max values used for creating ConstantOp. constexpr float kMaxFloat16Value = 65504.f; @@ -61,7 +64,8 @@ class QuantizeWeightPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeWeightPass) explicit QuantizeWeightPass( - ::stablehlo::quantization::QuantizationOptions quantization_options) {} + QuantizationComponentSpec quantization_component_spec) + : quantization_component_spec_(quantization_component_spec) {} StringRef getArgument() const final { // This is the argument used to refer to the pass in @@ -75,13 +79,17 @@ class QuantizeWeightPass private: void runOnOperation() override; + QuantizationComponentSpec quantization_component_spec_; }; // Collects quantizable target ops, then insert Q-DQ quantization patterns. class QuantizeWeight : public OpRewritePattern { public: - explicit QuantizeWeight(MLIRContext* context) - : OpRewritePattern(context) {} + explicit QuantizeWeight( + MLIRContext* context, + const QuantizationComponentSpec& quantization_component_spec) + : OpRewritePattern(context), + quantization_component_spec_(quantization_component_spec) {} LogicalResult matchAndRewrite(ConstantOp op, PatternRewriter& rewriter) const override { @@ -104,6 +112,7 @@ class QuantizeWeight : public OpRewritePattern { } private: + const QuantizationComponentSpec quantization_component_spec_; // Marks users that are applicable for quantization where the criteria for // determining quantizable ops differs by the inference type. QuantizationUnits GetQuantizableOps(ConstantOp op) const { @@ -125,7 +134,6 @@ class QuantizeWeight : public OpRewritePattern { // Returns whether quantization is applied to filtered users. bool QuantizeOps(PatternRewriter& rewriter, ConstantOp op, const QuantizationUnits& quantizable_ops) const { - // TODO(b/212514817): refactor mode checking to improve code quality. for (const std::pair& quant_op : quantizable_ops) { // For f16 quantization, quantize all constant ops as float16. QuantizeOpAsFloat16(rewriter, op, quant_op); @@ -222,9 +230,9 @@ class QuantizeWeight : public OpRewritePattern { void QuantizeWeightPass::runOnOperation() { func::FuncOp func = getOperation(); MLIRContext* ctx = func.getContext(); - RewritePatternSet patterns(ctx); - patterns.add(ctx); + + patterns.add(ctx, quantization_component_spec_); FrozenRewritePatternSet frozen_patterns(std::move(patterns)); @@ -237,8 +245,8 @@ void QuantizeWeightPass::runOnOperation() { // Creates an instance of the StableHLO dialect Quantize Weight pass. std::unique_ptr> CreateQuantizeWeightPass( - ::stablehlo::quantization::QuantizationOptions quantization_options) { - return std::make_unique(quantization_options); + QuantizationComponentSpec quantization_component_spec) { + return std::make_unique(quantization_component_spec); } } // namespace stablehlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.cc index 05290bcb126..31bb012e372 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.cc @@ -14,9 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h" +#include "absl/container/flat_hash_set.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" namespace stablehlo { @@ -24,8 +28,30 @@ namespace quantization { void AddQuantizationPasses(mlir::PassManager& pass_manager, const QuantizationOptions& quantization_options) { + QuantizationOptions quantization_options_ = quantization_options; + if (quantization_options.quantization_method() + .has_preset_quantization_method()) { + quantization_options_ = + mlir::stablehlo::FillPresetQuantizationOptions(quantization_options); + } + + // TODO(b/276999414): Add activation and bias quantization component as + // respective quantization passes are created. + QuantizationComponentSpec weight_component; + for (const auto& component : quantization_options_.quantization_method() + .custom_quantization_method() + .quantization_component_spec()) { + switch (component.quantization_component()) { + case QuantizationComponentSpec::COMPONENT_WEIGHT: + weight_component = component; + break; + default: + break; + } + } + pass_manager.addNestedPass( - mlir::stablehlo::CreateQuantizeWeightPass(quantization_options)); + mlir::stablehlo::CreateQuantizeWeightPass(weight_component)); } } // namespace quantization diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD index 00c76a029e9..4b657b51762 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD @@ -1,5 +1,6 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -7,6 +8,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir/quantization/stablehlo:run_lit.sh", size_override = { @@ -27,3 +29,15 @@ filegroup( # TODO(b/254144841): Add tests in this directory with the proper stablehlo-opt. ], ) + +tf_cc_test( + name = "fill_quantization_options_test", + srcs = ["fill_quantization_options_test.cc"], + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:fill_quantization_options", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_options_proto_cc", + "//tensorflow/tsl/platform:protobuf", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/fill_quantization_options_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tests/fill_quantization_options_test.cc new file mode 100644 index 00000000000..55ef992934b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/fill_quantization_options_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h" + +#include +#include + +#include +#include +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" +#include "tensorflow/tsl/platform/protobuf.h" + +namespace mlir::stablehlo { +namespace { + +using ::stablehlo::quantization::PresetQuantizationMethod; +using ::stablehlo::quantization::QuantizationComponentSpec; +using ::stablehlo::quantization::QuantizationOptions; + +// Simple implementation of ::testing::EqualsProto equivalent until open source +// b/135192747 is fixed. Originally from type_to_shape_test.cc. +class ProtoStringMatcher { + public: + explicit ProtoStringMatcher(const tsl::protobuf::Message& expected) + : expected_(expected.SerializeAsString()) {} + + template + bool MatchAndExplain(const Message& p, testing::MatchResultListener*) const { + return p.SerializeAsString() == expected_; + } + + void DescribeTo(::std::ostream* os) const { *os << expected_; } + void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_; + } + + private: + const std::string expected_; +}; + +inline ::testing::PolymorphicMatcher EqualsProto( + const tsl::protobuf::Message& x) { + return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); +} + +void FillPresetQuantizationOptionsTestHelper( + const PresetQuantizationMethod::PresetMethod preset_quantization_options, + const QuantizationComponentSpec expected_activation_component, + const QuantizationComponentSpec expected_weight_component, + const QuantizationComponentSpec expected_bias_component) { + QuantizationOptions quantization_options; + quantization_options.mutable_quantization_method() + ->mutable_preset_quantization_method() + ->set_preset_method(preset_quantization_options); + QuantizationOptions filled_quantization_options = + FillPresetQuantizationOptions(quantization_options); + for (QuantizationComponentSpec component : + filled_quantization_options.quantization_method() + .custom_quantization_method() + .quantization_component_spec()) { + switch (component.quantization_component()) { + case (QuantizationComponentSpec::COMPONENT_ACTIVATION): + EXPECT_THAT(component, EqualsProto(expected_activation_component)); + break; + case (QuantizationComponentSpec::COMPONENT_WEIGHT): + EXPECT_THAT(component, EqualsProto(expected_weight_component)); + break; + case (QuantizationComponentSpec::COMPONENT_BIAS): + EXPECT_THAT(component, EqualsProto(expected_bias_component)); + break; + default: + break; + } + } +} + +TEST(FillQuantizationOptionsTest, PresetFloat16) { + QuantizationComponentSpec activation_component, weight_component, + bias_component; + weight_component.set_quantization_component( + QuantizationComponentSpec::COMPONENT_WEIGHT); + weight_component.set_bit_width(QuantizationComponentSpec::BIT_WIDTH_16); + weight_component.set_bit_type(QuantizationComponentSpec::BIT_TYPE_FLOAT); + bias_component.set_quantization_component( + QuantizationComponentSpec::COMPONENT_BIAS); + bias_component.set_bit_width(QuantizationComponentSpec::BIT_WIDTH_16); + bias_component.set_bit_type(QuantizationComponentSpec::BIT_TYPE_FLOAT); + + FillPresetQuantizationOptionsTestHelper( + /*preset_quantization_options=*/PresetQuantizationMethod::FLOAT16, + /*expected_activation_component=*/activation_component, + /*expected_weight_component*/ weight_component, + /*expected_bias_component*/ bias_component); +} + +} // namespace +} // namespace mlir::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.cc new file mode 100644 index 00000000000..bff29736476 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.cc @@ -0,0 +1,71 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Debug.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace mlir { +namespace stablehlo { + +using ::stablehlo::quantization::CustomQuantizationMethod; +using ::stablehlo::quantization::PresetQuantizationMethod; +using ::stablehlo::quantization::QuantizationComponentSpec; + +// Returns QuantizationOptions filled with detailed specs when user specifies +// an optional preset method name. The preset methods are defined in +// quantization_options.proto. This function will only be executed if a user +// gives a preset method, not a custom method. +::stablehlo::quantization::QuantizationOptions FillPresetQuantizationOptions( + ::stablehlo::quantization::QuantizationOptions quantization_options_) { + CustomQuantizationMethod custom_method = + quantization_options_.quantization_method().custom_quantization_method(); + QuantizationComponentSpec *weight_component, *bias_component; + auto preset_method = quantization_options_.quantization_method() + .preset_quantization_method() + .preset_method(); + if (!preset_method) return quantization_options_; + switch (preset_method) { + case PresetQuantizationMethod::FLOAT16: + weight_component = custom_method.add_quantization_component_spec(); + weight_component->set_quantization_component( + QuantizationComponentSpec::COMPONENT_WEIGHT); + weight_component->set_bit_width(QuantizationComponentSpec::BIT_WIDTH_16); + weight_component->set_bit_type(QuantizationComponentSpec::BIT_TYPE_FLOAT); + bias_component = custom_method.add_quantization_component_spec(); + bias_component->set_quantization_component( + QuantizationComponentSpec::COMPONENT_WEIGHT); + bias_component->set_bit_width(QuantizationComponentSpec::BIT_WIDTH_16); + bias_component->set_bit_type(QuantizationComponentSpec::BIT_TYPE_FLOAT); + break; + // Note: This is weight-only quantization by default, but with the legacy + // flag "--force_dynamic_range_in_kernel", a DRQ behavior will be forced + // in the kernel. + case PresetQuantizationMethod::WEIGHT_ONLY: + weight_component = custom_method.add_quantization_component_spec(); + weight_component->set_quantization_component( + QuantizationComponentSpec::COMPONENT_WEIGHT); + weight_component->set_bit_width(QuantizationComponentSpec::BIT_WIDTH_8); + weight_component->set_bit_type(QuantizationComponentSpec::BIT_TYPE_INT); + break; + default: + break; + } + *quantization_options_.mutable_quantization_method() + ->mutable_custom_quantization_method() = custom_method; + return quantization_options_; +} + +} // namespace stablehlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h b/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h new file mode 100644 index 00000000000..782920826c6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_FILL_QUANTIZATION_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_FILL_QUANTIZATION_OPTIONS_H_ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace mlir { +namespace stablehlo { + +::stablehlo::quantization::QuantizationOptions FillPresetQuantizationOptions( + ::stablehlo::quantization::QuantizationOptions quantization_options); + +} // namespace stablehlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_FILL_QUANTIZATION_OPTIONS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 2d42d137f9b..6e6a6c8077f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -22,17 +22,6 @@ package( licenses = ["notice"], ) -cc_library( - name = "constants", - hdrs = [ - "constants.h", - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - "@com_google_absl//absl/strings", - ], -) - py_binary( name = "gen_quantized_function_library", srcs = ["gen_quantized_function_library.py"], @@ -78,6 +67,23 @@ cc_library( ], ) +cc_library( + name = "manipulate_model_attr", + srcs = [ + "passes/manipulate_model_attr.cc", + ], + hdrs = [ + "passes/manipulate_model_attr.h", + ], + compatible_with = get_compatible_with_cloud(), + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "remove_identity_op_pattern", srcs = [ @@ -375,11 +381,13 @@ cc_library( "passes/insert_restore_op.cc", "passes/insert_save_op.cc", "passes/issue_ids_of_custom_aggregation_ops.cc", + "passes/lift_hashtable_ops_as_args.cc", "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions.inc", "passes/lift_quantizable_spots_as_functions_drq.cc", "passes/lift_quantizable_spots_as_functions_drq.inc", "passes/mark_functions_noinline.cc", + "passes/merge_duplicate_resource_ops.cc", "passes/merge_initializer_function_ops_to_main.cc", "passes/merge_save_function_ops_to_main.cc", "passes/optimize.cc", @@ -408,7 +416,7 @@ cc_library( ], compatible_with = get_compatible_with_cloud(), deps = [ - ":constants", + ":manipulate_model_attr", ":pass_utils", ":quantization_options_proto_cc", ":remove_identity_op_pattern", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index 831bf9980ca..1f1ae5a13de 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -93,7 +93,7 @@ tf_py_test( ":gen_custom_aggregator_op_wrapper", "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/quantization/tensorflow/python:pywrap_quantize_model", - "//tensorflow/python:client_testlib", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/platform:client_testlib", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc index 3f153106f8a..23e7ee54f13 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" #include +#include #include +#include namespace tensorflow { namespace calibrator { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h index 7c0f830505a..c87fed2569c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATOR_SINGLETON_H_ #include +#include #include #include diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc index 189c0b319f6..cd8a473fa90 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include +#include + #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc index a911b495e20..f640b1aa3d7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/custom_aggregator_op.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc index 4ad3a370efc..1e442583b1d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc @@ -70,12 +70,12 @@ absl::StatusOr AddTensorToBundleWriter( if (const tsl::Status status = mlir::tfg::ConvertToTensor( /*attr=*/const_op.getValue(), /*output_tensor=*/&const_tensor); !status.ok()) { - return tsl::ToAbslStatus(status); + return status; } if (!bundle_writer.Add(/*key=*/var_handle_op.getSharedName(), const_tensor) .ok()) { - return tsl::ToAbslStatus(bundle_writer.status()); + return bundle_writer.status(); } return var_handle_op.getSharedName().str(); @@ -97,7 +97,7 @@ absl::StatusOr> SaveVariablesToCheckpoint( BundleWriter bundle_writer(Env::Default(), prefix); if (!bundle_writer.status().ok()) { - return tsl::ToAbslStatus(bundle_writer.status()); + return bundle_writer.status(); } std::vector saved_variable_shared_names; @@ -122,7 +122,7 @@ absl::StatusOr> SaveVariablesToCheckpoint( } if (!bundle_writer.Finish().ok()) { - return tsl::ToAbslStatus(bundle_writer.status()); + return bundle_writer.status(); } return saved_variable_shared_names; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc index 8967b64b877..fefff2345f6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc @@ -114,8 +114,7 @@ TEST_F(SaveVariablesToCheckpointTest, VariableSavedToCheckpoint) { BundleReader bundle_reader(env_, *checkpoint_prefix); Tensor loaded_tensor{}; - EXPECT_TRUE( - tsl::ToAbslStatus(bundle_reader.Lookup("var_0", &loaded_tensor)).ok()); + EXPECT_TRUE(bundle_reader.Lookup("var_0", &loaded_tensor).ok()); ExpectEqual(loaded_tensor, AsTensor({1.0, 2.0})); } @@ -161,13 +160,11 @@ TEST_F(SaveVariablesToCheckpointTest, MultipleVariablesSavedToCheckpoint) { BundleReader bundle_reader(env_, *checkpoint_prefix); Tensor loaded_var_0{}; - EXPECT_TRUE( - tsl::ToAbslStatus(bundle_reader.Lookup("var_0", &loaded_var_0)).ok()); + EXPECT_TRUE(bundle_reader.Lookup("var_0", &loaded_var_0).ok()); ExpectEqual(loaded_var_0, AsTensor({1.0, 2.0})); Tensor loaded_var_1{}; - EXPECT_TRUE( - tsl::ToAbslStatus(bundle_reader.Lookup("var_1", &loaded_var_1)).ok()); + EXPECT_TRUE(bundle_reader.Lookup("var_1", &loaded_var_1).ok()); ExpectEqual(loaded_var_1, AsTensor({3, 4, 5, 6})); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD index 448ac05842f..879ccc88de0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/quantization:__subpackages__", ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc index fe1b205ce1d..c5157fed64c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc @@ -77,7 +77,7 @@ absl::StatusOr> CreateMlirDumpFile( auto *env = tsl::Env::Default(); const tsl::Status status = env->RecursivelyCreateDir(*dump_dir); if (!status.ok()) { - return tsl::ToAbslStatus(status); + return status; } std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h index db13cd19f08..803cd39a0a5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h @@ -35,7 +35,7 @@ void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm); // level < 1 or TF_QUANT_MLIR_DUMP_PREFIX is not set or set to an empty string. // The returned ostream instance should live until the pass run is complete. absl::StatusOr> MaybeEnableIrPrinting( - mlir::PassManager &pm, const absl::string_view name); + mlir::PassManager &pm, absl::string_view name); } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc index a3cb59b241c..c3759ff75c6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 2970cafffc0..9671f1b17eb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc index 0d60c9a2020..0e6ce592ea0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc @@ -96,7 +96,7 @@ bool ShouldIncludeInMainFunction(func::FuncOp func_op) { void SetFunctionPrivate(func::FuncOp func) { func.setVisibility(SymbolTable::Visibility::Private); - // The `tf_saved_model` attributes can only be appied to public functions. + // The `tf_saved_model` attributes can only be applied to public functions. for (auto& attr : func->getAttrs()) { StringRef attr_name = attr.getName().getValue(); if (attr_name.startswith("tf_saved_model.")) { @@ -136,7 +136,7 @@ struct OutputInfo { }; // Makes input/output names across entry functions unique if necessary. If a -// dupliated name is found, this function will add signature prefix for all the +// duplicated name is found, this function will add signature prefix for all the // input/output names. void GetUniqueInputOutputNodeNames(ModuleOp module_op, std::vector& input_name_vec, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/issue_ids_of_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/issue_ids_of_custom_aggregation_ops.cc index 513076b51fb..0d1302c99f8 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/issue_ids_of_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/issue_ids_of_custom_aggregation_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc new file mode 100644 index 00000000000..175bf572074 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc @@ -0,0 +1,210 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +namespace mlir { +namespace quant { +namespace { + +constexpr StringRef kSharedNameAttr = "shared_name"; + +class LiftHashTableOpsAsArgsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftHashTableOpsAsArgsPass) + explicit LiftHashTableOpsAsArgsPass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "quant-lift-hashtable-ops-as-args"; + } + StringRef getDescription() const final { + return "Lifts HashTable ops as function arguments."; + } + + void runOnOperation() override; +}; + +// Checks if the given op is a Hashtable op. +bool IsHashTableOp(Operation* op) { + return llvm::isa(op); +} + +// Checks if the function is the main or initializer function. +bool IsMainOrInitializerFunction(ModuleOp module, func::FuncOp func) { + if (func.getSymName().equals(tensorflow::kImportModelDefaultGraphFuncName) || + func.getSymName().equals(kTfQuantSaveFuncName)) { + return true; + } + + for (func::FuncOp init_func : + tf_saved_model::GetInitializerFunctions(module)) { + if (func.getSymName().equals(init_func.getSymName())) { + return true; + } + } + return false; +} + +// Checks if the function is only used by supported ops. Returns false when the +// function has no uses. Currently, only PartitionedCall is supported. +// TODO(b/284222309): Support lifting for functions called by control flow. +bool UsedBySupportedOps(ModuleOp module, func::FuncOp func) { + auto function_uses = + SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + if (!function_uses.has_value()) return false; + for (auto& function_use : function_uses.value()) { + if (!llvm::isa( + function_use.getUser())) { + return false; + } + } + return true; +} + +// Returns the `shared_name` attribute value if exists. If not, returns an +// empty string. +StringRef GetSharedName(Operation* op) { + if (!op->hasAttrOfType(kSharedNameAttr)) return ""; + return op->getAttrOfType(kSharedNameAttr).getValue(); +} + +// Checks if the HashTable is initialized. This function assumes that the +// HashTable is initialized if it appears in the initializer since it can't +// check the actual value. +bool IsResourceInitialized(ModuleOp module_op, Operation* hash_table) { + StringRef shared_name = GetSharedName(hash_table); + if (shared_name.empty()) return false; + + for (func::FuncOp init_func_op : + tf_saved_model::GetInitializerFunctions(module_op)) { + for (Operation& op : init_func_op.getBody().getOps()) { + StringRef other_shared_name = GetSharedName(&op); + if (IsHashTableOp(&op) && other_shared_name.equals(shared_name)) { + return true; + } + } + } + return false; +} + +// Lifts HashTable ops in the target function as function arguments and returns +// the lifted ops. These ops will then be added to the caller function and +// passed to the target function. +LogicalResult LiftHashTableOpsToArguments(ModuleOp module_op, + func::FuncOp target_func) { + if (!llvm::hasSingleElement(target_func)) return success(); + if (!UsedBySupportedOps(module_op, target_func)) return success(); + if (IsMainOrInitializerFunction(module_op, target_func)) return success(); + + llvm::StringMap shared_name_to_arg_idx; + llvm::SmallDenseMap lifted_op_to_arg_idx; + Block& block = target_func.front(); + auto func_type = target_func.getFunctionType(); + + for (Operation& op : block.without_terminator()) { + StringRef shared_name = GetSharedName(&op); + if (shared_name.empty() || !IsHashTableOp(&op)) continue; + if (!IsResourceInitialized(module_op, &op)) continue; + + auto it = + shared_name_to_arg_idx.insert({shared_name, block.getNumArguments()}); + if (it.second) { + auto resource_type = op.getResult(0).getType(); + op.getResult(0).replaceAllUsesWith( + block.addArgument(resource_type, op.getLoc())); + AddEntryFunctionInput( + absl::StrCat("hash_table_", it.first->getValue(), ":0"), target_func); + // Avoid deleting the op here, clone it to the caller function first. + lifted_op_to_arg_idx.insert({&op, it.first->getValue()}); + } else { + op.getResult(0).replaceAllUsesWith( + block.getArgument(it.first->getValue())); + op.erase(); + } + } + if (lifted_op_to_arg_idx.empty()) return success(); + + // Update the function signature as well as its uses. + target_func.setType(FunctionType::get(target_func.getContext(), + block.getArgumentTypes(), + func_type.getResults())); + + IRMapping mapping; + OpBuilder builder(module_op); + OpBuilder::InsertionGuard g(builder); + // The function has been checked to have at least one use. + auto function_uses = + SymbolTable::getSymbolUses(target_func, &module_op.getBodyRegion()); + for (auto& function_use : function_uses.value()) { + auto call_op = function_use.getUser(); + auto caller_func = call_op->getParentOfType(); + if (!caller_func) return failure(); + + builder.setInsertionPoint(call_op); + for (auto [lifted_op, arg_idx] : lifted_op_to_arg_idx) { + auto new_op = builder.clone(*lifted_op, mapping); + call_op->insertOperands(arg_idx, new_op->getResult(0)); + } + + // Try to lift recursively until the main function. + if (failed(LiftHashTableOpsToArguments(module_op, caller_func))) { + return failure(); + } + } + + // Erase the lifted operations explicitly. + for (auto [lifted_op, arg_idx] : lifted_op_to_arg_idx) { + lifted_op->erase(); + } + + return success(); +} + +void LiftHashTableOpsAsArgsPass::runOnOperation() { + auto module_op = getOperation(); + + for (auto func_op : module_op.getOps()) { + if (failed(LiftHashTableOpsToArguments(module_op, func_op))) { + signalPassFailure(); + return; + } + } +} + +static PassRegistration pass; + +} // namespace + +std::unique_ptr> CreateLiftHashTableOpsAsArgsPass() { + return std::make_unique(); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.cc new file mode 100644 index 00000000000..06784f8dba5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.cc @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" + +#include +#include + +#include "llvm/ADT/StringExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project + +namespace mlir { +namespace quant { + +constexpr StringRef kTfEntryFunctionAttr = "tf.entry_function"; + +void AddEntryFunctionInput(StringRef input_name, func::FuncOp func_op) { + auto entry_func_attr = + func_op->getAttrOfType(kTfEntryFunctionAttr); + if (!entry_func_attr) return; + + auto entry_func_attrs = SmallVector(entry_func_attr.begin(), + entry_func_attr.end()); + + MLIRContext* ctx = func_op.getContext(); + for (auto& named_attr : entry_func_attrs) { + if (named_attr.getName() != "inputs") continue; + + // Splits the "inputs" field to retrieve individual input names. Ignores + // empty strings. + SmallVector inputs_attrs{}; + cast(named_attr.getValue()) + .strref() + .split(inputs_attrs, /*Separator=*/',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + + inputs_attrs.emplace_back(input_name); + + const std::string new_inputs_attr_str = + llvm::join(std::move(inputs_attrs), /*Separator=*/","); + + named_attr.setValue(StringAttr::get(ctx, new_inputs_attr_str)); + } + + func_op->setAttr(kTfEntryFunctionAttr, + DictionaryAttr::get(ctx, entry_func_attrs)); +} +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h new file mode 100644 index 00000000000..d42ad360034 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_MANIPULATE_MODEL_ATTR_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_MANIPULATE_MODEL_ATTR_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project + +namespace mlir { +namespace quant { + +// Adds a new input name to the `inputs` field of the `tf.entry_function` +// attribute if the attribute exist in the given function. Otherwise, no +// attribute is modified. +void AddEntryFunctionInput(StringRef input_name, func::FuncOp func_op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_MANIPULATE_MODEL_ATTR_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_duplicate_resource_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_duplicate_resource_ops.cc new file mode 100644 index 00000000000..be179db7306 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_duplicate_resource_ops.cc @@ -0,0 +1,139 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace quant { +namespace { + +using ::mlir::tf_executor::GraphOp; +using ::mlir::tf_executor::IslandOp; + +constexpr StringRef kSharedNameAttr = "shared_name"; + +class MergeDuplicateResourceOpsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeDuplicateResourceOpsPass) + + StringRef getArgument() const final { + return "quant-merge-duplicate-resource-ops"; + } + + StringRef getDescription() const final { + return "Merge resource ops that have the same shared name."; + } + + void runOnOperation() override; +}; + +// Checks if the island op contains a resource op like Variable or Hashtable +// and returns that resource op. Otherwise, returns null. +Operation* GetResourceOp(Operation* op) { + // Check if the island has only one block thats contain two ops, including + // one resource op and one Yield op. + auto island_op = llvm::dyn_cast_or_null(op); + if (!island_op || !island_op.getBody().hasOneBlock()) return nullptr; + auto& island_block = island_op.getBody().front(); + if (++island_block.begin() != --island_block.end()) return nullptr; + + Operation* resource_op = &island_block.front(); + if (llvm::isa(resource_op)) { + return resource_op; + } + return nullptr; +} + +// Returns the `shared_name` attribute value if exists. If not, returns an +// empty string. +StringRef GetSharedName(Operation* op) { + if (!op->hasAttrOfType(kSharedNameAttr)) return ""; + return op->getAttrOfType(kSharedNameAttr).getValue(); +} + +// Gets the GraphOp from the function op. Returns an empty op iff it doesn't +// exist. +// TODO(b/284222084): Move executor dialect utilities to a new library. +GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { + if (func_op->getNumRegions() == 0 || func_op.getBody().empty()) return {}; + + auto graph_op_range = func_op.front().without_terminator(); + if (llvm::hasSingleElement(graph_op_range)) { + // The pass runs on a valid tf_executor dialect, so the op should be the + // GraphOp. + return cast(graph_op_range.begin()); + } + + return {}; +} + +void MergeDuplicateResourceOpsPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + GraphOp graph_op = GetGraphOpFromFuncOp(func_op); + if (!graph_op) return; + + llvm::StringMap shared_name_to_resource; + llvm::SmallVector ops_to_remove; + for (Operation& op : graph_op.GetBody().without_terminator()) { + Operation* resource_op = GetResourceOp(&op); + if (!resource_op) continue; + StringRef shared_name = GetSharedName(resource_op); + if (shared_name.empty()) continue; + + if (!shared_name_to_resource.contains(shared_name)) { + shared_name_to_resource[shared_name] = resource_op; + continue; + } + + auto existing_resource = shared_name_to_resource[shared_name]; + if (resource_op->getName().getStringRef() != + existing_resource->getName().getStringRef() || + resource_op->getResult(0).getType() != + existing_resource->getResult(0).getType()) { + resource_op->emitOpError( + "This op has the same `shared_name` but different type with another " + "resource op in the function"); + signalPassFailure(); + return; + } + op.replaceAllUsesWith(existing_resource->getParentOp()->getResults()); + ops_to_remove.push_back(&op); + } + + // Remove op after the loop to avoid crash. + for (Operation* op : ops_to_remove) { + op->erase(); + } +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> +CreateMergeDuplicateResourceOpsPass() { + return std::make_unique(); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc index 221d3318730..6e94beb6b0a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -205,42 +206,6 @@ FailureOr> GetInitFuncOps( return init_func_ops; } -// If `main_func_op` has the `tf.entry_function` attribute, adds a new input -// name to the `inputs` field of the attribute. Otherwise, no attribute is -// modified. -void MaybeAddEntryFunctionInput(const StringRef input_name, - func::FuncOp main_func_op) { - auto entry_func_attr = - main_func_op->getAttrOfType("tf.entry_function"); - if (!entry_func_attr) return; - - auto entry_func_attrs = SmallVector(entry_func_attr.begin(), - entry_func_attr.end()); - - MLIRContext* ctx = main_func_op.getContext(); - for (auto& named_attr : entry_func_attrs) { - if (named_attr.getName() != "inputs") continue; - - // Splits the "inputs" field to retrieve individual input names. Ignores - // empty strings. - SmallVector inputs_attrs{}; - cast(named_attr.getValue()) - .strref() - .split(inputs_attrs, /*Separator=*/',', /*MaxSplit=*/-1, - /*KeepEmpty=*/false); - - inputs_attrs.emplace_back(input_name); - - const std::string new_inputs_attr_str = - llvm::join(std::move(inputs_attrs), /*Separator=*/","); - - named_attr.setValue(StringAttr::get(ctx, new_inputs_attr_str)); - } - - main_func_op->setAttr("tf.entry_function", - DictionaryAttr::get(ctx, entry_func_attrs)); -} - // Creates new arguments to the main function that corresponds to the source // function's arguments. Returns the `IRMapping` that contains the // relationship. @@ -265,7 +230,7 @@ IRMapping CloneSrcFuncArgumentsToMainFunc(func::FuncOp src_func_op, const std::string new_input_name = absl::StrCat(GetInitializerType(src_func_op), "_", src_arg_idx, ":0"); - MaybeAddEntryFunctionInput(new_input_name, main_func_op); + AddEntryFunctionInput(new_input_name, main_func_op); // During cloning, let it know that the source function's argument // corresponds to the main function's newly created argument when cloning diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc index e5037fe4962..caef5c034f4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -35,8 +36,6 @@ using ::mlir::tf_executor::IslandOp; using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; using ::tensorflow::kImportModelDefaultGraphFuncName; -constexpr StringRef kTfEntryFunctionAttr = "tf.entry_function"; - class MergeSaveFunctionOpsToMainPass : public PassWrapper> { @@ -130,28 +129,7 @@ BlockArgument CreateFilePrefixArg(func::FuncOp main_func_op) { // Append the "__tf_file_prefix:0" to the "tf.entry_function" attribute's // item keyed by "inputs". - auto entry_function_attr = - main_func_op->getAttrOfType(kTfEntryFunctionAttr); - - SmallVector new_entry_function_attr_items; - for (NamedAttribute entry_function_attr_item : entry_function_attr) { - if (entry_function_attr_item.getName() == "inputs") { - auto inputs_attr = entry_function_attr_item.getValue().cast(); - const auto new_inputs_value_attr = Twine(inputs_attr.getValue()) - .concat(kTfFilePrefix) - .concat(":0") - .str(); - new_entry_function_attr_items.emplace_back( - builder.getNamedAttr(builder.getStringAttr("inputs"), - builder.getStringAttr(new_inputs_value_attr))); - } else { - new_entry_function_attr_items.emplace_back(entry_function_attr_item); - } - } - - main_func_op->setAttr( - /*name=*/kTfEntryFunctionAttr, - /*value=*/builder.getDictionaryAttr(new_entry_function_attr_items)); + AddEntryFunctionInput(Twine(kTfFilePrefix).concat(":0").str(), main_func_op); return new_file_prefix_arg; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h index 5406e4d2ed8..99edd6fc0ea 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -109,7 +109,10 @@ std::unique_ptr> CreatePrepareQuantizeDRQPass( // Creates an instance of the PreprocessOp pass, which will perform op // preprocessing to allow multi-axis quantization, prior to quantization. std::unique_ptr> CreatePreprocessOpPass( - const QuantizationSpecs& quant_specs, OpSet op_set); + OpSet op_set, + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + bool enable_per_channel_quantization); // Creates an instance of the PostQuantize pass, which will remove unnecessary // ops from the final quantized graph. @@ -210,6 +213,16 @@ std::unique_ptr> CreateConvertTpuModelToCpuPass(); // model quantization. std::unique_ptr> CreateCastBf16OpsToF32Pass(); +// Creates a pass that lifts HashTable ops as function arguments. In the graph +// execution mode, resource ops with the same `shared_name` attribute point to +// the same underlying resource. This is not true in the eager execution mode. +// Lifting resource ops as arguments will help unifying them across functions. +std::unique_ptr> CreateLiftHashTableOpsAsArgsPass(); + +// Creates a pass that merges duplicate resource ops in each function. Two +// resource ops are considered duplicated if they have the same `shared_name`. +std::unique_ptr> +CreateMergeDuplicateResourceOpsPass(); } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index 6f6e6d89da6..4a95afd8873 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -41,10 +41,15 @@ def ConvertArithConstToTfConst : Pat< (TF_ConstOp $value), [(AnyStaticShapeTensor $res)]>; -// Converts CheckNumerics op to Identity -def ConvertCheckNumerics : Pat< +// Remove CheckNumerics op +def RemoveCheckNumerics : Pat< (TF_CheckNumericsOp $arg, $msg), - (TF_IdentityOp $arg)>; + (replaceWithValue $arg)>; + +// Remove StopGradient op +def RemoveStopGradient : Pat< + (TF_StopGradientOp $arg), + (replaceWithValue $arg)>; // Only handles the case where batch_dimension is empty. def IsXlaGatherWithoutBatch : diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc index 18aa58fe60c..4ec9e5361ff 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" @@ -42,6 +43,8 @@ namespace quant { namespace { +using QuantMethod = + tensorflow::quantization::QuantizationMethod::ExperimentalMethod; using QuantizationUnit = std::pair; using QuantizationUnits = llvm::SetVector; @@ -57,19 +60,20 @@ class PreprocessOpPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PreprocessOpPass) - // Constructor used by the PassRegistration and enforce int8 quantization. - // This is only used by test. - explicit PreprocessOpPass() : op_set_(OpSet::UNIFORM_QUANTIZED) { - quant_specs_.inference_type = tensorflow::DT_QINT8; - } + explicit PreprocessOpPass() = default; // Constructor used by manually creating the pass. - explicit PreprocessOpPass(const QuantizationSpecs& quant_specs, OpSet op_set) - : quant_specs_(quant_specs), op_set_(op_set) {} + explicit PreprocessOpPass(OpSet op_set, const QuantMethod quantization_method, + bool enable_per_channel_quantization) { + op_set_ = op_set; + quantization_method_ = quantization_method; + enable_per_channel_quantization_ = enable_per_channel_quantization; + } PreprocessOpPass(const PreprocessOpPass& other) { - quant_specs_ = other.quant_specs_; op_set_ = other.op_set_; + quantization_method_ = other.quantization_method_; + enable_per_channel_quantization_ = other.enable_per_channel_quantization_; } StringRef getArgument() const final { @@ -85,15 +89,103 @@ class PreprocessOpPass void runOnOperation() override; private: - QuantizationSpecs quant_specs_; - OpSet op_set_; + Option op_set_{ + *this, "target-opset", llvm::cl::init(OpSet::UNIFORM_QUANTIZED), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; + + Option quantization_method_{ + *this, "quantization-method", + llvm::cl::init( + tensorflow::quantization::QuantizationMethod::STATIC_RANGE), + llvm::cl::desc("Choose quantization method."), + llvm::cl::values( + clEnumValN(tensorflow::quantization::QuantizationMethod::STATIC_RANGE, + "ptq", "Post-training static-range quantization"), + clEnumValN( + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE, + "drq", "Post-training dynamic-range quantizaiton"), + clEnumValN(tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY, + "weight_only", "Post-training weight-only quantizaiton"))}; + + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; }; // Apply constant transformations for the op_set. class PreprocessConstantOp : public OpRewritePattern { public: - explicit PreprocessConstantOp(MLIRContext* context, OpSet op_set) - : OpRewritePattern(context), op_set_(op_set) {} + explicit PreprocessConstantOp(MLIRContext* context, OpSet op_set, + QuantMethod quantization_method, + bool enable_per_channel_quantization) + : OpRewritePattern(context), + op_set_(op_set), + quantization_method_(quantization_method), + enable_per_channel_quantization_(enable_per_channel_quantization) {} + + LogicalResult addReshapeOpToDepthwiseWeight(TF::PartitionedCallOp op, + PatternRewriter& rewriter, + StringRef function_name) const { + std::unique_ptr spec = GetTFOpQuantSpec(op); + const absl::flat_hash_set operands = spec->quantizable_operands; + + if (operands.size() != 1) return failure(); + int weight_operand_idx = *operands.begin(); + + Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); + DenseFPElementsAttr attr; + if (!matchPattern(weight_op->getResult(0), m_Constant(&attr))) { + return failure(); + } + + // Get new shape. + llvm::ArrayRef cur_shape = attr.getType().getShape(); + int cur_rank = cur_shape.size(); + if (cur_rank != 4 || cur_shape[2] == 1) return failure(); + TensorType new_shape = RankedTensorType::get( + {cur_shape[0], cur_shape[1], 1, cur_shape[2] * cur_shape[3]}, + attr.getElementType()); + + // Inserts a reshape op. + auto shape_spec_type = + RankedTensorType::get({cur_rank}, rewriter.getIntegerType(64)); + auto new_shape_const_attr = + DenseElementsAttr::get(shape_spec_type, new_shape.getShape()); + rewriter.setInsertionPointAfter(weight_op); + auto new_shape_const = rewriter.create( + weight_op->getLoc(), shape_spec_type, new_shape_const_attr); + auto reshape_op = rewriter.create( + weight_op->getLoc(), new_shape, weight_op->getResult(0), + new_shape_const); + op->setOperand(weight_operand_idx, reshape_op); + + // Create a new function with preprocessed types. + ModuleOp module = op->getParentOfType(); + SymbolTable symbol_table(module); + func::FuncOp float_func = + dyn_cast(symbol_table.lookup(function_name)); + OperandRange func_args = op.getArgs(); + func::FuncOp new_float_func = float_func.clone(); + + SmallVector new_float_func_args{func_args.begin(), func_args.end()}; + new_float_func_args[weight_operand_idx] = reshape_op; + new_float_func.getArgument(weight_operand_idx).setType(new_shape); + new_float_func.setType(FunctionType::get( + getContext(), TypeRange{ValueRange{new_float_func_args}}, + new_float_func.getResultTypes())); + symbol_table.insert(new_float_func); + + op->setAttr("f", SymbolRefAttr::get(rewriter.getContext(), + new_float_func.getName())); + + return success(); + } LogicalResult matchAndRewrite(TF::PartitionedCallOp op, PatternRewriter& rewriter) const override { @@ -101,13 +193,12 @@ class PreprocessConstantOp : public OpRewritePattern { // Non-quantizable op if (!op->hasAttr(kQuantTraitAttrName)) return failure(); StringRef function_name = f_attr.getValue(); + // TODO(b/228928859): Improve the getter function to match attributes rather + // than function name. if (!function_name.startswith("composite_")) { return failure(); } - std::unique_ptr spec = GetTFOpQuantSpec(op); - const absl::flat_hash_set operands = spec->quantizable_operands; - if (function_name.contains("depthwise_conv2d")) { // Uniform Quantized op requires weights of tf.DepthwiseConv2dNative to // be transformed from [H,W,C,M] to [H,W,1,CxM] where @@ -115,57 +206,11 @@ class PreprocessConstantOp : public OpRewritePattern { // inserted between the constant op and the function op so that the // constant is safely transformed for the multi-use cases as well. Note // that bias doesn't need transformation as its shape is already in [CxM]. - if (operands.size() != 1) return failure(); - int weight_operand_idx = *operands.begin(); - Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); - - if (op_set_ == OpSet::UNIFORM_QUANTIZED) { - DenseFPElementsAttr attr; - if (!matchPattern(weight_op->getResult(0), m_Constant(&attr))) { - return failure(); - } - - // Get new shape. - llvm::ArrayRef cur_shape = attr.getType().getShape(); - int cur_rank = cur_shape.size(); - if (cur_rank != 4 || cur_shape[2] == 1) return failure(); - TensorType new_shape = RankedTensorType::get( - {cur_shape[0], cur_shape[1], 1, cur_shape[2] * cur_shape[3]}, - attr.getElementType()); - - // Inserts a reshape op. - auto shape_spec_type = - RankedTensorType::get({cur_rank}, rewriter.getIntegerType(64)); - auto new_shape_const_attr = - DenseElementsAttr::get(shape_spec_type, new_shape.getShape()); - rewriter.setInsertionPointAfter(weight_op); - auto new_shape_const = rewriter.create( - weight_op->getLoc(), shape_spec_type, new_shape_const_attr); - auto reshape_op = rewriter.create( - weight_op->getLoc(), new_shape, weight_op->getResult(0), - new_shape_const); - op->setOperand(weight_operand_idx, reshape_op); - - // Create a new function with preprocessed types. - ModuleOp module = op->getParentOfType(); - SymbolTable symbol_table(module); - func::FuncOp float_func = - dyn_cast(symbol_table.lookup(function_name)); - OperandRange func_args = op.getArgs(); - func::FuncOp new_float_func = float_func.clone(); - - SmallVector new_float_func_args{func_args.begin(), - func_args.end()}; - new_float_func_args[weight_operand_idx] = reshape_op; - new_float_func.getArgument(weight_operand_idx).setType(new_shape); - new_float_func.setType(FunctionType::get( - getContext(), TypeRange{ValueRange{new_float_func_args}}, - new_float_func.getResultTypes())); - symbol_table.insert(new_float_func); - - op->setAttr("f", SymbolRefAttr::get(rewriter.getContext(), - new_float_func.getName())); - return success(); + if (op_set_ == OpSet::UNIFORM_QUANTIZED || + (op_set_ == OpSet::XLA && enable_per_channel_quantization_ && + quantization_method_ == + tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY)) { + return addReshapeOpToDepthwiseWeight(op, rewriter, function_name); } } return failure(); @@ -173,6 +218,8 @@ class PreprocessConstantOp : public OpRewritePattern { private: const OpSet op_set_; + const QuantMethod quantization_method_; + const bool enable_per_channel_quantization_; }; #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.inc" @@ -183,7 +230,8 @@ void PreprocessOpPass::runOnOperation() { ModuleOp module_op = getOperation(); populateWithGenerated(patterns); - patterns.add(ctx, op_set_); + patterns.add(ctx, op_set_, quantization_method_, + enable_per_channel_quantization_); FrozenRewritePatternSet frozen_patterns(std::move(patterns)); for (auto func : module_op.getOps()) { @@ -199,8 +247,10 @@ void PreprocessOpPass::runOnOperation() { // Creates an instance of the TensorFlow dialect PreprocessOp // pass. std::unique_ptr> CreatePreprocessOpPass( - const QuantizationSpecs& quant_specs, const OpSet op_set) { - return std::make_unique(quant_specs, op_set); + const OpSet op_set, QuantMethod quantization_method, + const bool enable_per_channel_quantization) { + return std::make_unique(op_set, quantization_method, + enable_per_channel_quantization); } static PassRegistration pass; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index 6c374141025..9269461a80c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -968,6 +968,79 @@ class QuantizeConstPattern OpSet target_opset_; }; +// To calculate per-channel scale and offset, weight of depthwise was reshaped +// to [H, W, 1, InxMul]. After scale and offset has been calculated, this +// pattern gets called and restores the weight of depthwise back +// into [H, W, In, Mul] +class RestoreWeightShapePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + private: + LogicalResult addReshapeOpToDepthwiseWeight(TF::PartitionedCallOp op, + PatternRewriter& rewriter) const { + int weight_operand_idx = 1; + Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); + + auto weight_type = weight_op->getResult(0).getType().dyn_cast(); + auto input_type = op.getOperand(0).getType().dyn_cast(); + + llvm::ArrayRef weight_shape = weight_type.getShape(); + llvm::ArrayRef input_shape = input_type.getShape(); + + // If weight_shape[2] != 1, it means weight shape was already restored. + if (weight_shape[2] != 1) return failure(); + + // Weight was reshaped into [H, W, 1, InxMul]. + // Since we know in_channels from input_shape, we can derive multiplier. + int64_t in_channels = input_shape[3]; + // If in_channels is 1, there is no need to restore weight shape. + if (in_channels == 1) return failure(); + int64_t multiplier = weight_shape[3] / in_channels; + + TensorType new_shape = RankedTensorType::get( + {weight_shape[0], weight_shape[1], in_channels, multiplier}, + weight_type.getElementType()); + + int cur_rank = weight_type.getRank(); + + // Inserts a reshape op. + auto shape_spec_type = + RankedTensorType::get({cur_rank}, rewriter.getIntegerType(64)); + auto new_shape_const_attr = + DenseElementsAttr::get(shape_spec_type, new_shape.getShape()); + rewriter.setInsertionPointAfter(weight_op); + auto new_shape_const = rewriter.create( + weight_op->getLoc(), shape_spec_type, new_shape_const_attr); + auto reshape_op = rewriter.create( + weight_op->getLoc(), new_shape, weight_op->getResult(0), + new_shape_const); + op->setOperand(weight_operand_idx, reshape_op); + + return success(); + } + + LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, + PatternRewriter& rewriter) const override { + const auto f_attr = call_op.getFAttr().dyn_cast(); + StringRef function_name = f_attr.getValue(); + // TODO(b/228928859): Improve the getter function to match attributes rather + // than function name. + // If enable_legacy_weight_only is enabled, QuantizeFunctionsPattern + // does not get called and function remains as composite + if (!function_name.startswith("quantized_") && + !function_name.startswith("composite_")) { + return failure(); + } + + if (function_name.contains("depthwise_conv2d")) { + return addReshapeOpToDepthwiseWeight(call_op, rewriter); + } + + return failure(); + } +}; + // Prints a summary about the quantization results. class QuantizationSummary { public: @@ -1133,10 +1206,12 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { pm.enableVerifier(false); QuantizationSpecs quant_specs; - pm.addPass(CreatePreprocessOpPass(quant_specs, target_opset_)); - quant_specs.inference_type = tensorflow::DT_QINT8; quant_specs.disable_per_channel = !enable_per_channel_quantization_; + + pm.addPass(CreatePreprocessOpPass(target_opset_, quantization_method_, + enable_per_channel_quantization_)); + // Apply activation-weight quantization. if (quantization_method_ == tensorflow::quantization::QuantizationMethod::STATIC_RANGE) { @@ -1180,6 +1255,13 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { patterns_2.add( ctx, target_opset_); patterns_2.add(ctx, target_opset_); + + if (target_opset_ == OpSet::XLA && enable_per_channel_quantization_ && + quantization_method_ == + tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY) { + patterns_2.add(ctx); + } + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns_2))) || failed(verify(module))) { signalPassFailure(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 3b5a9d55f5f..538be57c88c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -137,12 +137,12 @@ pytype_strict_library( ], deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:variables", "//tensorflow/python/client:session", + "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", "//tensorflow/python/lib/io:lib", + "//tensorflow/python/ops:variables", "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:constants", "//tensorflow/python/saved_model:loader", @@ -167,11 +167,11 @@ pytype_strict_library( "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_py", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python/client:session", "//tensorflow/python/eager:context", "//tensorflow/python/eager:wrap_function", + "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/lib/io:lib", "//tensorflow/python/platform:tf_logging", @@ -198,8 +198,8 @@ tf_py_test( "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/saved_model:tag_constants", "//third_party/py/numpy", @@ -216,16 +216,6 @@ pytype_library( ":representative_dataset", "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:array_ops_stack", - "//tensorflow/python:client_testlib", - "//tensorflow/python:io_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python:variables", "//tensorflow/python/client:session", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", @@ -233,7 +223,17 @@ pytype_library( "//tensorflow/python/framework:tensor_spec", "//tensorflow/python/lib/io:lib", "//tensorflow/python/module", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:io_ops", + "//tensorflow/python/ops:lookup_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:string_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/ops/ragged:ragged_string_ops", + "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:signature_def_utils", @@ -253,7 +253,7 @@ tf_py_test( deps = [ ":quantize_model", "//tensorflow:tensorflow_py", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//tensorflow/python/saved_model:tag_constants", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index 892dfde7c9a..a023c8e9148 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -198,6 +198,7 @@ class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): """ class SimpleModel(module.Module): + def __init__(self): self.filters = np.random.uniform(low=-1.0, high=1.0, size=(4, 3)).astype( 'f4' @@ -288,7 +289,9 @@ class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): self._input_saved_model_path, quantization_options=options ) - def test_per_channel_for_non_uniform_opset_raises_value_error(self): + def test_drq_per_channel_for_non_uniform_opset_raises_value_error( + self, + ): model = self.SimpleModel() saved_model_save.save(model, self._input_saved_model_path) @@ -385,6 +388,7 @@ class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE ), + op_set=quant_opts_pb2.TF, force_graph_mode_calibration=True, ) @@ -916,8 +920,14 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): equation, shape_unknown, has_bias and not shape_unknown ) ) - model = self._create_einsum_model_with_fake_quant( - equation, y_shape, x_signature, y_signature, bias_shape, activation_fn + model = self._create_einsum_model( + equation, + y_shape, + x_signature, + y_signature, + bias_shape, + activation_fn, + is_qat_model=True, ) x = array_ops.constant( np.random.uniform(size=x_shape), dtype=dtypes.float32 @@ -1027,8 +1037,14 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): equation, shape_unknown, has_bias and not shape_unknown ) ) - model = self._create_einsum_model_with_fake_quant( - equation, y_shape, x_signature, y_signature, bias_shape, activation_fn + model = self._create_einsum_model( + equation, + y_shape, + x_signature, + y_signature, + bias_shape, + activation_fn, + is_qat_model=True, ) x = array_ops.constant( @@ -1098,13 +1114,14 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self._prepare_sample_einsum_datashapes(equation) ) - model = self._create_einsum_model_with_fake_quant( + model = self._create_einsum_model( equation, y_shape, x_signature, y_signature, bias_shape=None, activation_fn=None, + is_qat_model=True, ) if use_kernel: @@ -1180,8 +1197,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self._output_saved_model_path, self._input_saved_model_path, 0.5 ) - # TODO(b/244276332): Allow table initialization in TF2 eager mode. - @test_util.deprecated_graph_mode_only def test_qat_vocab_table_lookup_model(self): tags = {tag_constants.SERVING} signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -1204,7 +1219,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) signature_def_keys = [signature_def_key] @@ -1253,8 +1269,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) - # TODO(b/244276332): Allow table initialization in TF2 eager mode. - @test_util.deprecated_graph_mode_only def test_qat_file_init_hash_table_lookup_model_tf1(self): tags = {tag_constants.SERVING} signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -1277,7 +1291,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) signature_def_keys = [signature_def_key] @@ -1390,7 +1405,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) converted_model = quantize_model.quantize( @@ -2044,7 +2060,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) converted_model = quantize_model.quantize( @@ -2135,7 +2152,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): def test_matmul_with_reshape_and_bias_ptq_model( self, input_shape, filter_shape, bias_size, activation_fn, use_biasadd ): - model = self._create_matmul_model( input_shape, filter_shape, @@ -2175,9 +2191,7 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): ) input_data = ops.convert_to_tensor( - rng.uniform(low=0.0, high=1.0, size=input_shape).astype( - np.float32 - ) + rng.uniform(low=0.0, high=1.0, size=input_shape).astype(np.float32) ) expected_outputs = model.matmul(input_data) @@ -2188,31 +2202,38 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertAllClose(expected_outputs, got_outputs, atol=0.05) @parameterized.parameters( - ('abc,cde->abde', (2, 2, 64), (64, 3, 3), (3, 3), quant_opts_pb2.XLA), - ('abc,dce->abde', (2, 2, 64), (3, 64, 3), (3, 3), quant_opts_pb2.XLA), + ('abc,cde->abde', quant_opts_pb2.XLA), + ('abc,dce->abde', quant_opts_pb2.XLA), ) def test_einsum_ptq_model( self, equation: str, - input_shape: Sequence[int], - weight_shape: Sequence[int], - bias_shape: Sequence[int], target_opset: quant_opts_pb2.OpSet, ): + _, y_shape, bias_shape, x_signature, y_signature = ( + self._prepare_sample_einsum_datashapes(equation, use_bias=True) + ) + model = self._create_einsum_model( - self._input_saved_model_path, equation, - input_shape, - weight_shape, + y_shape, + x_signature, + y_signature, bias_shape, activation_fn=nn_ops.relu, ) + signatures = { + 'serving_default': model.einsum_with_kernel.get_concrete_function(), + } + + saved_model_save.save(model, self._input_saved_model_path, signatures) + def data_gen() -> repr_dataset.RepresentativeDataset: - for _ in range(200): + for _ in range(4): yield { - 'input_tensor': ops.convert_to_tensor( - np.random.uniform(low=0.0, high=1.0, size=input_shape).astype( + 'x': ops.convert_to_tensor( + np.random.uniform(low=0.0, high=1.0, size=x_signature).astype( 'f4' ) ), @@ -2223,7 +2244,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) converted_model = quantize_model.quantize( @@ -2246,13 +2268,13 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertTrue(self._contains_quantized_function_call(output_graphdef)) input_data = ops.convert_to_tensor( - np.random.uniform(low=0.0, high=1.0, size=input_shape).astype('f4') + np.random.uniform(low=0.0, high=1.0, size=x_signature).astype('f4') ) - expected_outputs = model.einsum(input_data) + expected_outputs = model.einsum_with_kernel(input_data) got_outputs = converted_model.signatures['serving_default']( - input_tensor=ops.convert_to_tensor(input_data) + x=ops.convert_to_tensor(input_data) ) - self.assertAllClose(expected_outputs, got_outputs, atol=0.0608) + self.assertAllClose(expected_outputs, got_outputs, atol=0.097) # Check the converted model in the target opset. quantization_options = quant_opts_pb2.QuantizationOptions( @@ -2283,10 +2305,10 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2')) new_outputs = converted_model.signatures['serving_default']( - input_tensor=ops.convert_to_tensor(input_data) + x=ops.convert_to_tensor(input_data) ) # The difference between TF and target path is expected to be small. - self.assertAllClose(new_outputs, got_outputs, atol=0.0666) + self.assertAllClose(new_outputs, got_outputs, atol=0.097) self.assertAllClose(new_outputs, expected_outputs, atol=0.057) @test_util.run_in_graph_and_eager_modes @@ -2363,8 +2385,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): _, y_shape, _, x_signature, y_signature = ( self._prepare_sample_einsum_datashapes('ab,bc->ac') ) - model = self._create_einsum_model_with_fake_quant( - 'ab,bc->ac', y_shape, x_signature, y_signature + model = self._create_einsum_model( + 'ab,bc->ac', y_shape, x_signature, y_signature, is_qat_model=True ) signatures = { @@ -2420,7 +2442,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): func.node_def, op_name='XlaDotV2', attr_name='', attr_val=None ) - @test_util.deprecated_graph_mode_only def test_matmul_ptq_model_with_unfreeze_constants(self): # Uses large weight to exceed the constant size threshold of 64KiB # (specified by `kDefaultConstantSizeThresholdInBytes`) for unfreezing. @@ -2440,6 +2461,7 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE ), + op_set=quant_opts_pb2.TF, freeze_all_variables=quant_opts_pb2.FreezeAllVariables(enabled=False), ) @@ -2507,7 +2529,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) data_gen = self._create_data_generator( @@ -2549,7 +2572,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) tags = {tag_constants.SERVING} @@ -2590,7 +2614,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) tags = {tag_constants.SERVING} @@ -2634,7 +2659,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) tags = {tag_constants.SERVING} @@ -2676,7 +2702,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) tags = {tag_constants.SERVING} signature_def_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] @@ -2726,6 +2753,53 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + def test_model_ptq_preserving_assets_extra(self): + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) + asset_filename = 'assets.extra/tf_serving_warmup_requests' + file_io.create_dir_v2( + os.path.join(self._input_saved_model_path, 'assets.extra') + ) + file_io.write_string_to_file( + filename=os.path.join(self._input_saved_model_path, asset_filename), + file_content='Test content', + ) + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.TF, + ) + tags = {tag_constants.SERVING} + + # Use plain python lists as representative samples. + representative_dataset = [ + { + 'input_tensor': [[i * 0.1 for i in range(1024)]], + } + for _ in range(4) + ] + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options=quantization_options, + representative_dataset=representative_dataset, + ) + self.assertIsNotNone(converted_model) + # Check if the assets.extra file exists in the output model. + self.assertTrue( + file_io.file_exists_v2( + os.path.join(self._output_saved_model_path, asset_filename) + ) + ) + # tf.data.Dataset is as an Iterable (thus can be used as representative # dataset) only in TF2 (eager mode). @test_util.run_v2_only @@ -2739,7 +2813,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) tags = {tag_constants.SERVING} @@ -2897,7 +2972,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) with self.assertLogs(level='WARN') as warning_logs: @@ -2964,7 +3040,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) def data_gen_sig1() -> repr_dataset.RepresentativeDataset: @@ -3074,7 +3151,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) data_gen = self._create_data_generator( @@ -3129,7 +3207,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) data_gen = self._create_data_generator( @@ -3160,7 +3239,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def self.assertTrue(self._contains_quantized_function_call(output_graphdef)) - @test_util.deprecated_graph_mode_only def test_ptq_model_with_variable_tf1_saved_model_unfreeze_constants(self): signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY tags = {tag_constants.SERVING} @@ -3184,6 +3262,7 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE ), + op_set=quant_opts_pb2.TF, freeze_all_variables=quant_opts_pb2.FreezeAllVariables(enabled=False), ) @@ -3257,7 +3336,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) data_gen = self._create_data_generator( @@ -3316,7 +3396,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) def data_gen_sig1() -> repr_dataset.RepresentativeDataset: @@ -3429,7 +3510,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) data_gen = self._create_data_generator( @@ -3496,8 +3578,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): representative_dataset=data_gen, ) - # TODO(b/244276332): Allow table initialization in TF2 eager mode. - @test_util.deprecated_graph_mode_only def test_ptq_vocab_table_lookup_model(self): tags = {tag_constants.SERVING} signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -3520,7 +3600,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) signature_def_keys = [signature_def_key] @@ -3569,7 +3650,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) - @test_util.deprecated_graph_mode_only def test_ptq_file_init_hash_table_lookup_model(self): tags = {tag_constants.SERVING} signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -3592,7 +3672,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) signature_def_keys = [signature_def_key] @@ -3871,7 +3952,8 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.STATIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) data_gen = self._create_data_generator( @@ -3917,6 +3999,78 @@ class DynamicRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): eager mode (default in TF2) to ensure support for when TF2 is disabled. """ + @parameterized.parameters( + (True, quant_opts_pb2.XLA), + (False, quant_opts_pb2.XLA), + (True, quant_opts_pb2.UNIFORM_QUANTIZED), + (False, quant_opts_pb2.UNIFORM_QUANTIZED), + ) + @test_util.run_in_graph_and_eager_modes + def test_einsum_model( + self, + constant_y_operand: bool, + target_opset: quant_opts_pb2.OpSet, + ): + equation = 'abc,cde->abde' + _, y_shape, bias_shape, x_signature, y_signature = ( + self._prepare_sample_einsum_datashapes(equation, use_bias=True) + ) + + model = self._create_einsum_model( + equation, + y_shape, + x_signature, + y_signature, + bias_shape, + activation_fn=nn_ops.relu, + ) + + if constant_y_operand: + signatures = { + 'serving_default': model.einsum_with_kernel.get_concrete_function(), + } + else: + signatures = { + 'serving_default': ( + model.einsum_without_kernel.get_concrete_function() + ), + } + + saved_model_save.save(model, self._input_saved_model_path, signatures) + + tags = {tag_constants.SERVING} + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=target_opset, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + # TODO(b/286489783): Support Einsum + if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: + self.assertFalse(self._contains_op(output_graphdef, 'XlaDotV2')) + self.assertTrue(self._contains_op(output_graphdef, 'BatchMatMulV2')) + else: + self.assertFalse(self._contains_op(output_graphdef, 'XlaDotV2')) + self.assertTrue(self._contains_op(output_graphdef, 'Einsum')) + @parameterized.named_parameters( ('to_tf_per_tensor', quant_opts_pb2.TF, False), ('to_xla_per_tensor', quant_opts_pb2.XLA, False), @@ -4555,7 +4709,8 @@ class DynamicRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( experimental_method=_ExperimentalMethod.DYNAMIC_RANGE - ) + ), + op_set=quant_opts_pb2.TF, ) converted_model = quantize_model.quantize( @@ -4578,8 +4733,6 @@ class DynamicRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def self.assertTrue(self._contains_quantized_function_call(output_graphdef)) - # TODO(b/244276332): Allow table initialization in TF2 eager mode. - @test_util.deprecated_graph_mode_only def test_table_initialized_when_model_has_table_tf1(self): tags = {tag_constants.SERVING} signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -4637,7 +4790,6 @@ class DynamicRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) - @test_util.deprecated_graph_mode_only def test_file_init_hash_table_lookup_model(self): tags = {tag_constants.SERVING} signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -4700,6 +4852,65 @@ class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): (default in TF2) to ensure support for when TF2 is disabled. """ + @test_util.run_in_graph_and_eager_modes + def test_einsum_model( + self, + ): + equation = 'abc,cde->abde' + _, y_shape, bias_shape, x_signature, y_signature = ( + self._prepare_sample_einsum_datashapes(equation, use_bias=True) + ) + + model = self._create_einsum_model( + equation, + y_shape, + x_signature, + y_signature, + bias_shape, + activation_fn=nn_ops.relu, + ) + + # Use constant y operand. + signatures = { + 'serving_default': model.einsum_with_kernel.get_concrete_function(), + } + + saved_model_save.save(model, self._input_saved_model_path, signatures) + + tags = {tag_constants.SERVING} + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.WEIGHT_ONLY + ), + op_set=quant_opts_pb2.XLA, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + # TODO(b/286489783): Support Einsum for Weight only quantization + # Due to other meta data, the compression is not exactly 1/4. + self.assertFalse(self._contains_op(output_graphdef, 'XlaDotV2')) + self.assertSizeRatioLessThan( + self._output_saved_model_path, + self._input_saved_model_path, + threshold=0.5, + ) + @parameterized.named_parameters( # TODO(b/269421880): Enable legacy weight-only scheme with the uniform # quantized opset @@ -4756,17 +4967,22 @@ class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): @parameterized.named_parameters( # TODO(b/269421880): Enable legacy weight-only scheme with the uniform # quantized opset - ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False, False), + ('to_xla_per_channel', quant_opts_pb2.XLA, True, False), + ('to_xla_per_channel_legacy', quant_opts_pb2.XLA, True, True), ) @test_util.run_in_graph_and_eager_modes def test_conv_model( self, target_opset: quant_opts_pb2.OpSet, enable_per_channel_quantization: bool, + enable_legacy_weight_only: bool, ): + input_shape = (1, 3, 4, 512) + filter_shape = (2, 3, 512, 2) model = self._create_conv2d_model( - input_shape=(1, 3, 4, 512), - filter_shape=(2, 3, 512, 2), + input_shape=input_shape, + filter_shape=filter_shape, has_bias=False, has_batch_norm=False, activation_fn=nn_ops.relu6, @@ -4781,6 +4997,7 @@ class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): ), op_set=target_opset, enable_per_channel_quantization=enable_per_channel_quantization, + enable_legacy_weight_only=enable_legacy_weight_only, ) converted_model = quantize_model.quantize( @@ -4801,30 +5018,68 @@ class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): ) output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + if not enable_legacy_weight_only: + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + # Due to other meta data, the compression is not exactly 1/4. - self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) self.assertSizeRatioLessThan( self._output_saved_model_path, self._input_saved_model_path, threshold=0.3, ) + if enable_per_channel_quantization: + per_channel_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[ + tensor_shape_pb2.TensorShapeProto( + dim=[ + tensor_shape_pb2.TensorShapeProto.Dim( + size=filter_shape[-1] + ) + ] + ) + ] + ) + ) + self.assertTrue( + self._contains_op( + output_graphdef, 'Const', '_output_shapes', per_channel_size_attr + ) + ) + + input_tensor = array_ops.constant( + np.random.uniform(low=0, high=0.1, size=input_shape), + dtype=dtypes.float32, + ) + original_output = model.conv(input_tensor) + quantized_output = converted_model.signatures['serving_default']( + input_tensor + ) + + threshold = 0.015 if enable_per_channel_quantization else 0.02 + self.assertAllClose(original_output, quantized_output, atol=threshold) + @parameterized.named_parameters( # TODO(b/269421880): Enable legacy weight-only scheme with the uniform # quantized opset - ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False, False), + ('to_xla_per_channel', quant_opts_pb2.XLA, True, False), + ('to_xla_per_channel_legacy', quant_opts_pb2.XLA, True, True), ) @test_util.run_in_graph_and_eager_modes def test_depthwise_conv2d_model( self, target_opset: quant_opts_pb2.OpSet, enable_per_channel_quantization: bool, + enable_legacy_weight_only: bool, ): + input_shape = (1, 3, 4, 512) filter_shape = (2, 3, 512, 2) strides = (1, 2, 2, 1) model = self._create_depthwise_conv2d_model( - input_shape=(1, 3, 4, 512), filter_shape=filter_shape, strides=strides + input_shape=input_shape, filter_shape=filter_shape, strides=strides ) saved_model_save.save(model, self._input_saved_model_path) @@ -4837,6 +5092,7 @@ class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): ), op_set=target_opset, enable_per_channel_quantization=enable_per_channel_quantization, + enable_legacy_weight_only=enable_legacy_weight_only, ) converted_model = quantize_model.quantize( @@ -4858,13 +5114,48 @@ class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def # Due to other meta data, the compression is not exactly 1/4. - self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + if not enable_legacy_weight_only: + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + + size_threshold = 0.5 if enable_per_channel_quantization else 0.3 self.assertSizeRatioLessThan( self._output_saved_model_path, self._input_saved_model_path, - threshold=0.3, + threshold=size_threshold, ) + if enable_per_channel_quantization: + per_channel_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[ + tensor_shape_pb2.TensorShapeProto( + dim=[ + tensor_shape_pb2.TensorShapeProto.Dim( + size=filter_shape[2] * filter_shape[3] + ), + ] + ) + ] + ) + ) + self.assertTrue( + self._contains_op( + output_graphdef, 'Const', '_output_shapes', per_channel_size_attr + ) + ) + + input_tensor = array_ops.constant( + np.random.uniform(low=-0.1, high=0.1, size=input_shape), + dtype=dtypes.float32, + ) + original_output = model.depthwise_conv(input_tensor) + quantized_output = converted_model.signatures['serving_default']( + input_tensor + ) + + threshold = 0.68 if enable_per_channel_quantization else 1.3 + self.assertAllClose(original_output, quantized_output, atol=threshold) + @parameterized.named_parameters( ('to_tf_use_constant', quant_opts_pb2.TF, False), ('to_xla_use_constant', quant_opts_pb2.XLA, False), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py index f2593d336f7..d7dc023e1dc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py @@ -1114,11 +1114,20 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): """A simple model with a single depthwise conv2d, bias and relu.""" def __init__(self): - self.filters = np.random.uniform( - low=-10, high=10, size=filter_shape - ).astype('f4') - self.out_channel_size = filter_shape[2] * filter_shape[3] + + # This ensures filters will have different value range per out channel + self.filters = np.stack( + [ + np.random.uniform( + low=-(i + 1), high=(i + 1), size=filter_shape[:-2] + ).astype('f4') + for i in range(self.out_channel_size) + ], + axis=-1, + ) + self.filters = self.filters.reshape(filter_shape) + self.bias = np.random.uniform( low=0, high=10, size=(self.out_channel_size) ).astype('f4') @@ -1178,11 +1187,19 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): """A simple model with a single conv2d, bias and relu.""" def __init__(self): - self.filters = np.random.uniform( - low=-10, high=10, size=filter_shape - ).astype('f4') - self.out_channel_size = filter_shape[-1] + + # This ensures filters will have different value range per out channel + self.filters = np.stack( + [ + np.random.uniform( + low=-(i + 1), high=(i + 1), size=filter_shape[:-1] + ).astype('f4') + for i in range(self.out_channel_size) + ], + axis=-1, + ) + self.bias = np.random.uniform( low=0, high=10, size=(self.out_channel_size) ).astype('f4') @@ -1313,7 +1330,7 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): # Verify that when bias_size is not None, has_bias should be True. # And if bias_size is None, has_bias should be False using XNOR - assert (not ((bias_size is not None) ^ has_bias)) + assert not ((bias_size is not None) ^ has_bias) # Verify that bias size is correct if bias_size: @@ -1332,82 +1349,6 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): ) return model - def _create_einsum_model( - self, - saved_model_path: str, - equation: str, - input_shape: Sequence[int], - weight_shape: Sequence[int], - bias_shape: Optional[Sequence[int]] = None, - activation_fn: Optional[ops.Operation] = None, - ) -> module.Module: - class EinsumModel(module.Module): - """A simple model with a single einsum. - - Bias and activation function are optional. - """ - - def __init__( - self, - equation: str, - weight_shape: Sequence[int], - bias_shape: Optional[Sequence[int]] = None, - activation_fn: Optional[ops.Operation] = None, - ) -> None: - """Initializes a EinsumModel. - - Args: - equation: a string describing the contraction. - weight_shape: Shape of the weight tensor. - bias_shape: Shape of the bias. This is not always 1D so Einsum ops - usually use Add op instead of BiasAdd. - activation_fn: The activation function to be used. No activation - function if None. - """ - self.equation = equation - self.activation_fn = activation_fn - self.weight = np.random.uniform(low=-1.0, high=1.0, size=weight_shape) - self.bias = ( - np.random.uniform(low=-1.0, high=1.0, size=bias_shape) - if bias_shape is not None - else None - ) - - @def_function.function - def einsum(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: - """Evaluates the Einstein summation convention. - - Depending on self.has_bias and self.activation_fn, it may add a bias - term or go through the activaction function. - - Args: - input_tensor: Input tensor to einsum with the weight. - - Returns: - A map of: output key -> output result. - """ - out = tensorflow.einsum(self.equation, input_tensor, self.weight) - - if self.bias is not None: - out = out + self.bias - - if self.activation_fn is not None: - out = self.activation_fn(out) - - return {'output': out} - - model = EinsumModel(equation, weight_shape, bias_shape, activation_fn) - saved_model_save.save( - model, - saved_model_path, - signatures=model.einsum.get_concrete_function( - tensor_spec.TensorSpec( - shape=input_shape, dtype=dtypes.float32, name='input_tensor' - ) - ), - ) - return model - # Prepares sample einsum input data shapes. # This function returns: # 1. Shape for input 1 @@ -1435,7 +1376,7 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): out_labels = equation[arrow_pos + 1 :] # 2. Create sample shapes. - label_to_size = {'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6} + label_to_size = {'a': 4, 'b': 32, 'c': 64, 'd': 128, 'e': 8} x_shape = [label_to_size.get(x_label) for x_label in x_labels] y_shape = [label_to_size.get(y_label) for y_label in y_labels] bias_shape = None @@ -1460,7 +1401,7 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): ] return x_shape, y_shape, bias_shape, x_signature, y_signature - def _create_einsum_model_with_fake_quant( + def _create_einsum_model( self, equation: str, y_shape: Sequence[int], @@ -1468,9 +1409,10 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): y_signature: Sequence[Optional[int]], bias_shape: Optional[Sequence[int]] = None, activation_fn: Optional[ops.Operation] = None, + is_qat_model: bool = False, ) -> module.Module: class EinsumModel(module.Module): - """Einsum class with fakequants.""" + """Einsum class.""" def __init__(self): self._bias = None @@ -1509,33 +1451,35 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): return self._einsum(x, y) def _einsum(self, x, y): - x = array_ops.fake_quant_with_min_max_vars( - x, - min=ops.convert_to_tensor(self._min[0]), - max=ops.convert_to_tensor(self._max[0]), - num_bits=8, - narrow_range=False, - ) - y = array_ops.fake_quant_with_min_max_vars( - y, - min=ops.convert_to_tensor(self._min[1]), - max=ops.convert_to_tensor(self._max[1]), - num_bits=8, - narrow_range=False, - ) + if is_qat_model: + x = array_ops.fake_quant_with_min_max_vars( + x, + min=ops.convert_to_tensor(self._min[0]), + max=ops.convert_to_tensor(self._max[0]), + num_bits=8, + narrow_range=False, + ) + y = array_ops.fake_quant_with_min_max_vars( + y, + min=ops.convert_to_tensor(self._min[1]), + max=ops.convert_to_tensor(self._max[1]), + num_bits=8, + narrow_range=False, + ) out = tensorflow.einsum(equation, x, y) if self._bias is not None: out = nn_ops.bias_add(out, self._bias) if activation_fn is not None: out = activation_fn(out) - out = array_ops.fake_quant_with_min_max_vars( - out, - min=ops.convert_to_tensor(self._min[2]), - max=ops.convert_to_tensor(self._max[2]), - num_bits=8, - narrow_range=False, - ) + if is_qat_model: + out = array_ops.fake_quant_with_min_max_vars( + out, + min=ops.convert_to_tensor(self._min[2]), + max=ops.convert_to_tensor(self._max[2]), + num_bits=8, + narrow_range=False, + ) return {'output': out} return EinsumModel() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 0a3f2e95c36..d9a5aaf4e31 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -119,11 +119,14 @@ void AddExportPasses(const bool duplicate_shape_determining_constants, } pm.addPass(mlir::quant::CreateInsertMainFunctionPass()); + pm.addPass(mlir::quant::CreateLiftHashTableOpsAsArgsPass()); pm.addNestedPass( mlir::CreateFunctionalToExecutorDialectConversionPass()); pm.addPass(mlir::CreateBreakUpIslandsPass()); pm.addPass(mlir::quant::CreateMergeInitializerFunctionOpsToMainPass()); pm.addPass(mlir::quant::CreateMergeSaveFunctionOpsToMainPass()); + pm.addNestedPass( + mlir::quant::CreateMergeDuplicateResourceOpsPass()); // Used to clean up the "tf._noinliner" attribute that is previously used to // prevent certain functions from being inlined (see @@ -384,7 +387,7 @@ absl::Status UnfreezeConstantsAndSaveVariables( !create_dir_status.ok()) { LOG(ERROR) << "Failed to create checkpoint directory at: " << checkpoint_dir; - return tsl::ToAbslStatus(create_dir_status); + return create_dir_status; } TF_ASSIGN_OR_RETURN(const auto _, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index f17f20df4b6..2344c108016 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -47,7 +47,7 @@ absl::StatusOr QuantizeQatModel( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quant_opts, + const QuantizationOptions& quantization_options, const absl::flat_hash_map& function_aliases); // Apply post-training dynamic range quantization to the model. @@ -55,20 +55,20 @@ absl::StatusOr QuantizePtqDynamicRange( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quant_opts); + const QuantizationOptions& quantization_options); absl::StatusOr QuantizePtqModelPreCalibration( absl::string_view saved_model_path, - const std::vector& exported_names, + const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quant_opts, + const QuantizationOptions& quantization_options, const absl::flat_hash_map& function_aliases); absl::StatusOr QuantizePtqModelPostCalibration( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quant_opts, + const QuantizationOptions& quantization_options, const absl::flat_hash_map& function_aliases); } // namespace quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 758f99b62a0..53dd34cfc30 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -17,8 +17,8 @@ import collections.abc import tempfile from typing import Callable, Collection, Dict, Mapping, Optional, Sequence import uuid -from absl import logging +from absl import logging import numpy as np from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2 @@ -51,11 +51,12 @@ _ExperimentalMethod = quant_opts_pb2.QuantizationMethod.ExperimentalMethod _SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] # Default minimum number of elements in the weights for them to be quantized -# during dynamic range quantization (DRQ). +# during dynamic range quantization (DRQ) and weight-only quantization. _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS = 1024 # Name of the saved model assets directory. _ASSETS_DIR = 'assets' +_ASSETS_EXTRA_DIR = 'assets.extra' def _is_qat_saved_model(saved_model_path: str): @@ -530,26 +531,27 @@ def _copy_assets(src_path: str, dst_path: str) -> None: src_path: Source saved model directory. dst_path: Destination saved model directory. This directory must exist. """ - src_assets_path = file_io.join(src_path, _ASSETS_DIR) - if not file_io.file_exists_v2(src_assets_path): - # Do nothing if the source assets path does not exist. - return + for assets_dir_name in [_ASSETS_DIR, _ASSETS_EXTRA_DIR]: + src_assets_path = file_io.join(src_path, assets_dir_name) + if not file_io.file_exists_v2(src_assets_path): + # Do nothing if the source assets path does not exist. + continue - dst_assets_path = file_io.join(dst_path, _ASSETS_DIR) - file_io.create_dir_v2(dst_assets_path) + dst_assets_path = file_io.join(dst_path, assets_dir_name) + file_io.create_dir_v2(dst_assets_path) - for curr_dir, _, files in file_io.walk_v2(src_assets_path): - for asset_file_name in files: - src_asset_file = file_io.join(curr_dir, asset_file_name) + for curr_dir, _, files in file_io.walk_v2(src_assets_path): + for asset_file_name in files: + src_asset_file = file_io.join(curr_dir, asset_file_name) - # Construct the destination assets file path. - curr_dst_dir = curr_dir.replace(src_assets_path, dst_assets_path) - dst_asset_file = file_io.join(curr_dst_dir, asset_file_name) + # Construct the destination assets file path. + curr_dst_dir = curr_dir.replace(src_assets_path, dst_assets_path) + dst_asset_file = file_io.join(curr_dst_dir, asset_file_name) - file_io.copy_v2(src_asset_file, dst_asset_file) - logging.info( - 'Copied asset file: %s -> %s', src_asset_file, dst_asset_file - ) + file_io.copy_v2(src_asset_file, dst_asset_file) + logging.info( + 'Copied asset file: %s -> %s', src_asset_file, dst_asset_file + ) def _run_static_range_qat( @@ -1017,17 +1019,20 @@ def _populate_quantization_options_default_values( quantization_options: An instance of QuantizationOptions. """ if quantization_options.op_set == quant_opts_pb2.OpSet.OP_SET_UNSPECIFIED: - quantization_options.op_set = quant_opts_pb2.OpSet.TF + quantization_options.op_set = quant_opts_pb2.OpSet.XLA if not quantization_options.HasField('freeze_all_variables'): quantization_options.freeze_all_variables.enabled = True - if quantization_options.enable_per_channel_quantization and ( - quantization_options.op_set != quant_opts_pb2.OpSet.UNIFORM_QUANTIZED + # TODO(b/281595329): Implement static range quantization per-channel support + if quantization_options.enable_per_channel_quantization and not ( + quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED + or quantization_options.quantization_method.experimental_method + == _ExperimentalMethod.WEIGHT_ONLY ): raise ValueError( 'Currently, per-channel quantization is supported for Uniform ' - 'Quantized opset only.' + 'Quantized opset and Weight-only.' ) if ( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto index 6c5c520bffc..701f01a1da2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto @@ -141,7 +141,7 @@ message QuantizationOptions { // units that are not specified in unit-wise configurations. QuantizationMethod quantization_method = 1; - OpSet op_set = 2; // If not specified, it defaults to `TF`. + OpSet op_set = 2; // If not specified, it defaults to `XLA`. QuantizationPrecision quantization_precision = 3; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index 6c085b85b2d..abbb663462b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -57,7 +57,7 @@ absl::Status RunPassesOnModuleOp(const absl::string_view mlir_dump_file_name, } if (failed(pass_manager.run(module_op))) { - return tsl::ToAbslStatus(statusHandler.ConsumeStatus()); + return statusHandler.ConsumeStatus(); } return absl::OkStatus(); @@ -106,7 +106,7 @@ absl::Status PreprocessAndFreezeGraph( if (session.has_value() && failed(mlir::tf_saved_model::FreezeVariables( module_op, session.value()))) { - return tsl::ToAbslStatus(statusHandler.ConsumeStatus()); + return statusHandler.ConsumeStatus(); } return RunPassesOnModuleOp( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD index 2587a70d9cf..6484c365f92 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD @@ -21,6 +21,7 @@ filegroup( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", size_override = { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_hashtable_ops_as_args.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_hashtable_ops_as_args.mlir new file mode 100644 index 00000000000..20f37c578f2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_hashtable_ops_as_args.mlir @@ -0,0 +1,109 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-lift-hashtable-ops-as-args | FileCheck %s +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1506 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> () + func.func @init_all_tables() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_init_all_tables"], tf_saved_model.initializer_type = "init_op"} { + %cst = "tf.Const"() {value = dense<["hello", "model", "quantization"]> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + %cst_0 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.LookupTableImportV2"(%0, %cst, %cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + return + } + +// Check that HashTable op in the initilizer is not lifted. +// CHECK: func.func @init_all_tables() +// CHECK: %[[OUT_0:.*]] = "tf.HashTableV2"() +// CHECK: "tf.LookupTableImportV2"(%[[OUT_0]] + func.func private @serving_default(%arg0: tensor ) -> (tensor<*xi64>) attributes {tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}} { + %cst = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<0.00235294132> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.00117647066> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<0.00156862743> : tensor} : () -> tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %1 = "tf.LookupTableSizeV2"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<1xi32> + %3 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 5 : i64} : (tensor) -> tensor + %4 = "tf.AddV2"(%3, %1) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + return %5 : tensor<*xi64> + } + +// Check that HashTable op is lifted. +// CHECK: func.func private @serving_default +// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor<*xi64> +// CHECK-SAME: tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0,hash_table_1:0", outputs = "FakeQuantWithMinMaxArgs_2:0"} +// CHECK: "tf.LookupTableSizeV2"(%arg1) +// CHECK: "tf.LookupTableFindV2"(%arg1 + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_vocabs:0"]} ) -> (tensor<*xi64> {tf_saved_model.index_path = ["FakeQuantWithMinMaxArgs_2:0"]}) attributes {tf.entry_function = {inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> (tensor<*xi64>) + %1 = "tf.Identity"(%0) : (tensor<*xi64>) -> tensor<*xi64> + return %1 : tensor<*xi64> + } + +// Check that the caller is updated. +// CHECK: func.func @main +// CHECK: %[[OUT_1:.*]] = "tf.HashTableV2"() +// CHECK: %[[OUT_2:.*]] = "tf.PartitionedCall"(%arg0, %[[OUT_1]]) +} +// ----- +// Test nested function case. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1506 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> () + func.func @init_all_tables() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_init_all_tables"], tf_saved_model.initializer_type = "init_op"} { + %cst = "tf.Const"() {value = dense<["hello", "model", "quantization"]> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + %cst_0 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.LookupTableImportV2"(%0, %cst, %cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + return + } + +// Check that HashTable op in the initilizer is not lifted. +// CHECK: func.func @init_all_tables() +// CHECK: %[[OUT_0:.*]] = "tf.HashTableV2"() +// CHECK: "tf.LookupTableImportV2"(%[[OUT_0]] + func.func private @serving_default(%arg0: tensor ) -> (tensor<*xi64>) attributes {tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}} { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default1} : (tensor) -> (tensor<*xi64>) + %1 = "tf.Identity"(%0) : (tensor<*xi64>) -> tensor<*xi64> + return %1 : tensor<*xi64> + } +// Check that HashTable op is passed through. +// CHECK: func.func private @serving_default +// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor<*xi64> +// CHECK-SAME: tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0,hash_table_1:0", outputs = "FakeQuantWithMinMaxArgs_2:0"} +// CHECK: "tf.PartitionedCall"(%arg0, %arg1) + func.func private @serving_default1(%arg0: tensor ) -> (tensor<*xi64>) { + %cst = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<0.00235294132> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.00117647066> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<0.00156862743> : tensor} : () -> tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %1 = "tf.LookupTableSizeV2"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<1xi32> + %3 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 5 : i64} : (tensor) -> tensor + %4 = "tf.AddV2"(%3, %1) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + return %5 : tensor<*xi64> + } + +// Check that HashTable op is lifted. +// CHECK: func.func private @serving_default1 +// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor<*xi64> +// CHECK: "tf.LookupTableSizeV2"(%arg1) +// CHECK: "tf.LookupTableFindV2"(%arg1 + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_vocabs:0"]} ) -> (tensor<*xi64> {tf_saved_model.index_path = ["FakeQuantWithMinMaxArgs_2:0"]}) attributes {tf.entry_function = {inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> (tensor<*xi64>) + %1 = "tf.Identity"(%0) : (tensor<*xi64>) -> tensor<*xi64> + return %1 : tensor<*xi64> + } +// Check that the caller is updated. +// CHECK: func.func @main +// CHECK: %[[OUT_1:.*]] = "tf.HashTableV2"() +// CHECK: %[[OUT_2:.*]] = "tf.PartitionedCall"(%arg0, %[[OUT_1]]) +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_duplicate_resource_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_duplicate_resource_ops.mlir new file mode 100644 index 00000000000..c3099ea9418 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_duplicate_resource_ops.mlir @@ -0,0 +1,108 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-merge-duplicate-resource-ops | FileCheck %s + +func.func @merge_duplicate_variable(%arg0: tensor<1x20xf32>, %arg1: tensor) -> (tensor<20x4096xf32>) { + %0 = tf_executor.graph { + %outputs_5, %control_6 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_9, %control_10 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_11, %control_12 = tf_executor.island wraps "tf.RestoreV2"(%arg1, %outputs_7, %outputs_5) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<20x4096xf32> + %control_13 = tf_executor.island(%control_12) wraps "tf.AssignVariableOp"(%outputs_9, %outputs_11) {validate_shape = false} : (tensor>>, tensor<20x4096xf32>) -> () + %control_14 = tf_executor.island(%control_13) wraps "tf.NoOp"() : () -> () + %outputs_15, %control_16 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_17, %control_18 = tf_executor.island wraps "tf.ReadVariableOp"(%outputs_15) : (tensor>>) -> tensor<20x4096xf32> + %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_21, %control_22 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %control_23 = tf_executor.island(%control_18) wraps "tf.SaveV2"(%arg1, %outputs_19, %outputs_21, %outputs_17) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<20x4096xf32>) -> () + %outputs_24, %control_25 = tf_executor.island(%control_23) wraps "tf.Identity"(%arg1) : (tensor) -> tensor + tf_executor.fetch %outputs_17, %control_14, %control_25 : tensor<20x4096xf32>, !tf_executor.control, !tf_executor.control + } + return %0 : tensor<20x4096xf32> +} +// CHECK-LABEL: @merge_duplicate_variable +// CHECK: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.VarHandleOp"() +// CHECK: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.RestoreV2" +// CHECK: %[[CTL_2:.*]] = tf_executor.island(%[[CTL_1]]) wraps "tf.AssignVariableOp"(%[[OUT_0]], %[[OUT_1]]) + +// Check that ReadVariableOp now use the same variable op. +// CHECK: %[[OUT_3:.*]], %[[CTL_3:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[OUT_0]]) + +// ----- + +func.func @variables_with_different_shared_names(%arg0: tensor<1x20xf32>, %arg1: tensor) -> (tensor<20x4096xf32>) { + %0 = tf_executor.graph { + %outputs_5, %control_6 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_9, %control_10 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_11, %control_12 = tf_executor.island wraps "tf.RestoreV2"(%arg1, %outputs_7, %outputs_5) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<20x4096xf32> + %control_13 = tf_executor.island(%control_12) wraps "tf.AssignVariableOp"(%outputs_9, %outputs_11) {validate_shape = false} : (tensor>>, tensor<20x4096xf32>) -> () + %control_14 = tf_executor.island(%control_13) wraps "tf.NoOp"() : () -> () + %outputs_15, %control_16 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_1"} : () -> tensor>> + %outputs_17, %control_18 = tf_executor.island wraps "tf.ReadVariableOp"(%outputs_15) : (tensor>>) -> tensor<20x4096xf32> + %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_21, %control_22 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %control_23 = tf_executor.island(%control_18) wraps "tf.SaveV2"(%arg1, %outputs_19, %outputs_21, %outputs_17) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<20x4096xf32>) -> () + %outputs_24, %control_25 = tf_executor.island(%control_23) wraps "tf.Identity"(%arg1) : (tensor) -> tensor + tf_executor.fetch %outputs_17, %control_14, %control_25 : tensor<20x4096xf32>, !tf_executor.control, !tf_executor.control + } + return %0 : tensor<20x4096xf32> +} +// CHECK-LABEL: @variables_with_different_shared_names +// CHECK: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.VarHandleOp"() +// CHECK-SAME: shared_name = "MatMul/b_0" +// CHECK: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.RestoreV2" +// CHECK: %[[CTL_2:.*]] = tf_executor.island(%[[CTL_1]]) wraps "tf.AssignVariableOp"(%[[OUT_0]], %[[OUT_1]]) + +// Check that the second variable is not removed since they have different +// `shared_name` attribute. +// CHECK: %[[OUT_3:.*]], %[[CTL_3:.*]] = tf_executor.island wraps "tf.VarHandleOp"() +// CHECK-SAME: shared_name = "MatMul/b_1" +// CHECK: %[[OUT_4:.*]], %[[CTL_4:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[OUT_3]]) + +// ----- + +// Test two resource ops have the same shared_name but different types. +// expected-error @+1 {{This op has the same `shared_name` but different type with another}} +func.func @same_shared_name_but_different_types(%arg0: tensor<1x20xf32>, %arg1: tensor) -> (tensor<20x4096xf32>) { + %0 = tf_executor.graph { + %outputs_5, %control_6 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_9, %control_10 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_11, %control_12 = tf_executor.island wraps "tf.RestoreV2"(%arg1, %outputs_7, %outputs_5) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<20x4096xf32> + %control_13 = tf_executor.island(%control_12) wraps "tf.AssignVariableOp"(%outputs_9, %outputs_11) {validate_shape = false} : (tensor>>, tensor<20x4096xf32>) -> () + %control_14 = tf_executor.island(%control_13) wraps "tf.NoOp"() : () -> () + %outputs_15, %control_16 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_17, %control_18 = tf_executor.island wraps "tf.ReadVariableOp"(%outputs_15) : (tensor>>) -> tensor<20x4096xf32> + %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_21, %control_22 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %control_23 = tf_executor.island(%control_18) wraps "tf.SaveV2"(%arg1, %outputs_19, %outputs_21, %outputs_17) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<20x4096xf32>) -> () + %outputs_24, %control_25 = tf_executor.island(%control_23) wraps "tf.Identity"(%arg1) : (tensor) -> tensor + tf_executor.fetch %outputs_17, %control_14, %control_25 : tensor<20x4096xf32>, !tf_executor.control, !tf_executor.control + } + return %0 : tensor<20x4096xf32> +} + +// ----- + +func.func @merge_hashtable_ops(%arg0: tensor) -> (tensor) { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.LookupTableSizeV2"(%outputs) {device = ""} : (tensor) -> tensor + %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_0) : (tensor) -> tensor + %control_8 = tf_executor.island(%control_3, %control_5) wraps "tf.NoOp"() : () -> () + %outputs_9, %control_10 = tf_executor.island wraps "tf.Const"() {value = dense<["hello", "model", "quantization"]> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + %outputs_11, %control_12 = tf_executor.island wraps "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + %outputs_13, %control_14 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %control_15 = tf_executor.island wraps "tf.LookupTableImportV2"(%outputs_13, %outputs_9, %outputs_11) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + %control_16 = tf_executor.island(%control_15) wraps "tf.NoOp"() : () -> () + tf_executor.fetch %outputs_4, %control_8, %control_16 : tensor, !tf_executor.control, !tf_executor.control + } + return %0 : tensor +} + +// CHECK-LABEL: @merge_hashtable_ops +// CHECK: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.HashTableV2"() +// CHECK: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.LookupTableSizeV2"(%[[OUT_0]]) + +// Check that LookupTableImportV2 is using the same HashTableV2 with LookupTableSizeV2. +// CHECK: %[[CTL_2:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_0]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir index ee97c375ba9..6eae6df1323 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir @@ -371,10 +371,20 @@ func.func @xla_gather_known_output_shape(%arg0: tensor<5xi32>, %arg1: tensor<1xi // ----- -func.func @replace_checknumerics_to_identity(%arg0: tensor<*xf32>) -> tensor<*xf32> { +func.func @remove_check_numerics_op(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.CheckNumerics"(%arg0) {device = "", message = "transformer"} : (tensor<*xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } -// CHECK: func @replace_checknumerics_to_identity -// CHECK: %[[out:.*]] = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> \ No newline at end of file +// CHECK: func @remove_check_numerics_op +// CHECK: return %arg0 : tensor<*xf32> + +// ----- + +func.func @remove_stop_gradient_op(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.StopGradient"(%arg0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @remove_stop_gradient_op +// CHECK: return %arg0 : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op_weight_only.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op_weight_only.mlir new file mode 100644 index 00000000000..4f36784e67a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op_weight_only.mlir @@ -0,0 +1,55 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-preprocess-op='target-opset=XLA quantization-method=weight_only enable-per-channel-quantization=false' | FileCheck --check-prefix PerTensor %s +// RUN: tf-quant-opt %s -split-input-file -quant-preprocess-op='target-opset=XLA quantization-method=weight_only enable-per-channel-quantization=true' | FileCheck --check-prefix PerChannel %s + +module { + // For XLA weight-only per-channel depthwise convolution, tensor shape should have + // transformed from [H,W,C,M] to [H,W,1,CxM], + func.func @depthwise_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<6xf32>} : () -> tensor<6xf32> + %cst_1 = "tf.Const"() {value = dense<[[[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]]],[[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// PerTensor-LABEL: func @depthwise_conv +// PerTensor-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<6xf32> +// PerTensor: %[[CONST_1:.*]] = arith.constant dense +// PerTensor-NOT: tensor<2x3x1x6xf32> +// PerTensor-SAME: tensor<2x3x3x2xf32> +// PerTensor: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> +// PerTensor: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32> +// PerTensor: return %[[BIAS_0:.*]] : tensor<*xf32> + +// PerTensor-LABEL: func private @composite_depthwise_conv2d_fn( +// PerTensor-SAME: %arg0: tensor<1x3x4x3xf32>, +// PerTensor-SAME: %arg1: tensor<2x3x3x2xf32>) +// PerTensor: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", +// PerTensor: return %0 : tensor<*xf32> + +// PerChannel-LABEL: func @depthwise_conv +// PerChannel-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<6xf32> +// PerChannel: %[[CONST_1:.*]] = arith.constant dense +// PerChannel-NOT: tensor<2x3x3x2xf32> +// PerChannel-SAME: tensor<2x3x1x6xf32> +// PerChannel: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6xf32>) -> tensor<*xf32> +// PerChannel: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32> +// PerChannel: return %[[BIAS_0:.*]] : tensor<*xf32> + +// PerChannel-LABEL: func private @composite_depthwise_conv2d_fn( +// PerChannel-SAME: %arg0: tensor<1x3x4x3xf32>, +// PerChannel-SAME: %arg1: tensor<2x3x3x2xf32>) + +// PerChannel-LABEL: func private @composite_depthwise_conv2d_fn_0( +// PerChannel-SAME: %arg0: tensor<1x3x4x3xf32>, +// PerChannel-SAME: %arg1: tensor<2x3x1x6xf32>) +// PerChannel: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", +// PerChannel: return %0 : tensor<*xf32> +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir index 6d2fec56737..8c0786178ee 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir @@ -1,4 +1,5 @@ -// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions='quantization-method=weight_only target-opset=XLA' -quant-quantize-composite-functions='quantization-method=weight_only target-opset=XLA' -symbol-dce | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions='quantization-method=weight_only target-opset=XLA' -quant-quantize-composite-functions='quantization-method=weight_only target-opset=XLA' -symbol-dce | FileCheck --check-prefix=PerTensor %s +// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions='quantization-method=weight_only target-opset=XLA' -quant-quantize-composite-functions='quantization-method=weight_only target-opset=XLA enable-per-channel-quantization=true' -symbol-dce | FileCheck --check-prefix=PerChannel %s module { // TODO(b/260020937): Support transpose_a, transpose_b for matmul. @@ -13,13 +14,21 @@ module { } } -// CHECK-LABEL: func @matmul -// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<0> : tensor<12x2xi8>} : () -> tensor<12x2xi8> -// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<3.93700805E-9> : tensor} : () -> tensor -// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", -// CHECK-SAME: f = @quantized_matmul_fn_0} : (tensor<2x12xf32>, tensor<12x2xi8>, tensor, tensor) -> tensor<*xf32> -// CHECK: return %[[out]] +// PerTensor-LABEL: func @matmul +// PerTensor-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<0> : tensor<12x2xi8>} : () -> tensor<12x2xi8> +// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<3.93700805E-9> : tensor} : () -> tensor +// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// PerTensor: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerTensor-SAME: f = @quantized_matmul_fn_0} : (tensor<2x12xf32>, tensor<12x2xi8>, tensor, tensor) -> tensor<*xf32> +// PerTensor: return %[[out]] + +// PerChannel-LABEL: func @matmul +// PerChannel-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<0> : tensor<12x2xi8>} : () -> tensor<12x2xi8> +// PerChannel-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<3.93700805E-9> : tensor} : () -> tensor +// PerChannel-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// PerChannel: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerChannel-SAME: f = @quantized_matmul_fn_0} : (tensor<2x12xf32>, tensor<12x2xi8>, tensor, tensor) -> tensor<*xf32> +// PerChannel: return %[[out]] // ----- @@ -41,15 +50,25 @@ module { return %conv : tensor<*xf32> } -// CHECK-LABEL: func @conv -// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() -// CHECK-DAG: %[[scale:.*]] = "tf.Const"() -// CHECK-DAG: %[[zp:.*]] = "tf.Const"() -// CHECK: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", -// CHECK-SAME: f = @quantized_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<*xf32> -// CHECK: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", -// CHECK-SAME: f = @quantized_conv2d_fn_0} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<*xf32> -// CHECK: return %[[out_1]], %[[out_2]] +// PerTensor-LABEL: func @conv +// PerTensor-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> +// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor} : () -> tensor +// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor} : () -> tensor +// PerTensor: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerTensor-SAME: f = @quantized_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<*xf32> +// PerTensor: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerTensor-SAME: f = @quantized_conv2d_fn_0} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<*xf32> +// PerTensor: return %[[out_1]], %[[out_2]] + +// PerChannel-LABEL: func @conv +// PerChannel-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> +// PerChannel-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<2xf32>} : () -> tensor<2xf32> +// PerChannel-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2xi32>} : () -> tensor<2xi32> +// PerChannel: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerChannel-SAME: f = @quantized_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<2xf32>, tensor<2xi32>) -> tensor<*xf32> +// PerChannel: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerChannel-SAME: f = @quantized_conv2d_fn_0} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xi8>, tensor<2xf32>, tensor<2xi32>) -> tensor<*xf32> +// PerChannel: return %[[out_1]], %[[out_2]] } @@ -78,16 +97,31 @@ module { return %0 : tensor<*xf32> } -// CHECK-LABEL: func @depthwise_conv -// CHECK-DAG: %[[q_w1:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x1xi8>} -// CHECK-DAG: %[[q_w2:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> -// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<0.0236220472> : tensor} : () -> tensor -// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK-DAG: %[[bias:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} -// CHECK: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", -// CHECK-SAME: f = @quantized_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xi8>, tensor, tensor) -> tensor<*xf32> -// CHECK: %[[out_1_add:.*]] = "tf.BiasAdd"(%[[out_1]], %[[bias]]) -// CHECK: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", -// CHECK-SAME: f = @quantized_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<*xf32> -// CHECK: return %[[out_1_add]], %[[out_2]] +// PerTensor-LABEL: func @depthwise_conv +// PerTensor-DAG: %[[q_w1:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x1xi8>} +// PerTensor-DAG: %[[q_w2:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> +// PerTensor-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<0.0236220472> : tensor} : () -> tensor +// PerTensor-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// PerTensor-DAG: %[[bias:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} +// PerTensor: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerTensor-SAME: f = @quantized_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xi8>, tensor, tensor) -> tensor<*xf32> +// PerTensor: %[[out_1_add:.*]] = "tf.BiasAdd"(%[[out_1]], %[[bias]]) +// PerTensor: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// PerTensor-SAME: f = @quantized_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<*xf32> +// PerTensor: return %[[out_1_add]], %[[out_2]] + +// PerChannel-LABEL: func @depthwise_conv +// PerChannel-DAG: %[[bias1:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> +// PerChannel-DAG: %[[q_w1:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x1xi8>} : () -> tensor<2x3x3x1xi8> +// PerChannel-DAG: %[[q_w2:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> +// PerChannel-DAG: %[[scale1:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<3xf32>} : () -> tensor<3xf32> +// PerChannel-DAG: %[[scale2:.*]] = "tf.Const"() {value = dense<{{[0-9\.Ee\+\-]+}}> : tensor<6xf32>} : () -> tensor<6xf32> +// PerChannel-DAG: %[[zp1:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<3xi32>} : () -> tensor<3xi32> +// PerChannel-DAG: %[[zp2:.*]] = "tf.Const"() {value = dense<{{[0-9]+}}> : tensor<6xi32>} : () -> tensor<6xi32> +// PerChannel: %[[out_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[scale1]], %[[zp1]]) {config = "", config_proto = "", executor_type = "", +// PerChannel-SAME: f = @quantized_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xi8>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32> +// PerChannel: %[[out_1_add:.*]] = "tf.BiasAdd"(%[[out_1]], %[[bias1]]) +// PerChannel: %[[out_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[scale2]], %[[zp2]]) {config = "", config_proto = "", executor_type = "", +// PerChannel-SAME: f = @quantized_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xi8>, tensor<6xf32>, tensor<6xi32>) -> tensor<*xf32> +// PerChannel: return %[[out_1_add]], %[[out_2]] } diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD new file mode 100644 index 00000000000..162572e5d7c --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -0,0 +1,56 @@ +load("//tensorflow:pytype.default.bzl", "pytype_library") +load("//tensorflow/tsl:tsl.default.bzl", "tsl_pybind_extension") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + ":friends", + ], + licenses = ["notice"], +) + +package_group( + name = "friends", + packages = [ + "//tensorflow/compiler/tests/...", + ], +) + +tsl_pybind_extension( + name = "stablehlo_extension", + srcs = [ + "stablehlo.cc", + "@stablehlo//:stablehlo/integrations/python/PortableApi.cpp", + ], + hdrs = [ + "@stablehlo//:stablehlo/integrations/python/PortableApi.h", + ], + copts = [ + "-fexceptions", + "-frtti", + ], + features = ["-use_header_modules"], + deps = [ + "@pybind11", + "@stablehlo//:stablehlo_portable_api", + ], +) + +pytype_library( + name = "stablehlo", + srcs = ["stablehlo.py"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":stablehlo_extension", + ], +) + +py_test( + name = "stablehlo_test", + srcs = ["stablehlo_test.py"], + python_version = "PY3", + deps = [ + ":stablehlo", + ], +) diff --git a/tensorflow/compiler/xla/mlir_hlo/tosa/transforms/passes.h b/tensorflow/compiler/mlir/stablehlo/stablehlo.cc similarity index 53% rename from tensorflow/compiler/xla/mlir_hlo/tosa/transforms/passes.h rename to tensorflow/compiler/mlir/stablehlo/stablehlo.cc index acd63a76c25..0a256ff67c9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tosa/transforms/passes.h +++ b/tensorflow/compiler/mlir/stablehlo/stablehlo.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_TOSA_TRANSFORMS_PASSES_H -#define MLIR_HLO_TOSA_TRANSFORMS_PASSES_H - -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" +#include "pybind11/pybind11.h" // from @pybind11 +#include "stablehlo/integrations/python/PortableApi.h" // from @stablehlo namespace mlir { -namespace tosa { +namespace stablehlo { -std::unique_ptr> createLegalizeMhloPass(); -std::unique_ptr> createPrepareMhloPass(); +PYBIND11_MODULE(stablehlo_extension, m) { mlir::stablehlo::AddPortableApi(m); } -#define GEN_PASS_REGISTRATION -#define GEN_PASS_DECL_TOSALEGALIZEMHLOPASS -#include "passes.h.inc" - -} // namespace tosa +} // namespace stablehlo } // namespace mlir - -#endif // MLIR_HLO_TOSA_TRANSFORMS_PASSES_H diff --git a/tensorflow/python/training/tracking/python_state.py b/tensorflow/compiler/mlir/stablehlo/stablehlo.py similarity index 58% rename from tensorflow/python/training/tracking/python_state.py rename to tensorflow/compiler/mlir/stablehlo/stablehlo.py index 39e6e28addc..64c3f1b7be3 100644 --- a/tensorflow/python/training/tracking/python_state.py +++ b/tensorflow/compiler/mlir/stablehlo/stablehlo.py @@ -1,5 +1,4 @@ -"""Utilities for including Python state in TensorFlow checkpoints.""" -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""StableHLO Portable Python APIs. +This setup only exports the the StableHLO Portable C++ APIs, which have +signatures that do not rely on MLIR classes. -# TODO(kathywu): Delete this file after all imports have been moved to the path -# below. -from tensorflow.python.trackable import python_state -from tensorflow.python.util import deprecation +Exporting all of MLIR Python bindings to TF OSS has high maintenance +implications, especially given the frequency that TF updates the revision of +LLVM used. +""" -__getattr__ = deprecation.deprecate_moved_module( - __name__, python_state, "2.11") +# pylint: disable=wildcard-import +from .stablehlo_extension import * diff --git a/tensorflow/compiler/mlir/stablehlo/stablehlo_test.py b/tensorflow/compiler/mlir/stablehlo/stablehlo_test.py new file mode 100644 index 00000000000..f6a1d1a75bb --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/stablehlo_test.py @@ -0,0 +1,40 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Smoke test of functions in StableHLO Portable APIs.""" + +from tensorflow.compiler.mlir.stablehlo import stablehlo + + +def smoketest(): + """Test StableHLO Portable APIs.""" + assert isinstance(stablehlo.get_api_version(), int) + assembly = """ + module @jit_f_jax.0 { + func.func public @main(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<1> : tensor + %1 = "stablehlo.compare"(%arg0, %0) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + return %1 : tensor + } + } + """ + target = stablehlo.get_current_version() + artifact = stablehlo.serialize_portable_artifact(assembly, target) + deserialized = stablehlo.deserialize_portable_artifact(artifact) + rountrip = stablehlo.serialize_portable_artifact(deserialized, target) + assert artifact == rountrip + + +if __name__ == "__main__": + smoketest() diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d49bca20c10..43206931918 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -474,6 +474,7 @@ cc_library( ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ ":attribute_utils", + ":convert_type", ":dynamic_shape_utils", ":rewrite_util", ":tensorflow_attributes", @@ -929,6 +930,19 @@ cc_library( ], ) +cc_library( + name = "string_util", + srcs = ["utils/string_util.cc"], + hdrs = ["utils/string_util.h"], + deps = [ + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + cc_library( name = "fake_session", srcs = ["utils/fake_session.cc"], @@ -1048,6 +1062,7 @@ cc_library( ":tensorflow_ops", ":tensorflow_passes", ":tensorflow_types", + ":tf_saved_model_asset_sinking_pass", "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -1209,6 +1224,7 @@ cc_library( "transforms/drop_while_shape_invariant.cc", "transforms/einsum.cc", "transforms/embedding_pipelining.cc", + "transforms/embedding_sequencing.cc", "transforms/executor_island_coarsening.cc", "transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_island_coarsening.cc", @@ -1291,9 +1307,12 @@ cc_library( "transforms/tpu_variable_runtime_reformatting.cc", "transforms/update_control_dependencies.cc", "transforms/verify_suitable_for_graph_export_pass.cc", + "transforms/xla_call_module_deserialization.cc", + "transforms/xla_call_module_serialization.cc", "transforms/xla_cluster_formation.cc", "transforms/xla_inline_device_ops.cc", "transforms/xla_rewrite.cc", + "transforms/xla_validate_inputs.cc", "translate/breakup-islands.cc", "translate/split_into_island_per_op_pass.cc", "translate/tf_executor_to_functional.cc", @@ -1301,7 +1320,6 @@ cc_library( ], hdrs = [ "transforms/bridge.h", - "transforms/call_graph_util.h", "transforms/cluster_ops_by_policy.h", "transforms/collection_ops_util.h", "transforms/einsum.h", @@ -1318,6 +1336,7 @@ cc_library( deps = [ ":attribute_utils", ":bridge_logger", + ":call_graph_util", ":cluster_util", ":convert_tensor", ":convert_type", @@ -1333,6 +1352,8 @@ cc_library( ":parallel_execute_util", ":serialize_mlir_module_utils", ":shape_inference_pass", + ":stablehlo_custom_call_utils", + ":string_util", ":tensorflow", ":tensorflow_analysis", ":tensorflow_ops", @@ -1340,6 +1361,7 @@ cc_library( ":tensorflow_side_effects", ":tensorflow_types", ":tf_data_optimization", + ":tf_device_pass_inc_gen", ":tf_legalize_hlo", ":tf_ops_layout_helper", ":tf_pass_inc_gen", @@ -1353,6 +1375,8 @@ cc_library( ":unroll_batch_matmul_pass", ":verification_utils", ":verify_suitable_for_graph_export", + ":visitor", + ":xla_call_module_attrs", ":xla_sharding_util", "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", @@ -1360,11 +1384,13 @@ cc_library( "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/tf2xla:side_effect_util", + "//tensorflow/compiler/tf2xla/kernels:xla_call_module_loader", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1378,6 +1404,7 @@ cc_library( "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1395,12 +1422,38 @@ cc_library( "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_portable_api", + "@stablehlo//:stablehlo_serialization", + "@stablehlo//:vhlo_ops", + ], +) + +cc_library( + name = "xla_call_module_attrs", + srcs = [], + hdrs = ["utils/xla_call_module_attrs.h"], + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "stablehlo_custom_call_utils", + srcs = ["utils/stablehlo_custom_call.cc"], + hdrs = ["utils/stablehlo_custom_call.h"], + deps = [ + ":xla_call_module_attrs", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", ], ) @@ -1919,6 +1972,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1981,11 +2035,15 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//tensorflow/c:tf_status", - "//tensorflow/c/eager:c_api", - "//tensorflow/core:lib", - "//tensorflow/core/platform:logging", - "//tensorflow/core/protobuf:for_core_protos_cc", + ":convert_tensor", + ":export_tf_dialect_op", + ":tensorflow", + ":tensorflow_traits", + "//tensorflow/core/tfrt/fallback:fallback_state", + "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -2000,9 +2058,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":constant_fold_utils", ":convert_tensor", ":export_graphdef", - ":export_tf_dialect_op", ":tensorflow", ":tensorflow_traits", ":tensorflow_types", @@ -2012,8 +2070,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/ops", - "//tensorflow/core/tfrt/fallback:fallback_state", - "//tensorflow/core/tfrt/fallback:op_kernel_runner", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -2346,7 +2402,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/platform:errors", "//tensorflow/tsl/platform:statusor", - "@com_google_googletest//:gtest_main", + "@com_google_googletest//:gtest", "@llvm-project//mlir:FuncDialect", ], ) @@ -2456,7 +2512,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", - "@com_google_googletest//:gtest_main", + "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:FuncDialect", @@ -2542,6 +2598,41 @@ cc_library( ], ) +cc_library( + name = "call_graph_util", + srcs = [ + "utils/call_graph_util.cc", + ], + hdrs = [ + "utils/call_graph_util.h", + ], + deps = [ + ":tensorflow", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "call_graph_util_test", + size = "small", + srcs = ["utils/call_graph_util_test.cc"], + deps = [ + ":attribute_utils", + ":call_graph_util", + ":tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], +) + cc_library( name = "xla_sharding_util", srcs = [ @@ -2732,6 +2823,35 @@ cc_library( ], ) +cc_library( + name = "visitor", + srcs = ["utils/visitor.cc"], + hdrs = ["utils/visitor.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tf_saved_model_asset_sinking_pass", + srcs = ["transforms/tf_saved_model_asset_sinking_pass.cc"], + hdrs = ["transforms/tf_saved_model_asset_sinking_pass.h"], + deps = [ + ":tensorflow", + ":tensorflow_types", + ":tf_savedmodel_pass_inc_gen", + "//tensorflow/tsl/platform:path", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + build_test( name = "tensorflow_build_test", targets = [ diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h index 89580d1edd7..9817b290c4c 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_VALUE_TYPED_ANALYZER_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_VALUE_TYPED_ANALYZER_H_ +#include + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 14ae242525a..43db7e91a56 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -16,8 +16,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include +#include +#include #include #include +#include +#include #include "absl/container/node_hash_map.h" #include "llvm/ADT/DenseMap.h" @@ -30,6 +34,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -255,6 +260,14 @@ class OpSideEffectCollector { for (Region& region : op->getRegions()) { AddRegionSideEffectsForOp(region, op); } + } else if (auto xla_call_module_op = dyn_cast(op)) { + for (auto func_symbol : xla_call_module_op.getFunctionList().getAsRange< + mlir::FlatSymbolRefAttr>()) { + if (auto func = symbol_table_collection_.lookupNearestSymbolFrom< + mlir::func::FuncOp>(xla_call_module_op, func_symbol)) { + AddRegionSideEffectsForOp(func.getBody(), op); + } + } } else { // Now handle all other ops. auto& side_effects_by_resource_id = op_side_effect_map_[op]; diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index 6ed83b65428..05321522d50 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -16,6 +16,8 @@ limitations under the License. #include #include #include +#include +#include #include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" @@ -165,7 +167,7 @@ class MlirAbstractOp : public TracingOperation { Status SetAttrType(const char* attr_name, tensorflow::DataType dtype) override; Status SetAttrShape(const char* attr_name, const int64_t* dims, - const int num_dims) override; + int num_dims) override; Status SetAttrFunction(const char* attr_name, const AbstractOperation* value) override; Status SetAttrFunctionName(const char* attr_name, const char* value, @@ -189,7 +191,7 @@ class MlirAbstractOp : public TracingOperation { const char* attr_name, absl::Span values) override; - Status SetOpName(const char* const op_name) override; + Status SetOpName(const char* op_name) override; MLIRContext* GetContext() { return context_; } @@ -543,7 +545,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus()); tensorflow::GraphExportConfig configs; - fdef_.reset(new tensorflow::FunctionDef()); + fdef_ = std::make_unique(); TF_RETURN_IF_ERROR( ConvertMlirFunctionToFunctionLibraryDef(func_, configs, fdef_.get())); *f = fdef_.get(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h index bc46b0c04ec..aa0f84eb122 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_ +#include + #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h index d02b2b20e55..cad01806953 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h @@ -19,6 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ +#include +#include + #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 14ff8f37ae8..f063732db29 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2509,15 +2509,16 @@ This is typically used by gradient computations for a concat operation. let arguments = (ins Arg:$concat_dim, - Arg, [{The `N` int32 vectors representing shape of tensors being concatenated.}]>:$shape + Arg, [{The `N` int32 or int64 vectors representing shape of tensors being concatenated.}]>:$shape ); let results = (outs - Res, [{The `N` int32 vectors representing the starting offset -of input tensors within the concatenated output.}]>:$offset + Res, [{The `N` vectors representing the starting offset +of input tensors within the concatenated output with type matching `shape`.}]>:$offset ); TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; + TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<1>; let hasVerifier = 1; @@ -4621,6 +4622,36 @@ This operation creates a tensor of `shape` and `dtype`. let hasFolder = 1; } +def TF_EncodePngOp : TF_Op<"EncodePng", [Pure]> { + let summary = "PNG-encode an image."; + + let description = [{ +`image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` +where `channels` is: + +* 1: for grayscale. +* 2: for grayscale + alpha. +* 3: for RGB. +* 4: for RGBA. + +The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +default or a value from 0 to 9. 9 is the highest compression level, generating +the smallest output, but is slower. + }]; + + let arguments = (ins + Arg, [{3-D with shape `[height, width, channels]`.}]>:$image, + + DefaultValuedOptionalAttr:$compression + ); + + let results = (outs + Res:$contents + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_EnqueueTPUEmbeddingArbitraryTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingArbitraryTensorBatch", [DeclareOpInterfaceMethods, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> { let summary = [{ Eases the porting of code that uses tf.nn.embedding_lookup_sparse(). @@ -11187,6 +11218,8 @@ underlying graph, and executes each of the partitioned subgraphs as a function. // Returns the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getFAttr(); } + // Sets the callee from the callable. + void setCalleeFromCallable(CallInterfaceCallable callee); // returns the callee of this operation. func::FuncOp func() { @@ -21152,7 +21185,7 @@ for binary operators. }]; } -def TF_XlaCallModuleOp : TF_Op<"XlaCallModule", [Pure]> { +def TF_XlaCallModuleOp : TF_Op<"XlaCallModule", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "Invokes a StableHLO module."; let description = [{ @@ -21171,7 +21204,9 @@ platform argument (see `platforms`) nor the dimension arguments (see TF_ShapeAttrArray:$Sout, DefaultValuedOptionalAttr:$dim_args_spec, DefaultValuedOptionalAttr:$platforms, - DefaultValuedOptionalAttr:$function_list + DefaultValuedOptionalAttr:$function_list, + DefaultValuedOptionalAttr:$has_token_input_output, + DefaultValuedOptionalAttr:$disabled_checks ); let results = (outs @@ -22633,6 +22668,39 @@ expected to create these operators. }]; } +def TF__XlaCompileOp : TF_Op<"_XlaCompile", [AttrSizedOperandSegments]> { + let summary = "XLA Compile Op. For use by the XLA JIT only."; + + let description = [{ +Compiles a TensorFlow function into an XLA LocalExecutable and returns a key +that _XlaRun can use to look up the LocalExecutable and execute it. + }]; + + let arguments = (ins + Variadic:$constants, + Variadic:$args, + Variadic:$resources, + + BoolAttr:$must_compile, + SymbolRefAttr:$function + ); + + let results = (outs + Res:$key, + Res:$compilation_successful + ); + + TF_DerivedOperandSizeAttr Nresources = TF_DerivedOperandSizeAttr<2>; + TF_DerivedOperandTypeListAttr Targs = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedOperandTypeListAttr Tconstants = TF_DerivedOperandTypeListAttr<0>; +} + def TF__XlaHostComputeMlirOp : TF_Op<"_XlaHostComputeMlir", [TF_RecvSideEffect, TF_SendSideEffect, TF_XlaHostComputeSideEffect]> { let summary = [{ A pseudo-op to represent host-side computation in an XLA program. @@ -22703,6 +22771,27 @@ execution the transfer corresponds to.}]>:$dynamic_key, TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; } +def TF__XlaRunOp : TF_Op<"_XlaRun", []> { + let summary = "XLA Run Op. For use by the XLA JIT only."; + + let description = [{ +Executes a TensorFlow function previously compiled into a LocalExecutable by an +_XlaCompile op. + }]; + + let arguments = (ins + Variadic:$args, + TF_StrTensor:$key + ); + + let results = (outs + Variadic:$results + ); + + TF_DerivedOperandTypeListAttr Targs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>; +} + def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", [DeclareOpInterfaceMethods, TF_SendSideEffect]> { let summary = "A placeholder op to send values to a running XLA computation."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index c9a890778f6..d40089d2948 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -409,6 +409,8 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall", // Returns the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getFAttr(); } + // Sets the callee from the callable + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee); // Returns the resolved callee function of this operation. // Prefer passing in SymbolTableCollection to reduce lookup costs by @@ -570,6 +572,8 @@ underlying graph, and executes each of the partitioned subgraphs as a function. // Returns the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getFAttr(); } + // Sets the callee from the callable + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee); // Returns the resolved callee function of this operation. // Prefer passing in SymbolTableCollection to reduce lookup costs by @@ -1009,6 +1013,8 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", // Returns the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getFAttr(); } + // Sets the callee from the callable. + void setCalleeFromCallable(CallInterfaceCallable callee); // Returns the resolved callee function of this operation. // Prefer passing in SymbolTableCollection to reduce lookup costs by diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index dfa46846aa1..7f066b3f327 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include +#include +#include #include #include #include @@ -3214,6 +3216,16 @@ LogicalResult LegacyCallOp::verifySymbolUses( return success(); } +void LegacyCallOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + // Direct call. + if (SymbolRefAttr fAttr = getFAttr()) { + SymbolRefAttr calleeAttr = callee.get(); + return setFAttr(cast(calleeAttr)); + } + // Indirect call, callee Value is the first operand. + return setOperand(0, callee.get()); +} + //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h index 77f87e0f960..29dae2715a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_ +#include +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 36b9d6c6e20..62a047bd441 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -17,8 +17,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include +#include #include #include +#include #include #include #include @@ -61,6 +63,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -79,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" @@ -563,6 +567,31 @@ LogicalResult TPUPartitionedCallOp::verifySymbolUses( return VerifyPartitionedCall(*this, symbolTable); } +template +static void SetPartitionCalleeFromCallable(CallOpClass op, + mlir::CallInterfaceCallable callee) { + // Direct call. + if (SymbolRefAttr fAttr = op.getFAttr()) { + SymbolRefAttr calleeAttr = callee.get(); + return op.setFAttr(cast(calleeAttr)); + } + // Indirect call, callee Value is the first operand. + return op.setOperand(0, callee.get()); +} + +void PartitionedCallOp::setCalleeFromCallable( + mlir::CallInterfaceCallable callee) { + return SetPartitionCalleeFromCallable(*this, callee); +} +void StatefulPartitionedCallOp::setCalleeFromCallable( + CallInterfaceCallable callee) { + return SetPartitionCalleeFromCallable(*this, callee); +} +void TPUPartitionedCallOp::setCalleeFromCallable( + mlir::CallInterfaceCallable callee) { + return SetPartitionCalleeFromCallable(*this, callee); +} + //===----------------------------------------------------------------------===// // PowOp //===----------------------------------------------------------------------===// @@ -1057,7 +1086,7 @@ static Type InferSelectV2OpType(Value condition, Value e, Value t) { if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; // Explicitly get broadcasted output type as element types of condition may - // not be same as the broadcated type's element type. + // not be same as the broadcasted type's element type. SmallVector result_shape; if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(), broadcasted_ranked_ty.getShape(), @@ -2829,6 +2858,27 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) { auto transpose = dyn_cast_or_null(op.getX().getDefiningOp()); if (!transpose) return {}; + // If the transpose ops are on different devices, we don't fold them. + if (transpose->getBlock() != op->getBlock()) { + tensorflow::DataType dtype; + auto status = tensorflow::ConvertToDataType( + op.getX().getType().cast().getElementType(), &dtype); + if (status.ok()) { + // We can only leave the transpose op on host if its dtype is supported on + // host. + if (dtype == tensorflow::DT_UINT64 || dtype == tensorflow::DT_INT64 || + dtype == tensorflow::DT_UINT32 || dtype == tensorflow::DT_INT32 || + dtype == tensorflow::DT_UINT16 || dtype == tensorflow::DT_INT16 || + dtype == tensorflow::DT_UINT8 || dtype == tensorflow::DT_INT8 || + dtype == tensorflow::DT_HALF || dtype == tensorflow::DT_BFLOAT16 || + dtype == tensorflow::DT_FLOAT || dtype == tensorflow::DT_DOUBLE || + dtype == tensorflow::DT_COMPLEX64 || + dtype == tensorflow::DT_COMPLEX128 || dtype == tensorflow::DT_BOOL) { + return {}; + } + } + } + // Permutations defined by constant operations. DenseIntElementsAttr perm0; DenseIntElementsAttr perm1; @@ -2933,6 +2983,39 @@ void FusedBatchNormOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +//===----------------------------------------------------------------------===// +// XlaCallModuleOp +//===----------------------------------------------------------------------===// + +void XlaCallModuleOp::getEffects( + SmallVectorImpl> + &effects) { + if (!getFunctionList().empty()) { + // The StableHLO module embedded in XlaCallModule contains + // `stablehlo.custom_call` calling TF host callback functions. + // `stablehlo.custom_call` will be lowered to `stablehlo.send` and + // `stablehlo.recv`. + effects.emplace_back(MemoryEffects::Write::get(), + ResourceEffects::Send::get()); + effects.emplace_back(MemoryEffects::Write::get(), + ResourceEffects::Recv::get()); + effects.emplace_back(MemoryEffects::Write::get(), + ResourceEffects::XlaHostCompute::get()); + } +} + +LogicalResult XlaCallModuleOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { + for (auto f : getFunctionList()) { + auto func = symbolTable.lookupNearestSymbolFrom( + getOperation(), f.cast()); + if (!func) { + return emitOpError() << "refers to an undefined function: " << f; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // XlaLaunchOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 7bf2b3ca1f1..b295461d533 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/BUILD index 6cc7344b083..4c2e9dc642c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", size_override = { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 49ba1afe393..7bca5e649b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -906,6 +906,26 @@ func.func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf3 // CHECK: return %arg0 } +// CHECK-LABEL: @nonCancellableTransposeCrossRegion +func.func @nonCancellableTransposeCrossRegion(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + + %result = "tf_device.launch"() ({ + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + tf_device.return %3: tensor<1x4x4x8xf32> + }) {device = "device"} : () -> tensor<1x4x4x8xf32> + + func.return %result : tensor<1x4x4x8xf32> + + // CHECK-DAG: %[[CONST1:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} + // CHECK-DAG: %[[CONST2:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} + // CHECK: %[[TRANS1:.*]] = "tf.Transpose"(%arg0, %[[CONST1]]) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + // CHECK: %[[TRANS2:.*]] = "tf.Transpose"(%[[TRANS1]], %[[CONST2]]) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + // CHECK: return %[[TRANS2]] +} + // CHECK-LABEL: @cancellableTransposeConst func.func @cancellableTransposeConst(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { %0 = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD index 9132abf2fe5..20ca45e8264 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir index f6bd3d4d586..408342b0ebd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir @@ -8,13 +8,19 @@ module { return } func.func private @while_body(%arg0: tensor) -> (tensor) { - // Verify that everything is extracted into one of the four functions. + // Verify the overall pipelining control flow and supporting functions. // The order of these functions is also significant. - // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} - // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} - // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} - // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_backward.*}} - // CHECK-NEXT: return + // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @non_tpu.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @start_step_0.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @non_tpu.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @start_step_1.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}} + // CHECK: {{.*tf.While.* body = @new_while_body.* cond = @new_while_cond.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm2.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm1.*}} + // CHECK: return // metadata ops "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor @@ -37,39 +43,20 @@ module { %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor return %0 : tensor } - // Generated functions + // Generated functions for control flow ops (if, while, switch) + // non_tpu should have to TPU ops - just identity and return (in this test). - // CHECK: func.func private @_func_non_tpu + // CHECK: func.func private @non_tpu // CHECK-NEXT: tf.Identity // CHECK-NEXT: return - // sc_forward should have TPU ops including replicated outputs but not inputs - // CHECK: func.func private @_func_sc_forward - // CHECK-NOT: TPUReplicatedInput - // CHECK-DAG: TPUReplicateMetadata - // CHECK-DAG: TPUCompilationResult - // CHECK-DAG: TPUReplicatedOutput - // CHECK: return - - // core_tput should have TPU ops including both replicated inputs and outputs - // CHECK: func.func private @_func_core_tpu - // CHECK-DAG: TPUReplicatedInput - // CHECK-DAG: TPUReplicateMetadata - // CHECK-DAG: TPUCompilationResult - // CHECK-DAG: TPUReplicatedOutput - // CHECK: return - - // sc_backward should have TPU ops including replicted inputs but not outputs - // CHECK: func.func private @_func_sc_backward - // CHECK-NOT: TPUReplicatedOutput - // CHECK-DAG: TPUReplicateMetadata - // CHECK-DAG: TPUCompilationResult - // CHECK-DAG: TPUReplicatedInput - // CHECK: return + // Since there is a backward pass, finish_step_nm2 should be non-empty. + // CHECK: func.func private @finish_step_nm2 + // CHECK-NEXT: tf.TPUReplicateMetadata } // ----- -// This test verifies that the extraction works correctly for evaluation-only models. +// This test verifies that the pipelining works correctly for evaluation-only models. module { func.func @main() { %cst = "tf.Const"() {value = dense<2> : tensor} : () -> tensor @@ -77,9 +64,19 @@ module { return } func.func private @while_body(%arg0: tensor) -> (tensor) { - // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} - // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} - // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // The pipelining control flow and supporting functions stay the same as the training version above. + // The order of these functions is also significant. + // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @non_tpu.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @start_step_0.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @non_tpu.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @start_step_1.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @while_cond.*}} + // CHECK: {{.*tf.While.* body = @new_while_body.* cond = @new_while_cond.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm2.*}} + // CHECK: {{.*StatefulPartitionedCall.* f = @finish_step_nm1.*}} + // CHECK: return // metadata ops "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor @@ -99,8 +96,8 @@ module { %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor return %0 : tensor } - // Only verify sc_backward. The previous test case verifies everything else. - // CHECK: func.func private @_func_sc_backward + // There's no backward pass so finish_step_nm2 should be empty + // CHECK: func.func private @finish_step_nm2 // CHECK-NEXT: return } @@ -147,43 +144,6 @@ module { } } -// ----- -// A test verifying TPUReplicatedOutput in the input graph doesn't trigger -// any additional TPUReplicatedInput or TPUReplicatedOutput ops. -module { - func.func @main() { - %cst_1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %cst_2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor - %0:2 = "tf.While"(%cst_1, %cst_2) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor) -> (tensor, tensor) - return - } - func.func private @while_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} - // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} - // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} - // metadata ops - "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () - %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor - %2 = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<3> : tensor} : () -> tensor - %3:2 = "tf.TPUReplicatedOutput"(%2) {device = ""} : (tensor) -> (tensor, tensor) - - // core_tpu ops: - %res_t = "tf.Const"() {_replication_info = "repl_info", value = dense<4> : tensor} : () -> tensor - - // non_tpu_ops - %res_n = "tf.Const"() {value = dense<5> : tensor} : () -> tensor - - return %res_n, %3#1 : tensor, tensor - } - func.func private @while_cond(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.Less"(%arg1, %arg0) : (tensor, tensor) -> tensor - return %0 : tensor - } - // CHECK-DAG: TPUReplicatedOutput - // CHECK-NOT: TPUReplicatedoutput - // CHECK-NOT: TPUReplicatedInput -} - // ----- // Verify error for backward pass with no forward pass. module { @@ -317,3 +277,207 @@ module { return %0 : tensor } } + +// ----- +// Verify one while body function per while loop op. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + // expected-error @+1 {{'tf.While' op multiple users of function.}} + %1 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// Verify that the function to be pipelined is a while loop body function. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + return + } + // expected-error @+1 {{'func.func' op unable to find while body user.}} + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// This test verifies that TPUReplicatedInputOps for resource variable args are packed. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0, %1 = "tf.While"(%cst_main, %arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor<*x!tf_type.resource>>) -> (tensor, tensor<*x!tf_type.resource>>) + return + } + + func.func private @while_body(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> (tensor, tensor<*x!tf_type.resource>>) { + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // expected-error @+1 {{'tf.TPUReplicatedInput' op unexpected variable input, not packed}} + %37 = "tf.TPUReplicatedInput"(%arg1) {device = "", index = -1 : i64, is_mirrored_variable = true, is_packed = false} : (tensor<*x!tf_type.resource>>) -> tensor<*x!tf_type.resource>> + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + %cst_12 = "tf.Const"() {_replication_info = "repl_info", _xla_compile_device_type = "TPU", device = "", value = dense<1> : tensor} : () -> tensor + "tf.AssignAddVariableOp"(%37, %cst_12) {_has_manual_control_dependencies = true, _replication_info = "while/cluster_while_body_451", _xla_compile_device_type = "TPU", device = ""} : (tensor<*x!tf_type.resource>>, tensor) -> () + + return %res_n, %arg1 : tensor, tensor<*x!tf_type.resource>> + } + func.func private @while_cond(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// This test verifies that duplicate TPUReplicatedInput ops for a resource variable arg is an error. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0, %1 = "tf.While"(%cst_main, %arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor<*x!tf_type.resource>>) -> (tensor, tensor<*x!tf_type.resource>>) + return + } + + func.func private @while_body(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> (tensor, tensor<*x!tf_type.resource>>) { + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // expected-error @+1 {{'tf.TPUReplicatedInput' op unexpected multiple TPUReplicatedInputOp for single argument}} + %37 = "tf.TPUReplicatedInput"(%arg1) {device = "", index = -1 : i64, is_mirrored_variable = true, is_packed = true} : (tensor<*x!tf_type.resource>>) -> tensor<*x!tf_type.resource>> + %38 = "tf.TPUReplicatedInput"(%arg1) {device = "", index = -1 : i64, is_mirrored_variable = true, is_packed = true} : (tensor<*x!tf_type.resource>>) -> tensor<*x!tf_type.resource>> + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + %cst_12 = "tf.Const"() {_replication_info = "repl_info", _xla_compile_device_type = "TPU", device = "", value = dense<1> : tensor} : () -> tensor + "tf.AssignAddVariableOp"(%37, %cst_12) {_has_manual_control_dependencies = true, _replication_info = "while/cluster_while_body_451", _xla_compile_device_type = "TPU", device = ""} : (tensor<*x!tf_type.resource>>, tensor) -> () + + return %res_n, %arg1 : tensor, tensor<*x!tf_type.resource>> + } + func.func private @while_cond(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// This test verifies the EliminateResourceLoops workaround. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0, %1 = "tf.While"(%cst_main, %arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor<*x!tf_type.resource>>) -> (tensor, tensor<*x!tf_type.resource>>) + return + } + + func.func private @while_body(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> (tensor, tensor<*x!tf_type.resource>>) { + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + %rsrc_copy = "tf.StatefulPartitionedCall"(%arg1) {f = @broken_func, config = "", config_proto = "", executor_type = ""} : (tensor<*x!tf_type.resource>>) -> (tensor<*x!tf_type.resource>>) + // We expect uses of %rsrc_copy are replaced by the input resource variable (%arg1 in this context). + "tf.StatefulPartitionedCall"(%arg1) {f = @func1, config = "", config_proto = "", executor_type = ""} : (tensor<*x!tf_type.resource>>) -> () + "tf.StatefulPartitionedCall"(%rsrc_copy) {f = @func2, config = "", config_proto = "", executor_type = ""} : (tensor<*x!tf_type.resource>>) -> () + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + %cst_12 = "tf.Const"() {_replication_info = "repl_info", _xla_compile_device_type = "TPU", device = "", value = dense<1> : tensor} : () -> tensor + + return %res_n, %arg1 : tensor, tensor<*x!tf_type.resource>> + } + func.func private @while_cond(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @broken_func(%arg0: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> (tensor<*x!tf_type.resource>>) { + %x = "tf.Identity"(%arg0) : (tensor<*x!tf_type.resource>>) -> (tensor<*x!tf_type.resource>>) + %y = "tf.Identity"(%x) : (tensor<*x!tf_type.resource>>) -> (tensor<*x!tf_type.resource>>) + return %y : tensor<*x!tf_type.resource>> + } + func.func private @func1(%arg0: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> () { + return + } + func.func private @func2(%arg0: tensor<*x!tf_type.resource>> {tf._composite_device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0", tf._user_specified_name = "rsrc", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:COMPOSITE:0"}) -> () { + return + } + // Make sure func1 and func2 use the original resource variable and not the result of @broken_func. + // CHECK: func.func private @non_tpu + // CHECK: {{.*%0 = \"tf.StatefulPartitionedCall\"\(%arg0\).*f = @broken_func.*}} + // CHECK: {{.*StatefulPartitionedCall\"\(%arg0\).*f = @func1.*}} + // CHECK: {{.*StatefulPartitionedCall\"\(%arg0\).*f = @func2.*}} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/embedding_sequencing.mlir b/tensorflow/compiler/mlir/tensorflow/tests/embedding_sequencing.mlir new file mode 100644 index 00000000000..0a8a3069861 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/embedding_sequencing.mlir @@ -0,0 +1,319 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-embedding-sequencing | FILECHECK_OPTS="" FileCheck %s + +// This test verifies the handling of TPU replicated inputs and outputs as well as the extraction of the four main functions. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + // Verify that everything is extracted into one of the four functions. + // The order of these functions is also significant. + // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_backward.*}} + // CHECK-NEXT: return + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } + // Generated functions + // non_tpu should have to TPU ops - just identity and return (in this test). + // CHECK: func.func private @_func_non_tpu + // CHECK-NEXT: tf.Identity + // CHECK-NEXT: return + + // sc_forward should have TPU ops including replicated outputs but not inputs + // CHECK: func.func private @_func_sc_forward + // CHECK-NOT: TPUReplicatedInput + // CHECK-DAG: TPUReplicateMetadata + // CHECK-DAG: TPUCompilationResult + // CHECK-DAG: TPUReplicatedOutput + // CHECK: return + + // core_tput should have TPU ops including both replicated inputs and outputs + // CHECK: func.func private @_func_core_tpu + // CHECK-DAG: TPUReplicatedInput + // CHECK-DAG: TPUReplicateMetadata + // CHECK-DAG: TPUCompilationResult + // CHECK-DAG: TPUReplicatedOutput + // CHECK: return + + // sc_backward should have TPU ops including replicted inputs but not outputs + // CHECK: func.func private @_func_sc_backward + // CHECK-NOT: TPUReplicatedOutput + // CHECK-DAG: TPUReplicateMetadata + // CHECK-DAG: TPUCompilationResult + // CHECK-DAG: TPUReplicatedInput + // CHECK: return +} + +// ----- +// This test verifies that the extraction works correctly for evaluation-only models. +module { + func.func @main() { + %cst = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %0 = "tf.While"(%cst) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Identity"(%arg0) {_embedding_pipelining = "forward", _replication_info = "repl_info"} : (tensor) -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } + // Only verify sc_backward. The previous test case verifies everything else. + // CHECK: func.func private @_func_sc_backward + // CHECK-NEXT: return +} + +// ----- +// A test verifying too many TPUReplicateMetadataOp ops. Same logic tests too many TPUCompilationResultOp ops. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*x!tf_type.resource>>) { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + // expected-error @+1 {{number of tf.TPUReplicateMetadata in loop body is not 1}} + func.func private @while_body(%arg0: tensor) -> (tensor) { + // metadata ops + %embedding_pass_trigger = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + return %arg0 : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- +// A test verifying the replication region of TPUReplicateMetadataOp ops. Same logic tests too many TPUCompilationResultOp ops. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*x!tf_type.resource>>) { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + // metadata ops + %embedding_pass_trigger = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + // expected-error @+1 {{'tf.TPUCompilationResult' op is not part of the replication region "repl_info" vs "wrong_repl_info"}} + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "wrong_repl_info"} : () -> tensor + return %arg0 : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- +// A test verifying TPUReplicatedOutput in the input graph doesn't trigger +// any additional TPUReplicatedInput or TPUReplicatedOutput ops. +module { + func.func @main() { + %cst_1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %0:2 = "tf.While"(%cst_1, %cst_2) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + return + } + func.func private @while_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + %2 = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<3> : tensor} : () -> tensor + %3:2 = "tf.TPUReplicatedOutput"(%2) {device = ""} : (tensor) -> (tensor, tensor) + + // core_tpu ops: + %res_t = "tf.Const"() {_replication_info = "repl_info", value = dense<4> : tensor} : () -> tensor + + // non_tpu_ops + %res_n = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + + return %res_n, %3#1 : tensor, tensor + } + func.func private @while_cond(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.Less"(%arg1, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } + // CHECK-DAG: TPUReplicatedOutput + // CHECK-NOT: TPUReplicatedoutput + // CHECK-NOT: TPUReplicatedInput +} + +// ----- +// Verify error for backward pass with no forward pass. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + // expected-error @+1 {{'tf.Identity' op embedding backwards pass op with no forwards pass ops}} + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// Verify error for unknown _embedding_pipelining attribute value. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + // expected-error @+1 {{'tf.Identity' op embedding op has unknown _embedding_pipelining attribute value garbage.}} + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "garbage", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// Verify error for multiple WhileOp use of while_body function. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + // expected-error @+1 {{'tf.While' op multiple users of function.}} + %1 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// Verify error for non-WhileOp use of while_body function. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + // expected-error @+1 {{'tf.StatefulPartitionedCall' op non while use of function.}} + %38 = "tf.StatefulPartitionedCall"(%cst_main) {config = "", config_proto = "", executor_type = "", f = @while_body} : (tensor) -> tensor + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD index 954eca9c0e2..421bbd5de79 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD index 954eca9c0e2..421bbd5de79 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD index 954eca9c0e2..421bbd5de79 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir index f9a097d8fef..8657ed861c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir @@ -881,7 +881,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf._XlaHostComputeMlir"(%6) + // CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]]) // CHECK-SAME: key = "if_predicate_channel_1" // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) @@ -932,7 +932,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf._XlaHostComputeMlir"(%6) + // CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]]) // CHECK-SAME: key = "if_predicate_channel_0" // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK-NEXT: "tf.Yield"() : () -> () @@ -2098,3 +2098,203 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func.return %0 : tensor<2xi32> } } + +// ----- + +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1443 : i32}} { + // Tests map_outside_compilation when there is no replication. + // The sharding is: + // type: OTHER + // tile_assignment_dimensions: 2 + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // tile_assignment_devices: 1 + // Serialized string: + // "\08\03\1A\02\02\01\22\02\00\01" + + // CHECK-LABEL: func @map_outside_compilation_not_replicated + func.func @map_outside_compilation_not_replicated() -> () { + // CHECK: "tf_device.parallel_execute" + // CHECK: "tf_device.launch" + // CHECK: %[[PROGRAM0:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey" + // CHECK: %[[RECV0:.+]] = "tf._XlaRecvAtHost"(%[[PROGRAM0]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_0_args" + // CHECK: %[[B0:.+]] = "tf.OpB"(%[[RECV0]]) : (tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: "tf._XlaSendFromHost"(%[[B0]], %[[PROGRAM0]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_0_retvals" + // CHECK: }, { + // CHECK: %[[PROGRAM1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey" + // CHECK: %[[RECV1:.+]] = "tf._XlaRecvAtHost"(%[[PROGRAM1]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: device_ordinal = 1 + // CHECK-SAME: key = "host_compute_channel_0_args" + // CHECK: %[[B1:.+]] = "tf.OpB"(%[[RECV1]]) : (tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: "tf._XlaSendFromHost"(%[[B1]], %[[PROGRAM1]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: device_ordinal = 1 + // CHECK-SAME: key = "host_compute_channel_0_retvals" + // CHECK: }, { + // CHECK: "tf_device.cluster" + // CHECK: %[[A:.+]] = "tf.OpA" + // CHECK: %[[A_SHARD:.+]] = "tf.XlaSpmdFullToShardShape"(%[[A]]) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xi64>) -> tensor<1x2xi64> + // CHECK: %[[B:.+]] = "tf._XlaHostComputeMlir"(%[[A_SHARD]]) + // CHECK-SAME: manual_sharding = true + // CHECK-SAME: recv_key = "host_compute_channel_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_0_args" + // CHECK: %[[B_FULL:.+]] = "tf.XlaSpmdShardToFullShape"(%[[B]]) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xi64>) -> tensor<2x2xi64> + // CHECK: "tf.OpC"(%[[B_FULL]]) + "tf_device.cluster"() ({ + %0 = "tf.OpA"() {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01"} : () -> tensor<2x2xi64> + %1 = "tf.OpB"(%0) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor<2x2xi64>) -> tensor<2x2xi64> + "tf.OpC"(%1) : (tensor<2x2xi64>) -> () + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + return + } +} + +// ----- + +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:2", "/job:localhost/replica:0/task:0/device:TPU:3", "/job:localhost/replica:0/task:0/device:TPU:4", "/job:localhost/replica:0/task:0/device:TPU:5", "/job:localhost/replica:0/task:0/device:TPU:6", "/job:localhost/replica:0/task:0/device:TPU:7", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1458 : i32}} { + // Tests map_outside_compilation when there is replication. + // The sharding is: + // type: OTHER + // tile_assignment_dimensions: 2 + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // tile_assignment_devices: 1 + // Serialized string: + // "\08\03\1A\02\02\01\22\02\00\01" + + // CHECK-LABEL: func @map_outside_compilation_replicated + func.func @map_outside_compilation_replicated() -> () { + // CHECK: tf_device.replicate + // CHECK: "tf_device.parallel_execute" + // CHECK: %[[PROGRAM0:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey" + // CHECK: %[[DEVICE0_0:.+]] = "tf._TPUDeviceOrdinalPlaceholder" + // CHECK: %[[RECV0:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM0]], %[[DEVICE0_0]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: key = "host_compute_channel_0_args" + // CHECK: %[[B0:.+]] = "tf.OpB"(%[[RECV0]]) : (tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: "tf._XlaSendFromHostV2"(%[[B0]], %[[PROGRAM0]], %[[DEVICE0_0]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: key = "host_compute_channel_0_retvals" + // CHECK: }, { + // CHECK: %[[PROGRAM1:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey" + // CHECK: %[[DEVICE1_0:.+]] = "tf._TPUDeviceOrdinalPlaceholder" + // CHECK: %[[ONE_0:.+]] = "tf.Const" + // CHECK-SAME: value = dense<1> + // CHECK: %[[DEVICE1_1:.+]] = "tf.AddV2"(%[[DEVICE1_0]], %[[ONE_0]]) + // CHECK: %[[RECV1:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM1]], %[[DEVICE1_1]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: key = "host_compute_channel_0_args" + // CHECK: %[[B1:.+]] = "tf.OpB"(%[[RECV1]]) : (tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: %[[ONE_1:.+]] = "tf.Const" + // CHECK-SAME: value = dense<1> + // CHECK: %[[DEVICE1_2:.+]] = "tf.AddV2"(%[[DEVICE1_0]], %[[ONE_1]]) + // CHECK: "tf._XlaSendFromHostV2"(%[[B1]], %[[PROGRAM1]], %[[DEVICE1_2]]) + // CHECK-SAME: _xla_has_host_transfer = true + // CHECK-SAME: key = "host_compute_channel_0_retvals" + // CHECK: }, { + // CHECK: "tf_device.cluster" + // CHECK: %[[A:.+]] = "tf.OpA" + // CHECK: %[[A_SHARD:.+]] = "tf.XlaSpmdFullToShardShape"(%[[A]]) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xi64>) -> tensor<1x2xi64> + // CHECK: %[[B:.+]] = "tf._XlaHostComputeMlir"(%[[A_SHARD]]) + // CHECK-SAME: manual_sharding = true + // CHECK-SAME: recv_key = "host_compute_channel_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_0_args" + // CHECK: %[[B_FULL:.+]] = "tf.XlaSpmdShardToFullShape"(%[[B]]) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xi64>) -> tensor<2x2xi64> + // CHECK: "tf.OpC"(%[[B_FULL]]) + tf_device.replicate() {n = 4 : i32} { + "tf_device.cluster"() ({ + %0 = "tf.OpA"() {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01"} : () -> tensor<2x2xi64> + %1 = "tf.OpB"(%0) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor<2x2xi64>) -> tensor<2x2xi64> + "tf.OpC"(%1) : (tensor<2x2xi64>) -> () + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + tf_device.return + } + return + } +} + +// ----- + +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1443 : i32}} { + + // Test that map_outside_compilation's inputs are not unranked. + func.func @map_outside_compilation_must_be_ranked() -> () { + "tf_device.cluster"() ({ + %0 = "tf.OpA"() : () -> tensor<*xi64> + // expected-error @+1 {{must be ranked}} + %1 = "tf.OpB"(%0) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor<*xi64>) -> tensor<*xi64> + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + return + } + + // Test that map_outside_compilation's inputs have rank >= 1. + func.func @map_outside_compilation_must_have_rank_gte_1() -> () { + "tf_device.cluster"() ({ + %0 = "tf.OpA"() : () -> tensor + // expected-error @+1 {{must have rank at least one}} + %1 = "tf.OpB"(%0) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor) -> tensor + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + return + } + + // Test that map_outside_compilation's inputs shapes are divisible by num_cores_per_replica. + func.func @map_outside_compilation_div_num_cores_per_replica() -> () { + "tf_device.cluster"() ({ + %0 = "tf.OpA"() : () -> tensor<3xi64> + // expected-error @+1 {{divisible by num_cores_per_replica}} + %1 = "tf.OpB"(%0) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor<3xi64>) -> tensor<3xi64> + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + return + } + + // Test that map_outside_compilation's preceeding ops have an _XlaSharding attribute. + func.func @map_outside_compilation_explicit_sharding() -> () { + "tf_device.cluster"() ({ + %0 = "tf.OpA"() : () -> tensor<2xi64> + // expected-error @+1 {{should have an explicit sharding}} + %1 = "tf.OpB"(%0) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor<2xi64>) -> tensor<2xi64> + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + return + } + + // Test that map_outside_compilation has at least 1 input to the + // _XlaHostComputeMlir op. In this case, %arg0 is not input to the + // generated _XlaHostComputeMlir. + func.func @map_outside_compilation_preceeding_op(%arg0 : tensor<2xi64>) -> () { + "tf_device.cluster"() ({ + // expected-error @+1 {{should have at least one input}} + %1 = "tf.OpB"(%arg0) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor<2xi64>) -> tensor<2xi64> + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + return + } + + // Test that map_outside_compilation inputs have the same sharding. + func.func @map_outside_compilation_same_sharding() -> () { + tf_device.replicate() {n = 4 : i32} { + "tf_device.cluster"() ({ + %0 = "tf.OpA"() {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\01"} : () -> tensor<2x2xi64> + %1 = "tf.OpB"() {_XlaSharding = "\08\03\1A\02\02\01\22\02\00\02"} : () -> tensor<2x2xi64> + // expected-error @+1 {{should have the same sharding}} + %2 = "tf.OpC"(%0, %1) {_xla_map_outside_compilation = "0", _xla_outside_compilation = "from_launch"} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> + "tf.OpD"(%2) : (tensor<2x2xi64>) -> () + tf_device.return + }) {_xla_compile_device_type = "TPU", computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, padding_map = [], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01*\02\08\01", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + tf_device.return + } + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD index 71493d0f30a..186794e8891 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [ ":debug_info_files", ":test_utilities", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt index 2d2ad5b5083..515d74231df 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-retval-attrs.pbtxt @@ -152,4 +152,4 @@ versions { # CHECK: func @main # CHECK-SAME: ({{%.*}}: tensor<*xf32>, {{%.*}}: tensor<*xi32> {tf._arg1_attr0 = "_arg1_attr0_value", tf._arg1_attr1 = 8.000000e+00 : f32}, {{%.*}}: tensor<*xi1>) # CHECK-SAME: -> (tensor<*xf32> {tf._ret0_attr0 = 8 : i64, tf._ret0_attr1 = false}, tensor<*xi32>, tensor<*xi1> {tf._ret2_attr0 = !tf_type.variant, tf._ret2_attr1 = #tf_type.shape<128x1024>}) -# CHECK-SAME: attributes {tf.entry_function = {control_outputs = "", inputs = "arg0,arg1,arg2", outputs = "ret0,ret1,ret2"}} +# CHECK-SAME: attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "arg0,arg1,arg2", outputs = "ret0,ret1,ret2"}} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD index ab1cc6459a1..b770ab3bf89 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [ ":debug_info_files", ":test_utilities", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt index d4d5b8e3c52..eef2fbb92b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-xla-compile-device-type="GPU" -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-xla-compile-device-type="GPU" -tf-enable-soft-placement-on-import=true -o - | FileCheck %s # Verify main graph was converted to a function, args/rets are mapped correctly, # and ops in the main graph are retained. In addition, check if subsequent @@ -6,6 +6,7 @@ # CHECK: func @main(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>>, %arg2: tensor<*xf32>, %arg3: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<*xf32>) # CHECK-SAME: _xla_compile_device_type = "GPU" +# CHECK-SAME: allow_soft_placement # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "args_0,args_1,args_2,args_3" # CHECK-SAME: outputs = "rets_0,rets_1" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 56b9adab296..68440b125d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1,5 +1,5 @@ // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: tf-opt -tf-legalize-hlo %s | FileCheck %s +// RUN: tf-opt -tf-legalize-hlo %s -verify-diagnostics -split-input-file | FileCheck %s // CHECK-LABEL: func @biasAdd_NHWC( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x10x32xi32>, @@ -1476,7 +1476,16 @@ func.func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { func.func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> { %0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32> func.return %0 : tensor<2x2x6xf32> +} +// CHECK-LABEL: func @round_nearest_even( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.Round"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[VAL_1]] : tensor<2xf32> +// CHECK: } +func.func @round_nearest_even(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.round_nearest_even"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> } // CHECK-LABEL: func @convert_dot_2d_1d( @@ -1733,13 +1742,25 @@ func.func @no_convert_conv1d_feature_group_gt_1(%arg0: tensor<16x32x256xbf16>, % func.return %0 : tensor<16x32x128xbf16> } -// CHECK-LABEL: func.func @no_convert_conv1d_missing_windows_strides( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { -// CHECK: %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {pad = {{\[\[}}0, 0]], lhs_dilate = [1], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> -// CHECK: return %[[VAL_2]] : tensor<16x32x256xbf16> +// CHECK-LABEL: func.func @convert_conv1d_missing_windows_strides_fallback( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK-DAG: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16> +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16> +// CHECK-DAG: %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16> +// CHECK: %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16> +// CHECK: %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK: %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64> +// CHECK: %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16> +// CHECK: return %[[VAL_14]] : tensor<16x32x256xbf16> // CHECK: } -func.func @no_convert_conv1d_missing_windows_strides(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { +func.func @convert_conv1d_missing_windows_strides_fallback(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, @@ -1752,6 +1773,25 @@ func.func @no_convert_conv1d_missing_windows_strides(%arg0: tensor<16x32x256xbf1 func.return %0 : tensor<16x32x256xbf16> } +// CHECK-LABEL: func.func @convert_conv1d_missing_windows_strides_fallback_2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x64x64x4xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> { +// CHECK: %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x64x64x4xbf16>, tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> +// CHECK: return %[[VAL_2]] : tensor<1x62x62x320xbf16> +// CHECK: } +func.func @convert_conv1d_missing_windows_strides_fallback_2(%arg0: tensor<1x64x64x4xbf16>, %arg1: tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<[1, 1]> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<[1, 1]> : tensor<2xi64> + } : (tensor<1x64x64x4xbf16>, tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> + func.return %0 : tensor<1x62x62x320xbf16> +} + // CHECK-LABEL: func @convert_conv2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x8x207xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { @@ -2625,6 +2665,89 @@ func.func @convert_gather_offset(%arg0: tensor<1x20xi32>, %arg1: tensor<1x1xi32> func.return %0 : tensor<1x1xi32> } +// CHECK-LABEL: func @convert_gather_to_slice_batch_size_1( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x2944xi32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x2xi32>) +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 1440]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_0:.*]] = "tf.Maximum"(%[[ARG_1]], %[[CST_0:.*]]) : (tensor<1x2xi32>, tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_1:.*]] = "tf.Minimum"(%[[VAL_0]], %[[CST]]) : (tensor<1x2xi32>, tensor<2xi32>) -> tensor<1x2xi32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 1504]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_3:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_2]], %[[CST_1]]) : (tensor<1x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32> +// CHECK: return %[[VAL_3]] +// CHECK: } +func.func @convert_gather_to_slice_batch_size_1(%arg0: tensor<1x2944xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x1504xi32> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [1], + collapsed_slice_dims = [0], + start_index_map = [0, 1], + index_vector_dim = 1, + >, + indices_are_sorted = true, + slice_sizes = dense<[1, 1504]> : tensor<2xi64> + } : (tensor<1x2944xi32>, tensor<1x2xi32>) -> tensor<1x1504xi32> + func.return %0 : tensor<1x1504xi32> +} + +// CHECK-LABEL: func @convert_gather_to_slice( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x2944xi32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<3x2xi32>) +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 1440]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_0:.*]] = "tf.Maximum"(%[[ARG_1]], %[[CST_0]]) : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32> +// CHECK: %[[VAL_1:.*]] = "tf.Minimum"(%[[VAL_0]], %[[CST]]) : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 1504]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CST_3:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_2:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_2]], %[[CST_3]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_3:.*]] = "tf.Squeeze"(%[[VAL_2]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_4:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_3]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32> +// CHECK-DAG: %[[CST_4:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CST_5:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_5:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_4]], %[[CST_5]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_6:.*]] = "tf.Squeeze"(%[[VAL_5]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_7:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_6]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32> +// CHECK-DAG: %[[CST_6:.*]] = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CST_7:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_8:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_6]], %[[CST_7]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_9:.*]] = "tf.Squeeze"(%[[VAL_8]]) {squeeze_dims = [0]} : (tensor<1x2xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_10:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_9]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32> +// CHECK-DAG: %[[CST_8:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_11:.*]] = "tf.ConcatV2"(%[[VAL_4]], %[[VAL_7]], %[[VAL_10]], %[[CST_8]]) : (tensor<1x1504xi32>, tensor<1x1504xi32>, tensor<1x1504xi32>, tensor) -> tensor<3x1504xi32> +// CHECK: return %[[VAL_11]] +// CHECK: } +func.func @convert_gather_to_slice(%arg0: tensor<3x2944xi32>, %arg1: tensor<3x2xi32>) -> tensor<3x1504xi32> { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [1], + collapsed_slice_dims = [0], + start_index_map = [0, 1], + index_vector_dim = 1, + >, + indices_are_sorted = true, + slice_sizes = dense<[1, 1504]> : tensor<2xi64> + } : (tensor<3x2944xi32>, tensor<3x2xi32>) -> tensor<3x1504xi32> + func.return %0 : tensor<3x1504xi32> +} + +// CHECK-LABEL: func @convert_gather_to_slice_dynamic_error +func.func @convert_gather_to_slice_dynamic_error(%arg0: tensor<3x?xi32>, %arg1: tensor<3x2xi32>) -> tensor<3x1504xi32> { + // expected-error @+1 {{Dynamic shaped inputs are not supported.}} + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [1], + collapsed_slice_dims = [0], + start_index_map = [0, 1], + index_vector_dim = 1, + >, + indices_are_sorted = true, + slice_sizes = dense<[1, 1504]> : tensor<2xi64> + } : (tensor<3x?xi32>, tensor<3x2xi32>) -> tensor<3x1504xi32> + func.return %0 : tensor<3x1504xi32> +} + // CHECK-LABEL: func @convert_dynamic_slice( // CHECK-SAME: %[[VAL_0:.*]]: tensor<7x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor, @@ -2755,6 +2878,39 @@ func.func @convert_scatter_update_to_non_trailing_operand_dimensions( func.return %0 : tensor<5x4x3x7xf32> } +// CHECK-LABEL: func @convert_scatter_update_reshape_indices_and_updates( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<16x1504xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[ARG_2:.*]]: tensor<16xf32>) -> tensor<16x1504xf32> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = "tf.Transpose"(%[[ARG_0]], %[[CST]]) : (tensor<16x1504xf32>, tensor<2xi64>) -> tensor<1504x16xf32> +// CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_1:.*]] = "tf.Reshape"(%[[ARG_1]], %[[CST_0]]) : (tensor<1xi32>, tensor<2xi32>) -> tensor<1x1xi32> +// CHECK: %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 16]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[ARG_2]], %[[CST_1]]) : (tensor<16xf32>, tensor<2xi32>) -> tensor<1x16xf32> +// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterUpdate"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1504x16xf32>, tensor<1x1xi32>, tensor<1x16xf32>) -> tensor<1504x16xf32> +// CHECK: %[[CST_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_3]], %[[CST_2]]) : (tensor<1504x16xf32>, tensor<2xi64>) -> tensor<16x1504xf32> +// CHECK: return %[[VAL_4]] +// CHECK: } +func.func @convert_scatter_update_reshape_indices_and_updates( + %arg0: tensor<16x1504xf32>, + %arg1: tensor<1xi32>, + %arg2: tensor<16xf32>) -> tensor<16x1504xf32> +{ + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ +^bb0(%arg3: tensor, %arg4: tensor): + "mhlo.return"(%arg4) : (tensor) -> () +}) { + indices_are_sorted = true, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [0], + inserted_window_dims = [1], + scatter_dims_to_operand_dims = [1]>, + unique_indices = true} : (tensor<16x1504xf32>, tensor<1xi32>, tensor<16xf32>) -> tensor<16x1504xf32> + func.return %0 : tensor<16x1504xf32> +} + // CHECK-LABEL: func @convert_scatter_add( // CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, @@ -3264,30 +3420,9 @@ func.func @if(%arg0: tensor) -> (tensor) { // CHECK-SAME: %[[VAL_2:[a-z0-9]*]]: tensor, // CHECK-SAME: %[[VAL_3:[a-z0-9]*]]: tensor, // CHECK-SAME: %[[VAL_4:[a-z0-9]*]]: tensor) -> tensor<28x1x100xf32> { -// CHECK-DAG: %[[CST_0:[_a-z0-9]*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK-DAG: %[[CST_1:[_a-z0-9]*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor -// CHECK: %[[START_IND:[_a-z0-9]*]] = "tf.Pack"(%[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> -// CHECK-DAG: %[[OP_SHAPE:[_a-z0-9]*]] = "tf.Const"() {value = dense<[28, 1, 100]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK-DAG: %[[UP_SHAPE:[_a-z0-9]*]] = "tf.Const"() {value = dense<[1, 1, 100]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: %[[MAX_START:[_a-z0-9]*]] = "tf.Sub"(%[[OP_SHAPE]], %[[UP_SHAPE]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> -// CHECK: %[[START_1:[_a-z0-9]*]] = "tf.Minimum"(%[[START_IND]], %[[MAX_START]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> -// CHECK: %[[CLAMP_START:[_a-z0-9]*]] = "tf.Maximum"(%[[START_1]], %[[CST_0]]) : (tensor<3xi32>, tensor) -> tensor<3xi32> -// CHECK: %[[N_OP:[_a-z0-9]*]] = "tf.Const"() {value = dense<2800> : tensor} : () -> tensor -// CHECK: %[[FLAT_RANGE:[_a-z0-9]*]] = "tf.Range"(%[[CST_0]], %[[N_OP]], %[[CST_1]]) : (tensor, tensor, tensor) -> tensor<2800xi32> -// CHECK: %[[OP_SHAPE_1:[_a-z0-9]*]] = "tf.Const"() {value = dense<[28, 1, 100]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: %[[RANGE:[_a-z0-9]*]] = "tf.Reshape"(%[[FLAT_RANGE]], %[[OP_SHAPE_1]]) : (tensor<2800xi32>, tensor<3xi32>) -> tensor<28x1x100xi32> -// CHECK: %[[UP_SHAPE_1:[_a-z0-9]*]] = "tf.Const"() {value = dense<[1, 1, 100]> : tensor<3xi32>} : () -> tensor<3xi32> -// CHECK: %[[UPDATE_IDX:[_a-z0-9]*]] = "tf.Slice"(%[[RANGE]], %[[CLAMP_START]], %[[UP_SHAPE_1]]) : (tensor<28x1x100xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x1x100xi32> -// CHECK: %[[FLAT_UP_SHAPE:[_a-z0-9]*]] = "tf.Const"() {value = dense<[100, 1]> : tensor<2xi32>} : () -> tensor<2xi32> -// CHECK: %[[FLAT_UP_IDX:[_a-z0-9]*]] = "tf.Reshape"(%[[UPDATE_IDX]], %[[FLAT_UP_SHAPE]]) : (tensor<1x1x100xi32>, tensor<2xi32>) -> tensor<100x1xi32> -// CHECK: %[[FLAT_OP_SHAPE:[_a-z0-9]*]] = "tf.Const"() {value = dense<2800> : tensor<1xi32>} : () -> tensor<1xi32> -// CHECK: %[[FLAT_OP:[_a-z0-9]*]] = "tf.Reshape"(%[[VAL_0]], %[[FLAT_OP_SHAPE]]) : (tensor<28x1x100xf32>, tensor<1xi32>) -> tensor<2800xf32> -// CHECK: %[[FLAT_UP_SHAPE_1:[_a-z0-9]*]] = "tf.Const"() {value = dense<100> : tensor<1xi32>} : () -> tensor<1xi32> -// CHECK: %[[FLAT_UP:[_a-z0-9]*]] = "tf.Reshape"(%[[VAL_1]], %[[FLAT_UP_SHAPE_1]]) : (tensor<1x1x100xf32>, tensor<1xi32>) -> tensor<100xf32> -// CHECK: %[[FLAT_RESULT:[_a-z0-9]*]] = "tf.TensorScatterUpdate"(%[[FLAT_OP]], %[[FLAT_UP_IDX]], %[[FLAT_UP]]) : (tensor<2800xf32>, tensor<100x1xi32>, tensor<100xf32>) -> tensor<2800xf32> -// CHECK: %[[RESULT:[_a-z0-9]*]] = "tf.Reshape"(%[[FLAT_RESULT]], %[[OP_SHAPE]]) : (tensor<2800xf32>, tensor<3xi32>) -> tensor<28x1x100xf32> -// CHECK: return %[[RESULT]] : tensor<28x1x100xf32> -// CHECK: } +// CHECK: %0 = "tf.Pack"(%arg2, %arg3, %arg4) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> +// CHECK: %1 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %0) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor<3xi32>) -> tensor<28x1x100xf32> +// CHECK: return %1 : tensor<28x1x100xf32> func.func @convert_dynamic_update_slice(%arg0: tensor<28x1x100xf32>, %arg1: tensor<1x1x100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<28x1x100xf32> { %0 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor, tensor, tensor) -> tensor<28x1x100xf32> func.return %0 : tensor<28x1x100xf32> @@ -3577,7 +3712,7 @@ func.func @reduce_window_trivial_window_dims(%arg0: tensor<4x12xf32>) -> tensor< // expected-error @+1 {{no reduced dimension is found.}} %1 = "mhlo.reduce_window"(%arg0, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor): - %2 = mhlo.add %arg1, %arg2 : tensor + %2 = mhlo.multiply %arg1, %arg2 : tensor "mhlo.return"(%2) : (tensor) -> () }) {padding = dense<0> : tensor<2x2xi64>, window_dimensions = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor) -> tensor<4x12xf32> func.return %1 : tensor<4x12xf32> @@ -3596,3 +3731,30 @@ func.func @convert_dot_quant_type(%arg0: tensor<1x256xf32>, %arg1: tensor<256x!q %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo]} : (tensor<1x256xf32>, tensor<256x!quant.uniform>) -> tensor<1xf32> func.return %0 : tensor<1xf32> } + +// CHECK-LABEL: func @convert_approx_top_k_custom_call( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x4xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x4xi32>, +// CHECK-SAME: %[[ARG_2:.*]]: tensor, +// CHECK-SAME: %[[ARG_3:.*]]: tensor) -> (tensor<1x4xf32>, tensor<1x4xi32>) { +// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = "tf.ApproxTopK"(%[[ARG_0]]) {aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) +// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<1x4xf32>, tensor<1x4xi32> +// CHECK: } +func.func @convert_approx_top_k_custom_call(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xi32>, %arg2: tensor, %arg3: tensor) -> (tensor<1x4xf32>, tensor<1x4xi32>) { + %0:2 = mhlo.custom_call @ApproxTopK(%arg0, %arg1, %arg2, %arg3) { + api_version = 4 : i32, + called_computations = [@top_k_gt_f32_comparator], + backend_config = { + aggregate_to_topk = true, + is_fallback = true, + recall_target = 8.500000e-01 : f32, + reduction_dim = 1 : i64, + reduction_input_size_override = -1 : i64, + top_k = 4 : i64} + } : (tensor<1x4xf32>, tensor<1x4xi32>, tensor, tensor) -> (tensor<1x4xf32>, tensor<1x4xi32>) + func.return %0#0, %0#1 : tensor<1x4xf32>, tensor<1x4xi32> +} +func.func @top_k_gt_f32_comparator(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = mhlo.compare GT, %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD index 432d0ab8733..4c9bd8a03e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir index 019197a0a6c..9e2a83f4e06 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-order.mlir @@ -4,10 +4,13 @@ func.func @main() { tf_executor.graph { // CHECK: node { - // CHECK-NEXT: name: "tf.foo" - // CHECK-NEXT: op: "foo" + // CHECK-NEXT: name: "tf.PartitionedCall" + // CHECK-NEXT: op: "PartitionedCall" + // CHECK: func { + // CHECK: name: "foo" + // CHECK: } // CHECK: } - %0:2 = tf_executor.island wraps "tf.foo"() {name = "tf.foo"} : () -> tensor<*xf32> + %0 = tf_executor.island wraps "tf.PartitionedCall"() {Tin = [], Tout = [], config = "", config_proto = "", device = "", executor_type = "", f = @foo, name = "Call_foo"} : () -> () tf_executor.fetch } func.return @@ -65,14 +68,17 @@ func.func @bar() { // CHECK-NEXT: name: "foo" // CHECK-NEXT: } // CHECK-NEXT: node_def { -// CHECK-NEXT: name: "tf.bar" -// CHECK-NEXT: op: "bar" +// CHECK-NEXT: name: "tf.PartitionedCall" +// CHECK-NEXT: op: "PartitionedCall" +// CHECK: func { +// CHECK: name: "bar" +// CHECK: } // CHECK: } // CHECK-NEXT: } // CHECK: } func.func @foo() { tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.bar"() {name = "tf.bar"} : () -> tensor<*xf32> + %0 = tf_executor.island wraps "tf.PartitionedCall"() {Tin = [], Tout = [], config = "", config_proto = "", device = "", executor_type = "", f = @bar, name = "Call_bar"} : () -> () tf_executor.fetch } func.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir index c5aca980abb..4de7cb6ccaf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir @@ -8,10 +8,12 @@ func.func @main() { %0:2 = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", value = dense<2.500000e-01> : tensor} : () -> tensor loc("Const") // CHECK: node { - // CHECK-NEXT: name: "foo" - // CHECK-NEXT: op: "foo" - // CHECK-NEXT: input: "Const" - %1:2 = tf_executor.island wraps "tf.foo"(%0#0) {device = ""} : (tensor) -> tensor<*xf32> loc("foo") + // CHECK-NEXT: name: "tf.PartitionedCall" + // CHECK-NEXT: op: "PartitionedCall" + // CHECK: func { + // CHECK: name: "foo" + // CHECK: } + %1:2 = tf_executor.island wraps "tf.PartitionedCall"(%0) {Tin = [], Tout = [], config = "", config_proto = "", device = "", executor_type = "", f = @foo, name = "Call_foo"} : (tensor) -> tensor<*xf32> tf_executor.fetch } func.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index f3fee9f74d4..02c144467d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -53,7 +53,7 @@ func.func @only_resource_store() -> tensor<*xi32> { // ----- -// Tests that a resource ops with both load and store are hoisted. +// Tests that resource ops with both load and store are hoisted. // CHECK-LABEL: func @same_resource_load_and_store func.func @same_resource_load_and_store() -> tensor<*xi32> { @@ -82,7 +82,7 @@ func.func @same_resource_load_and_store() -> tensor<*xi32> { // ----- -// Tests that a resource ops with both load and store are hoisted +// Tests that resource ops with both load and store are hoisted // but input to load and output from store have mixed defined/undefined shapes. // CHECK-LABEL: func @same_resource_load_and_store_cast @@ -114,13 +114,85 @@ func.func @same_resource_load_and_store_cast() -> tensor<1xi32> { // ----- -// Tests that internal resource operations are not hoisted. +// Tests that anonymous internal resource operations are eliminated. -// CHECK-LABEL: func @internal_resource -func.func @internal_resource() -> tensor<*xi32> { +// CHECK-LABEL: func @anonymous_internal_resource +func.func @anonymous_internal_resource() -> tensor<*xi32> { + + // CHECK: %[[COMPUTE1_RES:[0-9]*]] = "tf.SomeComputation1"() + %0 = "tf.SomeComputation1"() : () -> (tensor<*xi32>) // CHECK: %[[CLUSTER_RES:[0-9]*]] = "tf_device.cluster" - %0 = "tf_device.cluster"() ({ + // CHECK-NOT: "tf.VarHandleOp" + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK-NOT: "tf.ReadVariableOp" + // CHECK: %[[COMPUTE2_RES:[0-9]*]] = "tf.SomeComputation2"(%[[COMPUTE1_RES]]) + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK-NOT: "tf.ReadVariableOp" + // CHECK: tf_device.return %[[COMPUTE2_RES]] + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> tensor<*xi32> + + %1 = "tf_device.cluster"() ( { + %1 = "tf.VarHandleOp"() {shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"} : () -> tensor<*x!tf_type.resource>> + "tf.AssignVariableOp"(%1, %0) {dtype = i32} : (tensor<*x!tf_type.resource>>, tensor<*xi32>) -> () + %2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf_type.resource>>) -> tensor<*xi32> + %3 = "tf.SomeComputation2"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) + "tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf_type.resource>>, tensor<*xi32>) -> () + %4 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf_type.resource>>) -> tensor<*xi32> + tf_device.return %4 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + + // CHECK: return %[[CLUSTER_RES]] + return %1 : tensor<*xi32> +} + +// ----- + +// Tests that anonymous internal resource operations (including DestroyResourceOp) are eliminated. + +// CHECK-LABEL: func @anonymous_internal_resource_with_destroy +func.func @anonymous_internal_resource_with_destroy() -> tensor<*xi32> { + + // CHECK: %[[COMPUTE1_RES:[0-9]*]] = "tf.SomeComputation1"() + %0 = "tf.SomeComputation1"() : () -> (tensor<*xi32>) + + // CHECK: %[[CLUSTER_RES:[0-9]*]] = "tf_device.cluster" + // CHECK-NOT: "tf.VarHandleOp" + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK-NOT: "tf.ReadVariableOp" + // CHECK: %[[COMPUTE2_RES:[0-9]*]] = "tf.SomeComputation2"(%[[COMPUTE1_RES]]) + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK-NOT: "tf.ReadVariableOp" + // CHECK-NOT: "tf.DestroyResourceOp" + // CHECK: tf_device.return %[[COMPUTE2_RES]] + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> tensor<*xi32> + + %1 = "tf_device.cluster"() ( { + %1 = "tf.VarHandleOp"() {shared_name = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"} : () -> tensor<*x!tf_type.resource>> + "tf.AssignVariableOp"(%1, %0) {dtype = i32} : (tensor<*x!tf_type.resource>>, tensor<*xi32>) -> () + %2 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf_type.resource>>) -> tensor<*xi32> + %3 = "tf.SomeComputation2"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) + "tf.AssignVariableOp"(%1, %3) {dtype = i32} : (tensor<*x!tf_type.resource>>, tensor<*xi32>) -> () + %4 = "tf.ReadVariableOp"(%1) {dtype = i32} : (tensor<*x!tf_type.resource>>) -> tensor<*xi32> + "tf.DestroyResourceOp"(%1) {dtype = i32} : (tensor<*x!tf_type.resource>>) -> () + tf_device.return %4 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + + // CHECK: return %[[CLUSTER_RES]] + return %1 : tensor<*xi32> +} + +// ----- + +// Tests that named internal resource operations are not hoisted. + +// CHECK-LABEL: func @named_internal_resource +func.func @named_internal_resource() -> tensor<*xi32> { + + // CHECK: %[[CLUSTER_RES:[0-9]*]] = "tf_device.cluster" + %0 = "tf_device.cluster"() ( { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource>> @@ -139,7 +211,7 @@ func.func @internal_resource() -> tensor<*xi32> { }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> // CHECK: return %[[CLUSTER_RES]] - func.return %0 : tensor<*xi32> + return %0 : tensor<*xi32> } // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index b489dd04e73..a760759e9ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -2628,3 +2628,39 @@ func.func @fetch_with_resource_operand( // expected-remark@above {{ID: 8}} // expected-remark@above {{Sinks: {7}}} } + +// ----- + +// Tests about we create the dependency PwStreamResults `start` and `end` +// within in XlaCallModule +func.func @_pws_program(%arg0: tensor {tf_saved_model.index_path = ["arg0"]}) -> (tensor {tf_saved_model.index_path = ["result0"]}, tensor {tf_saved_model.index_path = ["result1"]}) attributes {pws.program_id = 4722582128360897113 : i64, tf.entry_function = {}} { + // expected-remark@above {{ID: 7}} + "tf.PwStreamResults"(%arg0) {_callback_id = -2694175233261920887 : i64, _controller_address = "[2002:afb:afb::]:10004", _has_manual_control_dependencies = true, _model_name = "test", device = "/device/CPU", names = ["begin"]} : (tensor) -> () + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {4}}} + %0:2 = "tf_device.cluster"() ({ + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Successors: {5}}} + %1 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = "", unspecified_dims = []} : (tensor) -> tensor + // expected-remark@above {{ID: 1}} + %2:2 = "tf.XlaCallModule"(%1) {Sout = [#tf_type.shape<>, #tf_type.shape<>], dim_args_spec = [], function_list = [@__inference_callable_flat_tf_150], module = "ML\EFR\00__inference_callable_flat_tf_15\00", platforms = [], version = 5 : i64} : (tensor) -> (tensor, tensor) + // expected-remark@above {{ID: 2}} + tf_device.return %2#0, %2#1 : tensor, tensor + // expected-remark@above {{ID: 3}} + }) {_tpu_replicate = "cluster_0", allow_soft_placement = false, computation_shape = [], device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> (tensor, tensor) + "tf.PwStreamResults"(%arg0) {_callback_id = -2694175233261920887 : i64, _controller_address = "[2002:afb:afb::]:10004", _model_name = "test", device = "/device/CPU", names = ["end"]} : (tensor) -> () + // expected-remark@above {{Predecessors: {4}}} + // expected-remark@above {{ID: 5}} + return %0#0, %0#1 : tensor, tensor + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} +// expected-remark@below {{ID: 2}} +func.func private @__inference_callable_flat_tf_150(%arg0: tensor {tf._user_specified_name = "args_tf_flat_0"}, %arg1: tensor {tf._user_specified_name = "args_tf_flat_1"}) attributes {tf._XlaMustCompile = false, tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf._original_func_name = "__inference_callable_flat_tf_15", tf.signature.is_stateful} { + "tf.PwStreamResults"(%arg0, %arg1) {_callback_id = -2694175233261920887 : i64, _controller_address = "[2002:afb:afb::]:10004", _model_name = "test", names = ["foo", "bar"]} : (tensor, tensor) -> () + // expected-remark@above {{ID: 0}} + return + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Sinks: {0}}} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 74363ecc967..c17ef278ba1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -5134,3 +5134,11 @@ func.func @test_batch_function_with_invalid_symbol(%arg0: tensor<1x3xf32>, %arg1 "tf.BatchFunction"(%arg0, %arg1) {batch_timeout_micros = 100000 : i64, f = @undefined_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor>>) -> tensor<*xf32> func.return } + +// ----- + +func.func @test_xla_call_module_with_invalid_symbol() { + // expected-error @below {{refers to an undefined function: @undefined_function}} + "tf.XlaCallModule"() {Sout = [], device = "", dim_args_spec = [], function_list = [@undefined_function], module = "", platforms = [], version = 4 : i64} : () -> () + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD index 88c31e5057f..cee4ca7f782 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD @@ -62,6 +62,7 @@ test_files = glob( ] glob_lit_tests( + name = "all_tests", data = [":test_utilities"], default_size = "medium", default_tags = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_asset_sinking.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_asset_sinking.mlir new file mode 100644 index 00000000000..2638aab86b8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_asset_sinking.mlir @@ -0,0 +1,30 @@ +// RUN: tf-opt %s -split-input-file -tf-saved-model-asset-sinking='saved-model-dir=foo/bar' | FileCheck %s + +// CHECK-LABEL: module @asset +module @asset attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init]} : () -> () + + // CHECK-NOT: "tf_saved_model.asset" + "tf_saved_model.asset"() {filename = "assets/test0.txt", sym_name = "asset0"} : () -> () + "tf_saved_model.asset"() {filename = "assets/test1.txt", sym_name = "asset1"} : () -> () + + // CHECK: func @init() + func.func @init(%arg0: tensor {tf_saved_model.bound_input = @asset0}, %arg1: tensor {tf_saved_model.bound_input = @asset1}) attributes {tf_saved_model.exported_names = ["init"]} { + // CHECK-DAG: %[[ASSET0:.*]] = "tf.Const"() {value = dense<"foo/bar/assets/test0.txt"> : tensor} + // CHECK-DAG: %[[ASSET1:.*]] = "tf.Const"() {value = dense<"foo/bar/assets/test1.txt"> : tensor} + + // CHECK: %[[VAR0:.*]] = "tf.VarHandleOp"() + %0 = "tf.VarHandleOp"() {container = "", shared_name = "var0"} : () -> tensor>> + + // CHECK: "tf.AssignVariableOp"(%[[VAR0]], %[[ASSET0]]) + "tf.AssignVariableOp"(%0, %arg0) : (tensor>>, tensor) -> () + + // CHECK: %[[VAR1:.*]] = "tf.VarHandleOp"() + %1 = "tf.VarHandleOp"() {container = "", shared_name = "var1"} : () -> tensor>> + + // CHECK: "tf.AssignVariableOp"(%[[VAR1]], %[[ASSET1]]) + "tf.AssignVariableOp"(%1, %arg1) : (tensor>>, tensor) -> () + + func.return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD index 954eca9c0e2..421bbd5de79 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD index 954eca9c0e2..421bbd5de79 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") licenses(["notice"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", test_file_exts = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index db266ed4afe..8852458137f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -838,7 +838,7 @@ func.func @valid_compilation_cluster_no_replication_op_device() { // CHECK-NOT: device = // CHECK: return func.func @valid_compilation_cluster_no_replication_op_device() { - "tf.opA"() { _xla_compile_device_type = "TPU", device = "/device:CPU:0"} : () -> () + "tf.opA"() { _xla_compile_device_type = "TPU", device = "/device:TPU:0"} : () -> () "tf.opB"() { _xla_compile_device_type = "TPU", device = "/task:0/device:TPU:1"} : () -> () func.return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_deserialization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_deserialization.mlir new file mode 100644 index 00000000000..be47ea6ff2c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_deserialization.mlir @@ -0,0 +1,40 @@ +// RUN: tf-opt %s -split-input-file -tf-xla-call-module-deserialization | FileCheck %s + +// Tests that `tf.XlaCallModule` with both StableHLO module and TF function +// calls can be deserialized. + +// CHECK-LABEL: module +module { + // CHECK-LABEL: func private @_tf_func + func.func private @_tf_func(%arg0: tensor, %arg1: tensor<*xi32>) { + // CHECK: tf.StreamResults + + // StreamResults is a pseudo op in this test. + "tf.StreamResults"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> () + func.return + } + + // CHECK-LABEL: func @main + // CHECK-SAME: %[[ARG0:.*]]: tensor<10xi32>, %[[ARG1:.*]]: tensor<10xi32> + func.func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> { + // CHECK: %[[RESULT:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: _entry_function = @_stablehlo_main_0 + // CHECK-NOT: function_list + // CHECK-SAME: module = "" + + // `module` is stablehlo bytecode for: + // func.func @main(%arg0: tensor {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg1: tensor<*xi32>) -> (tensor {jax.result_info = ""}) { + // stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @_tf_func}} : (tensor, tensor<*xi32>) -> () + // return %arg0 : tensor + // } + %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape], dim_args_spec = [], function_list = [@_tf_func], module = "ML\EFR\07StableHLO_v0.12.0\00\01\19\05\01\05\01\03\05\03\09\07\09\0B\0D\03\8Fm\0F\01?\0B\07\0B\0B\0B\0B\0B\13\0B\0F\133\133\13\13S\0B\0B\0B\0B\0B\0B\0B\0B\0B\13\13\0B\13\13\03/\0B\0B\0B\13\1B\0B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\0B\0B\0B\13\0B\0F\01\03\0F\03\0D3\07\0B\1B\17\07\02:\03\05\0F\1F\05\11\05\13\05\15\05\17\05\19\03\03\11\13\05\1B\11\01\05\17\01S\15\03\0B\05E\07S\09U\0B[\0DA\17\011\07\03\0B\05?\07]\09?\0BC\0D_\17\01'\07\17\01)\0B\03\13#a%A'c)?+e-?/?1?3g\05\1D\05\1F\05!\05#\05%\05'\05)\05+\05-\17\013\0B\03\039C\05/\17\015\1B\17\017\0B\03\01\1D1\1D3\03\05GQ\0D\05IKMO\1D5\1D7\1D9\1D;\0D\01#\09\03\03W\0D\03YA\1D=\1D?#\0B\1DA\0B\05\1DC\05\03\0D\03ik\1DE\13\0D\01\01\02\04)\03\00\FF\FF\FF\FF\FF\FF\FF\FF\05\1B3\05\11\05\03\07\03\03\11\03\03\03\03\1D\04}\05\01\11\15\0F\07\04m\03\01\09\03\11\19\17\05\03\07\0F\05\03\03\07\03\00\07\055!\05\01\03\09\07;7\03\03\03\01\05\04=\03\05\03\11\1D\1B\05\03\03\07\03\03\03\00\05\04\1F\03\01\06\03\01\05\01\00\9E\07G\1B)\11\0B!\1B\1D\05\1B\1B\03\0F%\1F/!!)#\1F\19)\1F\13\15\1D\15G\11\1F\15\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00return_v1\00custom_call_v1\00call_v1\00xla_call_module_serialization.mlir\00arg_attrs\00function_type\00res_attrs\00sym_name\00sym_visibility\00mhlo.num_partitions\00api_version\00backend_config\00call_target_name\00called_computations\00has_side_effect\00operand_layouts\00output_operand_aliases\00result_layouts\00tf.backend_config\00callee\00\00_stablehlo_f\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00main\00private\00tf.call_tf_function\00called_index\00", platforms = [], version = 5 : i64} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: return %[[RESULT]] + func.return %0 : tensor<10xi32> + } + + // CHECK-LABEL: func private @_stablehlo_main_0 + // CHECK-SAME: (%[[ARG0:.*]]: tensor {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %[[ARG1:.*]]: tensor<*xi32>) -> (tensor {jax.result_info = ""}) attributes {_from_xla_call_module} { + // CHECK: stablehlo.custom_call @tf.call_tf_function(%[[ARG0]], %[[ARG1]]) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @_tf_func}} : (tensor, tensor<*xi32>) -> () + // CHECK: return %arg0 : tensor + // CHECK: } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_round_trip.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_round_trip.mlir new file mode 100644 index 00000000000..446a61cabc8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_round_trip.mlir @@ -0,0 +1,60 @@ +// RUN: tf-opt %s -split-input-file -tf-xla-call-module-serialization -tf-xla-call-module-deserialization | FileCheck %s + +// Tests that running xla-call-module-serialization followed by +// xla-call-module-deserialization preserves the original module. +// +// Note that function names may be different, but arguments, attributes, +// results, and function body should be the same. + +// CHECK-LABEL: module +module { + // CHECK-LABEL: func @main + // CHECK-SAME: %[[ARG0:.*]]: tensor<10xi32>, %[[ARG1:.*]]: tensor<10xi32> + func.func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> { + // CHECK: %[[RESULT:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: Sout = [#tf_type.shape] + // CHECK-SAME: _entry_function = @_stablehlo_main_0 + // CHECK-SAME: _stablehlo_module_attrs = {} + // CHECK-NOT: function_list + // CHECK-SAME: module = "" + // CHECK-SAME: platforms = [] + // CHECK-SAME: version = 5 + + %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape], dim_args_spec = [], _entry_function = @_stablehlo_main_0, module = "", platforms = [], version = 5 : i64} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: return %[[RESULT]] + func.return %0 : tensor<10xi32> + } + + // CHECK-LABEL: func private @_tf_func + func.func private @_tf_func(%arg0: tensor, %arg1: tensor<*xi32>) { + // CHECK: tf.StreamResults + + // StreamResults is a pseudo op in this test. + "tf.StreamResults"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> () + func.return + } + + // CHECK-LABEL: func private @_stablehlo_main_0 + // CHECK-SAME: %[[ARG0:.*]]: tensor {jax.arg_info = "x", mhlo.sharding = "{replicated}"} + // CHECK-SAME: %[[ARG1:.*]]: tensor<*xi32>) + // CHECK-SAME: (tensor {jax.result_info = ""}) + // CHECK-SAME: attributes {_from_xla_call_module} + func.func private @_stablehlo_main_0(%arg0: tensor {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg1: tensor<*xi32>) -> (tensor {jax.result_info = ""}) attributes {_from_xla_call_module} { + // CHECK: stablehlo.custom_call @tf.call_tf_function(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: { + // CHECK-SAME: api_version = 2 : i32, + // CHECK-SAME: has_side_effect = true, + // CHECK-SAME: tf.backend_config = {called_func = @_tf_func} + // CHECK-SAME: } + stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @_tf_func}} : (tensor, tensor<*xi32>) -> () + // CHECK: call @_stablehlo__stablehlo_f_0 + %arg2 = func.call @_stablehlo_f(%arg0) : (tensor) -> (tensor) + return %arg2 : tensor + } + + // CHECK-LABEL: func private @_stablehlo__stablehlo_f_0 + // CHECK: attributes {_from_xla_call_module} + func.func private @_stablehlo_f(%arg0: tensor) -> (tensor) attributes {_from_xla_call_module} { + return %arg0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_serialization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_serialization.mlir new file mode 100644 index 00000000000..e51433e38bc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_call_module_serialization.mlir @@ -0,0 +1,45 @@ +// RUN: tf-opt %s -split-input-file -tf-xla-call-module-serialization | FileCheck %s + +// Tests that stablehlo functions called by XlaCallModuleOp in the top-level +// module can be serialized into bytecode and embedded in XlaCallModuleOp's +// `module` attribute. + +// CHECK-LABEL: module +module { + // CHECK-LABEL: func private @_tf_func + func.func private @_tf_func(%arg0: tensor, %arg1: tensor<*xi32>) { + // CHECK: tf.StreamResults + + // StreamResults is a pseudo op in this test. + "tf.StreamResults"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> () + func.return + } + + // CHECK-NOT: @_stablehlo_f + func.func private @_stablehlo_f(%arg0: tensor) -> (tensor) attributes {_from_xla_call_module} { + return %arg0 : tensor + } + + // CHECK-NOT: @_stablehlo_main_0 + func.func private @_stablehlo_main_0(%arg0: tensor {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg1: tensor<*xi32>) -> (tensor {jax.result_info = ""}) attributes {_from_xla_call_module} { + stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_func = @_tf_func}} : (tensor, tensor<*xi32>) -> () + %arg2 = func.call @_stablehlo_f(%arg0) : (tensor) -> (tensor) + return %arg2 : tensor + } + + // CHECK-LABEL: func @main + // CHECK-SAME: %[[ARG0:.*]]: tensor<10xi32>, %[[ARG1:.*]]: tensor<10xi32> + func.func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> { + // CHECK: %[[RESULT:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: Sout = [#tf_type.shape] + // CHECK-SAME: dim_args_spec = [] + // CHECK-NOT: _entry_function + // CHECK-NOT: _stablehlo_module_attrs + // CHECK-SAME: function_list = [@_tf_func] + // CHECK-SAME: module = "ML\EFR{{.*}}" + + %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape], dim_args_spec = [], _entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = { mhlo.num_partitions = 1 }, module = "", platforms = [], version = 5 : i64} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: return %[[RESULT]] + func.return %0 : tensor<10xi32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_iputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_iputs.mlir new file mode 100644 index 00000000000..f7166ae11f4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_validate_iputs.mlir @@ -0,0 +1,11 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-xla-validate-inputs + +// expected-error @+1 {{CPU/GPU MLIR phase 1 pipeline does not support nested calls of entry functions}} +func.func @nested_entry_functions(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @func(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + func.return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 74f5a458b0b..fcc70ab1b39 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -148,8 +148,18 @@ void CreateTPUBridgePipelineImpl( pm.addNestedPass( CreateTPUReorderReplicateAndPartitionedInputsPass()); pm.addNestedPass(TF::CreateDecomposeReduceDatasetPass()); - pm.addPass(TFDevice::CreateEmbeddingPipeliningPass()); + if (tensorflow::GetBuildXlaOpsPassFlags() + ->tf_xla_disable_full_embedding_pipelining) { + pm.addPass(TFDevice::CreateEmbeddingSequencingPass()); + } else { + pm.addPass(TFDevice::CreateEmbeddingPipeliningPass()); + } pm.addPass(CreateTPUClusterFormationPass()); + // CreateEmbeddingPipeliningPass may have created more functions, but + // TPUClusterCleanup and OutsideCompiledToHostLaunch need every function to be + // only called from one cluster. Here, we choose to fix the all-funcs-one-use + // invariant right before it's needed, not after it's been broken. + pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); // Run TPU cluster cleanup attributes so ops with no outside compiled // attribute have no host device attribute. pm.addPass(CreateTPUClusterCleanupAttributesPass()); @@ -404,6 +414,8 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { VLOG(2) << "Create TF XLA Bridge pipeline"; pm.addNestedPass( TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); + // This pass expectes unified compilation markers. + pm.addPass(TFDevice::CreateXlaValidateInputsPass()); const llvm::SmallVector ops_to_preserve = {}; pm.addNestedPass( tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve)); @@ -425,6 +437,12 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { pm.addNestedPass(createCanonicalizerPass()); // Decompose resource ops. pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass()); + // TODO(b/267193636): Remove this flag when outside compilation + // for generic pipeline is landed. + if (tensorflow::GetMlirCommonFlags() + ->tf_mlir_enable_generic_outside_compilation) { + pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); + } // Run another shape inference pass because resource decomposition might have // created new partial types. Also, after dropping `shape_invariant` attribute // from While/WhileRegion ops within cluster would lead to more precise diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h deleted file mode 100644 index 6d27780316f..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CALL_GRAPH_UTIL_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CALL_GRAPH_UTIL_H_ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project - -namespace mlir { - -// Find the outermost ops with any of specified types starting from the tree -// rooted at `root` parameter. The results are stored in `ops`. Addtional -// filters can be specified by providing `predicate` parameter. -template -LogicalResult GetOutermostOpsOfType( - func::FuncOp root, SymbolTable &symtab, llvm::SmallVector &ops, - const std::function &predicate = {}) { - std::stack worklist; - worklist.push(root); - while (!worklist.empty()) { - func::FuncOp u = worklist.top(); - worklist.pop(); - auto result = u.walk([&](SymbolUserOpInterface op) { - if (llvm::isa(op) && (!predicate || predicate(op))) { - ops.push_back(op); - return WalkResult::advance(); - } - for (auto attr : op->getAttrs()) { - auto sym = attr.getValue().dyn_cast(); - if (!sym) continue; - auto v = symtab.lookup(sym.getRootReference()); - if (!v) { - // This is not expected to happen in practice. - op->emitError() << "Cannot find function " << sym.getRootReference(); - return WalkResult::interrupt(); - } - worklist.push(v); - } - return WalkResult::advance(); - }); - if (result.wasInterrupted()) return failure(); - } - return success(); -} - -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CALL_GRAPH_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 481f2d868e1..84220aa346b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -16,39 +16,18 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h" #include -#include -#include -#include -#include #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/tfrt/fallback/fallback_state.h" -#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" -#include "tensorflow/tsl/util/device_name_utils.h" namespace mlir { namespace TF { -static bool IsOk(const tensorflow::Status& s) { - if (s.ok()) return true; - VLOG(2) << s.message(); - return false; -} - -#define RETURN_FAILURE_IF_ERROR(expr) \ - if (!IsOk(expr)) { \ - return mlir::failure(); \ - } - // Implements a TF specific policy on when constant folding is allowed. // Policy: // @@ -63,7 +42,7 @@ static bool IsOk(const tensorflow::Status& s) { // (`kResultsSizeThreshold`), or // 2. size of results is within a factor (`kSizeFactor`) of size of operands, or // TODO(b/157226221): Look into other heuristics for constant fold policy. -static bool ShouldBeFolded(Operation* inst) { +static bool IsFoldedByDefaultPolicy(Operation* inst) { bool has_unknown_shape = false; auto get_size = [&](TypeRange types) { int64_t size = 0; @@ -98,142 +77,14 @@ static bool ShouldBeFolded(Operation* inst) { (results_size <= kSizeFactor * operands_size)); } -static const tensorflow::tfrt_stub::FallbackState& GetDefaultFallbackState() { - static const auto* const fallback_state = []() { - tensorflow::SessionOptions session_options; - tensorflow::FunctionDefLibrary fdef_lib; - auto fallback_state = - tensorflow::tfrt_stub::FallbackState::CreateWithCpuDevice( - session_options, fdef_lib) - .value(); - return fallback_state.release(); - }(); - - return *fallback_state; -} - -static std::function)>* GetDefaultRunner() { - static auto* const default_runner = - new std::function)>( - [](const std::function& f) { f(); }); - return default_runner; -} - -static mlir::LogicalResult EvaluateOperation( - mlir::Operation* inst, llvm::ArrayRef operands, - llvm::SmallVectorImpl* results) { - // If any operand is nullptr returns true for a failure. - // TODO(b/120678030): remove this constraint if we find operators can be - // evaluated with some unknown operands. - if (std::any_of(operands.begin(), operands.end(), - [](mlir::Attribute operand) { return !operand; })) { - VLOG(1) << "Can't evaluate since not all operands are constant."; - return mlir::failure(); - } - - // Builds TF operation and sets all the attributes. - std::string node_name = "unnamed"; - if (auto attr = inst->getAttrOfType("name")) { - node_name = std::string(attr.getValue()); - } - auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef( - inst, node_name.c_str(), /*ignore_unregistered_attrs=*/true); - RETURN_FAILURE_IF_ERROR(node_def_or.status()); - const auto& node_def = node_def_or.value(); - - const auto& fallback_state = GetDefaultFallbackState(); - - // Explicitly set device to Host CPU instead of the device present in device - // attribute of the MLIR op. The assigned device might be remote, not - // available during compilation or compilation only device for on demand - // execution which may create a recursion if used for constant folding. - auto host_cpu = tensorflow::DeviceNameUtils::FullName( - /*job=*/"localhost", /*replica=*/0, /*task=*/0, /*type=*/"CPU", /*id=*/0); - - auto statusor_runner = tensorflow::tfrt_stub::OpKernelRunner::Create( - node_def->op(), node_def->name(), host_cpu, operands.size(), - [&](tensorflow::AttrValueMap* attr_value_map) { - *attr_value_map = node_def->attr(); - return tensorflow::OkStatus(); - }, - fallback_state.device_manager(), - fallback_state.process_function_library_runtime()); - RETURN_FAILURE_IF_ERROR(statusor_runner.status()); - const auto& runner = *statusor_runner; - - VLOG(1) << "Start to evaluate node: " << node_def->DebugString(); - - std::vector inputs; - - // Adds inputs to the TF operation. - for (const auto operand : operands) { - tensorflow::Tensor tensor; - RETURN_FAILURE_IF_ERROR(tensorflow::ConvertToTensor(operand, &tensor)); - inputs.push_back(std::move(tensor)); - } - - std::vector input_values; - for (auto& tensor : inputs) { - input_values.emplace_back(); - input_values.back().tensor = &tensor; - } - - tensorflow::OpKernelContext::Params params; - params.inputs = input_values; - params.device = runner.device(); - params.op_kernel = runner.op_kernel(); - // Still use original device's resource_manager. - params.resource_manager = runner.resource_manager(); - params.input_alloc_attrs = runner.input_alloc_attrs(); - params.output_attr_array = runner.output_alloc_attrs().data(); - // Following two parameters are used to support executing tf.data via - // fallback. - params.function_library = runner.function_library_runtime(); - params.runner = GetDefaultRunner(); - - // Executes the TF operation. - tensorflow::OpKernelContext op_kernel_context(¶ms); - runner.Run(&op_kernel_context); - RETURN_FAILURE_IF_ERROR(op_kernel_context.status()); - - // Converts the outputs to MLIR attributes. - mlir::Builder builder(inst->getContext()); - - for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { - DCHECK(op_kernel_context.mutable_output(i)); - auto attr_or = tensorflow::ConvertTensor( - *op_kernel_context.mutable_output(i), &builder); - RETURN_FAILURE_IF_ERROR(attr_or.status()); - results->push_back(attr_or.value()); - } - - VLOG(1) << "Evaluate node " << node_name << " successfully!"; - - return mlir::success(); -} - LogicalResult ConstantFoldFallbackHook( Operation* inst, ArrayRef operands, SmallVectorImpl& results) { // NOLINT - // Instructions with side effects should not be constant folded to preserve - // the original semantics. Ops that have no side effect and zero results but - // could be folded should have a custom folder instead of relying on the - // TensorFlow folding hook. - if (inst->getNumResults() == 0 || - inst->hasTrait() || - inst->getNumRegions() != 0 || !isMemoryEffectFree(inst)) - return failure(); + if (!CanBeFolded(inst)) return failure(); - // If any of the result types are variants, don't try to constant fold them. - // This creates opaque variant constants which lose information and would - // require "raising" later. - for (auto type : inst->getResultTypes()) { - if (auto tensor_type = type.dyn_cast()) { - if (tensor_type.getElementType().isa()) { - return failure(); - } - } - } + // Determine if we should attempt to fold this operation by considering the + // size/size increase due to folding. + if (!IsFoldedByDefaultPolicy(inst)) return failure(); // If all the results are empty and has numerical element types, set results // to empty elements attribute. This is restricted to the numerical element @@ -259,15 +110,6 @@ LogicalResult ConstantFoldFallbackHook( return success(); } - // Do not execute function calls. - if (llvm::isa(inst)) { - return failure(); - } - - // Determine if we should attempt to fold this operation by considering the - // size/size increase due to folding. - if (!ShouldBeFolded(inst)) return failure(); - // Returns directly if any of the operands is not an elements attributes. if (std::any_of(operands.begin(), operands.end(), [](Attribute attr) { return !attr || !attr.isa(); @@ -284,8 +126,8 @@ LogicalResult ConstantFoldFallbackHook( // TODO(jpienaar): Avoid using global context & mutex here. static auto* mu = new tensorflow::mutex(); tensorflow::mutex_lock l(*mu); - SmallVector constants; - LogicalResult status = EvaluateOperation(inst, inputs, &constants); + SmallVector constants; + LogicalResult status = EvaluateOperation(inst, inputs, constants); results.assign(constants.begin(), constants.end()); return status; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc index fbd5541c137..6d28fa03a98 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc @@ -15,77 +15,182 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h" +#include +#include #include +#include +#include +#include -#include "tensorflow/c/tf_status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/tsl/platform/mem.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" namespace mlir { namespace TF { -TFE_Context* GetContextForConstantFold() { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - std::unique_ptr opts( - TFE_NewContextOptions(), TFE_DeleteContextOptions); - // Only initialize single CPU. - tensorflow::ConfigProto config_proto; - // This is conceptually equal to what we do in python/eager/context.py but - // with all GPU/TPU devices ignored and CPU only set to 1. - (*config_proto.mutable_device_count())["CPU"] = 1; - config_proto.add_device_filters("/device:CPU:*"); - // Limit the thread pool size. Without this, TF by default creates as many - // threads as the number of CPUs (`port::MaxParallelism()`). This can be - // expensive since this TFE context persists the entire program execution. - config_proto.set_inter_op_parallelism_threads(2); - std::unique_ptr config( - TF_NewBuffer(), TF_DeleteBuffer); - DCHECK(config->data == nullptr); +using tensorflow::tfrt_stub::FallbackState; +using tensorflow::tfrt_stub::OpKernelRunner; - // Copy config_proto into config. - { - const size_t proto_size = config_proto.ByteSizeLong(); - void* buf = tsl::port::Malloc(proto_size); - if (buf == nullptr) { - LOG(ERROR) << "Failed to allocate memory to serialize ConfigProto " - "while creating context options for constant folding"; - return nullptr; +static bool IsOk(const tensorflow::Status& s) { + if (s.ok()) return true; + VLOG(2) << s.message(); + return false; +} + +#define RETURN_FAILURE_IF_ERROR(expr) \ + if (!IsOk(expr)) { \ + return mlir::failure(); \ + } + +bool CanBeFolded(Operation* inst) { + // Instructions with side effects should not be constant folded to preserve + // the original semantics. Ops that have no side effect and zero results but + // could be folded should have a custom folder instead of relying on the + // TensorFlow folding hook. + if (inst == nullptr || inst->getNumResults() == 0 || + inst->hasTrait() || + inst->getNumRegions() != 0 || !isMemoryEffectFree(inst)) { + return false; + } + + // If any of the result types are variants, don't try to constant fold them. + // This creates opaque variant constants which lose information and would + // require "raising" later. + for (const Type type : inst->getResultTypes()) { + if (const TensorType tensor_type = type.dyn_cast()) { + if (tensor_type.getElementType().isa()) { + return false; + } } - if (!config_proto.SerializeWithCachedSizesToArray( - static_cast(buf))) { - tsl::port::Free(buf); - LOG(ERROR) << "Unable to serialize ConfigProto while creating context " - "options for constant folding"; - return nullptr; - } - config->data = buf; - config->length = proto_size; - config->data_deallocator = [](void* data, size_t length) { - tsl::port::Free(data); - }; } - TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, - status.get()); - if (TF_GetCode(status.get()) != TF_OK) { - LOG(ERROR) << "Failed to set context options for constant folding: " - << status.get(); - return nullptr; + // Operations that execute function calls shouldn't be constant folded. + if (llvm::isa(inst)) { + return false; } - // Input tensors are placed on the host CPU so use the explicit device - // policy to fail if no CPU kernels are available for the op. - TFE_ContextOptionsSetDevicePlacementPolicy(opts.get(), - TFE_DEVICE_PLACEMENT_EXPLICIT); - auto ctx = TFE_NewContext(opts.get(), status.get()); - if (TF_GetCode(status.get()) != TF_OK) { - LOG(ERROR) << "Failed to create context for constant folding: " - << status.get(); - return nullptr; + return true; +} + +static const FallbackState& GetDefaultFallbackState() { + static const auto* const fallback_state = []() { + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + auto fallback_state = + FallbackState::CreateWithCpuDevice(session_options, fdef_lib).value(); + return fallback_state.release(); + }(); + + return *fallback_state; +} + +static std::function)>* GetDefaultRunner() { + static auto* const default_runner = + new std::function)>( + [](const std::function& f) { f(); }); + return default_runner; +} + +LogicalResult EvaluateOperation(Operation* inst, + llvm::ArrayRef operands, + llvm::SmallVector& results) { + // If any operand is nullptr returns true for a failure. + // TODO(b/120678030): remove this constraint if we find operators can be + // evaluated with some unknown operands. + if (std::any_of(operands.begin(), operands.end(), + [](Attribute operand) { return !operand; })) { + VLOG(1) << "Can't evaluate since not all operands are constant."; + return failure(); } - return ctx; + + // Builds TF operation and sets all the attributes. + std::string node_name = "unnamed"; + if (const StringAttr attr = inst->getAttrOfType("name")) { + node_name = std::string(attr.getValue()); + } + absl::StatusOr> node_def = + tensorflow::ConvertTFDialectOpToNodeDef( + inst, node_name.c_str(), /*ignore_unregistered_attrs=*/true); + RETURN_FAILURE_IF_ERROR(node_def.status()); + + const FallbackState& fallback_state = GetDefaultFallbackState(); + + // Explicitly set device to Host CPU instead of the device present in device + // attribute of the MLIR op. The assigned device might be remote, not + // available during compilation or compilation only device for on demand + // execution which may create a recursion if used for constant folding. + std::string host_cpu = tensorflow::DeviceNameUtils::FullName( + /*job=*/"localhost", /*replica=*/0, /*task=*/0, /*type=*/"CPU", /*id=*/0); + + absl::StatusOr runner = OpKernelRunner::Create( + node_def->get()->op(), node_def->get()->name(), host_cpu, operands.size(), + [&](tensorflow::AttrValueMap* attr_value_map) { + *attr_value_map = node_def->get()->attr(); + return tensorflow::OkStatus(); + }, + fallback_state.device_manager(), + fallback_state.process_function_library_runtime()); + RETURN_FAILURE_IF_ERROR(runner.status()); + + VLOG(1) << "Start to evaluate node: " << node_def->get()->DebugString(); + + std::vector inputs; + + // Adds inputs to the TF operation. + for (const ElementsAttr& operand : operands) { + tensorflow::Tensor tensor; + RETURN_FAILURE_IF_ERROR(tensorflow::ConvertToTensor(operand, &tensor)); + inputs.push_back(std::move(tensor)); + } + + std::vector input_values; + for (tensorflow::Tensor& tensor : inputs) { + input_values.emplace_back(); + input_values.back().tensor = &tensor; + } + + tensorflow::OpKernelContext::Params params; + params.inputs = input_values; + params.device = runner->device(); + params.op_kernel = runner->op_kernel(); + + // Still use original device's resource_manager. + params.resource_manager = runner->resource_manager(); + params.input_alloc_attrs = runner->input_alloc_attrs(); + params.output_attr_array = runner->output_alloc_attrs().data(); + + // Following two parameters are used to support executing tf.data via + // fallback. + params.function_library = runner->function_library_runtime(); + params.runner = GetDefaultRunner(); + + // Executes the TF operation. + tensorflow::OpKernelContext op_kernel_context(¶ms); + runner->Run(&op_kernel_context); + RETURN_FAILURE_IF_ERROR(op_kernel_context.status()); + + // Converts the outputs to MLIR attributes. + Builder builder(inst->getContext()); + + for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { + DCHECK(op_kernel_context.mutable_output(i)); + absl::StatusOr result_attr = tensorflow::ConvertTensor( + *op_kernel_context.mutable_output(i), &builder); + RETURN_FAILURE_IF_ERROR(result_attr.status()); + results.push_back(result_attr.value()); + } + + VLOG(1) << "Evaluate node " << node_name << " successfully!"; + + return success(); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h index 8f28735d2a9..636dde98d2b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h @@ -16,12 +16,21 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_UTILS_H_ -#include "tensorflow/c/eager/c_api.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { namespace TF { -TFE_Context* GetContextForConstantFold(); +// Checks whether the given TF operation can be folded or not. +bool CanBeFolded(Operation* inst); + +// Evaluates the operation with given operand values. +LogicalResult EvaluateOperation(Operation* inst, + llvm::ArrayRef operands, + llvm::SmallVector& results); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc index e5671bf5961..84b161a0fd7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc @@ -13,53 +13,160 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This pass implements automated pipelining for TPU embeddings defined using -// the TF2 Embedding API. This is designed for applications that have an -// embedding lookup on the SparseCore, followed by one or more dense layers on -// TensorCores, optionally followed by a backward pass (training update) with -// more ops on the SparseCore. Ops are broken up into: -// 1. SC forward pass -// 2. TC forward/backward pass -// 3. SC backward pass -// 4. non-TPU loop counter updates -// These 4 functions are then staggered so as to enable parallel execution. +/****************************************************************************** +This pass implements automated pipelining for TPU embeddings defined using +the TF2 Embedding API. This is designed for applications that have an +embedding lookup on the SparseCore, followed by one or more dense layers on +TensorCores, optionally followed by a backward pass (training update) with +more ops on the SparseCore. Ops are broken up into: + 1. SC forward pass + 2. TC forward/backward pass + 3. SC backward pass + 4. non-TPU loop counter updates +These 4 functions are then staggered so as to enable parallel execution. + +In pseudocode, the algorithm is as follows: + +// Start step 0 +C_0 = cond(args_0) +N_0 = non_tup(args_0) +if (C_0) { + F_0 = forward(args_0, N_0) + T_0 = core_tpu(args_0, N_0, F_0) + // B_0 = backward() is not evaluated here. +} + +args_1 = update_args(args_0, N_0, T_0) + +// Start step 1 +C_1 = cond(args_1) +N_1 = non_tup(args_1) +if (C_1) { + F_1 = forward(args_1, N_1) + // T_1 = core_tpu() is not evaluated here. + // B_1 = backward() is not evaluated here. +} + +// Partial update of args. We expect this to be sufficient +// for evaluating cond(). +args_2a = update_args(args_1, N_1) // NO T_1 here + +// Conditional for step 2 +C_2 = cond(args_2) + +new_while_body (new_args) { // starts at i==2 + // Finish step i-2 + B_im2 = backward(args_im2, N_im2, F_im2, T_im2) + + // Advance step i-1 + T_im1 = core_tpu(args_im1, N_im1, F_im1) + + // Finish the update of args_2 + args_i = args_2b = update_args(args_2a, T_im1) + + // Start step i + N_i = non_tpu(args_i) + F_i = forward(args_i, N_i) + + // Conditional update + args_ip1 = update_args(args_i, N_i) // T_i is lagged. + C_ip1 = cond(args_ip1) + + return (...) +} +// Note: the tf.while conditional is based on Ci which is initially C2. The +// tf.while op returns the inputs unmodified if the initial conditional is +// false. Thus, the following special cases hold for N <= 2: +// N==0 | N==1 | N==2 | N==3 +// ----------------------------- +// C_nm2 == C_0 -> false | true | true | true +// C_nm1 == C_1 -> false | false | true | true + +// Finish step N-2 +if (C_nm2) { + backward(args_nm2, N_nm2, F_nm2, T_nm2) +} + +// Finish step N-1 +if (C_nm1) { + T_nm1 = core_tpu(args_nm1, N_nm1, F_nm1) + backward(args_nm1, N_nm1, F_nm1, T_nm1) +} + +// To match the original, un-pipelined while loop, we need to return the +// correct results from the pipelined version. Nominally, we'd like to do +// this: +// if ( NOT(C_nm2) ) { +// return args_nm2 +// } else if (NOT(C_nm1)) { +// return args_nm1 +// } else { +// return args_n +// } +// but we don't have if/else-if operators. We can convert this to a CaseOp. +// Note, if C_nm1==true and C_nm2 must also be true. +branch_index = int(C_nm2) + int(C_nm1) +selected_results = switch(branch_index) { + case 0: return args_nm2 + case 1: return args_nm1 + case 2: return args_n +} +return selected_results +******************************************************************************/ #include +#include #include #include #include #include +// #include "smartass/brain/ops/flogs_ops.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define GEN_PASS_DEF_EMBEDDINGPIPELININGPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining"; +static constexpr char kEmbeddingPipeliningInlineAttr[] = + "_embedding_pipelining_inline"; static constexpr char kEmbeddingForward[] = "forward"; static constexpr char kEmbeddingBackward[] = "backward"; static constexpr char kDevice[] = "device"; +static constexpr char kLower[] = "_lower_using_switch_merge"; static constexpr llvm::StringRef kTpuCompilationStatus = "_tpu_compilation_status"; @@ -67,24 +174,6 @@ namespace mlir { namespace TFDevice { namespace { -struct EmbeddingPipeliningPass - : public ::impl::EmbeddingPipeliningPassBase { - void getDependentDialects(mlir::DialectRegistry& registry) const override { - registry.insert(); - } - - void runOnOperation() override; -}; - -template -std::vector GetValueTypes(const InputContainer& input) { - // Convert a list of mlir::Value's into a list of mlir::Type's - std::vector types; - types.reserve(input.size()); - for (auto val : input) types.push_back(val.getType()); - return types; -} - bool IsResourceType(Type val_type) { if (auto tensor_type = val_type.dyn_cast()) { if (tensor_type.getElementType().isa()) { @@ -94,9 +183,14 @@ bool IsResourceType(Type val_type) { return false; } -bool IsTPUOp(mlir::Operation* op) { - return op->hasAttr(TF::kReplicationInfoAttr); -} +struct EmbeddingPipeliningPass + : public ::impl::EmbeddingPipeliningPassBase { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; StringAttr GetReplicationAttr(mlir::Operation* op) { return op->getAttrOfType(TF::kReplicationInfoAttr); @@ -108,12 +202,313 @@ StringAttr GetReplicationAttr(TF::TPUCompilationResultOp op) { return op->getAttrOfType(kTpuCompilationStatus); } +// Replaces the replication region attribute if it already exists. +void UpdateReplicationAttr(Operation* op, StringAttr attr) { + if (op->hasAttr(TF::kReplicationInfoAttr)) { + op->setAttr(TF::kReplicationInfoAttr, attr); + } +} + +// Replaces the replication region attribute if it already exists. +void UpdateReplicationAttr(TF::TPUCompilationResultOp& op, StringAttr attr) { + // Special case for getting the replication region for + // TPUCompilationResultsOp. + if (op->hasAttr(kTpuCompilationStatus)) { + op->setAttr(kTpuCompilationStatus, attr); + } +} + +// A helper class to inline TF::StatefulPartitionedCall ops +struct Inliner : public InlinerInterface { + Inliner(OpBuilder& builder, SymbolTable& symbol_table) + : InlinerInterface(builder.getContext()), + builder(builder), + symbol_table(symbol_table) {} + + bool isLegalToInline(Operation* call, Operation* callable, + bool wouldBeCloned) const override { + return true; + } + bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned, + IRMapping& valueMapping) const override { + return true; + } + bool isLegalToInline(Operation* op, Region* dest, bool wouldBeCloned, + IRMapping& valueMapping) const override { + return true; + } + + // Don't recursively analyze operations, because they can all be "inlined". + bool shouldAnalyzeRecursively(Operation* op) const override { return true; } + + LogicalResult UnifyReplicationInfo(func::FuncOp func) { + auto new_repl_info = + builder.getStringAttr(func.getSymName().str() + "_repl_info"); + for (auto& op : func.getRegion().getOps()) { + if (auto compile_op = llvm::dyn_cast(op)) { + UpdateReplicationAttr(compile_op, new_repl_info); + } else { + UpdateReplicationAttr(&op, new_repl_info); + } + } + return LogicalResult::success(); + } + + // After inlining, there will likely be some instances where a + // TPUReplicatedInput feeds directly into a TPUReplicatedOutput. Find such + // pairs and remove them. + LogicalResult RemoveOutputInputPairs(func::FuncOp func) { + llvm::SetVector ops_to_erase; + // Inlining can result in multiple TPUCompilationResultOp and + // TPUReplicateMetadataOp ops. Only keep one, the first will do fine. + TF::TPUCompilationResultOp compile_op = nullptr; + for (auto op : func.getRegion().getOps()) { + if (compile_op == nullptr) { + compile_op = op; + } else { + ops_to_erase.insert(op); + } + } + // If there's no outside compilation, we can exit early because this isn't + // a TPU function. + if (compile_op == nullptr) { + return LogicalResult::success(); + } + + TF::TPUReplicateMetadataOp metadata_op = nullptr; + for (auto op : func.getRegion().getOps()) { + if (metadata_op == nullptr) + metadata_op = op; + else + ops_to_erase.insert(op); + } + if (metadata_op == nullptr) { + func->emitError( + "Expected to find TPUReplicateMetadataOps but found none."); + return LogicalResult::failure(); + } + + for (auto output_op : + func.getRegion().getOps()) { + bool outputs_are_returned = false; + TF::TPUReplicatedInputOp input_op = nullptr; + // Only visit each user of the results once. + llvm::SetVector seen_users; + for (auto user : output_op->getUsers()) { + if (!seen_users.insert(user)) continue; + if (llvm::isa(user)) { + if (input_op != nullptr) { + func->emitError( + "Found multiple TPUReplicatedInput ops but only expected 1."); + return LogicalResult::failure(); + } + input_op = llvm::dyn_cast(user); + } + if (llvm::isa(user)) { + outputs_are_returned = true; + } + } + if (input_op == nullptr) continue; + + // If we found matching input ops, we can remove the TPUReplicatedInput + // ops and replace their result values with the inputs to the matching + // TPUReplicatedOutput op. + replaceAllUsesInRegionWith(input_op.getResult(), output_op.getOperand(), + func.getRegion()); + ops_to_erase.insert(input_op); + + // If the outputs aren't also returned from this function, then we can + // remove the TPUReplicatedOutput op as well. In some cases we'll + // still need these ops. + if (!outputs_are_returned) ops_to_erase.insert(output_op); + } + for (auto op : ops_to_erase) op->erase(); + + return LogicalResult::success(); + } + + LogicalResult RemoveDuplicateReplication(func::FuncOp func) { + llvm::SetVector ops_to_erase; + llvm::MapVector cache; + for (auto input_op : func.getRegion().getOps()) { + // We're only expecting a single input argument to be replicated. + if (input_op->getNumOperands() > 1) continue; + Value operand = input_op->getOperand(0); + if (!llvm::isa(operand)) continue; + BlockArgument arg = llvm::dyn_cast(operand); + + // See if we've run across this TPUReplicatedInputOp before. + if (!cache.insert({arg, input_op}).second) { + // We've seen this before. Replace this instance with the cached op. + for (auto p : + llvm::zip(input_op->getResults(), cache[arg]->getResults())) { + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + func.getRegion()); + } + ops_to_erase.insert(input_op); + } + } + for (auto op : ops_to_erase) op->erase(); + return LogicalResult::success(); + } + + // Find any StatefulPartitionedCalls and inline their contents in this func. + LogicalResult InlineCallsInFunc(func::FuncOp func, + bool inline_all_funcs = false) { + llvm::SetVector ops_to_erase; + for (auto caller : + func.getRegion().getOps()) { + if (!inline_all_funcs && + !caller->hasAttr(kEmbeddingPipeliningInlineAttr)) { + continue; + } + Operation* symbol = symbol_table.lookup(caller.getF()); + if (symbol == nullptr) { + func.emitError() << "Symbol not found in SymbolTable: " + << caller.getF(); + return LogicalResult::failure(); + } + if (!llvm::isa(symbol)) { + func.emitError() << "Invalid callee: " << caller.getF(); + return LogicalResult::failure(); + } + auto callee = + llvm::dyn_cast(symbol_table.lookup(caller.getF())); + auto& src_region = callee.getRegion(); + auto result = inlineCall(*this, caller, callee, &src_region, true); + if (failed(result)) { + func.emitError("Inliner failed"); + return result; + } + ops_to_erase.insert(caller); + } + for (auto op : ops_to_erase) op->erase(); + + auto result = UnifyReplicationInfo(func); + if (failed(result)) return result; + + result = RemoveOutputInputPairs(func); + if (failed(result)) return result; + + result = RemoveDuplicateReplication(func); + if (failed(result)) return result; + + return LogicalResult::success(); + } + + private: + OpBuilder& builder; + SymbolTable& symbol_table; +}; + +LogicalResult EliminateResourceLoops(OpBuilder& builder, + SymbolTable& symbol_table, + func::FuncOp func) { + // Examine all StatefulPartitionedCall ops that have resources as return + // types. If the returned resource traces back to an input argument for the + // SPC, then replace uses of the returned copy with the original input. + // + // Note: This does not descend through nested SCPs. + auto ComesFromBlockArgNumber = [](Value val) -> int { + while (true) { + if (auto block_arg = llvm::dyn_cast(val)) { + return block_arg.getArgNumber(); + } + if (auto identity_op = + llvm::dyn_cast(val.getDefiningOp())) { + val = identity_op.getOperand(); + } else { + return -1; + } + } + }; + + for (auto call_op : + func.getRegion().getOps()) { + for (int i = 0; i < call_op->getNumResults(); ++i) { + if (IsResourceType(call_op->getResult(i).getType())) { + Operation* symbol = symbol_table.lookup(call_op.getF()); + if (symbol == nullptr) { + func.emitError() << "Symbol not found in SymbolTable: " + << call_op.getF(); + return LogicalResult::failure(); + } + if (!llvm::isa(symbol)) { + func.emitError() << "Invalid callee: " << call_op.getF(); + return LogicalResult::failure(); + } + auto callee = + llvm::dyn_cast(symbol_table.lookup(call_op.getF())); + func::ReturnOp return_op = *callee.getOps().begin(); + auto val = return_op.getOperand(i); + auto block_arg_number = ComesFromBlockArgNumber(val); + if (block_arg_number >= 0) { + replaceAllUsesInRegionWith(call_op->getResult(i), + call_op->getOperand(block_arg_number), + func.getRegion()); + } + } + } + } + return LogicalResult::success(); +} + +struct Callers { + TF::StatefulPartitionedCallOp forward; + TF::StatefulPartitionedCallOp core_tpu; + TF::StatefulPartitionedCallOp backward; + TF::StatefulPartitionedCallOp non_tpu; +}; + +template +std::vector GetValueTypes(const InputContainer& input) { + // Convert a list of mlir::Value's into a list of mlir::Type's + std::vector types; + types.reserve(input.size()); + for (auto val : input) types.push_back(val.getType()); + return types; +} + +bool IsTPUOp(mlir::Operation* op) { + return op->hasAttr(TF::kReplicationInfoAttr); +} + +template +void Append(Vector& a, const Container& b) { + a.insert(a.end(), b.begin(), b.end()); +} + +template +void Append(Vector& a, const Vector& b) { + a.insert(a.end(), b.begin(), b.end()); +} + int64_t GetNumOps(func::FuncOp func) { int64_t num_ops = 0; for (auto it = func.begin(); it != func.end(); ++it) ++num_ops; return num_ops; } +std::vector ResultsAsVector(Operation* op) { + std::vector vec; + vec.reserve(op->getNumResults()); + for (auto res : op->getResults()) vec.push_back(res); + return vec; +} + +void SetBasicBlockAttributes(OpBuilder& builder, Operation* op) { + op->setAttr(kDevice, builder.getStringAttr("")); + op->setAttr(kLower, builder.getBoolAttr(true)); +} + +std::vector ResultsAsVector(Operation* op, int begin, int num) { + int end = begin + num; + std::vector vec; + vec.reserve(end - begin); + for (int i = begin; i < end; ++i) vec.push_back(op->getResult(i)); + return vec; +} + void GatherOpsForExtraction(mlir::SetVector* operations, const mlir::SetVector& ops_to_avoid, bool predecessors, bool successors) { @@ -158,9 +553,11 @@ void GatherOpsForExtraction(mlir::SetVector* operations, } } -TF::StatefulPartitionedCallOp MakeFuncCaller( - mlir::OpBuilder& builder, const Location& loc, func::FuncOp func, - const llvm::SetVector& operands) { +TF::StatefulPartitionedCallOp MakeFuncCaller(mlir::OpBuilder& builder, + const Location& loc, + func::FuncOp func, + const ArrayRef& operands, + bool flag_for_inlining) { // Constructs a tf.StatefulPartitionedCall to the function provided in 'func' // using the operands in 'operands'. Assumes the insertion point on builder is // already set. @@ -168,60 +565,65 @@ TF::StatefulPartitionedCallOp MakeFuncCaller( mlir::SymbolRefAttr::get(builder.getContext(), func.getSymName()); auto result_types = func.getResultTypes(); auto caller = builder.create( - loc, result_types, operands.getArrayRef(), symbol, + loc, result_types, operands, symbol, /*config=*/builder.getStringAttr(""), /*config_proto=*/builder.getStringAttr(""), /*executor_type=*/builder.getStringAttr("")); caller.setFAttr(symbol); + + // Set an attribute that our inliner will look for when choosing which + // TF::StatefulPartitionedCallOps to inline. + if (flag_for_inlining) + caller->setAttr(kEmbeddingPipeliningInlineAttr, builder.getBoolAttr(true)); return caller; } -func::FuncOp CreateFnWithSignature(ModuleOp module, +func::FuncOp CreateFnWithSignature(ModuleOp module, SymbolTable& symbol_table, const llvm::SetVector& inputs, const llvm::SetVector& outputs, const std::string& name) { // Creates an empty func.FuncOp with a signature compatible with 'inputs' // (operands) and 'outputs' (results). OpBuilder builder(module); - - std::vector input_types = GetValueTypes(inputs); - std::vector output_types = GetValueTypes(outputs); + auto in_types = GetValueTypes(inputs); + auto out_types = GetValueTypes(outputs); builder.setInsertionPointToEnd(&module.getBodyRegion().back()); - func::FuncOp func_op = builder.create( - module.getLoc(), name, - builder.getFunctionType(input_types, output_types)); + auto func_op = builder.create( + module.getLoc(), name, builder.getFunctionType(in_types, out_types)); func_op.setPrivate(); - + symbol_table.insert(func_op); return func_op; } TF::StatefulPartitionedCallOp EncapsulateOpsInFunc( - OpBuilder& builder, const llvm::SetVector& ops, + OpBuilder& builder, SymbolTable& symbol_table, + const llvm::SetVector& ops, const llvm::SetVector& inputs, const llvm::SetVector& outputs, - func::FuncOp parent_func, ModuleOp module, const std::string& name) { + func::FuncOp parent_func, ModuleOp module, const std::string& name, + bool flag_for_inlining) { // Moves all of the Operations in 'ops' into a newly created func.FuncOp // function named 'name' and replaces the original ops with a call to the // newly created function using a tf.StatefulPartitionedCall. Here, // 'parent_func' is the function that holds the original set of ops. // Note, 'inputs' and 'outputs' are the predetermined set of values that // should become the operands and return values, respectively. - auto insertion_point = builder.saveInsertionPoint(); - func::FuncOp new_func = CreateFnWithSignature(module, inputs, outputs, - absl::StrCat("_func_", name)); + auto saved_insertion_point = builder.saveInsertionPoint(); + func::FuncOp new_func = + CreateFnWithSignature(module, symbol_table, inputs, outputs, name); // This preserves the order of the ops that was in the original parent - // funtion. This is critical for preserving correctness in the presence of + // function. This is critical for preserving correctness in the presence of // resource variables and stateful functions. std::vector topological_order; for (Operation& op : parent_func.getOps()) if (ops.contains(&op)) topological_order.push_back(&op); // Create the partitioned call - builder.restoreInsertionPoint(insertion_point); - auto caller = MakeFuncCaller(builder, module.getLoc(), new_func, inputs); + builder.restoreInsertionPoint(saved_insertion_point); + auto caller = MakeFuncCaller(builder, module.getLoc(), new_func, + inputs.getArrayRef(), flag_for_inlining); Block* block = new_func.addEntryBlock(); - for (Operation* op : topological_order) op->moveBefore(block, block->end()); // Replace the 'inputs' values with the new function's arguments. @@ -293,7 +695,7 @@ LogicalResult FindAndExcludeOp(func::FuncOp func, } LogicalResult FindOwningWhileOp(func::FuncOp body_func, ModuleOp module, - TF::WhileOp* while_op) { + TF::WhileOp& while_op) { // Given a while loop body function 'body_func', find the tf.While Op that // uses it. auto uses_optional = body_func.getSymbolUses(module); @@ -301,14 +703,14 @@ LogicalResult FindOwningWhileOp(func::FuncOp body_func, ModuleOp module, body_func.emitOpError() << "no use of while loop body"; return LogicalResult::failure(); } - *while_op = nullptr; + while_op = nullptr; for (auto& use : uses_optional.value()) { if (llvm::isa(use.getUser())) { - if (*while_op != nullptr) { + if (while_op != nullptr) { use.getUser()->emitOpError() << "multiple users of function."; return LogicalResult::failure(); } else { - *while_op = llvm::cast(use.getUser()); + while_op = llvm::cast(use.getUser()); } } else { use.getUser()->emitOpError() << "non while use of function."; @@ -397,15 +799,13 @@ LogicalResult FindForwardPassOps(OpBuilder& builder, if (use_in_forward && use_in_not_forward) { loop_body_func.emitOpError() << "resource input " << argument.getArgNumber() - << " is used both in the forwards and " - << "not forward passes dataset"; + << " is used both in the forwards and not forward passes dataset"; return LogicalResult::failure(); } if (is_non_variable && is_variable) { loop_body_func.emitOpError() << "resource input " << argument.getArgNumber() - << " is used both as a varible and not " - << " a variable"; + << " is used both as a variable and not a variable"; return LogicalResult::failure(); } if (is_variable && use_in_forward) @@ -461,7 +861,7 @@ LogicalResult FindForwardPassOps(OpBuilder& builder, } } - VLOG(2) << "Cloned " << cloned_inputs << " TPUReplicatedInputOps"; + VLOG(3) << "Cloned " << cloned_inputs << " TPUReplicatedInputOps"; // Add TPUReplicatedInput/TPUReplicatedOutput pairs along each edge. llvm::SetVector new_forward_ops; @@ -515,7 +915,7 @@ LogicalResult FindForwardPassOps(OpBuilder& builder, } } - VLOG(2) << "inserted " << new_forward_ops.size() << " TPU Input/Output ops"; + VLOG(3) << "Inserted " << new_forward_ops.size() << " TPU Input/Output ops."; forward_pass_ops.insert(new_forward_ops.begin(), new_forward_ops.end()); return LogicalResult::success(); } @@ -537,7 +937,7 @@ LogicalResult FindBackwardPassOps( GatherOpsForExtraction(&backward_pass_ops, merged_set, /*predecessors=*/false, /*successors=*/true); - VLOG(3) << "found " << backward_pass_ops.size() << " backwards pass ops"; + VLOG(3) << "Found " << backward_pass_ops.size() << " backwards pass ops."; // If any inputs are to the backward_pass_ops region are direct // TPUReplicatedInput ops, then include (if this is the only use) or @@ -719,10 +1119,12 @@ LogicalResult FindNonTPUOps(llvm::SetVector& non_tpu_ops, } LogicalResult ExtractOpsAsFunc( - OpBuilder& builder, ModuleOp module, llvm::SetVector& ops, - StringAttr replication_attr, TF::TPUReplicateMetadataOp metadata_op, + OpBuilder& builder, ModuleOp module, SymbolTable& symbol_table, + llvm::SetVector& ops, StringAttr replication_attr, + TF::TPUReplicateMetadataOp metadata_op, TF::TPUCompilationResultOp compilation_op, func::FuncOp parent_func, - const std::string& func_name, Operation** caller) { + const std::string& func_name, TF::StatefulPartitionedCallOp* caller, + bool flag_for_inlining) { // Move the given set of 'ops' into it's own function and replace them with a // call to that function ('caller'). if 'metadata_op' and 'compilation_op' are // non-null, also insert those (i.e., target the resulting function to the @@ -753,8 +1155,9 @@ LogicalResult ExtractOpsAsFunc( } llvm::SetVector outputs; for (auto output : results) outputs.insert(output); - auto tf_caller = EncapsulateOpsInFunc(builder, ops, inputs, outputs, - parent_func, module, func_name); + auto tf_caller = + EncapsulateOpsInFunc(builder, symbol_table, ops, inputs, outputs, + parent_func, module, func_name, flag_for_inlining); if (!ops.empty() && metadata_op != nullptr && compilation_op != nullptr) UpdateAndInsertTPUOps(tf_caller, metadata_op, compilation_op, replication_attr); @@ -762,8 +1165,464 @@ LogicalResult ExtractOpsAsFunc( return LogicalResult::success(); } +LogicalResult FindSourceTPUReplicatedOutput( + Value val, TF::TPUReplicatedOutputOp& rep_out) { + Operation* op = val.getDefiningOp(); + if (auto src = llvm::dyn_cast(op)) { + rep_out = src; + return LogicalResult::success(); + } + if (auto src = llvm::dyn_cast(op)) { + return FindSourceTPUReplicatedOutput(src->getOperand(0), rep_out); + } + op->emitOpError() << "Value did not come from a TPUReplicatedOutput op: " + << val; + return LogicalResult::failure(); +} + +int FindReturnIndex(Value val) { + const int not_found = -1; + for (auto user : val.getUsers()) { + if (auto ret_op = llvm::dyn_cast(user)) { + for (auto index = 0; index < ret_op->getNumOperands(); ++index) { + if (val == ret_op->getOperand(index)) { + return index; + } + } + } + if (auto ident_op = llvm::dyn_cast(user)) { + auto index = FindReturnIndex(ident_op->getResult(0)); + if (index != not_found) return index; + } + } + return not_found; +} + +void AddAssertion(OpBuilder& builder, Location& loc, Value cond, + const std::string& message) { + auto shape_type = + RankedTensorType::get({1}, builder.getType()); + auto msg = builder.create( + loc, DenseStringElementsAttr::get(shape_type, + llvm::ArrayRef{message})); + builder.create(loc, cond, msg.getResult()); +} + +LogicalResult StartStep0(OpBuilder& builder, Location& loc, + SymbolTable& symbol_table, + TF::TPUReplicateMetadataOp& metadata_op, + TF::TPUCompilationResultOp& compilation_op, + Value& cond_value, Callers& callers, + const std::vector& loop_operands_nm0, + TF::StatefulPartitionedCallOp& caller) { + const std::string name = "start_step_0"; + + AddAssertion(builder, loc, cond_value, + "Auto-pipelining requires at least two steps."); + auto insertion_point = builder.saveInsertionPoint(); + + func::FuncOp orig_parent_func = + callers.backward->getParentOfType(); + + std::vector operands = loop_operands_nm0; + + // Input types will be the same as the original loop body. + std::vector input_types = GetValueTypes(operands); + + // Determine the results types. + // Return ALL outputs, respecting the provided order of the Operations. This + // makes it straightforward for users of this function to map the return + // values. + llvm::SetVector ops; + ops.insert(callers.forward); + ops.insert(callers.core_tpu); + std::vector result_map; + result_map.reserve(callers.forward->getNumResults() + + callers.core_tpu->getNumResults()); + int result_pos = 0; + for (auto res : callers.forward->getResults()) { + bool is_output = false; + for (auto user : res.getUsers()) { + if (!ops.contains(user)) { + is_output = true; + break; + } + } + result_map.push_back(is_output ? result_pos++ : -1); + } + std::vector result_types; + Append(result_types, callers.forward->getResultTypes()); + Append(result_types, callers.core_tpu->getResultTypes()); + + // Create the function based on input and result types and values. + auto func_type = + mlir::FunctionType::get(builder.getContext(), input_types, result_types); + func::FuncOp then_func = func::FuncOp::create(loc, name, func_type); + then_func.setPrivate(); + symbol_table.insert(then_func); + mlir::OpBuilder func_builder = + mlir::OpBuilder::atBlockBegin(then_func.addEntryBlock()); + + // This must match the concatenation order in 'operands' above. + IRMapping ir_map; + int pos = 0; + for (auto orig : orig_parent_func.getArguments()) + ir_map.map(orig, then_func.getArgument(pos++)); + + // Clone the specified ops into the new function. + auto new_forward = func_builder.insert(callers.forward->clone(ir_map)); + for (auto p : + llvm::zip(callers.core_tpu->getResults(), new_forward->getResults())) + ir_map.map(std::get<0>(p), std::get<1>(p)); + auto new_core_tpu = func_builder.insert(callers.core_tpu->clone(ir_map)); + + // Add the function return; + std::vector results; + Append(results, new_forward->getResults()); + Append(results, new_core_tpu->getResults()); + func_builder.create(loc, results); + + // Inline any StatefulPartitionCall Ops. + auto result = Inliner(builder, symbol_table).InlineCallsInFunc(then_func); + if (failed(result)) return result; + + builder.restoreInsertionPoint(insertion_point); + caller = MakeFuncCaller(builder, loc, then_func, operands, + /*flag_for_inlining=*/false); + return LogicalResult::success(); +} + +LogicalResult StartStep1(OpBuilder& builder, Location& loc, + SymbolTable& symbol_table, + TF::TPUReplicateMetadataOp& metadata_op, + TF::TPUCompilationResultOp& compilation_op, + Value& cond_value, Callers& callers, + const std::vector& loop_operands_1, + TF::StatefulPartitionedCallOp& caller) { + const std::string name = "start_step_1"; + + AddAssertion(builder, loc, cond_value, + "Auto-pipelining requires at least two steps."); + + auto insertion_point = builder.saveInsertionPoint(); + func::FuncOp orig_parent_func = + callers.backward->getParentOfType(); + + std::vector operands = loop_operands_1; + + // Input types will be the same as the original loop body. + std::vector input_types = GetValueTypes(operands); + + // Determine the results types. + // Return ALL outputs, respecting the provided order of the Operations. This + // makes it straightforward for users of this function to map the return + // values. + auto result_types = callers.forward->getResultTypes(); + + // Create the function based on input and result types and values. + auto func_type = + mlir::FunctionType::get(builder.getContext(), input_types, result_types); + func::FuncOp then_func = func::FuncOp::create(loc, name, func_type); + then_func.setPrivate(); + symbol_table.insert(then_func); + mlir::OpBuilder func_builder = + mlir::OpBuilder::atBlockBegin(then_func.addEntryBlock()); + + // This must match the concatenation order in 'operands' above. + IRMapping ir_map; + int pos = 0; + for (auto orig : orig_parent_func.getArguments()) + ir_map.map(orig, then_func.getArgument(pos++)); + + // Clone the specified ops into the new function. + auto new_forward = func_builder.insert(callers.forward->clone(ir_map)); + + // Add the function return; + func_builder.create(loc, new_forward->getResults()); + + // Inline any StatefulPartitionCall Ops. + auto result = Inliner(builder, symbol_table).InlineCallsInFunc(then_func); + if (failed(result)) return result; + + builder.restoreInsertionPoint(insertion_point); + caller = MakeFuncCaller(builder, loc, then_func, operands, + /*flag_for_inlining=*/false); + return LogicalResult::success(); +} + +LogicalResult FinishStepNm2(OpBuilder& builder, Location& loc, + SymbolTable& symbol_table, + TF::TPUReplicateMetadataOp& metadata_op, + TF::TPUCompilationResultOp& compilation_op, + Value& cond_value, Callers& callers, + const std::vector& loop_operands_nm2, + const std::vector& forward_res_nm2, + const std::vector& core_tpu_res_nm2, + TF::StatefulPartitionedCallOp& caller) { + const std::string name = "finish_step_nm2"; + + AddAssertion(builder, loc, cond_value, + "Auto-pipelining requires at least two steps."); + + auto insertion_point = builder.saveInsertionPoint(); + func::FuncOp orig_parent_func = + callers.backward->getParentOfType(); + + std::vector operands = loop_operands_nm2; + Append(operands, forward_res_nm2); + Append(operands, core_tpu_res_nm2); + + // Input types will be the same as the original loop body. + std::vector input_types = GetValueTypes(operands); + + // Determine the results types. + // Return ALL outputs, respecting the provided order of the Operations. This + // makes it straightforward for users of this function to map the return + // values. + auto result_types = callers.backward->getResultTypes(); + + // Create the function based on input and result types and values. + auto func_type = + mlir::FunctionType::get(builder.getContext(), input_types, result_types); + func::FuncOp then_func = func::FuncOp::create(loc, name, func_type); + then_func.setPrivate(); + symbol_table.insert(then_func); + mlir::OpBuilder func_builder = + mlir::OpBuilder::atBlockBegin(then_func.addEntryBlock()); + + // This must match the concatenation order in 'operands' above. + IRMapping ir_map; + int pos = 0; + for (auto orig : orig_parent_func.getArguments()) + ir_map.map(orig, then_func.getArgument(pos++)); + for (auto orig : callers.forward->getResults()) + ir_map.map(orig, then_func.getArgument(pos++)); + for (auto orig : callers.core_tpu->getResults()) + ir_map.map(orig, then_func.getArgument(pos++)); + + // Clone the specified ops into the new function. + auto new_backward = func_builder.insert(callers.backward->clone(ir_map)); + + // Add the function return; + func_builder.setInsertionPointAfter(new_backward); + func_builder.create(loc, new_backward->getResults()); + + // Inline any StatefulPartitionCall Ops. + auto result = Inliner(builder, symbol_table).InlineCallsInFunc(then_func); + if (failed(result)) return result; + + builder.restoreInsertionPoint(insertion_point); + caller = MakeFuncCaller(builder, loc, then_func, operands, + /*flag_for_inlining=*/false); + return LogicalResult::success(); +} + +LogicalResult FinishStepNm1(OpBuilder& builder, Location& loc, + SymbolTable& symbol_table, + TF::TPUReplicateMetadataOp& metadata_op, + TF::TPUCompilationResultOp& compilation_op, + Value& cond_value, Callers& callers, + const std::vector& loop_operands_nm1, + const std::vector& forward_res_nm1, + TF::StatefulPartitionedCallOp& caller) { + const std::string name = "finish_step_nm1"; + + AddAssertion(builder, loc, cond_value, + "Auto-pipelining requires at least two steps."); + + auto insertion_point = builder.saveInsertionPoint(); + func::FuncOp orig_parent_func = + callers.backward->getParentOfType(); + + std::vector operands = loop_operands_nm1; + Append(operands, forward_res_nm1); + + // Input types will be the same as the original loop body. + std::vector input_types = GetValueTypes(operands); + + // Determine the results types. + // Return ALL outputs, respecting the provided order of the Operations. This + // makes it straightforward for users of this function to map the return + // values. + std::vector result_types; + Append(result_types, callers.core_tpu->getResultTypes()); + Append(result_types, callers.backward->getResultTypes()); + + // Create the function based on input and result types and values. + auto func_type = + mlir::FunctionType::get(builder.getContext(), input_types, result_types); + func::FuncOp then_func = func::FuncOp::create(loc, name, func_type); + then_func.setPrivate(); + symbol_table.insert(then_func); + mlir::OpBuilder func_builder = + mlir::OpBuilder::atBlockBegin(then_func.addEntryBlock()); + + // This must match the concatenation order in 'operands' above. + IRMapping ir_map; + int pos = 0; + for (auto orig : orig_parent_func.getArguments()) + ir_map.map(orig, then_func.getArgument(pos++)); + for (auto orig : callers.forward->getResults()) + ir_map.map(orig, then_func.getArgument(pos++)); + + // Clone the specified ops into the new function. + auto new_core_tpu = func_builder.insert(callers.core_tpu->clone(ir_map)); + for (auto p : + llvm::zip(callers.core_tpu->getResults(), new_core_tpu->getResults())) + ir_map.map(std::get<0>(p), std::get<1>(p)); + auto new_backward = func_builder.insert(callers.backward->clone(ir_map)); + // Add the function return; + std::vector results; + Append(results, new_core_tpu->getResults()); + Append(results, new_backward->getResults()); + func_builder.create(loc, results); + + // Inline any StatefulPartitionCall Ops. + auto result = Inliner(builder, symbol_table).InlineCallsInFunc(then_func); + if (failed(result)) return result; + + builder.restoreInsertionPoint(insertion_point); + caller = MakeFuncCaller(builder, loc, then_func, operands, + /*flag_for_inlining=*/false); + return LogicalResult::success(); +} + +LogicalResult MakeForwardOperands(Operation* forward_caller, + Operation* non_tpu_caller, + const std::vector& loop_operands, + const std::vector& non_tpu_res, + std::vector& f_operands) { + f_operands.clear(); + f_operands.reserve(forward_caller->getNumOperands()); + for (auto operand : forward_caller->getOperands()) { + if (llvm::isa(operand)) { + // Pull this from the original operands to the original while op. + auto arg = llvm::cast(operand); + f_operands.push_back(loop_operands[arg.getArgNumber()]); + continue; + } + auto src = operand.getDefiningOp(); + auto res = llvm::cast(operand); + if (src == non_tpu_caller) { + f_operands.push_back(non_tpu_res[res.getResultNumber()]); + } else { + forward_caller->emitOpError() + << "Unknown op source for operand " << operand; + return LogicalResult::failure(); + } + } + return LogicalResult::success(); +} + +LogicalResult MakeCoreTPUOperands(Operation* core_tpu_caller, + Operation* non_tpu_caller, + Operation* forward_caller, + const std::vector& loop_operands, + const std::vector& non_tpu_res, + const std::vector& forward_res, + std::vector& t_operands) { + t_operands.clear(); + t_operands.reserve(core_tpu_caller->getNumOperands()); + for (auto operand : core_tpu_caller->getOperands()) { + if (llvm::isa(operand)) { + // Pull this from the original operands to the original while op. + auto arg = llvm::cast(operand); + t_operands.push_back(loop_operands[arg.getArgNumber()]); + continue; + } + auto src = operand.getDefiningOp(); + auto res = llvm::cast(operand); + if (src == non_tpu_caller) { + t_operands.push_back(non_tpu_res[res.getResultNumber()]); + } else if (src == forward_caller) { + t_operands.push_back(forward_res[res.getResultNumber()]); + } else { + core_tpu_caller->emitOpError() << "Unknown op source for operand " + << operand << ": " << src->getName(); + return LogicalResult::failure(); + } + } + return LogicalResult::success(); +} + +LogicalResult MakeBackwardOperands(Operation* forward_caller, + Operation* core_tpu_caller, + Operation* backward_caller, + const std::vector& loop_operands, + const std::vector& forward_res, + const std::vector& core_tpu_res, + std::vector& b_operands) { + b_operands.clear(); + b_operands.reserve(backward_caller->getNumOperands()); + for (auto operand : backward_caller->getOperands()) { + if (llvm::isa(operand)) { + // Pull this from the original operands to the original while op. + auto arg = llvm::cast(operand); + b_operands.push_back(loop_operands[arg.getArgNumber()]); + continue; + } + auto src = operand.getDefiningOp(); + auto res = llvm::cast(operand); + if (src == forward_caller) { + b_operands.push_back(forward_res[res.getResultNumber()]); + } else if (src == core_tpu_caller) { + b_operands.push_back(core_tpu_res[res.getResultNumber()]); + } else { + // Note: we're expecting no edges from non_tpu() to backward(). + backward_caller->emitOpError() << "Unknown op source for operand " + << operand << ": " << src->getName(); + return LogicalResult::failure(); + } + } + return LogicalResult::success(); +} + +LogicalResult MakeNonTPUOperands(Operation* non_tpu_caller, + const std::vector& loop_operands, + std::vector& n_operands) { + n_operands.clear(); + n_operands.reserve(non_tpu_caller->getNumOperands()); + for (auto operand : non_tpu_caller->getOperands()) { + if (llvm::isa(operand)) { + auto arg = llvm::cast(operand); + n_operands.push_back(loop_operands[arg.getArgNumber()]); + continue; + } + // This shouldn't happen: + auto src = operand.getDefiningOp(); + non_tpu_caller->emitOpError() << "Unknown op source for operand " << operand + << ": " << src->getName(); + return LogicalResult::failure(); + } + return LogicalResult::success(); +} + +Operation* LiftNonTpuFuncCaller(mlir::OpBuilder& builder, + Operation* orig_non_tpu_caller, + const std::vector& operands) { + // Use this to clone an op and lift it outside its parent function. The + // original while body is unchanged. Example: + // Original: + // %x = tf.while(%a, %b) + // ... + // while_body: + // call(f=@sc_fw, %arg0, %arg1) + // Lifted: + // call(f=@sc_fw, %a, %b) + // %x = tf.while(%a, %b) + // ... + func::FuncOp orig_parent_func = + orig_non_tpu_caller->getParentOfType(); + IRMapping ir_map; + ir_map.map(orig_parent_func.getArguments(), operands); + Operation* new_caller = builder.clone(*orig_non_tpu_caller, ir_map); + return new_caller; +} + void EmbeddingPipeliningPass::runOnOperation() { + VLOG(3) << "EmbeddingPipeliningPass::runOnOperation()"; ModuleOp module = getOperation(); + SymbolTable symbol_table(module); llvm::SetVector forward_pass_ops; llvm::SetVector backward_pass_ops; @@ -793,6 +1652,7 @@ void EmbeddingPipeliningPass::runOnOperation() { // If there are no forward pass ops, there is no SC, so we end early. if (forward_pass_ops.empty()) { if (backward_pass_ops.empty()) { + VLOG(1) << "no pipelining ops found"; return; } else { (*backward_pass_ops.begin())->emitOpError() @@ -804,9 +1664,9 @@ void EmbeddingPipeliningPass::runOnOperation() { // Ensure that all ops are in the same region, and have the same replication // info. // TODO(bfontain): Allow for multiple regions/loops in one module. - // TODO(patn): move this pass after cluster formation to remove the complexity - // with replication info and metadata, cluster checking and generalizing to - // multiple TPU clusters. + // TODO(patn): move this pass after cluster formation to remove the + // complexity with replication info and metadata, cluster checking and + // generalizing to multiple TPU clusters. Region* region = (*forward_pass_ops.begin())->getParentRegion(); StringAttr replication_attr = GetReplicationAttr(*forward_pass_ops.begin()); llvm::SmallVector checkset(forward_pass_ops.getArrayRef()); @@ -826,7 +1686,7 @@ void EmbeddingPipeliningPass::runOnOperation() { // TODO(bfontain): Check that the region here is the region // of the loop body func. // Find the FuncOp for the surrounding while loop body. - func::FuncOp loop_body_func = + auto loop_body_func = (*forward_pass_ops.begin())->getParentOfType(); // merged_set will keep track of which ops are to be avoided when gather ops @@ -846,12 +1706,21 @@ void EmbeddingPipeliningPass::runOnOperation() { loop_body_func, replication_attr, merged_set, compilation_op); if (failed(result)) return signalPassFailure(); - TF::WhileOp while_op = nullptr; - result = FindOwningWhileOp(loop_body_func, module, &while_op); + TF::WhileOp orig_while_op = nullptr; + result = FindOwningWhileOp(loop_body_func, module, orig_while_op); if (failed(result)) return signalPassFailure(); + Location loc = orig_while_op->getLoc(); OpBuilder builder(module); + // A special fix for models that pass resources into helper functions and + // return the same resource (after passing it through multiple identity ops). + // Some subsequent ops use the original resource and others use the returned + // version. Pipelining splits these uses across loop iterations resulting in + // terrible things. + result = EliminateResourceLoops(builder, symbol_table, loop_body_func); + if (failed(result)) return signalPassFailure(); + result = FindForwardPassOps(builder, forward_pass_ops, backward_pass_ops, merged_set, loop_body_func, num_replicas); if (failed(result)) return signalPassFailure(); @@ -873,45 +1742,440 @@ void EmbeddingPipeliningPass::runOnOperation() { if (failed(result)) return signalPassFailure(); merged_set.insert(non_tpu_ops.begin(), non_tpu_ops.end()); - VLOG(2) << "Forwards pass " << forward_pass_ops.size() + VLOG(3) << "Forwards pass " << forward_pass_ops.size() << " ops, backwards pass " << backward_pass_ops.size() << " ops, core " << core_tpu_ops.size() << " ops. Total = " << merged_set.size() << " of " - << GetNumOps(loop_body_func) << ".\n"; + << GetNumOps(loop_body_func); builder.setInsertionPointAfter(*non_tpu_ops.begin()); - Operation* non_tpu_caller = nullptr; + TF::StatefulPartitionedCallOp non_tpu_caller = nullptr; result = - ExtractOpsAsFunc(builder, module, non_tpu_ops, replication_attr, nullptr, - nullptr, loop_body_func, "non_tpu", &non_tpu_caller); + ExtractOpsAsFunc(builder, module, symbol_table, non_tpu_ops, + replication_attr, nullptr, nullptr, loop_body_func, + "non_tpu", &non_tpu_caller, /*flag_for_inlining=*/false); if (failed(result)) return signalPassFailure(); builder.setInsertionPointAfter(non_tpu_caller); - Operation* forward_caller = nullptr; - result = ExtractOpsAsFunc(builder, module, forward_pass_ops, replication_attr, - metadata_op, compilation_op, loop_body_func, - "sc_forward", &forward_caller); + TF::StatefulPartitionedCallOp forward_caller = nullptr; + result = ExtractOpsAsFunc(builder, module, symbol_table, forward_pass_ops, + replication_attr, metadata_op, compilation_op, + loop_body_func, "sc_forward", &forward_caller, + /*flag_for_inlining=*/true); if (failed(result)) return signalPassFailure(); // Create tpu_core function builder.setInsertionPointAfter(forward_caller); - Operation* core_tpu_caller = nullptr; - result = ExtractOpsAsFunc(builder, module, core_tpu_ops, replication_attr, - metadata_op, compilation_op, loop_body_func, - "core_tpu", &core_tpu_caller); + TF::StatefulPartitionedCallOp core_tpu_caller = nullptr; + result = ExtractOpsAsFunc(builder, module, symbol_table, core_tpu_ops, + replication_attr, metadata_op, compilation_op, + loop_body_func, "core_tpu", &core_tpu_caller, + /*flag_for_inlining=*/true); if (failed(result)) return signalPassFailure(); builder.setInsertionPointAfter(core_tpu_caller); - Operation* backwards_pass_caller = nullptr; - result = ExtractOpsAsFunc( - builder, module, backward_pass_ops, replication_attr, metadata_op, - compilation_op, loop_body_func, "sc_backward", &backwards_pass_caller); + TF::StatefulPartitionedCallOp backward_caller = nullptr; + result = ExtractOpsAsFunc(builder, module, symbol_table, backward_pass_ops, + replication_attr, metadata_op, compilation_op, + loop_body_func, "sc_backward", &backward_caller, + /*flag_for_inlining=*/true); if (failed(result)) return signalPassFailure(); - metadata_op->erase(); - compilation_op->erase(); -} + Callers orig_callers; + orig_callers.forward = forward_caller; + orig_callers.backward = backward_caller; + orig_callers.core_tpu = core_tpu_caller; + orig_callers.non_tpu = non_tpu_caller; + // The output of the original while op also serves as subsequent input to + // the same function so input_signature == output_signature. Figure out the + // mapping from the result of each of the four functions into the result + // vector. + auto orig_return_op = *loop_body_func.getOps().begin(); + std::map loop_arg_update_map_non_tpu; + std::map loop_arg_update_map_core_tpu; + for (int ret_pos = 0; ret_pos < orig_return_op->getNumOperands(); ++ret_pos) { + auto operand = orig_return_op->getOperand(ret_pos); + auto def_op = operand.getDefiningOp(); + auto result = operand.dyn_cast(); + if (def_op == non_tpu_caller) { + loop_arg_update_map_non_tpu[result.getResultNumber()] = ret_pos; + } else if (def_op == core_tpu_caller) { + loop_arg_update_map_core_tpu[result.getResultNumber()] = ret_pos; + } else if (def_op == forward_caller) { + loop_body_func->emitOpError( + "Unexpected loop carried variable dependency on sc_forward"); + return signalPassFailure(); + } else if (def_op == backward_caller) { + loop_body_func->emitOpError( + "Unexpected loop carried variable dependency on sc_"); + return signalPassFailure(); + } else if (llvm::isa(operand)) { + // pass + } else { + // This should never happen. + loop_body_func->emitOpError("Couldn't find mapping for return value "); + return signalPassFailure(); + } + } + + const int num_f_res = forward_caller->getNumResults(); + const int num_t_res = core_tpu_caller->getNumResults(); + + // At this point, we have separated the main while body ops into four + // functions: + // 1. SC forward pass ("forward_ops") + // 2. TC forward/backward pass ("core_tput_ops") + // 3. SC backward pass ("backward_ops") + // 4. Loop counter updates ("non_tpu_ops") + // + // Next, extract the original conditional function which we'll use to + // kick off the pre-loop pipelining steps. + // are just the operands passed to the original WhileOp. + func::FuncOp orig_cond_func = orig_while_op.cond_function(); + + std::vector loop_operands_0; + const int num_orig_loop_operands = orig_while_op->getNumOperands(); + loop_operands_0.reserve(num_orig_loop_operands); + Append(loop_operands_0, orig_while_op->getOperands()); + + // Evaluate the real conditional function before the new while loop. + builder.setInsertionPoint(orig_while_op); + Operation* cond_caller_0 = + MakeFuncCaller(builder, orig_while_op->getLoc(), orig_cond_func, + loop_operands_0, /*flag_for_inlining=*/false); + Value C_0 = cond_caller_0->getResults().front(); + + // Call the non_tpu function to update the loop counters. This is still + // part of the i=0 loop iteration. + builder.setInsertionPointAfter(cond_caller_0); + Operation* non_tpu_caller_0 = + LiftNonTpuFuncCaller(builder, non_tpu_caller, loop_operands_0); + // Save the results for later reference. + auto non_tpu_res_0 = ResultsAsVector(non_tpu_caller_0); + + // Start step 0. + // Now make the sc_fw + tc_fb call in the pre-loop. We assume (and assert) + // that we'll execute at least two steps. + builder.setInsertionPointAfter(non_tpu_caller_0); + TF::StatefulPartitionedCallOp start_step_0; + result = StartStep0(builder, loc, symbol_table, metadata_op, compilation_op, + C_0, orig_callers, loop_operands_0, start_step_0); + if (failed(result)) return signalPassFailure(); + + // Save the results of the forward_0 and core_tpu_0 calls by slicing them + // out of the results. + auto forward_res_0 = ResultsAsVector(start_step_0, 0, num_f_res); + auto core_tpu_res_0 = ResultsAsVector(start_step_0, num_f_res, num_t_res); + + // Update the loop operands with results of non_tpu() and core_tpu(). + std::vector loop_operands_1 = loop_operands_0; + for (auto p : loop_arg_update_map_non_tpu) + loop_operands_1[p.second] = non_tpu_res_0[p.first]; + for (auto p : loop_arg_update_map_core_tpu) + loop_operands_1[p.second] = core_tpu_res_0[p.first]; + + // The second conditional evaluation. + builder.setInsertionPointAfter(start_step_0); + Operation* cond_caller_1 = + MakeFuncCaller(builder, orig_while_op->getLoc(), orig_cond_func, + loop_operands_1, /*flag_for_inlining=*/false); + Value C_1 = cond_caller_1->getResults().front(); + + builder.setInsertionPointAfter(cond_caller_1); + Operation* non_tpu_caller_1 = + LiftNonTpuFuncCaller(builder, non_tpu_caller, loop_operands_1); + auto non_tpu_res_1 = ResultsAsVector(non_tpu_caller_1); + + // Start step 1. Again, assume. + builder.setInsertionPointAfter(non_tpu_caller_1); + TF::StatefulPartitionedCallOp start_step_1; + result = StartStep1(builder, loc, symbol_table, metadata_op, compilation_op, + C_1, orig_callers, loop_operands_1, start_step_1); + if (failed(result)) return signalPassFailure(); + + // Save the results of the forward_1 call. + auto forward_res_1 = ResultsAsVector(start_step_1); + + // Update the loop operands with any outputs from the non_tpu and core_tpu + // functions. Note, core_tpu isn't called again until the middle of the loop + // body. So, loop_operands_2 is only partially updated here. We'll finish + // updating this after core_tpu() is called in the new while body. + std::vector loop_operands_2 = loop_operands_1; + for (auto p : loop_arg_update_map_non_tpu) + loop_operands_2[p.second] = non_tpu_res_1[p.first]; + + // The second conditional evaluation. The assumption here is that the + // partially updated loop_operands_2 is sufficient for correct evaluation of + // the cond() function. + builder.setInsertionPointAfter(start_step_1); + Operation* cond_caller_2 = + MakeFuncCaller(builder, orig_while_op->getLoc(), orig_cond_func, + loop_operands_2, /*flag_for_inlining=*/false); + Value C_2 = cond_caller_2->getResults().front(); + + // The new while body: + // + // First, we need to construct the body and conditional functions. To do so, + // we need to create the initial operand list that we'll need. This will + // determine the type signature for the body and cond functions. + std::vector tmp_while_operands; + Append(tmp_while_operands, loop_operands_0); + Append(tmp_while_operands, loop_operands_1); + Append(tmp_while_operands, loop_operands_2); + Append(tmp_while_operands, forward_res_0); + Append(tmp_while_operands, forward_res_1); + Append(tmp_while_operands, core_tpu_res_0); + Append(tmp_while_operands, non_tpu_res_1); + Append(tmp_while_operands, {C_0, C_1, C_2}); + + // Dedupe the operands. We'll need a map to help translate. + llvm::SetVector new_while_operands; + llvm::MapVector loop_var_map; + for (auto operand : tmp_while_operands) { + if (new_while_operands.insert(operand)) { + // First time seeing this operand. Let's record the final resting place + // in the new_while_operands vector. + loop_var_map[operand] = new_while_operands.size() - 1; + } + } + // Save index mappings for canonical vectors. + auto BuildUnpackIndexes = + [&loop_var_map](std::vector& prototype_vals) { + std::vector indexes; + indexes.reserve(prototype_vals.size()); + for (auto prototype_val : prototype_vals) + indexes.push_back(loop_var_map[prototype_val]); + return indexes; + }; + auto loop_operands_indexes_im2 = BuildUnpackIndexes(loop_operands_0); + auto loop_operands_indexes_im1 = BuildUnpackIndexes(loop_operands_1); + auto loop_operands_indexes_i = BuildUnpackIndexes(loop_operands_2); + auto forward_res_indexes_im2 = BuildUnpackIndexes(forward_res_0); + auto forward_res_indexes_im1 = BuildUnpackIndexes(forward_res_1); + auto core_tpu_res_indexes_im2 = BuildUnpackIndexes(core_tpu_res_0); + auto non_tpu_res_indexes_im1 = BuildUnpackIndexes(non_tpu_res_1); + int C_index_im2 = loop_var_map[C_0]; + int C_index_im1 = loop_var_map[C_1]; + int C_index_i = loop_var_map[C_2]; + + // Get the operand types. + std::vector new_while_operand_types = GetValueTypes(new_while_operands); + + // Make cond and body functions for the new while op. + // Create the function based on input and result types and values. + // Note, for a while loop body function, the operand types and result types + // are identical. + auto body_func_type = mlir::FunctionType::get( + &getContext(), new_while_operand_types, new_while_operand_types); + auto cond_func_type = mlir::FunctionType::get( + &getContext(), new_while_operand_types, orig_cond_func.getResultTypes()); + func::FuncOp cond = + func::FuncOp::create(loc, "new_while_cond", cond_func_type); + func::FuncOp body = + func::FuncOp::create(loc, "new_while_body", body_func_type); + cond.setPrivate(); + body.setPrivate(); + symbol_table.insert(cond); + symbol_table.insert(body); + OpBuilder cond_builder = OpBuilder::atBlockBegin(cond.addEntryBlock()); + OpBuilder body_builder = OpBuilder::atBlockBegin(body.addEntryBlock()); + + //**************************************************************************** + // Build the internals of the new tf.While op's conditional function. + //**************************************************************************** + // Build the cond function body. All we need is a ReturnOp that returns C_i + // which is the last argument. + cond_builder.create(loc, cond.getArgument(C_index_i)); + + //**************************************************************************** + // Build the internals of the new tf.While op's body function. + //**************************************************************************** + auto body_args = body.getArguments(); + // First, let's unpack all the body arguments. + auto UnpackArgs = [&body_args](std::vector& indexes) { + // This helper makes it easy to unpack "natural" vectors of values while + // still respecting the impact of deduping. + std::vector slice; + int num = indexes.size(); + slice.reserve(num); + for (auto i : indexes) slice.push_back(body_args[i]); + return slice; + }; + auto loop_operands_im2 = UnpackArgs(loop_operands_indexes_im2); + auto loop_operands_im1 = UnpackArgs(loop_operands_indexes_im1); + auto loop_operands_i = UnpackArgs(loop_operands_indexes_i); + auto forward_res_im2 = UnpackArgs(forward_res_indexes_im2); + auto forward_res_im1 = UnpackArgs(forward_res_indexes_im1); + auto core_tpu_res_im2 = UnpackArgs(core_tpu_res_indexes_im2); + auto non_tpu_res_im1 = UnpackArgs(non_tpu_res_indexes_im1); + auto C_im1 = body_args[C_index_im1]; + auto C_i = body_args[C_index_i]; + + // Now, construct the operand least for each op by unpacking values. + + // + // Finish step i-2 + // + // First, add all the inputs to sc_backward(). These all come from the block + // arguments, sc_forward() and core_tpu() and need to be pulled from the + // "i-2" (or "0") version of the inputs. + std::vector b_operands; + result = MakeBackwardOperands(forward_caller, core_tpu_caller, + backward_caller, loop_operands_im2, + forward_res_im2, core_tpu_res_im2, b_operands); + if (failed(result)) return signalPassFailure(); + auto backward_caller_im2 = body_builder.clone(*backward_caller); + backward_caller_im2->setOperands(b_operands); + + // + // Finish step i-1 + // + // Second, add all the inputs to core_tpu(). Thesse all come from the while + // loop opernads, sc_forward() or non_tpu() and need to be pulled from the + // "i-1" (or "1") version of the inputs. + std::vector t_operands; + result = MakeCoreTPUOperands(core_tpu_caller, non_tpu_caller, forward_caller, + loop_operands_im1, non_tpu_res_im1, + forward_res_im1, t_operands); + if (failed(result)) return signalPassFailure(); + auto core_tpu_caller_im1 = body_builder.clone(*core_tpu_caller); + core_tpu_caller_im1->setOperands(t_operands); + auto core_tpu_res_im1 = ResultsAsVector(core_tpu_caller_im1); + + // Update the loop operands with results of core_tpu(). + for (auto p : loop_arg_update_map_core_tpu) + loop_operands_i[p.second] = core_tpu_res_im1[p.first]; + + // + // Start step i + // + // Third, add all the inputs to non_tpu(). These all come from the while + // loop operands and need to be pulled from the "i" (or "2") version of the + // inputs. + std::vector n_operands; + result = MakeNonTPUOperands(non_tpu_caller, loop_operands_i, n_operands); + if (failed(result)) return signalPassFailure(); + auto non_tpu_caller_i = body_builder.clone(*non_tpu_caller); + non_tpu_caller_i->setOperands(n_operands); + auto non_tpu_res_i = ResultsAsVector(non_tpu_caller_i); + + // Fourth, add all the inputs to sc_forward(). These all come from the + // while loop operands or the non_tpu() call that's in the loop body. The + // loop operands need to be pulled from the "i" (or "2") version of the + // inputs. The inputs coming from non_tpu() are from the same loop iteration + // (non_tpu_res_i). + std::vector f_operands; + result = MakeForwardOperands(forward_caller, non_tpu_caller, loop_operands_i, + non_tpu_res_i, f_operands); + if (failed(result)) return signalPassFailure(); + auto forward_caller_i = body_builder.clone(*forward_caller); + forward_caller_i->setOperands(f_operands); + auto forward_res_i = ResultsAsVector(forward_caller_i); + + // Update the loop operands with results of non_tpu(). Results for + // core_tpu() are lagged. + std::vector loop_operands_ip1 = loop_operands_i; + for (auto p : loop_arg_update_map_non_tpu) + loop_operands_ip1[p.second] = non_tpu_res_i[p.first]; + + // Add the conditional evaluation for the next loop iteration. + Operation* cond_caller_ip1 = + MakeFuncCaller(body_builder, orig_while_op->getLoc(), orig_cond_func, + loop_operands_ip1, /*flag_for_inlining=*/false); + Value C_ip1 = cond_caller_ip1->getResults().front(); + + // Build the ReturnOp. This mirrors the construction of the operands with + // 'i' values incremented. + std::vector tmp_body_results; + Append(tmp_body_results, loop_operands_im1); + Append(tmp_body_results, loop_operands_i); + Append(tmp_body_results, loop_operands_ip1); + Append(tmp_body_results, forward_res_im1); + Append(tmp_body_results, forward_res_i); + Append(tmp_body_results, core_tpu_res_im1); + Append(tmp_body_results, non_tpu_res_i); + Append(tmp_body_results, {C_im1, C_i, C_ip1}); + + llvm::SetVector new_body_results; + // This should pack the same as deduping code above. + new_body_results.insert(tmp_body_results.begin(), tmp_body_results.end()); + auto new_body_return_types = GetValueTypes(new_body_results); + + body_builder.setInsertionPointAfter(cond_caller_ip1); + body_builder.create(orig_while_op->getLoc(), + new_body_results.getArrayRef()); + + // Finally, create the new tf.WhileOp. + builder.setInsertionPoint(orig_while_op); + auto new_while_op = builder.create( + orig_while_op->getLoc(), new_body_return_types, + new_while_operands.getArrayRef(), cond.getSymName(), body.getSymName(), + /*parallel_iterations=*/10, + /*is_stateless=*/false, + /*shape_invariant=*/false); + SetBasicBlockAttributes(builder, new_while_op); + + // First, let's unpack all the body arguments. + auto UnpackResults = [&new_while_op](std::vector& indexes) { + int num = indexes.size(); + std::vector slice; + slice.reserve(num); + for (auto i : indexes) slice.push_back(new_while_op->getResult(i)); + return slice; + }; + auto loop_operands_nm2 = UnpackResults(loop_operands_indexes_im2); + auto loop_operands_nm1 = UnpackResults(loop_operands_indexes_im1); + auto loop_operands_n = UnpackResults(loop_operands_indexes_i); + auto forward_res_nm2 = UnpackResults(forward_res_indexes_im2); + auto forward_res_nm1 = UnpackResults(forward_res_indexes_im1); + auto core_tpu_res_nm2 = UnpackResults(core_tpu_res_indexes_im2); + auto non_tpu_res_nm1 = UnpackResults(non_tpu_res_indexes_im1); + auto C_nm2 = new_while_op->getResult(C_index_im2); + auto C_nm1 = new_while_op->getResult(C_index_im1); + + // Finish step n-2. + builder.setInsertionPointAfter(new_while_op); + TF::StatefulPartitionedCallOp finish_step_nm2; + result = FinishStepNm2(builder, loc, symbol_table, metadata_op, + compilation_op, C_nm2, orig_callers, loop_operands_nm2, + forward_res_nm2, core_tpu_res_nm2, finish_step_nm2); + if (failed(result)) return signalPassFailure(); + + // Finish step n-1. + builder.setInsertionPointAfter(finish_step_nm2); + TF::StatefulPartitionedCallOp finish_step_nm1; + result = FinishStepNm1(builder, loc, symbol_table, metadata_op, + compilation_op, C_nm1, orig_callers, loop_operands_nm1, + forward_res_nm1, finish_step_nm1); + if (failed(result)) return signalPassFailure(); + + // Save the results of the core_tpu_0 call and use it to finalize the + // loop_operands_n array. + auto core_tpu_res_nm1 = ResultsAsVector(finish_step_nm1, 0, num_t_res); + for (auto p : loop_arg_update_map_core_tpu) + loop_operands_n[p.second] = core_tpu_res_nm1[p.first]; + + // Replace the return values from the original WhileOp with the output of + // the pipelining. + for (auto p : llvm::zip(orig_while_op->getResults(), loop_operands_n)) + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + *orig_while_op->getParentRegion()); + + // Inline the new while body. + result = Inliner(builder, symbol_table).InlineCallsInFunc(body, false); + if (failed(result)) return signalPassFailure(); + + // Erase original while op and temporary functions. Note, we use the non_tpu + // function in the output graph. + symbol_table.lookup(orig_callers.forward.getF())->erase(); + symbol_table.lookup(orig_callers.core_tpu.getF())->erase(); + symbol_table.lookup(orig_callers.backward.getF())->erase(); + orig_while_op.body_function().erase(); + orig_while_op.erase(); + + VLOG(3) << "EmbeddingPipeliningPass::runOnOperation done."; +} } // namespace std::unique_ptr> CreateEmbeddingPipeliningPass() { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc new file mode 100644 index 00000000000..a83f6ac54a8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc @@ -0,0 +1,924 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This pass separates SparseCore, TensorCore, and non-TPU operations into +// separate functions for proper sequencing of TF2 TPU Embedding (see +// tpu_embedding_v3.py). This pass is a precursor for pipelining (see +// embedding_pipelining.cc) and DOES NOT permit parallel execution across SC and +// TC. This pass is a temporary fallback to use while developing full pipelining +// capabilities. +// +// Ops are broken up into: +// 1. SC forward pass +// 2. TC forward/backward pass +// 3. SC backward pass +// 4. non-TPU loop counter updates + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" + +#define GEN_PASS_DEF_EMBEDDINGSEQUENCINGPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining"; +static constexpr char kEmbeddingForward[] = "forward"; +static constexpr char kEmbeddingBackward[] = "backward"; +static constexpr char kDevice[] = "device"; +static constexpr llvm::StringRef kTpuCompilationStatus = + "_tpu_compilation_status"; + +namespace mlir { +namespace TFDevice { +namespace { + +struct EmbeddingSequencingPass + : public ::impl::EmbeddingSequencingPassBase { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; + +template +std::vector GetValueTypes(const InputContainer& input) { + // Convert a list of mlir::Value's into a list of mlir::Type's + std::vector types; + types.reserve(input.size()); + for (auto val : input) types.push_back(val.getType()); + return types; +} + +bool IsResourceType(Type val_type) { + if (auto tensor_type = val_type.dyn_cast()) { + if (tensor_type.getElementType().isa()) { + return true; + } + } + return false; +} + +bool IsTPUOp(mlir::Operation* op) { + return op->hasAttr(TF::kReplicationInfoAttr); +} + +StringAttr GetReplicationAttr(mlir::Operation* op) { + return op->getAttrOfType(TF::kReplicationInfoAttr); +} + +StringAttr GetReplicationAttr(TF::TPUCompilationResultOp op) { + // Special case for getting the replication region for + // TPUCompilationResultsOp. + return op->getAttrOfType(kTpuCompilationStatus); +} + +int64_t GetNumOps(func::FuncOp func) { + int64_t num_ops = 0; + for (auto it = func.begin(); it != func.end(); ++it) ++num_ops; + return num_ops; +} + +void GatherOpsForExtraction(mlir::SetVector* operations, + const mlir::SetVector& ops_to_avoid, + bool predecessors, bool successors) { + // Walk the input and output dependencies of the Ops in `operations` to form + // the closer of Ops needed to evaluate 'operations'. Input dependencies are + // walked if 'predecessors' is true and output dependencies are walked if + // 'successors' is true. In either case, if a discoverd Op is in the + // 'ops_to_avoid' set, then the dependency walking is terminated. + llvm::SetVector ops_to_process(*operations); + llvm::SetVector new_ops; + + while (!ops_to_process.empty()) { + for (Operation* op : ops_to_process) { + if (predecessors) { + for (Value operand : op->getOperands()) { + // Stop at the block boundary. + if (operand.isa()) continue; + + Operation* predecessor = operand.getDefiningOp(); + if (!operations->contains(predecessor) && + !ops_to_avoid.contains(predecessor)) { + new_ops.insert(operand.getDefiningOp()); + operations->insert(operand.getDefiningOp()); + } + } + } + if (successors) { + for (mlir::Operation* successor : op->getUsers()) { + // Don't include the return op + if (llvm::isa(successor)) continue; + + if (!operations->contains(successor) && + !ops_to_avoid.contains(successor)) { + new_ops.insert(successor); + operations->insert(successor); + } + } + } + } + ops_to_process.swap(new_ops); + new_ops.clear(); + } +} + +TF::StatefulPartitionedCallOp MakeFuncCaller( + mlir::OpBuilder& builder, const Location& loc, func::FuncOp func, + const llvm::SetVector& operands) { + // Constructs a tf.StatefulPartitionedCall to the function provided in 'func' + // using the operands in 'operands'. Assumes the insertion point on builder is + // already set. + auto symbol = + mlir::SymbolRefAttr::get(builder.getContext(), func.getSymName()); + auto result_types = func.getResultTypes(); + auto caller = builder.create( + loc, result_types, operands.getArrayRef(), symbol, + /*config=*/builder.getStringAttr(""), + /*config_proto=*/builder.getStringAttr(""), + /*executor_type=*/builder.getStringAttr("")); + caller.setFAttr(symbol); + return caller; +} + +func::FuncOp CreateFnWithSignature(ModuleOp module, + const llvm::SetVector& inputs, + const llvm::SetVector& outputs, + const std::string& name) { + // Creates an empty func.FuncOp with a signature compatible with 'inputs' + // (operands) and 'outputs' (results). + OpBuilder builder(module); + + std::vector input_types = GetValueTypes(inputs); + std::vector output_types = GetValueTypes(outputs); + builder.setInsertionPointToEnd(&module.getBodyRegion().back()); + func::FuncOp func_op = builder.create( + module.getLoc(), name, + builder.getFunctionType(input_types, output_types)); + func_op.setPrivate(); + + return func_op; +} + +TF::StatefulPartitionedCallOp EncapsulateOpsInFunc( + OpBuilder& builder, const llvm::SetVector& ops, + const llvm::SetVector& inputs, const llvm::SetVector& outputs, + func::FuncOp parent_func, ModuleOp module, const std::string& name) { + // Moves all of the Operations in 'ops' into a newly created func.FuncOp + // function named 'name' and replaces the original ops with a call to the + // newly created function using a tf.StatefulPartitionedCall. Here, + // 'parent_func' is the function that holds the original set of ops. + // Note, 'inputs' and 'outputs' are the predetermined set of values that + // should become the operands and return values, respectively. + auto insertion_point = builder.saveInsertionPoint(); + func::FuncOp new_func = CreateFnWithSignature(module, inputs, outputs, + absl::StrCat("_func_", name)); + + // This preserves the order of the ops that was in the original parent + // funtion. This is critical for preserving correctness in the presence of + // resource variables and stateful functions. + std::vector topological_order; + for (Operation& op : parent_func.getOps()) + if (ops.contains(&op)) topological_order.push_back(&op); + + // Create the partitioned call + builder.restoreInsertionPoint(insertion_point); + auto caller = MakeFuncCaller(builder, module.getLoc(), new_func, inputs); + + Block* block = new_func.addEntryBlock(); + + for (Operation* op : topological_order) op->moveBefore(block, block->end()); + + // Replace the 'inputs' values with the new function's arguments. + for (auto p : llvm::zip(inputs, new_func.getArguments())) + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + new_func.getBody()); + + builder.setInsertionPointToEnd(block); + builder.create(parent_func.getLoc(), outputs.getArrayRef()); + + // Replace the original 'outputs' values with the result of the call to the + // new function. + for (auto p : llvm::zip(outputs, caller->getResults())) + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + parent_func.getBody()); + + return caller; +} + +void UpdateAndInsertTPUOps(TF::StatefulPartitionedCallOp caller, + TF::TPUReplicateMetadataOp metadata_op, + TF::TPUCompilationResultOp compilation_op, + StringAttr old_group) { + // Adds the TPUReplicateMetatdataOp and TPUCompilationResultOp ops to the + // function called by the provided 'caller'. + mlir::CallInterfaceCallable callable = caller.getCallableForCallee(); + mlir::SymbolRefAttr sym = callable.dyn_cast(); + auto func = llvm::dyn_cast( + mlir::SymbolTable::lookupNearestSymbolFrom(caller, sym)); + OpBuilder builder(func.getBody()); + + StringAttr new_group = builder.getStringAttr( + absl::StrCat(old_group.getValue().str(), caller.getF().str())); + + builder.insert(metadata_op.clone()); + for (Operation& op : func.getOps()) { + if (!IsTPUOp(&op)) continue; + op.setAttr(TF::kReplicationInfoAttr, new_group); + } + TF::TPUCompilationResultOp new_result = compilation_op.clone(); + new_result->setAttr(kTpuCompilationStatus, new_group); + builder.insert(new_result); +} + +template +LogicalResult FindAndExcludeOp(func::FuncOp func, + const StringAttr& replication_attr, + llvm::SetVector& merged_set, + OpType& found_op) { + // Find the TPUReplicationMetadata or TPUCompilationResult ops which will be + // cloned/inserted into each region. We add them to the merged_set so that + // they're ignored when extracting the four main functions. + found_op = nullptr; + for (OpType op : func.getOps()) { + if (found_op != nullptr) { + func.emitOpError() << "number of " << found_op.getOperationName() + << " in loop body is not 1"; + return LogicalResult::failure(); + } + if (GetReplicationAttr(op) != replication_attr) { + op.emitOpError() << "is not part of the replication region " + << replication_attr << " vs " << GetReplicationAttr(op); + return LogicalResult::failure(); + } + found_op = op; + merged_set.insert(found_op); + } + return LogicalResult::success(); +} + +LogicalResult FindOwningWhileOp(func::FuncOp body_func, ModuleOp module, + TF::WhileOp* while_op) { + // Given a while loop body function 'body_func', find the tf.While Op that + // uses it. + auto uses_optional = body_func.getSymbolUses(module); + if (!uses_optional.has_value()) { + body_func.emitOpError() << "no use of while loop body"; + return LogicalResult::failure(); + } + *while_op = nullptr; + for (auto& use : uses_optional.value()) { + if (llvm::isa(use.getUser())) { + if (*while_op != nullptr) { + use.getUser()->emitOpError() << "multiple users of function."; + return LogicalResult::failure(); + } else { + *while_op = llvm::cast(use.getUser()); + } + } else { + use.getUser()->emitOpError() << "non while use of function."; + return LogicalResult::failure(); + } + } + // TODO(bfontain): If the while op is not present we could just split things + // or we wait until the compiler supports multiple regions? + if (while_op == nullptr) { + body_func.emitOpError() << "unable to find while body user."; + return LogicalResult::failure(); + } + return LogicalResult::success(); +} + +LogicalResult FindForwardPassOps(OpBuilder& builder, + llvm::SetVector& forward_pass_ops, + llvm::SetVector& backward_pass_ops, + llvm::SetVector& merged_set, + func::FuncOp loop_body_func, + const int num_replicas) { + // Find all the ops that are to be included in the 'sc_forward' function which + // will be executed on the SparseCore. Note, 'forward_pass_ops' is initially + // seeded with ops from the input MLIR graph that have the + // _embedding_pipelining="forward" attribute which is set by the TF2 Embedding + // API. + // + // When outputs of the forward pass function are used outside of it, we'll + // need to insert a TPUReplicatedOutput Op and include that in the + // forward_pass_ops. And if that usage is also on the TPU (either TensorCore + // or SparseCore) we'll need to insert a matching TPUReplicatedInput. We do + // this before the Ops are removed from the original function/graph so that + // function operands and return values are handled automatically. + + // First, walk the op dependencies. + GatherOpsForExtraction(&forward_pass_ops, merged_set, /*predecessors=*/true, + /*successors=*/false); + + // Locate which variable inputs are part of the forwards pass. These will + // also be used in the backwards pass. We need to create a 'private' copy + // of the TpuReplicatedInput for for the fowards pass if there are users + // outside the pass. Note that in the case of the backwards pass existing + // this will be the case. + // This means that when we have put all out sections together some resource + // inputs will have multiple TPUReplicateInput nodes, so we will need a final + // pass to merge these together into the earliest copy. + llvm::SetVector forward_variable_inputs; + + // Validate that the only resource inputs that are read by ops in + // forward_pass_ops are dataset and variable ops. + int64_t resource_count = 0; + for (auto argument : loop_body_func.getArguments()) { + // Check that all resource arguments are either fed to iterator get next + // or a TPUReplicatedInput with is_packed. + + if (IsResourceType(argument.getType())) { + resource_count++; + bool is_variable = false; + bool is_non_variable = false; + bool use_in_forward = false; + bool use_in_not_forward = false; + for (auto user : argument.getUsers()) { + if (llvm::isa(user)) continue; + if (!forward_pass_ops.contains(user)) { + use_in_not_forward = true; + } else { + use_in_forward = true; + } + if (TF::TPUReplicatedInputOp input = + llvm::dyn_cast(user)) { + if (!input.getIsPacked()) { + input.emitOpError() << "unexpected variable input, not packed"; + return LogicalResult::failure(); + } + + if (is_variable) { + input.emitOpError() << "unexpected multiple TPUReplicatedInputOp " + << "for single argument"; + return LogicalResult::failure(); + } + is_variable = true; + } else { + is_non_variable = true; + } + } + if (use_in_forward && use_in_not_forward) { + loop_body_func.emitOpError() + << "resource input " << argument.getArgNumber() + << " is used both in the forwards and " + << "not forward passes dataset"; + return LogicalResult::failure(); + } + if (is_non_variable && is_variable) { + loop_body_func.emitOpError() + << "resource input " << argument.getArgNumber() + << " is used both as a varible and not " + << " a variable"; + return LogicalResult::failure(); + } + if (is_variable && use_in_forward) + forward_variable_inputs.insert(argument.getArgNumber()); + } + } + + VLOG(3) << "Found " << forward_variable_inputs.size() + << " variables used in forward pass of " << resource_count + << " total resource inputs"; + + // Clone the TPUReplicatedInputs. + int64_t cloned_inputs = 0; + for (int64_t index : forward_variable_inputs) { + Value argument = loop_body_func.getArgument(index); + // Uses of this argument should only be the return and the + // TPUReplicateInputOp. This is checked by the loop above. + Operation* input_ptr = nullptr; + for (Operation* user : argument.getUsers()) { + if (llvm::isa(user)) { + input_ptr = user; + break; + } + } + TF::TPUReplicatedInputOp input = + llvm::cast(input_ptr); + + // Validate that all users of the TPUReplicatedInput are ReadVariable + // or AssignVariable ops and check if any are outside the forwards pass. + bool duplicate_needed = false; + for (Operation* next_user : input.getOutput().getUsers()) { + if (!llvm::isa(next_user) && + !llvm::isa(next_user)) { + next_user->emitOpError() + << "unexpected user of output of TPUReplicatedInputOp"; + return LogicalResult::failure(); + } + if (!forward_pass_ops.contains(next_user)) duplicate_needed = true; + } + if (!duplicate_needed) continue; + + cloned_inputs++; + builder.setInsertionPointAfter(input); + forward_pass_ops.remove(input); + + TF::TPUReplicatedInputOp private_input = input.clone(); + builder.insert(private_input); + forward_pass_ops.insert(private_input); + for (OpOperand& next_use : input.getOutput().getUses()) { + if (!forward_pass_ops.contains(next_use.getOwner())) continue; + next_use.getOwner()->setOperand(next_use.getOperandNumber(), + private_input.getOutput()); + } + } + + VLOG(2) << "Cloned " << cloned_inputs << " TPUReplicatedInputOps"; + + // Add TPUReplicatedInput/TPUReplicatedOutput pairs along each edge. + llvm::SetVector new_forward_ops; + for (Operation* op : forward_pass_ops) { + // TODO(bfontain): Should validate that all the TPU ops are in the same + // replication region. + if (!IsTPUOp(op)) continue; + for (Value result : op->getResults()) { + std::vector> out_of_region_use; + for (OpOperand& use : result.getUses()) { + auto use_owner = use.getOwner(); + // TODO(bfontain): Error check here, if the use.getOwner() is not a TPU + // then this op must be a TPUReplicatedOutputOp. + if (IsTPUOp(use_owner) && !forward_pass_ops.contains(use_owner)) + out_of_region_use.push_back( + std::make_pair(use_owner, use.getOperandNumber())); + } + if (out_of_region_use.empty()) continue; + builder.setInsertionPointAfter(op); + std::vector types(num_replicas, result.getType()); + TF::TPUReplicatedOutputOp replicated_output = + builder.create(op->getLoc(), + TypeRange(types), result); + new_forward_ops.insert(replicated_output); + // TODO(bfontain): Check for other attributes. + replicated_output->setAttr(kDevice, builder.getStringAttr("")); + TF::TPUReplicatedInputOp input = builder.create( + op->getLoc(), result.getType(), replicated_output.getResults()); + input->setAttr(kDevice, builder.getStringAttr("")); + mlir::Value new_value = input.getOutput(); + + if (mlir::isa( + result.getDefiningOp())) { + TF::TPUAnnotateTensorsWithDynamicShapeOp annotate_op = + builder.create( + op->getLoc(), result.getType(), new_value, + result.getDefiningOp()->getAttrs()); + for (auto [operation, index] : out_of_region_use) { + if (!backward_pass_ops.contains(operation)) { + operation->emitOpError() + << "expect all dynamic inputs consumed by backwards pass."; + return LogicalResult::failure(); + } + } + + backward_pass_ops.insert(annotate_op); + new_value = annotate_op->getResult(0); + } + for (auto [operation, index] : out_of_region_use) + operation->setOperand(index, new_value); + } + } + + VLOG(2) << "inserted " << new_forward_ops.size() << " TPU Input/Output ops"; + forward_pass_ops.insert(new_forward_ops.begin(), new_forward_ops.end()); + return LogicalResult::success(); +} + +LogicalResult FindBackwardPassOps( + OpBuilder& builder, llvm::SetVector& backward_pass_ops, + llvm::SetVector& merged_set, const int num_replicas) { + // Find all the ops that are to be included in the 'sc_backward' function + // which will be executed on the SparseCore. Note, 'backward_pass_ops' is + // initially seeded with ops from the input MLIR graph that have the + // _embedding_pipelining="backward" attribute which is set by the TF2 + // Embedding API. + // + // Since we're inserting a replication boundary around the backward pass + // function, we'll also need to make sure TPUReplicatedInputOp and + // TPUReplicatedOutputOp ops are inserted as necessary. + + // First, walk the Ops dependencies. + GatherOpsForExtraction(&backward_pass_ops, merged_set, /*predecessors=*/false, + /*successors=*/true); + + VLOG(3) << "found " << backward_pass_ops.size() << " backwards pass ops"; + + // If any inputs are to the backward_pass_ops region are direct + // TPUReplicatedInput ops, then include (if this is the only use) or + // clone the op. This will be the case for all Read/Assign variable ops. + + llvm::SetVector to_clone; + llvm::SetVector to_insert; + + for (Operation* op : backward_pass_ops) { + for (OpOperand& input_value : op->getOpOperands()) { + Operation* predecessor_op = input_value.get().getDefiningOp(); + if (TF::TPUReplicatedInputOp input = + llvm::dyn_cast(predecessor_op)) { + if (to_clone.contains(input) || to_insert.contains(input)) continue; + // Check if all uses in backwards pass. + bool all_in_backwards = true; + for (Operation* user : input->getUsers()) + if (!backward_pass_ops.contains(user)) all_in_backwards = false; + if (all_in_backwards) + to_insert.insert(input); + else + to_clone.insert(input); + } + } + } + backward_pass_ops.insert(to_insert.begin(), to_insert.end()); + for (TF::TPUReplicatedInputOp input : to_clone) { + builder.setInsertionPointAfter(input); + TF::TPUReplicatedInputOp private_input = input.clone(); + builder.insert(private_input); + backward_pass_ops.insert(private_input); + for (OpOperand& next_use : input.getOutput().getUses()) { + if (!backward_pass_ops.contains(next_use.getOwner())) continue; + next_use.getOwner()->setOperand(next_use.getOperandNumber(), + private_input.getOutput()); + } + } + + VLOG(2) << " cloned " << to_clone.size() << " and inserted " + << to_insert.size() << " TPUReplicatedInput ops"; + + // For all other inputs that go from TPU op to TPU op, insert the + // TPUOutput/Input pair. + + // Add TPUReplicatedInput/TPUReplicatedOutput pairs along each edge. + // TODO(bfontain): Should be merged with the above loop. + llvm::SetVector values_to_add_nodes; + + for (Operation* op : backward_pass_ops) { + // TODO(bfontain): Should validate that all the TPU ops are in the same + // replication region. + // If the op is already a replicated input, no need to to anything. + if (!IsTPUOp(op) || llvm::isa(op)) continue; + for (OpOperand& input_value : op->getOpOperands()) + // TODO(bfontain): Error check here, this line should never be false, + // since we skip the TF::TPUReplicatedInputOp case. + if (IsTPUOp(input_value.get().getDefiningOp()) && + !backward_pass_ops.contains(input_value.get().getDefiningOp())) + values_to_add_nodes.insert(input_value.get()); + } + + for (Value value : values_to_add_nodes) { + builder.setInsertionPointAfter(value.getDefiningOp()); + std::vector types(num_replicas, value.getType()); + Location loc = value.getDefiningOp()->getLoc(); + TF::TPUReplicatedOutputOp output = + builder.create(loc, TypeRange(types), value); + // TODO(bfontain): Check for other attributes. + output->setAttr(kDevice, builder.getStringAttr("")); + TF::TPUReplicatedInputOp input = builder.create( + loc, value.getType(), output.getResults()); + input->setAttr(kDevice, builder.getStringAttr("")); + for (OpOperand& use : value.getUses()) + if (backward_pass_ops.contains(use.getOwner())) + use.getOwner()->setOperand(use.getOperandNumber(), input.getOutput()); + backward_pass_ops.insert(input); + } + + VLOG(2) << " inserted " << values_to_add_nodes.size() + << " TPUReplicatedInput/Output pairs"; + return LogicalResult::success(); +} + +LogicalResult FindCoreTPUOps( + llvm::SetVector& core_tpu_ops, + const llvm::SetVector& forward_pass_ops, + const llvm::SetVector& backward_pass_ops, + const llvm::SetVector& merged_set, + func::FuncOp loop_body_func) { + // Find all of the Ops that are part of the forward/backward pass but aren't + // targeting the SparseCore. Note that we need to include some non-TPU ops + // that flow out of the forward pass function. Otherwise, they would get + // absorbed into the non_tpu function which breaks the pipelining + // decomposition strategy. + // + // Find all the outputs of the forward pass that aren't fed into the backward + // pass. + for (Operation* op : forward_pass_ops) { + for (Value res : op->getResults()) { + for (auto user : res.getUsers()) { + if (!forward_pass_ops.contains(user) && + !backward_pass_ops.contains(user)) { + core_tpu_ops.insert(user); + } + } + } + } + + // Gather all TPU ops marked for compilation in this while loop body that also + // are not in one of the two other sets. + for (Operation& op : loop_body_func.getOps()) { + // Find all TPU ops that don't belong to the forward or backward pass. + if (merged_set.contains(&op) || llvm::isa(op) || + !IsTPUOp(&op) || op.hasAttr(kEmbeddingPipelining)) + continue; + // TODO(bfontain): only collect those ops in a fixed TPUReplica. + core_tpu_ops.insert(&op); + } + + GatherOpsForExtraction(&core_tpu_ops, merged_set, /*predecessors=*/true, + /*successors=*/true); + + // TODO(patn): Verify that all the ops here fall between the forward pass + // and backward pass ops (i.e., not before the forward pass or after the + // backward pass). + return LogicalResult::success(); +} + +LogicalResult FindNonTPUOps(llvm::SetVector& non_tpu_ops, + const llvm::SetVector& merged_set, + func::FuncOp loop_body_func) { + // Find all of the left over Ops after the sc_forward, sc_backward and + // core_tpu ops have been identified. What's left are just the ops necessary + // for updating loop counters etc. + llvm::SetVector non_tpu_args; + for (Operation& op : loop_body_func.getOps()) { + if (merged_set.contains(&op) || llvm::isa(op) || + op.hasAttr(kEmbeddingPipelining)) + continue; + // Note, there should be no TPU ops left at this point. If this trips, + // there's likely a bug in this pass. + if (IsTPUOp(&op)) { + loop_body_func.emitOpError() + << "Unexpcted TPU op found while identifying non-TPU ops."; + return LogicalResult::failure(); + } + non_tpu_ops.insert(&op); + } + + // Validate that remainder_ops takes and returns a subset of the loop carried + // args. This will basically be our set increment fn. + for (Operation* op : non_tpu_ops) + for (Value input : op->getOperands()) + if (BlockArgument arg = llvm::dyn_cast(input)) + // TODO(bfontain): Check that this is actually an argument to the loop + // body. + non_tpu_args.insert(arg.getArgNumber()); + + // All funcs have a return op so this should be safe. + func::ReturnOp return_op = *loop_body_func.getOps().begin(); + + for (OpOperand& operand : return_op->getOpOperands()) { + if (non_tpu_args.contains(operand.getOperandNumber())) { + if (BlockArgument argument = + llvm::dyn_cast(operand.get())) { + if (argument.getArgNumber() != operand.getOperandNumber()) { + return_op.emitOpError() + << "non TPU ops do not divide state into two pieces."; + return LogicalResult::failure(); + } + } else if (!non_tpu_ops.contains(operand.get().getDefiningOp())) { + return_op.emitOpError() + << "non TPU ops do not divide state into two pieces."; + return LogicalResult::failure(); + } + } + } + return LogicalResult::success(); +} + +LogicalResult ExtractOpsAsFunc( + OpBuilder& builder, ModuleOp module, llvm::SetVector& ops, + StringAttr replication_attr, TF::TPUReplicateMetadataOp metadata_op, + TF::TPUCompilationResultOp compilation_op, func::FuncOp parent_func, + const std::string& func_name, Operation** caller) { + // Move the given set of 'ops' into it's own function and replace them with a + // call to that function ('caller'). if 'metadata_op' and 'compilation_op' are + // non-null, also insert those (i.e., target the resulting function to the + // TPU). Here, 'parent_func' is the func.FuncOp that owns the ops in 'ops'. + // + // Returns in 'caller' a tf.StatefulPartitionedCallOp that calls the function + // that was extracted.. + + // Find the input edges to form the set of operands to the new function call. + llvm::SetVector inputs; + for (Operation* op : ops) { + for (Value operand : op->getOperands()) { + Operation* defining_op = operand.getDefiningOp(); + if (!ops.contains(defining_op)) inputs.insert(operand); + } + } + // Find the output edges to form the set of resutls of the new function call. + llvm::SetVector results; + for (Operation* op : ops) { + for (auto result : op->getResults()) { + for (const OpOperand& operand : result.getUsers()) { + if (!ops.contains(operand.getOwner())) { + results.insert(result); + break; + } + } + } + } + llvm::SetVector outputs; + for (auto output : results) outputs.insert(output); + auto tf_caller = EncapsulateOpsInFunc(builder, ops, inputs, outputs, + parent_func, module, func_name); + if (!ops.empty() && metadata_op != nullptr && compilation_op != nullptr) + UpdateAndInsertTPUOps(tf_caller, metadata_op, compilation_op, + replication_attr); + *caller = tf_caller; + return LogicalResult::success(); +} + +void EmbeddingSequencingPass::runOnOperation() { + ModuleOp module = getOperation(); + + llvm::SetVector forward_pass_ops; + llvm::SetVector backward_pass_ops; + + // Find all ops that we know compose the embedding forward and backward pass. + // These ops are only tagged if one enables the + // `pipeline_execution_with_tensor_core` flag in the mid-level API. + WalkResult walk_result = module.walk([&](Operation* op) -> WalkResult { + if (op->hasAttr(kEmbeddingPipelining)) { + const std::string region = + op->getAttrOfType(kEmbeddingPipelining).getValue().str(); + if (region == kEmbeddingForward) { + forward_pass_ops.insert(op); + } else if (region == kEmbeddingBackward) { + backward_pass_ops.insert(op); + } else { + return op->emitOpError() + << "embedding op has unknown " << kEmbeddingPipelining + << " attribute value " << region << "."; + } + op->removeAttr(kEmbeddingPipelining); + } + return WalkResult::advance(); + }); + if (walk_result.wasInterrupted()) return signalPassFailure(); + + // If there are no forward pass ops, there is no SC, so we end early. + if (forward_pass_ops.empty()) { + if (backward_pass_ops.empty()) { + return; + } else { + (*backward_pass_ops.begin())->emitOpError() + << "embedding backwards pass op with no forwards pass ops."; + return signalPassFailure(); + } + } + + // Ensure that all ops are in the same region, and have the same replication + // info. + // TODO(bfontain): Allow for multiple regions/loops in one module. + // TODO(patn): move this pass after cluster formation to remove the complexity + // with replication info and metadata, cluster checking and generalizing to + // multiple TPU clusters. + Region* region = (*forward_pass_ops.begin())->getParentRegion(); + StringAttr replication_attr = GetReplicationAttr(*forward_pass_ops.begin()); + llvm::SmallVector checkset(forward_pass_ops.getArrayRef()); + checkset.append(backward_pass_ops.begin(), backward_pass_ops.end()); + for (Operation* op : checkset) { + if (op->getParentRegion() != region) { + op->emitOpError() << "embedding ops in two different regions"; + return signalPassFailure(); + } + if (GetReplicationAttr(op) != replication_attr) { + op->emitOpError() << "embedding ops with different replication info " + << replication_attr << " vs " << GetReplicationAttr(op); + return signalPassFailure(); + } + } + + // TODO(bfontain): Check that the region here is the region + // of the loop body func. + // Find the FuncOp for the surrounding while loop body. + func::FuncOp loop_body_func = + (*forward_pass_ops.begin())->getParentOfType(); + + // merged_set will keep track of which ops are to be avoided when gather ops + // for inclusion into the four extracted functions. + llvm::SetVector merged_set; + + // Find the TPUReplicationMetadata and TPUCompilationResult ops and delete + // them. These will be cloned/inserted into each region. + TF::TPUReplicateMetadataOp metadata_op; + auto result = FindAndExcludeOp(loop_body_func, replication_attr, merged_set, + metadata_op); + if (failed(result)) return signalPassFailure(); + const int num_replicas = metadata_op.getNumReplicas(); + + TF::TPUCompilationResultOp compilation_op; + result = FindAndExcludeOp( + loop_body_func, replication_attr, merged_set, compilation_op); + if (failed(result)) return signalPassFailure(); + + TF::WhileOp while_op = nullptr; + result = FindOwningWhileOp(loop_body_func, module, &while_op); + if (failed(result)) return signalPassFailure(); + + OpBuilder builder(module); + + result = FindForwardPassOps(builder, forward_pass_ops, backward_pass_ops, + merged_set, loop_body_func, num_replicas); + if (failed(result)) return signalPassFailure(); + merged_set.insert(forward_pass_ops.begin(), forward_pass_ops.end()); + + result = + FindBackwardPassOps(builder, backward_pass_ops, merged_set, num_replicas); + if (failed(result)) return signalPassFailure(); + merged_set.insert(backward_pass_ops.begin(), backward_pass_ops.end()); + + llvm::SetVector core_tpu_ops; + result = FindCoreTPUOps(core_tpu_ops, forward_pass_ops, backward_pass_ops, + merged_set, loop_body_func); + if (failed(result)) return signalPassFailure(); + merged_set.insert(core_tpu_ops.begin(), core_tpu_ops.end()); + + llvm::SetVector non_tpu_ops; + result = FindNonTPUOps(non_tpu_ops, merged_set, loop_body_func); + if (failed(result)) return signalPassFailure(); + merged_set.insert(non_tpu_ops.begin(), non_tpu_ops.end()); + + VLOG(2) << "Forwards pass " << forward_pass_ops.size() + << " ops, backwards pass " << backward_pass_ops.size() + << " ops, core " << core_tpu_ops.size() + << " ops. Total = " << merged_set.size() << " of " + << GetNumOps(loop_body_func) << ".\n"; + + builder.setInsertionPointAfter(*non_tpu_ops.begin()); + Operation* non_tpu_caller = nullptr; + result = + ExtractOpsAsFunc(builder, module, non_tpu_ops, replication_attr, nullptr, + nullptr, loop_body_func, "non_tpu", &non_tpu_caller); + if (failed(result)) return signalPassFailure(); + + builder.setInsertionPointAfter(non_tpu_caller); + Operation* forward_caller = nullptr; + result = ExtractOpsAsFunc(builder, module, forward_pass_ops, replication_attr, + metadata_op, compilation_op, loop_body_func, + "sc_forward", &forward_caller); + if (failed(result)) return signalPassFailure(); + + // Create tpu_core function + builder.setInsertionPointAfter(forward_caller); + Operation* core_tpu_caller = nullptr; + result = ExtractOpsAsFunc(builder, module, core_tpu_ops, replication_attr, + metadata_op, compilation_op, loop_body_func, + "core_tpu", &core_tpu_caller); + if (failed(result)) return signalPassFailure(); + + builder.setInsertionPointAfter(core_tpu_caller); + Operation* backwards_pass_caller = nullptr; + result = ExtractOpsAsFunc( + builder, module, backward_pass_ops, replication_attr, metadata_op, + compilation_op, loop_body_func, "sc_backward", &backwards_pass_caller); + if (failed(result)) return signalPassFailure(); + + metadata_op->erase(); + compilation_op->erase(); +} + +} // namespace + +std::unique_ptr> CreateEmbeddingSequencingPass() { + return std::make_unique(); +} + +} // namespace TFDevice +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/extract_head_tail_outside_compilation.cc index 58f2e62df2f..863bdc6b635 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/extract_head_tail_outside_compilation.cc @@ -53,10 +53,15 @@ namespace TFDevice { namespace { +constexpr char kXlaMapOutsideCompilationAttr[] = "_xla_map_outside_compilation"; constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; +// Return true if `op` has attributes that say it can be outside compiled by +// this pass. This pass ignores _xla_map_outside_compilation, which will only be +// handled by extract_outside_compilation pass. bool HasOutsideCompilationAttribute(Operation* op) { - return op->getAttrOfType(kXlaOutsideCompilationAttr) != nullptr; + return op->getAttrOfType(kXlaOutsideCompilationAttr) != nullptr && + !op->hasAttrOfType(kXlaMapOutsideCompilationAttr); } // Finds op that created a given value. If the value is a BlockArgument, this diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc index 9c3e82e88e1..8b3acdf0063 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -32,10 +33,10 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project @@ -53,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" namespace mlir { @@ -62,6 +64,7 @@ namespace { constexpr char kDeviceAttr[] = "device"; constexpr char kHostFunctionAttr[] = "host_func"; +constexpr char kXlaMapOutsideCompilationAttr[] = "_xla_map_outside_compilation"; constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; constexpr char kNoReplicationCluster[] = "__no_replication_cluster"; @@ -444,23 +447,189 @@ void GetExternalOutputs(const llvm::SmallSetVector& cluster_ops, } } -// Creates the HostCompute with `inputs` and `outputs` -// using `communication_key`. -TF::_XlaHostComputeMlirOp CreateHostCompute( - OpBuilder& builder, Location loc, - const llvm::SmallSetVector& inputs, llvm::ArrayRef outputs, - llvm::StringRef args_communication_key, - llvm::StringRef retvals_communication_key, - llvm::StringRef serialized_func_module) { +// Output `shard_type`, which is the type of each shard, given `full_type`. If +// the full shape is (num_cores_per_replica * a, b, c), then the shard shape is +// (a, b, c). `context_op` is used for error reporting, in case of errors. +LogicalResult GetShardShapedType(Operation* context_op, + int num_cores_per_replica, Type full_type, + Type& shard_type) { + RankedTensorType ranked_type = full_type.dyn_cast(); + if (!ranked_type) + return context_op->emitOpError() + << "A map_outside_compilation op's input and output types must be " + "ranked tensors."; + ArrayRef in_shape = ranked_type.getShape(); + if (in_shape.empty() || in_shape[0] < 0) { + return context_op->emitOpError() + << "A map_outside_compilation op's input and output shapes must " + "have rank at least one and the first dimension must be known."; + } + int64_t split_size = in_shape[0] / num_cores_per_replica; + if (in_shape[0] % num_cores_per_replica != 0) { + return context_op->emitOpError() + << "A map_outside_compilation op's input and output shapes must be " + "divisible by num_cores_per_replica=" + << num_cores_per_replica; + } + llvm::SmallVector shape; + shape.push_back(split_size); + for (int i = 1; i < in_shape.size(); ++i) { + shape.push_back(in_shape[i]); + } + shard_type = RankedTensorType::Builder(ranked_type).setShape(shape); + return success(); +} + +// Output `sharding`, which is the sharding of `val`. `context_op` is used for +// error reporting, in case of errors. +// TODO(b/255350483): Explicitly pass the sharding to map_outside_compilation, +// so it does not need to be retrieved from a Value. +LogicalResult GetShardingOfValue(Operation* context_op, Value val, + std::string& sharding) { + Operation* op = val.getDefiningOp(); + // val should always have a defining op because cluster inputs always have + // defining ops. + assert(op); + StringAttr sharding_attr = op->getAttrOfType("_XlaSharding"); + if (!sharding_attr) + return context_op->emitOpError() + << "A map_outside_compilation op's input should have an explicit " + "sharding. There is no _XlaSharding attribute on the input op."; + sharding = sharding_attr.str(); + return success(); +} + +// Create an `_XlaHostComputeMlir` for the map_outside_compilation case. Inputs +// are converted from split sharding to MANUAL sharding and outputs are +// converted from MANUAL sharding to split sharding. Output `full_outputs`, +// which is the outputs of the `_XlaHostComputeMlir` and add the +// `_XlaHostComputeMlir` to `host_compute_out_ops`. +LogicalResult CreateHostComputeMap( + Operation* original_op, OpBuilder& builder, Location loc, + ArrayRef inputs, ArrayRef outputs, + StringRef args_communication_key, StringRef retvals_communication_key, + StringRef serialized_func_module, int num_cores_per_replica, + SmallVector& full_outputs, + SmallVector& host_compute_out_ops) { + // Get output types. + llvm::SmallVector shard_output_types; + llvm::SmallVector full_output_types; + shard_output_types.reserve(outputs.size()); + full_output_types.reserve(outputs.size()); + for (const auto& output : outputs) { + Type shard_type; + if (failed(GetShardShapedType(original_op, num_cores_per_replica, + output.getType(), shard_type))) + return failure(); + shard_output_types.push_back(shard_type); + full_output_types.push_back(output.getType()); + } + + // There should be at least 1 input so common_split_sharding can be defined. + if (inputs.empty()) + return original_op->emitOpError() + << "map_outside_compilation should have at least one input"; + + // Convert split sharded inputs to MANUAL sharded inputs. + // common_split_sharding is the split sharding that is common to all inputs + // and outputs. + std::string common_split_sharding; + llvm::SmallVector manual_inputs; + manual_inputs.reserve(inputs.size()); + for (Value in : inputs) { + Type shard_type; + if (failed(GetShardShapedType(original_op, num_cores_per_replica, + in.getType(), shard_type))) + return failure(); + std::string in_sharding; + if (failed(GetShardingOfValue(original_op, in, in_sharding))) + return failure(); + if (common_split_sharding.empty()) { + common_split_sharding = std::move(in_sharding); + } else { + if (common_split_sharding != in_sharding) + return original_op->emitOpError() + << "All inputs and outputs of map_outside_compilation should " + "have the same sharding."; + } + auto in_manual = builder.create( + loc, shard_type, in, common_split_sharding, /*dim=*/-1, + /*unspecified_dims=*/builder.getI64ArrayAttr({})); + manual_inputs.push_back(in_manual); + } + + // Create the _XlaHostComputeMlirOp + auto host_compute = builder.create( + loc, shard_output_types, manual_inputs, + /*send_key=*/builder.getStringAttr(args_communication_key), + /*recv_key=*/builder.getStringAttr(retvals_communication_key), + /*host_mlir_module=*/builder.getStringAttr(serialized_func_module), + /*manual_sharding=*/builder.getBoolAttr(true)); + host_compute_out_ops.push_back(host_compute); + + // Convert MANUAL sharded outputs to split sharded outputs. + for (auto [full_type, out] : + llvm::zip(full_output_types, host_compute.getResults())) { + RankedTensorType full_type_ranked = full_type.dyn_cast(); + if (!full_type_ranked) + return original_op->emitOpError() + << "map_outside_compilation must have ranked outputs"; + auto out_full = builder.create( + loc, full_type, out, common_split_sharding, full_type_ranked.getShape(), + /*dim=*/-1, + /*unspecified_dims=*/builder.getI64ArrayAttr({})); + host_compute_out_ops.push_back(out_full); + full_outputs.push_back(out_full); + } + + return success(); +} + +// Create the _XlaHostComputeMlir with `inputs` and `outputs` for the ordinary +// outside_compilation case. +// Output `full_outputs`, which is the outputs of the `_XlaHostComputeMlir` and +// add the `_XlaHostComputeMlir` to `host_compute_out_ops`. +void CreateHostComputeNotMap(OpBuilder& builder, Location loc, + ArrayRef inputs, ArrayRef outputs, + StringRef args_communication_key, + StringRef retvals_communication_key, + StringRef serialized_func_module, + SmallVector& full_outputs, + SmallVector& host_compute_out_ops) { llvm::SmallVector device_output_types; for (const auto& output : outputs) device_output_types.push_back(output.getType()); auto host_compute = builder.create( - loc, device_output_types, inputs.getArrayRef(), + loc, device_output_types, inputs, builder.getStringAttr(args_communication_key), builder.getStringAttr(retvals_communication_key), /*host_mlir_module=*/builder.getStringAttr(serialized_func_module)); - return host_compute; + host_compute_out_ops.push_back(host_compute); + for (Value v : host_compute.getResults()) full_outputs.push_back(v); +} + +// Create the _XlaHostComputeMlir with `inputs` and `outputs`. +// Output `full_outputs`, which is the outputs of the `_XlaHostComputeMlir` and +// add the `_XlaHostComputeMlir` to `host_compute_out_ops`. +LogicalResult CreateHostCompute( + Operation* original_op, OpBuilder& builder, Location loc, + ArrayRef inputs, ArrayRef outputs, + StringRef args_communication_key, StringRef retvals_communication_key, + StringRef serialized_func_module, bool is_map_oc, int num_cores_per_replica, + SmallVector& full_outputs, + SmallVector& host_compute_out_ops) { + if (is_map_oc) { + return CreateHostComputeMap( + original_op, builder, loc, inputs, outputs, args_communication_key, + retvals_communication_key, serialized_func_module, + num_cores_per_replica, full_outputs, host_compute_out_ops); + } else { + CreateHostComputeNotMap(builder, loc, inputs, outputs, + args_communication_key, retvals_communication_key, + serialized_func_module, full_outputs, + host_compute_out_ops); + return success(); + } } void MarkOutsideCompiled(Operation* op) { @@ -498,10 +667,10 @@ bool ShouldCloseCluster(llvm::ArrayRef outputs) { // region as insertion. // For static-shapes, Replace operand usages if op is in the same region as // insertion or if the op is outside compiled and will be moved to host later. -void ReplaceExternalOperandUsage( - const llvm::SmallSetVector& external_operands, - Operation* recv_at_host, Operation* insertion_point, - Block* original_op_block) { +void ReplaceExternalOperandUsage(ArrayRef external_operands, + Operation* recv_at_host, + Operation* insertion_point, + Block* original_op_block) { auto replace_operand_usage = [&](OpOperand& operand) { if (TF::CanBeRefined(operand.get().getType()) || HasDynamicOutputs(operand.getOwner())) { @@ -531,10 +700,9 @@ bool HasDynamicOutputs(llvm::ArrayRef outputs) { // Replaces usages of `external_outputs` which are values returned by outside // compilation with the corresponding outputs from `host_compute`. -void ReplaceExternalOutputUsage( - const llvm::SmallSetVector& external_outputs, - TF::_XlaHostComputeMlirOp host_compute) { - bool has_dynamic_outputs = HasDynamicOutputs(external_outputs.getArrayRef()); +void ReplaceExternalOutputUsage(ArrayRef external_outputs, + ArrayRef host_compute_outputs) { + bool has_dynamic_outputs = HasDynamicOutputs(external_outputs); auto replace_output_usage = [&](OpOperand& operand) { // Don't replace output usages if in host computation (defining op and user @@ -551,25 +719,16 @@ void ReplaceExternalOutputUsage( !HasOutsideCompilationAncestor(operand.getOwner()); } }; - for (auto result : llvm::zip(external_outputs, host_compute.getResults())) { + for (auto result : llvm::zip(external_outputs, host_compute_outputs)) { Value external_output = std::get<0>(result); external_output.replaceUsesWithIf(std::get<1>(result), replace_output_usage); } } -// Move `clustered_ops` to run on host and adds communication ops to transfer -// `external_operands` and `external_outputs` to/from device/host. Inserts -// ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when -// creating comm ops. -void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, - const llvm::SmallSetVector& external_operands, - const llvm::SmallSetVector& external_outputs, - Operation* insertion_point, Value compilation_key, - Value device_ordinal, int default_device_ordinal, - StringAttr device_type_attr, int& communication_key_index) { - OpBuilder builder(insertion_point); - Operation& op = *clustered_ops.back(); +std::pair MakeCommunicationKeys( + ArrayRef clustered_ops, ArrayRef external_operands, + int communication_key_index, Operation& op) { std::string args_communication_key = llvm::formatv("host_compute_channel_{0}_args", (communication_key_index)) .str(); @@ -586,22 +745,22 @@ void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, llvm::formatv("if_predicate_channel_{0}", (communication_key_index)) .str(); } + return std::pair(args_communication_key, retvals_communication_key); +} - std::string serialized_func_module; - if (HasDynamicOutputs(external_outputs.getArrayRef())) { - func::FuncOp shape_op = BuildFunction( - clustered_ops.getArrayRef(), external_operands.getArrayRef(), - external_outputs.getArrayRef(), &builder); - EncapsulateFuncAndSerialize(shape_op, &serialized_func_module); - } - - builder.setInsertionPoint(&op); - auto host_compute = - CreateHostCompute(builder, op.getLoc(), external_operands, - external_outputs.getArrayRef(), args_communication_key, - retvals_communication_key, serialized_func_module); - // Insert ops on the host side computation to receive data from device. - builder.setInsertionPoint(insertion_point); +// Add ops to the host-side. These are `RecvAtHost`, `clustered_ops` moved from +// device cluster, `SendFromHost`. Add these host-side ops to `host_ops`. Return +// the `RecvAtHost` op. +Operation* CreateHostOps(ArrayRef clustered_ops, + ArrayRef external_operands, + ArrayRef external_outputs, + Operation* host_insertion_point, Value compilation_key, + Value device_ordinal, int default_device_ordinal, + StringAttr device_type_attr, OpBuilder& builder, + Operation& op, std::string args_communication_key, + std::string retvals_communication_key, + SmallVector& host_ops) { + builder.setInsertionPoint(host_insertion_point); llvm::SmallVector host_operand_types; for (const auto& operand : external_operands) host_operand_types.push_back(operand.getType()); @@ -609,37 +768,174 @@ void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, Operation* recv_at_host = CreateRecvAtHostOp( builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal, default_device_ordinal, device_type_attr, args_communication_key); - Block* original_op_block = op.getBlock(); + + if (!external_operands.empty()) host_ops.push_back(recv_at_host); Operation* after_op = recv_at_host; for (Operation* cluster_op : clustered_ops) { cluster_op->moveAfter(after_op); cluster_op->removeAttr(StringAttr::get(op.getContext(), kDeviceAttr)); after_op = cluster_op; + host_ops.push_back(cluster_op); } if (!external_outputs.empty()) { - CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(), - compilation_key, device_ordinal, - default_device_ordinal, device_type_attr, - retvals_communication_key); + Operation* send_from_host = CreateSendFromHostOp( + builder, op.getLoc(), external_outputs, compilation_key, device_ordinal, + default_device_ordinal, device_type_attr, retvals_communication_key); + host_ops.push_back(send_from_host); } + return recv_at_host; +} + +// Clone the first outside compiled region to one for each TPU core. This is +// used for map_outside_compilation. +// Message identification arguments to RecvAtHost and SendFromHost are changed. +void CloneFirstHost(ArrayRef core_to_host_insertion_point, + ArrayRef core_to_compilation_key, + ArrayRef core_to_device_ordinal, + int num_cores_per_replica, ArrayRef host0_ops, + OpBuilder& builder) { + for (int core = 1; core < num_cores_per_replica; ++core) { + IRMapping mapper; + for (Operation* op : host0_ops) { + builder.setInsertionPoint(core_to_host_insertion_point[core]); + Operation* clone = builder.clone(*op, mapper); + mapper.map(op, clone); + if (auto recv_at_host = llvm::dyn_cast(clone)) { + recv_at_host.setDeviceOrdinal(core); + clone->setOperand(0, core_to_compilation_key[core]); + } else if (auto send_from_host = + llvm::dyn_cast(clone)) { + send_from_host.setDeviceOrdinal(core); + clone->setOperand(1, core_to_compilation_key[core]); + } else if (auto recv_at_host = + llvm::dyn_cast(clone)) { + recv_at_host.setOperand(0, core_to_compilation_key[core]); + builder.setInsertionPoint(recv_at_host); + // core_ordinal = device_ordinal + core + // where device_ordinal is the base device for the replica + Value device_ordinal = core_to_device_ordinal[core]; + Value const_core = builder.create( + recv_at_host.getLoc(), builder.getI64IntegerAttr(core)); + Value core_ordinal = builder.create( + recv_at_host.getLoc(), device_ordinal.getType(), device_ordinal, + const_core); + recv_at_host.setOperand(1, core_ordinal); + } else if (auto send_from_host = + llvm::dyn_cast(clone)) { + send_from_host.setOperand(1, core_to_compilation_key[core]); + builder.setInsertionPoint(send_from_host); + // core_ordinal = device_ordinal + core + // where device_ordinal is the base device for the replica + Value device_ordinal = core_to_device_ordinal[core]; + Value const_core = builder.create( + send_from_host.getLoc(), builder.getI64IntegerAttr(core)); + Value core_ordinal = builder.create( + send_from_host.getLoc(), device_ordinal.getType(), device_ordinal, + const_core); + send_from_host.setOperand(2, core_ordinal); + } + } + } +} + +// Move `clustered_ops` to run on host and adds communication ops to transfer +// `external_operands` and `external_outputs` to/from device/host. Inserts +// ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when +// creating comm ops. +LogicalResult MoveToHostSingleCluster( + ArrayRef clustered_ops, ArrayRef external_operands, + ArrayRef external_outputs, + ArrayRef core_to_host_insertion_point, + ArrayRef core_to_compilation_key, + ArrayRef core_to_device_ordinal, int default_device_ordinal, + StringAttr device_type_attr, bool is_map_oc, int num_cores_per_replica, + int& communication_key_index) { + OpBuilder builder(core_to_host_insertion_point[0]); + Operation& op = *clustered_ops.back(); + Block* original_op_block = op.getBlock(); + auto [args_communication_key, retvals_communication_key] = + MakeCommunicationKeys(clustered_ops, external_operands, + communication_key_index, op); + + std::string serialized_func_module; + if (HasDynamicOutputs(external_outputs)) { + func::FuncOp shape_op = BuildFunction(clustered_ops, external_operands, + external_outputs, &builder); + EncapsulateFuncAndSerialize(shape_op, &serialized_func_module); + } + + builder.setInsertionPoint(&op); + SmallVector host_compute_outputs; + SmallVector host_compute_out_ops; + if (failed(CreateHostCompute( + &op, builder, op.getLoc(), external_operands, external_outputs, + args_communication_key, retvals_communication_key, + serialized_func_module, is_map_oc, num_cores_per_replica, + host_compute_outputs, host_compute_out_ops))) + return failure(); + + // Insert ops on the host side computation to receive data from device. + // host0_ops are the ops that will make up the first host process. In the + // map_outside_compilation case, there are multiple host processes, which will + // be created by cloning. + SmallVector host0_ops; + Operation* recv_at_host = CreateHostOps( + clustered_ops, external_operands, external_outputs, + core_to_host_insertion_point[0], core_to_compilation_key[0], + core_to_device_ordinal.empty() ? nullptr : core_to_device_ordinal[0], + default_device_ordinal, device_type_attr, builder, op, + args_communication_key, retvals_communication_key, host0_ops); + if (external_operands.empty()) { recv_at_host->erase(); } else { - ReplaceExternalOperandUsage(external_operands, - /*recv_at_host=*/recv_at_host, - /*insertion_point=*/insertion_point, - /*original_op_block=*/original_op_block); + ReplaceExternalOperandUsage( + external_operands, recv_at_host, + /*insertion_point=*/core_to_host_insertion_point[0], original_op_block); } - ReplaceExternalOutputUsage(external_outputs, host_compute); + ReplaceExternalOutputUsage(external_outputs, host_compute_outputs); + + // Clone the first outside compiled region to one for each TPU core. + if (is_map_oc) + CloneFirstHost(core_to_host_insertion_point, core_to_compilation_key, + core_to_device_ordinal, num_cores_per_replica, host0_ops, + builder); + + ReplaceExternalOutputUsage(external_outputs, host_compute_outputs); if (external_operands.empty() && external_outputs.empty()) { - host_compute.erase(); + for (Operation* op : host_compute_out_ops) op->erase(); } else { ++communication_key_index; } + + return success(); +} + +// Update is_map_oc the true if op has attribute _xla_map_outside_compilation +// and false otherwise. Check that this is consistent with the previous setting +// of is_map_oc. +LogicalResult UpdateIsMapOutsideCompilation(Operation& op, bool control_above, + std::optional& is_map_oc) { + bool op_is_map_oc = + op.hasAttrOfType(kXlaMapOutsideCompilationAttr); + if (is_map_oc) { + if (op_is_map_oc != *is_map_oc) { + return op.emitOpError() + << "Cannot mix map_outside_compilation with ordinary " + "outside_compilation in the same graph."; + } + } else { + is_map_oc = op_is_map_oc; + } + if (control_above && op_is_map_oc) { + return op.emitOpError() << "map_outside_compilation inside control flow " + "is not implemented."; + } + return success(); } // Move outside compiled ops in `src` to `insertion_point` in host @@ -649,13 +945,21 @@ void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, // `communication_key_index` which is incremented when used. Communication ops // are added only when needed and at the location need. There are checks to // ensure that duplicate communication between device and host is not added. -// When `return_value_from_host` is not nullptr, MoveOpsToHost will also update -// its value. -LogicalResult MoveOpsToHost( - tf_device::ClusterOp device_cluster, Block* src, Operation* insertion_point, - Value compilation_key, Value device_ordinal, int default_device_ordinal, +// When `return_value_from_host` is not nullptr, MoveToHostMultiCluster will +// also update its value. `control_above` means that this Block is within +// control flow, which is not currently supported with map_outside_compilation. +// `is_map_oc` tracks whether map_outside_compilation is used, for the whole +// program. Currently only map_outside_compilation-only or ordinary +// outside_compilation only is supported. +LogicalResult MoveToHostMultiCluster( + tf_device::ClusterOp device_cluster, Block* src, + ArrayRef core_to_host_insertion_point, + ArrayRef core_to_compilation_key, + ArrayRef core_to_device_ordinal, int default_device_ordinal, + bool control_above, std::optional& is_map_oc, int& communication_key_index, llvm::SmallVector* return_value_from_host = nullptr) { + int num_cores_per_replica = core_to_host_insertion_point.size(); // Contains all of the outside compiled operations that should be moved to the // host using a single `_XlaHostComputeMlir` op. This should only contain a // single op except in the case where some of the input/output shapes are @@ -669,6 +973,9 @@ LogicalResult MoveOpsToHost( !op.hasAttrOfType(kXlaOutsideCompilationAttr)) continue; + if (failed(UpdateIsMapOutsideCompilation(op, control_above, is_map_oc))) + return failure(); + llvm::SmallSetVector external_outputs; llvm::SmallVector host_outputs; // We want to move the clustered_ops if the op to be added has all @@ -684,10 +991,13 @@ LogicalResult MoveOpsToHost( return_value_from_host->push_back(output); } } - MoveOpsToHost(clustered_ops, external_operands, external_outputs, - insertion_point, compilation_key, device_ordinal, - default_device_ordinal, device_type_attr, - communication_key_index); + if (failed(MoveToHostSingleCluster( + clustered_ops.getArrayRef(), external_operands.getArrayRef(), + external_outputs.getArrayRef(), core_to_host_insertion_point, + core_to_compilation_key, core_to_device_ordinal, + default_device_ordinal, device_type_attr, *is_map_oc, + num_cores_per_replica, communication_key_index))) + return failure(); clustered_ops.clear(); } @@ -708,10 +1018,13 @@ LogicalResult MoveOpsToHost( } } - MoveOpsToHost(clustered_ops, external_operands, external_outputs, - insertion_point, compilation_key, device_ordinal, - default_device_ordinal, device_type_attr, - communication_key_index); + if (failed(MoveToHostSingleCluster( + clustered_ops.getArrayRef(), external_operands.getArrayRef(), + external_outputs.getArrayRef(), core_to_host_insertion_point, + core_to_compilation_key, core_to_device_ordinal, + default_device_ordinal, device_type_attr, *is_map_oc, + num_cores_per_replica, communication_key_index))) + return failure(); clustered_ops.clear(); } } @@ -736,27 +1049,34 @@ void GetReturnValueFromDevice( // (outside compiled) computation into two separate control flow ops with // communication between the device/host for data dependencies. Both device and // host control flow initially remain within `device_cluster` and a subsequency -// call to MoveOpsToHost moves the host side control flow to the host launch in -// tf_device.parallel_execute. Uses `compilation_key, `device_ordinal` and -// `communication_key_index` when creating communication ops. +// call to MoveToHostSingleCluster moves the host side control flow to the host +// launch in tf_device.parallel_execute. Uses `compilation_key, +// `device_ordinal` and `communication_key_index` when creating communication +// ops. LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, - Value compilation_key, Value device_ordinal, + ArrayRef core_to_compilation_key, + ArrayRef core_to_device_ordinal, int default_device_ordinal, - int& communication_key_index) { + int& communication_key_index, + std::optional& is_map_oc) { auto result = device_cluster.GetBody().walk([&](Operation* op) { if (auto if_op = llvm::dyn_cast(op)) { if (!HasOutsideCompilationNested(op)) return WalkResult::advance(); OpBuilder builder(if_op); auto host_if = CloneEmptyIfWithPredicate(if_op, builder); - if (failed(MoveOpsToHost( + if (failed(MoveToHostMultiCluster( device_cluster, &if_op.getThenBranch().front(), - host_if.getThenBranch().front().getTerminator(), compilation_key, - device_ordinal, default_device_ordinal, communication_key_index))) + {host_if.getThenBranch().front().getTerminator()}, + core_to_compilation_key, core_to_device_ordinal, + default_device_ordinal, /*control_above=*/true, is_map_oc, + communication_key_index))) return WalkResult::interrupt(); - if (failed(MoveOpsToHost( + if (failed(MoveToHostMultiCluster( device_cluster, &if_op.getElseBranch().front(), - host_if.getElseBranch().front().getTerminator(), compilation_key, - device_ordinal, default_device_ordinal, communication_key_index))) + {host_if.getElseBranch().front().getTerminator()}, + core_to_compilation_key, core_to_device_ordinal, + default_device_ordinal, /*control_above=*/true, is_map_oc, + communication_key_index))) return WalkResult::interrupt(); // Mark op as stateful due to side-effecting communication ops. if_op->setAttr("is_stateless", builder.getBoolAttr(false)); @@ -778,24 +1098,32 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, builder.setInsertionPoint(while_op.getCond().front().getTerminator()); builder.create(while_op.getLoc(), condition, condition_send_recv_key); + // device_ordinal0 is the ordinal of TPU_REPLICATED_CORE_0 and is only + // used in the replicated case. + Value device_ordinal0 = nullptr; + if (!core_to_device_ordinal.empty()) + device_ordinal0 = core_to_device_ordinal[0]; builder.setInsertionPointToEnd(&cond.front()); auto recv_condition_at_host = CreateRecvAtHostOp( builder, while_op.getLoc(), TypeRange{condition.getType()}, - compilation_key, device_ordinal, default_device_ordinal, + core_to_compilation_key[0], device_ordinal0, default_device_ordinal, device_cluster->getAttrOfType(TF::kCompileDeviceTypeAttr), condition_send_recv_key); builder.create(while_op.getLoc(), recv_condition_at_host->getResults()); - if (failed(MoveOpsToHost(device_cluster, &while_op.getCond().front(), - recv_condition_at_host, compilation_key, - device_ordinal, default_device_ordinal, - communication_key_index))) + if (failed(MoveToHostMultiCluster( + device_cluster, &while_op.getCond().front(), + {recv_condition_at_host}, core_to_compilation_key, + core_to_device_ordinal, default_device_ordinal, + /*control_above=*/true, is_map_oc, communication_key_index))) return WalkResult::interrupt(); - if (failed(MoveOpsToHost( + if (failed(MoveToHostMultiCluster( device_cluster, &while_op.getBody().front(), - host_while.getBody().front().getTerminator(), compilation_key, - device_ordinal, default_device_ordinal, communication_key_index))) + {host_while.getBody().front().getTerminator()}, + core_to_compilation_key, core_to_device_ordinal, + default_device_ordinal, /*control_above=*/true, is_map_oc, + communication_key_index))) return WalkResult::interrupt(); // Mark op as stateful due to side-effecting communication ops. while_op->setAttr("is_stateless", builder.getBoolAttr(false)); @@ -859,8 +1187,8 @@ LogicalResult GetDefaultDeviceOrdinal(tf_device::ClusterOp device_cluster, // The results of parallel executes is the combination of return values from // both host and device. llvm::SmallVector GetParallelExecuteResultsTypes( - const llvm::SmallVector& return_value_from_host, - const llvm::SmallVector& return_value_from_device) { + ArrayRef return_value_from_host, + ArrayRef return_value_from_device) { llvm::SmallVector parallel_execute_result_types; const int num_of_outputs = return_value_from_host.size() + return_value_from_device.size(); @@ -939,7 +1267,7 @@ void RemapDeviceClusterResultsWithParallelExecuteResults( // Get the vector of results for new device cluster llvm::SmallVector GetNewDeviceResults( - const llvm::SmallVector& return_value_from_device) { + ArrayRef return_value_from_device) { llvm::SmallVector device_results; device_results.reserve(return_value_from_device.size()); for (Value old_result : return_value_from_device) @@ -949,7 +1277,7 @@ llvm::SmallVector GetNewDeviceResults( // Get the vector of types of results for new device cluster llvm::SmallVector GetNewDeviceTypes( - const llvm::SmallVector& return_value_from_device) { + ArrayRef return_value_from_device) { llvm::SmallVector device_result_types; device_result_types.reserve(return_value_from_device.size()); for (Value old_result : return_value_from_device) @@ -983,10 +1311,11 @@ void MoveTmpLaunchOpToNewLaunchOp(tf_device::LaunchOp tmp_host_launch_op, // Still, one region is for the host computation for outside compilation and // the other one is for the original Device cluster computation. tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( - OpBuilder& builder, int num_regions, llvm::StringRef host_device, - tf_device::ClusterOp device_cluster, tf_device::LaunchOp tmp_host_launch_op, - const llvm::SmallVector& return_value_from_host, - const llvm::SmallVector& return_value_from_device) { + OpBuilder& builder, int num_regions, ArrayRef core_to_host, + tf_device::ClusterOp device_cluster, + ArrayRef core_to_tmp_host_launch, + ArrayRef return_value_from_host, + ArrayRef return_value_from_device) { llvm::SmallVector parallel_execute_result_types = GetParallelExecuteResultsTypes(return_value_from_host, return_value_from_device); @@ -994,25 +1323,35 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( builder.setInsertionPoint(device_cluster); auto parallel_execute_op = builder.create( device_cluster.getLoc(), num_regions, parallel_execute_result_types); - Block& host_computation_block = - parallel_execute_op.GetRegionBlockWithIndex(0); - builder.setInsertionPointToEnd(&host_computation_block); + SmallVector core_to_host_launch; + for (int core = 0; core < core_to_tmp_host_launch.size(); ++core) { + Block& host_computation_block = + parallel_execute_op.GetRegionBlockWithIndex(core); + builder.setInsertionPointToEnd(&host_computation_block); - // Create a single launch op for all outside compiled ops. - llvm::SmallVector host_results; - host_results.insert(host_results.end(), return_value_from_host.begin(), - return_value_from_host.end()); - tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster( - builder, device_cluster, host_device, host_results); + // map_outside_compilation with return values from host is not implemented. + // This would only be needed if head-tail-outside-compilation supports + // map_outside_compilation"; + assert(core == 0 || return_value_from_host.empty()); - // Create a return op for host computation block - builder.setInsertionPointToEnd(&host_computation_block); - builder.create(device_cluster.getLoc(), - host_launch_op->getResults()); + // Create a single launch op for all outside compiled ops. + llvm::SmallVector host_results; + host_results.insert(host_results.end(), return_value_from_host.begin(), + return_value_from_host.end()); + tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster( + builder, device_cluster, core_to_host[core], host_results); + core_to_host_launch.push_back(host_launch_op); + + // Create a return op for host computation block + builder.setInsertionPointToEnd(&host_computation_block); + builder.create(device_cluster.getLoc(), + host_launch_op->getResults()); + } // Move the launch body to last parallel_execute block. Block& parallel_execute_device_block = - parallel_execute_op.GetRegionBlockWithIndex(1); + parallel_execute_op.GetRegionBlockWithIndex( + core_to_tmp_host_launch.size()); builder.setInsertionPointToEnd(¶llel_execute_device_block); // Get the vector of results and types of results for new device cluster @@ -1042,8 +1381,13 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( MoveOldTpuClusterToNewTpuCluster(device_cluster, after_op_r); - Operation* after_op_host_cluster = host_launch_op.GetBody().getTerminator(); - MoveTmpLaunchOpToNewLaunchOp(tmp_host_launch_op, after_op_host_cluster); + // Move each host-side Launch op. + for (int core = 0; core < core_to_tmp_host_launch.size(); ++core) { + Operation* after_op_host_cluster = + core_to_host_launch[core].GetBody().getTerminator(); + MoveTmpLaunchOpToNewLaunchOp(core_to_tmp_host_launch[core], + after_op_host_cluster); + } return parallel_execute_op; } @@ -1052,102 +1396,121 @@ tf_device::ParallelExecuteOp CreateFinalParallelExecuteOp( // a region for `device_cluster` computation by extracting outside compiled ops // to host computation. LogicalResult CreateParallelExecuteForOutsideCompilation( - ModuleOp module, tf_device::ClusterOp device_cluster, - llvm::StringRef host_device, + tf_device::ClusterOp device_cluster, llvm::SmallVector& ops, + std::optional& is_map_oc, ArrayRef core_to_host, bool has_tpu_device) { OpBuilder builder(device_cluster); llvm::SmallVector returns_from_host; // Create a temporary parallel_execute. This is temporary because the result - // type is not determined until after it is filled. There are two regions in - // `tmp_parallel_execute_op`. The first one is for the host computation for - // outside compilation and the second one is for the original Device cluster - // computation. - const int num_regions = 2; + // type is not determined until after it is filled. The parallel_execute has + // `num_host_regions` assigned to hosts and 1 region for the Device cluster. + // In the ordinary outside compilation case `num_host_regions` is 1 and in the + // `map_outside_compilation` case `num_host_regions == num_cores_per_replica`. + const int num_host_regions = core_to_host.size(); + const int num_regions = 1 + num_host_regions; auto tmp_parallel_execute_op = builder.create( device_cluster.getLoc(), num_regions, llvm::ArrayRef{}); - Block& tmp_host_computation_block = - tmp_parallel_execute_op.GetRegionBlockWithIndex(0); - builder.setInsertionPointToEnd(&tmp_host_computation_block); + SmallVector core_to_host_insertion_point; + SmallVector core_to_tmp_launch; + SmallVector compilation_key_ops; + SmallVector core_to_compilation_key; + SmallVector core_to_device_ordinal_op; + SmallVector core_to_device_ordinal; + for (int core = 0; core < num_host_regions; ++core) { + Block& tmp_host_computation_block = + tmp_parallel_execute_op.GetRegionBlockWithIndex(core); + builder.setInsertionPointToEnd(&tmp_host_computation_block); + // Create a single tmp launch op for all outside compiled ops. + llvm::SmallVector tmp_host_results; + tf_device::LaunchOp tmp_host_launch_op = CreateLaunchOpForOutsideCluster( + builder, device_cluster, core_to_host[core], tmp_host_results); + core_to_tmp_launch.push_back(tmp_host_launch_op); + // Create a tmp return op for tmp host computation block + builder.setInsertionPointToEnd(&tmp_host_computation_block); + builder.create(device_cluster.getLoc(), + llvm::ArrayRef{}); + core_to_host_insertion_point.push_back( + tmp_host_launch_op.GetBody().getTerminator()); - // Create a single tmp launch op for all outside compiled ops. - llvm::SmallVector tmp_host_results; - tf_device::LaunchOp tmp_host_launch_op = CreateLaunchOpForOutsideCluster( - builder, device_cluster, host_device, tmp_host_results); + builder.setInsertionPoint(tmp_host_launch_op.GetBody().getTerminator()); - // Create a tmp return op for tmp host computation block - builder.setInsertionPointToEnd(&tmp_host_computation_block); - builder.create(device_cluster.getLoc(), - llvm::ArrayRef{}); - - builder.setInsertionPoint(tmp_host_launch_op.GetBody().getTerminator()); - - Operation* compilation_key_op = nullptr; - Value compilation_key = nullptr; - Operation* device_ordinal_op = nullptr; - - if (has_tpu_device) { - compilation_key_op = - CreateCompilationKeyPlaceholder(device_cluster.getLoc(), builder); - compilation_key = - llvm::dyn_cast( - compilation_key_op) - .getProgram(); - device_ordinal_op = builder.create( - device_cluster.getLoc(), - RankedTensorType::get({}, builder.getI64Type())); - } else { - compilation_key_op = - CreateCpuGpuComilationKeyPlaceholder(device_cluster.getLoc(), builder); - compilation_key = - llvm::dyn_cast(compilation_key_op->getResults()[0]); - device_ordinal_op = builder.create( - device_cluster.getLoc(), - DenseIntElementsAttr::get( - RankedTensorType::get({}, builder.getI64Type()), - static_cast(0))); + // Create message identification ops. + Operation* compilation_key_op = nullptr; + Value compilation_key = nullptr; + Operation* device_ordinal_op = nullptr; + if (has_tpu_device) { + compilation_key_op = + CreateCompilationKeyPlaceholder(device_cluster.getLoc(), builder); + compilation_key = + llvm::dyn_cast( + compilation_key_op) + .getProgram(); + device_ordinal_op = builder.create( + device_cluster.getLoc(), + RankedTensorType::get({}, builder.getI64Type())); + } else { + compilation_key_op = CreateCpuGpuComilationKeyPlaceholder( + device_cluster.getLoc(), builder); + compilation_key = + llvm::dyn_cast(compilation_key_op->getResults()[0]); + device_ordinal_op = builder.create( + device_cluster.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({}, builder.getI64Type()), + static_cast(0))); + } + compilation_key_ops.push_back(compilation_key_op); + core_to_compilation_key.push_back(compilation_key); + core_to_device_ordinal_op.push_back(device_ordinal_op); + if (device_cluster->getParentOfType()) + core_to_device_ordinal.push_back( + core_to_device_ordinal_op[core]->getResults()[0]); } - Value device_ordinal = nullptr; - - if (device_cluster->getParentOfType()) { - device_ordinal = device_ordinal_op->getResults()[0]; - } + builder.setInsertionPoint(tmp_parallel_execute_op); int default_device_ordinal = 0; if (failed(GetDefaultDeviceOrdinal(device_cluster, default_device_ordinal))) { return failure(); } + // communication_key_index is part of the message identifier and is + // incremented for each _XlaHostComputeMlir. int communication_key_index = 0; + // Decompose control flow into device and host control flow when outside // compilation is included. - if (failed(DecomposeControlFlow(device_cluster, compilation_key, - device_ordinal, default_device_ordinal, - communication_key_index))) + if (failed(DecomposeControlFlow( + device_cluster, core_to_compilation_key, core_to_device_ordinal, + default_device_ordinal, communication_key_index, is_map_oc))) return failure(); // Move all outside compiled ops including control flow to tmp host launch. // Also set the values returned from the host when ops are moved. - if (failed(MoveOpsToHost(device_cluster, &device_cluster.GetBody(), - tmp_host_launch_op.GetBody().getTerminator(), - compilation_key, device_ordinal, - default_device_ordinal, communication_key_index, - &returns_from_host))) + if (failed(MoveToHostMultiCluster( + device_cluster, &device_cluster.GetBody(), + core_to_host_insertion_point, core_to_compilation_key, + core_to_device_ordinal, default_device_ordinal, + /*control_above=*/false, is_map_oc, communication_key_index, + &returns_from_host))) return failure(); llvm::SmallVector returns_from_device; GetReturnValueFromDevice(device_cluster, returns_from_host, returns_from_device); - if (communication_key_index == 0) compilation_key_op->erase(); - if (communication_key_index == 0 || device_ordinal == nullptr) - device_ordinal_op->erase(); + // Remove unused message identification ops. + if (communication_key_index == 0) + for (auto op : compilation_key_ops) op->erase(); + if (communication_key_index == 0 || core_to_device_ordinal.empty()) + for (auto op : core_to_device_ordinal_op) op->erase(); - RemoveOutsideCompilation(tmp_host_launch_op); + for (tf_device::LaunchOp tmp_host_launch_op : core_to_tmp_launch) + RemoveOutsideCompilation(tmp_host_launch_op); tf_device::ParallelExecuteOp parallel_execute_op = - CreateFinalParallelExecuteOp(builder, num_regions, host_device, - device_cluster, tmp_host_launch_op, + CreateFinalParallelExecuteOp(builder, num_regions, core_to_host, + device_cluster, core_to_tmp_launch, returns_from_host, returns_from_device); ops.push_back(tmp_parallel_execute_op); @@ -1167,11 +1530,10 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) { for (OpResult result : cluster.getResults()) { if (!tensorflow::TypeValidForXLA(result.getType())) { - cluster.emitError() - << "The ExtractHeadTailOutsideCompilation pass produced a Device " - "cluster with a result with a non-XLA type: " - << result.getType(); - return failure(); + return cluster.emitError() + << "The ExtractHeadTailOutsideCompilation pass produced a Device " + "cluster with a result with a non-XLA type: " + << result.getType(); } } return success(); @@ -1185,11 +1547,10 @@ LogicalResult CheckAncestorNotOutsideComp(Operation* op) { Operation* iter_op = op; while (auto* parent_op = iter_op->getParentOp()) { if (parent_op->getAttrOfType(kXlaOutsideCompilationAttr)) { - op->emitOpError() - << "An op marked for outside compilation (having attribute " - << kXlaOutsideCompilationAttr - << ") has an ancestor marked for outside compilation."; - return failure(); + return op->emitOpError() + << "An op marked for outside compilation (having attribute " + << kXlaOutsideCompilationAttr + << ") has an ancestor marked for outside compilation."; } iter_op = parent_op; } @@ -1226,15 +1587,15 @@ void ExtractOutsideCompilation::runOnOperation() { return signalPassFailure(); llvm::SmallVector tmp_parallel_execute_ops; + std::optional is_map_oc; module.walk([&](tf_device::ClusterOp device_cluster) { if (HasOutsideCompilationNested(device_cluster.getOperation())) { - std::string host_device; - if (failed(tensorflow::GetHostDeviceOutsideComputation( - devices, device_cluster, &host_device))) + SmallVector core_to_host; + if (failed(tensorflow::GetDeviceToHostMap(device_cluster, core_to_host))) return signalPassFailure(); if (failed(CreateParallelExecuteForOutsideCompilation( - module, device_cluster, host_device, tmp_parallel_execute_ops, + device_cluster, tmp_parallel_execute_ops, is_map_oc, core_to_host, tensorflow::HasTPUDevice(devices)))) return signalPassFailure(); } @@ -1248,8 +1609,10 @@ void ExtractOutsideCompilation::runOnOperation() { // on ops outside of tf_device.cluster don't have any meaning and can lead to // errors later on. These ops were likely lifted out of the // tf_device.cluster in an earlier pass. - module.walk( - [](Operation* op) { op->removeAttr("_xla_outside_compilation"); }); + module.walk([](Operation* op) { + op->removeAttr(kXlaOutsideCompilationAttr); + op->removeAttr(kXlaMapOutsideCompilationAttr); + }); if (failed(CheckPostconditions(module))) return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 433308c7966..e8520cb932a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project @@ -50,6 +51,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -183,7 +185,7 @@ struct ConvertNdConvOp { auto num_spatial_dims = conv_op.getDimensionNumbers().getInputSpatialDimensions().size(); - // TODO(b/158636600): Currently we don't support 3D Convolution. + // TODO: b/158636600 - Currently we don't support 3D Convolution. if (num_spatial_dims != SupportedSpatialDims) return false; return true; @@ -204,6 +206,20 @@ class Convert1DConvOp : public OpConversionPattern, // Check that input is a supported 1d convolution. // + // stablehlo.convolution allows ops without window strides, where default + // value 1 will be set for each spatial dimension. However, window_strides + // are needed for mhlo.convolution -> tf.Conv2D conversion. Therefore, in + // this conversion path have a fallback to set window strides if not set. + if (!conv_op.getWindowStrides().has_value()) { + const int window_strides_size = + conv_op.getDimensionNumbers().getInputSpatialDimensions().size(); + std::vector window_strides_2d_array_default(window_strides_size, + 1); + DenseIntElementsAttr window_strides_2d_default = + rewriter.getI64TensorAttr(window_strides_2d_array_default); + conv_op.setWindowStridesAttr(window_strides_2d_default); + } + if (!IsSupportedConvOp(conv_op) || conv_op->getNumResults() != 1) return rewriter.notifyMatchFailure(conv_op, "unsupported conv op."); @@ -1219,62 +1235,14 @@ class ConvertDynamicUpdateSliceOp Type idx_type = start_indices_type.getElementType(); int64_t shape_dim = operand_type.getRank(); - auto operand_shape = operand_type.getShape(); - auto update_shape = update_type.getShape(); - - ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - Value zero_cst = BuildIntConstOp(builder, rewriter, 0, idx_type); - Value one_cst = BuildIntConstOp(builder, rewriter, 1, idx_type); - // Clamp start indices in [0, operand_size - update_size]. llvm::SmallVector start_indices_vector; Append(start_indices_vector, op.getStartIndices()); auto shape_tensor_type = RankedTensorType::get({shape_dim}, idx_type); - Value start_indices_tensor = - builder.create(shape_tensor_type, start_indices_vector); - Value operand_shape_cst = - BuildIntArrayConstOp(builder, rewriter, operand_shape, idx_type); - Value update_shape_cst = - BuildIntArrayConstOp(builder, rewriter, update_shape, idx_type); - Value max_start_indices = - builder.create(operand_shape_cst, update_shape_cst); - Value start_indices_clip_max = - builder.create(start_indices_tensor, max_start_indices); - Value clamped_start_indices = - builder.create(start_indices_clip_max, zero_cst); - - // Do dynamic_upate_slice on flattened operand and update with the aid of - // tf.TensorScatterUpdate op. It takes in 3 parameters: flat_operand, - // indices and flat_update. The indices are computed as follows: - // 1. Construct a range (0, n_operand). It arranges a id number to each - // element position in operand. - // 2. Reshape the range to the shape of operand. - // 3. Compute the id numbers of update positions by choose a slice form - // clamped_start_indices to clamped_start_indices + update_size. - // 4. Flatten the update id numbers and the indices is obtained. - int64_t n_operand = operand_type.getNumElements(); - Value n_operand_cst = - BuildIntConstOp(builder, rewriter, n_operand, idx_type); - Value range_flat = - builder.create(zero_cst, n_operand_cst, one_cst); - Value range = BuildReshapeOp(builder, rewriter, range_flat, operand_shape, - idx_type, idx_type); - Value update_indices_raw = - BuildSliceOp(builder, rewriter, range, clamped_start_indices, - update_shape, idx_type, idx_type); - int64_t n_update = update_type.getNumElements(); - Type element_type = operand_type.getElementType(); - Value update_indices = BuildReshapeOp(builder, rewriter, update_indices_raw, - {n_update, 1}, idx_type, idx_type); - Value operand_flat = BuildReshapeOp(builder, rewriter, op.getOperand(), - {n_operand}, idx_type, element_type); - Value update_flat = BuildReshapeOp(builder, rewriter, op.getUpdate(), - {n_update}, idx_type, element_type); - Value flat_result = builder.create( - operand_flat, update_indices, update_flat); - - // Reshape back before return. - rewriter.replaceOpWithNewOp(op, operand_type, flat_result, - operand_shape_cst); + Value start_indices_tensor = rewriter.create( + op.getLoc(), shape_tensor_type, start_indices_vector); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getOperand(), op.getUpdate(), + start_indices_tensor); return success(); }; }; @@ -1547,21 +1515,26 @@ bool MatchIota(DenseIntElementsAttr dimensions, Value iota) { MatchIotaConst(dimensions, iota); } +template bool MatchTopKComparator(Region& comparator) { if (!comparator.hasOneBlock()) return false; Block& comparator_blk = comparator.front(); using OpListType = llvm::iplist; OpListType& operations = comparator_blk.getOperations(); if (operations.size() != 2) return false; - auto compare_op = dyn_cast_or_null(&operations.front()); - auto return_op = dyn_cast_or_null(&operations.back()); + auto compare_op = dyn_cast_or_null(&operations.front()); + auto return_op = dyn_cast_or_null(&operations.back()); if (!compare_op || !return_op) return false; // TODO(xuanyuanluo): Support mhlo::ComparisonDirection::LT direction. - if (compare_op.getComparisonDirection() != mhlo::ComparisonDirection::GT) + if (std::is_same_v && + dyn_cast_or_null(&operations.front()) + .getComparisonDirection() != mhlo::ComparisonDirection::GT) { return false; - if (compare_op.getLhs() != comparator_blk.getArgument(0) || - compare_op.getRhs() != comparator_blk.getArgument(1)) + } + if (compare_op.getOperands()[0] != comparator_blk.getArgument(0) || + compare_op.getOperands()[1] != comparator_blk.getArgument(1)) { return false; + } return return_op.getOperands().front() == compare_op.getResult(); } @@ -1612,7 +1585,8 @@ class ConvertSortToTfTopk : public OpConversionPattern { if (!MatchIota(sort_dim_attr, indices)) return rewriter.notifyMatchFailure( op, "the second operand is supposed to be obtained from IOTA"); - if (!MatchTopKComparator(op.getComparator())) + if (!MatchTopKComparator( + op.getComparator())) return rewriter.notifyMatchFailure(op, "only match for GT comparator"); ImplicitLocOpBuilder builder(op.getLoc(), rewriter); Value k_cst = BuildIntConstOp(builder, rewriter, k, rewriter.getI32Type()); @@ -1798,7 +1772,7 @@ Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) { } // Checks if the specified region is a binary reduction function that takes 2 -// inputs, passes it to an instance of the specifiied reduction op and then +// inputs, passes it to an instance of the specified reduction op and then // returns the result. template LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { @@ -1835,7 +1809,7 @@ LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { } // Replace BinaryOp with a combination of TfBinaryOp and TfReduceOp if the -// init value doesn't match the expection of TfReduceOp. +// init value doesn't match the expectation of TfReduceOp. template LogicalResult rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices, @@ -1849,7 +1823,7 @@ LogicalResult rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op, Value input, return success(); } -// Cannot replace BinaryOp if the init value doesn't match the expection of +// Cannot replace BinaryOp if the init value doesn't match the expectation of // TfReduceOp and there is no corresponding TfBinaryOp. template <> LogicalResult rewriteNonMatchInitValue( @@ -2449,7 +2423,8 @@ class ConvertLoweredCumOp : public OpConversionPattern { } if (cumulative_axis == -1) { - return rewriter.notifyMatchFailure(rw, "no reduced dimension is found."); + rw.emitOpError() << "no reduced dimension is found."; + return failure(); } // For a cumulative op, padding (expressed as a list of left-padding and @@ -2993,6 +2968,10 @@ class ConvertGatherOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::GatherOp gather_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { + if (succeeded(ConvertGatherOpToSlice(gather_op, rewriter))) { + return success(); + } + Value operand = gather_op.getOperand(); Value start_indices = gather_op.getStartIndices(); @@ -3113,6 +3092,153 @@ class ConvertGatherOp : public OpConversionPattern { return success(); } + // Convert gather op to tf.slice and tf.concat + LogicalResult ConvertGatherOpToSlice( + mhlo::GatherOp gather_op, ConversionPatternRewriter& rewriter) const { + Value operand = gather_op.getOperand(); + Value start_indices = gather_op.getStartIndices(); + static const int rank_two = 2; + // This converts a gather op to multiple slice ops, cap the number of slice + // ops allowed. + static const int max_batch_size = 50; + + // Can only convert with static shaped gather. + ShapedType operand_type = operand.getType().cast(); + ShapedType start_indices_type = start_indices.getType().cast(); + ShapedType result_type = gather_op.getResult().getType().cast(); + if (!operand_type.hasStaticShape() || + !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) { + gather_op.emitOpError() << "Dynamic shaped inputs are not supported."; + return failure(); + } + + auto start_index_map = gather_op.getDimensionNumbers().getStartIndexMap(); + auto collapsed_slice_dims = + gather_op.getDimensionNumbers().getCollapsedSliceDims(); + auto offset_dims = gather_op.getDimensionNumbers().getOffsetDims(); + auto slice_sizes = gather_op.getSliceSizes(); + llvm::SmallVector slice_sizes_vector; + slice_sizes_vector.reserve(slice_sizes.size()); + for (int64_t s : slice_sizes.getValues()) { + slice_sizes_vector.push_back(s); + } + + llvm::SmallVector batch_dims; + // Offset dims are guaranteed to be sorted. + int offset_index = 0; + for (int64_t i = 0; i < result_type.getRank(); ++i) { + if (offset_index >= offset_dims.size() || + offset_dims[offset_index] != i) { + batch_dims.push_back(i); + } else { + ++offset_index; + } + } + // Here we only support gather with one batch dim and the batch dim is 0. + if (batch_dims.size() != 1 || batch_dims[0] != 0) { + return failure(); + } + int64_t batch_dim = batch_dims[0]; + // Batch dim in operand and start indices should match. + if (operand_type.getDimSize(batch_dim) > max_batch_size || + operand_type.getRank() != rank_two || + start_indices_type.getRank() != rank_two || + operand_type.getDimSize(batch_dim) != + start_indices_type.getDimSize(batch_dim) || + slice_sizes_vector[batch_dim] != 1) { + return failure(); + } + // Here we only support the case where [0, 1] in start_indices maps to + // operand[0, 1] + for (int64_t i = 0; i < start_index_map.size(); i++) { + if (start_index_map[i] != i) { + return failure(); + } + } + // Collapsed slice dims should contain the batch dim. + if (collapsed_slice_dims.size() != start_index_map.size() - 1 || + collapsed_slice_dims.size() != 1 || collapsed_slice_dims[0] != 0) { + return failure(); + } + + // Normalize start_indices so index_vector_dim == start_indices.rank() - 1. + int64_t index_vector_dim = + gather_op.getDimensionNumbers().getIndexVectorDim(); + if (failed(NormalizeIndexVector(gather_op, start_indices, + start_indices_type, index_vector_dim, + rewriter))) { + return failure(); + } + + ImplicitLocOpBuilder builder(gather_op.getLoc(), rewriter); + // Clamp the start indices to ensure it is in bounds. + auto max_start_indices = BuildIntArrayConstOp( + builder, rewriter, + llvm::SmallVector( + {operand_type.getDimSize(0) - slice_sizes_vector[0], + operand_type.getDimSize(1) - slice_sizes_vector[1]}), + start_indices_type.getElementType()); + auto min_start_indices = BuildIntArrayConstOp( + builder, rewriter, llvm::SmallVector({0, 0}), + start_indices_type.getElementType()); + auto start_indices_max_op = rewriter.create( + gather_op.getLoc(), start_indices, min_start_indices); + auto clamped_start_indices_op = rewriter.create( + gather_op.getLoc(), start_indices_max_op, max_start_indices); + + int64_t batch_size = start_indices_type.getDimSize(batch_dim); + auto slice_size = BuildIntArrayConstOp( + builder, rewriter, slice_sizes_vector, rewriter.getI32Type()); + if (batch_size == 1) { + auto squeeze_op = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get({rank_two}, + start_indices_type.getElementType()), + clamped_start_indices_op, + rewriter.getI64ArrayAttr(llvm::ArrayRef({batch_dim}))); + auto slice_op = + rewriter.create(gather_op.getLoc(), gather_op.getType(), + operand, squeeze_op, slice_size); + rewriter.replaceOp(gather_op, {slice_op}); + return mlir::success(); + } + + llvm::SmallVector slices; + slices.reserve(batch_size); + for (int64_t i = 0; i < batch_size; ++i) { + auto zero = BuildIntArrayConstOp(builder, rewriter, + llvm::SmallVector({i, 0}), + rewriter.getI32Type()); + auto two = BuildIntArrayConstOp(builder, rewriter, + llvm::SmallVector({1, 2}), + rewriter.getI32Type()); + auto begin = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get({1, 2}, start_indices_type.getElementType()), + clamped_start_indices_op, zero, two); + auto squeeze_op = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get({rank_two}, + start_indices_type.getElementType()), + begin, + rewriter.getI64ArrayAttr(llvm::ArrayRef({batch_dim}))); + auto slice_op = rewriter.create( + gather_op.getLoc(), + RankedTensorType::get({1, slice_sizes_vector[1]}, + operand_type.getElementType()), + operand, squeeze_op, slice_size); + slices.push_back(slice_op); + } + auto scalar_type = RankedTensorType::get({}, rewriter.getI32Type()); + auto zero_scalar = rewriter.create( + gather_op.getLoc(), + DenseIntElementsAttr::get(scalar_type, static_cast(0))); + auto concat_op = rewriter.create( + gather_op.getLoc(), result_type, slices, zero_scalar); + rewriter.replaceOp(gather_op, {concat_op}); + return mlir::success(); + } + private: // Canonicalize the offset dims to make sure the offset dims are the trailing // dimensions of the output tensor. @@ -3332,11 +3458,41 @@ class ConvertScatterOp : public OpConversionPattern { loc, permutation_and_shape.shape, operands[0], permutation_and_shape.permutation); + Value new_indices = indices; + int64_t index_depth = + permutation_and_shape.shape.getRank() - inserted_window_dims.size(); + int64_t num_updates = indices_type.getDimSize(0); + // For TF::TensorScatterUpdateOp, `indices` must have at least 2 axes: + // `(num_updates, index_depth)`. Reshape indices and updates if necessary. + if (std::is_same::value && + indices_type.getRank() == 1 && updates_type.getRank() == 1 && + index_depth == 1 && num_updates == 1) { + ImplicitLocOpBuilder builder(loc, rewriter); + auto indices_shape = BuildIntArrayConstOp( + builder, rewriter, + llvm::SmallVector({num_updates, index_depth}), + rewriter.getI32Type()); + new_indices = rewriter.create( + loc, + RankedTensorType::get({num_updates, index_depth}, + indices_type.getElementType()), + indices, indices_shape); + auto updates_shape = BuildIntArrayConstOp( + builder, rewriter, + llvm::SmallVector({num_updates, updates_type.getDimSize(0)}), + rewriter.getI32Type()); + new_updates = rewriter.create( + loc, + RankedTensorType::get({1, updates_type.getDimSize(0)}, + updates_type.getElementType()), + new_updates, updates_shape); + } + // Apply TF scatter to update the trailing dimensions of the // transposed operand. auto tf_scatter_op = rewriter.create(loc, permutation_and_shape.shape, - transposed_operand, indices, new_updates); + transposed_operand, new_indices, new_updates); // Reverse the earlier transpose. auto inverse_permutation = @@ -3398,6 +3554,161 @@ class ConvertPopulationCountOp } }; +class ConvertCustomCallWithApproxTopK + : public mlir::OpConversionPattern { + public: + explicit ConvertCustomCallWithApproxTopK(MLIRContext* context, + mlir::ModuleOp* module_op) + : OpConversionPattern(context), + module_op_(module_op) {} + + mlir::LogicalResult matchAndRewrite( + mhlo::CustomCallOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + if (op.getCallTargetName() != "ApproxTopK") { + return mlir::failure(); + } + auto is_supported_attr_name = [](NamedAttribute attr) { + auto name = attr.getName(); + return name == "call_target_name" || name == "backend_config" || + name == "api_version" || name == "called_computations"; + }; + for (const auto& attr : op->getAttrs()) { + if (!is_supported_attr_name(attr)) { + return op.emitOpError() + << attr.getName().getValue() + << " is not a supported attribute for ApproxTopK"; + } + } + auto backend_config = + op.getBackendConfigAttr().dyn_cast_or_null(); + if (!backend_config) { + return op.emitOpError() << "Missing backend_config attribute"; + } + + for (const auto& attr : backend_config) { + auto name = attr.getName(); + if (!(name == "top_k" || name == "reduction_dim" || + name == "recall_target" || name == "aggregate_to_topk" || + name == "reduction_input_size_override" || name == "is_fallback")) { + return op.emitOpError() + << name.getValue() << " is not a supported backend_config" + << " attribute for ApproxTopK"; + } + } + + auto check_i64_attr = + [&](const std::string& attr_name) -> mlir::LogicalResult { + if (!backend_config.contains(attr_name)) { + return op.emitOpError() + << "Missing " << attr_name << " attribute in backend_config"; + } + auto attr = backend_config.getAs(attr_name); + if (!attr || !attr.getType().isInteger(64)) { + return op.emitOpError() + << attr_name + << " attribute in backend_config must be of i64 type"; + } + return success(); + }; + auto check_f32_attr = + [&](const std::string& attr_name) -> mlir::LogicalResult { + if (!backend_config.contains(attr_name)) { + return op.emitOpError() + << "Missing " << attr_name << " attribute in backend_config"; + } + auto attr = backend_config.getAs(attr_name); + if (!attr || !attr.getType().isF32()) { + return op.emitOpError() + << attr_name + << " attribute in backend_config must be of f32 type"; + } + return success(); + }; + auto check_bool_attr = + [&](const std::string& attr_name) -> mlir::LogicalResult { + if (!backend_config.contains(attr_name)) { + return op.emitOpError() + << "Missing " << attr_name << " attribute in backend_config"; + } + auto attr = backend_config.getAs(attr_name); + if (!attr) { + return op.emitOpError() + << attr_name + << " attribute in backend_config must be of bool type"; + } + return success(); + }; + if (failed(check_i64_attr("top_k"))) return failure(); + if (failed(check_i64_attr("reduction_dim"))) return failure(); + if (failed(check_f32_attr("recall_target"))) return failure(); + if (failed(check_bool_attr("aggregate_to_topk"))) return failure(); + if (failed(check_i64_attr("reduction_input_size_override"))) { + return failure(); + } + bool has_is_fallback = backend_config.contains("is_fallback"); + if (has_is_fallback && !backend_config.getAs("is_fallback")) { + return op.emitOpError() + << "is_fallback attribute in backend_config must be of bool type"; + } + + auto top_k_attr = backend_config.getAs("top_k"); + auto reduction_dim_attr = + backend_config.getAs("reduction_dim"); + auto recall_target_attr = backend_config.getAs("recall_target"); + auto aggregate_to_topk_attr = + backend_config.getAs("aggregate_to_topk"); + auto reduction_input_size_override_attr = + backend_config.getAs("reduction_input_size_override"); + if (op.getInputs().size() % 2 != 0) { + return op.emitOpError() << "ApproxTopK takes an even number of operands."; + } + + auto called_computations = op.getCalledComputations(); + if (called_computations.size() != 1) { + return op.emitOpError() + << "ApproxTopK takes exactly 1 called_computation."; + } + mlir::func::FuncOp callee = module_op_->lookupSymbol( + op.getCalledComputations()[0].cast()); + mlir::FunctionType callee_type = callee.getFunctionType(); + SmallVector expected_callee_input_types; + auto num_inputs = op.getInputs().size() / 2; + for (unsigned i = 0; i < num_inputs; ++i) { + auto input_type = op.getOperand(i).getType().dyn_cast(); + auto scalar = RankedTensorType::get({}, input_type.getElementType()); + expected_callee_input_types.push_back(scalar); + expected_callee_input_types.push_back(scalar); + } + FunctionType expected_callee_type = mlir::FunctionType::get( + op->getContext(), expected_callee_input_types, + RankedTensorType::get({}, IntegerType::get(op->getContext(), 1))); + if (callee_type != expected_callee_type) { + return op.emitOpError() + << "called_computation type does not match the expected type. Got " + << callee_type << " expected " << expected_callee_type; + } + if (!MatchTopKComparator( + callee.getBody()) && + !MatchTopKComparator( + callee.getBody())) { + return op.emitOpError() << "only match for GT comparator"; + } + auto is_max_k = rewriter.getBoolAttr(true); + + auto approx_top_k = rewriter.create( + op.getLoc(), op->getResultTypes(), op.getInputs()[0], top_k_attr, + reduction_dim_attr, recall_target_attr, is_max_k, + reduction_input_size_override_attr, aggregate_to_topk_attr); + + rewriter.replaceOp(op, approx_top_k.getResults()); + return mlir::success(); + } + + private: + mlir::ModuleOp* module_op_; +}; + // Returns true if broadcast_dimensions obey Tensorflow convention, as in new // dimensions are added as prefix. bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions, @@ -3441,9 +3752,10 @@ arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, /// Performs the lowering to XLA dialect. void LegalizeHloToTf::runOnOperation() { MLIRContext& context = getContext(); + mlir::ModuleOp module = getOperation()->getParentOfType(); - // Add legalization patterns to the list. RewritePatternSet patterns(&getContext()); + patterns.add(&context, &module); PopulateLegalizeHloToTfPatterns(&patterns, &context); ConversionTarget target(context); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 16493c30286..0261783da7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -189,6 +189,9 @@ def : Pat<(MHLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))> def : Pat<(MHLO_ReshapeOp:$output $input), (TF_ReshapeOp $input, (ShapeToConst $output))>; +// Both implement the Banker's rounding. +def : Pat<(MHLO_RoundNearestEvenOp $input), (TF_RoundOp $input)>; + //===----------------------------------------------------------------------===// // Ternary op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 738a109dec6..ed1e896e2d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/core/lib/monitoring/gauge.h" @@ -395,6 +397,38 @@ void UnmarkChildren(ModuleOp module) { }); } +constexpr int kTooManyOutsideCompileRegionThreshold = 32; +constexpr int kOpDetailCount = 8; + +void WarnOnExcessOutsideCompilationOps(ModuleOp module) { + // Count the number of outside compilation ops. If it exceeds the reporting + // threshold, warn the user that their model may run slowly. + llvm::SmallVector outside_compile_ops; + module->walk([&](Operation* op) { + if (op->getAttrOfType(kXlaOutsideCompilationAttr)) { + outside_compile_ops.push_back(op); + } + }); + + if (outside_compile_ops.size() > kTooManyOutsideCompileRegionThreshold) { + llvm::SmallVector op_info; + for (int i = 0; i < kOpDetailCount; ++i) { + auto& op = outside_compile_ops[i]; + op_info.push_back(tensorflow::OpAsString(*op)); + } + + LOG(WARNING) << outside_compile_ops.size() << " outside compilation " + << "regions found while processing " + << module->getName().getStringRef().str() + << ". This may result in excessively slow model execution. " + << "First " << op_info.size() + << " ops: " << absl::StrJoin(op_info, "\n"); + } else { + LOG(INFO) << "Found " << outside_compile_ops.size() + << " outside compilation regions."; + } +} + void MarkOpsForOutsideCompilation::runOnOperation() { auto module = getOperation(); const Dialect* tf_dialect = getContext().getLoadedDialect("tf"); @@ -446,6 +480,8 @@ void MarkOpsForOutsideCompilation::runOnOperation() { if (result.wasInterrupted()) return signalPassFailure(); UnmarkChildren(module); + + WarnOnExcessOutsideCompilationOps(module); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 4092b90411e..a46a1204af2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -323,6 +323,20 @@ std::unique_ptr> CreateSplitIntoIslandPerOpPass(); // CPU/GPU bridge. void CreateTFXLABridgePipeline(OpPassManager& pm); +//===----------------------------------------------------------------------===// +// XlaCallModule +//===----------------------------------------------------------------------===// + +// Creates a pass that deserializes functions in the StableHLO modules from +// `tf.XlaCallModule` to the top-level module. +std::unique_ptr> +CreateXlaCallModuleDeserializationPass(); + +// Creates a pass that serializes StableHLO functions referenced by +// `tf.XlaCallModule` from the top-level module to `tf.XlaCallModule`'s +// `module` attribute. +std::unique_ptr> CreateXlaCallModuleSerializationPass(); + } // namespace TF namespace tf_executor { @@ -435,7 +449,9 @@ std::unique_ptr> CreateReplicaIDToDeviceOrdinalPass(); // Creates a pass that adds pipelining to a graph that contains device -// accelerated embeddings. +// accelerated embeddings. The EmbeddingSequencingPass is a temporary fallback +// while developing full pipelining capabilities. +std::unique_ptr> CreateEmbeddingSequencingPass(); std::unique_ptr> CreateEmbeddingPipeliningPass(); // Creates a pass that creates `tf_executor.island` from a single @@ -491,6 +507,9 @@ std::unique_ptr> CreateXlaInlineDeviceOpsPass(); // Creates a pass that rewrites partitioned calls with `_xla_compile_device // type` with `tf.XlaLaunch` ops. std::unique_ptr> CreateXlaRewritePass(); + +// Create a pass that validates the input graph to the CPU/GPU bridge. +std::unique_ptr> CreateXlaValidateInputsPass(); } // namespace TFDevice namespace TFTPU { @@ -725,6 +744,9 @@ enum MoveTransposeDirection { kBegin, kEnd }; #define GEN_PASS_DECL_TRANSFORMEINSUMPASS #define GEN_PASS_DECL_UNROLLBATCHMATMULPASS #define GEN_PASS_DECL_VERIFYSUITABLEFOREXPORTPASS +#define GEN_PASS_DECL_XLACALLMODULEDESERIALIZATIONPASS +#define GEN_PASS_DECL_XLACALLMODULESERIALIZATIONPASS +#define GEN_PASS_DECL_XLACALLMODULECUSTOMCALLTFFUNCTIONRENAMINGPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" } // namespace detail using namespace detail; // NOLINT @@ -746,6 +768,7 @@ namespace TFDevice { #define GEN_PASS_DECL_XLACLUSTERFORMATIONPASS #define GEN_PASS_DECL_XLAINLINEDEVICEOPSPASS #define GEN_PASS_DECL_XLAREWRITEPASS +#define GEN_PASS_DECL_XLAVALIDATEINPUTSPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc" } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index a8bfb700209..e03eb9a9228 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -308,7 +308,7 @@ LogicalResult RegionResourceHoister::Analyze() { // Since all the sub-regions within this region (i.e., regions attached to // op's in this region) have themselves gone through lifting, all resource // users are expected to be operations in this region and not embedded - // within other sub-regions attached to op's in this region. So the check + // within other sub-regions attached to ops in this region. So the check // for whether a user is in one of the regions attached to this op is // straightforward. if (user->getParentRegion()->getParentOp() != op_) continue; @@ -1260,6 +1260,11 @@ void ResourceOpLiftingPass::runOnOperation() { }); if (walk_result.wasInterrupted()) return signalPassFailure(); + + // Clean up and canonicalize to remove dead local variables as some local + // variables might be dead after hoisting resource loads/stores. + if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module))) + return signalPassFailure(); } #define GEN_PASS_DEF_RESOURCEOPLIFTINGFORMAINFUNCTIONPASS diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc index 99693a91b2f..2f1c675b305 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h" #include +#include #include "llvm/ADT/BitVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -52,22 +53,55 @@ void RemovePassthroughOp(Block &block) { } } +using LocalVarOp = std::variant; + +Value LocalVarOp_resource(LocalVarOp &op) { + if (auto var_handle_op = std::get_if(&op)) { + return var_handle_op->getResource(); + } else { + return std::get(op).getResource(); + } +} + +void LocalVarOp_erase(LocalVarOp &op) { + if (auto var_handle_op = std::get_if(&op)) { + var_handle_op->erase(); + } else { + std::get(op).erase(); + } +} + +std::optional IsLocalVarOp(Operation &op) { + if (TF::MlirLocalVarOp mlir_local_var_op = + llvm::dyn_cast(&op)) { + return std::make_optional(LocalVarOp(mlir_local_var_op)); + } + if (TF::VarHandleOp var_handle_op = llvm::dyn_cast(&op)) { + auto ANONYMOUS_NAME = ::tensorflow::ResourceHandle::ANONYMOUS_NAME; + if (var_handle_op.getSharedName() == ANONYMOUS_NAME) { + return std::make_optional(LocalVarOp(var_handle_op)); + } + } + return {}; +} + // Eliminate local variables that are only assigned to but never read, and thus // are dead. void RemoveDeadLocalVariables(Block &block) { - llvm::SmallVector local_vars; + llvm::SmallVector local_vars; for (Operation &op : block) { - if (auto local_var = llvm::dyn_cast(&op)) { - local_vars.push_back(local_var); + if (auto local_var = IsLocalVarOp(op)) { + local_vars.push_back(local_var.value()); } } for (auto local_var : local_vars) { - auto users = local_var.getResource().getUsers(); + auto users = LocalVarOp_resource(local_var).getUsers(); if (llvm::all_of(users, [](const Operation *user) { - return isa(user); + return isa(user) || + isa(user); })) { for (auto user : llvm::make_early_inc_range(users)) user->erase(); - local_var.erase(); + LocalVarOp_erase(local_var); } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 1edc7f4bb73..2cf33360e10 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -530,7 +530,7 @@ struct ValuePortHasher { }; using ValuePortResultMap = - std::unordered_map; + absl::flat_hash_map; using ComputedQueryFn = function_ref; using ValueQueryFn = function_ref; using ValuePortInputs = SmallVectorImpl; @@ -1200,14 +1200,24 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { for (auto attr : op.getDimArgsSpec().getAsRange()) { dim_args_spec.push_back(attr.getValue().str()); } - + std::vector disabled_checks; + for (auto attr : op.getDisabledChecks().getAsRange()) { + disabled_checks.push_back(attr.getValue().str()); + } + std::vector platforms; + for (auto attr : op.getPlatforms().getAsRange()) { + platforms.push_back(attr.getValue().str()); + } // Always use the first platform. The assumption is that shape inference // results should be the same regardless of which platform is chosen. - int platform_index = op.getPlatforms().size() > 1 ? 0 : -1; + // Very old versions of the op have an empty platforms attribute. + std::string loading_platform = + (platforms.empty() ? "CPU" : platforms.front()); auto l = tensorflow::XlaCallModuleLoader::Create( &xla_call_module_context_, op.getVersion(), op.getModule().str(), - std::move(dim_args_spec), platform_index); + std::move(dim_args_spec), std::move(disabled_checks), + std::move(platforms), std::move(loading_platform)); if (!l.ok()) { LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: " << l.status().ToString() << "\n"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td index e307266c93f..e8d78b646cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td @@ -390,3 +390,12 @@ def XlaRewritePass : Pass<"tf-xla-rewrite", "mlir::ModuleOp"> { let constructor = "TFDevice::CreateXlaRewritePass()"; let dependentDialects = ["tf_device::TensorFlowDeviceDialect"]; } + +def XlaValidateInputsPass : Pass<"tf-xla-validate-inputs", "ModuleOp"> { + let summary = "Validtes inputs to the TF CPU/GPU bridge"; + let description = [{ + This pass checks that the IR has valid input to CPU/GPU TF/XLA bridge. + }]; + let constructor = "TFDevice::CreateXlaValidateInputsPass()"; + let dependentDialects = ["tf_device::TensorFlowDeviceDialect"]; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index 839d9d601d9..93d2a9c708b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -402,6 +402,16 @@ def EmbeddingPipeliningPass : Pass<"tf-embedding-pipelining", "mlir::ModuleOp"> }]; } +def EmbeddingSequencingPass : Pass<"tf-embedding-sequencing", "mlir::ModuleOp"> { + let summary = "Rewrite graph for sequential execution of embeddings"; + let constructor = "TFDevice::CreateEmbeddingSequencingPass()"; + let description = [{ + This is a strictly sequential and formally correct fallback option for the + embedding pipelining pass intended for debugging during pipelining + development. + }]; +} + def ConvertReadonlyReferenceVariablesToResourceVariablesPass : Pass<"tf-readonly-references-to-resources", "mlir::func::FuncOp"> { let summary = "Convert readonly reference variables to resource variables."; @@ -2707,3 +2717,70 @@ def NameAnonymousIteratorsPass : Pass<"tf-name-anonymous-iterators", "ModuleOp"> }]; let constructor = "TF::CreateNameAnonymousIteratorsPass()"; } + +//===----------------------------------------------------------------------===// +// XlaCallModule +//===----------------------------------------------------------------------===// + +def XlaCallModuleDeserializationPass + : Pass<"tf-xla-call-module-deserialization", "ModuleOp"> { + let summary = "Deserializes StableHLO functions embedded in `tf.XlaCallModule` to top level module"; + + let description = [{ + This pass deserializes the StableHLO bytecodes embedded in tf.XlaCallModule, + then outlines the functions in the deserialized StableHLO module to the top + level MLIR module, with function renamings to avoid naming conflicts. + + After the outlining, it updates tf.XlaCallModule's module attribute to be + empty, adds an `_entry_function` attribute referring to the entry function. + It also adds a `_from_xla_call_module: true` attribute to each lifted + StableHLO function. + }]; + + // These dialects are needed by stablehlo deserialization. + // + // We use tensorflow::XlaCallModuleLoader. + // tensorflow::XlaCallModuleLoader will get or load dialects: + // Func, Stablehlo, Mhlo, Chlo, and Vhlo. + // + // XlaCallModuleLoader uses mlir::stablehlo::deserializePortableArtifact, + // which runs VhloLegalizeToStablehloPass whose depends on dialects: + // Func, Stablehlo, Shape, and Quantization. + // + // If we do not register them here, an error will be + // triggered because we cannot load a dialect while in a + // multi-threaded execution context, and PassManager is + // multi-threaded. + let dependentDialects = [ + "chlo::ChloDialect", + "mhlo::MhloDialect", + "shape::ShapeDialect", + "stablehlo::StablehloDialect", + "vhlo::VhloDialect", + "quant::QuantizationDialect", + ]; + + let constructor = "TF::CreateXlaCallModuleDeserializationPass()"; +} + +def XlaCallModuleSerializationPass + : Pass<"tf-xla-call-module-serialization", "ModuleOp"> { + let summary = "Serializes StableHLO functions from top-level module into `tf.XlaCallModule`'s `module` attribute"; + + let description = [{ + This pass collects StableHLO functions referenced from `tf.XlaCallModule`'s + `_entry_function` attribute into a module, serializes the module into MLIR + bytecode, and embed the bytecode to `tf.XlaCallModule`'s `module` attribute. + + After serialization, this pass removes the `_entry_function` attribute from + `tf.XlaCallModule`, and removes all the serialized stablehlo functions + from the top-level module. + }]; + + let dependentDialects = [ + "stablehlo::StablehloDialect", + "vhlo::VhloDialect", + ]; + + let constructor = "TF::CreateXlaCallModuleSerializationPass()"; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc new file mode 100644 index 00000000000..cb7d8ea0c21 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc @@ -0,0 +1,142 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/tsl/platform/path.h" + +namespace mlir { +namespace tf_saved_model { +namespace { + +#define GEN_PASS_DEF_ASSETSINKINGPASS +#define GEN_PASS_DECL_ASSETSINKINGPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.h.inc" + +class AssetSinkingPass : public impl::AssetSinkingPassBase { + public: + AssetSinkingPass() = default; + + explicit AssetSinkingPass(llvm::StringRef saved_model_dir) { + saved_model_dir_ = saved_model_dir.str(); + } + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + if (!mlir::tf_saved_model::HasTfSavedModelSemantics(module)) { + return; + } + + auto init_op = mlir::tf_saved_model::GetSessionInitializerOp(module); + if (init_op == nullptr || init_op.getInitializers().empty()) { + return; + } + + mlir::SymbolTable symbol_table(module); + for (auto initializer : init_op.getInitializers()) { + auto func = symbol_table.lookup( + initializer.cast().getValue()); + RewriteFunction(symbol_table, func); + } + + // Clean up unused asset ops. + for (auto asset : llvm::make_early_inc_range( + module.getOps())) { + if (symbol_table.symbolKnownUseEmpty(asset, module)) { + asset.erase(); + } + } + } + + private: + // Replaces bounded-input arguments of the function with constant ops in the + // body and removes the arguments. + void RewriteFunction(const mlir::SymbolTable& symbol_table, + mlir::func::FuncOp func) { + if (func.getNumArguments() == 0) { + return; + } + + auto builder = mlir::OpBuilder::atBlockBegin(&func.front()); + + llvm::SmallDenseMap const_ops; + llvm::BitVector arg_indexes_to_remove(func.getNumArguments()); + + // Replace arguments with const ops. + for (mlir::BlockArgument argument : func.getArguments()) { + auto asset = mlir::tf_saved_model::LookupBoundInputOfType< + mlir::tf_saved_model::AssetOp>(func, argument.getArgNumber(), + symbol_table); + if (asset == nullptr) { + continue; + } + + // Create a const op for the asset if it doesn't already exist. + auto it = const_ops.find(asset.getSymName()); + if (it == const_ops.end()) { + // Asset filenames are relative to the SavedModel directory. + const std::string filename = tsl::io::JoinPath( + saved_model_dir_, absl::string_view(asset.getFilename())); + + mlir::RankedTensorType type = mlir::RankedTensorType::get( + {}, mlir::TF::StringType::get(builder.getContext())); + auto const_op = builder.create( + builder.getUnknownLoc(), + mlir::DenseStringElementsAttr::get(type, {filename})); + + it = const_ops.insert({asset.getSymName(), const_op}).first; + } + + argument.replaceAllUsesWith(it->second.getOutput()); + arg_indexes_to_remove.set(argument.getArgNumber()); + } + + // Erase function arguments with bounded input. + func.eraseArguments(arg_indexes_to_remove); + } +}; + +} // namespace + +std::unique_ptr> CreateAssetSinkingPass( + llvm::StringRef saved_model_dir) { + return std::make_unique(saved_model_dir); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h new file mode 100644 index 00000000000..a14e98e483f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_ASSET_SINKING_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_ASSET_SINKING_PASS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tf_saved_model { + +// Creates a pass that sinks SavedModel asset filenames to constants. +std::unique_ptr> CreateAssetSinkingPass( + llvm::StringRef saved_model_dir); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_ASSET_SINKING_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h index 3cdabf52246..801eaaeb0ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" #include "tensorflow/core/public/session.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td index 302cdf29bc3..2e190255080 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td @@ -162,3 +162,28 @@ def AddFunctionsForExportedNamesPass : Pass<"tf-saved-model-add-functions-for-ex }]; let constructor = "::mlir::tf_saved_model::CreateAddFunctionsForExportedNamesPass()"; } + +def AssetSinkingPass : Pass<"tf-saved-model-asset-sinking", "mlir::ModuleOp"> { + let summary = "Sinks SavedModel asset filenames to constants"; + + let description = [{ + This pass sinks arguments of SavedModel methods that are bounded to + `tf_saved_model.asset` into constants in the methods. After the pass, unused + asset ops are removed from the module. + + This is to convert initialization methods with bound inputs into the same + methods without any arguments, so that program invocation doesn't need to + track and explicitly pass asset filenames. + + This pass accepts an option `saved-model-dir`, which specifies the directory + where SavedModel is stored. This is a required option because all asset + filenames are relative to this directory. + }]; + + let constructor = "::mlir::tf_saved_model::CreateAssetSinkingPass(\"\")"; + + let options = [ + Option<"saved_model_dir_", "saved-model-dir", "std::string", "", + "SavedModel directory, which is prepended to asset file names.">, + ]; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index bb2f8f26b65..bbbb92db49e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include #include +#include #include +#include #include #include #include @@ -36,10 +38,12 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -51,6 +55,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" #include "tensorflow/core/util/device_name_utils.h" @@ -134,6 +139,11 @@ LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) { return success(); } +struct OpDevice { + Operation* op; + std::string device; +}; + // Collects and clusters ops either based on `_replication_info` attribute // (replicated case) or using one single cluster (non-replicated case). Also // sets `device_type` if there is any cluster (note that the device type must be @@ -147,7 +157,7 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, bool has_local_device_name_collisions = false; // Use ordered set here to make error message below deterministic. std::set device_types; - std::unordered_map devices; + absl::flat_hash_map devices; for (Operation& op : *block) { LogicalResult result = TF::HasValidCompilationAndReplicationAttributes(op); if (failed(result)) return result; @@ -188,10 +198,25 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, } auto device_attr = op.getAttrOfType(kDeviceAttr); std::string device_local_name; + bool is_tpu_device = false; if (device_attr && !device_attr.str().empty()) { + tensorflow::DeviceNameUtils::ParsedName parsed; + if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_attr.str(), + &parsed)) { + op.emitWarning() << "Invalid device name " << device_attr.str(); + return failure(); + } + device_local_name = - tensorflow::DeviceNameUtils::LocalName(device_attr.str()); + tensorflow::DeviceNameUtils::LocalName(parsed.type, parsed.id); + is_tpu_device = parsed.type == "TPU"; } + + // Ignore non-TPU devices when clustering. + if (!is_tpu_device) { + continue; + } + if (!has_replicated_compiled_op && !device_local_name.empty()) { // It is possible that a device may be same Local Name but // different fullname. Devices with same Local name are identical @@ -200,24 +225,30 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, // information such as task, replica, job etc. An example fullname is // "/job:foo_bar/replica:1/task:2/device:GPU:3" if (devices.count(device_local_name)) { - std::string device1 = devices[device_local_name]; + std::string device1 = devices[device_local_name].device; std::string device2 = device_attr.str(); // Is either of the two devices just a substring of the other? If // not, we treat them as different devices, and we have a collision. if (device1.find(device2) == std::string::npos && device2.find(device1) == std::string::npos) { + Operation* previous_op = devices[device_local_name].op; has_local_device_name_collisions = true; - LOG(WARNING) << "found two devices with same local name " + + LOG(WARNING) << "Found two devices with same local name " << device_local_name << " but conflicting fullname: " << device1 << " and " - << device2; + << device2 << "."; + LOG(WARNING) << "Previous assignment came from op: " + << tensorflow::OpAsString(*previous_op) + << ". Current op is: " << tensorflow::OpAsString(op); } // Always keep the longer name. - if (devices[device_local_name].size() < device_attr.str().size()) { - devices[device_local_name] = device_attr.str(); + if (devices[device_local_name].device.size() < + device_attr.str().size()) { + devices[device_local_name] = {&op, device_attr.str()}; } } else { - devices.insert({device_local_name, device_attr.str()}); + devices.insert({device_local_name, {&op, device_attr.str()}}); } } } @@ -237,13 +268,14 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, if (devices.size() > 1) { LOG(WARNING) << "found different devices for no replication: "; for (const auto& device_names : devices) { - LOG(WARNING) << device_names.first << ", " << device_names.second; + LOG(WARNING) << device_names.first << ", " + << device_names.second.device; } } else if (has_local_device_name_collisions) { LOG(WARNING) << "Not assigning device because of conflicting fullnames."; } else if (devices.size() == 1 && - absl::StrContains(devices.begin()->second, "TPU:")) { - device = devices.begin()->second; + absl::StrContains(devices.begin()->second.device, "TPU:")) { + device = devices.begin()->second.device; } } if (!clusters->empty()) { @@ -697,7 +729,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, } } - // Create `ordered_tpu_replicate_inputs` which constains the final ordered + // Create `ordered_tpu_replicate_inputs` which contains the final ordered // replicate inputs. All packed arguments are moved to the end of the arg // list. llvm::SmallVector ordered_tpu_replicate_inputs = diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc new file mode 100644 index 00000000000..8e9431f7391 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc @@ -0,0 +1,280 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace TF { +namespace { + +#define GEN_PASS_DEF_XLACALLMODULEDESERIALIZATIONPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +// `tf.backend_config` is a DictionaryAttr, JAX2TF sets the value of its +// i64 attribute `called_index` to the TF function's name. +constexpr llvm::StringRef kTfBackendConfigAttrName = "tf.backend_config"; +constexpr llvm::StringRef kCalledIndexAttrName = "called_index"; +constexpr llvm::StringRef kCalledFuncAttrName = "called_func"; + +// The function name format for the deserialized stablehlo functions: +// _stablehlo_{original function name}_{index}. +constexpr const char *kNewFuncNameFormat = "_stablehlo_%s_%d"; + +// Deserialize the StableHLO module embedded in XlaCallModuleOp's module +// attribute. +tsl::StatusOr> DeserializeStablehlo(MLIRContext *context, + XlaCallModuleOp op) { + std::vector dim_args_spec; + for (auto attr : op.getDimArgsSpec().getAsRange()) { + dim_args_spec.push_back(attr.getValue().str()); + } + std::vector disabled_checks; + for (auto attr : op.getDisabledChecks().getAsRange()) { + disabled_checks.push_back(attr.getValue().str()); + } + std::vector platforms; + for (auto attr : op.getPlatforms().getAsRange()) { + platforms.push_back(attr.getValue().str()); + } + // XlaCallModuleOp OpKernel will determine platform index when running + // TF2XLA. We don't know the device/platform type in this MLIR pass, so + // we set loading_platform to the first platform. + std::string loading_platform = + (platforms.empty() ? "CPU" : platforms.front()); + TF_ASSIGN_OR_RETURN( + auto loader, + tensorflow::XlaCallModuleLoader::Create( + context, static_cast(op.getVersion()), op.getModule().str(), + std::move(dim_args_spec), std::move(disabled_checks), + std::move(platforms), std::move(loading_platform))); + return std::move(*loader).module(); +} + +// Returns a new function name in the kNewFuncNameFormat. +// The new name is unique in the symbol table. +std::string NewFuncName(const SymbolTable &symbol_table, + const llvm::StringRef func_name) { + uint64_t index = 0; + std::string new_func_name; + do { + new_func_name = absl::StrFormat(kNewFuncNameFormat, func_name, index++); + } while (symbol_table.lookup(new_func_name)); + return new_func_name; +} + +// Renames functions in the stablehlo module to avoid naming conflicts with +// existing functions in the tf module. +// Sets _from_xla_call_module attribute for each stablehlo function. +// Returns the new stablehlo main function's name or error. +// +// If we directly insert stablehlo functions into tf module, MLIR will rename +// the stablehlo functions themselves in the tf module automatically to avoid +// naming conflicts. But we need to rename the function calls inside the +// stablehlo functions as well. So we first do this renaming in the stablehlo +// module itself without inserting into the tf module. +FailureOr RenameStablehloFunctions( + MLIRContext *context, SymbolTableCollection &symbol_tables, + ModuleOp tf_module, ModuleOp stablehlo_module) { + SymbolTable &tf_sym_table = symbol_tables.getSymbolTable(tf_module); + SymbolTable &stablehlo_sym_table = + symbol_tables.getSymbolTable(stablehlo_module); + Builder builder(context); + StringAttr new_main_func_name; + for (auto func : stablehlo_module.getOps()) { + auto new_func_name = + builder.getStringAttr(NewFuncName(tf_sym_table, func.getSymName())); + if (func.getSymName() == kStablehloMainFunctionName) { + new_main_func_name = new_func_name; + } + if (failed(stablehlo_sym_table.replaceAllSymbolUses(func, new_func_name, + stablehlo_module))) { + return failure(); + } + func.setName(new_func_name); + func->setAttr(kFromXlaCallModuleAttrName, builder.getUnitAttr()); + } + return new_main_func_name; +} + +// Moves functions from one module to another. +// The moved functions are set to private. +void MoveFunctions(SymbolTableCollection &symbol_tables, ModuleOp from, + ModuleOp to) { + SymbolTable &to_sym_table = symbol_tables.getSymbolTable(to); + for (auto func : llvm::make_early_inc_range(from.getOps())) { + func->remove(); + func.setPrivate(); + to_sym_table.insert(func); + } +} + +void CopyStablehloModuleAttrs(ModuleOp stablehlo_module, XlaCallModuleOp op) { + op->setAttr(kStablehloModuleAttrsAttrName, + stablehlo_module->getAttrDictionary()); +} + +// Symbolizes `called_index` attributes in custom all ops to `called_func`. +LogicalResult SymbolizeCustomCallCalledIndex( + ModuleOp module, llvm::ArrayRef function_list) { + WalkResult result = + module.walk([&](stablehlo::CustomCallOp op) { + if (!IsTfFuncCustomCall(op)) { + return WalkResult::advance(); + } + + auto backend_config = + op->getAttrOfType(kTfBackendConfigAttrName); + if (!backend_config) { + op->emitOpError() + << "is missing attribute '" << kTfBackendConfigAttrName << "'"; + return WalkResult::interrupt(); + } + + auto called_index_attr = backend_config.get(kCalledIndexAttrName) + .dyn_cast_or_null(); + if (!called_index_attr) { + op->emitOpError() + << "is missing attribute '" << kCalledIndexAttrName << "'"; + return WalkResult::interrupt(); + } + int called_index = called_index_attr.getInt(); + if (called_index < 0 || called_index >= function_list.size()) { + op->emitOpError() + << "references function #" << called_index + << " but enclosing XlaCallModule has a function list of size " + << function_list.size(); + return WalkResult::interrupt(); + } + + llvm::SmallVector new_config; + // Copy the attributes in the current config except `called_index`. + for (auto attr : backend_config) { + if (attr.getName() != kCalledIndexAttrName) { + new_config.push_back(attr); + } + } + + Builder builder(op.getContext()); + // Sets the `called_index` attribute to the TF function's name. + new_config.push_back(builder.getNamedAttr(kCalledFuncAttrName, + function_list[called_index])); + + // Sets the `tf.backend_config` attribute to the `new_config`. + op->setAttr(kTfBackendConfigAttrName, + builder.getDictionaryAttr(new_config)); + + return WalkResult::advance(); + }); + return result.wasInterrupted() ? failure() : success(); +} + +LogicalResult DeserializeXlaCallModule(MLIRContext *context, + SymbolTableCollection &symbol_tables, + ModuleOp module, XlaCallModuleOp op) { + auto deserialized = DeserializeStablehlo(context, op); + if (!deserialized.ok()) { + return op.emitOpError() + << "failed to deserialize StableHLO module from XlaCallModule: " + << deserialized.status().ToString(); + } + OwningOpRef stablehlo_module = *std::move(deserialized); + + CopyStablehloModuleAttrs(*stablehlo_module, op); + + auto main_func = RenameStablehloFunctions(context, symbol_tables, module, + stablehlo_module.get()); + if (failed(main_func)) { + return failure(); + } + + MoveFunctions(symbol_tables, *stablehlo_module, module); + + // Translate `called_index` in TF function custom calls into symbol + // references. `function_list` attribute is needed after that. + SmallVector function_list( + op.getFunctionList().getAsRange()); + if (failed(SymbolizeCustomCallCalledIndex(module, function_list))) { + return failure(); + } + op.removeFunctionListAttr(); + + // Module is deserialized, we set an empty string to it instead removing + // it because it's a required attribute. + op.setModule(""); + // Set the stablehlo main function as a symbol attribute. + // This is required because we not only need this to look up the + // stablehlo function called by XlaCallModule, but also need the symbol + // reference to prevent DCE from removing the stablehlo functions from the + // top-level module. + op->setAttr(kStablehloEntryFunctionAttrName, + SymbolRefAttr::get(main_func.value())); + + return success(); +} + +class XlaCallModuleDeserializationPass + : public impl::XlaCallModuleDeserializationPassBase< + XlaCallModuleDeserializationPass> { + public: + void runOnOperation() override { + ModuleOp module = getOperation(); + SymbolTableCollection symbol_tables; + WalkResult result = module.walk([&](XlaCallModuleOp op) { + if (failed(DeserializeXlaCallModule(&getContext(), symbol_tables, module, + op))) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateXlaCallModuleDeserializationPass() { + return std::make_unique(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc new file mode 100644 index 00000000000..a75bf4c75d8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc @@ -0,0 +1,260 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "stablehlo/api/PortableApi.h" // from @stablehlo +#include "stablehlo/dialect/Serialization.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +namespace mlir { +namespace TF { +namespace { + +#define GEN_PASS_DEF_XLACALLMODULESERIALIZATIONPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" // IWYU pragma: keep + +// `tf.backend_config` is a DictionaryAttr, JAX2TF sets the value of its +// i64 attribute `called_index` to the TF function's name. +constexpr llvm::StringRef kTfBackendConfigAttrName = "tf.backend_config"; +constexpr llvm::StringRef kCalledIndexAttrName = "called_index"; +constexpr llvm::StringRef kCalledFuncAttrName = "called_func"; + +// Converts `called_func` attributes in custom call ops back to `called_index`. +FailureOr DesymbolizeCustomCallCalledIndex(ModuleOp module) { + Builder builder(module.getContext()); + + SmallVector function_list; + llvm::DenseMap called_indexes; + + WalkResult result = module.walk([&](stablehlo::CustomCallOp op) { + if (!IsTfFuncCustomCall(op)) { + return WalkResult::advance(); + } + + auto backend_config = + op->getAttrOfType(kTfBackendConfigAttrName); + if (!backend_config) { + op->emitOpError() << "is missing attribute '" << kTfBackendConfigAttrName + << "'"; + return WalkResult::interrupt(); + } + auto called_func = backend_config.get(kCalledFuncAttrName) + .dyn_cast_or_null(); + if (!called_func) { + op->emitOpError() << "is missing attribute '" << kCalledFuncAttrName + << "'"; + return WalkResult::interrupt(); + } + + llvm::SmallVector new_config; + // Copy the attributes in the current config except `called_func`. + for (auto attr : backend_config) { + if (attr.getName() != kCalledFuncAttrName) { + new_config.push_back(attr); + } + } + + auto [it, inserted] = + called_indexes.insert({called_func, called_indexes.size()}); + if (inserted) { + function_list.push_back(called_func); + } + + // Set the `called_index` attribute to the TF function's name. + new_config.push_back(builder.getNamedAttr( + kCalledIndexAttrName, builder.getI64IntegerAttr(it->second))); + + // Set the `tf.backend_config` attribute to the `new_config`. + op->setAttr(kTfBackendConfigAttrName, + builder.getDictionaryAttr(new_config)); + + return WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return failure(); + } + + return builder.getArrayAttr(function_list); +} + +// Creates a pruned module containing the XlaCallModule's entry function and +// other functions transitively called by the entry function. +FailureOr> PruneStablehloModule( + SymbolTableCollection& symbol_table, ModuleOp module, XlaCallModuleOp op) { + auto entry_func_symbol = + op->getAttrOfType(kStablehloEntryFunctionAttrName); + if (!entry_func_symbol) { + return op.emitOpError() << "does not have " + << kStablehloEntryFunctionAttrName << " attribute"; + } + auto entry_func = + symbol_table.lookupSymbolIn(module, entry_func_symbol); + if (!entry_func) { + return op.emitOpError() + << "references an unknown entry function " << entry_func_symbol; + } + + OpBuilder builder(module.getContext()); + + OwningOpRef stablehlo_module = + builder.create(op.getLoc()); + builder.setInsertionPointToEnd(stablehlo_module->getBody()); + + // Copy all referenced StableHLO functions to the new module. + WalkResult result = WalkReachableFunctions( + entry_func, + [&](func::FuncOp f) -> WalkResult { + if (!f->hasAttr(kFromXlaCallModuleAttrName)) { + return WalkResult::advance(); + } + + auto cloned = llvm::cast(builder.clone(*f)); + cloned->removeAttr(kFromXlaCallModuleAttrName); + + if (f == entry_func) { + // Entry function must be public and has symbol name "@main". + cloned.setPublic(); + cloned.setName(kStablehloMainFunctionName); + } else { + cloned.setPrivate(); + } + + return WalkResult::advance(); + }, + &symbol_table); + if (result.wasInterrupted()) { + return failure(); + } + + // Rewrite `custom_call`'s `called_func` attribute to `called_index`. + auto function_list = DesymbolizeCustomCallCalledIndex(*stablehlo_module); + if (failed(function_list)) return failure(); + op.setFunctionListAttr(*function_list); + + // Restore the deserialized stablehlo module's attributes to the reconstructed + // stablehlo module. The stablehlo module's attributes can contain important + // information such as SPMD num_replicas and num_partitions. + auto original_stablehlo_module_attrs = + op->getAttrOfType(kStablehloModuleAttrsAttrName); + if (original_stablehlo_module_attrs) { + (*stablehlo_module)->setAttrs(original_stablehlo_module_attrs); + // Now, remove the attribute because later passes may not know how to handle + // it, we may encounter errors such as: + // "Unhandled attribute kind for attribute '_stablehlo_module_attrs'". + op->removeAttr(kStablehloModuleAttrsAttrName); + } + + return stablehlo_module; +} + +// Serializes the stablehlo module into bytecode. +FailureOr SerializeStablehlo(ModuleOp stablehlo_module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + if (mlir::failed(stablehlo::serializePortableArtifact( + stablehlo_module, stablehlo::getCurrentVersion(), os))) { + return stablehlo_module.emitError() + << "failed to serialize the pruned stablehlo module"; + } + return bytecode; +} + +// Serializes the stablehlo functions called by XlaCallModuleOp to bytecode +// and embeds the bytecode in XlaCallModuleOp's `module` attribute. +// +// The stablehlo functions include the function referred by XlaCallModuleOp's +// `_entry_function` attribute, and any stablehlo functions called transitively +// from the entry function. +LogicalResult SerializeXlaCallModule(SymbolTableCollection& symbol_table, + ModuleOp module, XlaCallModuleOp op) { + auto stablehlo_module = PruneStablehloModule(symbol_table, module, op); + if (failed(stablehlo_module)) { + return failure(); + } + + auto bytecode = SerializeStablehlo(**stablehlo_module); + if (failed(bytecode)) { + return failure(); + } + + op.setModule(*bytecode); + op->removeAttr(kStablehloEntryFunctionAttrName); + + return success(); +} + +// Removes the serialized stablehlo functions, because `XlaCallModuleOp` no +// longer has `_entry_function` attribute referencing the stablehlo main +// function, so all stablehlo functions are of no use in the top-level module. +// +// Walk the module to find functions with `_from_xla_call_module` attribute, +// and remove them. +void RemoveSerializedStablehloFunctions(ModuleOp module) { + module.walk([&](func::FuncOp f) { + if (f->hasAttr(kFromXlaCallModuleAttrName)) { + f->erase(); + } + }); +} + +class XlaCallModuleSerializationPass + : public impl::XlaCallModuleSerializationPassBase< + XlaCallModuleSerializationPass> { + public: + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + mlir::SymbolTableCollection symbol_table; + + mlir::WalkResult result = + module.walk([&](mlir::TF::XlaCallModuleOp xla_call_module) { + if (failed(SerializeXlaCallModule(symbol_table, module, + xla_call_module))) { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return signalPassFailure(); + } + + RemoveSerializedStablehloFunctions(module); + } +}; + +} // namespace + +std::unique_ptr> +CreateXlaCallModuleSerializationPass() { + return std::make_unique(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc index f876231ab00..03e05816992 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc @@ -20,11 +20,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" - -inline constexpr absl::string_view kEntryFunctionAttr = "tf.entry_function"; +#include "tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h" namespace mlir { @@ -63,37 +61,25 @@ void EncapsulatePartitionedCall(Operation *call_op) { } void XlaClusterFormationPass::runOnOperation() { + auto has_compile_device_type = [](SymbolUserOpInterface op) { + return op->hasAttr(tensorflow::kCompileDeviceTypeAttr); + }; + ModuleOp module = getOperation(); SymbolTable symtab(module); - - llvm::SmallVector entry_funcs; - // A model may have multiple graphs, with each graph having its own entry. - // When a graph is imported to MLIR, `tf.entry_function` will be added to - // each entry function. The one exception are initializer functions, which - // have `tf_saved_model.initializer_type` instead. - module.walk([&](func::FuncOp func) { - if (func->hasAttr(kEntryFunctionAttr) || - func->hasAttr(tf_saved_model::kTfSavedModelInitializerTypeAttr)) { - entry_funcs.push_back(func); - } - }); - if (entry_funcs.empty()) { - LOG(WARNING) << "no entry function is found"; - } - auto predicate = [](Operation *op) { - if (op->hasAttr(tensorflow::kCompileDeviceTypeAttr)) return true; - return false; - }; - for (auto &root : entry_funcs) { - llvm::SmallVector outermost_call_ops; - if (failed(GetOutermostOpsOfType( - root, symtab, outermost_call_ops, predicate))) + llvm::SmallVector entry_funcs = GetEntryFunctions(module); + for (auto &entry_func : entry_funcs) { + llvm::SmallVector outermost_pcall_ops; + if (failed(GetFirstOpsOfType( + entry_func, symtab, /*predicate*/ has_compile_device_type, + outermost_pcall_ops))) { return signalPassFailure(); + } // Cluster outermost partitioned calls with _xla_compile_device_type // attribute. - for (auto &call_op : outermost_call_ops) { - EncapsulatePartitionedCall(call_op); + for (auto &pcall_op : outermost_pcall_ops) { + EncapsulatePartitionedCall(pcall_op); } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc index c59d6e532d0..1992f43a951 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This transformation pass converts stateful and stateless paritioned calls +// This transformation pass converts stateful and stateless partitioned calls // with _xla_compile_device_type attribute to XLA launch ops. #include @@ -21,9 +21,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h" #define DEBUG_TYPE "tf-xla-rewrite" @@ -52,7 +52,7 @@ void MoveResourceArgsToEnd(func::FuncOp callee) { removed_params.push_back(false); } } - // Remove old reousrce-type parameters. + // Remove old resource-type parameters. callee.getBody().front().eraseArguments(removed_params); // Update function type. callee.setFunctionType(FunctionType::get(callee.getContext(), @@ -98,20 +98,6 @@ void XlaRewritePass::runOnOperation() { module.walk([&](tf_device::ClusterFuncOp cluster_func_op) { RewriteCall(cluster_func_op, symtab, builder); }); - - // Verify that there are no nested XLA launch ops. - module.walk([&](TF::XlaLaunchOp xla_launch_op) { - llvm::SmallVector nested_launch_ops; - func::FuncOp root = symtab.lookup( - xla_launch_op.getFunctionAttr().getRootReference()); - if (failed(GetOutermostOpsOfType(root, symtab, - nested_launch_ops))) - return signalPassFailure(); - if (!nested_launch_ops.empty()) { - xla_launch_op.emitError() << "Nested XLA launch ops detected"; - return signalPassFailure(); - } - }); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc new file mode 100644 index 00000000000..7891a672bdb --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_validate_inputs.cc @@ -0,0 +1,102 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h" + +namespace mlir { + +namespace { + +#define GEN_PASS_DEF_XLAVALIDATEINPUTSPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc" + +// Validate input graph. +struct XlaValidateInputsPass + : public impl::XlaValidateInputsPassBase { + void runOnOperation() override; +}; + +LogicalResult has_nested_entry_functions( + const llvm::SmallVector &entry_funcs, SymbolTable &symtab) { + auto calls_entry_functions = [&](SymbolUserOpInterface op) { + llvm::SmallVector callees; + if (GetCallees(op, symtab, callees).failed()) { + return false; + } + for (auto &callee : callees) { + if (IsEntryFunction(callee)) { + return true; + } + } + return false; + }; + + for (auto &entry_func : entry_funcs) { + llvm::SmallVector calls; + if (GetFirstOpsOfType( + entry_func, symtab, /*predicate*/ calls_entry_functions, calls) + .failed()) { + return failure(); + } + if (!calls.empty()) { + // Some passes in MLIR GPU phase 1 pipeline uses entry functions as start + // point for tree traversal (input graphs are transformed to trees in + // GuaranteeAllFuncsOneUsePass). They will not work properly if there are + // nested calls of entry fucntions. We can add a pass after + // GuaranteeAllFuncsOneUsePass to remove "tf.entry_function" or + // "tf_saved_model.initializer_type" attribute from the callee of the + // inner calls + entry_func->emitError() + << "CPU/GPU MLIR phase 1 pipeline does not support nested calls of " + "entry functions. Remove tf.entry_function or " + "tf_saved_model.initializer_type from the called functions in the " + "inner calls after GuaranteeAllFuncsOneUsePass to add the support"; + return failure(); + } + } + return success(); +} + +void XlaValidateInputsPass::runOnOperation() { + ModuleOp module = getOperation(); + SymbolTable symtab(module); + llvm::SmallVector entry_funcs = GetEntryFunctions(module); + if (entry_funcs.empty()) { + LOG(WARNING) << "missing entry functions"; + } + + if (has_nested_entry_functions(entry_funcs, symtab).failed()) { + return signalPassFailure(); + } +} + +} // namespace + +namespace TFDevice { +std::unique_ptr> CreateXlaValidateInputsPass() { + return std::make_unique(); +} +} // namespace TFDevice + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 61490f6a749..74cf8423270 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -557,7 +557,7 @@ StatusOr> Exporter::Convert( llvm::dyn_cast(inst)) { Operation& inner_op = island.GetBody().front(); auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef()); - if (op_name.ok()) { + if (llvm::isa(inner_op) && op_name.ok()) { // If it is TF Control dialect specific op, look up custom operation // in the module and first convert that, then add it to function // definition library diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index b8ba989b33b..8b23293dadc 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -1764,7 +1764,7 @@ mlir::Location ImporterBase::GetLocation(const Node& node) { // finally to just name. if (auto stack_trace = node.GetStackTrace()) { DVLOG(1) << "Stack available for " << node.name(); - absl::Span frames = stack_trace->ToFrames(); + std::vector frames = stack_trace->ToUncachedFrames(); locations.reserve(frames.size()); for (const StackFrame& frame : llvm::reverse(frames)) { auto file_name = mlir::StringAttr::get(context_, frame.file_name); @@ -1773,7 +1773,6 @@ mlir::Location ImporterBase::GetLocation(const Node& node) { mlir::FileLineColLoc::get(file_name, frame.line_number, 1); locations.push_back(file_line_loc); } - stack_trace->WipeCache(); } else { DVLOG(1) << "No stack trace for " << node.name(); const auto location_it = debug_info.find(debug_info_key); @@ -2486,6 +2485,8 @@ StatusOr> GraphDefImporter::Convert( b.getNamedAttr("_xla_compile_device_type", b.getStringAttr(specs.xla_compile_device_type))); } + attrs.push_back(b.getNamedAttr("allow_soft_placement", + b.getBoolAttr(specs.enable_soft_placement))); } else { // Collects the argument and return nodes by looking up the node names // specified by the user. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index 79d364bf6b2..191676999be 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -93,6 +93,9 @@ struct GraphImportConfig { // If set, use the value as the device type and mark the function graph for // XLA compilation. string xla_compile_device_type; + // If true, enables moving ops to different devices or moving unsupported ops + // out of a compilation cluster. + bool enable_soft_placement = false; }; struct GraphExportConfig { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 233d35d8c01..45bfe3e2e11 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include +#include +#include #include "absl/memory/memory.h" #include "llvm/Support/raw_ostream.h" @@ -45,16 +47,12 @@ limitations under the License. namespace tensorflow { static StatusOr> GraphdefToMlirImport( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, - const std::vector& input_arrays, + llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, const std::vector& output_arrays, const std::vector& control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, - bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context) { GraphDef graphdef; TF_RETURN_IF_ERROR( tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef)); @@ -62,19 +60,21 @@ static StatusOr> GraphdefToMlirImport( TF_RETURN_IF_ERROR(ByteSwapTensorContentInGraphDef(&graphdef)); GraphDebugInfo debug_info; - if (!debug_info_file.empty()) { - TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info)); + if (!import_options.debug_info_file.empty()) { + TF_RETURN_IF_ERROR( + LoadProtoFromFile(import_options.debug_info_file, &debug_info)); } GraphImportConfig specs; - specs.prune_unused_nodes = prune_unused_nodes; - specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs; - specs.graph_as_function = graph_as_function; - specs.upgrade_legacy = upgrade_legacy; - specs.enable_shape_inference = enable_shape_inference; + specs.prune_unused_nodes = import_options.prune_unused_nodes; + specs.convert_legacy_fed_inputs = import_options.convert_legacy_fed_inputs; + specs.graph_as_function = import_options.graph_as_function; + specs.upgrade_legacy = import_options.upgrade_legacy; + specs.enable_shape_inference = import_options.enable_shape_inference; specs.unconditionally_use_set_output_shapes = - unconditionally_use_set_output_shapes; - specs.xla_compile_device_type = xla_compile_device_type; + import_options.unconditionally_use_set_output_shapes; + specs.xla_compile_device_type = import_options.xla_compile_device_type; + specs.enable_soft_placement = import_options.enable_soft_placement; TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes, input_shapes, &specs.inputs)); TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs)); @@ -109,22 +109,15 @@ static StatusOr> GraphdefToMlirImport( } StatusOr> GraphdefToMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, - const std::vector& input_arrays, + llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, const std::vector& output_arrays, const std::vector& control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, - bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( - input, debug_info_file, xla_compile_device_type, input_arrays, - input_dtypes, input_shapes, output_arrays, control_output_arrays, - prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, - upgrade_legacy, enable_shape_inference, - unconditionally_use_set_output_shapes, context); + input, input_arrays, input_dtypes, input_shapes, output_arrays, + control_output_arrays, import_options, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); } @@ -132,13 +125,10 @@ StatusOr> GraphdefToMlirTranslateFunction( } StatusOr> GraphdefToMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, absl::string_view input_arrays, + llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, - bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context) { std::vector input_array_vector; std::vector input_dtype_vector; std::vector>> input_shapes_vector; @@ -151,11 +141,9 @@ StatusOr> GraphdefToMlirTranslateFunction( TF_RETURN_IF_ERROR( ParseNodeNames(control_output_arrays, control_output_array_vector)); return GraphdefToMlirTranslateFunction( - input, debug_info_file, xla_compile_device_type, input_array_vector, - input_dtype_vector, input_shapes_vector, output_array_vector, - control_output_array_vector, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, - enable_shape_inference, unconditionally_use_set_output_shapes, context); + input, input_array_vector, input_dtype_vector, input_shapes_vector, + output_array_vector, control_output_array_vector, import_options, + context); } StatusOr> SavedModelObjectGraphToMlirImport( @@ -252,22 +240,15 @@ SavedModelSignatureDefsToMlirImportLite( StatusOr> GraphdefToSplattedMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, - const std::vector& input_arrays, + llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, const std::vector& output_arrays, const std::vector& control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, - bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( - input, debug_info_file, xla_compile_device_type, input_arrays, - input_dtypes, input_shapes, output_arrays, control_output_arrays, - prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, - upgrade_legacy, enable_shape_inference, - unconditionally_use_set_output_shapes, context); + input, input_arrays, input_dtypes, input_shapes, output_arrays, + control_output_arrays, import_options, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); return module_or.status(); @@ -306,13 +287,10 @@ GraphdefToSplattedMlirTranslateFunction( StatusOr> GraphdefToSplattedMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, absl::string_view input_arrays, + llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, - bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context) { std::vector input_array_vector; std::vector input_dtype_vector; std::vector>> input_shapes_vector; @@ -325,11 +303,9 @@ GraphdefToSplattedMlirTranslateFunction( TF_RETURN_IF_ERROR( ParseNodeNames(control_output_arrays, control_output_array_vector)); return GraphdefToSplattedMlirTranslateFunction( - input, debug_info_file, xla_compile_device_type, input_array_vector, - input_dtype_vector, input_shapes_vector, output_array_vector, - control_output_array_vector, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, - enable_shape_inference, unconditionally_use_set_output_shapes, context); + input, input_array_vector, input_dtype_vector, input_shapes_vector, + output_array_vector, control_output_array_vector, import_options, + context); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index 677c09dd027..ff53e066964 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/base/macros.h" #include "absl/strings/string_view.h" @@ -34,24 +35,30 @@ namespace tensorflow { using tsl::Status; using tsl::StatusOr; +struct GraphdefToMlirOptions { + std::string debug_info_file; + std::string xla_compile_device_type; + bool prune_unused_nodes; + bool convert_legacy_fed_inputs; + bool graph_as_function; + bool upgrade_legacy; + bool enable_shape_inference; + bool unconditionally_use_set_output_shapes; + bool enable_soft_placement; +}; + // TODO(antiagainst): Directly manipulating files in library functions is not // a good idea. We should pass in a string/stream here. // Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. // Creates MLIR entities into the given MLIR `context`. StatusOr> GraphdefToMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, - const std::vector& input_arrays, + llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, const std::vector& output_arrays, const std::vector& control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, - // TODO(jpienaar): Remove these. - bool enable_shape_inference, bool unconditionally_use_set_output_shapes, - mlir::MLIRContext* context); + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); ABSL_DEPRECATED( "Please use the other overload of this function which accepts structured " @@ -59,32 +66,21 @@ ABSL_DEPRECATED( // Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. // Creates MLIR entities into the given MLIR `context`. StatusOr> GraphdefToMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, absl::string_view input_arrays, + llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, - // TODO(jpienaar): Remove these. - bool enable_shape_inference, bool unconditionally_use_set_output_shapes, - mlir::MLIRContext* context); + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); // Similar as the above function, but replaces all constant tensors // with randomly generated splat values. StatusOr> GraphdefToSplattedMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, - const std::vector& input_arrays, + llvm::StringRef input, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>& input_shapes, const std::vector& output_arrays, const std::vector& control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, - // TODO(jpienaar): Remove these. - bool enable_shape_inference, bool unconditionally_use_set_output_shapes, - mlir::MLIRContext* context); + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); ABSL_DEPRECATED( "Please use the other overload of this function which accepts structured " @@ -93,15 +89,10 @@ ABSL_DEPRECATED( // with randomly generated splat values. StatusOr> GraphdefToSplattedMlirTranslateFunction( - llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view xla_compile_device_type, absl::string_view input_arrays, + llvm::StringRef input, absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, - bool prune_unused_nodes, bool convert_legacy_fed_inputs, - bool graph_as_function, bool upgrade_legacy, - // TODO(jpienaar): Remove these. - bool enable_shape_inference, bool unconditionally_use_set_output_shapes, - mlir::MLIRContext* context); + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); // Converts a TensorFlow SavedModel stored in the directory with the given // `saved_model_dir` into a MLIR module. Creates MLIR entities into the diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc index d739b3997c5..ac1a6fe6881 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc @@ -130,6 +130,12 @@ opt unconditionally_use_set_output_shapes( "(temporary)"), llvm::cl::init(false)); +// NOLINTNEXTLINE +opt enable_soft_placement( + "tf-enable-soft-placement-on-import", + llvm::cl::desc("Enable soft device placement on import."), + llvm::cl::init(false)); + // Export options. // NOLINTNEXTLINE opt export_entry_func_to_flib( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h index af50bdc185f..ebf5dc0b0a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h @@ -44,6 +44,7 @@ extern llvm::cl::opt upgrade_legacy; // TODO(jpienaar): Temporary flag, flip default and remove. extern llvm::cl::opt enable_shape_inference; extern llvm::cl::opt unconditionally_use_set_output_shapes; +extern llvm::cl::opt enable_soft_placement; // Export options. extern llvm::cl::opt export_entry_func_to_flib; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index 6ce04664a7b..4aa10153e79 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -44,12 +44,16 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) { static OwningOpRef GraphdefToMlirTranslateFunction( llvm::StringRef input, MLIRContext* context) { + tensorflow::GraphdefToMlirOptions options{ + debug_info_file, xla_compile_device_type, + prune_unused_nodes, convert_legacy_fed_inputs, + graph_as_function, upgrade_legacy, + enable_shape_inference, unconditionally_use_set_output_shapes, + enable_soft_placement}; + auto module_or = tensorflow::GraphdefToMlirTranslateFunction( - input, debug_info_file, xla_compile_device_type, input_arrays, - input_dtypes, input_shapes, output_arrays, control_output_arrays, - prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, - upgrade_legacy, enable_shape_inference, - unconditionally_use_set_output_shapes, context); + input, input_arrays, input_dtypes, input_shapes, output_arrays, + control_output_arrays, options, context); if (!module_or.status().ok()) return nullptr; return std::move(module_or).value(); } @@ -59,12 +63,14 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate( static OwningOpRef GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, MLIRContext* context) { + tensorflow::GraphdefToMlirOptions options{ + debug_info_file, xla_compile_device_type, + prune_unused_nodes, convert_legacy_fed_inputs, + graph_as_function, upgrade_legacy, + enable_shape_inference, unconditionally_use_set_output_shapes}; auto module_or = tensorflow::GraphdefToSplattedMlirTranslateFunction( - input, debug_info_file, xla_compile_device_type, input_arrays, - input_dtypes, input_shapes, output_arrays, control_output_arrays, - prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, - upgrade_legacy, enable_shape_inference, - unconditionally_use_set_output_shapes, context); + input, input_arrays, input_dtypes, input_shapes, output_arrays, + control_output_arrays, options, context); if (!module_or.status().ok()) return nullptr; return std::move(module_or).value(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index ed84b747fc6..d11371e395f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include +#include +#include +#include #include "absl/strings/str_split.h" #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h index cf972183f4e..485ac2f7293 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ +#include +#include + #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc index 4a15dace1c8..b2d2d71128a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include + #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc new file mode 100644 index 00000000000..c1e9c9ad24b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +inline constexpr absl::string_view kEntryFunctionAttr = "tf.entry_function"; + +namespace mlir { + +bool IsEntryFunction(func::FuncOp func) { + return func->hasAttr(kEntryFunctionAttr) || + func->hasAttr(tf_saved_model::kTfSavedModelInitializerTypeAttr); +} + +llvm::SmallVector GetEntryFunctions(ModuleOp module) { + llvm::SmallVector entry_funcs; + module.walk([&](func::FuncOp func) { + // A model may have multiple graphs, with each graph having its own entry. + // When a graph is imported to MLIR, `tf.entry_function` will be added to + // each entry function. The one exception are initializer functions, which + // have `tf_saved_model.initializer_type` instead. + if (IsEntryFunction(func)) { + entry_funcs.push_back(func); + } + }); + return entry_funcs; +} + +LogicalResult GetCallees(SymbolUserOpInterface op, SymbolTable &symtab, + llvm::SmallVector &callees) { + for (auto attr : op->getAttrs()) { + auto sym = attr.getValue().dyn_cast(); + if (!sym) continue; + auto callee = symtab.lookup(sym.getRootReference()); + if (!callee) { + // This is not expected to happen in practice. + return op->emitError() + << "Cannot find function " << sym.getRootReference(); + } + callees.push_back(callee); + } + return success(); +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h new file mode 100644 index 00000000000..8a45d6e79c8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h @@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CALL_GRAPH_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CALL_GRAPH_UTIL_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { + +// Check if a function is an entry in an MLIR module. +bool IsEntryFunction(func::FuncOp func); + +// Get all the entry functions in an MLIR module. +llvm::SmallVector GetEntryFunctions(ModuleOp module); + +// Get all the functions referenced in a symber user op and save them in +// `callees`. +LogicalResult GetCallees(SymbolUserOpInterface op, SymbolTable &symtab, + llvm::SmallVector &callees); + +// Find the first op with any of the specified types on the paths rooted at the +// `root` node in a tree. Additional filters can be applied via `predicate`. The +// results are stored in `ops`. +template +LogicalResult GetFirstOpsOfType( + func::FuncOp root, SymbolTable &symtab, + const std::function &predicate, + llvm::SmallVector &ops) { + std::stack worklist; + worklist.push(root); + while (!worklist.empty()) { + func::FuncOp u = worklist.top(); + worklist.pop(); + auto result = u.walk([&](SymbolUserOpInterface op) { + if (llvm::isa(op) && (!predicate || predicate(op))) { + ops.push_back(op); + return WalkResult::advance(); + } + llvm::SmallVector callees; + if (GetCallees(op, symtab, callees).failed()) { + return WalkResult::interrupt(); + } + for (auto callee : callees) { + worklist.push(callee); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) return failure(); + } + return success(); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CALL_GRAPH_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util_test.cc new file mode 100644 index 00000000000..54f30fbe3b4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util_test.cc @@ -0,0 +1,156 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h" + +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(CallGraphUtilTest, GetEntryFunctions) { + const char *const code = R"mlir( +func.func @entry_func_1(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @entry_func_2(%arg0: tensor) -> tensor attributes {tf_saved_model.initializer_type = ""} { + %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} +)mlir"; + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::parseSourceString(code, &context); + ASSERT_TRUE(module); + auto entry_funcs = GetEntryFunctions(*module); + EXPECT_EQ(entry_funcs.size(), 2); + EXPECT_EQ(entry_funcs[0].getSymName(), "entry_func_1"); + EXPECT_EQ(entry_funcs[1].getSymName(), "entry_func_2"); +} + +TEST(CallGraphUtilTest, GetCallees) { + const char *const code = R"mlir( +func.func @entry_func(%arg0: tensor) -> tensor attributes {tf_saved_model.initializer_type = ""} { + %0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @while_cond_func(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + func.return %0 : tensor +} + +func.func @while_body_func(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + func.return %0 : tensor +} + + +)mlir"; + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::parseSourceString(code, &context); + ASSERT_TRUE(module); + mlir::SymbolTable symtab(*module); + llvm::SmallVector callees; + module->walk([&](mlir::SymbolUserOpInterface op) { + auto result = GetCallees(op, symtab, callees).succeeded(); + ASSERT_TRUE(result); + EXPECT_EQ(callees.size(), 2); + EXPECT_EQ(callees[0].getSymName(), "while_body_func"); + EXPECT_EQ(callees[1].getSymName(), "while_cond_func"); + }); +} + +TEST(CallGraphUtilTest, GetFirstOpsOfType) { + const char *const code = R"mlir( +func.func @entry_func(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @while_cond_func(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func.func @while_body_func +func.func @while_body_func(%arg0: tensor) -> (tensor) { + %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @outer_stateful_pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @outer_stateful_pcall_func(%arg0: tensor) -> (tensor) { + %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @inner_stateful_pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @inner_stateful_pcall_func(%arg0: tensor) -> tensor { + %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} +)mlir"; + auto has_compile_device_type = [](mlir::SymbolUserOpInterface op) { + return op->hasAttr(tensorflow::kCompileDeviceTypeAttr); + }; + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module = + mlir::parseSourceString(code, &context); + ASSERT_TRUE(module); + mlir::SymbolTable symtab(*module); + llvm::SmallVector entry_funcs = + GetEntryFunctions(*module); + EXPECT_EQ(entry_funcs.size(), 1); + EXPECT_EQ(entry_funcs[0].getSymName(), "entry_func"); + llvm::SmallVector outermost_pcall_ops; + auto result = + mlir::GetFirstOpsOfType( + entry_funcs[0], symtab, has_compile_device_type, outermost_pcall_ops) + .succeeded(); + ASSERT_TRUE(result); + EXPECT_EQ(outermost_pcall_ops.size(), 1); + auto func = + llvm::dyn_cast(outermost_pcall_ops[0]->getParentOp()); + ASSERT_TRUE(func); + EXPECT_EQ(func.getSymName(), "outer_stateful_pcall_func"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc index f855b7f2c19..df641fb2176 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/cluster_util.h" +#include + #include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index fce8c6f8dcf..4100ce55cf3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" @@ -31,6 +33,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -38,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -268,6 +272,15 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef()); } +StatusOr ConvertTypeToTensorSpecProto(const mlir::Type& type) { + DataType dtype; + TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype)); + TensorSpecProto tensor_spec; + tensor_spec.set_dtype(dtype); + *tensor_spec.mutable_shape() = ConvertTypeToTensorShape(type).AsProto(); + return tensor_spec; +} + // Converts the tensor shape proto into an MLIR shape attribute. StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, mlir::MLIRContext* context) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index 9255667c647..227e4bf465f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" namespace tensorflow { @@ -47,6 +48,10 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type); // Converts an MLIR shaped type to a TensorFlow shape attribute. mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type); +// Converts an MLIR shaped type to a Tensorflow tensor spec proto. +absl::StatusOr ConvertTypeToTensorSpecProto( + const mlir::Type& type); + // Converts a TensorFlow shape attribute to an MLIR shape attribute. StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, mlir::MLIRContext* context); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 373e88f7413..f5e58f28689 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -222,5 +222,41 @@ TEST(ConvertTensorProtoTest, NonSplatTensor) { ResultOf(IsSplat, IsFalse()))); } +TEST(ConvertTypeToTensorSpecProtoTest, UnrankedTensorType) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + auto output_proto = ConvertTypeToTensorSpecProto( + mlir::UnrankedTensorType::get(b.getF32Type())); + TF_ASSERT_OK(output_proto.status()); + EXPECT_EQ(output_proto->dtype(), DT_FLOAT); + EXPECT_TRUE(output_proto->shape().unknown_rank()); +} + +TEST(ConvertTypeToTensorSpecProtoTest, RankedTensorType) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + auto output_proto = ConvertTypeToTensorSpecProto( + mlir::RankedTensorType::get({1, 2, 3}, b.getF32Type())); + TF_ASSERT_OK(output_proto.status()); + EXPECT_EQ(output_proto->dtype(), DT_FLOAT); + EXPECT_EQ(output_proto->shape().dim_size(), 3); + EXPECT_EQ(output_proto->shape().dim().at(0).size(), 1); + EXPECT_EQ(output_proto->shape().dim().at(1).size(), 2); + EXPECT_EQ(output_proto->shape().dim().at(2).size(), 3); +} + +TEST(ConvertTypeToTensorSpecProtoTest, ScalarTensorType) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + auto output_proto = ConvertTypeToTensorSpecProto(b.getF32Type()); + TF_ASSERT_OK(output_proto.status()); + EXPECT_EQ(output_proto->dtype(), DT_FLOAT); + EXPECT_FALSE(output_proto->shape().unknown_rank()); + EXPECT_EQ(output_proto->shape().dim_size(), 0); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index 2546fa44a05..45459e31f3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include + #include "absl/strings/str_cat.h" #include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc index 7bc65919030..b844966c7ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include +#include + #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index 92ce8886f8a..326dbbb4781 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include +#include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc index fd100546555..9f3e0113339 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index efcbca84872..f07af4f8b85 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" @@ -87,7 +88,19 @@ struct WritableFileRawStream : public llvm::raw_ostream { SetUnbuffered(); } ~WritableFileRawStream() override = default; - uint64_t current_pos() const override { return 0; } + + uint64_t current_pos() const override { + int64_t position; + if (file->Tell(&position).ok()) { + return position; + } else { + // MLIR uses os.tell() to determine whether something was written by + // a subroutine or not, so it's important we have a working current_pos(). + LOG(WARNING) + << "Couldn't query file position. Stream might be malformed.\n"; + return -1; + } + } void write_impl(const char* ptr, size_t size) override { // Write the file if it is still valid. If the write fails, null out the @@ -154,7 +167,8 @@ Status CreateFileForDumping(llvm::StringRef name, if (dir == kCrashReproducerStdErr) { *os = std::make_unique(); - *filepath = "(stderr)"; + *filepath = + llvm::formatv("(stderr; requested filename: '{0}')", name).str(); return Status(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 6069b8ca2ad..a7760872d79 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ +#include #include #include "absl/strings/string_view.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 908bf40f834..bb474b1413f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include + #include #include #include "llvm/Support/MemoryBuffer.h" @@ -61,7 +63,7 @@ TEST(DumpMlirModuleTest, LogInfo) { setenv("TF_DUMP_GRAPH_PREFIX", "-", 1); std::string filepath = DumpMlirOpToFile("module", module_ref.get()); - EXPECT_EQ(filepath, "(stderr)"); + EXPECT_EQ(filepath, "(stderr; requested filename: 'module')"); } TEST(DumpMlirModuleTest, Valid) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 42bdbf19d2a..6a66067920f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -47,15 +47,13 @@ StatusScopedDiagnosticHandler::StatusScopedDiagnosticHandler( } Status StatusScopedDiagnosticHandler::ConsumeStatus() { - return tensorflow::FromAbslStatus( - BaseScopedDiagnosticHandler::ConsumeStatus()); + return BaseScopedDiagnosticHandler::ConsumeStatus(); } Status StatusScopedDiagnosticHandler::Combine(Status status) { - absl::Status absl_s = - BaseScopedDiagnosticHandler::Combine(tensorflow::ToAbslStatus(status)); + absl::Status absl_s = BaseScopedDiagnosticHandler::Combine(status); - return tensorflow::FromAbslStatus(absl_s); + return absl_s; } } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index 925c2dfc57b..260caf3494b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" +#include +#include + #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 5f1c2735972..b51856bc478 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -15,6 +15,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include +#include +#include +#include +#include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 24152cad81c..86ff64b5ed4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -57,7 +57,7 @@ StatusOr> GetOperationNodeDef( // "name" and "device" attributes are ignored by default. Use attrs_to_ignore to // specify any other attributes that should be ignored. Status ConvertAttributes( - const llvm::ArrayRef attrs, + llvm::ArrayRef attrs, const absl::flat_hash_set& attrs_to_ignore, bool remove_ref_type, AttrValueMap* values); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/fake_session.cc b/tensorflow/compiler/mlir/tensorflow/utils/fake_session.cc index ecbc6e12fa1..7d7c2a3c074 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/fake_session.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/fake_session.cc @@ -14,6 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tensorflow/utils/fake_session.h" +#include +#include +#include +#include + #include "absl/strings/match.h" #include "llvm/Support/CommandLine.h" #include "tensorflow/core/common_runtime/device_mgr.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h b/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h index 83d499e0361..213cf4e66e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h @@ -15,6 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_FAKE_SESSION_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_FAKE_SESSION_H_ +#include +#include +#include +#include + #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index b47da952929..7b3312a76a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" +#include + #include "llvm/Support/FileUtilities.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/ToolOutputFile.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc index d29343d83e4..ffd41db7f47 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include #include #include "absl/strings/match.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc index 550b9c87b77..477d2948d25 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h" +#include + #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc index db365a0c910..3709f88c4d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include +#include + #include "llvm/Support/raw_ostream.h" #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc index fdb1ebc39a9..2895ebdc9c6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tensorflow/utils/session_utils.h" +#include +#include + #include "absl/status/status.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.h index 3d009cbda37..be2d3786cb7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SESSION_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SESSION_UTILS_H_ +#include +#include + #include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h index d672c967060..040429ccf73 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ +#include + #include "tensorflow/core/ir/utils/shape_inference_utils.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc new file mode 100644 index 00000000000..549b665f044 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc @@ -0,0 +1,76 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h" + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TF { +namespace { + +// jax2tf sets `stablehlo.custom_call`'s target name as `tf.call_tf_function` +// to represent calling a TF host callback function. +constexpr llvm::StringRef kTfTargetName = "tf.call_tf_function"; + +// `tf.backend_config` is a DictionaryAttr, JAX2TF sets the value of its +// string attribute `caller_name` to the TF host callback function's name. +constexpr llvm::StringRef kTfBackendConfigAttrName = "tf.backend_config"; +constexpr llvm::StringRef kCalledFuncAttrName = "called_func"; + +} // namespace + +bool IsTfFuncCustomCall(stablehlo::CustomCallOp op) { + return op.getCallTargetName() == kTfTargetName; +} + +DictionaryAttr GetTfBackendConfig(stablehlo::CustomCallOp op) { + return op->getAttrOfType(kTfBackendConfigAttrName); +} + +FailureOr GetTfFuncCustomCallFuncName( + stablehlo::CustomCallOp op) { + if (!IsTfFuncCustomCall(op)) { + return success(nullptr); + } + + auto config = GetTfBackendConfig(op); + if (config == nullptr) { + op.emitOpError() << "does not have dictionary attribute '" + << kTfBackendConfigAttrName << "'"; + return failure(); + } + + auto f = config.get(kCalledFuncAttrName); + if (f == nullptr) { + op.emitOpError() << "does not have attribute '" << kCalledFuncAttrName + << "' in its dictionary attribute '" + << kTfBackendConfigAttrName << "'"; + return failure(); + } + + if (auto attr = f.dyn_cast()) { + return attr; + } + + op.emitOpError() << "'s attribute '" << kCalledFuncAttrName + << "' is neither StringAttr nor FlatSymbolRefAttr"; + return failure(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h new file mode 100644 index 00000000000..7bb38112f77 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ + +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir { +namespace TF { + +// Returns whether the custom call op represents a TF function call. +bool IsTfFuncCustomCall(stablehlo::CustomCallOp op); + +// Returns the `called_func` symbol ref attribute in the `tf.backend_config` +// dictionary attribute. +// +// If the op does not represent a TF function call, returns nullptr. +// Otherwise, if the op does not have `caller_name`, returns failure. +FailureOr GetTfFuncCustomCallFuncName( + stablehlo::CustomCallOp op); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/string_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/string_util.cc new file mode 100644 index 00000000000..7fd832e7604 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/string_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tensorflow/utils/string_util.h" + +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace tensorflow { + +// Return a string form of `op` including debug information. +std::string OpAsString(mlir::Operation& op) { + std::string out; + llvm::raw_string_ostream op_stream(out); + op.print(op_stream, mlir::OpPrintingFlags() + .elideLargeElementsAttrs() + .assumeVerified() + .skipRegions() + .printGenericOpForm()); + return out; +} + +std::string AttrAsString(mlir::Attribute& attr) { + std::string out; + llvm::raw_string_ostream attr_stream(out); + attr.print(attr_stream); + return out; +} + +std::ostream& operator<<(std::ostream& o, const LoggableOperation& op) { + return o << OpAsString(op.v); +} + +std::ostream& operator<<(std::ostream& o, const LoggableAttribute& attr) { + return o << AttrAsString(attr.v); +} + +std::ostream& operator<<(std::ostream& o, const LoggableStringRef& ref) { + return o << ref.v.str(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/string_util.h b/tensorflow/compiler/mlir/tensorflow/utils/string_util.h new file mode 100644 index 00000000000..56410385c20 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/string_util.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STRING_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STRING_UTIL_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +// Utility functions for dumping operations/attributes as strings and ostream +// bindings. + +namespace tensorflow { +std::string OpAsString(mlir::Operation& op); +std::string AttrAsString(mlir::Attribute& attr); + +// b/281863212 enable automatic without Op/AttrAsString. +// We add logging via a wrapper struct in order to respect ODS and avoid +// multiple symbol definitions if MLIR or someone else decides to add ostream +// definitions for the MLIR symbols. +struct LoggableOperation { + mlir::Operation& v; + // NOLINTNEXTLINE(google-explicit-constructor) + LoggableOperation(mlir::Operation& v) : v(v) {} +}; +std::ostream& operator<<(std::ostream& o, const LoggableOperation& op); + +struct LoggableAttribute { + mlir::Attribute& v; + // NOLINTNEXTLINE(google-explicit-constructor) + LoggableAttribute(mlir::Attribute& v) : v(v) {} +}; +std::ostream& operator<<(std::ostream& o, const LoggableAttribute& attr); + +struct LoggableStringRef { + const llvm::StringRef& v; + // NOLINTNEXTLINE(google-explicit-constructor) + LoggableStringRef(const llvm::StringRef& v) : v(v) {} +}; +std::ostream& operator<<(std::ostream& o, const LoggableStringRef& ref); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STRING_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 69671753ca9..2853816dd87 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include #include #include "absl/strings/str_join.h" @@ -89,7 +91,7 @@ mlir::LogicalResult PrintHloModuleText( compilation_result.computation->proto(), module_config); if (!status_or_hlo_module.ok()) { LOG(ERROR) << "Conversion to HLO module failed: " - << status_or_hlo_module.status().ToString(); + << status_or_hlo_module.status(); return mlir::failure(); } @@ -315,7 +317,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl( auto args_status = ParseArgumentShapes(mlir::StringRefToView(input_shapes), arg_shapes); if (!args_status.ok()) { - LOG(ERROR) << args_status.ToString(); + LOG(ERROR) << args_status; return mlir::failure(); } @@ -334,8 +336,7 @@ static mlir::LogicalResult MlirTfToHloTextTranslateFunctionImpl( /*shape_determination_fns=*/{}, &compilation_result, custom_legalization_passes); if (!compilation_status.ok()) { - LOG(ERROR) << "TF/XLA compilation failed: " - << compilation_status.ToString(); + LOG(ERROR) << "TF/XLA compilation failed: " << compilation_status; return mlir::failure(); } @@ -351,7 +352,7 @@ static mlir::LogicalResult MlirTfGraphToHloTextTranslateFunction( mlir::StringRefToView(input_shapes), mlir::StringRefToView(input_dtypes), mlir::StringRefToView(input_types), xla_arguments); if (!args_status.ok()) { - LOG(ERROR) << args_status.ToString(); + LOG(ERROR) << args_status; return mlir::failure(); } @@ -363,8 +364,7 @@ static mlir::LogicalResult MlirTfGraphToHloTextTranslateFunction( /*shape_determination_fns=*/{}, &compilation_result, /*custom_legalization_passes=*/{}); if (!compilation_status.ok()) { - LOG(ERROR) << "TF/XLA compilation failed: " - << compilation_status.ToString(); + LOG(ERROR) << "TF/XLA compilation failed: " << compilation_status; return mlir::failure(); } @@ -403,7 +403,7 @@ SerializedMlirStringAttrToMlirModuleTranslate(llvm::StringRef input, auto status = DeserializeMlirModule(str_attr.getValue().str(), context, &module_ref); if (!status.ok()) { - LOG(ERROR) << status.ToString(); + LOG(ERROR) << status; return nullptr; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc index b07b4ad6f5a..9c82c728f5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include +#include + #include "mlir/Analysis/CallGraph.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h index 69e4bc0593b..46ead1b827b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_CLUSTER_UTIL_H_ +#include +#include +#include + #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 449e0532cf0..c7ce98aff86 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" @@ -222,7 +223,7 @@ StatusOr GetFullMeshTPUExecutionDeviceAssignment( // Helper struct for keeping track of task and device for an associated TPU // device coordinate. struct TaskAndDevice { - TaskAndDevice() {} + TaskAndDevice() = default; TaskAndDevice(int task, int device) : task(task), device(device) {} int task = -1; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index 77f853be582..183688cd88c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" @@ -42,7 +43,7 @@ inline constexpr absl::string_view kDeviceAssignmentAttr = "device_assignment"; // A TPU device for execution alongside its associated host CPU device. struct TPUDeviceAndHost { - TPUDeviceAndHost() {} + TPUDeviceAndHost() = default; TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host) : device(device), host(host) {} diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 2f33ccd88b2..fb88bc8bc44 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include #include +#include #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor.cc b/tensorflow/compiler/mlir/tensorflow/utils/visitor.cc new file mode 100644 index 00000000000..517a56de5de --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor.cc @@ -0,0 +1,132 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor.h" + +#include + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TF { + +WalkResult WalkReachableFunctions( + func::FuncOp func, + llvm::function_ref callback, + SymbolTableCollection* symbol_table) { + llvm::SmallDenseSet visited; + + llvm::SmallVector stack; + stack.push_back(func); + + while (!stack.empty()) { + func::FuncOp f = stack.back(); + stack.pop_back(); + + if (!visited.insert(f).second) { + continue; + } + + WalkResult result = callback(f); + if (result.wasInterrupted()) { + return result; + } else if (result.wasSkipped()) { + continue; + } + + result = f.walk([&](Operation* op) { + const auto uses = SymbolTable::getSymbolUses(op); + if (!uses.has_value()) { + op->emitOpError() << "contains a potentially unknown symbol table"; + return WalkResult::interrupt(); + } + + for (const SymbolTable::SymbolUse& use : *uses) { + func::FuncOp called_func = + symbol_table != nullptr + ? symbol_table->lookupNearestSymbolFrom( + use.getUser(), use.getSymbolRef()) + : SymbolTable::lookupNearestSymbolFrom< + func::FuncOp>(use.getUser(), use.getSymbolRef()); + if (called_func == nullptr) { + op->emitOpError() + << "refers to an unknown symbol (expects a function)"; + return WalkResult::interrupt(); + } + stack.push_back(called_func); + } + + return WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return result; + } + } + + return WalkResult::advance(); +} + +FailureOr> CreatePrunedModule( + ModuleOp module, llvm::ArrayRef function_names) { + SymbolTableCollection symbol_table; + OpBuilder builder(module.getContext()); + + OwningOpRef pruned = + builder.create(module->getLoc()); + (*pruned)->setAttrs(module->getAttrs()); + builder.setInsertionPointToEnd(pruned->getBody()); + + llvm::SmallDenseSet added; + for (const llvm::StringRef function_name : function_names) { + auto func = + llvm::dyn_cast_or_null(symbol_table.lookupSymbolIn( + module, builder.getStringAttr(function_name))); + if (func == nullptr) { + return module.emitError() + << "Cannot find function '" << function_name << "'"; + } + + const WalkResult result = WalkReachableFunctions( + func, + [&](func::FuncOp f) { + if (!added.insert(f).second) { + return WalkResult::skip(); + } + builder.clone(*f); + return WalkResult::advance(); + }, + &symbol_table); + if (result.wasInterrupted()) { + return failure(); + } + } + + return pruned; +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor.h b/tensorflow/compiler/mlir/tensorflow/utils/visitor.h new file mode 100644 index 00000000000..6a7ada0bdb8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Walks the function by following function call chains and calling the callback +// for each reachable function (including `func`). Each function is visited only +// once even if it's called from multiple places and/or recursively. +// +// The current implementation follows direct calls to `mlir::func::FuncOp` only +// and returns a `mlir::WalkResult::interrupt()` when it encounters a call whose +// callee cannot be resolved to `mlir::func::FuncOp`. +mlir::WalkResult WalkReachableFunctions( + mlir::func::FuncOp func, + llvm::function_ref callback, + mlir::SymbolTableCollection* symbol_table = nullptr); + +// Creates a new MLIR module that contains only the given functions and all +// reachable functions from them. +mlir::FailureOr> CreatePrunedModule( + mlir::ModuleOp module, llvm::ArrayRef function_names); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h new file mode 100644 index 00000000000..5f8275b21d9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_CALL_MODULE_ATTRS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_CALL_MODULE_ATTRS_H_ + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace TF { + +// The main function's name in the serialized stablehlo module embedded in +// XlaCallModule's `module` attribute. +constexpr llvm::StringRef kStablehloMainFunctionName = "main"; + +// After deserializing the stablehlo functions from XlaCallModule, +// this XlaCallModule attribute refers to the deserialized stablehlo main +// function. +constexpr llvm::StringRef kStablehloEntryFunctionAttrName = "_entry_function"; + +// Every stablehlo function deserialized from XlaCallModule has this attribute. +constexpr llvm::StringRef kFromXlaCallModuleAttrName = "_from_xla_call_module"; + +// Name of `tf.XlaCallModule`'s dictionary attribute for keeping the +// deserialized stablehlo module's attributes. +constexpr llvm::StringRef kStablehloModuleAttrsAttrName = + "_stablehlo_module_attrs"; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_CALL_MODULE_ATTRS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index e55ba55caf9..838624e0d2f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -41,47 +41,6 @@ namespace { constexpr char kNumSplitAttr[] = "num_split"; -// Gets the proper tensor dimension from XLA OpSharding. -// "replicate_on_last_tile_dim" and "last_tile_dims" should be deducted from the -// real Tensor dimensions when tiled. -// For example: -// f32[8,512](sharding={devices=[1,1,2]0,1 last_tile_dims={REPLICATED}) -// also means a replicated tensor over all devices. -// -// See xla_data.proto for detailed explanations on the fields. -int GetDimsFromXLAShardingTiled(const xla::OpSharding& xla_sharding) { - return xla_sharding.tile_assignment_dimensions_size() - - (xla_sharding.replicate_on_last_tile_dim() ? 1 : 0) - - xla_sharding.last_tile_dims_size(); -} - -// A sharding with OTHER type may be REPLICATED if: -// 'replicate_on_last_tile_dim' is true OR -// 'last_tile_dims' is not empty -// AND -// other than replicated last tile dims, all other dims are not sharded. -bool IsOtherReplicatedSharding(const xla::OpSharding& xla_sharding) { - int max_dim = GetDimsFromXLAShardingTiled(xla_sharding); - for (int i = 0; i < max_dim; ++i) { - if (xla_sharding.tile_assignment_dimensions(i) != 1) { - return false; - } - } - return xla_sharding.type() == xla::OpSharding::OTHER && - (xla_sharding.replicate_on_last_tile_dim() || - !xla_sharding.last_tile_dims().empty()); -} - -bool IsSplitSharding(const xla::OpSharding& sharding) { - return sharding.type() == xla::OpSharding::OTHER && - !IsOtherReplicatedSharding(sharding); -} - -bool IsReplicatedSharding(const xla::OpSharding& sharding) { - return sharding.type() == xla::OpSharding::REPLICATED || - IsOtherReplicatedSharding(sharding); -} - // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways // in 'split_dimension' dimension and returns the split values. mlir::LogicalResult CreateSplitOp(const int num_split, @@ -241,6 +200,34 @@ bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { } // namespace +int GetDimsFromXLAShardingTiled(const xla::OpSharding& xla_sharding) { + return xla_sharding.tile_assignment_dimensions_size() - + (xla_sharding.replicate_on_last_tile_dim() ? 1 : 0) - + xla_sharding.last_tile_dims_size(); +} + +bool IsOtherReplicatedSharding(const xla::OpSharding& xla_sharding) { + int max_dim = GetDimsFromXLAShardingTiled(xla_sharding); + for (int i = 0; i < max_dim; ++i) { + if (xla_sharding.tile_assignment_dimensions(i) != 1) { + return false; + } + } + return xla_sharding.type() == xla::OpSharding::OTHER && + (xla_sharding.replicate_on_last_tile_dim() || + !xla_sharding.last_tile_dims().empty()); +} + +bool IsSplitSharding(const xla::OpSharding& sharding) { + return sharding.type() == xla::OpSharding::OTHER && + !IsOtherReplicatedSharding(sharding); +} + +bool IsReplicatedSharding(const xla::OpSharding& sharding) { + return sharding.type() == xla::OpSharding::REPLICATED || + IsOtherReplicatedSharding(sharding); +} + mlir::LogicalResult ExtractInputsForLogicalDevices( const int num_cores_per_replica, mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 3297b9aa5b5..715a9ce1c1a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -37,25 +37,24 @@ inline constexpr absl::string_view kOutputShardingAttr = // Parses "input_sharding_configuration" attribute and returns a list where i-th // element is a list of mlir::Value's which represent inputs for the TPU -// computation correponding to i-th logical device. If the attribute does not +// computation corresponding to i-th logical device. If the attribute does not // exist, the all inputs are placed on logical core 0. mlir::LogicalResult ExtractInputsForLogicalDevices( - const int num_cores_per_replica, - mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder, + int num_cores_per_replica, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::OpBuilder* builder, llvm::SmallVectorImpl>* input_list); // Extracts a list of OpSharding that represent output sharding configuration of // `tf_device.cluster`. mlir::LogicalResult ParseAndValidateOutputSharding( - const int num_cores_per_replica, - mlir::tf_device::ClusterFuncOp cluster_func, + int num_cores_per_replica, mlir::tf_device::ClusterFuncOp cluster_func, mlir::SmallVector* output_sharding_list); // Retrieves output types for TPUExecute op representing execution for provided // logical device id. TPUExecute op for different logical device may have // different outputs depending on the output sharding configuration. mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( - const int core_id, llvm::ArrayRef output_sharding_config, + int core_id, llvm::ArrayRef output_sharding_config, mlir::tf_device::ClusterFuncOp cluster_func, llvm::SmallVectorImpl* output_types, llvm::SmallVectorImpl* cluster_to_core_index); @@ -80,6 +79,31 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( llvm::SmallVector, 4> GetMetadataArgumentMapping( const tpu::TPUCompileMetadataProto& metadata); +// Gets the proper tensor dimension from XLA OpSharding. +// "replicate_on_last_tile_dim" and "last_tile_dims" should be deducted from the +// real Tensor dimensions when tiled. +// For example: +// f32[8,512](sharding={devices=[1,1,2]0,1 last_tile_dims={REPLICATED}) +// also means a replicated tensor over all devices. +// +// See xla_data.proto for detailed explanations on the fields. +int GetDimsFromXLAShardingTiled(const xla::OpSharding& xla_sharding); + +// A sharding with OTHER type may be REPLICATED if: +// 'replicate_on_last_tile_dim' is true OR +// 'last_tile_dims' is not empty +// AND +// other than replicated last tile dims, all other dims are not sharded. +bool IsOtherReplicatedSharding(const xla::OpSharding& xla_sharding); + +// Returns whether the sharding is split sharding. i.e. A sharding with OTHER +// type but not replicated. +bool IsSplitSharding(const xla::OpSharding& sharding); + +// Returns whether the sharding is replicated. It includes sharding with +// REPLICATED type and replicated OTHER type. +bool IsReplicatedSharding(const xla::OpSharding& sharding); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v0/BUILD index 18744b3032f..72bee83a841 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v0/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/BUILD @@ -27,6 +27,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:translate_utils", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", + "//tensorflow/compiler/mlir/tf2xla/internal:mlir_pass_instrumentation", "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_targets", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc index 19c148214e1..108f3760f4f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc @@ -56,6 +56,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h" #include "tensorflow/compiler/tf2xla/layout_util.h" @@ -518,6 +519,10 @@ Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type, CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, enable_op_fallback, custom_legalization_passes); + auto pass_instrumentors = mlir::GetPassInstrumentors(); + for (const auto& creator : pass_instrumentors) { + tf2xla.addInstrumentation(creator()); + } if (DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(), kDebugGroupMain) || VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile( diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index a95e558f506..d9d4a963648 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -47,6 +47,7 @@ cc_library( "//tensorflow/core/tpu/kernels:tpu_util_hdrs", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/types:variant", "@llvm-project//mlir:IR", @@ -71,7 +72,7 @@ tf_cc_test( "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "//tensorflow/tsl/lib/monitoring:test_utils", "//tensorflow/tsl/platform:statusor", - "@com_google_googletest//:gtest_main", + "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc index f5f6818d33e..95913f9692a 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/variant.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -266,6 +267,27 @@ tsl::StatusOr LegalizeMlirToHlo( return old_bridge_status; } + if (VLOG_IS_ON(2)) { + xla::DebugOptions debug_options; + TF_ASSIGN_OR_RETURN( + auto hlo_module_config, + xla::HloModule::CreateModuleConfigFromProto( + compilation_result.computation->proto(), debug_options)); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + xla::HloModule::CreateFromProto(compilation_result.computation->proto(), + hlo_module_config)); + + std::string all_computations; + for (auto computation : hlo_module->computations()) { + all_computations += computation->ToString() + "\n\n"; + } + + tensorflow::DumpRawStringToFile("legalize_tf_fallback_hlo", + all_computations); + } + if (filtered_graph) { mlir_second_phase_count->GetCell(kOldBridgeMlirFilteredSuccess) ->IncrementBy(1); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD new file mode 100644 index 00000000000..6913853f682 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -0,0 +1,31 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/tf2xla/api/v0:__subpackages__", + ], +) + +cc_library( + name = "mlir_pass_instrumentation", + srcs = ["mlir_pass_instrumentation.cc"], + hdrs = ["mlir_pass_instrumentation.h"], + deps = [ + "//tensorflow/core/platform:logging", + "@llvm-project//mlir:Pass", + ], +) + +tf_cc_test( + name = "mlir_pass_instrumentation_test", + srcs = ["mlir_pass_instrumentation_test.cc"], + deps = [ + ":mlir_pass_instrumentation", + "//tensorflow/compiler/mlir/tf2xla/api/v0:compile_mlir_util_no_tf_dialect_passes", + "//tensorflow/core:test", + "//tensorflow/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.cc new file mode 100644 index 00000000000..f6366f47011 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.cc @@ -0,0 +1,66 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" + +namespace mlir { + +class MlirPassInstrumentationRegistry { + public: + static MlirPassInstrumentationRegistry& Instance() { + static MlirPassInstrumentationRegistry* r = + new MlirPassInstrumentationRegistry; + return *r; + } + std::unordered_map()>> + instrumentors_; +}; + +void RegisterPassInstrumentor( + const std::string& name, + std::function()> creator) { + MlirPassInstrumentationRegistry& r = + MlirPassInstrumentationRegistry::Instance(); + auto result = r.instrumentors_.emplace(name, creator); + if (!result.second) { + VLOG(1) << "Duplicate MLIR pass instrumentor registration"; + } +} + +std::vector()>> +GetPassInstrumentors() { + MlirPassInstrumentationRegistry& r = + MlirPassInstrumentationRegistry::Instance(); + std::vector()>> result; + result.reserve(r.instrumentors_.size()); + + std::transform(r.instrumentors_.begin(), r.instrumentors_.end(), + std::back_inserter(result), [](auto v) { return v.second; }); + + return result; +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h b/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h new file mode 100644 index 00000000000..f4375dfc562 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_PASS_INSTRUMENTATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_PASS_INSTRUMENTATION_H_ + +#include +#include +#include +#include + +#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project + +namespace mlir { + +void RegisterPassInstrumentor( + const std::string& name, + std::function()> creator); +std::vector()>> +GetPassInstrumentors(); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_PASS_INSTRUMENTATION_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation_test.cc new file mode 100644 index 00000000000..b2a8dde0700 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace mlir { +namespace { +static const char* kTestInstrumentationName = "test-intrumentatron"; +static const char* kTestInstrumentationSearch = "tf.Identity"; + +struct StringStream : public llvm::raw_ostream { + StringStream() { SetUnbuffered(); } + ~StringStream() override = default; + uint64_t current_pos() const override { return 0; } + + void write_impl(const char* ptr, size_t size) override { + ss.write(ptr, size); + } + std::stringstream ss; +}; + +class TestPassInstrumentation : public ::testing::Test { + public: + void SetPassThatChangedIdentity(absl::string_view pass_name) { + pass_that_changed_identity_ = pass_name; + } + absl::string_view GetPassThatChangedIdentity() { + return pass_that_changed_identity_; + } + + private: + std::string pass_that_changed_identity_; + friend class TestInstrumentor; +}; + +class TestInstrumentor : public PassInstrumentation { + public: + explicit TestInstrumentor(TestPassInstrumentation* test) : test_(test) {} + + private: + void runBeforePass(Pass* pass, Operation* op) override { + StringStream stream; + op->print(stream, mlir::OpPrintingFlags().useLocalScope()); + ops_seen_by_pass_[pass] = stream.ss.str(); + } + void runAfterPass(Pass* pass, Operation* op) override { + StringStream stream; + op->print(stream, mlir::OpPrintingFlags().useLocalScope()); + if (!absl::StrContains(stream.ss.str(), kTestInstrumentationSearch) && + absl::StrContains(ops_seen_by_pass_[pass], + kTestInstrumentationSearch)) { + test_->SetPassThatChangedIdentity(pass->getName().str()); + } + } + + private: + TestPassInstrumentation* test_; + std::unordered_map ops_seen_by_pass_; +}; + +TEST_F(TestPassInstrumentation, CreatedCalledAndSetsPassName) { + RegisterPassInstrumentor(kTestInstrumentationName, [&]() { + return std::make_unique(this); + }); + constexpr char legalization[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor>) -> tensor> { + %0 = "tf.Identity"(%arg0) : (tensor>) -> tensor> + func.return %0 : tensor> + } + })"; + SetPassThatChangedIdentity(""); + std::vector<::tensorflow::TensorShape> arg_shapes = {{1}}; + auto compilation_result = tensorflow::XlaCompilationResult(); + + TF_EXPECT_OK(tensorflow::CompileSerializedMlirToXlaHlo( + legalization, arg_shapes, /*device_type=*/"XLA_TPU_JIT", + /*use_tuple_args=*/true, /*enable_op_fallback=*/false, + /*shape_determination_fns=*/{}, &compilation_result)); + + EXPECT_FALSE(GetPassThatChangedIdentity().empty()); +} + +} // namespace +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc index 6479253dd6e..3f35813744c 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc @@ -24,7 +24,7 @@ namespace tensorflow { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, bool is_tpu_graph, + std::optional config_proto, bool run_tpu_bridge, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats) { switch (GetMlirBridgeRolloutState(config_proto)) { diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h index 9f67442205d..5c7f47a219e 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h @@ -53,7 +53,7 @@ enum class MlirBridgeRolloutPolicy { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, bool is_tpu_graph, + std::optional config_proto, bool run_tpu_bridge, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats); diff --git a/tensorflow/compiler/mlir/tf2xla/tests/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/BUILD index a728ed58ad9..c68c485954d 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", size_override = { diff --git a/tensorflow/compiler/mlir/tf2xla/tests/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/tf2xla/tests/convert-mhlo-quant-to-int.mlir index 849c270f083..947c9f85624 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/convert-mhlo-quant-to-int.mlir @@ -7,21 +7,21 @@ func.func @uniform_quantize_and_dequantize(%arg0: tensor) -> tensor : tensor // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor - // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor - // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[HALF]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[HALF]] : (tensor, tensor) -> tensor // CHECK: %[[VAL2:.*]] = mhlo.floor %[[VAL1]] : tensor // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2]] : (tensor) -> tensor - // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[ZPS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor - // CHECK: %[[VAL5:.*]] = chlo.broadcast_maximum %[[VAL4]], %[[QUANT_MIN]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor - // CHECK: %[[VAL6:.*]] = chlo.broadcast_minimum %[[VAL5]], %[[QUANT_MAX]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[ZPS]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL5:.*]] = chlo.broadcast_maximum %[[VAL4]], %[[QUANT_MIN]] : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_minimum %[[VAL5]], %[[QUANT_MAX]] : (tensor, tensor) -> tensor // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor // CHECK: %[[VAL8:.*]] = mhlo.convert %[[VAL7]] : (tensor) -> tensor - // CHECK: %[[VAL9:.*]] = chlo.broadcast_subtract %[[VAL8]], %[[ZPS_DQ]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL9:.*]] = chlo.broadcast_subtract %[[VAL8]], %[[ZPS_DQ]] : (tensor, tensor) -> tensor // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor - // CHECK: %[[VAL11:.*]] = chlo.broadcast_multiply %[[VAL10]], %[[SCALES_DQ]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL11:.*]] = chlo.broadcast_multiply %[[VAL10]], %[[SCALES_DQ]] : (tensor, tensor) -> tensor // CHECK: return %[[VAL11]] : tensor %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> %1 = mhlo.uniform_dequantize %0 : (tensor>) -> tensor @@ -33,7 +33,7 @@ func.func @uniform_quantize_and_dequantize(%arg0: tensor) -> tensor>) -> () { // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor>) -> tensor> - // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor>, tensor) -> tensor> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor>, tensor) -> tensor> %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor, #mhlo.type_extensions> %1 = mhlo.uniform_dequantize %0 : (tensor, #mhlo.type_extensions>) -> tensor> return @@ -42,10 +42,10 @@ func.func @uniform_quantize_and_dequantize_type_exensions(%arg0: tensor>) -> () { - // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor>) -> tensor> - // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor>, tensor) -> tensor> - %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> - %1 = mhlo.uniform_dequantize %0 : (tensor, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor> +func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor>) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor>) -> tensor> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] : (tensor>, tensor) -> tensor> + %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>> + %1 = mhlo.uniform_dequantize %0 : (tensor, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>) -> tensor> return } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_runtime_pipeline.mlir b/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_runtime_pipeline.mlir index 28e1c0c37a1..11e19b3b1a6 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_runtime_pipeline.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_runtime_pipeline.mlir @@ -9,7 +9,7 @@ func.func @simple_add(%arg0: tensor) -> tensor { // ----- -#CSR = #sparse_tensor.encoding<{dimLevelType = [ "dense", "compressed" ]}> +#CSR = #sparse_tensor.encoding<{lvlTypes = [ "dense", "compressed" ]}> // CHECK-LABEL: func.func @csr_gendot( // CHECK-SAME: %[[PTR:.*0]]: memref, @@ -53,8 +53,8 @@ func.func @csr_gendot(%arg0: tensor<32x64xf64, #CSR>, // ----- -#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }> -#DCSR = #sparse_tensor.encoding<{ dimLevelType = ["compressed", "compressed"] }> +#CSR = #sparse_tensor.encoding<{ lvlTypes = ["dense", "compressed"] }> +#DCSR = #sparse_tensor.encoding<{ lvlTypes = ["compressed", "compressed"] }> // CHECK-LABEL: func.func @convert_nop( // CHECK-SAME: %[[PTR:.*0]]: memref, diff --git a/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_sparsification.mlir b/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_sparsification.mlir index 3a5d2e95e24..920c2ab744c 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_sparsification.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/hlo_xla_sparsification.mlir @@ -1,6 +1,6 @@ // RUN: tf-opt -hlo-legalize-to-linalg -hlo-xla-runtime-sparsification %s | FileCheck %s -#SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] }> +#SparseVector = #sparse_tensor.encoding<{ lvlTypes = ["compressed"] }> // CHECK-LABEL: func.func @mult_sparse_dense( // CHECK-SAME: %[[PTR:.*0]]: memref, diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir index fefaeb7d589..bd1b01fa171 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir @@ -89,6 +89,15 @@ func.func @xla_all_reduce_mul(%input: tensor) -> tensor { func.return %0 : tensor } +// ----- + +func.func @xla_all_reduce_tuple(%input: tuple, tensor>) -> tuple, tensor> { + %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> + // expected-error@+1 {{'tf.XlaAllReduce' op operand #0 must be tensor of bfloat16 or 16-bit float or 32-bit float or 32-bit integer or 32-bit unsigned integer values, but got 'tuple, tensor>'}} + %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplica"} : (tuple, tensor>, tensor<2x1xi32>) -> tuple, tensor> + func.return %0 : tuple, tensor> +} + // ----- diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir index 3ca8dc09a80..92fa37f7e44 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir @@ -4352,7 +4352,7 @@ func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: conv_dynamic func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor { // CHECK: "mhlo.dynamic_conv" - // CHECK-SAME: {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor + // CHECK-SAME: {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor, tensor<3x3x3x16xf32>) -> tensor func.return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir index 1d6bfb6bcd7..730e3ec09c4 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -553,4 +553,157 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr %values, %indices = "tf.ApproxTopK"(%0) {aggregate_to_topk = true, device = "", is_max_k = true, k = 10 : i64, recall_target = 0.949999988 : f32, reduction_dimension = -1 : i64, reduction_input_size_override = -1 : i64} : (tensor<10x500xbf16>) -> (tensor<10x10xbf16>, tensor<10x10xi32>) return %values : tensor<10x10xbf16> } + + // CHECK-LABEL: fusedBatchNormV3_noTraining + func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> + } + + // CHECK-LABEL: fusedBatchNormV3_training + func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> + } + + // CHECK-LABEL: fusedBatchNormGradV3_noTraining + func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> + // CHECK: %[[scr1:.*]] = mhlo.rsqrt + // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> + // CHECK: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[convert_init:.*]] = mhlo.convert %[[init]] : tensor + // CHECK: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[convert_init]]) across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32> + + // CHECK: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + + // CHECK: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> + + // CHECK: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32> + // CHECK: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[convert_init2:.*]] = mhlo.convert %[[init2]] : tensor + // CHECK: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[convert_init2]]) across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> + // CHECK: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32> + + // CHECK: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32> + // CHECK: return %[[x_backprop]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> + } + + // CHECK-LABEL: fusedBatchNormGradV3_Training + func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { + // CHECK-NEXT: %[[grad:.*]] = mhlo.convert %arg0 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[act:.*]] = mhlo.convert %arg1 : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[grad_operand:.*]], %[[grad_scale:.*]], %[[grad_offset:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[x_backprop:.*]] = mhlo.convert %[[grad_operand]] : tensor<8x8x8x8xf32> + // CHECK: return %[[x_backprop]] + // CHECK-SAME: tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) + func.return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> + } + + // CHECK-LABEL: @max_pool_grad_valid + // CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> + func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ({ + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]] : (tensor, tensor) -> tensor + // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + explicit_paddings = [], + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> + func.return %result : tensor<10x24x24x64xf32> + } + + // CHECK-LABEL: @max_pool_grad_same + func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + explicit_paddings = [], + ksize = [1, 2, 3, 1], + padding = "SAME", + strides = [1, 4, 4, 1] + } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> + func.return %result : tensor<2x13x25x7xf32> + } + + //===--------------------------------------------------------------------===// + // tf.XlaReduceScatter legalization + //===--------------------------------------------------------------------===// + // CHECK-LABEL: func @xla_reduce_scatter + func.func @xla_reduce_scatter(%arg0: tensor<128x128xf32>) -> tensor<64x128xf32> { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + // CHECK: "mhlo.reduce_scatter"(%arg0) + // CHECK{LITERAL}: replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> + // CHECK-SAME: scatter_dimension = 0 + // + %1 = "tf.XlaReduceScatter"(%arg0, %cst_0, %cst) {reduce_op = "Add"} : (tensor<128x128xf32>, tensor<4x2xi32>, tensor) -> tensor<64x128xf32> + func.return %1 : tensor<64x128xf32> + } + + // CHECK-LABEL: func @tf_mod + func.func @tf_mod(%arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = "tf.Const"() {value = dense<7.000000e+00> : tensor} : () -> tensor + // CHECK: "mhlo.dynamic_broadcast_in_dim" + // CHECK: mhlo.remainder + %6 = "tf.Mod"(%arg1, %cst) {_global_shape = [#tf_type.shape<4x8>], device = ""} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %6 : tensor<2x2xf32> + } + + // CHECK-LABEL: func @concat_v2 + func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> + func.return %1 : tensor<6x3xf32> + } + + // CHECK-LABEL: func @xla_call_module + func.func @xla_call_module(%arg0: tensor) -> tensor<*xf32> { + // Equivalent to the following: + // + // module @jit_sin { + // func.func public @main(%arg0: tensor) -> tensor { + // %0 = mhlo.sine %arg0 : tensor + // return %0 : tensor + // } + // } + // CHECK: call @main.2 + %0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], function_list = [], disabled_checks = [], has_token_input_output = false, module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = ["CPU"], version = 6 : i64} : (tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> + } + + // Verifies that the following functions are added from xla_call_module. Note this must be at the end of the file. + // CHECK: func.func private @main.2(%arg0: tensor {mhlo.sharding = "{replicated}"}) -> tensor { + // CHECK: %0 = mhlo.sine %arg0 : tensor + // CHECK: return %0 : tensor + // CHECK: } + } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir index 3e550e0366c..90ef7c88910 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir @@ -491,6 +491,21 @@ func.func @approx_topk(%arg0: tensor>> {tf return %values : tensor<10x10xbf16> } +// CHECK-LABEL: func @xla_call_module +func.func @xla_call_module(%arg0: tensor) -> tensor<*xf32> { + // Equivalent to the following: + // + // module @jit_sin { + // func.func public @main(%arg0: tensor) -> tensor { + // %0 = mhlo.sine %arg0 : tensor + // return %0 : tensor + // } + // } + // expected-remark@+1 {{UNIMPLEMENTED: MlirHloBuilder does not support op call}} + %0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], function_list = [], disabled_checks = [], has_token_input_output = false, module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = ["CPU"], version = 6 : i64} : (tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 19fe43f0250..3bce71fa26a 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -4499,7 +4499,7 @@ func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: conv_dynamic func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor { // CHECK: "mhlo.dynamic_conv" - // CHECK-SAME: {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor + // CHECK-SAME: {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor, tensor<3x3x3x16xf32>) -> tensor func.return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index df4bf8fa204..181d6b582b9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -5,6 +5,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl/platform:build_config_root.bzl", "if_static") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -159,6 +160,7 @@ cc_library( "//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/tsl/platform:bfloat16", "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:Dialect", @@ -171,7 +173,7 @@ cc_library( "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@stablehlo//:chlo_ops", - ], + ] + if_static(["//tensorflow/tsl/platform:tensor_float_32_utils"]), ) cc_library( diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/tf2xla/transforms/convert_mhlo_quant_to_int.cc index 382b059ebb4..0343ae3b96b 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/convert_mhlo_quant_to_int.cc @@ -114,7 +114,6 @@ class ConvertUniformQuantizeOp op->getLoc(), rewriter.getI32IntegerAttr(static_cast( element_type.getStorageTypeMax()))); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); auto res_float_tensor_type_or = GetSameShapeTensorType(op, op.getOperand().getType().cast(), rewriter.getF32Type(), rewriter); @@ -123,10 +122,9 @@ class ConvertUniformQuantizeOp } Value res_float = rewriter.create( op->getLoc(), *res_float_tensor_type_or, adaptor.getOperand(), scale, - scalar_broadcast_dims); + nullptr); res_float = rewriter.create( - op->getLoc(), *res_float_tensor_type_or, res_float, half, - scalar_broadcast_dims); + op->getLoc(), *res_float_tensor_type_or, res_float, half, nullptr); res_float = rewriter.create(op->getLoc(), res_float); auto res_int32_tensor_type_or = GetSameShapeTensorType(op, res_float.getType().cast(), @@ -138,13 +136,13 @@ class ConvertUniformQuantizeOp op->getLoc(), *res_int32_tensor_type_or, res_float); res_int32 = rewriter.create( op->getLoc(), *res_int32_tensor_type_or, res_int32, zero_point, - scalar_broadcast_dims); + nullptr); res_int32 = rewriter.create( op->getLoc(), *res_int32_tensor_type_or, res_int32, quantization_min, - scalar_broadcast_dims); + nullptr); res_int32 = rewriter.create( op->getLoc(), *res_int32_tensor_type_or, res_int32, quantization_max, - scalar_broadcast_dims); + nullptr); auto res_final_tensor_type_or = GetSameShapeTensorType(op, res_int32.getType().cast(), rewriter.getI8Type(), rewriter); @@ -177,7 +175,6 @@ class ConvertUniformDequantizeOp static_cast(element_type.getZeroPoint()))); Value input = adaptor.getOperand(); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); auto res_int32_tensor_type_or = GetSameShapeTensorType(op, input.getType().cast(), rewriter.getI32Type(), rewriter); @@ -188,7 +185,7 @@ class ConvertUniformDequantizeOp op->getLoc(), *res_int32_tensor_type_or, input); res_int32 = rewriter.create( op->getLoc(), *res_int32_tensor_type_or, res_int32, zero_point, - scalar_broadcast_dims); + nullptr); auto res_float_tensor_type_or = GetSameShapeTensorType(op, res_int32.getType().cast(), rewriter.getF32Type(), rewriter); @@ -198,7 +195,7 @@ class ConvertUniformDequantizeOp Value res_float = rewriter.create( op->getLoc(), *res_float_tensor_type_or, res_int32); res_float = rewriter.replaceOpWithNewOp( - op, *res_float_tensor_type_or, res_float, scale, scalar_broadcast_dims); + op, *res_float_tensor_type_or, res_float, scale, nullptr); return success(); } }; @@ -213,6 +210,8 @@ void ConvertMHLOQuantToInt::runOnOperation() { patterns.add(context); ConversionTarget target(*op->getContext()); + // An addDynamicallyLegalDialect callback that declares a given operation as + // legal only if its all operands and results are non-quantized types. auto is_legal = [](Operation *op) { auto is_not_quant = [](Type type) { return !getElementTypeOrSelf(type).isa(); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 06d6df007f2..b1bf04bb232 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/tsl/platform/bfloat16.h" #include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/tensor_float_32_utils.h" namespace mlir { namespace mhlo { @@ -150,6 +151,21 @@ static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, return b->getI64IntegerAttr(axis); } +// Returns a PrecisionConfig as an array attribute based on whether TF32 +// execution is enabled +static ArrayAttr GetPrecisionConfig(Builder *builder) { + mlir::mhlo::Precision precision = tsl::tensor_float_32_execution_enabled() + ? mhlo::Precision::DEFAULT + : mlir::mhlo::Precision::HIGHEST; + llvm::SmallVector attr_vec; + const int num_inputs = 2; + for (int i = 0; i < num_inputs; i++) { + attr_vec.push_back( + mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision)); + } + return builder->getArrayAttr(attr_vec); +} + // If `value` is an IntegerAttr, returns the integer value for the HLO axis // corresponding to the tensorflow axis. In particular, the tensorflow axis can // be negative, in which case, the corresponding HLO axis is @@ -1082,6 +1098,9 @@ class ConvertConvDynamic : public OpRewritePattern { auto batch_group_count_attr = rewriter.getNamedAttr( "batch_group_count", rewriter.getI64IntegerAttr(1)); + auto precision_config_attr = rewriter.getNamedAttr( + "precision_config", GetPrecisionConfig(&rewriter)); + Value paddings_op = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(2 * num_spatial_dims, @@ -1105,9 +1124,9 @@ class ConvertConvDynamic : public OpRewritePattern { filter_ty.getElementType()), operands[1]); } - NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, + NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, - batch_group_count_attr}; + batch_group_count_attr, precision_config_attr}; rewriter.replaceOpWithNewOp(op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); @@ -1246,6 +1265,9 @@ class ConvertConvOp : public OpRewritePattern { auto paddings_attr = rewriter.getNamedAttr( "padding", DenseElementsAttr::get(paddings_ty, paddings)); + auto precision_config_attr = rewriter.getNamedAttr( + "precision_config", GetPrecisionConfig(&rewriter)); + SmallVector operands(op.getOperands()); // Reshape the filter to {spatial_dims...., 1,in_channels * // channel_multiplier} @@ -1264,7 +1286,8 @@ class ConvertConvOp : public OpRewritePattern { } NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, - batch_group_count_attr, paddings_attr}; + batch_group_count_attr, paddings_attr, + precision_config_attr}; rewriter.replaceOpWithNewOp(op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); @@ -3160,9 +3183,9 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); // TODO(silvasean): Emit shape checks for contracting dimensions. // (The batch dimensions are checked by the broadcasting logic) - rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs, - dimension_numbers, - /*precision_config=*/nullptr); + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, rhs, dimension_numbers, + /*precision_config=*/GetPrecisionConfig(&rewriter)); return success(); } }; @@ -4958,7 +4981,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { /*outputSpatialDimensions=*/spatial_dims), rewriter.getI64IntegerAttr(feature_group_count), /*batch_group_count=*/rewriter.getI64IntegerAttr(1), - /*precision_config=*/ArrayAttr()); + /*precision_config=*/GetPrecisionConfig(&rewriter)); rewriter.replaceOp(op, {result}); @@ -5165,7 +5188,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { /*outputSpatialDimensions=*/output_spatial_dimensions), /*feature_group_count=*/rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(batch_group_count), - /*precision_config=*/ArrayAttr()); + /*precision_config=*/GetPrecisionConfig(&rewriter)); rewriter.replaceOp(op, {result}); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc index 20eeb67b7d5..4f355d5255c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc @@ -144,6 +144,11 @@ LogicalResult ConvertAllReduce(OpBuilder& builder, int64_t channel_id, Type element_type = getElementTypeOrSelf(input.getType()); auto all_reduce = builder.create( loc, result_type, input, replica_groups, channel_handle, nullptr); + + if (all_reduce.getNumResults() != 1) { + return op->emitOpError() + << "AllReduceOp must have one result: " << *all_reduce; + } if (merge_op == "Add") { BuildReduceBody(element_type, &all_reduce.getComputation(), &builder); @@ -173,7 +178,7 @@ LogicalResult ConvertAllReduce(OpBuilder& builder, int64_t channel_id, GetScalarConstOfType(element_type, loc, replica_group_size, &builder); auto broadcast_dims = GetI64ElementsAttr({}, &builder); result = builder.create( - loc, all_reduce.getResult(), divisor.getResult(), broadcast_dims); + loc, all_reduce.getResult(0), divisor.getResult(), broadcast_dims); } else if (final_op != "Id") { return op->emitOpError() << "invalid final_op " << final_op << ", want one of [Id, Div]"; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index f28ea6958d3..3234e22bf6e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -389,11 +389,14 @@ foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in // MatMul op patterns. //===----------------------------------------------------------------------===// +def GetPrecisionConfig: NativeCodeCall< + "GetPrecisionConfig(&$_builder)">; + def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b), (MHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), - /*precision_config=*/(NullArrayAttr))>; + /*precision_config=*/(GetPrecisionConfig))>; //===----------------------------------------------------------------------===// // Lower `tf.ZerosLike` diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index ddd3b091e23..f5c76c4fecd 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -180,6 +180,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -272,6 +273,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 4117b5ce026..c916c89fb43 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -156,6 +156,13 @@ tsl::StatusOr Tf2XlaRewriter::ImportXlaComputation( return tsl::errors::InvalidArgument("Imported XLA Root is not a tuple op"); } + if (op_->getNumOperands() != + hlo_module->entry_computation()->num_parameters()) { + return tsl::errors::InvalidArgument( + "Entry computation does not have equal number of parameters to op " + "operands"); + } + ModuleOp mlir_module = op_->getParentOfType(); mlir::OpBuilder builder(op_); mlir::SymbolTable symbol_table(mlir_module); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index 4aeb42bd7bd..b6f1b54591b 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -75,8 +75,10 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr XlaComputation GetTestXlaComputation() { XlaBuilder xla_builder("test"); - XlaOp add = xla::Add(xla::ConstantR0(&xla_builder, 1.0), - xla::ConstantR0(&xla_builder, 2.0)); + auto param = + Parameter(&xla_builder, 0, ShapeUtil::MakeScalarShape(xla::F32), "a"); + + XlaOp add = xla::Add(param, xla::ConstantR0(&xla_builder, 2.0)); std::vector tuple_values; tuple_values.push_back(add); @@ -291,7 +293,7 @@ TEST_F(Tf2XlaRewriterTest, InsertsConstantParameters) { LegalizeModule(/*use_tf2xla_hlo_importer=*/true, kModuleWithConstParam)); } -TEST_F(Tf2XlaRewriterTest, DISABLED_ImportsPrivateFunctions) { +TEST_F(Tf2XlaRewriterTest, ErrorsWithInvalidNumberOfParametersToArgs) { XlaBuilder builder("test_builder"); XlaComputation to_apply; { @@ -315,9 +317,9 @@ TEST_F(Tf2XlaRewriterTest, DISABLED_ImportsPrivateFunctions) { EXPECT_EQ(computation.proto().computations_size(), 2); TF_ASSERT_OK(CreateMlirModule()); - TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, - ImportXlaComputationIntoModule(computation)); - EXPECT_TRUE(root_tuple); + tsl::StatusOr status_or_tuple_op = + ImportXlaComputationIntoModule(computation); + EXPECT_FALSE(status_or_tuple_op.ok()); } } // namespace mhlo diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc index 5fe37a04160..e773f5d8b52 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc @@ -48,7 +48,7 @@ class VerifyTFXLALegalization : public impl::VerifyTFXLALegalizationBase { public: explicit VerifyTFXLALegalization(bool legalize_chlo) { - legalize_chlo_ = legalize_chlo_; + legalize_chlo_ = legalize_chlo; } void runOnOperation() override; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index 09d5b91f05a..fe5326206a4 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -730,6 +730,13 @@ const llvm::DenseSet &MlirPreferredOps() { // clang-format off static const llvm::DenseSet* ops = new llvm::DenseSet{ + // Ops that should always use the MLIR legalization. + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + // Ops that are legalized in the old bridge using MlirXlaOpKernel TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td index cfec5714798..727baf76084 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td @@ -46,7 +46,7 @@ def LegalizeTF : Pass<"xla-legalize-tf", "ModuleOp"> { "Prioritize tf2xla fallback legalization over MLIR legalization " "patterns">, Option<"use_tf2xla_hlo_importer_", "use-tf2xla-hlo-importer", - "bool", /*default=*/"false", + "bool", /*default=*/"true", "Use the experimental HLO to MHLO importer for per-op fallback calls " " from MLIR bridge to TF2XLA." "Users should not set this flag and ideally this goes away."> diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 6758aee3b77..00ae360f1e6 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include #include "absl/strings/str_split.h" #include "llvm/Support/InitLLVM.h" diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index f9fff19986e..b1990be9b58 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -194,6 +194,7 @@ tf_cc_binary( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", test_file_exts = ["mlir"], @@ -328,8 +329,8 @@ tf_python_pybind_extension( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tfr", - "//tensorflow/python:pybind11_lib", - "//tensorflow/python:pybind11_status", + "//tensorflow/python/lib/core:pybind11_lib", + "//tensorflow/python/lib/core:pybind11_status", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", @@ -387,8 +388,8 @@ tf_py_test( ":tfr_gen", "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", "//tensorflow/compiler/mlir/tfr/resources:test_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", ], ) diff --git a/tensorflow/compiler/mlir/tfr/examples/customization/BUILD b/tensorflow/compiler/mlir/tfr/examples/customization/BUILD index 748a189e25c..fe4b0ebee47 100644 --- a/tensorflow/compiler/mlir/tfr/examples/customization/BUILD +++ b/tensorflow/compiler/mlir/tfr/examples/customization/BUILD @@ -39,6 +39,6 @@ tf_py_test( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/tfr:test_utils", - "//tensorflow/python:test_ops", + "//tensorflow/python/framework:test_ops", ], ) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD index 4160d864f2e..b54b5fc56fb 100644 --- a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD @@ -54,7 +54,7 @@ py_library( ":mnist_ops", ":mnist_ops_py", "//tensorflow:tensorflow_py", - "//tensorflow/python:framework", + "//tensorflow/python/framework", "@absl_py//absl/flags", ], ) @@ -80,12 +80,12 @@ distribute_py_test( xla_enable_strict_auto_jit = False, deps = [ ":mnist_train", - "//tensorflow/python:client_testlib", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:is_mlir_bridge_test_true", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:test_util", + "//tensorflow/python/framework:is_mlir_bridge_test_true", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index 91a306c1fba..d30b5934691 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -161,6 +161,16 @@ bool TFRType::classof(Type type) { // Custom op methods //===----------------------------------------------------------------------===// +void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + // Direct call. + if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { + auto symRef = callee.get(); + return setCalleeAttr(cast(symRef)); + } + // Indirect call, callee Value is the first operand. + return setOperand(0, callee.get()); +} + LogicalResult ConstantTensorOp::verify() { ConstantTensorOp op = *this; auto input_type = op.getArg().getType(); diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index f9ce81f680b..3746674a8ce 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -143,6 +143,8 @@ def TFR_CallOp : TFR_Op<"call", [CallOpInterface]> { // Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getCalleeAttr(); } + // Sets the callee from the callable + void setCalleeFromCallable(CallInterfaceCallable callee); }]; let assemblyFormat = [{ diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h index d1049e51dd9..c862f0f1b5f 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_types.h +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_types.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ #define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ +#include +#include + #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 9a76d68efd9..5d59d958d3e 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include +#include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/tfr/passes/passes.h b/tensorflow/compiler/mlir/tfr/passes/passes.h index 967a4c35d99..00bf11870ca 100644 --- a/tensorflow/compiler/mlir/tfr/passes/passes.h +++ b/tensorflow/compiler/mlir/tfr/passes/passes.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFR_PASSES_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_TFR_PASSES_PASSES_H_ +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc index 2b69ba782a8..dd85565cfed 100644 --- a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc index c1f34402835..babfef28d33 100644 --- a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc +++ b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "llvm/ADT/StringRef.h" diff --git a/tensorflow/compiler/mlir/tfr/python/test_utils.py b/tensorflow/compiler/mlir/tfr/python/test_utils.py index 22c61d0a5c8..09c1455eae0 100644 --- a/tensorflow/compiler/mlir/tfr/python/test_utils.py +++ b/tensorflow/compiler/mlir/tfr/python/test_utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Test utils for composite op definition.""" from tensorflow.python.eager import backprop +from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -23,6 +24,8 @@ class OpsDefsTest(test.TestCase): op_kwargs=None): if op_kwargs is None: op_kwargs = kwargs + if test_util.IsMklEnabled(): + self.skipTest("Not compatible with oneDNN custom ops.") # compute with op. with backprop.GradientTape() as gt: diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc index 7e580ba61e4..760ddab974c 100644 --- a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc +++ b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.cc b/tensorflow/compiler/mlir/tfr/utils/utils.cc index 4f7a90bb972..3580b7dab7f 100644 --- a/tensorflow/compiler/mlir/tfr/utils/utils.cc +++ b/tensorflow/compiler/mlir/tfr/utils/utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfr/utils/utils.h" +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.h b/tensorflow/compiler/mlir/tfr/utils/utils.h index 7e0c0208254..911015ae0be 100644 --- a/tensorflow/compiler/mlir/tfr/utils/utils.h +++ b/tensorflow/compiler/mlir/tfr/utils/utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFR_UTILS_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TFR_UTILS_UTILS_H_ +#include + #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 068b7cabf22..55ca19518cb 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -22,9 +22,9 @@ package_group( packages = [ "//tensorflow/compiler/...", "//tensorflow/core/runtime_fallback/...", - "//tensorflow/core/tfrt/eager/...", "//tensorflow/core/tfrt/experimental/data/...", "//tensorflow/core/tfrt/graph_executor/...", + "//tensorflow/core/tfrt/mlrt/...", "//tensorflow/core/tfrt/saved_model/...", "//tensorflow/core/tfrt/tfrt_session/...", ] + if_google([ @@ -307,8 +307,10 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@tf_runtime//:basic_kernels_alwayslink", "@tf_runtime//:bef", "@tf_runtime//:befexecutor", + "@tf_runtime//:core_runtime_alwayslink", "@tf_runtime//:hostcontext", "@tf_runtime//:mlirtobef", "@tf_runtime//:support", @@ -425,13 +427,12 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_asset_sinking_pass", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "//tensorflow/core:framework", "//tensorflow/core/platform:status", - "//tensorflow/core/platform:tstring", - "//tensorflow/tsl/platform:status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -531,10 +532,9 @@ cc_library( "translate/import_model.h", ], visibility = [ - # copybara:uncomment "//learning/brain/experimental/tfrt/mlrt/application/tensorflow/compiler/transforms:__pkg__", # copybara:uncomment "//learning/brain/experimental/tfrt/visualization:__pkg__", "//tensorflow/compiler/mlir/tfrt/tests/saved_model:__pkg__", - "//tensorflow/core/tfrt/eager:__pkg__", + "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:__pkg__", "//tensorflow/core/tfrt/graph_executor:__pkg__", "//tensorflow/core/tfrt/saved_model:__pkg__", ], @@ -587,6 +587,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@tf_runtime//:compiler_tfrt_op_interfaces", ], ) @@ -660,29 +661,35 @@ cc_library( deps = [ ":passes", ":test_cost_analysis_pass", + ":test_opkernels", ":test_tensor_array_side_effect_analysis", ":tf_jitrt_opdefs", ":tf_to_tfrt", ":tfrt_jitrt_passes", + ":transforms/gpu_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir:passes", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:bridge_pass_test_pipeline_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/mlir/tfrt:transforms/gpu_passes", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_sync_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_test_passes", + "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:passes", "//tensorflow/compiler/xla/mlir_hlo:gml_st", "//tensorflow/compiler/xla/mlir_hlo:gml_st_passes", "//tensorflow/core:lib", + "//tensorflow/core:tensorflow", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Transforms", "@tf_runtime//:init_tfrt_dialects", "@tf_runtime//:print_stream_pass", "@tf_runtime//backends/jitrt:jitrt_compiler", @@ -892,3 +899,11 @@ cc_library( name = "constants", hdrs = ["constants.h"], ) + +cc_library( + name = "test_opkernels", + testonly = True, + srcs = ["test_opkernels.cc"], + deps = ["//tensorflow/core:framework"], + alwayslink = True, +) diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc index c7d02332839..5573e7c2d46 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" +#include #include +#include #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/constants.h" #include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tfrt/compiler/opdefs/tfrt_op_interfaces.h" // from @tf_runtime namespace tensorflow { namespace tfrt_compiler { @@ -157,6 +160,12 @@ void CostAnalysis::AnalyzeBlock(mlir::Block* block) { } void CostAnalysis::EvaluateCost(mlir::Operation* op) { + if (auto cost_function = + mlir::dyn_cast(op)) { + cost_map_[op] = cost_function.cost(); + return; + } + if (!llvm::isa(op->getDialect())) { cost_map_[op] = max_arg_size_; return; @@ -178,7 +187,7 @@ void CostAnalysis::EvaluateCost(mlir::Operation* op) { const auto op_key_attr = op->getAttrOfType(kOpKeyAttrName); if (op_key_attr) { - cost_map_[op] = cost_recorder_->GetCostNanosecond(op_key_attr.getInt()); + cost_map_[op] = cost_recorder_->GetCost(op_key_attr.getInt()); return; } } diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h index fa01b38dd64..809846619d3 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_ +#include + #include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc index 58ff8929c2a..cf3acc48906 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" +#include #include +#include #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h index c315d2e9917..41f3b93b121 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_BENCHMARKS_BENCHMARK_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_BENCHMARKS_BENCHMARK_H_ +#include +#include #define EIGEN_USE_THREADS #include diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc index a390d365303..2fc595caee0 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc @@ -55,7 +55,7 @@ static llvm::SmallVector GetInputTensors( for (const InputTensorSpec& spec : input_specs) { TensorShape shape; - CHECK(TensorShapeUtils::MakeShape(spec.dims, &shape).ok()); + CHECK_OK(TensorShapeUtils::MakeShape(spec.dims, &shape)); input_tensors.emplace_back(spec.dtype, shape); // Initialize tensors with random data. diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h index d3f7ade5a32..5d8972ec0e2 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_BENCHMARKS_CWISE_OP_UNARY_BENCHMARK_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_BENCHMARKS_CWISE_OP_UNARY_BENCHMARK_H_ +#include +#include #include #include diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.cc index 5e8f9115360..c578b82c17a 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.cc +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.h" + +#include + #include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.h index 22977dfd702..da619834397 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.h +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/reduction_benchmark.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" #include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.h" diff --git a/tensorflow/compiler/mlir/tfrt/constants.h b/tensorflow/compiler/mlir/tfrt/constants.h index dfbb9ba4898..ed6e773c52a 100644 --- a/tensorflow/compiler/mlir/tfrt/constants.h +++ b/tensorflow/compiler/mlir/tfrt/constants.h @@ -23,12 +23,6 @@ namespace tfrt_compiler { inline constexpr char kOpKeyAttrName[] = "__op_key"; } // namespace tfrt_compiler - -namespace mlrt_compiler { - -inline constexpr char kArgPassByValue[] = "mlrt.__pass_by_value"; - -} // namespace mlrt_compiler } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TFRT_CONSTANTS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index f63f0c7ff07..464bef0fe8d 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -62,6 +62,7 @@ cc_library( visibility = [ "//tensorflow/compiler/mlir/tfrt:__subpackages__", # copybara:uncomment "//tensorflow/core/runtime_fallback:internal", + "//tensorflow/core/tfrt/mlrt/application/tensorflow/tests:__subpackages__", ], deps = [ ":tfrt_fallback_common", diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD new file mode 100644 index 00000000000..313b7ee1197 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -0,0 +1,190 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +td_library( + name = "mlrt_td_files", + srcs = [ + "mlrt_dialect.td", + "mlrt_ops.td", + ], + includes = ["."], + visibility = [ + "//tensorflow/core/tfrt/mlrt:__subpackages__", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "mlrt_ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "mlrt_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "mlrt_ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "mlrt_ops.td", + deps = [":mlrt_td_files"], +) + +cc_library( + name = "mlrt_ops", + srcs = [ + "mlrt_dialect.cc", + "mlrt_ops.cc", + ], + hdrs = [ + "mlrt_dialect.h", + "mlrt_ops.h", + ], + visibility = [ + "//tensorflow/compiler/mlir/tfrt:__subpackages__", + "//tensorflow/core/tfrt/mlrt:__subpackages__", + ], + deps = [ + ":mlrt_ops_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + +td_library( + name = "tf_mlrt_td_files", + srcs = [ + "tf_mlrt_dialect.td", + "tf_mlrt_ops.td", + "tf_ops.td", + ], + includes = ["."], + visibility = [ + # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", + # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", + ], + deps = [ + ":mlrt_td_files", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@tf_runtime//:compiler_td_files", + ], +) + +td_library( + name = "tf_mlrt_tpu_td_files", + srcs = [ + "tf_mlrt_tpu_ops.td", + ], + includes = ["."], + visibility = [ + "//tensorflow/core/tfrt/mlrt:__subpackages__", + ], + deps = [ + ":mlrt_td_files", + ":tf_mlrt_td_files", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "tf_mlrt_ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "tf_mlrt_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "tf_mlrt_ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "tf_mlrt_ops.td", + deps = [":tf_mlrt_td_files"], +) + +gentbl_cc_library( + name = "tf_mlrt_tpu_ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "tf_mlrt_tpu_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "tf_mlrt_tpu_ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "tf_mlrt_tpu_ops.td", + deps = [":tf_mlrt_tpu_td_files"], +) + +gentbl_cc_library( + name = "tf_ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "tf_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "tf_ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "tf_ops.td", + deps = [":tf_mlrt_td_files"], +) + +cc_library( + name = "tf_mlrt_ops", + srcs = ["tf_mlrt_ops.cc"], + hdrs = ["tf_mlrt_ops.h"], + visibility = [ + # copybara:uncomment "//learning/brain/experimental/tfrt/mlrt/application/tensorflow/tests:__subpackages__", + # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", + "//tensorflow/compiler/mlir/tfrt:__subpackages__", + "//tensorflow/core/tfrt/mlrt:__subpackages__", + ], + deps = [ + ":mlrt_ops", + ":tf_mlrt_ops_inc_gen", + ":tf_ops_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_side_effects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Transforms", + "@tf_runtime//:compiler_tfrt_op_interfaces", + "@tf_runtime//:compiler_tfrt_traits", + ], +) + +cc_library( + name = "tf_mlrt_tpu_ops", + srcs = ["tf_mlrt_tpu_ops.cc"], + hdrs = ["tf_mlrt_tpu_ops.h"], + visibility = [ + "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:__subpackages__", + "//tensorflow/core/tfrt/mlrt:__subpackages__", + ], + deps = [ + ":mlrt_ops", + ":tf_mlrt_ops", + ":tf_mlrt_tpu_ops_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc new file mode 100644 index 00000000000..50d4cb12142 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc @@ -0,0 +1,95 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" + +namespace mlrt { +namespace compiler { + +namespace { + +struct MlrtInlinerInterface : public mlir::DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(mlir::Operation *op, mlir::Region *dest, + bool would_be_cloned, + mlir::IRMapping &mapping) const final { + // All mlrt dialect ops can be inlined. + return true; + } +}; + +} // namespace + +MlrtDialect::MlrtDialect(mlir::MLIRContext *context) + : mlir::Dialect(/*name=*/"mlrt", context, + mlir::TypeID::get()) { + addTypes(); + addTypes(); + addTypes(); + addInterfaces(); + + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.cpp.inc" + >(); +} + +// Parse a type registered to this dialect. +mlir::Type MlrtDialect::parseType(mlir::DialectAsmParser &parser) const { + llvm::StringRef keyword; + if (parser.parseKeyword(&keyword)) return mlir::Type(); + + if (keyword == "future") return FutureType::get(getContext()); + if (keyword == "promise") return PromiseType::get(getContext()); + if (keyword == "async_handle") return AsyncHandleType::get(getContext()); + + parser.emitError(parser.getNameLoc(), "unknown type: ") << keyword; + return mlir::Type(); +} + +// Print a type registered to this dialect. +void MlrtDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &os) const { + if (type.isa()) { + os << "future"; + return; + } + + if (type.isa()) { + os << "promise"; + return; + } + + if (type.isa()) { + os << "async_handle"; + return; + } + + llvm_unreachable("unexpected mlrt type kind"); +} + +} // namespace compiler +} // namespace mlrt diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h new file mode 100644 index 00000000000..0fb568b44dc --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h @@ -0,0 +1,59 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_DIALECT_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project + +namespace mlrt { +namespace compiler { + +class MlrtDialect : public mlir::Dialect { + public: + explicit MlrtDialect(mlir::MLIRContext *context); + static llvm::StringRef getDialectNamespace() { return "mlrt"; } + + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + void printType(mlir::Type type, mlir::DialectAsmPrinter &os) const override; +}; + +// The MLIR type represents a C++ mlrt::Future. +class FutureType + : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +// The MLIR type represents a C++ mlrt::Promise. +class PromiseType + : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +// The MLIR type represents a C++ mlrt::AsyncHandle. +class AsyncHandleType : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +} // namespace compiler +} // namespace mlrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_DIALECT_H_ diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td new file mode 100644 index 00000000000..b260dcb402f --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td @@ -0,0 +1,55 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifdef MLRT_DIALECT +#else +#define MLRT_DIALECT + +include "mlir/IR/OpBase.td" + +def Mlrt_Dialect : Dialect { + let name = "mlrt"; + + let description = [{ + The MLRT Dialect. + }]; + + let cppNamespace = "::mlrt::compiler"; +} + +def MlrtFutureType : DialectType()">, "!mlrt.future type">, + BuildableType<"$_builder.getType<::mlrt::compiler::FutureType>()"> { + let description = [{ + `!mlrt.future type` represents a C++ mlrt::Future. + }]; +} + +def MlrtPromiseType : DialectType()">, "!mlrt.promise type">, + BuildableType<"$_builder.getType<::mlrt::compiler::PromiseType>()"> { + let description = [{ + `!mlrt.promise type` represents a C++ mlrt::Promise. + }]; +} + +def MlrtAsyncHandleType : DialectType()">, "!mlrt.async_handle type">, + BuildableType<"$_builder.getType<::mlrt::compiler::AsyncHandleType>()"> { + let description = [{ + `!mlrt.async_handle type` represents a C++ mlrt::AsyncHandle. + }]; +} + +#endif diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.cc new file mode 100644 index 00000000000..878b2504de2 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.cc @@ -0,0 +1,28 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.cpp.inc" diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h new file mode 100644 index 00000000000..e3922c6e0ce --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_OPS_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.td new file mode 100644 index 00000000000..24c34fb4a41 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.td @@ -0,0 +1,240 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifdef MLRT_OPS +#else +#define MLRT_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td" + +class Mlrt_Op traits = []> : + Op { +} + +def CondOp: Mlrt_Op<"cond", []> { + let summary = "mlrt.cond op"; + + let description = [{ + Execute $a_true_fn with $args if $cond is true; otherwise, %b_false_fn is + executed. + }]; + + let arguments = (ins + I1:$cond, + Variadic:$args, + SymbolRefAttr:$a_true_fn, + SymbolRefAttr:$b_false_fn + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = [{ + $cond $a_true_fn $b_false_fn `(` $args `)` attr-dict `:` `(` type($args) `)` `->` `(` type($results) `)` + }]; +} + +def AsyncOp: Mlrt_Op<"async", []> { + let summary = "Launches a function asynchronously."; + + let description = [{ + Launch a function asynchronously. + + $args: a list of arguments to be passed. + $callee: The function to be launched. Its return op must not have operands. + + $handle: This op returns a handle object that manages the context of the async execution. + }]; + + let arguments = (ins + Variadic:$args, + SymbolRefAttr:$callee + ); + + let results = (outs + MlrtAsyncHandleType:$handle + ); + + let assemblyFormat = "`(` $args `)` attr-dict `:` functional-type($args, $handle)"; +} + +def AwaitHandleOp: Mlrt_Op<"await_handle", []> { + let summary = "Awaits an async execution "; + + let description = [{ + Awaits an async execution. + + $handle: The handle returned by mlrt.async op. + }]; + + let arguments = (ins + MlrtAsyncHandleType:$handle + ); + + let assemblyFormat = "operands attr-dict"; +} + +def AwaitAllHandleOp: Mlrt_Op<"await_all_handle", []> { + let summary = "Awaits multiple async executions"; + + let description = [{ + Awaits multiple async execution. + + $handles: A list of handles returned by mlrt.async ops. + }]; + + let arguments = (ins + Variadic:$handles + ); + + let assemblyFormat = "operands attr-dict `:` type($handles)"; +} + +def AwaitControlOp: Mlrt_Op<"await_control", []> { + let summary = "Await a signal from a future"; + + let description = [{ + Await a signal, instead of a value, from a future. + + $future: A value of !mlrt.future type. + }]; + + let arguments = (ins + MlrtFutureType:$future + ); + + let assemblyFormat = "operands attr-dict"; +} + +def AwaitAllControlOp: Mlrt_Op<"await_all_control", []> { + let summary = "Awaits multiple signals"; + + let description = [{ + Awaits multiple signals + + $futures: A list of !mlrt.futures + }]; + + let arguments = (ins + Variadic:$futures + ); + + let assemblyFormat = "operands attr-dict `:` type($futures)"; +} + +def PromiseControlOp: Mlrt_Op<"promise_control", []> { + let summary = "Set a control promise"; + + let description = [{ + Set a control promise. + + $promise: A value of !mlrt.promise type. + }]; + + let arguments = (ins + MlrtPromiseType:$promise + ); + + let assemblyFormat = "operands attr-dict"; +} + +def CaseOp : Mlrt_Op<"case"> { + let summary = "An n-way switch statement which calls a single branch function."; + let description = [{ + An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + outputs = branches[0](inputs); + break; + case 1: + outputs = branches[1](inputs); + break; + ... + case [[nbranches-1]]: + default: + outputs = branches[nbranches-1](inputs); + break; + } + ``` + Example: %res = mlrt.case %branch_idx [@branch0, @branch1] (%arg0, %arg1) : (i32, i32) -> (i32) + }]; + + let arguments = (ins I32:$branch_index, + ConfinedAttr]>:$branches, + Variadic:$branch_operands); + + let results = (outs Variadic:$branch_outputs); + let assemblyFormat = [{ + $branch_index $branches `(` $branch_operands `)` attr-dict `:` `(` type($branch_operands) `)` `->` `(` type($branch_outputs) `)` + }]; +} + +def AllocateControlFuturesOp: Mlrt_Op<"allocate_control_futures", [AttrSizedResultSegments]> { + let summary = "Allocate futures and corresponding promises"; + + let description = [{ + Allocate futures and corresponding promises. + + $num: The number of futures to be allocated. + + $promises: There are $num promises, and promises[i] shares the state with futures[i]. + $futures: There are $num futures, and futures[i] shares the state with promises[i]. + }]; + + let arguments = (ins + I32Attr:$num + ); + + let results = (outs + Variadic:$promises, + Variadic:$futures + ); +} + +def WhileOp : Mlrt_Op<"while", []> { + let summary = "while operation"; + let description = [{ + cond: The boolean to control whether the first iteration should be + executed. + operands: The arguments to the first iteration. + results: The results of the last iteration. The number and types of results + excluding the last one are the same as the number and types of operands. The + last element of results is an I1 value that is false. + body_fn: The body function that takes the arguments and returns the results + that includes an I1 value to indicate whether next iteration should be executed. + + The pseudo code: + + while(cond) { + results = body_fn(operands) + cond = results#1 + } + return results + + }]; + + let arguments = (ins I1:$cond, + Variadic:$arguments, + FlatSymbolRefAttr:$body_fn); + + let results = (outs Variadic); + + let assemblyFormat = [{ + $cond $body_fn `(` $arguments `)` attr-dict `:` `(` type($arguments) `)` `->` `(` type(results) `)` + }]; +} +#endif diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td new file mode 100644 index 00000000000..9cf997e0c3e --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td @@ -0,0 +1,56 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifdef TF_MLRT_DIALECT +#else +#define TF_MLRT_DIALECT + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td" + +// TODO(chky,rohitju): Unify this dialect with tfrt_fallback_sync dialect after +// vrooml is using the new interpreter. +def TensorflowMlrt_Dialect : Dialect { + let name = "tf_mlrt"; + + let description = [{ + The TF MLRT Dialect. + }]; + + let cppNamespace = "::tensorflow::tf_mlrt"; +} + +class TensorflowMlrt_Op traits = []> : + Op { +} + +// This corresponds to tensorflow::Tensor. +def TFTensorType : DialectType()">, "!tf_mlrt.tensor type">, + BuildableType<"$_builder.getType<::tensorflow::tf_mlrt::TFTensorType>()"> { + let description = [{ + `!tf_mlrt.tensor type` represents a tensorflow::Tensor. + }]; +} + +// This corresponds to tensorflow::Device* . +def TFDeviceType : DialectType()">, "!tf_mlrt.device type">, + BuildableType<"$_builder.getType<::tensorflow::tf_mlrt::TFDeviceType>()"> { + let description = [{ + `!tf_mlrt.device type` represents a tensorflow::device. + }]; +} + +#endif diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc new file mode 100644 index 00000000000..fc4cb6a93a2 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc @@ -0,0 +1,95 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" + +namespace tensorflow { +namespace tf_mlrt { + +namespace { + +struct TensorflowMlrtInlinerInterface : public mlir::DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(mlir::Operation *op, mlir::Region *dest, + bool would_be_cloned, + mlir::IRMapping &mapping) const final { + // All tf_mlrt dialect ops can be inlined. + return true; + } + // Note that CallOp and ReturnOp are handled by func; so need to implement + // handleTerminator. +}; + +} // namespace + +TensorflowMlrtDialect::TensorflowMlrtDialect(mlir::MLIRContext *context) + : mlir::Dialect(/*name=*/"tf_mlrt", context, + mlir::TypeID::get()) { + addTypes(); + addInterfaces(); + + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.cpp.inc" + >(); +} + +// Parse a type registered to this dialect. +mlir::Type TensorflowMlrtDialect::parseType( + mlir::DialectAsmParser &parser) const { + llvm::StringRef keyword; + if (parser.parseKeyword(&keyword)) return mlir::Type(); + + if (keyword == "tensor") return TFTensorType::get(getContext()); + + parser.emitError(parser.getNameLoc(), "unknown type: ") << keyword; + return mlir::Type(); +} + +// Print a type registered to this dialect. +void TensorflowMlrtDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &os) const { + if (type.isa()) { + os << "tensor"; + return; + } + + llvm_unreachable("unexpected tf_mlrt type kind"); +} + +} // namespace tf_mlrt +} // namespace tensorflow + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cpp.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.cpp.inc" diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h new file mode 100644 index 00000000000..da91450aa19 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_OPS_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tfrt/compiler/opdefs/tfrt_op_interfaces.h" // from @tf_runtime +#include "tfrt/compiler/opdefs/tfrt_traits.h" // from @tf_runtime + +namespace tensorflow { +namespace tf_mlrt { + +class TensorflowMlrtDialect : public mlir::Dialect { + public: + explicit TensorflowMlrtDialect(mlir::MLIRContext *context); + static llvm::StringRef getDialectNamespace() { return "tf_mlrt"; } + + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + void printType(mlir::Type type, mlir::DialectAsmPrinter &os) const override; +}; + +// The MLIR type represents a tensorflow::Tensor. +class TFTensorType + : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +// The MLIR type represents a tensorflow::Device* +class TFDeviceType + : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +} // namespace tf_mlrt +} // namespace tensorflow + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td new file mode 100644 index 00000000000..bbbec10187a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.td @@ -0,0 +1,378 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifdef TF_MLRT_OPS +#else +#define TF_MLRT_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td" +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td" + +def CreateOp: TensorflowMlrt_Op<"createop", []> { + let summary = "The Fallback CreateOp"; + + let description = [{ + The CreateOp creates the tensorflow::OpKernel in the fallback context. + }]; + + let arguments = (ins + StrAttr:$node_def, + I32Attr:$op_key + ); + + let assemblyFormat = "attr-dict"; +} + +def ExecuteOp : TensorflowMlrt_Op<"executeop", []> { + let summary = "The Fallback ExecuteOp"; + let description = [{ + The ExecuteOp executes an operation on the specified device. + }]; + + let arguments = (ins + Variadic:$args, + StrAttr:$node_def, + I32Attr:$op_key + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "`(` $args `)` attr-dict `:` functional-type($args, $results)"; +} + +def ExecuteOpWithDevice: TensorflowMlrt_Op<"executeop.device", []> { + let summary = "The Fallback ExecuteOp with custom device"; + let description = [{ + The ExecuteOp executes an operation on the specified device using a custom device. + }]; + + let arguments = (ins + TFDeviceType:$device, + Variadic:$args, + StrAttr:$node_def, + I32Attr:$op_key + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "`(` $device`)` `(` $args `)` attr-dict `:` functional-type($args, $results)"; +} + +def AsyncExecuteOp : TensorflowMlrt_Op<"async_executeop", []> { + let summary = "The Fallback ExecuteOp for tensorflow::AsyncOpKernel"; + let description = [{ + The ExecuteOp executes an operation on the specified device asynchronously. + }]; + + let arguments = (ins + Variadic:$args, + StrAttr:$node_def, + I32Attr:$op_key + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "`(` $args `)` attr-dict `:` functional-type($args, $results)"; +} + +def AsyncExecuteOpWithDevice : TensorflowMlrt_Op<"async_executeop.device", []> { + let summary = "The Fallback ExecuteOp for tensorflow::AsyncOpKernel"; + let description = [{ + The ExecuteOp executes an operation on the specified device asynchronously. + }]; + + let arguments = (ins + TFDeviceType:$device, + Variadic:$args, + StrAttr:$node_def, + I32Attr:$op_key + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "`(` $device`)` `(` $args `)` attr-dict `:` functional-type($args, $results)"; +} + +def SetResourceOp : TensorflowMlrt_Op<"set_resource", []> { + let summary = "Set a tensor in resource array"; + + let description = [{ + Set a tensor in resource array. + + arg: the tensor to be set in the resource array. + index: the index in the resource array + }]; + + let arguments = (ins + TFTensorType:$arg, + I64Attr:$index + ); + + let results = (outs); + + let assemblyFormat = "operands attr-dict"; +} + +def GetResourceOp : TensorflowMlrt_Op<"get_resource", []> { + let summary = "get a tensor in resource array"; + + let description = [{ + Get a tensor in resource array. + + indices: the indices in the resource array. + results: the tensor values for the corresponding indices. + }]; + + let arguments = (ins + I64ArrayAttr:$indices + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "attr-dict `:` type($results)"; +} + +def AwaitOp: TensorflowMlrt_Op<"await", [Pure]> { + let summary = "Await a tensor from a !mlrt.future"; + + let description = [{ + Await a tensor from a !mlrt.future. + + $future: A value of type !mlrt.future. The underlying value must be a tensorflow tensor. + + $result: a tensorflow tensor. + }]; + + let arguments = (ins + MlrtFutureType:$future + ); + + let results = (outs + TFTensorType:$result + ); + + let assemblyFormat = "operands attr-dict"; +} + +def AwaitAllOp: TensorflowMlrt_Op<"await_all", [Pure]> { + let summary = "Await tensors from a list of !mlrt.future"; + + let description = [{ + Await tensors from a list of !mlrt.future. + + $futures: A list of !mlrt.future. The underlying value must be tensorflow tensors. + + $results: A list of tensorflow tensors. + }]; + + let arguments = (ins + Variadic:$futures + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "$futures attr-dict `:` type($results)"; +} + +def PromiseOp: TensorflowMlrt_Op<"promise", []> { + let summary = "Set a tensor in a promise"; + + let description = [{ + Set a tensor in a promise. + + $promise: A value of type !mlrt.promise. The underlying value must be a tensorflow tensor. + $tensor: A tensorflow tensor. + }]; + + let arguments = (ins + MlrtPromiseType:$promise, + TFTensorType:$tensor + ); + + let assemblyFormat = "operands attr-dict"; +} + +def PromiseFutureOp: TensorflowMlrt_Op<"promise_future", []> { + let summary = "Set a tensor future in a promise"; + + let description = [{ + Set a tensor future in a promise. + + $promise: A value of type !mlrt.promise. The underlying value must be a tensorflow tensor. + $future: A value of type !mlrt.future. Must represents a tensorflow tensor. + }]; + + let arguments = (ins + MlrtPromiseType:$promise, + MlrtFutureType:$tensor + ); + + let assemblyFormat = "operands attr-dict"; +} + + +def AllocateFuturesOp: TensorflowMlrt_Op<"allocate_futures", [AttrSizedResultSegments]> { + let summary = "Allocate futures and promsies for tensorflow tensors"; + + let description = [{ + Allocate futures and promsies for tensorflow tensors. + + $num_futures: The number of futures to be allocated. + + $promises: There are $num_futures promises. promises[i] shares the state with futures[i]. + $futures: There are $num_futures futures. futures[i] shares the state with promises[i]. + }]; + + let arguments = (ins + I32Attr:$num_futures + ); + + let results = (outs + Variadic:$promises, + Variadic:$futures + ); +} + +def TensorToIntOp : TensorflowMlrt_Op<"tensor_to_int32", [Pure]> { + let summary = "Cast a Tensor to int32."; + let description = [{ + Cast a Tensor to int32. + + Example: + %one = tf_mlrt.tensor_to_int32 %src_tenosr + }]; + + let arguments = (ins TFTensorType:$src); + let results = (outs I32:$result); + let assemblyFormat = "operands attr-dict"; +} + +def PredicateOp : TensorflowMlrt_Op<"predicate", [Pure]> { + let summary = "Converts a fallback tensor to a bool"; + + let description = [{ + Note: this kernel is used for CPU tensors. + + Converts a fallback tensor to a bool with the following rules: + + - For 0D tensors, truthiness is determined by comparing against a "zero" + value. For numerical types it is the obvious zero. For strings it is the + empty string. + + - For >0D tensors, truthiness is determined by looking at the number of + elements. If has zero elements, then the result is false. Otherwise the + result is true. + + input: a fallback tensor representing the condition. + device: the name of the tensorflow device that is associated with the + input fallback tensor. + + output: the converted bool. + }]; + + let arguments = (ins + TFTensorType:$input + ); + + let results = (outs + I1:$output + ); + + let assemblyFormat = "$input attr-dict"; +} + +def BatchFunctionOp : TensorflowMlrt_Op<"batch_function", [Pure]> { + let summary = "Fallback ExecuteOp specialized for tf.BatchFunction."; + + let description = [{ + This kernel executes a variant tf.BatchFunction kernel that supports having + the `f` attribute as a bytecode function. + + Example: + %res = tf_mlrt.batch_function(%input, %captured_input) { + device = "/device:CPU:0", + f = @batch_function, + node_def = "..." + } : (!tf_mlrt.tensor,!tf_mlrt.tensor) -> (!tf_mlrt.tensor) + + Note that the trailing number indicates the number of results. + }]; + + let arguments = (ins + Variadic:$args, + StrAttr:$device, + SymbolRefAttr:$f, + StrAttr:$node_def + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "`(` $args `)` attr-dict `:` functional-type($args, $results)"; +} + +def CancelOp: TensorflowMlrt_Op<"cancel", []> { + let summary = "Handle cancellation request."; + + let description = [{ + This kernel will early terminate the program upon cancellation request (e.g. time out). + }]; +} + +def MapFnOp : TensorflowMlrt_Op<"map_fn", [AttrSizedOperandSegments, Pure]> { + let summary = "The Parallel Map for tf_mlrt dialect"; + let description = [{ + The Pmap executes body function in parallel for all ranges up to $max_iterations. + + The pseudo code: + for(int i = 0; i < $max_iterations; i++) { + body_fn(MlrtFture($tensor_list_or_flow_in[i]), + MlrtPromise($tensor_list_or_flow_in[i+1]), + i, i, $invariant_args); + } + + return $tensor_list_or_flow_in[$max_iterations] + }]; + + let arguments = (ins + TFTensorType:$max_iterations, + Variadic:$tensor_list_or_flow_in, + Variadic:$invariant_args, + FlatSymbolRefAttr:$body_fn, + I32Attr:$num_tensor_list_or_flow_in + ); + + let results = (outs + Variadic:$result + ); + + let assemblyFormat = "`(`$max_iterations`,` $tensor_list_or_flow_in`,` $invariant_args `)` attr-dict `:` functional-type(operands, results)"; +} + + +#endif diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.cc new file mode 100644 index 00000000000..94e5d52bde1 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.cc @@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h" + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" + +namespace tensorflow { +namespace tf_mlrt_tpu { + +TensorflowMlrtTpuDialect::TensorflowMlrtTpuDialect(mlir::MLIRContext *context) + : mlir::Dialect(/*name=*/"tf_mlrt_tpu", context, + mlir::TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.cpp.inc" + >(); +} + +} // namespace tf_mlrt_tpu +} // namespace tensorflow + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.cpp.inc" diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h new file mode 100644 index 00000000000..a428488da86 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_TPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_TPU_OPS_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +namespace tensorflow { +namespace tf_mlrt_tpu { + +class TensorflowMlrtTpuDialect : public mlir::Dialect { + public: + explicit TensorflowMlrtTpuDialect(mlir::MLIRContext *context); + static llvm::StringRef getDialectNamespace() { return "tf_mlrt_tpu"; } +}; + +} // namespace tf_mlrt_tpu +} // namespace tensorflow + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_TPU_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.td new file mode 100644 index 00000000000..a207b83c7e5 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.td @@ -0,0 +1,82 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifdef TF_MLRT_TPU_OPS +#else +#define TF_MLRT_TPU_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td" +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td" + +def TensorflowMlrtTpu_Dialect : Dialect { + let name = "tf_mlrt_tpu"; + + let description = [{ + The TF MLRT TPU Dialect. + }]; + + let cppNamespace = "::tensorflow::tf_mlrt_tpu"; +} + +class TensorflowMlrtTpu_Op traits = []> : + Op { +} + +def GetTpuHostDeviceOp : TensorflowMlrtTpu_Op<"get_tpu_host_device", [Pure]> { + let summary = "get the tpu host allocator that implements tensorflow::Device"; + + let results = (outs + TFDeviceType:$device + ); + + let assemblyFormat = "attr-dict"; +} + +def CompileAndExecuteOp : TensorflowMlrtTpu_Op<"compile_and_execute"> { + let summary = "tpu compile and execute operation"; + let description = [{ + tf_mlrt_tpu.compile_and_execute compiles a mlir tpu program and executes the compiled tpu program. + + $mlir_module is a serialized MLIR module with a `main` function that contains target computation. + $metadata is a serialized TPUCompileMetadataProto describing the shapes and types of the inputs to the computation, as well as a mapping onto the TPU pod topology. + $constant_operand_indices are the indices of the inputs that are constant to the TPU program (e.g. weights in inference), the rest of the inputs are input tensors. + constant_operand_indices is sorted in ascending order. + $operands_with_static_shape are indices of operands that are tagged with a maximum static shape. + $producer_name is a string describing the name of the framework that added support for running this portion of the model on TPUs. + + Example: + %rendezvous_key_base, %result = tf_mlrt_tpu.compile_and_execute (%operands) constant_operand_indices = [1, 3] metadata = "metadata..." mlir_module = "mlir_module..." + }]; + let arguments = (ins + Variadic:$operands_and_static_shapes, + DenseI32ArrayAttr:$constant_operand_indices, + StrAttr:$metadata, + StrAttr:$mlir_module, + UI32Attr:$num_operands, + DenseI32ArrayAttr:$operands_with_static_shape, + StrAttr:$producer_name + ); + + let results = (outs + TFTensorType:$rendezvous_key_base, + Variadic:$results + ); + + let assemblyFormat = [{ + `(` $operands_and_static_shapes `)` attr-dict `:` functional-type($operands_and_static_shapes, results) + }]; +} + +#endif diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td new file mode 100644 index 00000000000..7268588749d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.td @@ -0,0 +1,131 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifdef MLRT_TF_OPS +#else +#define MLRT_TF_OPS + +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td" +include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" +include "third_party/tf_runtime/include/tfrt/compiler/opdefs/tfrt_op_interfaces.td" +include "third_party/tf_runtime/include/tfrt/compiler/opdefs/tfrt_traits.td" + +// tf_mlrt.tf_await returns a tensorflow Tensor. It is a fake op that is only +// used during parallelization and has no runtime implementation. +def TFAwaitOp: TensorflowMlrt_Op<"tf_await", [Pure, TFRT_CostFunctionInterface, TFRT_FixedCost<1>]> { + let summary = "Await a tensor from a !mlrt.future"; + + let description = [{ + Await a tensor from a !mlrt.future. + + $future: A value of type !mlrt.future. The underlying value must be a tensorflow tensor. + + $result: a tensorflow tensor. + }]; + + let arguments = (ins + MlrtFutureType:$future + ); + + let results = (outs + TF_Tensor:$result + ); +} + +// tf_mlrt.tf_promise takes a tensorflow Tensor. It is a fake op that is only +// used during parallelization and has no runtime implementation. +def TFPromiseOp: TensorflowMlrt_Op<"tf_promise", [TF_MustExecute, TFRT_CostFunctionInterface, TFRT_FixedCost<1>]> { + let summary = "Set a tensor in a promise"; + + let description = [{ + Set a tensor in a promise. + + $promise: A value of type !mlrt.promise. The underlying value will always be a tensorflow tensor. + $tensor: A tensorflow tensor. + }]; + + let arguments = (ins + MlrtPromiseType:$promise, + TF_Tensor:$tensor + ); +} + +def TFMapFnOp : TensorflowMlrt_Op<"tf_map_fn", [AttrSizedOperandSegments, Pure]> { + let summary = "The Parallel Map for tf_mlrt dialect"; + let description = [{ + The Pmap executes body function in parallel for all ranges up to $max_iterations. + + The pseudo code: + for(int i = 0; i < $max_iterations; i++) { + body_fn(MlrtFture($tensor_list_or_flow_in[i]), + MlrtPromise($tensor_list_or_flow_in[i+1]), + i, i, $invariant_args); + } + + return $tensor_list_or_flow_in[$max_iterations] + }]; + + let arguments = (ins + TF_Tensor:$max_iterations, + Variadic:$tensor_list_or_flow_in, + Variadic:$invariant_args, + FlatSymbolRefAttr:$body_fn, + I32Attr:$num_tensor_list_or_flow_in + ); + + let results = (outs + Variadic:$result + ); + + let assemblyFormat = "`(`$max_iterations`,` $tensor_list_or_flow_in`,` $invariant_args `)` attr-dict `:` functional-type(operands, results)"; +} + +def TFTPUCompileAndExecuteOp : TensorflowMlrt_Op<"tf_tpu_compile_and_execute", [TF_MustExecute]> { + let summary = "tpu compile and execute operation"; + let description = [{ + tf_mlrt_tpu.compile_and_execute compiles a mlir tpu program and executes the compiled tpu program. + + $mlir_module is a serialized MLIR module with a `main` function that contains target computation. + $metadata is a serialized TPUCompileMetadataProto describing the shapes and types of the inputs to the computation, as well as a mapping onto the TPU pod topology. + $constant_operand_indices are the indices of the inputs that are constant to the TPU program (e.g. weights in inference), the rest of the inputs are input tensors. + constant_operand_indices is sorted in ascending order. + $operands_with_static_shape are indices of operands that are tagged with a maximum static shape. + $producer_name is a string describing the name of the framework that added support for running this portion of the model on TPUs. + + Example: + %rendezvous_key_base, %result = tf_mlrt_tpu.compile_and_execute (%operands) constant_operand_indices = [1, 3] metadata = "metadata..." mlir_module = "mlir_module..." + }]; + let arguments = (ins + Variadic:$operands_and_static_shapes, + DenseI32ArrayAttr:$constant_operand_indices, + StrAttr:$metadata, + StrAttr:$mlir_module, + UI32Attr:$num_operands, + DenseI32ArrayAttr:$operands_with_static_shape, + StrAttr:$producer_name + ); + + let results = (outs + TF_Tensor:$rendezvous_key_base, + Variadic:$results + ); + + let assemblyFormat = [{ + `(` $operands_and_static_shapes `)` attr-dict `:` functional-type($operands_and_static_shapes, results) + }]; +} + + +#endif diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc index 9643c041cf6..19d29e506b3 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" +#include + #include "llvm/ADT/STLExtras.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc index 2dc4adfa084..28af77dd5a7 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h" +#include + #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h index 93e75309206..e78d247c038 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_COMMON_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_COMMON_H_ +#include + #include "llvm/ADT/STLExtras.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td index daf76268bc2..bba8a021921 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td @@ -240,4 +240,47 @@ def ConvertFallbackTensorToDhtOp : FallbackSync_Op<"convert_fallback_tensor_to_d let assemblyFormat = "operands attr-dict `:` type($dht)"; } +// TODO(rohitju): This is Ads specific, need to find an appropriate home for it. +def SetSparseMatrixResourceOp : FallbackSync_Op<"set_sparse_matrix_resource", [CoreRT_TypedAttributeTrait]> { + let summary = "Set a Sparse matrix in resource array"; + + let description = [{ + Set a sparse matrix in resource array. + + arg: the matrix to be set in the resource array. + index: the index in the resource array + }]; + + let arguments = (ins + TFTensorType:$arg, + I64Attr:$index + ); + + let results = (outs); + + let assemblyFormat = "operands attr-dict"; +} + +def GetSparseMatrixResourceOp : FallbackSync_Op<"get_sparse_matrix_resource", + [CoreRT_TypedAttributeTrait]> { + let summary = "get a sparse matrix from resource array"; + + let description = [{ + Get a sparse matrix from resource array. + + indices: the indices in the resource array. + results: the tensor values for the corresponding indices. + }]; + + let arguments = (ins + I64ArrayAttr:$indices + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "attr-dict `:` type($results)"; +} + #endif diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc index 18634cdc3a1..f0d5e08ece0 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc index 4f4b9439619..c36cd3f498a 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.h" #include +#include #include #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.cc index cc649528ad5..58f1501b54a 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h" +#include #include #include "absl/time/time.h" diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h index 5a328b7f265..92378ee7ef3 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_H_ +#include #include #include diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc index 6fe3091fed3..ec1d4f199eb 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #define EIGEN_USE_THREADS #include @@ -279,7 +280,7 @@ static std::string AsTensorContent(const MemrefDesc& desc) { } // Gets the session name from the fallback request state. -static const std::string GetSessionName(RequestContext* req_ctx) { +static std::string GetSessionName(RequestContext* req_ctx) { auto* fallback = req_ctx->GetDataIfExists(); if (!fallback) return ""; diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/BUILD b/tensorflow/compiler/mlir/tfrt/python_tests/BUILD index 15574068e61..4f7288d4c9c 100644 --- a/tensorflow/compiler/mlir/tfrt/python_tests/BUILD +++ b/tensorflow/compiler/mlir/tfrt/python_tests/BUILD @@ -15,7 +15,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -27,7 +27,7 @@ py_strict_test( tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -42,7 +42,7 @@ py_strict_test( ], deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -57,7 +57,7 @@ py_strict_test( ], deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -72,7 +72,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -87,7 +87,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -102,7 +102,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -114,7 +114,7 @@ py_strict_test( tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -129,7 +129,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -141,7 +141,7 @@ py_strict_test( tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -154,7 +154,7 @@ py_strict_test( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", "@absl_py//absl/flags", "@absl_py//absl/testing:parameterized", @@ -168,10 +168,11 @@ py_strict_test( tags = [ "no_oss", "no_pip", + "not_run:arm", ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -186,7 +187,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -201,7 +202,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -216,7 +217,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -231,7 +232,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -243,7 +244,7 @@ py_strict_test( tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -258,7 +259,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -271,10 +272,11 @@ py_strict_test( tags = [ "no_oss", "no_pip", # TODO(b/201803253): TFRT pybindings not in OSS. + "not_run:arm", ], deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -289,7 +291,7 @@ py_strict_test( ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -301,10 +303,11 @@ py_strict_test( tags = [ "no_oss", "no_pip", + "not_run:arm", ], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -316,7 +319,7 @@ py_strict_test( tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) @@ -328,7 +331,7 @@ py_strict_test( tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. deps = [ "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", - "//tensorflow/python:client_testlib", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl b/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl index fb183e02362..9e64cfb74f6 100644 --- a/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl +++ b/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl @@ -31,7 +31,6 @@ def _run_regression_test(name, compare_with_tensorflow, vectorize, data): "//third_party/py/numpy", "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tfrt_fallback", - "//tensorflow/python:client_testlib", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:resource_loader", diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index 47572390666..214dd90d009 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h" +#include + #include "absl/strings/str_split.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h index fab391c753b..94b7f73fd73 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/tensorflow/compiler/mlir/tfrt/test_opkernels.cc b/tensorflow/compiler/mlir/tfrt/test_opkernels.cc new file mode 100644 index 00000000000..096a2626b75 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/test_opkernels.cc @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace tf_mlrt { + +REGISTER_OP("TestAsyncIdentity") + .Input("in: T") + .Output("out: T") + .Attr( + "T: {bfloat16, half, float, double, uint8, int8, int16, uint32, int32, " + "int64, complex64, complex128}") + .SetShapeFn(::tensorflow::shape_inference::UnchangedShape); + +class TestAsyncIdentityKernel : public AsyncOpKernel { + public: + explicit TestAsyncIdentityKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor& in = ctx->input(0); + ctx->set_output(0, in); + done(); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TestAsyncIdentityKernel); +}; + +REGISTER_KERNEL_BUILDER(Name("TestAsyncIdentity").Device(DEVICE_CPU), + TestAsyncIdentityKernel); + +} // namespace tf_mlrt +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/tests/BUILD b/tensorflow/compiler/mlir/tfrt/tests/BUILD index f7df07f4708..3bff3d02e5f 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", features = if_oss(["--path=org_tensorflow/tensorflow/compiler/mlir/tfrt"]), diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD b/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD index 91f9e57a9b2..c9b64b7b4fb 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", exclude = ["testdata/**"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc b/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc index 0c50bbbdea8..420a937eaae 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc @@ -47,7 +47,13 @@ absl::flat_hash_map GetOpCostMap(mlir::ModuleOp op) { return op_cost_map; } -TEST(CostUpdateTest, Basic) { +struct TestParams { + uint32_t normalize_ratio = 1; +}; + +class CostUpdateTest : public ::testing::TestWithParam {}; + +TEST_P(CostUpdateTest, Basic) { std::string saved_model_mlir_path = tensorflow::GetDataDependencyFilepath( "tensorflow/compiler/mlir/tfrt/tests/analysis/testdata/test.mlir"); @@ -61,15 +67,17 @@ TEST(CostUpdateTest, Basic) { ASSERT_TRUE(module); // Create a cost recorder with fake cost records. - auto expected_op_cost_map = GetOpCostMap(module.get()); - EXPECT_EQ(expected_op_cost_map.size(), 1); + auto fake_recorded_op_cost_map = GetOpCostMap(module.get()); + EXPECT_EQ(fake_recorded_op_cost_map.size(), 1); unsigned int seed = 23579; - for (auto& [op_key, cost] : expected_op_cost_map) { + for (auto& [op_key, cost] : fake_recorded_op_cost_map) { cost = rand_r(&seed) % 1000; } - tensorflow::tfrt_stub::CostRecorder cost_recorder; - for (const auto& [op_key, cost] : expected_op_cost_map) { - cost_recorder.RecordCostNanosecond(op_key, cost); + tensorflow::tfrt_stub::CostRecorder cost_recorder(GetParam().normalize_ratio); + absl::flat_hash_map expected_op_cost_map; + for (const auto& [op_key, cost] : fake_recorded_op_cost_map) { + cost_recorder.RecordCost(op_key, cost); + expected_op_cost_map[op_key] = cost_recorder.GetCost(op_key); } // Update the TFRT MLIR with the cost recorder. @@ -80,5 +88,8 @@ TEST(CostUpdateTest, Basic) { EXPECT_THAT(got_op_cost_map, ::testing::ContainerEq(expected_op_cost_map)); } +INSTANTIATE_TEST_SUITE_P(CostUpdateTests, CostUpdateTest, + ::testing::Values(TestParams{1}, TestParams{100})); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD index e4662bd66d7..8d49d08b102 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", exclude = ["testdata/**"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc b/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc index 1cf8bc78a2d..bfa9c148174 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_util.h" #include +#include +#include #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD index d4abbb3bc44..3e1391e6e45 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD @@ -8,6 +8,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", features = if_oss(["--path=org_tensorflow/tensorflow/compiler/mlir/tfrt"]), diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD index 16f00e94ca7..e23bbda91e0 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD @@ -9,6 +9,7 @@ package( # copybara:uncomment_begin # # glob_lit_tests( +# name = "all_tests", # data = [":test_utilities"], # driver = "//tensorflow/compiler/mlir:run_lit.sh", # features = if_oss(["--path=org_tensorflow/tensorflow/compiler/mlir/tfrt"]), diff --git a/tensorflow/compiler/xla/mlir_hlo/tosa/tests/BUILD b/tensorflow/compiler/mlir/tfrt/tests/mlrt/BUILD similarity index 54% rename from tensorflow/compiler/xla/mlir_hlo/tosa/tests/BUILD rename to tensorflow/compiler/mlir/tfrt/tests/mlrt/BUILD index a32ee149f5a..8da9c4cf2d5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tosa/tests/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/BUILD @@ -1,14 +1,11 @@ -load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", + driver = "//tensorflow/compiler/mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -17,7 +14,8 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/xla/mlir_hlo/tosa:mhlo-tosa-opt", + "//tensorflow/compiler/mlir/tfrt:tf-tfrt-opt", "@llvm-project//llvm:FileCheck", + "@llvm-project//mlir:run_lit.sh", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/assign_op_key.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/assign_op_key.mlir new file mode 100644 index 00000000000..9677f183b61 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/assign_op_key.mlir @@ -0,0 +1,49 @@ +// RUN: tf-tfrt-opt -split-input-file -tf-mlrt-assign-op-key %s | FileCheck %s + +// CHECK-LABEL: func @main +// CHECK: tf.AddV2 +// CHECK-SAME: {__op_key = 0 : i32} + +// CHECK: tf.AddV2 +// CHECK-SAME: {__op_key = 1 : i32} + +// CHECK: tf.AddV2 +// CHECK-SAME: {__op_key = 2 : i32} + +// CHECK: tf.AddV2 +// CHECK-SAME: {__op_key = 3 : i32} + +// CHECK: tf.Sub +// CHECK-SAME: {__op_key = 4 : i32} + +// CHECK: tf.Sub +// CHECK-SAME: {__op_key = 5 : i32} + +// CHECK: tf.Sub +// CHECK-SAME: {__op_key = 6 : i32} + +// CHECK: tf.Sub +// CHECK-SAME: {__op_key = 7 : i32} + + +// CHECK: [[x:%.*]] = "tf.AddV2" +// CHECK-SAME: {__op_key = 8 : i32} + +// CHECK: return [[x]] + +func.func @main(%a: tensor, %b: tensor) -> tensor { + + %a0 = "tf.AddV2"(%a, %a) : (tensor, tensor) -> tensor + %a1 = "tf.AddV2"(%a0, %a) : (tensor, tensor) -> tensor + %a2 = "tf.AddV2"(%a1, %a) : (tensor, tensor) -> tensor + %a3 = "tf.AddV2"(%a2, %a) : (tensor, tensor) -> tensor + + %b0 = "tf.Sub"(%b, %b) : (tensor, tensor) -> tensor + %b1 = "tf.Sub"(%b0, %b) : (tensor, tensor) -> tensor + %b2 = "tf.Sub"(%b1, %b) : (tensor, tensor) -> tensor + %b3 = "tf.Sub"(%b2, %b) : (tensor, tensor) -> tensor + + %c = "tf.AddV2"(%a3, %b3) : (tensor, tensor) -> tensor + + func.return %c : tensor +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/fuse_mlrt_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/fuse_mlrt_ops.mlir new file mode 100644 index 00000000000..ce750ec73ed --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/fuse_mlrt_ops.mlir @@ -0,0 +1,58 @@ +// RUN: tf-tfrt-opt -split-input-file -tf-mlrt-fuse %s | FileCheck %s + +// CHECK-LABEL: @main +// CHECK-SAME: ([[f0:%.*]]: !mlrt.future, [[f1:%.*]]: !mlrt.future, [[f2:%.*]]: !mlrt.future) +func.func @main(%f0: !mlrt.future, %f1: !mlrt.future, %f2: !mlrt.future) -> (!tf_mlrt.tensor, !tf_mlrt.tensor, !tf_mlrt.tensor) { + // CHECK-NEXT: [[t:%.*]]:3 = tf_mlrt.await_all [[f0]], [[f1]], [[f2]] + // CHECK-NOT: tf_mlrt.await + // CHECK-NEXT: return [[t]]#0, [[t]]#1, [[t]]#2 + %t0 = tf_mlrt.await %f0 + %t1 = tf_mlrt.await %f1 + %t2 = tf_mlrt.await %f2 + func.return %t0, %t1, %t2 : !tf_mlrt.tensor, !tf_mlrt.tensor, !tf_mlrt.tensor +} + +// ----- + +// CHECK-LABEL: @main +// CHECK-SAME: ([[f0:%.*]]: !mlrt.future, [[f1:%.*]]: !mlrt.future, [[f2:%.*]]: !mlrt.future) +func.func @main(%f0: !mlrt.future, %f1: !mlrt.future, %f2: !mlrt.future) -> (!tf_mlrt.tensor, !tf_mlrt.tensor) { + // CHECK-NEXT: [[t:%.*]]:2 = tf_mlrt.await_all [[f0]], [[f1]] + // CHECK-NOT: tf_mlrt.await + // CHECK-NEXT: [[t2:%.*]] = tf_mlrt.executeop([[t]]#0, [[t]]#1) + // CHECK-NEXT: [[t3:%.*]] = tf_mlrt.await [[f2]] + // CHECK-NEXT: return [[t2]], [[t3]] + %t0 = tf_mlrt.await %f0 + %t1 = tf_mlrt.await %f1 + %t2 = tf_mlrt.executeop(%t0, %t1) {node_def = "AddV2", op_key = 0 : i32} : (!tf_mlrt.tensor, !tf_mlrt.tensor) -> (!tf_mlrt.tensor) + %t3 = tf_mlrt.await %f2 + func.return %t2, %t3 : !tf_mlrt.tensor, !tf_mlrt.tensor +} + +// ----- + +// CHECK-LABEL: @main +// CHECK-SAME: ([[f0:%.*]]: !mlrt.async_handle, [[f1:%.*]]: !mlrt.async_handle, [[f2:%.*]]: !mlrt.async_handle) +func.func @main(%f0: !mlrt.async_handle, %f1: !mlrt.async_handle, %f2: !mlrt.async_handle) -> () { + // CHECK-NEXT: mlrt.await_all_handle [[f0]], [[f1]], [[f2]] + // CHECK-NOT: mlrt.await_handle + // CHECK-NEXT: return + mlrt.await_handle %f0 + mlrt.await_handle %f1 + mlrt.await_handle %f2 + func.return +} + +// ----- + +// CHECK-LABEL: @main +func.func @main() -> (!tf_mlrt.tensor, !tf_mlrt.tensor) { + // CHECK-NEXT: [[r:%.*]]:3 = tf_mlrt.get_resource {indices = [2, 0, 1]} + // CHECK-NEXT: [[v:%.*]] = tf_mlrt.executeop([[r]]#0, [[r]]#1) + // CHECK-NEXT: return [[v]], [[r]]#2 + %0 = tf_mlrt.get_resource {indices = [2]} : !tf_mlrt.tensor + %1 = tf_mlrt.get_resource {indices = [0]} : !tf_mlrt.tensor + %r = tf_mlrt.executeop(%0, %1) {node_def = "AddV2", op_key = 0 : i32} : (!tf_mlrt.tensor, !tf_mlrt.tensor) -> (!tf_mlrt.tensor) + %2 = tf_mlrt.get_resource {indices = [1]} : !tf_mlrt.tensor + func.return %r, %2 : !tf_mlrt.tensor, !tf_mlrt.tensor +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/inline.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/inline.mlir new file mode 100644 index 00000000000..de2a29c017d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/inline.mlir @@ -0,0 +1,50 @@ +// RUN: tf-tfrt-opt -split-input-file -pass-pipeline='builtin.module(tf-to-mlrt, inline)' %s | FileCheck %s -dump-input=fail + +// Test generated tf_mlrt while body and predicate is inlined. + +func.func @then(%x: tensor, %y: tensor, %z: tensor) -> tensor { + return %x: tensor +} + +func.func @else(%x: tensor, %y: tensor, %z: tensor) -> tensor { + return %y: tensor +} + +// CHECK-LABEL: func @while_cond_if +// CHECK: [[cond:%.*]] = tf_mlrt.predicate +// CHECK: [[z:%.*]] = mlrt.cond [[cond]] @then @else +// CHECK: return [[z]] +func.func @while_cond_if(%cond: tensor, %x: tensor, %y: tensor, %z: tensor) -> (tensor) { + %r = "tf.If"(%cond, %x, %y, %z) {then_branch = @then, else_branch = @else, is_stateless = true} : (tensor, tensor, tensor, tensor) -> tensor + return %r : tensor +} + +// CHECK-LABEL: func @while_body_if +func.func @while_body_if(%cond: tensor, %x: tensor, %y: tensor, %z: tensor) -> (tensor, tensor, tensor, tensor) { + %0 = "tf.Const"() {__op_key = 0: i32, device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor + %1 = "tf.Add"(%z, %0) {__op_key = 1: i32, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %cond, %x, %y, %1 : tensor, tensor, tensor, tensor +} + +// CHECK-LABEL: func @while_test_if +// CHECK-SAME: -> !tf_mlrt.tensor +func.func @while_test_if(%cond: tensor, %x: tensor, %y: tensor) -> (tensor) { + // CHECK: [[CONST:%.*]] = tf_mlrt.executeop + %cst = "tf.Const"() {__op_key = 2: i32, device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + // Predicate should be inlined. + // CHECK-NEXT: tf_mlrt.predicate + // CHECK-NEXT: mlrt.cond + // CHECK-NEXT: tf_mlrt.predicate + + // CHECK-NEXT: mlrt.while + %0:4 = "tf.While"(%cond, %x, %y, %cst) { cond = @while_cond_if, body = @while_body_if, is_stateless = false, parallel_iterations = 1} : (tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + // CHECK: return + // CHECK-SAME: !tf_mlrt.tensor + func.return %0#3 : tensor +} + +// CHECK-LABEL: func @"while_body_if/tf_mlrt_body" +// CHECK-NOT: call + +// CHECK-LABEL: func @"while_cond_if/tf_mlrt_predicate" +// CHECK-NOT: call diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/parallelization.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/parallelization.mlir new file mode 100644 index 00000000000..d4ab7c2c321 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/parallelization.mlir @@ -0,0 +1,378 @@ +// RUN: tf-tfrt-opt -split-input-file -tf-mlrt-parallelization %s | FileCheck %s --dump-input=fail --dump-input-filter=all + +// CHECK-LABEL: func private @main_stream_{{[0-9]*}} +// CHECK-SAME: ({{%.*}}: tensor, [[PROMISE:%.*]]: !mlrt.promise) +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: [[RES:%.*]] = "tf.Sub" +// CHECK: "tf_mlrt.tf_promise"([[PROMISE]], [[RES]]) +// CHECK: return + +// CHECK-LABEL: func @main +// CHECK: [[PROMISE:%.*]], [[FUTURE:%.*]] = "tf_mlrt.allocate_futures" +// CHECK: [[HANDLE:%.*]] = mlrt.async({{%.*}}, [[PROMISE]]) +// CHECK-SAME: callee = @main_stream_{{[0-9]*}} +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: [[x:%.*]] = "tf.AddV2" +// CHECK: [[y:%.*]] = "tf_mlrt.tf_await"([[FUTURE]]) +// CHECK: [[RES:%.*]] = "tf.AddV2"([[x]], [[y]]) +// CHECK: mlrt.await_handle [[HANDLE]] +// CHECK: return [[RES]] + +func.func @main(%a: tensor, %b: tensor) -> tensor { + + %a0 = "tf.AddV2"(%a, %a) : (tensor, tensor) -> tensor + %a1 = "tf.AddV2"(%a0, %a) : (tensor, tensor) -> tensor + %a2 = "tf.AddV2"(%a1, %a) : (tensor, tensor) -> tensor + %a3 = "tf.AddV2"(%a2, %a) : (tensor, tensor) -> tensor + + %b0 = "tf.Sub"(%b, %b) : (tensor, tensor) -> tensor + %b1 = "tf.Sub"(%b0, %b) : (tensor, tensor) -> tensor + %b2 = "tf.Sub"(%b1, %b) : (tensor, tensor) -> tensor + %b3 = "tf.Sub"(%b2, %b) : (tensor, tensor) -> tensor + + %c = "tf.AddV2"(%a3, %b3) : (tensor, tensor) -> tensor + + func.return %c : tensor +} + +// ----- + +// Test merging child streams + +// CHECK-LABEL: func private @main_stream_{{[0-9]*}} +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor, [[PROMISE:%.*]]: !mlrt.promise) +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: [[RES:%.*]] = "tf.Sub" +// CHECK: "tf_mlrt.tf_promise"([[PROMISE]], [[RES]]) +// CHECK: return + +// CHECK-LABEL: func @main +// CHECK: [[PROMISE:%.*]], [[FUTURE:%.*]] = "tf_mlrt.allocate_futures" +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: [[VALUE:%.*]] = "tf.AddV2" +// CHECK: [[HANDLE:%.*]] = mlrt.async([[VALUE]], {{%.*}}, [[PROMISE]]) +// CHECK-SAME: callee = @main_stream_{{[0-9]*}} +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: [[x:%.*]] = "tf.AddV2" +// CHECK: [[y:%.*]] = "tf_mlrt.tf_await"([[FUTURE]]) +// CHECK: [[RES:%.*]] = "tf.AddV2"([[x]], [[y]]) +// CHECK: mlrt.await_handle [[HANDLE]] +// CHECK: return [[RES]] + +func.func @main(%a: tensor, %b: tensor) -> tensor { + + %a0 = "tf.AddV2"(%a, %a) : (tensor, tensor) -> tensor + %a1 = "tf.AddV2"(%a0, %a) : (tensor, tensor) -> tensor + %a2 = "tf.AddV2"(%a1, %a) : (tensor, tensor) -> tensor + %a3 = "tf.AddV2"(%a2, %a) : (tensor, tensor) -> tensor + %a4 = "tf.AddV2"(%a3, %a) : (tensor, tensor) -> tensor + %a5 = "tf.AddV2"(%a4, %a) : (tensor, tensor) -> tensor + %a6 = "tf.AddV2"(%a5, %a) : (tensor, tensor) -> tensor + %a7 = "tf.AddV2"(%a6, %a) : (tensor, tensor) -> tensor + + %b0 = "tf.Sub"(%a3, %b) : (tensor, tensor) -> tensor + %b1 = "tf.Sub"(%b0, %b) : (tensor, tensor) -> tensor + %b2 = "tf.Sub"(%b1, %b) : (tensor, tensor) -> tensor + %b3 = "tf.Sub"(%b2, %b) : (tensor, tensor) -> tensor + + %c = "tf.AddV2"(%a7, %b3) : (tensor, tensor) -> tensor + + func.return %c : tensor +} + +// ----- + +// Test side-effecting ops + +// CHECK-LABEL: func private @main_stream_{{[0-9]*}} +// CHECK-SAME: ([[ARG:%.*]]: tensor, [[FUTURE:%.*]]: !mlrt.future, [[CONTROL_PROMISE:%.*]]: !mlrt.promise) +// CHECK: [[HANDLE:%.*]] = "tf_mlrt.tf_await"([[FUTURE]]) +// CHECK: "tf.AssignVariableOp"([[HANDLE]], [[ARG]]) +// CHECK-NEXT: mlrt.promise_control [[CONTROL_PROMISE]] + +// CHECK-LABEL: func private @main_stream_{{[0-9]*}} +// CHECK-SAME: ({{%.*}}: tensor, {{%.*}}: tensor, [[FUTURE:%.*]]: !mlrt.future, [[PROMISE:%.*]]: !mlrt.promise) +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: [[V:%.*]] = "tf_mlrt.tf_await"([[FUTURE]]) +// CHECK-NEXT: [[RES:%.*]] = "tf.Sub"({{%.*}}, [[V]]) +// CHECK: "tf_mlrt.tf_promise"([[PROMISE]], [[RES]]) +// CHECK: return + +// CHECK-LABEL: func private @main_stream_{{[0-9]*}} +// CHECK-SAME: ([[CONTROL_FUTURE:%.*]]: !mlrt.future, [[PROMISE:%.*]]: !mlrt.promise, [[PROMISE_HANDLE:%.*]]: !mlrt.promise) +// CHECK: [[HANDLE:%.*]] = "tf.VarHandleOp" +// CHECK-NEXT: "tf_mlrt.tf_promise"([[PROMISE_HANDLE]], [[HANDLE]]) +// CHECK: mlrt.await_control [[CONTROL_FUTURE]] +// CHECK-NEXT: [[V:%.*]] = "tf.ReadVariableOp"([[HANDLE]]) +// CHECK: "tf_mlrt.tf_promise"([[PROMISE]], [[V]]) + +// CHECK-LABEL: func @main +// CHECK: [[PROMISE:%.*]]:3, [[FUTURE:%.*]]:3 = "tf_mlrt.allocate_futures" +// CHECK: [[CONTROL_PROMISE:%.*]], [[CONTROL_FUTURE:%.*]] = "mlrt.allocate_control_futures" +// CHECK: [[ASYNC_HANDLE_0:%.*]] = mlrt.async([[CONTROL_FUTURE]], [[PROMISE]]#0, [[PROMISE]]#1) +// CHECK-SAME: callee = @main_stream_{{[0-9]*}} +// CHECK: [[ASYNC_HANDLE_1:%.*]] = mlrt.async({{%.*}}, {{%.*}}, [[FUTURE]]#0, [[PROMISE]]#2) +// CHECK-SAME: callee = @main_stream_{{[0-9]*}} +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: [[x:%.*]] = "tf.AddV2" +// CHECK: [[ASYNC_HANDLE_2:%.*]] = mlrt.async([[x]], [[FUTURE]]#1, [[CONTROL_PROMISE]]) +// CHECK-SAME: callee = @main_stream_{{[0-9]*}} +// CHECK: [[y:%.*]] = "tf_mlrt.tf_await"([[FUTURE]]#2) +// CHECK: [[RES:%.*]] = "tf.AddV2"([[x]], [[y]]) +// CHECK: mlrt.await_handle [[ASYNC_HANDLE_0]] +// CHECK-NEXT: mlrt.await_handle [[ASYNC_HANDLE_1]] +// CHECK-NEXT: mlrt.await_handle [[ASYNC_HANDLE_2]] +// CHECK-NEXT: return [[RES]] + +func.func @main(%a: tensor, %b: tensor) -> tensor { + %handle = "tf.VarHandleOp"() {container = "", shared_name = "var"} : () -> tensor>> + + %a0 = "tf.AddV2"(%a, %a) : (tensor, tensor) -> tensor + %a1 = "tf.AddV2"(%a0, %a) : (tensor, tensor) -> tensor + %a2 = "tf.AddV2"(%a1, %a) : (tensor, tensor) -> tensor + %a3 = "tf.AddV2"(%a2, %a) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%handle, %a3) : (tensor>>, tensor) -> () + + %b0 = "tf.Sub"(%a, %b) : (tensor, tensor) -> tensor + %b1 = "tf.Sub"(%b0, %b) : (tensor, tensor) -> tensor + %b2 = "tf.Sub"(%b1, %b) : (tensor, tensor) -> tensor + %var = "tf.ReadVariableOp"(%handle) : (tensor>>) -> tensor + %b3 = "tf.Sub"(%b2, %var) : (tensor, tensor) -> tensor + + %c = "tf.AddV2"(%a3, %b3) : (tensor, tensor) -> tensor + + func.return %c : tensor +} + +// ----- + +// Test multiple promises and futures + +// CHECK-LABEL: func private @main_stream_1 +// CHECK: mlrt.await_control +// CHECK: "tf.DummySideEffecting"() {id = 4 +// CHECK: return + +// CHECK-LABEL: func private @main_stream_2 +// CHECK: mlrt.await_control +// CHECK: "tf.DummySideEffecting"() {id = 3 +// CHECK: mlrt.promise_control +// CHECK: return + +// CHECK-LABEL: func private @main_stream_3 +// CHECK: mlrt.await_control +// CHECK: "tf.DummySideEffecting"() {id = 2 +// CHECK: mlrt.promise_control +// CHECK: return + +// CHECK-LABEL: func private @main_stream_4 +// CHECK: "tf.DummySideEffecting"() {id = 1 +// CHECK: mlrt.promise_control +// CHECK: return + +// CHECK-LABEL: func @main() +// CHECK: [[PROMISES:%.*]]:3, [[FUTURES:%.*]]:3 = "mlrt.allocate_control_futures" +// CHECK: mlrt.async([[PROMISES]]#2) {callee = @main_stream_4 +// CHECK: mlrt.async([[FUTURES]]#2, [[PROMISES]]#1) {callee = @main_stream_3 +// CHECK: mlrt.async([[FUTURES]]#1, [[PROMISES]]#0) {callee = @main_stream_2 +// CHECK: mlrt.async([[FUTURES]]#0) {callee = @main_stream_1 +// CHECK: mlrt.await_handle +// CHECK: mlrt.await_handle +// CHECK: mlrt.await_handle +// CHECK: mlrt.await_handle + +func.func @main() { + "tf.DummySideEffecting"() {id = 1} : () -> () + "tf.DummySideEffecting"() {id = 2} : () -> () + "tf.DummySideEffecting"() {id = 3} : () -> () + "tf.DummySideEffecting"() {id = 4} : () -> () + func.return +} + +// ----- + +// Test correctness when there are both data and control promises in a stream function. + +// CHECK-LABEL: func private @main_stream_1 +// CHECK-SAME: ([[PROMISE:%.*]]: !mlrt.promise, [[CONTROL_PROMISE:%.*]]: !mlrt.promise) +// CHECK: tf.DummySideEffecting +// CHECK: "tf_mlrt.tf_promise"([[PROMISE]] +// CHECK: mlrt.promise_control [[CONTROL_PROMISE]] + +func.func @main() -> tensor { + %v = "tf.DummySideEffecting"() {id = 1} : () -> tensor + + %w = "tf.DummySideEffecting"() {id = 2} : () -> tensor + %r = "tf.AddV2"(%w, %v) : (tensor, tensor) -> tensor + func.return %r : tensor +} + +// ----- + +// Test inputs to the child streams are merged to the parent streams + +// CHECK-LABEL: func private @main_stream_1 +// CHECK-SAME: ([[INPUT0:%.*]]: tensor, [[INPUT1:%.*]]: tensor +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: mlrt.async({{%.*}}, [[INPUT1]] + +// CHECK-LABEL: func @main +func.func @main(%a: tensor, %b: tensor) -> tensor { + + %a0 = "tf.AddV2"(%a, %a) : (tensor, tensor) -> tensor + %a1 = "tf.AddV2"(%a0, %a) : (tensor, tensor) -> tensor + %a2 = "tf.AddV2"(%a1, %a) : (tensor, tensor) -> tensor + %a3 = "tf.AddV2"(%a2, %a) : (tensor, tensor) -> tensor + + %b0 = "tf.Sub"(%b, %b) : (tensor, tensor) -> tensor + %b1 = "tf.Sub"(%b0, %b) : (tensor, tensor) -> tensor + + %c = "tf.AddV2"(%b1, %a) : (tensor, tensor) -> tensor + + %b2 = "tf.Sub"(%b1, %b) : (tensor, tensor) -> tensor + %b3 = "tf.Sub"(%b2, %b) : (tensor, tensor) -> tensor + + %d = "tf.AddN"(%a3, %b3, %c) : (tensor, tensor, tensor) -> tensor + func.return %d : tensor +} + +// ----- + +// Test that constants are copied instead of using promise/await. + +// CHECK-LABEL: func private @main_stream_1 +// CHECK-SAME: ({{%.*}}: tensor, [[PROMISE:%.*]]: !mlrt.promise) +// CHECK: tf._TfrtGetResource +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: tf.Sub +// CHECK: [[RES:%.*]] = "tf.Sub" +// CHECK: "tf_mlrt.tf_promise"([[PROMISE]], [[RES]]) +// CHECK: return + +// CHECK-NOT: func private @main_stream + +// CHECK-LABEL: func @main +// CHECK: [[PROMISE:%.*]], [[FUTURE:%.*]] = "tf_mlrt.allocate_futures" +// CHECK-NEXT: [[HANDLE:%.*]] = mlrt.async({{%.*}}, [[PROMISE]]) +// CHECK-SAME: callee = @main_stream_1 +// CHECK: tf._TfrtGetResource +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: tf.AddV2 +// CHECK: [[x:%.*]] = "tf.AddV2" +// CHECK: [[y:%.*]] = "tf_mlrt.tf_await"([[FUTURE]]) +// CHECK: [[RES:%.*]] = "tf.AddV2"([[x]], [[y]]) +// CHECK: mlrt.await_handle [[HANDLE]] +// CHECK: return [[RES]] + +func.func @main(%a: tensor, %b: tensor) -> tensor { + + %c0 = "tf._TfrtGetResource"() {indices = [0], shared_name = [""], container = [""]} : () -> (tensor) + + %a0 = "tf.AddV2"(%a, %c0) : (tensor, tensor) -> tensor + %a1 = "tf.AddV2"(%a0, %c0) : (tensor, tensor) -> tensor + %a2 = "tf.AddV2"(%a1, %c0) : (tensor, tensor) -> tensor + %a3 = "tf.AddV2"(%a2, %c0) : (tensor, tensor) -> tensor + + %b0 = "tf.Sub"(%b, %c0) : (tensor, tensor) -> tensor + %b1 = "tf.Sub"(%b0, %c0) : (tensor, tensor) -> tensor + %b2 = "tf.Sub"(%b1, %c0) : (tensor, tensor) -> tensor + %b3 = "tf.Sub"(%b2, %c0) : (tensor, tensor) -> tensor + + %c = "tf.AddV2"(%a3, %b3) : (tensor, tensor) -> tensor + + func.return %c : tensor +} + +// ----- + +// Test that constants private to a stream are still handled properly when we are copying shared constants. + +// CHECK-LABEL: func private @main_stream_1 +// CHECK: [[r:%.*]] = "tf._TfrtGetResource" +// CHECK-SAME: indices = [1] +// CHECK: "tf.DummySideEffecting"([[r]]) + +// CHECK-LABEL: func private @main_stream_2 +// CHECK: [[r:%.*]] = "tf._TfrtGetResource" +// CHECK-SAME: indices = [0] +// CHECK: "tf.DummySideEffecting"([[r]]) + +// CHECK-LABEL: func @main + +func.func @main(%a: tensor, %b: tensor) -> () { + + %c0 = "tf._TfrtGetResource"() {indices = [0], shared_name = [""], container = [""]} : () -> (tensor) + "tf.DummySideEffecting"(%c0) : (tensor) -> () + + %c1 = "tf._TfrtGetResource"() {indices = [1], shared_name = [""], container = [""]} : () -> (tensor) + "tf.DummySideEffecting"(%c1) : (tensor) -> () + + func.return +} + +// ----- + +// Test that streams with no args but side-effecting ops are still created properly + +// CHECK-LABEL: func private @main_stream_1() +// CHECK: [[r:%.*]] = "tf._TfrtGetResource" +// CHECK-SAME: indices = [0] +// CHECK: "tf.DummySideEffecting"([[r]]) + +// CHECK-LABEL: func @main + +func.func @main(%a: tensor, %b: tensor) -> () { + %c0 = "tf._TfrtGetResource"() {indices = [0], shared_name = [""], container = [""]} : () -> (tensor) + "tf.DummySideEffecting"(%c0) : (tensor) -> () + func.return +} + +// ----- + +// Test control deps of tf.Assert is skipped. + +// CHECK-LABEL: func.func private @skip_assert_stream_3( +// CHECK-NOT: mlrt.await_control +// CHECK: tf.Assert +// CHECK-NOT: mlrt.promise_control +// CHECK: return + +// CHECK-LABEL: func.func private @skip_assert_stream_2( +// CHECK-NOT: mlrt.await_control +// CHECK: tf.Assert +// CHECK-NOT: mlrt.promise_control +// CHECK: return + +func.func @skip_assert(%key: tensor) -> (tensor, tensor) { + %error_message = "tf.Const"() {value = dense<"error"> : tensor} : () -> tensor + %default = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %handle = "tf.HashTableV2"() {container = "", device = "/job:localhost/replica:0/task:0/device:CPU:0", key_dtype = !tf_type.string, shared_name = "hash_table", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + + + %keys = "tf.Const"() {value = dense<["a", "b", "c", "d"]> : tensor<4x!tf_type.string>} : () -> tensor<4x!tf_type.string> + %values = "tf.Const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi64>} : () -> tensor<4xi64> + "tf.LookupTableImportV2"(%handle, %keys, %values) {device = ""} : (tensor, tensor<4x!tf_type.string>, tensor<4xi64>) -> () + %value0 = "tf.LookupTableFindV2"(%handle, %key, %default) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor, tensor) -> tensor + %cond = "tf.Equal"(%value0, %default) {device = "/job:localhost/replica:0/task:0/device:CPU:0", incompatible_shape_error = true} : (tensor, tensor) -> tensor + "tf.Assert"(%cond, %error_message) {device = "/job:localhost/replica:0/task:0/device:CPU:0", summarize = 3 : i64} : (tensor, tensor) -> () + "tf.Assert"(%cond, %error_message) {device = "/job:localhost/replica:0/task:0/device:CPU:0", summarize = 3 : i64} : (tensor, tensor) -> () + %value1 = "tf.LookupTableFindV2"(%handle, %key, %default) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor, tensor) -> tensor + func.return %value0, %value1 : tensor, tensor +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir new file mode 100644 index 00000000000..d5d3254901c --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tf_to_mlrt.mlir @@ -0,0 +1,419 @@ +// RUN: tf-tfrt-opt -split-input-file -tf-to-mlrt %s | FileCheck %s + +// CHECK-LABEL: @main_stream_0 +// CHECK-SAME: ([[input0:%.*]]: !tf_mlrt.tensor, [[promise_b:%.*]]: !mlrt.promise) +func.func @main_stream_0(%input0: tensor, %promise_b: !mlrt.promise) { + %const = "tf.Const"() {__op_key = 0 : i32, value = dense<1> : tensor} : () -> tensor + // CHECK: [[a:%.*]] = tf_mlrt.executeop([[input0]], + // CHECK-SAME: AddV2 + %a = "tf.AddV2"(%input0, %const) {__op_key = 1: i32}: (tensor, tensor) -> tensor + // CHECK: [[b:%.*]] = tf_mlrt.executeop([[a]]) + // CHECK-SAME: Abs + %b = "tf.Abs"(%a) {__op_key = 2 : i32}: (tensor) -> tensor + // CHECK: tf_mlrt.promise [[promise_b]], [[b]] + "tf_mlrt.tf_promise"(%promise_b, %b) : (!mlrt.promise, tensor) -> () + // CHECK: return + return +} + +// CHECK-LABEL: @main_stream_1 +// CHECK-SAME: ([[input1:%.*]]: !tf_mlrt.tensor, [[promise_c:%.*]]: !mlrt.promise, [[promise_d:%.*]]: !mlrt.promise) +func.func @main_stream_1(%input1: tensor, %promise_c: !mlrt.promise, %promise_d: !mlrt.promise) { + %const = "tf.Const"() {__op_key = 3 : i32, value = dense<1> : tensor} : () -> tensor + // CHECK: [[c:%.*]] = tf_mlrt.executeop([[input1]], + // CHECK-SAME: Sub + %c = "tf.Sub"(%input1, %const) {__op_key = 4: i32} : (tensor, tensor) -> tensor + // CHECK: tf_mlrt.promise [[promise_c]], [[c]] + "tf_mlrt.tf_promise"(%promise_c, %c) : (!mlrt.promise, tensor) -> () + // CHECK: [[d:%.*]] = tf_mlrt.executeop([[c]]) + // CHECK-SAME: Abs + %d = "tf.Abs"(%c) {__op_key = 5: i32}: (tensor) -> tensor + // CHECK: tf_mlrt.promise [[promise_d]], [[d]] + "tf_mlrt.tf_promise"(%promise_d, %d) : (!mlrt.promise, tensor) -> () + // CHECK: return + return +} + +// CHECK-LABEL: @main +// CHECK-SAME: ([[input0:%.*]]: !tf_mlrt.tensor, [[input1:%.*]]: !tf_mlrt.tensor) +func.func @main(%input0: tensor, %input1: tensor) -> tensor { + // CHECK: [[promises:%.*]]:3, [[futures:%.*]]:3 = "tf_mlrt.allocate_futures" + // CHECK-SAME: num_futures = 3 + %promise_b, %promise_c, %promise_d, %future_b, %future_c, %future_d = + "tf_mlrt.allocate_futures"() + {num_futures = 3 : i32, result_segment_sizes = array} : () -> + (!mlrt.promise, !mlrt.promise, !mlrt.promise, + !mlrt.future, !mlrt.future, !mlrt.future) + + // CHECK: [[handle_0:%.*]] = mlrt.async([[input0]], [[promises]]#0) + // CHECK-SAME: callee = @main_stream_0 + %handle_0 = mlrt.async(%input0, %promise_b) + {callee = @main_stream_0} : + (tensor, !mlrt.promise) -> !mlrt.async_handle + // CHECK: [[handle_1:%.*]] = mlrt.async([[input1]], [[promises]]#1, [[promises]]#2) + // CHECK-SAME: callee = @main_stream_1 + %handle_1 = mlrt.async(%input1, %promise_c, %promise_d) + {callee = @main_stream_1} : + (tensor, !mlrt.promise, !mlrt.promise) -> !mlrt.async_handle + + %const = "tf.Const"() {__op_key = 6: i32, value = dense<2> : tensor} : () -> tensor + // CHECK: [[e:%.*]] = tf_mlrt.executeop([[input1]], + // CHECK-SAME: Mul + %e = "tf.Mul"(%input1, %const) {__op_key = 7: i32} : (tensor, tensor) -> tensor + // CHECK: [[c:%.*]] = tf_mlrt.await [[futures]]#1 + %c = "tf_mlrt.tf_await"(%future_c) : (!mlrt.future) ->tensor + // CHECK: [[f:%.*]] = tf_mlrt.executeop([[e]], [[c]]) + // CHECK-SAME: Div + %f = "tf.Div"(%e, %c) {__op_key = 8: i32}: (tensor, tensor) -> tensor + + // CHECK: [[b:%.*]] = tf_mlrt.await [[futures]]#0 + %b = "tf_mlrt.tf_await"(%future_b) : (!mlrt.future) ->tensor + // CHECK: [[d:%.*]] = tf_mlrt.await [[futures]]#2 + %d = "tf_mlrt.tf_await"(%future_d) : (!mlrt.future) ->tensor + + // CHECK: [[result:%.*]] = tf_mlrt.executeop([[b]], [[d]], [[f]]) + // CHECK-SAME: AddN + %result = "tf.AddN"(%b, %d, %f) {__op_key = 9: i32}: (tensor, tensor, tensor) -> tensor + + // CHECK: mlrt.await_handle [[handle_0]] + // CHECK: mlrt.await_handle [[handle_1]] + mlrt.await_handle %handle_0 + mlrt.await_handle %handle_1 + + // CHECK: return [[result]] + return %result : tensor +} + +// ----- + +// Test lowering tf.If + +func.func @then(%x: tensor, %y: tensor) -> tensor { + return %x: tensor +} + +func.func @else(%x: tensor, %y: tensor) -> tensor { + return %y: tensor +} + +// CHECK-LABEL: func @main +// CHECK-SAME: ([[cond_tensor:%.*]]: !tf_mlrt.tensor, [[x:%.*]]: !tf_mlrt.tensor, [[y:%.*]]: !tf_mlrt.tensor) +// CHECK: [[cond:%.*]] = tf_mlrt.predicate [[cond_tensor]] +// CHECK: [[z:%.*]] = mlrt.cond [[cond]] @then @else([[x]], [[y]]) +// CHECK: return [[z]] +func.func @main(%cond: tensor, %x: tensor, %y: tensor) -> tensor { + %z = "tf.If"(%cond, %x, %y) {then_branch = @then, else_branch = @else, is_stateless = true} : (tensor, tensor, tensor) -> tensor + return %z: tensor +} + +// ----- + +// Test lowering AsyncOpKernel + +// CHECK-LABEL: func @main +func.func @main(%x: tensor) -> (tensor, tensor, tensor) { + // CHECK: [[y_future:%.*]] = tf_mlrt.async_executeop + %y = "tf.TestAsyncIdentity"(%x) {__op_key = 0: i32, T = i32} : (tensor) -> tensor + // CHECK: [[z:%.*]] = tf_mlrt.executeop + %z = "tf.Identity"(%x) {__op_key = 1: i32}: (tensor) -> tensor + // CHECK: [[y:%.*]] = tf_mlrt.await [[y_future]] + // CHECK-NEXT: tf_mlrt.executeop([[y]] + %w = "tf.AddV2"(%y, %z) {__op_key = 2: i32}: (tensor, tensor) -> tensor + // CHECK-NEXT: tf_mlrt.executeop([[y]] + %u = "tf.AddV2"(%y, %z) {__op_key = 3: i32} : (tensor, tensor) -> tensor + // CHECK-NEXT: tf_mlrt.executeop([[y]] + %v = "tf.AddV2"(%y, %z) {__op_key = 4: i32}: (tensor, tensor) -> tensor + return %w, %u, %v : tensor, tensor, tensor +} + +// ----- + +// Test lowering BatchFunction op. + +func.func @batched_function(%x: tensor) -> tensor { + return %x : tensor +} + +// CHECK-LABEL: func @main +func.func @main(%x: tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) { + // CHECK: [[y_future:%.*]] = tf_mlrt.batch_function + // CHECK-SAME: f = @batched_function + // CHECK-SAME: \22batch_function\22 + %y = "tf.BatchFunction"(%x) { + allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, + batching_queue = "", container = "", device = "/device:CPU:0", + enable_large_batch_splitting = false, f = @batched_function, + max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, + num_batch_threads = 1 : i64, operand_segment_sizes = array, + shared_name = "batch_function" + } : (tensor<1xi32>) -> tensor<1xi32> + + // CHECK: [[z:%.*]] = tf_mlrt.executeop + %z = "tf.Identity"(%x) {__op_key = 0: i32} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: [[y:%.*]] = tf_mlrt.await [[y_future]] + // CHECK-NEXT: tf_mlrt.executeop([[y]] + %w = "tf.AddV2"(%y, %z) {__op_key = 1: i32}: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: tf_mlrt.executeop([[y]] + %u = "tf.AddV2"(%y, %z) {__op_key = 2: i32}: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: tf_mlrt.executeop([[y]] + %v = "tf.AddV2"(%y, %z) {__op_key = 3: i32}: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + return %w, %u, %v : tensor<1xi32>, tensor<1xi32>, tensor<1xi32> +} + +// ----- + +// Test node names are preserved. + +// CHECK-LABEL: func @main +func.func @main(%x: tensor) -> tensor { + // CHECK: tf_mlrt.executeop + // CHECK-SAME: name: \22name_loc/AddV2_0\22 + %y = "tf.AddV2"(%x, %x) {__op_key = 0: i32} : (tensor, tensor) -> tensor loc("name_loc:AddV2") + // CHECK: tf_mlrt.executeop + // CHECK-SAME: name: \22fused_loc/AddV2_1\22 + %z = "tf.AddV2"(%y, %x) {__op_key = 1: i32}: (tensor, tensor) -> tensor loc(fused["fused_loc:", "AddV2"]) + // CHECK: tf_mlrt.executeop + // CHECK-SAME: name: \22AddV2_2\22 + %w = "tf.AddV2"(%z, %x) {__op_key = 2: i32}: (tensor, tensor) -> tensor + return %z : tensor +} + +// ----- + +// Test function name canonicalization + +// CHECK-LABEL: func @__inference_pruned_35 +func.func @__inference_pruned_35() -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "flatmapdataset__4_RetVal"}} { + %0 = "tf.Const"() {__op_key = 0: i32, device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {__op_key = 1: i32, device = "/device:CPU:0", value = dense<5> : tensor} : () -> tensor + %2 = "tf.Const"() {__op_key = 2: i32, device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %3 = "tf.RangeDataset"(%0, %1, %2) {__op_key = 3: i32, device = "/device:CPU:0", output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor, tensor, tensor) -> tensor + // CHECK: tf_mlrt.executeop{{.*}}op: \22FlatMapDataset\22 + // CHECK-SAME: \22__inference_Dataset_flat_map_lambda_19\22 + %4 = "tf.FlatMapDataset"(%3) {__op_key = 4: i32, Targuments = [], device = "/device:CPU:0", f = @__inference_Dataset_flat_map_lambda_190, output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor) -> tensor + func.return %4 : tensor +} +// CHECK-LABEL: __inference_Dataset_flat_map_lambda_190 +func.func private @__inference_Dataset_flat_map_lambda_190(%arg0: tensor {tf._user_specified_name = "args_0"}) -> tensor attributes {tf._original_func_name = "__inference_Dataset_flat_map_lambda_19", tf._tf_data_function = true, tf.signature.is_stateful} { + %0 = "tf.Const"() {__op_key = 5: i32, device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {__op_key = 6: i32,device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {__op_key = 7: i32,device = "/device:CPU:0", value = dense<5> : tensor} : () -> tensor + %3 = "tf.RangeDataset"(%0, %2, %1) {__op_key = 8: i32, device = "/device:CPU:0", output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor, tensor, tensor) -> tensor + // CHECK: tf_mlrt.executeop{{.*}}op: \22MapDataset\22 + // CHECK-SAME: \22__inference_Dataset_map_lambda_16\22 + %4 = "tf.MapDataset"(%3) {__op_key = 9: i32, device = "/device:CPU:0", f = @__inference_Dataset_map_lambda_160, f._tf_data_function = true, output_shapes = [#tf_type.shape<>], output_types = [i64], preserve_cardinality = true, use_inter_op_parallelism = true, metadata = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {__op_key = 10: i32, device = "/device:CPU:0"} : (tensor) -> tensor + func.return %5 : tensor +} +// CHECK-LABEL: __inference_Dataset_map_lambda_160 +func.func private @__inference_Dataset_map_lambda_160(%arg0: tensor {tf._user_specified_name = "args_0"}) -> tensor attributes {tf._tf_data_function = true} { + %0 = "tf.Const"() {__op_key = 11: i32, device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {__op_key = 12: i32, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.Identity"(%1) {__op_key = 13: i32, device = "/device:CPU:0"} : (tensor) -> tensor + func.return %2 : tensor +} + +// ----- + +// Test while conversion + +// CHECK-LABEL: func @while_cond_lt9 +// CHECK-SAME: ([[arg0:%.*]]: !tf_mlrt.tensor) -> !tf_mlrt.tensor +func.func @while_cond_lt9(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {__op_key = 0: i32, device = "/device:CPU:0", value = dense<9> : tensor} : () -> tensor + %1 = "tf.Less"(%arg0, %0) {__op_key = 1: i32, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: func @while_body_add2 +// CHECK-SAME: ([[arg0:%.*]]: !tf_mlrt.tensor) -> !tf_mlrt.tensor +func.func @while_body_add2(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {__op_key = 2: i32, device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor + %1 = "tf.Add"(%arg0, %0) {__op_key = 3: i32, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: func @while_test() +// CHECK-SAME: -> !tf_mlrt.tensor +func.func @while_test() -> (tensor) { + // CHECK: [[CONST:%.*]] = tf_mlrt.executeop + %0 = "tf.Const"() {__op_key = 4: i32, device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: [[pred_res:%.*]] = call @"while_cond_lt9/tf_mlrt_predicate"([[CONST]]) : (!tf_mlrt.tensor) -> i1 + // CHECK: [[while_res:%.*]]:2 = mlrt.while + // CHECK-SAME: @"while_body_add2/tf_mlrt_body"([[CONST]]) + // CHECK-SAME: (!tf_mlrt.tensor) -> (!tf_mlrt.tensor, i1) + %1 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) + // CHECK: return [[while_res]]#0 : !tf_mlrt.tensor + func.return %1 : tensor +} +// CHECK: func @"while_body_add2/tf_mlrt_body"([[arg:%.*]]: !tf_mlrt.tensor) -> (!tf_mlrt.tensor, i1) +// CHECK: [[body_res:%.*]] = call @while_body_add2([[arg]]) : (!tf_mlrt.tensor) -> !tf_mlrt.tensor +// CHECK: [[pred_res:%.*]] = call @"while_cond_lt9/tf_mlrt_predicate"([[body_res]]) : (!tf_mlrt.tensor) -> i1 +// CHECK: return [[body_res]], [[pred_res]] : !tf_mlrt.tensor, i1 + +// CHECK: func @"while_cond_lt9/tf_mlrt_predicate"([[arg:%.*]]: !tf_mlrt.tensor) -> i1 +// CHECK: [[cond_res:%.*]] = call @while_cond_lt9([[arg]]) : (!tf_mlrt.tensor) -> !tf_mlrt.tensor +// CHECK: [[bool_res:%.*]] = tf_mlrt.predicate [[cond_res]] +// CHECK: return [[bool_res]] : i1 + +// CHECK-LABEL: func @multi_while_test +func.func @multi_while_test() -> (tensor, tensor) { + %0 = "tf.Const"() {__op_key = 5: i32, device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {__op_key = 6: i32, device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor + // CHECK: [[pred_0:%.*]] = call @"while_cond_lt9/tf_mlrt_predicate" + // CHECK: mlrt.while [[pred_0]] @"while_body_add2/tf_mlrt_body" + // CHECK: [[pred_1:%.*]] = call @"while_cond_lt9/tf_mlrt_predicate" + // CHECK: mlrt.while [[pred_1]] @"while_body_add2/tf_mlrt_body" + %2 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) + %3 = "tf.While"(%1) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) + func.return %2, %3 : tensor, tensor +} + +// ----- + +// Test async output to function is converted + +// CHECK-LABEL: @serving_default_stream_1 +// CHECK-SAME: !mlrt.future +func.func private @serving_default_stream_1(%arg0: tensor) { + // CHECK: [[tensor:%.*]] = tf_mlrt.await + // CHECK: tf_mlrt.executeop([[tensor]]) + %0 = "tf.StringFormat"(%arg0) {__op_key = 0: i32, device = "/job:localhost/replica:0/task:0/device:CPU:0", placeholder = "{}", strtemplate = "%s", summarize = 3 : i64, template = "Outside compiled {}"} : (tensor) -> tensor + "tf.PrintV2"(%0) {__op_key = 1: i32, device = "/job:localhost/replica:0/task:0/device:CPU:0", end = "\0A", output_stream = "stderr"} : (tensor) -> () + return +} + +func.func @callee(%arg: tensor) -> (tensor) { + func.return %arg: tensor +} + +// CHECK-LABEL: @executeop_input +func.func @executeop_input(%arg0: tensor) -> (tensor) { + // CHECK: [[async_out:%.*]] = tf_mlrt.batch_function + %2 = "tf.BatchFunction"(%arg0) {device = "/device:CPU:0", allowed_batch_sizes = [64], batch_timeout_micros = 1 : i64, batching_queue = "", container = "", f = @callee, max_batch_size = 256 : i64, num_batch_threads = 2 : i64, operand_segment_sizes = array, shared_name = ""} : (tensor) -> tensor + // CHECK-NEXT: mlrt.async([[async_out]]) {{.*}} : (!mlrt.future) + %3 = mlrt.async(%2) {callee = @serving_default_stream_1} : (tensor) -> !mlrt.async_handle + // CHECK: mlrt.await_handle + mlrt.await_handle %3 + // CHECK: return + // CHECK-SAME: !tf_mlrt.tensor + func.return %2 : tensor +} + +// ----- + +// Support pre-assigned op_key + +// CHECK-LABEL: @main +// CHECK-SAME: ([[input0:%.*]]: !tf_mlrt.tensor, [[promise_b:%.*]]: !mlrt.promise) +func.func @main(%input0: tensor, %promise_b: !mlrt.promise) { + %const = "tf.Const"() {__op_key = 0 : i32, value = dense<1> : tensor} : () -> tensor + // CHECK: [[a:%.*]] = tf_mlrt.executeop([[input0]], + // CHECK-SAME: AddV2 + // CHECK-SAME: op_key = 1 + // CHECK-NOT: __op_key + %a = "tf.AddV2"(%input0, %const) {__op_key = 1: i32}: (tensor, tensor) -> tensor + // CHECK: [[b:%.*]] = tf_mlrt.executeop([[a]]) + // CHECK-SAME: Abs + // CHECK-SAME: op_key = 2 + // CHECK-NOT: __op_key + %b = "tf.Abs"(%a) {__op_key = 2: i32 }: (tensor) -> tensor + // CHECK: tf_mlrt.promise [[promise_b]], [[b]] + "tf_mlrt.tf_promise"(%promise_b, %b) : (!mlrt.promise, tensor) -> () + // CHECK: return + return +} + +// ----- + +// Test future as input to promise + +// CHECK-LABEL: func @main_stream_0 +func.func @main_stream_0(%x: tensor, %p: !mlrt.promise) -> () { + // CHECK: [[y_future:%.*]] = tf_mlrt.async_executeop + %y = "tf.TestAsyncIdentity"(%x) {__op_key = 0: i32, T = i32} : (tensor) -> tensor + // CHECK: tf_mlrt.promise_future + // CHECK-SAME: [[y_future]] + "tf_mlrt.tf_promise"(%p, %y): (!mlrt.promise, tensor) -> () + return +} + +// CHECK-LABEL: @main +// CHECK-SAME: ([[input0:%.*]]: !tf_mlrt.tensor) +func.func @main(%input0: tensor) -> tensor { + // CHECK: [[promises:%.*]], [[futures:%.*]] = "tf_mlrt.allocate_futures" + // CHECK-SAME: num_futures = 1 + %promise_b, %future_b = "tf_mlrt.allocate_futures"() + {num_futures = 1 : i32, result_segment_sizes = array} : () -> + (!mlrt.promise, !mlrt.future) + + // CHECK: [[handle_0:%.*]] = mlrt.async([[input0]], [[promises]]) + // CHECK-SAME: callee = @main_stream_0 + %handle_0 = mlrt.async(%input0, %promise_b) + {callee = @main_stream_0} : + (tensor, !mlrt.promise) -> !mlrt.async_handle + + // CHECK: [[const:%.*]] = tf_mlrt.executeop + // CHECK-SAME: Const + %const = "tf.Const"() {__op_key = 1: i32, value = dense<2> : tensor} : () -> tensor + + // CHECK: [[b:%.*]] = tf_mlrt.await [[futures]] + %b = "tf_mlrt.tf_await"(%future_b) : (!mlrt.future) ->tensor + + // CHECK: [[result:%.*]] = tf_mlrt.executeop([[b]], [[const]]) + // CHECK-SAME: AddV2 + %result = "tf.AddV2"(%b, %const) {__op_key = 2: i32}: (tensor, tensor) -> tensor + + // CHECK: mlrt.await_handle [[handle_0]] + mlrt.await_handle %handle_0 + + // CHECK: return [[result]] + return %result : tensor +} + +// ----- + +// Test lowering of tf call ops + +// CHECK-LABEL: @callee +func.func @callee(%arg0: tensor) -> (tensor) { + func.return %arg0: tensor +} + +// CHECK-LABEL: func @call_test +func.func @call_test(%arg0: tensor) -> (tensor, tensor, tensor) { + %0 = "tf.Add"(%arg0, %arg0) {__op_key = 0, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + // CHECK: [[results_0:%.*]] = call @callee( + // CHECK-SAME: (!tf_mlrt.tensor) -> !tf_mlrt.tensor + %1 = "tf.StatefulPartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor) -> (tensor) + // CHECK-NEXT: [[results_1:%.*]] = call @callee( + // CHECK-SAME: (!tf_mlrt.tensor) -> !tf_mlrt.tensor + %2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor) -> (tensor) + // CHECK-NEXT: [[results_2:%.*]] = call @callee( + // CHECK-SAME: (!tf_mlrt.tensor) -> !tf_mlrt.tensor + %3 = "tf.LegacyCall"(%0) {f = @callee} : (tensor) -> (tensor) + // CHECK: [[results_0]], [[results_1]], [[results_2]] + func.return %1, %2, %3 : tensor, tensor, tensor +} + +// CHECK-LABEL: @branch0 +func.func @branch0(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.Add" (%arg0, %arg1) {__op_key = 1, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @branch1 +func.func @branch1(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.Add" (%arg0, %arg1) {__op_key = 2, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Add" (%arg0, %0) {__op_key = 3, device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: func @case_test +// CHECK-SAME: ([[tf_idx:%.*]]: !tf_mlrt.tensor, [[branch_arg0:%.*]]: !tf_mlrt.tensor, [[branch_arg1:%.*]]: !tf_mlrt.tensor) +func.func @case_test(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: [[idx:%.*]] = tf_mlrt.tensor_to_int32 [[tf_idx]] + // CHECK-NEXT: [[out:%.*]] = mlrt.case [[idx]] [@branch0, @branch1]([[branch_arg0]], [[branch_arg1]]) + %0 = "tf.Case"(%arg0, %arg1, %arg2) {_lower_using_switch_merge = true, branches = [@branch0, @branch1], is_stateless = true} : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/tpu_conversions.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tpu_conversions.mlir new file mode 100644 index 00000000000..87f906bcbe1 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/tpu_conversions.mlir @@ -0,0 +1,168 @@ +// RUN: tf-tfrt-opt --split-input-file -pass-pipeline='builtin.module(pre-parallel-tf-to-mlrt{use-tpu-host-allocator-for-inputs=true},tf-mlrt-parallelization{tfrt-cost-threshold=4},tf-to-mlrt)' %s | FileCheck %s --dump-input=fail --dump-input-filter=all + +func.func @callee(%arg0: tensor, %arg1: tensor) -> (tensor) { + func.return %arg0: tensor +} + +// CHECK-LABEL: func @batch_function +func.func @batch_function(%arg0: tensor) -> (tensor) { + // CHECK: [[batch_result_future:%.*]] = tf_mlrt.batch_function + // CHECK: [[batch_result:%.*]] = tf_mlrt.await [[batch_result_future]] + // CHECK-NEXT: [[rendezvous_key_base:%.*]] = tf_mlrt_tpu.compile_and_execute([[batch_result]]) + // CHECK-NEXT: return [[rendezvous_key_base]] + %0 = "tf.BatchFunction"(%arg0, %arg0) {device = "/device:CPU:0", allowed_batch_sizes = [64], batch_timeout_micros = 1 : i64, batching_queue = "", container = "", f = @callee, max_batch_size = 256 : i64, num_batch_threads = 2 : i64, operand_segment_sizes = array, shared_name = ""} : (tensor, tensor) -> tensor + %1 = "tf.TPUCompileMlirAndExecute"(%0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +func.func @executeop_input(%arg0: tensor) -> (tensor, tensor) { + // CHECK-NOT: tf_mlrt.executeop( + // CHECK: [[device:%.*]] = tf_mlrt_tpu.get_tpu_host_device + // CHECK: [[cast:%.*]] = tf_mlrt.executeop.device([[device]]){{.*}}op: \22Cast\22 + // CHECK: [[rendezvous_key_base:%.*]], [[result_future:%.*]] = tf_mlrt_tpu.compile_and_execute([[cast]]) + // CHECK: tf_mlrt.await [[result_future]] + %0 = "tf.Cast"(%arg0) {__op_key = 0: i32, device = "/device:CPU:0"} : (tensor) -> tensor + %1, %2 = "tf.TPUCompileMlirAndExecute"(%0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor) -> (tensor, tensor) + func.return %1, %2 : tensor, tensor +} + +// ----- + +func.func @executeop_side_effecting_input(%arg0: tensor>>, %indices: tensor) -> (tensor) { + // CHECK-NOT: tf_mlrt.executeop( + // CHECK: [[device:%.*]] = tf_mlrt_tpu.get_tpu_host_device + // CHECK: [[var:%.*]] = tf_mlrt.executeop.device([[device]]){{.*}}op: \22ResourceGather\22 + // CHECK: [[rendezvous_key_base:%.*]] = tf_mlrt_tpu.compile_and_execute([[var]]) + %0 = "tf.ResourceGather"(%arg0, %indices) {__op_key = 0: i32, device = "/device:CPU:0"} : (tensor>>, tensor) -> tensor + %1 = "tf.TPUCompileMlirAndExecute"(%0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +func.func @executeop_input_same_execute_op(%arg0: tensor, %arg1: tensor<2xf32>) -> (tensor) { + // CHECK-NOT: tf_mlrt.executeop( + // CHECK: [[device:%.*]] = tf_mlrt_tpu.get_tpu_host_device + // CHECK: [[split:%.*]]:2 = tf_mlrt.executeop.device([[device]]) + // CHECK: tf_mlrt_tpu.compile_and_execute([[split]]#0, [[split]]#1) + %0, %1 = "tf.Split"(%arg0, %arg1) {__op_key = 0: i32, device = "/device:CPU:0"} : (tensor, tensor<2xf32>) -> (tensor, tensor) + %2 = "tf.TPUCompileMlirAndExecute"(%0, %1) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// ----- + +// Test that inputs are lowered correctly when they form a DAG. + +// CHECK-LABEL: executeop_dag +func.func @executeop_dag(%arg0: tensor) -> (tensor) { + // CHECK-NEXT: tf_mlrt_tpu.get_tpu_host_device + // CHECK-NEXT: tf_mlrt.executeop.device{{.*}}op: \22Cast\22 + // CHECK-NEXT: tf_mlrt_tpu.get_tpu_host_device + // CHECK-NEXT: tf_mlrt.executeop.device{{.*}}op: \22Relu\22 + // CHECK-NEXT: tf_mlrt_tpu.compile_and_execute + %0 = "tf.Cast"(%arg0) {__op_key = 0: i32, device = "/device:CPU:0"} : (tensor) -> tensor + %1 = "tf.Relu"(%0) {__op_key = 1: i32, device = "/device:CPU:0"} : (tensor) -> (tensor) + %2 = "tf.TPUCompileMlirAndExecute"(%1, %0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// ----- + +func.func @test_fuse_dynamic_dimension_ops(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>, %arg4: tensor<*xi32>, %arg5: tensor, %arg6: tensor, %arg7: tensor) -> tensor<*xi32> { + %0 = "tf.ReadVariableOp"(%arg1) {__op_key = 0: i32, device = "/CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor<*xi32> + %1 = "tf.Shape"(%arg0) {__op_key = 1: i32, device = "/CPU:0"} : (tensor<*xi32>) -> tensor + %2 = "tf.Shape"(%0) {__op_key = 2: i32, device = "/CPU:0"} : (tensor<*xi32>) -> tensor + // CHECK: [[rendezvous_key_base:%.*]], [[result_future:%.*]] = tf_mlrt_tpu.compile_and_execute + // CHECK-SAME: constant_operand_indices = array + // CHECK-SAME: num_operands = 4 + // CHECK-SAME: operands_with_static_shape = array + %rendezvous_key_base, %results = "tf.TPUCompileMlirAndExecute"(%arg0, %2, %0, %1, %arg5, %arg6, %arg7) {operands_with_static_shape = [0 : i32, 1 : i32, 3 : i32], metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor<*xi32>, tensor, tensor<*xi32>, tensor, tensor, tensor, tensor) -> (tensor<3x!tf_type.string>, tensor<*xi32>) + func.return %results : tensor<*xi32> +} + +// ----- + +// Test async output of tf.TPUCompileMlirAndExecute to function is converted + +// CHECK-LABEL: @executeop_input_stream_1 +// CHECK-SAME: ([[future:%.*]]: !mlrt.future +// CHECK: [[tensor:%.*]] = tf_mlrt.await [[future]] +// CHECK: tf_mlrt.executeop([[tensor]]) +// CHECK-SAME: StringFormat + +// CHECK-LABEL: @executeop_input +func.func @executeop_input(%arg0: tensor) -> (tensor) { + // CHECK: tf_mlrt.executeop + %0 = "tf.Cast"(%arg0) {__op_key = 0: i32, device = "/device:CPU:0"} : (tensor) -> tensor + // CHECK: [[rendezvous_key_base:%.*]], [[result:%.*]] = tf_mlrt_tpu.compile_and_execute + %1, %2 = "tf.TPUCompileMlirAndExecute"(%0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor) -> (tensor, tensor) + %3 = "tf.StringFormat"(%2) {__op_key = 1: i32, device = "/job:localhost/replica:0/task:0/device:CPU:0", placeholder = "{}", strtemplate = "%s", summarize = 3 : i64, template = "Outside compiled {}"} : (tensor) -> tensor + "tf.PrintV2"(%3) {__op_key = 2: i32, device = "/job:localhost/replica:0/task:0/device:CPU:0", end = "\0A", output_stream = "stderr"} : (tensor) -> () + // CHECK: [[handle:%.*]] = mlrt.async([[result]]) + // CHECK-SAME: (!mlrt.future) + // CHECK: mlrt.await_handle [[handle]] + // CHECK: return [[rendezvous_key_base]] + // CHECK-SAME: !tf_mlrt.tensor + func.return %1 : tensor +} + +// ----- + +// Test constant arguments to tf.TPUCompileMlirAndExecute are preserved during parallelization. + +// CHECK-LABEL: @preserve_constant_args( +func.func @preserve_constant_args(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*x!tf_type.resource>, %arg3: tensor<*x!tf_type.resource>) -> (tensor) { + // CHECK-NOT: ReadVariableOp + // CHECK: mlrt.async( + %v0 = "tf.ReadVariableOp"(%arg1) {__op_key = 0: i32, device = "/CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor + %v1 = "tf.ReadVariableOp"(%arg2) {__op_key = 1: i32, device = "/CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor + // CHECK: [[cast:%.*]] = tf_mlrt.executeop( + // CHECK-SAME: ReadVariableOp + %v2 = "tf.ReadVariableOp"(%arg3) {__op_key = 2: i32, device = "/CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor + // CHECK: [[cast:%.*]] = tf_mlrt.executeop.device + // CHECK-SAME: Cast + %0 = "tf.Cast"(%arg0) {__op_key = 3: i32, device = "/device:CPU:0"} : (tensor) -> tensor + // CHECK: tf_mlrt_tpu.compile_and_execute({{%.*}}, [[cast]] + // CHECK-SAME: constant_operand_indices = array + %1, %2 = "tf.TPUCompileMlirAndExecute"(%0, %v1, %0, %v2, %v0, %arg0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) + func.return %2 : tensor +} + +// ----- + +func.func @executeop_input_async() -> (tensor, tensor) { + // CHECK-NOT: tf_mlrt.executeop( + // CHECK: [[device:%.*]] = tf_mlrt_tpu.get_tpu_host_device + // CHECK: [[recv_future:%.*]] = tf_mlrt.async_executeop.device([[device]]){{.*}}op: \22Recv\22 + // CHECK: [[recv:%.*]] = tf_mlrt.await [[recv_future]] + // CHECK: [[rendezvous_key_base:%.*]], [[result_future:%.*]] = tf_mlrt_tpu.compile_and_execute([[recv]]) + // CHECK: tf_mlrt.await [[result_future]] + %0 = "tf.Recv"() {__op_key = 0: i32, device = "/device:CPU:0", tensor_name = "tensor", send_device = "/device:CPU:0", send_device_incarnation = 0, recv_device = "/device:CPU:0"} : () -> tensor + %1, %2 = "tf.TPUCompileMlirAndExecute"(%0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor) -> (tensor, tensor) + func.return %1, %2 : tensor, tensor +} + +// ----- + +// Test the output from TPU op is properly awaited before its use by map_fn. +// CHECK-LABEL: @main +// CHECK-SAME: ([[input0:%.*]]: !tf_mlrt.tensor, [[input1:%.*]]: !tf_mlrt.tensor) +func.func @main(%input0: tensor, %input1: tensor, %input2: tensor>> ) -> tensor { + %0 = "tf.Cast"(%input0) {__op_key = 0: i32, device = "/device:CPU:0"} : (tensor) -> tensor + // CHECK: tf_mlrt_tpu.compile_and_execute + %1, %2 = "tf.TPUCompileMlirAndExecute"(%0) {metadata = "metadata", mlir_module = "mlir_module", operand_segment_sizes = array, producer_name = "producer_name"} : (tensor) -> (tensor, tensor) + %max_iter = "tf.Const"() {__op_key = 1, value = dense<2> : tensor} : () -> tensor + // CHECK: tf_mlrt.map_fn + %result = "tf_mlrt.tf_map_fn"(%max_iter, %input2, %2) { operand_segment_sizes = array, body_fn = @NopMapFnBody, num_tensor_list_or_flow_in = 1 : i32} : (tensor, tensor>>, tensor) -> tensor + return %result : tensor +} + +// CHECK-LABEL: @NopMapFnBody +func.func private @NopMapFnBody(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor>>) -> () { + %const = "tf.Const"() {__op_key = 2 : i32, value = dense<1> : tensor} : () -> tensor + %a = "tf.AddV2"(%arg2, %const) {__op_key = 3: i32}: (tensor, tensor) -> tensor + return +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir new file mode 100644 index 00000000000..27c92289a5b --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir @@ -0,0 +1,640 @@ +// RUN: tf-tfrt-opt -split-input-file -tf-mlrt-while-to-map-fn %s | FileCheck %s + +// Test a while to map_fn conversion in which the max iteration is hard coded inside the predicate body. + +// CHECK-LABEL: map/while_cond +func.func private @"map/while_cond"(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> tensor { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor} : () -> tensor + %0 = "tf.Less"(%arg0, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Less"(%arg1, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.LogicalAnd"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: map/while_body +func.func private @"map/while_body"(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> (tensor, tensor, tensor>>, tensor) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00]> : tensor<9xf32>} : () -> tensor<9xf32> + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_3 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]> : tensor<9xf32>} : () -> tensor<9xf32> + %cst_4 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst_4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Mul"(%arg3, %cst_3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<9xf32>) -> tensor<9xf32> + %2 = "tf.Reshape"(%1, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<9xf32>, tensor<2xi32>) -> tensor<3x3xf32> + %3 = "tf.AddV2"(%arg1, %cst_4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %4 = "tf.GatherV2"(%cst_1, %arg1, %cst_0) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<3xi32>, tensor, tensor) -> tensor + %5 = "tf.Cast"(%4) {Truncate = false, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + %6 = "tf.Mul"(%5, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<9xf32>) -> tensor<9xf32> + %7 = "tf.Reshape"(%6, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<9xf32>, tensor<2xi32>) -> tensor<3x3xf32> + %8 = "tf.MatMul"(%2, %7) {device = "/job:localhost/replica:0/task:0/device:CPU:0", transpose_a = false, transpose_b = false} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %9 = "tf.MatrixDeterminant"(%8) {T = f32, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<3x3xf32>) -> tensor + %10 = "tf.TensorListSetItem"(%arg2, %arg1, %9) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor>>, tensor, tensor) -> tensor>> + return %0, %3, %10, %arg3 : tensor, tensor, tensor>>, tensor +} + +// CHECK-LABEL: map/while_body/MapFnBody +// CHECK-SAME: (%arg0: !mlrt.future, %arg1: !mlrt.promise, %arg2: tensor, %arg3: tensor, %arg4: tensor) +// CHECK: [[det:%.*]] = "tf.MatrixDeterminant" +// CHECK-NEXT: [[ta_0:%.*]] = "tf_mlrt.tf_await"(%arg0) : (!mlrt.future) -> tensor>> +// CHECK-NEXT: [[ta_1:%.*]] = "tf.TensorListSetItem"([[ta_0]], %arg3, [[det]]) { +// CHECK-NEXT: "tf_mlrt.tf_promise"(%arg1, [[ta_1]]) : (!mlrt.promise, tensor>>) -> () +// CHECK-NEXT: return + +//CHECK-LABEL: @serving_default +func.func @serving_default(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> tensor<3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<3> : tensor} : () -> tensor + // CHECK: [[tensor_list:%.*]] = "tf.TensorListReserve"([[shape:%.*]], [[reserve_size:%.*]]) { + %0 = "tf.TensorListReserve"(%cst_1, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor>> + // CHECK: [[map_fn_result:%.*]] = tf_mlrt.tf_map_fn([[reserve_size]], [[tensor_list]], %arg0) + // CHECK-SAME: {body_fn = @"map/while_body/MapFnBody", num_tensor_list_or_flow_in = 1 : i32} + // CHECK-NOT: tf.While + %1:4 = "tf.While"(%cst, %cst, %0, %arg0) {_lower_using_switch_merge = true, _num_original_outputs = 6 : i64, _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, body = @"map/while_body", cond = @"map/while_cond", device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = true, parallel_iterations = 4 : i64, shape_invariant} : (tensor, tensor, tensor>>, tensor) -> (tensor, tensor, tensor>>, tensor) + // CHECK-NEXT: "tf.TensorListStack"([[map_fn_result]], %cst_0) { + %2 = "tf.TensorListStack"(%1#2, %cst_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 3 : i64} : (tensor>>, tensor<0xi32>) -> tensor<3xf32> + return %2 : tensor<3xf32> +} + +// ----- + +// Test a while to map_fn conversion in which max_iterations are passed +// into the predicate function. + +// CHECK-LABEL: @"map/while_cond" +func.func private @"map/while_cond"(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor, %arg4: tensor>>, %arg5: tensor, %arg6: tensor) -> tensor { + %outputs = "tf.Less"(%arg0, %arg3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %outputs_0 = "tf.Less"(%arg1, %arg3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %outputs_2 = "tf.LogicalAnd"(%outputs_0, %outputs) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + + return %outputs_2 : tensor +} + +// CHECK-LABEL: @"map/while_body" +func.func private @"map/while_body"(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor, %arg4: tensor>>, %arg5: tensor, %arg6: tensor) -> (tensor, tensor, tensor>>, tensor, tensor>>, tensor, tensor) { + %outputs = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %outputs_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %outputs_2 = "tf.AddV2"(%arg0, %outputs_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %outputs_4 = "tf.ReadVariableOp"(%arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor>>) -> tensor<3x1xf32> + %outputs_6 = "tf.Identity"(%outputs_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + %outputs_8 = "tf.MatMul"(%arg5, %outputs_4) {device = "/job:localhost/replica:0/task:0/device:CPU:0", transpose_a = false, transpose_b = false} : (tensor, tensor<3x1xf32>) -> tensor + %outputs_10 = "tf.AddV2"(%arg1, %outputs_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %outputs_12 = "tf.Identity"(%outputs_10) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + %outputs_14 = "tf.GatherV2"(%arg6, %arg1, %outputs) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor, tensor) -> tensor<4xf32> + %outputs_16 = "tf.AddV2"(%outputs_8, %outputs_14) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<4xf32>) -> tensor + %outputs_18 = "tf.TensorListSetItem"(%arg2, %arg1, %outputs_16) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor>>, tensor, tensor) -> tensor>> + return %outputs_6, %outputs_12, %outputs_18, %arg3, %arg4, %arg5, %arg6 : tensor, tensor, tensor>>, tensor, tensor>>, tensor, tensor +} + +// CHECK-LABEL: @"map/while_body/MapFnBody" +// CHECK-SAME (%arg0: !mlrt.Future, %arg1: !mlrt.Promise, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor>>, %arg6: tensor, %arg7: tensor) +// CHECK-NEXT: [[cst_0:%.*]] = "tf.Const" +// CHECK-NEXT: [[cst_1:%.*]] = "tf.Const" +// CHECK-NEXT: [[loop_counter:%.*]] = "tf.AddV2"(%arg2, [[cst_1]]) +// CHECK-NEXT: [[weight:%.*]] = "tf.ReadVariableOp"(%arg5) +// CHECK-NEXT: [[mpy:%.*]] = "tf.MatMul"(%arg6, [[weight]]) +// CHECK-NEXT: [[element_index:%.*]] = "tf.AddV2"(%arg3, [[cst_1]]) +// CHECK-NEXT: [[bias:%.*]] = "tf.GatherV2"(%arg7, %arg3, [[cst_0]]) +// CHECK-NEXT: [[res:%.*]] = "tf.AddV2"([[mpy]], [[bias]]) +// CHECK-NEXT: [[ta_0:%.*]] = "tf_mlrt.tf_await"(%arg0) +// CHECK-NEXT: [[ta_1:%.*]] = "tf.TensorListSetItem"([[ta_0]], %arg3, [[res]]) +// CHECK-NEXT: "tf_mlrt.tf_promise"(%arg1, [[ta_1]]) +// CHECK-NEXT: return + +// CHECK-LABEL: func @main_while +func.func @main_while(%arg0: tensor, %arg1: tensor) -> tensor { + %outputs = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[-1, 4]> : tensor<2xi32>} : () -> tensor<2xi32> + %outputs_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor} : () -> tensor + %outputs_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %outputs_4 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %outputs_6 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %outputs_8 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + // CHECK: [[elems:%.*]] = "tf.VarHandleOp" + %outputs_10 = "tf.VarHandleOp"() {_xla_inferred_shapes = [#tf_type.shape<>], allowed_devices = [], container = "", device = "/job:localhost/replica:0/task:0/device:CPU:0", shared_name = "w"} : () -> tensor>> + %outputs_12 = "tf.Shape"(%arg1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor<2xi32> + // CHECK: [[max_iter:%.*]] = "tf.StridedSlice" + %outputs_14 = "tf.StridedSlice"(%outputs_12, %outputs_2, %outputs_4, %outputs_4) {begin_mask = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK: [[tensor_list:%.*]] = "tf.TensorListReserve" + %outputs_16 = "tf.TensorListReserve"(%outputs_0, %outputs_14) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor>> + // CHECK: tf_mlrt.tf_map_fn + // CHECK-SAME: ([[max_iter]], [[tensor_list]], [[max_iter]], [[elems]], %arg0, %arg1) + // CHECK-SAME: {body_fn = @"map/while_body/MapFnBody", num_tensor_list_or_flow_in = 1 : i32} + // CHECK-NOT: tf.while + %outputs_18:7 = "tf.While"(%outputs_6, %outputs_6, %outputs_16, %outputs_14, %outputs_10, %arg0, %arg1) {_lower_using_switch_merge = true, _num_original_outputs = 8 : i64, _read_only_resource_inputs = [6], _xla_propagate_compile_time_consts = true, body = @"map/while_body", cond = @"map/while_cond", device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor, tensor, tensor>>, tensor, tensor>>, tensor, tensor) -> (tensor, tensor, tensor>>, tensor, tensor>>, tensor, tensor) + %outputs_20 = "tf.TensorListStack"(%outputs_18#2, %outputs) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = -1 : i64} : (tensor>>, tensor<2xi32>) -> tensor + return %outputs_20 : tensor +} + +// ----- + +// Test a while to map_fn conversion in which the passed in max_iterations +// is not in typical location of %arg3 and there are identify chains in function bodies. + +// CHECK-LABEL: @map_while_cond_170 +func.func private @map_while_cond_170(%arg0: tensor {tf._user_specified_name = "map/while/loop_counter"}, %arg1: tensor {tf._user_specified_name = "map/while/maximum_iterations"}, %arg2: tensor, %arg3: tensor, %arg4: tensor<*x!tf_type.variant>, %arg5: tensor<*xf32>) -> tensor<*xi1> attributes {tf._construction_context = "kEagerRuntime", tf._original_func_name = "map_while_cond_17"} { + %outputs = "tf.Const"() {device = "", value = dense<16> : tensor} : () -> tensor + %outputs_0 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor<*xi1> + %outputs_2 = "tf.Less"(%arg2, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi1> + %outputs_4 = "tf.LogicalAnd"(%outputs_0, %outputs_2) {device = ""} : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + %outputs_6 = "tf.Identity"(%outputs_4) {device = ""} : (tensor<*xi1>) -> tensor<*xi1> + return %outputs_6 : tensor<*xi1> +} + +// Original input argument list (loop_counter, max_iterations, element_index, tensor_list, read_only_tensor_list, scale) +// CHECK-LABEL: @map_while_body_180 +func.func private @map_while_body_180(%arg0: tensor {tf._user_specified_name = "map/while/loop_counter"}, %arg1: tensor {tf._user_specified_name = "map/while/maximum_iterations"}, %arg2: tensor, %arg3: tensor, %arg4: tensor {tf._user_specified_name = "map/TensorArrayUnstack/TensorListFromTensor"}, %arg5: tensor {tf._user_specified_name = "input"}) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*x!tf_type.variant>, tensor, tensor) attributes {tf._construction_context = "kEagerRuntime", tf._original_func_name = "map_while_body_18"} { + %outputs = "tf.Const"() {device = "", value = dense<16> : tensor<2xi32>} : () -> tensor<2xi32> + %outputs_0 = "tf.Const"() {device = "", value = dense<16> : tensor<2xi32>} : () -> tensor<2xi32> + %outputs_2 = "tf.Const"() {device = "", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %outputs_4 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_6 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_8 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_10 = "tf.Const"() {device = "", value = dense<256> : tensor} : () -> tensor + %outputs_12 = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %outputs_14 = "tf.Range"(%outputs_12, %outputs_10, %outputs_8) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi32> + %outputs_16 = "tf.Cast"(%outputs_14) {Truncate = false, device = ""} : (tensor<*xi32>) -> tensor<*xf32> + %outputs_18 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_20 = "tf.Const"() {device = "", value = dense<257> : tensor} : () -> tensor + %outputs_22 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_24 = "tf.Range"(%outputs_22, %outputs_20, %outputs_18) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi32> + %outputs_26 = "tf.Cast"(%outputs_24) {Truncate = false, device = ""} : (tensor<*xi32>) -> tensor<*xf32> + %outputs_28 = "tf.Const"() {device = "", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %outputs_30 = "tf.Transpose"(%outputs_26, %outputs_28) {device = ""} : (tensor<*xf32>, tensor<1xi32>) -> tensor<*xf32> + %outputs_32 = "tf.AddV2"(%arg0, %outputs_6) {device = ""} : (tensor, tensor) -> tensor<*xi32> + %outputs_34 = "tf.Identity"(%outputs_32) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_36 = "tf.Identity"(%arg1) {device = ""} : (tensor) -> tensor<*xi32> + %outputs_38 = "tf.Mul"(%outputs_16, %arg5) {device = ""} : (tensor<*xf32>, tensor) -> tensor<*xf32> + %outputs_40 = "tf.Reshape"(%outputs_38, %outputs) {device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %outputs_42 = "tf.AddV2"(%arg2, %outputs_4) {device = ""} : (tensor, tensor) -> tensor<*xi32> + %outputs_44 = "tf.Identity"(%outputs_42) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_46 = "tf.TensorListGetItem"(%arg4, %arg2, %outputs_2) {device = ""} : (tensor, tensor, tensor<0xi32>) -> tensor<*xi32> + %outputs_48 = "tf.Cast"(%outputs_46) {Truncate = false, device = ""} : (tensor<*xi32>) -> tensor<*xf32> + %outputs_50 = "tf.Mul"(%outputs_30, %outputs_48) {device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %outputs_52 = "tf.Reshape"(%outputs_50, %outputs_0) {device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %outputs_54 = "tf.MatMul"(%outputs_40, %outputs_52) {device = "", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %outputs_56 = "tf.MatrixDeterminant"(%outputs_54) {T = f32, device = ""} : (tensor<*xf32>) -> tensor<*xf32> + %outputs_58 = "tf.TensorListSetItem"(%arg3, %arg2, %outputs_56) {device = "", resize_if_index_out_of_bounds = false} : (tensor, tensor, tensor<*xf32>) -> tensor<*x!tf_type.variant> + %outputs_60 = "tf.Identity"(%outputs_58) {device = ""} : (tensor<*x!tf_type.variant>) -> tensor<*x!tf_type.variant> + return %outputs_34, %outputs_36, %outputs_44, %outputs_60, %arg4, %arg5 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*x!tf_type.variant>, tensor, tensor +} + +// Converted input argument list (loop_counter, element_index, max_iterations, tensor_list, read_only_tensor_list, scale) +// CHECK-LABEL: @"map_while_body_180/MapFnBody" +// CHECK-SAME: (%arg0: !mlrt.future, %arg1: !mlrt.promise, %arg2: tensor {tf._user_specified_name = "map/while/loop_counter"}, %arg3: tensor, %arg4: tensor {tf._user_specified_name = "map/while/maximum_iterations"}, %arg5: tensor {tf._user_specified_name = "map/TensorArrayUnstack/TensorListFromTensor"}, %arg6: tensor {tf._user_specified_name = "input"}) +// CHECK: [[res:%.*]] = "tf.MatrixDeterminant" +// CHECK-NEXT: [[ta_0:%.*]] = "tf_mlrt.tf_await"(%arg0) +// CHECK-NEXT: [[ta_1:%.*]] = "tf.TensorListSetItem"([[ta_0]], %arg3, [[res]]) +// CHECK-NEXT: "tf_mlrt.tf_promise"(%arg1, [[ta_1]]) +// CHECK-NEXT: return + + +// CHECK-LABEL: __inference_while_from_map_fn_810 +// CHECK-SAME: ([[scale:%.*]]: tensor +func.func private @__inference_while_from_map_fn_810(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor<*xf32> attributes {tf._construction_context = "kEagerRuntime", tf._original_func_name = "__inference_while_from_map_fn_81"} { + // CHECK: [[element_index:%.*]] = "tf.Const" + %outputs = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %outputs_0 = "tf.Const"() {device = "", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %outputs_2= "tf.Const"() {device = "", value = dense<-1> : tensor} : () -> tensor + %outputs_4 = "tf.Const"() {device = "", value = dense<16> : tensor} : () -> tensor + // CHECK: tf.TensorListReserve + %outputs_6 = "tf.TensorListReserve"(%outputs_2, %outputs_4) {device = ""} : (tensor, tensor) -> tensor>> + %outputs_8 = "tf.Const"() {device = "", value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %outputs_10 = "tf.Const"() {device = "", value = dense<-1> : tensor} : () -> tensor + %outputs_12 = "tf.Const"() {device = "", value = dense<16> : tensor} : () -> tensor + // CHECK: [[tensor_list:%.*]] = "tf.TensorListReserve"([[shape:%.*]], [[reserve_size:%.*]]) { + %outputs_14 = "tf.TensorListReserve"(%outputs_10, %outputs_12) {device = ""} : (tensor, tensor) -> tensor>> + // CHECK-NEXT: [[loop_counter:%.*]] = "tf.Const" + %outputs_16 = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + // CHECK-NEXT: [[max_iterations:%.*]] = "tf.Const" + %outputs_18 = "tf.Const"() {device = "", value = dense<16> : tensor} : () -> tensor + %outputs_20 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_22 = "tf.Const"() {device = "", value = dense<16> : tensor} : () -> tensor + %outputs_24 = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %outputs_26 = "tf.Range"(%outputs_24, %outputs_22, %outputs_20) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi32> + // CHECK: [[read_only_tensor_list:%.*]] = "tf.TensorListFromTensor" + %outputs_28 = "tf.TensorListFromTensor"(%outputs_26, %outputs_0) {device = ""} : (tensor<*xi32>, tensor<0xi32>) -> tensor<*x!tf_type.variant> +// CHECK: [[map_fn_out:%.*]] = tf_mlrt.tf_map_fn + // CHECK-SAME: ([[reserve_size]], [[tensor_list]], [[max_iterations]], [[read_only_tensor_list]], [[scale]]) + // CHECK-SAME: {body_fn = @"map_while_body_180/MapFnBody", num_tensor_list_or_flow_in = 1 : i32} + // CHECK-NOT: tf.While + %outputs_30:6 = "tf.While"(%outputs_16, %outputs_18, %outputs, %outputs_14, %outputs_28, %arg0) {T = [i32, i32, i32, !tf_type.variant, !tf_type.variant, f32], _lower_using_switch_merge = true, _num_original_outputs = 6 : i64, _read_only_resource_inputs = [], body = @map_while_body_180, cond = @map_while_cond_170, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape], parallel_iterations = 4 : i64, shape_invariant} : (tensor, tensor, tensor, tensor>>, tensor<*x!tf_type.variant>, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) + // CHECK-NEXT: "tf.TensorListStack" + // CHECK-SAME: ([[map_fn_out]], + %outputs_32 = "tf.TensorListStack"(%outputs_30#3, %outputs_8) {device = "", num_elements = 16 : i64} : (tensor, tensor<0xi32>) -> tensor<*xf32> + %outputs_34 = "tf.Identity"(%outputs_32) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + return %outputs_34 : tensor<*xf32> +} + +// ----- + +// Test a while to map_fn conversion in which tensor array is used instead of +// tensor list. + +// CHECK-LABEL: map/while/LoopCond_cond +func.func private @"map/while/LoopCond_cond"(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor, %arg3: tensor, %arg4: tensor<2x!tf_type.resource>>, %arg5: tensor, %arg6: tensor<2x!tf_type.resource>>) -> tensor { + %outputs = "tf.Less"(%arg0, %arg3) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi1> + %outputs_0 = "tf.Less"(%arg1, %arg3) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi1> + %outputs_2 = "tf.LogicalAnd"(%outputs, %outputs_0) {device = ""} : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + %outputs_4 = "tf.ToBool"(%outputs_2) : (tensor<*xi1>) -> tensor + return %outputs_4 : tensor +} + +// CHECK-LABEL: map/while/LoopCond_body +func.func private @"map/while/LoopCond_body"(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor, %arg3: tensor, %arg4: tensor<2x!tf_type.resource>>, %arg5: tensor, %arg6: tensor<2x!tf_type.resource>>) -> (tensor<*xi32>, tensor<*xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>>) { + %outputs = "tf.Const"() {value = dense<224> : tensor<2xi32>} : () -> tensor<2xi32> + %outputs_0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %outputs_2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %outputs_4 = "tf.Identity"(%arg0) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_6 = "tf.AddV2"(%outputs_4, %outputs_2) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %outputs_8 = "tf.Identity"(%arg1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_10 = "tf.AddV2"(%outputs_8, %outputs_2) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %outputs_12 = "tf.Identity"(%arg2) {device = ""} : (tensor) -> tensor + %outputs_14 = "tf.TensorArrayReadV3"(%arg4, %outputs_8, %arg5) {device = ""} : (tensor<2x!tf_type.resource>>, tensor<*xi32>, tensor) -> tensor<*x!tf_type.string> + %outputs_16 = "tf.DecodeJpeg"(%outputs_14) {acceptable_fraction = 1.000000e+00 : f32, channels = 3 : i64, dct_method = "INTEGER_FAST", device = "", fancy_upscaling = true, ratio = 1 : i64, try_recover_truncated = false} : (tensor<*x!tf_type.string>) -> tensor + %outputs_18 = "tf.ExpandDims"(%outputs_16, %outputs_0) {device = ""} : (tensor, tensor) -> tensor<1x?x?x3xui8> + %outputs_20 = "tf.ResizeBilinear"(%outputs_18, %outputs) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x3xui8>, tensor<2xi32>) -> tensor<1x224x224x3xf32> + %outputs_22 = "tf.Squeeze"(%outputs_20) {device = "", squeeze_dims = [0]} : (tensor<1x224x224x3xf32>) -> tensor<224x224x3xf32> + %outputs_24 = "tf.Cast"(%outputs_22) {Truncate = false, device = ""} : (tensor<224x224x3xf32>) -> tensor<224x224x3xui8> + %outputs_26 = "tf.TensorArrayWriteV3"(%arg6, %outputs_8, %outputs_24, %outputs_12) {device = ""} : (tensor<2x!tf_type.resource>>, tensor<*xi32>, tensor<224x224x3xui8>, tensor) -> tensor + return %outputs_6, %outputs_10, %outputs_26, %arg3, %arg4, %arg5, %arg6: tensor<*xi32>, tensor<*xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>> +} + +// CHECK-LABEL: @"map/while/LoopCond_body/MapFnBody" +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.TensorArrayReadV3 +// CHECK-NEXT: tf.DecodeJpeg +// CHECK-NEXT: tf.ExpandDims +// CHECK-NEXT: tf.ResizeBilinear +// CHECK-NEXT: tf.Squeeze +// CHECK-NEXT: tf.Cast +// CHECK-NEXT: tf_mlrt.tf_await +// CHECK-NEXT: tf.TensorArrayWriteV3 +// CHECK-NEXT: tf_mlrt.tf_promise +// CHECK-NEXT: return + +//CHECK-LABEL: map_while_test +func.func @map_while_test(%arg0: tensor) -> tensor { + %outputs = "tf.Const"() {value = dense<0> : tensor} : () -> tensor<1xi32> + %outputs_0 = "tf.Const"() {value = dense<224> : tensor<2xi32>} : () -> tensor<2xi32> + %outputs_2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %outputs_4 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %outputs_6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %outputs_8 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %outputs_10 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<1xi32> + // CHECK: [[max_iter:%.*]] = "tf.StridedSlice" + %outputs_12 = "tf.StridedSlice"(%outputs_10, %outputs_6, %outputs_4, %outputs_4) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + // CHECK-NEXT: tf.Range + %outputs_14 = "tf.Range"(%outputs_2, %outputs_12, %outputs_8) {device = ""} : (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: [[handle_1:%.*]], [[flow_in_1:%.*]] = "tf.TensorArrayV3" + %outputs_16:2 = "tf.TensorArrayV3"(%outputs_12) {clear_after_read = true, device = "", dtype = !tf_type.string, dynamic_size = false, element_shape = #tf_type.shape<*>, identical_element_shapes = true, tensor_array_name = ""} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + // CHECK-NEXT: [[handle_2:%.*]] = "tf.TensorArrayScatterV3" + %outputs_18 = "tf.TensorArrayScatterV3"(%outputs_16#0, %outputs_14, %arg0, %outputs_16#1) {device = ""} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor + // CHECK-NEXT: tf.Range + %outputs_20 = "tf.Range"(%outputs_2, %outputs_12, %outputs_8) {device = ""} : (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: [[tensor_array:%.*]], [[flow_in:%.*]] = "tf.TensorArrayV3" + %outputs_22:2 = "tf.TensorArrayV3"(%outputs_12) {clear_after_read = true, device = "", dtype = ui8, dynamic_size = false, element_shape = #tf_type.shape<*>, identical_element_shapes = true, tensor_array_name = ""} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + // CHECK-NEXT: tf_mlrt.tf_map_fn + // CHECK-SAME: ([[max_iter]], [[flow_in]], [[max_iter]], [[handle_1]], [[handle_2]], [[tensor_array]]) + // CHECK-SAME: {body_fn = @"map/while/LoopCond_body/MapFnBody", num_tensor_list_or_flow_in = 1 : i32} + // CHECK-NOT: tf.While + %outputs_24:7 = "tf.While"(%outputs, %outputs, %outputs_22#1, %outputs_12, %outputs_16#0, %outputs_18, %outputs_22#0) {_xla_propagate_compile_time_consts = true, body = @"map/while/LoopCond_body", cond = @"map/while/LoopCond_cond", device = "", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>>) -> (tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>>) + // CHECK-NEXT: tf.TensorArrayGatherV3 + %outputs_26 = "tf.TensorArrayGatherV3"(%outputs_22#0, %outputs_20, %outputs_24#2) {device = "", element_shape = #tf_type.shape<224x224x3>} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor + return %outputs_26 : tensor +} + +// ----- +// Test non-applicable while is NOT converted to map_fn. + +// CHECK-LABEL: func @while_cond_lt9 +func.func @while_cond_lt9(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> tensor + %1 = "tf.Less"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: func @while_body_add2 +func.func @while_body_add2(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor + %1 = "tf.Add"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: func @while_test() +func.func @while_test() -> (tensor) { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: tf.While + %1 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) + func.return %1 : tensor +} + +// ----- + +// Test a case that the while body has multiple tensor lists. + +// CHECK-LABEL: tf.MultiListWhileRegion_body +func.func private @tf.MultiListWhileRegion_body(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor>>, %arg4: tensor) -> (tensor, tensor, tensor>>, tensor>>, tensor) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00], [8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01, 1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01]]> : tensor<2x8xf32>} : () -> tensor<2x8xf32> + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[[1.600000e+01, 1.700000e+01, 1.800000e+01, 1.900000e+01, 2.000000e+01, 2.100000e+01, 2.200000e+01, 2.300000e+01], [2.400000e+01, 2.500000e+01, 2.600000e+01, 2.700000e+01, 2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]]> : tensor<2x8xf32>} : () -> tensor<2x8xf32> + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %0 = "tf.GatherV2"(%arg4, %cst_2, %cst_2) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor, tensor) -> tensor + %1 = "tf.AddV2"(%arg0, %cst_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%arg1, %cst_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %3 = "tf.GatherV2"(%cst_0, %arg1, %cst_2) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x8xf32>, tensor, tensor) -> tensor<8xf32> + %4 = "tf.Mul"(%0, %3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<8xf32>) -> tensor<8xf32> + %5 = "tf.TensorListSetItem"(%arg2, %arg1, %4) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor>>, tensor, tensor<8xf32>) -> tensor>> + %6 = "tf.GatherV2"(%cst, %arg1, %cst_2) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x8xf32>, tensor, tensor) -> tensor<8xf32> + %7 = "tf.Mul"(%0, %6) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor<8xf32>) -> tensor<8xf32> + %8 = "tf.TensorListSetItem"(%arg3, %arg1, %7) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor>>, tensor, tensor<8xf32>) -> tensor>> + return %1, %2, %5, %8, %arg4 : tensor, tensor, tensor>>, tensor>>, tensor +} + +// CHECK-LABEL: tf.MultiListWhileRegion_body/MapFnBody +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.GatherV2 +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.GatherV2 +// CHECK-NEXT: tf.Mul +// CHECK-NEXT: tf.GatherV2 +// CHECK-NEXT: tf.Mul +// CHECK-NEXT: tf_mlrt.tf_await +// CHECK-NEXT: tf_mlrt.tf_await +// CHECK-NEXT: tf.TensorListSetItem +// CHECK-NEXT: tf.TensorListSetItem +// CHECK-NEXT: tf_mlrt.tf_promise +// CHECK-NEXT: tf_mlrt.tf_promise +// CHECK-NEXT: return + +// CHECK-LABEL: tf.MultiListWhileRegion_cond +func.func private @tf.MultiListWhileRegion_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor>>, %arg4: tensor) -> tensor { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<2> : tensor} : () -> tensor + %0 = "tf.Less"(%arg0, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Less"(%arg1, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.LogicalAnd"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: multilist_serving +func.func private @multilist_serving(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<2x8xf32>, tensor<2x8xf32>) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<2> : tensor} : () -> tensor + // CHECK: TensorListReserve + %0 = "tf.TensorListReserve"(%cst_1, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor>> + // CHECK-NEXT: tf_mlrt.tf_map_fn + %1:5 = "tf.While"(%cst, %cst, %0, %0, %arg0) {_lower_using_switch_merge = true, _num_original_outputs = 8 : i64, _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, body = @tf.MultiListWhileRegion_body, cond = @tf.MultiListWhileRegion_cond, device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = true, parallel_iterations = 4 : i64, shape_invariant} : (tensor, tensor, tensor>>, tensor>>, tensor) -> (tensor, tensor, tensor>>, tensor>>, tensor) + // CHECK-NEXT: TensorListStack + %2 = "tf.TensorListStack"(%1#2, %cst_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 2 : i64} : (tensor>>, tensor<1xi32>) -> tensor<2x8xf32> + %3 = "tf.TensorListStack"(%1#3, %cst_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 2 : i64} : (tensor>>, tensor<1xi32>) -> tensor<2x8xf32> + return %3, %2 : tensor<2x8xf32>, tensor<2x8xf32> +} + + +// ----- + +// Convert a while with multiple tensor array to map_fn + +// CHECK-LABEL: tf.WhileRegion1_body( +func.func private @tf.WhileRegion1_body(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor<2x!tf_type.resource>>, %arg6: tensor<2x!tf_type.resource>>, %arg7: tensor<*xi32>) -> (tensor, tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>, tensor<*xi32>) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%arg1, %cst_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %3 = "tf.RaggedTensorToVariant"(%arg7) {RAGGED_RANK = 0 : i64, Tsplits = i64, Tvalues = i32, batched_input = false, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*xi32>) -> tensor + %4 = "tf.TensorArrayWriteV3"(%arg5, %arg1, %3, %arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor + %5 = "tf.RaggedTensorToVariant"(%arg7) {RAGGED_RANK = 0 : i64, Tsplits = i64, Tvalues = f32, batched_input = false, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*xi32>) -> tensor + %6 = "tf.TensorArrayWriteV3"(%arg6, %arg1, %5, %arg3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor + return %0, %1, %4, %6, %arg4, %arg5, %arg6, %arg7 : tensor, tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>, tensor<*xi32> +} + +// CHECK-LABEL: func.func private @"tf.WhileRegion1_body/MapFnBody"(%arg0: !mlrt.future, %arg1: !mlrt.promise, %arg2: !mlrt.future, %arg3: !mlrt.promise, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<2x!tf_type.resource>>, %arg8: tensor<2x!tf_type.resource>>, %arg9: tensor<*xi32>) attributes {tfrt.cost_threshold = 4294967295 : i64} +// CHECK: [[result_0:%.*]] = "tf.RaggedTensorToVariant" +// CHECK: [[result_1:%.*]] = "tf.RaggedTensorToVariant" +// CHECK-NEXT: [[flow_in_0:%.*]] = "tf_mlrt.tf_await"(%arg0) : (!mlrt.future) -> tensor +// CHECK-NEXT: [[flow_in_1:%.*]] = "tf_mlrt.tf_await"(%arg2) : (!mlrt.future) -> tensor +// CHECK-NEXT: [[flow_out_0:%.*]] = "tf.TensorArrayWriteV3"(%arg7, %arg5, [[result_0]], [[flow_in_0]]) +// CHECK-NEXT: [[flow_out_1:%.*]] = "tf.TensorArrayWriteV3"(%arg8, %arg5, [[result_1]], [[flow_in_1]]) +// CHECK-NEXT: "tf_mlrt.tf_promise"(%arg1, [[flow_out_0]]) : (!mlrt.promise, tensor) -> () +// CHECK-NEXT: "tf_mlrt.tf_promise"(%arg3, [[flow_out_1]]) : (!mlrt.promise, tensor) -> () +// CHECK-NEXT: return + +// CHECK-LABEL: tf.WhileRegion1_cond +func.func private @tf.WhileRegion1_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor<2x!tf_type.resource>>, %arg6: tensor<2x!tf_type.resource>>, %arg7: tensor<*xi32>) -> (tensor) { + %0 = "tf.Less"(%arg0, %arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor<*xi1> + %1 = "tf.Less"(%arg1, %arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor<*xi1> + %2 = "tf.LogicalAnd"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + %3 = "tf.ToBool"(%2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*xi1>) -> tensor + return %3 : tensor +} + +// CHECK-LABEL: func.func private @tf.WhileRegion2_body( +func.func private @tf.WhileRegion2_body(%arg0: tensor<*xi32>) -> (tensor, tensor) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[-1, 4]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %max_iter = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<4> : tensor} : () -> tensor + // CHECK: "tf.TensorArrayV3" + %handle_12, %flow_13 = "tf.TensorArrayV3"(%max_iter) {device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = !tf_type.variant, dynamic_size = false, element_shape = #tf_type.shape<*>, identical_element_shapes = true, tensor_array_name = ""} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + // CHECK: "tf.TensorArrayV3" + %handle_14, %flow_15 = "tf.TensorArrayV3"(%max_iter) {device = "/job:localhost/replica:0/task:0/device:CPU:0", dtype = !tf_type.variant, dynamic_size = false, element_shape = #tf_type.shape<*>, identical_element_shapes = true, tensor_array_name = ""} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + // CHECK: tf_mlrt.tf_map_fn + // CHECK-SAME: {body_fn = @"tf.WhileRegion1_body/MapFnBody", num_tensor_list_or_flow_in = 2 : i32} + %4:8 = "tf.While"(%cst_0, %cst_0, %flow_13, %flow_15, %max_iter, %handle_12, %handle_14, %arg0) {body = @tf.WhileRegion1_body, cond = @tf.WhileRegion1_cond, device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor, tensor, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, tensor, tensor, tensor, tensor<2x!tf_type.resource>>, tensor<2x!tf_type.resource>>, tensor<*xi32>) + // CHECK: TensorArrayGatherV3 + %5 = "tf.TensorArrayGatherV3"(%handle_12, %1, %4#2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor + // CHECK: TensorArrayGatherV3 + %6 = "tf.TensorArrayGatherV3"(%handle_14, %2, %4#3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor + return %5, %6 : tensor, tensor +} + +// ----- + +// Test a while to map_fn conversion in which tensor array is used instead of +// tensor list and the tensor array size and the number of iterations are bounded +// by separate constants of the same value. + +// CHECK-LABEL: map2/while/LoopCond_body +func.func private @"map2/while/LoopCond_body"(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor, %arg3: tensor, %arg4: tensor<2x!tf_type.resource>>, %arg5: tensor, %arg6: tensor<2x!tf_type.resource>>) -> (tensor<*xi32>, tensor<*xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>>) { + %outputs = "tf.Const"() {value = dense<224> : tensor<2xi32>} : () -> tensor<2xi32> + %outputs_0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %outputs_2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %outputs_4 = "tf.Identity"(%arg0) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_6 = "tf.AddV2"(%outputs_4, %outputs_2) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %outputs_8 = "tf.Identity"(%arg1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_10 = "tf.AddV2"(%outputs_8, %outputs_2) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %outputs_12 = "tf.Identity"(%arg2) {device = ""} : (tensor) -> tensor + %outputs_14 = "tf.TensorArrayReadV3"(%arg4, %outputs_8, %arg5) {device = ""} : (tensor<2x!tf_type.resource>>, tensor<*xi32>, tensor) -> tensor<*x!tf_type.string> + %outputs_16 = "tf.DecodeJpeg"(%outputs_14) {acceptable_fraction = 1.000000e+00 : f32, channels = 3 : i64, dct_method = "INTEGER_FAST", device = "", fancy_upscaling = true, ratio = 1 : i64, try_recover_truncated = false} : (tensor<*x!tf_type.string>) -> tensor + %outputs_18 = "tf.ExpandDims"(%outputs_16, %outputs_0) {device = ""} : (tensor, tensor) -> tensor<1x?x?x3xui8> + %outputs_20 = "tf.ResizeBilinear"(%outputs_18, %outputs) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x3xui8>, tensor<2xi32>) -> tensor<1x224x224x3xf32> + %outputs_22 = "tf.Squeeze"(%outputs_20) {device = "", squeeze_dims = [0]} : (tensor<1x224x224x3xf32>) -> tensor<224x224x3xf32> + %outputs_24 = "tf.Cast"(%outputs_22) {Truncate = false, device = ""} : (tensor<224x224x3xf32>) -> tensor<224x224x3xui8> + %outputs_26 = "tf.TensorArrayWriteV3"(%arg6, %outputs_8, %outputs_24, %outputs_12) {device = ""} : (tensor<2x!tf_type.resource>>, tensor<*xi32>, tensor<224x224x3xui8>, tensor) -> tensor + return %outputs_6, %outputs_10, %outputs_26, %arg3, %arg4, %arg5, %arg6: tensor<*xi32>, tensor<*xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>> +} + +// CHECK-LABEL: @"map2/while/LoopCond_body/MapFnBody" +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.Const +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.TensorArrayReadV3 +// CHECK-NEXT: tf.DecodeJpeg +// CHECK-NEXT: tf.ExpandDims +// CHECK-NEXT: tf.ResizeBilinear +// CHECK-NEXT: tf.Squeeze +// CHECK-NEXT: tf.Cast +// CHECK-NEXT: tf_mlrt.tf_await +// CHECK-NEXT: tf.TensorArrayWriteV3 +// CHECK-NEXT: tf_mlrt.tf_promise +// CHECK-NEXT: return + +// CHECK-LABEL: map2/while/LoopCond_cond +func.func private @"map2/while/LoopCond_cond"(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor, %arg3: tensor, %arg4: tensor<2x!tf_type.resource>>, %arg5: tensor, %arg6: tensor<2x!tf_type.resource>>) -> tensor { + %cst = "tf.Const"() {value = dense<224> : tensor} : () -> tensor + %outputs = "tf.Less"(%arg0, %cst) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi1> + %outputs_0 = "tf.Less"(%arg1, %cst) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi1> + %outputs_2 = "tf.LogicalAnd"(%outputs, %outputs_0) {device = ""} : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + %outputs_4 = "tf.ToBool"(%outputs_2) : (tensor<*xi1>) -> tensor + return %outputs_4 : tensor +} + +//CHECK-LABEL: map2_while_test +func.func private @map2_while_test(%arg0: tensor) -> tensor { + // CHECK-NEXT: tf.Const + %outputs = "tf.Const"() {value = dense<0> : tensor} : () -> tensor<1xi32> + // CHECK-NEXT: [[max_iter:%.*]] = "tf.Const" + %cst_0 = "tf.Const"() {value = dense<224> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<256> : tensor} : () -> tensor + %outputs_2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %outputs_4 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %outputs_6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %outputs_8 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %outputs_10 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<1xi32> + // CHECK: tf.Range + %outputs_14 = "tf.Range"(%outputs_2, %cst_0, %outputs_8) {device = ""} : (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: tf.TensorArrayV3 + %outputs_16:2 = "tf.TensorArrayV3"(%cst_0) {clear_after_read = true, device = "", dtype = !tf_type.string, dynamic_size = false, element_shape = #tf_type.shape<*>, identical_element_shapes = true, tensor_array_name = ""} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + // CHECK-NEXT: tf.TensorArrayScatterV3 + %outputs_18 = "tf.TensorArrayScatterV3"(%outputs_16#0, %outputs_14, %arg0, %outputs_16#1) {device = ""} : (tensor<2x!tf_type.resource>>, tensor, tensor, tensor) -> tensor + // CHECK-NEXT: tf.Range + %outputs_20 = "tf.Range"(%outputs_2, %cst_0, %outputs_8) {device = ""} : (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: [[tensor_array:%.*]], [[flow_in:%.*]] = "tf.TensorArrayV3" + %outputs_22:2 = "tf.TensorArrayV3"(%cst_0) {clear_after_read = true, device = "", dtype = ui8, dynamic_size = false, element_shape = #tf_type.shape<*>, identical_element_shapes = true, tensor_array_name = ""} : (tensor) -> (tensor<2x!tf_type.resource>>, tensor) + // CHECK-NEXT: tf_mlrt.tf_map_fn + // CHECK-SAME: ([[max_iter]], [[flow_in]], %cst_1 + // CHECK-SAME: {body_fn = @"map2/while/LoopCond_body/MapFnBody", num_tensor_list_or_flow_in = 1 : i32} + // CHECK-NOT: tf.While + %outputs_24:7 = "tf.While"(%outputs, %outputs, %outputs_22#1, %cst_1, %outputs_16#0, %outputs_18, %outputs_22#0) {_xla_propagate_compile_time_consts = true, body = @"map2/while/LoopCond_body", cond = @"map2/while/LoopCond_cond", device = "", is_stateless = false, parallel_iterations = 10 : i64, shape_invariant} : (tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>>) -> (tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<2x!tf_type.resource>>, tensor, tensor<2x!tf_type.resource>>) + // CHECK-NEXT: tf.TensorArrayGatherV3 + %outputs_26 = "tf.TensorArrayGatherV3"(%outputs_22#0, %outputs_20, %outputs_24#2) {device = "", element_shape = #tf_type.shape<224x224x3>} : (tensor<2x!tf_type.resource>>, tensor, tensor) -> tensor + return %outputs_26 : tensor +} + +// ----- +// Test a nest while in which the while body is after the usage. + +// CHECK-LABEL: nested_while +func.func @nested_while(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<16x16x?xf32>) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[16, -1]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<16> : tensor} : () -> tensor + // CHECK: tf.TensorListReserve + %0 = "tf.TensorListReserve"(%cst_1, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor>> + // CHECK-NEXT: tf_mlrt.tf_map_fn + %1:4 = "tf.While"(%cst, %cst, %0, %arg0) {_lower_using_switch_merge = true, _num_original_outputs = 6 : i64, _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, body = @tf.NestedWhileRegion1_body, cond = @tf.NestedWhileRegion1_cond, device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = true, parallel_iterations = 4 : i64, shape_invariant} : (tensor, tensor, tensor>>, tensor) -> (tensor, tensor, tensor>>, tensor) + %2 = "tf.TensorListStack"(%1#2, %cst_0) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 16 : i64} : (tensor>>, tensor<2xi32>) -> tensor<16x16x?xf32> + return %2 : tensor<16x16x?xf32> +} +// CHECK-LABEL: tf.NestedWhileRegion1_body +func.func private @tf.NestedWhileRegion1_body(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> (tensor, tensor, tensor>>, tensor) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi32>} : () -> tensor<16xi32> + %cst_2 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<16> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<-1> : tensor} : () -> tensor + %0 = "tf.TensorListReserve"(%cst_4, %cst_3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor>> + %1 = "tf.AddV2"(%arg0, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%arg1, %cst_2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %3 = "tf.GatherV2"(%cst_1, %arg1, %cst_0) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<16xi32>, tensor, tensor) -> tensor + %4 = "tf.Cast"(%3) {Truncate = false, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + %5 = "tf.Mul"(%arg3, %4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %6:4 = "tf.While"(%cst_0, %cst_0, %0, %5) {_lower_using_switch_merge = true, _num_original_outputs = 6 : i64, _read_only_resource_inputs = [], _xla_propagate_compile_time_consts = true, body = @tf.NestedWhileRegion_body, cond = @tf.NestedWhileRegion_cond, device = "/job:localhost/replica:0/task:0/device:CPU:0", is_stateless = true, parallel_iterations = 10 : i64, shape_invariant} : (tensor, tensor, tensor>>, tensor) -> (tensor, tensor, tensor>>, tensor) + %7 = "tf.TensorListStack"(%6#2, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0", num_elements = 16 : i64} : (tensor>>, tensor<1xi32>) -> tensor<16x?xf32> + %8 = "tf.TensorListSetItem"(%arg2, %arg1, %7) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor>>, tensor, tensor<16x?xf32>) -> tensor>> + return %1, %2, %8, %arg3 : tensor, tensor, tensor>>, tensor +} + +//CHECK-LABEL: @"tf.NestedWhileRegion1_body/MapFnBody"(%arg0: !mlrt.future, %arg1: !mlrt.promise, %arg2: tensor, %arg3: tensor, %arg4: tensor) +// CHECK: tf.TensorListReserve +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.GatherV2 +// CHECK-NEXT: tf.Cast +// CHECK-NEXT: tf.Mul +// CHECK-NEXT: tf_mlrt.tf_map_fn +// CHECK-NEXT: tf.TensorListStack +// CHECK-NEXT: tf_mlrt.tf_await +// CHECK-NEXT: tf.TensorListSetItem +// CHECK-NEXT: tf_mlrt.tf_promise +// CHECK-NEXT: return + +// CHECK-LABEL: tf.NestedWhileRegion1_cond +func.func private @tf.NestedWhileRegion1_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> tensor { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<16> : tensor} : () -> tensor + %0 = "tf.Less"(%arg0, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Less"(%arg1, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.LogicalAnd"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + return %2 : tensor +} +// CHECK-LABEL: tf.NestedWhileRegion_body +func.func private @tf.NestedWhileRegion_body(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> (tensor, tensor, tensor>>, tensor) { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi32>} : () -> tensor<16xi32> + %cst_1 = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%arg1, %cst_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.GatherV2"(%cst_0, %arg1, %cst) {batch_dims = 0 : i64, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<16xi32>, tensor, tensor) -> tensor + %3 = "tf.Cast"(%2) {Truncate = false, device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor + %4 = "tf.Mul"(%arg3, %3) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %5 = "tf.TensorListSetItem"(%arg2, %arg1, %4) {device = "/job:localhost/replica:0/task:0/device:CPU:0", resize_if_index_out_of_bounds = false} : (tensor>>, tensor, tensor) -> tensor>> + return %0, %1, %5, %arg3 : tensor, tensor, tensor>>, tensor +} + +// CHECK-LABEL: tf.NestedWhileRegion_body/MapFnBody +// CHECK: tf.AddV2 +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf.GatherV2 +// CHECK-NEXT: tf.Cast +// CHECK-NEXT: tf.Mul +// CHECK-NEXT: tf_mlrt.tf_await +// CHECK-NEXT: tf.TensorListSetItem +// CHECK-NEXT: "tf_mlrt.tf_promise +// CHECK-NEXT: return + +// CHECK-LABEL: tf.NestedWhileRegion_cond +func.func private @tf.NestedWhileRegion_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor>>, %arg3: tensor) -> tensor { + %cst = "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:CPU:0", value = dense<16> : tensor} : () -> tensor + %0 = "tf.Less"(%arg0, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %1 = "tf.Less"(%arg1, %cst) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.LogicalAnd"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor, tensor) -> tensor + return %2 : tensor +} + diff --git a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc index 995bb242fe3..6f6aafa566d 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h" +#include +#include +#include + #include "absl/strings/match.h" #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD index 4badbc11669..c73d06bd8ff 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", features = if_oss(["--path=org_tensorflow/tensorflow/compiler/mlir/tfrt"]), diff --git a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc index eb5615dc2c6..eb2006e6849 100644 --- a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc @@ -17,10 +17,13 @@ limitations under the License. #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h" @@ -28,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_test_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" #include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h" @@ -38,6 +42,8 @@ int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); mlir::registerAllPasses(); + mlir::registerInlinerPass(); + mlir::registerTensorFlowPasses(); // Register passes for TF->JitRt compilation. @@ -45,6 +51,8 @@ int main(int argc, char **argv) { registerTfJitRtTestPasses(); mlir::gml_st::registerGmlStPasses(); + tensorflow::mlrt_compiler::RegisterMlrtPasses(); + mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::RegisterAllTensorFlowDialects(registry); @@ -56,6 +64,8 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); tensorflow::RegisterTPUDialects(®istry); tensorflow::RegisterGpuDialects(®istry); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD new file mode 100644 index 00000000000..beb50129756 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD @@ -0,0 +1,195 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", + # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", + "//tensorflow/compiler/mlir/tfrt:__subpackages__", + "//tensorflow/core/tfrt:__subpackages__", + ], +) + +cc_library( + name = "parallelization", + srcs = ["parallelization.cc"], + hdrs = ["parallelization.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tfrt:constants", + "//tensorflow/compiler/mlir/tfrt:cost_analysis", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", + "//tensorflow/core/tfrt/fallback:cost_recorder", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@tf_runtime//:stream_analysis", + ], +) + +cc_library( + name = "assign_op_key", + srcs = ["assign_op_key.cc"], + hdrs = ["assign_op_key.h"], + deps = [ + ":util", + "//tensorflow/compiler/mlir/tfrt:constants", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_n_z_inc_gen", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_tfrt_ops_inc_gen", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "tf_to_mlrt", + srcs = ["tf_to_mlrt.cc"], + hdrs = ["tf_to_mlrt.h"], + deps = [ + ":execute_op_registry", + ":tpu_conversion_patterns", + ":util", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/compiler/mlir/tfrt:constants", + "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options", + "//tensorflow/compiler/mlir/tfrt:transform_utils", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_tpu_ops", + "//tensorflow/core/tfrt/fallback:fallback_state", + "//tensorflow/core/tfrt/fallback:op_kernel_runner_cache", + "//third_party/protobuf", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "passes", + srcs = ["passes.cc"], + hdrs = ["passes.h"], + deps = [ + ":assign_op_key", + ":fuse_mlrt_ops", + ":parallelization", + ":tf_to_mlrt", + ":while_to_map_fn", + "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options", + "//tensorflow/core/tfrt/fallback:cost_recorder", + "//tensorflow/core/tfrt/fallback:fallback_state", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "execute_op_registry", + hdrs = ["execute_op_registry.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tpu_conversion_patterns", + srcs = ["tpu_conversion_patterns.cc"], + hdrs = ["tpu_conversion_patterns.h"], + deps = [ + ":execute_op_registry", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options", + "//tensorflow/compiler/mlir/tfrt:transform_utils", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_tpu_ops", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "fuse_mlrt_ops", + srcs = ["fuse_mlrt_ops.cc"], + hdrs = ["fuse_mlrt_ops.h"], + deps = [ + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "import_model", + srcs = ["import_model.cc"], + hdrs = ["import_model.h"], + deps = [ + ":assign_op_key", + ":passes", + ":while_to_map_fn", + "//base:vlog", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tfrt:import_model", + "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", + "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", + "//tensorflow/compiler/mlir/tfrt:tfrt_pipeline_options", + "//tensorflow/compiler/mlir/tfrt/translate/mlrt:mlir_to_bytecode", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:statusor", + "//tensorflow/core/tfrt/fallback:cost_recorder", + "//tensorflow/core/tfrt/fallback:fallback_state", + "//tensorflow/core/tfrt/mlrt/attribute", + "//tensorflow/core/tfrt/mlrt/bytecode", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:status", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "while_to_map_fn", + srcs = ["while_to_map_fn.cc"], + hdrs = ["while_to_map_fn.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:mlrt_ops", + "//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_ops", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.cc new file mode 100644 index 00000000000..e2896f8a070 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.cc @@ -0,0 +1,71 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h" + +#include + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/constants.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h" + +namespace tensorflow { +namespace mlrt_compiler { +namespace { + +class AssignOpKeyPass + : public mlir::PassWrapper> { + public: + AssignOpKeyPass() = default; + AssignOpKeyPass& operator=(const AssignOpKeyPass&) = delete; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AssignOpKeyPass) + + private: + llvm::StringRef getArgument() const final { return "tf-mlrt-assign-op-key"; } + llvm::StringRef getDescription() const final { + return "tf-mlrt-assign-op-key"; + } + + void runOnOperation() override; +}; + +void AssignOpKeyPass::runOnOperation() { + auto module = getOperation(); + mlir::OpBuilder builder(module); + + int32_t op_key = 0; + module.walk([&builder, &op_key](mlir::Operation* op) mutable { + if (UseFallback(op)) { + op->setAttr(tensorflow::tfrt_compiler::kOpKeyAttrName, + builder.getI32IntegerAttr(op_key)); + op_key++; + } + }); +} + +} // namespace + +std::unique_ptr> CreateAssignOpKeyPass() { + return std::make_unique(); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h new file mode 100644 index 00000000000..6ed9f1e9198 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h @@ -0,0 +1,32 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASSIGN_OP_KEY_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASSIGN_OP_KEY_H_ +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Create a pass that assigns an op_key to every fallback OP. The op_key +// provides a uniform key to look up online cost for a specific op. +// This pass is expected to run before parallerization. +std::unique_ptr> CreateAssignOpKeyPass(); + +} // namespace mlrt_compiler +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASSIGN_OP_KEY_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h new file mode 100644 index 00000000000..93dde8140c0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_EXECUTE_OP_REGISTRY_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_EXECUTE_OP_REGISTRY_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +class ExecuteOpRegistry { + public: + mlir::LogicalResult RegisterExecuteOp(mlir::Operation* op, uint32_t op_key) { + if (op_key >= execute_ops_.size()) { + execute_ops_.resize(op_key + 1); + } + if (auto* register_op = execute_ops_[op_key]) { + if (register_op->getName() != op->getName() || + register_op->getAttrs() != op->getAttrs()) { + return op->emitError() << "Key " << op_key << " already registered."; + } + return mlir::success(); + } + execute_ops_[op_key] = op; + return mlir::success(); + } + + void ReplaceExecuteOp(int64_t key, mlir::Operation* op) { + execute_ops_[key] = op; + } + + llvm::ArrayRef GetExecuteOps() const { + return execute_ops_; + } + + private: + // Using a vector to keep fallback ops in order, and the key for a fallback op + // is its corresponding index here. + llvm::SmallVector execute_ops_; +}; + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_EXECUTE_OP_REGISTRY_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.cc new file mode 100644 index 00000000000..a53404653fa --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.cc @@ -0,0 +1,157 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h" + +#include + +#include "llvm/ADT/SmallVector.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" + +namespace tensorflow { +namespace mlrt_compiler { +namespace { + +class FuseMlrtOpPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseMlrtOpPass) + + private: + llvm::StringRef getArgument() const final { return "tf-mlrt-fuse"; } + + llvm::StringRef getDescription() const final { + return "Fuse consecutive mlrt ops of the same kind into one."; + } + + void runOnOperation() override; +}; + +void FuseGetResourceOps(mlir::OpBuilder& builder, mlir::Block& block) { + llvm::SmallVector get_resource_ops; + for (auto& op : llvm::make_early_inc_range(block)) { + if (auto get_resource_op = llvm::dyn_cast(&op)) { + get_resource_ops.push_back(get_resource_op); + } + } + + if (get_resource_ops.empty()) return; + + // The last op is always a return op, so it is guaranteed to process all + // groups of the candidate ops. + auto first_get = get_resource_ops.front(); + + builder.setInsertionPointAfter(first_get); + + llvm::SmallVector indices; + llvm::SmallVector result_types; + llvm::SmallVector old_values; + + indices.reserve(get_resource_ops.size()); + result_types.reserve(get_resource_ops.size()); + old_values.reserve(get_resource_ops.size()); + + for (auto op : get_resource_ops) { + auto indices_attr = op.getIndices(); + indices.append(indices_attr.begin(), indices_attr.end()); + result_types.append(op.result_type_begin(), op.result_type_end()); + old_values.append(op.result_begin(), op.result_end()); + } + + auto new_op = builder.create( + first_get.getLoc(), result_types, builder.getArrayAttr(indices)); + + for (auto [old_value, new_value] : + llvm::zip(old_values, new_op.getResults())) { + old_value.replaceAllUsesWith(new_value); + } + + for (auto get_resource_op : get_resource_ops) { + get_resource_op->erase(); + } +} + +template +void FuseAwaitOps(mlir::OpBuilder& builder, mlir::Block& block) { + llvm::SmallVector await_ops; + for (auto& op : llvm::make_early_inc_range(block)) { + if (auto await_op = llvm::dyn_cast(&op)) { + await_ops.push_back(await_op); + continue; + } + + // The last op is always a return op, so it is guaranteed to process all + // groups of the candidate ops. + if (await_ops.size() > 1) { + auto last_await = await_ops.back(); + + builder.setInsertionPointAfter(last_await); + + llvm::SmallVector futures; + futures.reserve(await_ops.size()); + for (auto op : await_ops) { + futures.push_back(op.getOperand()); + } + + llvm::SmallVector result_types; + if constexpr (!std::is_same_v) { + result_types.assign(futures.size(), builder.getType()); + } + + auto await_all = + builder.create(op.getLoc(), result_types, futures); + + if constexpr (!std::is_same_v) { + for (auto [await_op, new_value] : + llvm::zip(await_ops, await_all.getResults())) { + await_op.getResult().replaceAllUsesWith(new_value); + } + } + + for (auto await_op : await_ops) { + await_op->erase(); + } + } + + await_ops.clear(); + } +} + +void FuseMlrtOpPass::runOnOperation() { + auto func = getOperation(); + + mlir::OpBuilder builder(func); + + FuseAwaitOps( + builder, func.front()); + FuseAwaitOps( + builder, func.front()); + FuseAwaitOps(builder, func.front()); + FuseGetResourceOps(builder, func.front()); +} + +} // namespace + +std::unique_ptr> +CreateFuseMlrtOpPass() { + return std::make_unique(); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h new file mode 100644 index 00000000000..6f772a895bb --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_FUSE_MLRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_FUSE_MLRT_OPS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +std::unique_ptr> CreateFuseMlrtOpPass(); + +} +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_FUSE_MLRT_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc new file mode 100644 index 00000000000..63b0de8e243 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc @@ -0,0 +1,136 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h" + +#include + +#include "base/vlog_is_on.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" +#include "tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status.h" + +namespace tensorflow { +namespace mlrt_compiler { + +StatusOr ConvertTfMlirToBytecode( + const TfrtCompileOptions& options, + const tfrt_stub::FallbackState& fallback_state, mlir::ModuleOp module, + mlir::OwningOpRef* module_with_op_keys) { + mlrt::bc::Buffer bytecode_buffer; + TF_RETURN_IF_ERROR(ConvertTfMlirToRuntimeExecutable( + options, module, + [&bytecode_buffer, &fallback_state, module_with_op_keys]( + mlir::PassManager& pm, mlir::ModuleOp module, + const TfrtPipelineOptions& options) { + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + + if (options.enable_while_parallel_iterations) { + pm.addPass(mlrt_compiler::CreateWhileToMapFnPass()); + // Remove unreachable private functions after mapfn conversion. + pm.addPass(mlir::createSymbolDCEPass()); + } + tensorflow::CreateTFExecutorToTFInvariantOptimizationPipelineHelper( + pm, options); + // TODO(b/283481729): Add test to cover unused constants that do not + // cause op_key discontinuity + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlrt_compiler::CreateAssignOpKeyPass()); + // Run passes until (including) AssignOpKeyPass. + if (mlir::failed(pm.run(module))) { + return diag_handler.Combine(absl::InternalError( + "failed to finish passes before (including) assign op keys.")); + } + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("tf_dialect_after_assign_op_key", + module); + } + // Save the module. + if (module_with_op_keys != nullptr) { + *module_with_op_keys = module.clone(); + } + // Clear passes already run. + pm.clear(); + // Create the remaining pipeline and run. + CreateTfToMlrtPipeline(pm, options, &fallback_state); + if (mlir::failed(pm.run(module))) { + return diag_handler.Combine(absl::InternalError( + "failed to lower TF Dialect to MLRT dialect.")); + } + // Generate bytecode. + mlrt::AttributeEncoderRegistry registry; + registry.Register("tf_mlrt", + &tensorflow::tf_mlrt::EncodeTensorflowAttribute); + auto statusor = mlrt::EmitExecutable(registry, module); + if (!statusor.ok()) return statusor.status(); + bytecode_buffer = std::move(*statusor); + return OkStatus(); + })); + return bytecode_buffer; +} + +StatusOr ConvertTfMlirWithOpKeysToBytecode( + const TfrtCompileOptions& options, + const tfrt_stub::FallbackState& fallback_state, + mlir::ModuleOp module_with_op_keys, + const tfrt_stub::CostRecorder& cost_recorder) { + mlir::StatusScopedDiagnosticHandler diag_handler( + module_with_op_keys.getContext()); + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("tf_dialect_with_op_keys", + module_with_op_keys); + } + // Create the reconversion pipeline and run. + mlir::PassManager pm(module_with_op_keys.getContext()); + const auto pipeline_options = GetTfrtPipelineOptions(options); + CreateTfToMlrtPipeline(pm, *pipeline_options, &fallback_state, + &cost_recorder); + if (mlir::failed(pm.run(module_with_op_keys))) { + return diag_handler.Combine( + absl::InternalError("failed to lower TF Dialect to MLRT dialect.")); + } + // Generate bytecode. + mlrt::AttributeEncoderRegistry registry; + registry.Register("tf_mlrt", &tensorflow::tf_mlrt::EncodeTensorflowAttribute); + auto statusor = mlrt::EmitExecutable(registry, module_with_op_keys); + if (!statusor.ok()) return statusor.status(); + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("tfrt_dialect_from_tf_dialect_with_op_keys", + module_with_op_keys); + } + return std::move(*statusor); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h new file mode 100644 index 00000000000..37e0563c691 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h @@ -0,0 +1,52 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IMPORT_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IMPORT_MODEL_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" + +namespace tensorflow { +namespace mlrt_compiler { + +// Converts an MLIR `module` in TF dialect to MLRT's bytecode format. If +// `module_with_op_keys` is non-null, the intermediate module on which passes +// until (including) AssignOpKeyPass have run will be cloned to it. +// +// This is for initial conversion. +StatusOr ConvertTfMlirToBytecode( + const TfrtCompileOptions& options, + const tfrt_stub::FallbackState& fallback_state, mlir::ModuleOp module, + mlir::OwningOpRef* module_with_op_keys = nullptr); + +// Converts an MLIR `module_with_op_keys` in TF dialect to MLRT's bytecode +// format, with op costs from `cost_recorder`. +// +// This is for re-conversion. +StatusOr ConvertTfMlirWithOpKeysToBytecode( + const TfrtCompileOptions& options, + const tfrt_stub::FallbackState& fallback_state, + mlir::ModuleOp module_with_op_keys, + const tfrt_stub::CostRecorder& cost_recorder); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IMPORT_MODEL_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.cc new file mode 100644 index 00000000000..7cab8a9d528 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.cc @@ -0,0 +1,833 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" +#include "tensorflow/compiler/mlir/tfrt/constants.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tfrt/compiler/stream_analysis.h" // from @tf_runtime + +namespace tensorflow { +namespace mlrt_compiler { +namespace { + +using tensorflow::tfrt_compiler::CostAnalysis; +using tfrt::compiler::Stream; +using tfrt::compiler::StreamAnalysis; + +std::string GetStreamFunctionName(absl::string_view func_name, + const Stream& stream) { + return absl::StrCat(func_name, "_stream_", stream.id()); +} + +bool IsConstant(mlir::Operation* op) { + return op && llvm::isa(op); +} + +// StreamInfo is a bookkeeping for inputs, futures, and promises for a stream. +struct StreamInfo { + const Stream* parent = nullptr; + + // The values that are produced by constant ops. Instead of using + // promise/await to pass these values between streams, we can just copying + // these ops to the streams that use these constants. + llvm::SetVector constants; + // The values that are the inputs to the stream. + llvm::SetVector inputs; + // The values that will be the futures to the stream. + llvm::SetVector futures; + // The values that will be the control futures (i.e. futures with no data) to + // the stream. + llvm::SetVector control_futures; + // The values that will be the promises to the stream. + llvm::SetVector promises; + // The values that will be the control promises (i.e., promises with no data) + // to the stream. + llvm::SetVector control_promises; + // The values that are defined by the operations in the stream. Note that all + // values in `futures` will also be in `results`. + llvm::DenseSet results; + + bool contains_only_constants = true; + + bool IsRoot() const { return parent == nullptr; } +}; + +// Preprocess the block to produce StreamInfo for every stream. +llvm::DenseMap PreprocessStreamInfo( + mlir::Block& block, + const llvm::DenseMap>& + control_predecessors, + const StreamAnalysis& stream_analysis) { + llvm::DenseMap stream_map; + + // All values that will be promises in the block. + llvm::DenseSet promises; + + // All operations that will be control promises in the block. + llvm::DenseSet control_promises; + + // Keep track of all available values and controls as we traverse the stream + // tree in depth-first order. + llvm::DenseSet available_values; + llvm::DenseSet available_controls; + + struct Entry { + explicit Entry(const Stream* stream) : stream(stream) {} + + const Stream* stream = nullptr; + + // Keep track of the next operation to be processed. If all operations are + // processed, we can pop this stream from the DFS stack. + int op_idx = 0; + }; + + std::vector stack; + stack.reserve(stream_analysis.GetNumStreams()); + + // We first push the entry for the root stream. + const auto& root_stream = stream_analysis.GetRootStream(); + auto& root_stream_info = stream_map[&root_stream]; + available_values.insert(block.getArguments().begin(), + block.getArguments().end()); + root_stream_info.results.insert(block.getArguments().begin(), + block.getArguments().end()); + stack.push_back(Entry(&root_stream)); + + // The root stream's first operation a dummy operation that defines all block + // arguments. + for (auto* child_stream : root_stream.GetChildStreamsForRootOp()) { + stream_map[child_stream].parent = &root_stream; + stack.push_back(Entry(child_stream)); + } + + // The first DFS traveral populates inputs and futures for every stream but + // not promises. We only know whether a value definition is a promise only + // after traversing all streams, so it is not possible to know it in the first + // pass. + while (!stack.empty()) { + auto& [stream, op_idx] = stack.back(); + auto& stream_info = stream_map[stream]; + + auto ops = stream->ops(); + + // If we finish processing all operations in the stream, we can pop this + // stream, as well as the values defined by its operations. + if (op_idx == ops.size()) { + for (auto* op : stream->ops()) { + // Erase the values and controls produced by the current stream. + for (auto result : op->getResults()) { + available_values.erase(result); + } + available_controls.erase(op); + } + // Futures and control futures will also be available, so we erase them as + // well. + for (auto future : stream_info.futures) { + available_values.erase(future); + } + for (auto* control_future : stream_info.control_futures) { + available_controls.erase(control_future); + } + + if (!stream_info.IsRoot()) { + // Merge inputs, futures, and promises into the parent stream, as they + // will be passed down from the root in the output program. + DCHECK_GT(stream_map.count(stream_info.parent), 0); + auto& parent_info = stream_map[stream_info.parent]; + + for (const auto& input : stream_info.inputs) { + DCHECK(available_values.contains(input)); + if (!parent_info.results.contains(input)) { + // An input in the current stream will be an input in the parent + // stream only if it is not a result in the parent stream. + parent_info.inputs.insert(input); + } + } + + for (auto future : stream_info.futures) { + DCHECK(!available_values.contains(future)); + parent_info.futures.insert(future); + } + for (auto* control_future : stream_info.control_futures) { + DCHECK(!available_controls.contains(control_future)); + parent_info.control_futures.insert(control_future); + } + } + + // Update the global promise set. + promises.insert(stream_info.futures.begin(), stream_info.futures.end()); + control_promises.insert(stream_info.control_futures.begin(), + stream_info.control_futures.end()); + + stack.pop_back(); + continue; + } + + // We process the operations one by one. If the operation has child streams, + // we process the child streams first before continuing to the next + // operation. + bool has_child_streams = false; + for (; op_idx < ops.size() && !has_child_streams; ++op_idx) { + auto* op = ops[op_idx]; + + stream_info.contains_only_constants &= IsConstant(op); + + // Check every operand to see whether it is a future or input. + for (mlir::Value operand : op->getOperands()) { + // If the value is defined in the current stream, nothing needs to be + // done. + if (!stream_info.results.contains(operand)) { + if (available_values.insert(operand).second) { + // If the operand is not available in the current stream or any + // parent stream, it will be a future and then become a result. + if (IsConstant(operand.getDefiningOp())) { + stream_info.constants.insert(operand); + } else { + stream_info.futures.insert(operand); + } + stream_info.results.insert(operand); + } else { + // If the operand is not available in the current stream but + // available in the parent stream, it is an input. + if (IsConstant(operand.getDefiningOp())) { + stream_info.constants.insert(operand); + } else { + stream_info.inputs.insert(operand); + } + } + } + } + + // Insert mlrt.await_control if this op has control deps on other ops. + if (auto ctrl_iter = control_predecessors.find(op); + ctrl_iter != control_predecessors.end()) { + const auto& ctrl_deps = ctrl_iter->second; + + for (mlir::Operation* control_dep : ctrl_deps) { + if (available_controls.insert(control_dep).second) { + // If the control is not already available, it will be a control + // future and then become available. + stream_info.control_futures.insert(control_dep); + } + } + } + + // Update results of this operations. + for (mlir::Value result : op->getResults()) { + available_values.insert(result); + stream_info.results.insert(result); + } + + // Update this op as an available control. + available_controls.insert(op); + + // Pause processing the current stream to process the child streams first. + const auto& child_streams = stream->GetChildStreams(op); + has_child_streams = !child_streams.empty(); + for (auto* child_stream : child_streams) { + stream_map[child_stream].parent = stream; + stack.push_back(Entry(child_stream)); + } + } + } + + // The second pass populates promises for each stream. We also need to merge + // promises in a child to its parent stream. We can do this by traversing the + // operation in reverse program order. + for (auto& op : llvm::reverse(block)) { + const auto& stream = stream_analysis.GetStream(&op); + auto& stream_info = stream_map[&stream]; + + for (mlir::Value result : op.getResults()) { + if (promises.contains(result)) { + stream_info.promises.insert(result); + } + } + + if (control_promises.contains(&op)) { + stream_info.control_promises.insert(&op); + } + + for (const auto* child_stream : stream.GetChildStreams(&op)) { + const auto& child_info = stream_map[child_stream]; + + stream_info.promises.insert(child_info.promises.begin(), + child_info.promises.end()); + stream_info.control_promises.insert(child_info.control_promises.begin(), + child_info.control_promises.end()); + } + } + + // Special handling for the dummy operation in the root. + auto& root_info = stream_map[&root_stream]; + for (const auto* child_stream : root_stream.GetChildStreamsForRootOp()) { + const auto& child_info = stream_map[child_stream]; + + root_info.promises.insert(child_info.promises.begin(), + child_info.promises.end()); + root_info.control_promises.insert(child_info.control_promises.begin(), + child_info.control_promises.end()); + } + + return stream_map; +} + +// A custom struct that groups mappings for values, futures and promises for a +// stream during creating the corresponding stream function. +struct Mapping { + // This is the mappings for the SSA values used in the original and new + // operations. + mlir::IRMapping value_mapping; + + // Maps the original tensor value that will be a future to the corresponding + // !mlrt.future value. + mlir::IRMapping future_mapping; + + // Maps the original tensor value that will be a promise to the corresponding + // !mlrt.promise value. + mlir::IRMapping promise_mapping; + + // In addition to value mappings, we also need mappings for input control + // dependencies to the corresponding !mlrt.future and !mlrt.promise values. + llvm::DenseMap future_control_mapping; + llvm::DenseMap promise_control_mapping; +}; + +mlrt::compiler::AsyncOp CreateAsyncOp( + mlir::OpBuilder& builder, absl::string_view function_name, + const llvm::DenseMap& stream_map, + const Stream& stream, const Mapping& mapping, mlir::Location loc) { + auto iter = stream_map.find(&stream); + DCHECK(iter != stream_map.end()); + const auto& stream_info = iter->second; + + if (stream_info.contains_only_constants) return nullptr; + + const auto& [value_mapping, future_mapping, promise_mapping, + future_control_mapping, promise_control_mapping] = mapping; + + llvm::SmallVector async_operands; + + for (auto input : stream_info.inputs) { + async_operands.push_back(value_mapping.lookup(input)); + DCHECK(async_operands.back()); + } + + for (auto future : stream_info.futures) { + async_operands.push_back(future_mapping.lookup(future)); + DCHECK(async_operands.back()); + } + + for (auto* control_future : stream_info.control_futures) { + DCHECK_GT(future_control_mapping.count(control_future), 0); + async_operands.push_back(future_control_mapping.lookup(control_future)); + DCHECK(async_operands.back()); + } + + for (auto promise : stream_info.promises) { + async_operands.push_back(promise_mapping.lookup(promise)); + DCHECK(async_operands.back()); + } + + for (auto* control_promise : stream_info.control_promises) { + DCHECK_GT(promise_control_mapping.count(control_promise), 0); + async_operands.push_back(promise_control_mapping.lookup(control_promise)); + DCHECK(async_operands.back()); + } + + return builder.create( + loc, builder.getType(), async_operands, + mlir::SymbolRefAttr::get(builder.getContext(), + GetStreamFunctionName(function_name, stream))); +} + +mlir::func::FuncOp CreateStreamFunction( + mlir::OpBuilder& builder, Mapping& mapping, absl::string_view name, + const Stream& stream, const StreamInfo& stream_info, mlir::Location loc) { + if (stream_info.contains_only_constants) return nullptr; + + auto& [value_mapping, future_mapping, promise_mapping, future_control_mapping, + promise_control_mapping] = mapping; + + llvm::SmallVector arg_types; + for (mlir::Value input : stream_info.inputs) { + arg_types.push_back(input.getType()); + } + + arg_types.append( + stream_info.futures.size() + stream_info.control_futures.size(), + builder.getType()); + arg_types.append( + stream_info.promises.size() + stream_info.control_promises.size(), + builder.getType()); + + // The stream function has no result. + auto func_type = builder.getFunctionType(arg_types, /*results=*/{}); + + auto func = builder.create( + loc, GetStreamFunctionName(name, stream), func_type); + func.setVisibility(mlir::func::FuncOp::Visibility::Private); + + // Populate the body of the stream function by copying over the operations + // in the stream. + auto* new_block = func.addEntryBlock(); + + // Replace inputs with the function arguments. + for (int i = 0; i < stream_info.inputs.size(); ++i) { + value_mapping.map(stream_info.inputs[i], new_block->getArgument(i)); + } + + // Maps the original tensor value that will be a future or a promise to + // the corresponding !mlrt.future or !mlrt.promise value. + size_t start = stream_info.inputs.size(); + for (int i = 0; i < stream_info.futures.size(); ++i) { + future_mapping.map(stream_info.futures[i], + new_block->getArgument(i + start)); + } + + start += stream_info.futures.size(); + for (int i = 0; i < stream_info.control_futures.size(); ++i) { + future_control_mapping[stream_info.control_futures[i]] = + new_block->getArgument(i + start); + } + + start += stream_info.control_futures.size(); + for (int i = 0; i < stream_info.promises.size(); ++i) { + promise_mapping.map(stream_info.promises[i], + new_block->getArgument(i + start)); + } + + start += stream_info.promises.size(); + for (int i = 0; i < stream_info.control_promises.size(); ++i) { + promise_control_mapping[stream_info.control_promises[i]] = + new_block->getArgument(i + start); + } + + return func; +} + +void CreateAllocateFuturesOp(mlir::OpBuilder& builder, Mapping& mapping, + const StreamInfo& stream_info, + mlir::Location loc) { + auto& [value_mapping, future_mapping, promise_mapping, future_control_mapping, + promise_control_mapping] = mapping; + + DCHECK_EQ(stream_info.futures.size(), stream_info.promises.size()); + + llvm::SmallVector promise_types( + stream_info.promises.size(), + builder.getType()); + llvm::SmallVector future_types( + stream_info.futures.size(), + builder.getType()); + + if (!stream_info.futures.empty()) { + auto allocate_futures = builder.create( + loc, promise_types, future_types, stream_info.futures.size()); + for (int i = 0; i < stream_info.futures.size(); ++i) { + future_mapping.map(stream_info.futures[i], + allocate_futures.getFutures()[i]); + } + + for (int i = 0; i < stream_info.futures.size(); ++i) { + // Use the original values in `futures` to make sure futures[i] shares the + // state with promises[i]. + DCHECK(stream_info.promises.contains(stream_info.futures[i])); + promise_mapping.map(stream_info.futures[i], + allocate_futures.getPromises()[i]); + } + } + + DCHECK_EQ(stream_info.control_futures.size(), + stream_info.control_promises.size()); + if (!stream_info.control_futures.empty()) { + promise_types.resize(stream_info.control_promises.size(), + builder.getType()); + future_types.resize(stream_info.control_futures.size(), + builder.getType()); + + auto allocate_control_futures = + builder.create( + loc, promise_types, future_types, + stream_info.control_futures.size()); + for (int i = 0; i < stream_info.control_futures.size(); ++i) { + future_control_mapping[stream_info.control_futures[i]] = + allocate_control_futures.getFutures()[i]; + } + for (int i = 0; i < stream_info.control_futures.size(); ++i) { + // Use the original operations in `control_futures` to make sure + // control_futures[i] shares the state with control_promises[i]. + DCHECK(stream_info.control_promises.contains( + stream_info.control_futures[i])); + promise_control_mapping[stream_info.control_futures[i]] = + allocate_control_futures.getPromises()[i]; + } + } +} + +class TensorflowCostModel : public StreamAnalysis::CostModelInterface { + public: + explicit TensorflowCostModel(CostAnalysis* cost_analysis) + : cost_analysis_(*cost_analysis) {} + + std::optional GetOperationCost(mlir::Operation* op) const override { + return cost_analysis_.GetCost(op); + } + + private: + const CostAnalysis& cost_analysis_; +}; + +bool SkipControlDep(mlir::Operation* op) { + // TODO(chky): Consider define side effects more properly for these ops. + return llvm::isa( + op); +} + +void ParallelizeBlock( + absl::string_view name, mlir::Block& block, + const mlir::TF::SideEffectAnalysis::Info& side_effect_analysis, + const tfrt_stub::CostRecorder* cost_recorder) { + // First, we use SideEffectAnalysis to find out control predecessors for each + // operation. We use this map later to insert control futures. + llvm::DenseMap> + control_predecessors; + for (auto& op : block) { + auto& deps = control_predecessors[&op]; + for (auto* dep : side_effect_analysis.DirectControlPredecessors(&op)) { + // If we skip the control deps of `op`, then we need to use the control + // deps of these control deps instead. + if (SkipControlDep(dep)) { + for (auto* d : control_predecessors[dep]) { + DCHECK(!SkipControlDep(d)); + deps.insert(d); + } + } else { + deps.insert(dep); + } + } + } + + // Remove skipped control deps. + for (auto& op : block) { + if (SkipControlDep(&op)) { + control_predecessors.erase(&op); + } + } + + // Perform stream analysis. + CostAnalysis cost_analysis( + llvm::cast(block.getParentOp()), cost_recorder); + TensorflowCostModel cost_model(&cost_analysis); + StreamAnalysis stream_analysis(block, &cost_model); + + // Preprocess all streams to gather StreamInfos for all streams, without + // modifying the program. + llvm::DenseMap stream_map = + PreprocessStreamInfo(block, control_predecessors, stream_analysis); + + // Then we perform a DFS traversal to create stream functions and insert async + // operations. + std::vector stack; + stack.reserve(stream_analysis.GetNumStreams()); + + const auto& root_stream = stream_analysis.GetRootStream(); + stack.push_back(&root_stream); + + llvm::SmallVector to_remove; + + mlir::OpBuilder builder(block.getParentOp()); + + while (!stack.empty()) { + const auto* stream = stack.back(); + stack.pop_back(); + DCHECK(stream); + + DCHECK_GT(stream_map.count(stream), 0); + const auto& stream_info = stream_map[stream]; + + Mapping mapping; + auto& [value_mapping, future_mapping, promise_mapping, + future_control_mapping, promise_control_mapping] = mapping; + + // `async_handles` keeps the !mlrt.async_handle created in the stream. A + // mlrt.await_handle op will be inserted at the end of the stream function + // for each async handle. + llvm::SmallVector async_handles; + + mlir::func::FuncOp stream_func; + if (!stream_info.IsRoot()) { + // If it is not a root stream, we need to create a new function for this + // stream. And futures and promises are also passed as parameters. For the + // root stream, futures and promises are allocated in the body. + + // Insert the stream function before the original function. + builder.setInsertionPoint(block.getParentOp()); + + stream_func = + CreateStreamFunction(builder, mapping, name, *stream, stream_info, + block.getParentOp()->getLoc()); + + if (stream_func) { + // Set the insertion point to the start of the new block in the + // function. + builder.setInsertionPointToStart(&stream_func.front()); + } + } else { + stream_func = llvm::cast(block.getParentOp()); + + DCHECK_EQ(stream, &root_stream); + // If it is the root stream, we insert new operations in the original + // function. And we need to allocate all the futures used here. + builder.setInsertionPointToStart(&block); + + // The block arguments of the root stream are in the `results`. There will + // be no additional inputs in `inputs`. + DCHECK(stream_info.inputs.empty()); + + // Put the original arguments in the mapping as they are not changed. + for (auto arg : block.getArguments()) { + value_mapping.map(arg, arg); + } + + // Insert a tf_mlrt.allocate_futures op to allocate all futures used. + CreateAllocateFuturesOp(builder, mapping, stream_info, + block.getParentOp()->getLoc()); + + // Lastly for the root stream, we need to handle the dummy op that defines + // the arguments. + for (const auto* child_stream : stream->GetChildStreamsForRootOp()) { + stack.push_back(child_stream); + if (auto async = + CreateAsyncOp(builder, name, stream_map, *child_stream, mapping, + block.getParentOp()->getLoc())) { + async_handles.push_back(async); + } + } + } + + for (auto* op : stream->ops()) { + to_remove.push_back(op); + } + + // Skip empty streams. + if (!stream_func) continue; + + mlir::Operation* return_op = nullptr; + + // Cloning the operations in the stream. If the operand is a future, a + // tf_mlrt.Await op will be inserted. If the result is a promise, a + // tf_mlrt.Promise will be inserted. Similar to control futures and control + // promises. + for (auto* op : stream->ops()) { + // Clone the current op into the function of this stream, using the + // new operands, which can be futures. + for (mlir::Value operand : op->getOperands()) { + if (stream_info.constants.contains(operand) && + !value_mapping.contains(operand)) { + builder.clone(*operand.getDefiningOp(), value_mapping); + } else if (stream_info.futures.contains(operand) && + !value_mapping.contains(operand)) { + // Insert Await op if it is a future. + auto future_value = builder.create( + op->getLoc(), operand.getType(), future_mapping.lookup(operand)); + + // Now this future is available in the current stream, so it can be a + // normal value. + value_mapping.map(operand, future_value); + } + } + + if (auto ctrl_iter = control_predecessors.find(op); + ctrl_iter != control_predecessors.end()) { + const auto& ctrl_deps = ctrl_iter->second; + + for (mlir::Operation* control_dep : ctrl_deps) { + // This control may be available in the ancestors or in a previous + // AwaitControl, we only insert a new AwaitControl if it is not. + if (stream_info.control_futures.contains(control_dep)) { + if (auto iter = future_control_mapping.find(control_dep); + iter != future_control_mapping.end()) { + builder.create( + control_dep->getLoc(), iter->second); + + // Now we no longer need this control dep in this stream. + future_control_mapping.erase(iter); + } + } + } + } + + // Clone the op using the value mapping that includes values from futures. + auto* new_op = builder.clone(*op, value_mapping); + + // TODO(chky): Ensure the original return op is in the root stream. This + // is currently an implicit guarantee in stream analysis. + if (llvm::isa(op)) { + DCHECK(stream_info.IsRoot()) << name << " " << stream->id(); + return_op = new_op; + } + + for (mlir::Value result : op->getResults()) { + if (stream_info.promises.contains(result)) { + // Insert Promise op if the result is a promise. + builder.create(op->getLoc(), + promise_mapping.lookup(result), + value_mapping.lookup(result)); + } + } + + if (stream_info.control_promises.contains(op)) { + // Insert Promise op if this op produce a control dependency to ops in + // other streams. + builder.create( + op->getLoc(), promise_control_mapping[op]); + } + + // If this op has child streams, insert mlrt.async ops. + for (auto* child_stream : stream->GetChildStreams(op)) { + stack.push_back(child_stream); + if (auto async = CreateAsyncOp(builder, name, stream_map, *child_stream, + mapping, op->getLoc())) { + async_handles.push_back(async); + } + } + } + + // Create the return op for non-root streams. + // + // TODO(chky): Ensure the original return op is in the root stream. This is + // currently an implicit guarantee in stream analysis. + if (!return_op) { + DCHECK(!stream_info.IsRoot()) << name << " " << stream->id(); + return_op = + builder.create(block.getParentOp()->getLoc()); + } + + // We need to wait for async executions at the end of the stream function, + // in order to manage resource lifetime and handle errors properly. These + // mlrt.await_handle ops are inserted before the return op. + builder.setInsertionPoint(return_op); + for (auto handle : async_handles) { + builder.create( + block.getParentOp()->getLoc(), handle); + } + } + + // Remove the operations in the original block. + for (auto* op : llvm::reverse(to_remove)) { + op->dropAllDefinedValueUses(); + op->erase(); + } +} + +class ParallelizationPass + : public mlir::PassWrapper> { + public: + ParallelizationPass() = default; + ParallelizationPass(uint64_t cost_threshold, + bool merge_inter_dependent_streams, + const tfrt_stub::CostRecorder* cost_recorder) { + cost_threshold_ = cost_threshold; + merge_inter_dependent_streams_ = merge_inter_dependent_streams; + cost_recorder_ = cost_recorder; + } + ParallelizationPass(const ParallelizationPass&) {} + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParallelizationPass) + + private: + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + } + + llvm::StringRef getArgument() const final { + return "tf-mlrt-parallelization"; + } + + llvm::StringRef getDescription() const final { + return "Parallelize tf graphs by inserting mlrt async operations."; + } + + void runOnOperation() override { + auto module = getOperation(); + + mlir::Builder builder(module); + module->setAttr("tfrt.cost_threshold", + builder.getI64IntegerAttr(cost_threshold_)); + module->setAttr("tfrt.merge_inter_dependent_streams", + builder.getBoolAttr(merge_inter_dependent_streams_)); + + mlir::TF::SideEffectAnalysis side_effect_analysis(module); + + for (auto func_op : + llvm::make_early_inc_range(module.getOps())) { + ParallelizeBlock(func_op.getSymName(), func_op.front(), + side_effect_analysis.GetAnalysisForFunc(func_op), + cost_recorder_); + } + } + + Option cost_threshold_{ + *this, "tfrt-cost-threshold", + llvm::cl::desc("If a sequence of operations has a cost lower than the " + "cost-threshold, the sequence will be executed as a block " + "in the same thread."), + llvm::cl::init(1)}; + Option merge_inter_dependent_streams_{ + *this, "tfrt-merge-inter-dependent-streams", + llvm::cl::desc("If true, streams with inter data depenedencies will be " + "preferred to be merged for inline execution."), + llvm::cl::init(false)}; + const tfrt_stub::CostRecorder* cost_recorder_ = nullptr; +}; + +} // namespace + +std::unique_ptr> CreateParallelizationPass( + uint64_t cost_threshold, bool merge_inter_dependent_streams, + const tfrt_stub::CostRecorder* cost_recorder) { + return std::make_unique( + cost_threshold, merge_inter_dependent_streams, cost_recorder); +} + +std::unique_ptr> +CreateParallelizationPass() { + return std::make_unique(); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h new file mode 100644 index 00000000000..71221276fa9 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PARALLELIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PARALLELIZATION_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" + +namespace tensorflow { +namespace mlrt_compiler { + +std::unique_ptr> CreateParallelizationPass( + uint64_t cost_threshold, bool merge_inter_dependent_streams, + const tfrt_stub::CostRecorder* cost_recorder = nullptr); + +std::unique_ptr> +CreateParallelizationPass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PARALLELIZATION_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.cc new file mode 100644 index 00000000000..b55a7ff19bc --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.cc @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" + +namespace tensorflow { +namespace mlrt_compiler { + +void RegisterMlrtPasses() { + mlir::registerPass([]() { return CreateAssignOpKeyPass(); }); + mlir::registerPass([]() { return CreateParallelizationPass(); }); + mlir::registerPass([]() { return CreateWhileToMapFnPass(); }); + mlir::registerPass( + []() { return CreateTfToMlrtPreParallelizationConversionPass({}); }); + mlir::registerPass([]() { return CreateTfToMlrtConversionPass({}); }); + mlir::registerPass([]() { return CreateFuseMlrtOpPass(); }); +} + +void CreateTfToMlrtPipeline(mlir::OpPassManager &pm, + const TfrtPipelineOptions &options, + const tfrt_stub::FallbackState *fallback_state, + const tfrt_stub::CostRecorder *cost_recorder) { + pm.addPass( + mlrt_compiler::CreateTfToMlrtPreParallelizationConversionPass(options)); + pm.addPass(mlrt_compiler::CreateParallelizationPass( + options.cost_threshold, options.merge_inter_dependent_streams, + cost_recorder)); + + DCHECK(fallback_state); + pm.addPass( + mlrt_compiler::CreateTfToMlrtConversionPass(options, fallback_state)); + + // Perform optimizations in the lowered MLIR. + pm.addNestedPass(mlrt_compiler::CreateFuseMlrtOpPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass(mlir::createCSEPass()); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h new file mode 100644 index 00000000000..f9bf621b8bf --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PASSES_H_ + +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" + +namespace tensorflow { +namespace mlrt_compiler { + +void RegisterMlrtPasses(); + +// Creates a pipeline of passes that lowers MLIR TF dialect to MLRT dialects. +// The op costs from `cost_recorder` (if non-null) are used for Stream Analysis. +void CreateTfToMlrtPipeline( + mlir::OpPassManager& pm, const TfrtPipelineOptions& options, + const tfrt_stub::FallbackState* fallback_state, + const tfrt_stub::CostRecorder* cost_recorder = nullptr); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc new file mode 100644 index 00000000000..c9ec37e0aea --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -0,0 +1,1146 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h" + +#include + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "third_party/protobuf/text_format.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tfrt/constants.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h" + +namespace tensorflow { +namespace mlrt_compiler { +namespace { + +// TODO(chky): Add registration interface for custom device +mlir::Value CreateCustomDevice(mlir::Location loc, llvm::StringRef device_name, + mlir::ConversionPatternRewriter &rewriter) { + if (device_name == kTpuHostDevice) { + return rewriter.create( + loc, rewriter.getType()); + } + + return nullptr; +} + +class FuncOpSignatureConversion final + : public mlir::OpConversionPattern { + public: + explicit FuncOpSignatureConversion( + mlir::MLIRContext *context, mlir::TypeConverter *type_converter, + const llvm::DenseMap> + *function_call_site_input_types) + : mlir::OpConversionPattern(context), + type_converter_(*type_converter), + function_call_site_input_types_(*function_call_site_input_types) {} + + mlir::LogicalResult matchAndRewrite( + mlir::func::FuncOp func_op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto it = function_call_site_input_types_.find(func_op.getName()); + if (it == function_call_site_input_types_.end()) { + return mlir::failure(); + } + const llvm::SmallVector &call_site_input_types = it->second; + + mlir::FunctionType func_type = func_op.getFunctionType(); + DCHECK_EQ(func_type.getNumInputs(), call_site_input_types.size()); + + mlir::TypeConverter::SignatureConversion converted_signature( + func_type.getNumInputs()); + for (const auto &[index, value] : llvm::enumerate(call_site_input_types)) { + converted_signature.addInputs(index, value); + } + + // Update the function signature in-place. + rewriter.updateRootInPlace(func_op, [&] { + func_op.setType(mlir::FunctionType::get( + func_op.getContext(), converted_signature.getConvertedTypes(), + func_type.getResults())); + }); + + // Update the entry block + if (rewriter.applySignatureConversion(&func_op.getBody(), + converted_signature, + &type_converter_) == nullptr) { + return mlir::failure(); + } + + return mlir::success(); + } + + private: + mlir::TypeConverter &type_converter_; + const llvm::DenseMap> + &function_call_site_input_types_; +}; + +class TFAwaitOpConversion final + : public mlir::OpConversionPattern { + public: + explicit TFAwaitOpConversion(mlir::MLIRContext *context) + : mlir::OpConversionPattern(context) {} + + mlir::LogicalResult matchAndRewrite( + tf_mlrt::TFAwaitOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto new_op = rewriter.create( + op->getLoc(), rewriter.getType(), + adaptor.getFuture()); + rewriter.replaceOp(op, new_op.getResult()); + return mlir::success(); + } +}; + +class TFPromiseOpConversion final + : public mlir::OpConversionPattern { + public: + explicit TFPromiseOpConversion(mlir::MLIRContext *context) + : mlir::OpConversionPattern(context) {} + + mlir::LogicalResult matchAndRewrite( + tf_mlrt::TFPromiseOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (llvm::isa<::mlrt::compiler::FutureType>( + adaptor.getTensor().getType())) { + auto new_op = rewriter.create( + op->getLoc(), adaptor.getPromise(), adaptor.getTensor()); + rewriter.replaceOp(op, new_op->getResults()); + + } else { + auto new_op = rewriter.create( + op->getLoc(), adaptor.getPromise(), adaptor.getTensor()); + rewriter.replaceOp(op, new_op->getResults()); + } + return mlir::success(); + } +}; + +// Convert tf_mlrt::MapFn's signature to tf_mlrt::TFTensorType +class TFMapFnOpConversion + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + tf_mlrt::TFMapFnOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector result_types; + result_types.resize(op->getResultTypes().size(), + rewriter.getType()); + + auto new_op = rewriter.create( + op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs()); + rewriter.replaceOp(op, new_op.getResult()); + return mlir::success(); + } +}; + +// Convert TF call ops (eg. StatefulPartitionedCall) to call. +template +class TFCallOpConversion : public mlir::OpConversionPattern { + public: + TFCallOpConversion(mlir::MLIRContext *context, + mlir::TypeConverter *type_converter) + : mlir::OpConversionPattern(context), + type_converter_(*type_converter) {} + + mlir::LogicalResult matchAndRewrite( + TFCallOp op, typename TFCallOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (auto xla_must_compile = + op->template getAttrOfType("_XlaMustCompile"); + xla_must_compile && xla_must_compile.getValue()) { + return mlir::failure(); + } + + auto callee = + op.getCallableForCallee().template dyn_cast(); + if (!callee) return mlir::failure(); + + llvm::SmallVector result_types; + for (auto type : op.getOperation()->getResultTypes()) { + if (failed(type_converter_.convertType(type, result_types))) + return mlir::failure(); + } + + auto new_op = rewriter.create( + op.getLoc(), result_types, callee.getRootReference().getValue(), + adaptor.getOperands()); + rewriter.replaceOp(op, new_op.getResults()); + return mlir::success(); + } + + private: + mlir::TypeConverter &type_converter_; +}; + +// Convert tf.Case op to mlrt.Case. +// +// TF dialect: +// %outputs = "tf.Case"(%idx_tensor, %arg, ...) { branches = [@branch0, +// @branch1], +// ...} +// +// lowered MLRT dialect: +// %branch_idx = tf_mlrt.tensor_to_int32(%idx_tensor) +// %outputs = mlrt.case %branch_idx [@branch0, @branch1] (%arg, ...) +class CaseOpConversion : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::CaseOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::ArrayAttr branches = op.getBranches(); + + llvm::SmallVector result_types; + result_types.resize(op->getResultTypes().size(), + rewriter.getType()); + + auto index_operand = rewriter.create( + op.getLoc(), rewriter.getI32Type(), adaptor.getBranchIndex()); + + auto new_op = rewriter.create( + op.getLoc(), result_types, index_operand.getResult(), branches, + adaptor.getInput()); + + rewriter.replaceOp(op, new_op.getResults()); + return mlir::success(); + } +}; + +class AsyncOpConversion + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + // Hook for derived classes to implement combined matching and rewriting. + mlir::LogicalResult matchAndRewrite( + mlrt::compiler::AsyncOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), adaptor.getOperands(), op.getCallee()); + return mlir::success(); + } +}; + +// SetResourceOpConversion lowers a TF SetResource op to a tf_mlrt.set_resource +// op. +class SetResourceOpConversion final + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::_TfrtSetResourceOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getArg(), + op.getIndex()); + return mlir::success(); + } +}; + +// GetResourceOpConversion lowers a TF GetResource op to a tf_mlrt.get_resource +// op. +class GetResourceOpConversion final + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::_TfrtGetResourceOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector result_types( + op.getNumResults(), rewriter.getType()); + auto new_op = rewriter.create( + op->getLoc(), result_types, op.getIndices()); + rewriter.replaceOp(op, new_op->getResults()); + return mlir::success(); + } +}; + +std::optional DecodeLongName(mlir::Location loc) { + if (auto name_loc = loc.dyn_cast()) { + return name_loc.getName().str(); + } + + if (auto fused_loc = loc.dyn_cast()) { + std::string fused_name; + for (auto l : fused_loc.getLocations()) { + if (auto n = DecodeLongName(l)) { + fused_name += *n; + } + } + return fused_name; + } + + return std::nullopt; +} + +std::string GetNodeName(mlir::Operation *op) { + auto name = [&]() -> std::string { + if (auto name = DecodeLongName(op->getLoc())) { + return *std::move(name); + } + + return op->getName().stripDialect().str(); + }(); + + for (char &c : name) { + if (c == ':') c = '/'; + } + return name; +} + +void CanonicalizeFunctionNameInNodeDef(const mlir::SymbolTable &symbol_table, + NodeDef &node_def) { + for (auto &p : *node_def.mutable_attr()) { + if (p.second.has_func()) { + auto *func = p.second.mutable_func(); + if (auto n = CanonicalizeTensorflowFunctionName( + symbol_table, func->name(), + /*use_mlir_func_name=*/false)) { + func->set_name(*n); + } + } + + if (p.second.has_list() && p.second.list().func_size() > 0) { + for (auto &func : *p.second.mutable_list()->mutable_func()) { + if (auto n = CanonicalizeTensorflowFunctionName( + symbol_table, func.name(), + /*use_mlir_func_name=*/false)) { + func.set_name(*n); + } + } + } + } +} + +class ExecuteOpConversion final : public mlir::ConversionPattern { + public: + ExecuteOpConversion(mlir::MLIRContext *context, + const mlir::SymbolTable *symbol_table, + mlir::TypeConverter *type_converter, + ExecuteOpRegistry *execute_op_registry, + tfrt_stub::OpKernelRunnerCache *op_kernel_cache, + const tfrt_stub::FallbackState *fallback_state) + : mlir::ConversionPattern(*type_converter, + mlir::Pattern::MatchAnyOpTypeTag(), + /*benefit=*/1, context), + symbol_table_(*symbol_table), + execute_op_registry_(*execute_op_registry), + op_kernel_cache_(*op_kernel_cache), + fallback_state_(*fallback_state) {} + + mlir::LogicalResult matchAndRewrite( + mlir::Operation *op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // TODO(b/173017701): Avoid fallback for ops within XLA GPU clusters. + if (!UseFallback(op)) return mlir::failure(); + + // The assign_op_key pass should have ran. + if (!op->hasAttr(tensorflow::tfrt_compiler::kOpKeyAttrName)) + return op->emitError("does not have op_key defined"); + + std::string node_name = GetNodeName(op); + + uint32_t execute_key = op->getAttrOfType( + tensorflow::tfrt_compiler::kOpKeyAttrName) + .getInt(); + + absl::StrAppend(&node_name, "_", execute_key); + + auto statusor_node_def = tensorflow::ConvertTFDialectOpToNodeDef( + op, node_name, /*ignore_unregistered_attrs=*/false); + if (!statusor_node_def.ok()) + return op->emitWarning("failed to export NodeDef."); + auto &node_def = **statusor_node_def; + + CanonicalizeFunctionNameInNodeDef(symbol_table_, node_def); + + std::string node_def_text; + proto2::TextFormat::PrintToString(node_def, &node_def_text); + + auto op_kernel_runner = op_kernel_cache_.GetOrCreate( + tfrt::Location(nullptr, execute_key), node_def.op(), node_def.device(), + op->getNumOperands(), + [&](tensorflow::AttrValueMap *attr_value_map) { + *attr_value_map = node_def.attr(); + return OkStatus(); + }, + fallback_state_.device_manager(), + fallback_state_.process_function_library_runtime()); + LOG_IF(ERROR, !op_kernel_runner.ok()) << op_kernel_runner.status(); + + mlir::Value device; + if (auto custom_device = + op->getAttrOfType(kTfMlrtCustomDevice)) { + device = + CreateCustomDevice(op->getLoc(), custom_device.getValue(), rewriter); + if (!device) return op->emitWarning("Failed to create custom device."); + } + + mlir::Operation *new_op = nullptr; + if (op_kernel_runner.ok() && (*op_kernel_runner)->IsAsync()) { + // If it is an AsyncOpKernel, we lower it to tf_mlrt.async_executeop, + // which return !mlrt.futures. These results will be converted as + // necessary through the target materialization hook in the type + // converter. + llvm::SmallVector result_types( + op->getNumResults(), rewriter.getType()); + if (device) { + new_op = rewriter.replaceOpWithNewOp( + op, result_types, device, operands, node_def_text, execute_key); + } else { + new_op = rewriter.replaceOpWithNewOp( + op, result_types, operands, node_def_text, execute_key); + } + if (mlir::failed( + execute_op_registry_.RegisterExecuteOp(new_op, execute_key))) { + return op->emitWarning("Fail to register async op"); + } + } else { + // Otherwise, lower to tf_mlrt.executeop. + llvm::SmallVector result_types( + op->getNumResults(), rewriter.getType()); + if (device) { + new_op = rewriter.replaceOpWithNewOp( + op, result_types, device, operands, node_def_text, execute_key); + } else { + new_op = rewriter.replaceOpWithNewOp( + op, result_types, operands, node_def_text, execute_key); + } + + if (op_kernel_runner.ok()) { + // Only register this executeop if its opkernel can be created. + // Otherwise, it is an unused op so we don't need to create them at + // runtime. + if (mlir::failed( + execute_op_registry_.RegisterExecuteOp(new_op, execute_key))) { + return op->emitWarning("Fail to register sync op"); + } + } + } + + return mlir::success(); + } + + private: + const mlir::SymbolTable &symbol_table_; + ExecuteOpRegistry &execute_op_registry_; + tfrt_stub::OpKernelRunnerCache &op_kernel_cache_; + const tfrt_stub::FallbackState &fallback_state_; +}; + +mlir::Value GetPredicate(mlir::Operation *op, mlir::Value cond_operand, + mlir::ConversionPatternRewriter &rewriter) { + return rewriter.create( + op->getLoc(), rewriter.getI1Type(), cond_operand); +} + +class CondOpConversion : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::IfOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::FlatSymbolRefAttr then_branch = op.getThenBranchAttr(); + mlir::FlatSymbolRefAttr else_branch = op.getElseBranchAttr(); + + llvm::SmallVector result_types( + op.getNumResults(), rewriter.getType()); + + auto bool_cond = GetPredicate(op, adaptor.getCond(), rewriter); + + auto new_op = rewriter.create( + op.getLoc(), result_types, bool_cond, adaptor.getInput(), then_branch, + else_branch); + + rewriter.replaceOp(op, new_op.getResults()); + + return mlir::success(); + } +}; + +// Convert TF WhileOp to mlrt.while. +// The pseudo code of mlrt.while is as follows: +// +// while(cond) { +// outputs, cond = body(inputs) +// inputs = outputs +// } +// return outputs, cond +// +// So we need to insert extra conversion kernels and merge functions when +// lowering tf.While to mlrt.while. +// +// %result = tf.While(%arg) {cond = @original_cond_fn, body = +// @original_body_fn} +// +// is converted to +// +// func @new_pred_fn(%arg) { +// %cond_tensor = func.call @original_cond_fn(%arg) +// %cond_bool = mlrt.predicate %cond_tensor +// return %cond_bool +// } +// +// func @new_while_body(%arg) { +// %result = func.call @original_body_fn(%arg) +// %cond_bool = func.call @new_pred_fn(%result) +// return%result, %cond_bool +// } +// +// %first_iter_cond = func.call @new_pred_fn(%arg) +// %result = mlrt.while %first_iter_cond @new_while_body(%arg) +// +class WhileOpConversion : public mlir::OpConversionPattern { + public: + WhileOpConversion(mlir::MLIRContext *context, + mlir::TypeConverter *type_converter, + mlir::SymbolTable *symbol_table) + : mlir::OpConversionPattern(*type_converter, context), + symbol_table_(*symbol_table) {} + + mlir::LogicalResult matchAndRewrite( + mlir::TF::WhileOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::FlatSymbolRefAttr cond_fn = op.getCondAttr(); + mlir::FlatSymbolRefAttr body_fn = op.getBodyAttr(); + + // Create the predicate function that calls the original cond function and + // in addition convert the result to a boolean value. + mlir::func::FuncOp pred_fn = GetPredicateFunction( + op, cond_fn, adaptor.getOperands().getTypes(), rewriter); + if (!pred_fn) return mlir::failure(); + + // Insert a call op to call the pred function for the first iteration. + auto call_pred_fn = rewriter.create( + op.getLoc(), pred_fn.getFunctionType().getResults(), + pred_fn.getSymName(), adaptor.getOperands()); + + if (!call_pred_fn) return mlir::failure(); + + // Create the new while body function. + mlir::func::FuncOp new_body_fn = GetWhileBodyFunction( + op, body_fn, pred_fn, adaptor.getOperands().getTypes(), rewriter); + + // mlrt.while returns one more additional boolean value than tf.while. + llvm::SmallVector while_result_types( + adaptor.getOperands().getTypes().begin(), + adaptor.getOperands().getTypes().end()); // = while_arg_types; + while_result_types.push_back(rewriter.getI1Type()); + auto new_op = rewriter.create( + op.getLoc(), while_result_types, call_pred_fn.getResult(0), + adaptor.getOperands(), new_body_fn.getSymName()); + + rewriter.replaceOp(op, new_op.getResults().drop_back()); + + return mlir::success(); + } + + private: + mlir::func::FuncOp GetPredicateFunction( + mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn, + mlir::TypeRange arg_types, + mlir::ConversionPatternRewriter &rewriter) const; + + mlir::func::FuncOp GetWhileBodyFunction( + mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr body_fn, + mlir::func::FuncOp pred_fn, mlir::TypeRange arg_types, + mlir::ConversionPatternRewriter &rewriter) const; + + mlir::SymbolTable &symbol_table_; +}; + +// Create the pred function that contains a call to the original cond function +// and a predicate kernel that converts the cond tensor to a boolean value. eg. +// +// func @pred_fn( %arg) { +// %cond_tensor = tf_mlrt.call @original_cond_fn(%arg) +// %cond_bool = tf_mlrt.predicate %cond_tensor +// return %cond_bool +// } +// +mlir::func::FuncOp WhileOpConversion::GetPredicateFunction( + mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn, + mlir::TypeRange arg_types, + mlir::ConversionPatternRewriter &rewriter) const { + std::string pred_fn_name = + absl::StrCat(cond_fn.getValue().str(), "/tf_mlrt_predicate"); + + if (auto pred_fn = symbol_table_.lookup(pred_fn_name)) { + return pred_fn; + } + + auto func_op = op->getParentOfType(); + + mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter); + rewriter.setInsertionPointAfter(func_op); + + auto func_type = rewriter.getFunctionType(arg_types, {rewriter.getI1Type()}); + + auto pred_fn = + rewriter.create(op.getLoc(), pred_fn_name, func_type); + + auto *block = pred_fn.addEntryBlock(); + rewriter.setInsertionPointToStart(block); + + auto call_cond_fn = rewriter.create( + op.getLoc(), arg_types.take_front(), cond_fn, block->getArguments()); + mlir::Value bool_cond = GetPredicate(op, call_cond_fn.getResult(0), rewriter); + rewriter.create(op.getLoc(), bool_cond); + + symbol_table_.insert(pred_fn); + + return pred_fn; +} + +// Create the new while body function that contains a call to original while +// body and then a call to the pred function. eg. +// +// func @while_body(%arg) { +// %result = mlrt.call @original_body(%arg) +// %cond_bool = mlrt.call @pred_function(%arg) +// mlrt.return %result, %cond_bool +// } +// +mlir::func::FuncOp WhileOpConversion::GetWhileBodyFunction( + mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr original_body_fn, + mlir::func::FuncOp pred_fn, mlir::TypeRange arg_types, + mlir::ConversionPatternRewriter &rewriter) const { + std::string body_fn_name = + absl::StrCat(original_body_fn.getValue().str(), "/tf_mlrt_body"); + + if (auto body_fn = symbol_table_.lookup(body_fn_name)) { + return body_fn; + } + + auto func_op = op->getParentOfType(); + + mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter); + rewriter.setInsertionPointAfter(func_op); + + llvm::SmallVector body_result_types(arg_types.begin(), + arg_types.end()); + // The last result of the while body function is the boolean condition. + body_result_types.push_back(rewriter.getI1Type()); + + auto func_type = rewriter.getFunctionType(arg_types, body_result_types); + auto body_fn = + rewriter.create(op.getLoc(), body_fn_name, func_type); + + auto *block = body_fn.addEntryBlock(); + rewriter.setInsertionPointToStart(block); + + // Insert a call to the original body function. + // The returned result type is also the original argument types. + auto call_original_body_fn = rewriter.create( + op.getLoc(), arg_types, original_body_fn, block->getArguments()); + + // Insert a call to the pred function, which contains a call to the original + // cond function and the predicate kernel that converts the tensor to boolean + // value. + auto call_pred_fn = rewriter.create( + op.getLoc(), pred_fn.getFunctionType().getResults(), pred_fn.getSymName(), + call_original_body_fn.getResults()); + + llvm::SmallVector body_results = + call_original_body_fn.getResults(); + + // The last result should be the boolean value converted from the condition. + auto bool_cond = call_pred_fn.getResult(0); + body_results.push_back(bool_cond); + + rewriter.create(op.getLoc(), body_results); + + symbol_table_.insert(body_fn); + + return body_fn; +} + +class BatchFunctionOpConversion + : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::TF::BatchFunctionOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + std::string node_name = GetNodeName(op); + + auto statusor_node_def = tensorflow::ConvertTFDialectOpToNodeDef( + op, node_name, /*ignore_unregistered_attrs=*/true); + if (!statusor_node_def.ok()) + return op->emitWarning("failed to export NodeDef."); + const auto &node_def = **statusor_node_def; + + std::string node_def_text; + proto2::TextFormat::PrintToString(node_def, &node_def_text); + + llvm::SmallVector result_types( + op->getNumResults(), rewriter.getType()); + + rewriter.replaceOpWithNewOp( + op, result_types, adaptor.getOperands(), node_def.device(), + op.getFAttr(), node_def_text); + + return mlir::success(); + } +}; + +void CreateFallbackInitializationFunction( + mlir::ModuleOp module, ExecuteOpRegistry &execute_op_registry) { + mlir::OpBuilder builder(&module.getBodyRegion()); + + auto func_op = builder.create( + module.getLoc(), "_tfrt_fallback_init", + mlir::FunctionType::get(module.getContext(), /*inputs=*/{}, + /*outputs=*/{})); + + auto *block = func_op.addEntryBlock(); + builder.setInsertionPointToStart(block); + + // Create operations for all fallback kernels in the module. + for (const auto &[op_index, op] : + llvm::enumerate(execute_op_registry.GetExecuteOps())) { + if (op) { + // There might be unused ops, and we don't need to create them at runtime. + // + // TODO(chky, deqiangc): Clean up unused ops before hand. + builder.create( + func_op.getLoc(), /*resultTypes=*/mlir::TypeRange{}, + /*operands=*/mlir::ValueRange{}, op->getAttrs()); + } + } + + builder.create(func_op.getLoc()); +} + +// Move the tf_mlrt.await ops to right before their first uses to avoid +// unnecessary blocking. +void MoveAwaitOpToFirstUse(mlir::Block &block) { + llvm::SmallVector await_ops; + for (auto &op : block) { + if (auto await_op = llvm::dyn_cast(&op)) { + await_ops.push_back(await_op); + } + } + + for (auto op : await_ops) { + auto result = op.getResult(); + if (result.use_empty()) continue; + + mlir::Operation *first_user = *result.user_begin(); + for (auto *user : result.getUsers()) { + if (user->isBeforeInBlock(first_user)) { + first_user = user; + } + } + + op->moveBefore(first_user); + } +} + +const tfrt_stub::FallbackState &GetDefaultFallbackState() { + static const auto *const fallback_state = []() { + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + auto fallback_state = + tfrt_stub::FallbackState::Create(session_options, fdef_lib).value(); + return fallback_state.release(); + }(); + + return *fallback_state; +} + +// The conversion pass that is run before 'tf-mlrt-parallelization' passes. The +// parallelization pass changes the graph content, so any rewrite/conversion +// that depends on the graph instead of individual ops should be done before +// parallelization. +class TfToMlrtPreParallelizationConversionPass + : public mlir::PassWrapper> { + public: + TfToMlrtPreParallelizationConversionPass() = default; + explicit TfToMlrtPreParallelizationConversionPass( + const TfrtPipelineOptions &options) { + // This is needed to progating user configs into this pass. + options_.copyOptionValuesFrom(options); + } + TfToMlrtPreParallelizationConversionPass( + const TfToMlrtPreParallelizationConversionPass &other) {} + TfToMlrtPreParallelizationConversionPass &operator=( + const TfToMlrtPreParallelizationConversionPass &) = delete; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TfToMlrtPreParallelizationConversionPass) + + private: + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + + RegisterTpuDialect(registry); + } + + llvm::StringRef getArgument() const final { + return "pre-parallel-tf-to-mlrt"; + } + llvm::StringRef getDescription() const final { + return "pre-parallel-tf-to-mlrt"; + } + + mlir::LogicalResult initialize(mlir::MLIRContext *context) override { + if (use_tpu_host_allocator_for_inputs_.hasValue()) { + options_.use_tpu_host_allocator_for_inputs = + use_tpu_host_allocator_for_inputs_; + } + + return mlir::success(); + } + + mlir::LogicalResult runOnFunction(mlir::func::FuncOp func) { + auto &context = getContext(); + mlir::ConversionTarget target(context); + mlir::RewritePatternSet patterns(&getContext()); + target.addLegalDialect(); + PopulateTpuPreParallelizationConversionPatterns(target, patterns, options_); + + return mlir::applyPartialConversion(func, target, std::move(patterns)); + } + + void runOnOperation() override { + auto module = getOperation(); + + for (auto func : module.getOps()) { + if (mlir::failed(runOnFunction(func))) { + signalPassFailure(); + return; + } + } + } + + Option use_tpu_host_allocator_for_inputs_{ + *this, "use-tpu-host-allocator-for-inputs", + llvm::cl::desc("If true, fallback executeops that produce inputs to tpu " + "program will use tpu host allocator."), + llvm::cl::init(false)}; + + TfrtPipelineOptions options_; +}; + +class TfToMlrtConversionPass + : public mlir::PassWrapper> { + public: + TfToMlrtConversionPass() + : TfToMlrtConversionPass({}, &GetDefaultFallbackState()) {} + explicit TfToMlrtConversionPass( + const TfrtPipelineOptions &options, + const tfrt_stub::FallbackState *fallback_state) + : fallback_state_(*fallback_state) { + // This is needed to progating user configs into this pass. + options_.copyOptionValuesFrom(options); + } + TfToMlrtConversionPass(const TfToMlrtConversionPass &other) + : fallback_state_(other.fallback_state_) {} + TfToMlrtConversionPass &operator=(const TfToMlrtConversionPass &) = delete; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TfToMlrtConversionPass) + + private: + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + + RegisterTpuDialect(registry); + } + + llvm::StringRef getArgument() const final { return "tf-to-mlrt"; } + llvm::StringRef getDescription() const final { return "tf-to-mlrt"; } + + mlir::LogicalResult initialize(mlir::MLIRContext *context) override { + // TODO(b/285064425): See if this and below are the right way to + // accommodate other dialects. + type_converter_.addConversion([](mlir::Type type) { return type; }); + type_converter_.addConversion( + [=](mlir::TensorType type) -> std::optional { + // Ref types are not supported in both compiler and runtime. + if (type.getElementType().isa()) + return std::nullopt; + return tf_mlrt::TFTensorType::get(context); + }); + + auto future_to_tensor_materialization = + [](mlir::OpBuilder &builder, mlir::Type desired_type, + mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value { + if (inputs.size() != 1) return mlir::Value(); + + if (inputs[0].getType().isa()) { + if (desired_type.isa()) { + return builder.create(loc, desired_type, inputs[0]); + } + + return mlir::Value(); + } + + return inputs[0]; + }; + + type_converter_.addTargetMaterialization(future_to_tensor_materialization); + type_converter_.addArgumentMaterialization( + future_to_tensor_materialization); + type_converter_.addSourceMaterialization( + [](mlir::OpBuilder &builder, mlir::Type result_type, + mlir::ValueRange inputs, + mlir::Location loc) -> std::optional { + return builder + .create(loc, result_type, + inputs) + .getResult(0); + }); + + if (use_tpu_host_allocator_for_inputs_.hasValue()) { + options_.use_tpu_host_allocator_for_inputs = + use_tpu_host_allocator_for_inputs_; + } + + return mlir::success(); + } + + void runOnOperation() override { + auto module = getOperation(); + mlir::SymbolTable symbol_table(module); + + // Use llvm::make_early_inc_range instead of the stock range from + // module.getOps because conversions such as WhileOpConversion could insert + // new functions into the module ops list causing the stock range to not + // able to find next OP correctly. + for (auto func : + llvm::make_early_inc_range(module.getOps())) { + if (mlir::failed(runOnFunction(func, symbol_table))) { + signalPassFailure(); + return; + } + } + + // Some mlrt kernels such as tf_mlrt_tpu.CompileAndExecute produce futures, + // but function invoked by mlrt execute op are not aware of these changes. + // We add a post process to fix up this caller-callee mismatch. + for (auto func : module.getOps()) { + CollectFunctionCallSiteInputTypes(func); + } + for (auto func : module.getOps()) { + if (mlir::failed(PostProcessFunctionSignature(func, symbol_table))) { + signalPassFailure(); + return; + } + // Move the tf_mlrt.await ops to right before their first uses to avoid + // unnecessary blocking. + MoveAwaitOpToFirstUse(func.getBlocks().front()); + } + + CreateFallbackInitializationFunction(module, execute_op_registry_); + + module.walk([&](mlir::UnrealizedConversionCastOp op) { + op->replaceAllUsesWith(op->getOperands()); + op->erase(); + }); + } + + mlir::LogicalResult PostProcessFunctionSignature( + mlir::func::FuncOp func, mlir::SymbolTable &symbol_table) { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + target.addDynamicallyLegalOp( + [this](mlir::func::FuncOp func) { + // By default, we assume callers are well behaved. + if (function_call_site_input_types_.find(func.getName()) == + function_call_site_input_types_.end()) { + return true; + } + DCHECK_EQ(function_call_site_input_types_.at(func.getName()).size(), + func.getFunctionType().getInputs().size()); + + for (auto [expected_input_type, call_site_type] : + llvm::zip(func.getFunctionType().getInputs(), + function_call_site_input_types_.at(func.getName()))) { + if (expected_input_type != call_site_type) { + return false; + } + } + return true; + }); + + patterns.add(&getContext(), &type_converter_, + &function_call_site_input_types_); + + return mlir::applyPartialConversion(func, target, std::move(patterns)); + } + + void CollectFunctionCallSiteInputTypes(mlir::func::FuncOp func) { + func.walk([&function_call_site_input_types = + function_call_site_input_types_]( + mlir::Operation *op) mutable { + // Only collect the call-site input types when a function is invoked + // by async op. This is the only known case that the previous pass + // may left un-match types between call-site and callee. + if (auto async_op = llvm::dyn_cast(op)) { + function_call_site_input_types[async_op.getCallee() + .getLeafReference()] = + llvm::SmallVector(async_op.getOperandTypes().begin(), + async_op.getOperandTypes().end()); + } + }); + } + + mlir::LogicalResult runOnFunction(mlir::func::FuncOp func, + mlir::SymbolTable &symbol_table) { + auto &context = getContext(); + mlir::ConversionTarget target(context); + mlir::RewritePatternSet patterns(&getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + target.addDynamicallyLegalOp( + [this](mlir::func::FuncOp op) { + return type_converter_.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp( + [this](mlir::func::ReturnOp op) { + for (auto operand : op.getOperands()) { + if (!type_converter_.isLegal(operand.getType())) return false; + } + return true; + }); + target.addDynamicallyLegalOp( + [this](mlrt::compiler::AsyncOp op) { + for (auto operand : op.getOperands()) { + if (!type_converter_.isLegal(operand.getType())) return false; + } + return true; + }); + target.addDynamicallyLegalOp( + [this](mlir::func::CallOp op) { + for (auto operand : op.getOperands()) { + if (!type_converter_.isLegal(operand.getType())) return false; + } + return true; + }); + + // LINT.IfChange(fallback_allow_list) + // Order the list of added ops alphabetically. + patterns.add(&context, &type_converter_, &symbol_table); + patterns.add(&context); + patterns.add(type_converter_, &context); + patterns.add(&context, &symbol_table, &type_converter_, + &execute_op_registry_, &op_kernel_cache_, + &fallback_state_); + patterns.add, + TFCallOpConversion, + TFCallOpConversion>(&context, + &type_converter_); + // LINT.ThenChange(util.cc:fallback_allow_list) + + mlir::populateFunctionOpInterfaceTypeConversionPattern( + patterns, type_converter_); + mlir::populateReturnOpTypeConversionPattern(patterns, type_converter_); + + PopulateTpuConversionPatterns(target, patterns, type_converter_, + execute_op_registry_, options_); + + return mlir::applyPartialConversion(func, target, std::move(patterns)); + } + + Option use_tpu_host_allocator_for_inputs_{ + *this, "use-tpu-host-allocator-for-inputs", + llvm::cl::desc("If true, fallback executeops that produce inputs to tpu " + "program will use tpu host allocator."), + llvm::cl::init(false)}; + + TfrtPipelineOptions options_; + mlir::TypeConverter type_converter_; + ExecuteOpRegistry execute_op_registry_; + tfrt_stub::OpKernelRunnerCache op_kernel_cache_; + const tfrt_stub::FallbackState &fallback_state_; + + // True input argument types for a given function at call site. + llvm::DenseMap> + function_call_site_input_types_; +}; + +} // namespace + +std::unique_ptr> +CreateTfToMlrtPreParallelizationConversionPass( + const TfrtPipelineOptions &options) { + return std::make_unique(options); +} + +std::unique_ptr> +CreateTfToMlrtConversionPass(const TfrtPipelineOptions &options, + const tfrt_stub::FallbackState *fallback_state) { + return std::make_unique(options, fallback_state); +} + +std::unique_ptr> +CreateTfToMlrtConversionPass(const TfrtPipelineOptions &options) { + return CreateTfToMlrtConversionPass(options, &GetDefaultFallbackState()); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h new file mode 100644 index 00000000000..1206f66f72b --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TF_TO_MLRT_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TF_TO_MLRT_H_ +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" + +namespace tensorflow { +namespace mlrt_compiler { + +// The conversion pass that is run before 'tf-mlrt-parallelization' passes. The +// parallelization pass changes the graph content, so any rewrite/conversion +// that depends on the graph instead of individual ops should be done before +// parallelization. +std::unique_ptr> +CreateTfToMlrtPreParallelizationConversionPass( + const TfrtPipelineOptions& options); + +// The conversion pass that is run after 'tf-mlrt-parallelization' passes. The +// parallelization pass changes the graph content, so this pass should only +// contain conversion that depends on individual ops. +std::unique_ptr> +CreateTfToMlrtConversionPass(const TfrtPipelineOptions& options); + +std::unique_ptr> +CreateTfToMlrtConversionPass(const TfrtPipelineOptions& options, + const tfrt_stub::FallbackState* fallback_state); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TF_TO_MLRT_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.cc new file mode 100644 index 00000000000..0212c945de6 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.cc @@ -0,0 +1,169 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h" + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" + +namespace tensorflow { +namespace mlrt_compiler { +namespace { + +class TPUCompileMlirAndExecuteOpPreParallelizationConversion + : public mlir::OpConversionPattern { + public: + TPUCompileMlirAndExecuteOpPreParallelizationConversion( + mlir::MLIRContext* context, bool use_tpu_host_allocator_for_inputs) + : OpConversionPattern(context), + use_tpu_host_allocator_for_inputs_(use_tpu_host_allocator_for_inputs) {} + + mlir::LogicalResult matchAndRewrite( + mlir::TF::TPUCompileMlirAndExecuteOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const override { + llvm::SmallVector constant_operand_indices; + llvm::SmallVector non_constant_operand_indices; + + for (int i = 0; i < adaptor.getArgs().size(); ++i) { + auto operand = adaptor.getOperands()[i]; + auto original_operand = op.getOperand(i); + if (IsResultVariable(original_operand, operand)) { + // NOTE: It's important to populate constant_operand_indices in + // ascending order. + constant_operand_indices.push_back(i); + } else { + non_constant_operand_indices.push_back(i); + } + } + + llvm::SmallVector operands = adaptor.getArgs(); + + size_t tensor_operands_size = operands.size(); + operands.append(adaptor.getStaticShapes().begin(), + adaptor.getStaticShapes().end()); + + auto producer_name = op->getAttrOfType("producer_name"); + + llvm::SmallVector operands_with_static_shapes; + if (adaptor.getOperandsWithStaticShape().has_value()) { + for (auto attr : adaptor.getOperandsWithStaticShapeAttr() + .getAsRange()) { + operands_with_static_shapes.push_back( + static_cast(attr.getInt())); + } + } + + if (use_tpu_host_allocator_for_inputs_) { + llvm::DenseMap replaced_ops; + + for (int i : non_constant_operand_indices) { + DCHECK_LT(i, op.getNumOperands()); + auto old_value = operands[i]; + mlir::Operation* def = old_value.getDefiningOp(); + + if (def && llvm::isa(def->getDialect())) { + auto*& op_with_device = replaced_ops[def]; + if (!op_with_device) { + mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(def); + + op_with_device = rewriter.clone(*def); + op_with_device->setAttr(kTfMlrtCustomDevice, + rewriter.getStringAttr(kTpuHostDevice)); + rewriter.replaceOp(def, op_with_device->getResults()); + } + } + } + } + + auto compile_and_execute_op = + rewriter.create( + op.getLoc(), op.getResultTypes(), operands, + rewriter.getDenseI32ArrayAttr(constant_operand_indices), + op.getMetadataAttr(), op.getMlirModuleAttr(), + rewriter.getUI32IntegerAttr(tensor_operands_size), + rewriter.getDenseI32ArrayAttr(operands_with_static_shapes), + producer_name); + + rewriter.replaceOp(op, compile_and_execute_op->getResults()); + + return mlir::success(); + } + + private: + bool use_tpu_host_allocator_for_inputs_ = false; +}; + +class TPUCompileMlirAndExecuteOpConversion + : public mlir::OpConversionPattern { + public: + TPUCompileMlirAndExecuteOpConversion(mlir::TypeConverter* type_converter, + mlir::MLIRContext* context, + ExecuteOpRegistry* execute_op_registry) + : OpConversionPattern(*type_converter, context) {} + + mlir::LogicalResult matchAndRewrite( + tf_mlrt::TFTPUCompileAndExecuteOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const override { + llvm::SmallVector operands = + adaptor.getOperandsAndStaticShapes(); + llvm::SmallVector result_types; + result_types.push_back(rewriter.getType()); + result_types.append(op.getResults().size(), + rewriter.getType()); + + auto compile_and_execute_op = + rewriter.create( + op.getLoc(), result_types, operands, op.getConstantOperandIndices(), + op.getMetadataAttr(), op.getMlirModuleAttr(), op.getNumOperands(), + op.getOperandsWithStaticShape(), op.getProducerName()); + + rewriter.replaceOp(op, compile_and_execute_op->getResults()); + + return mlir::success(); + } +}; + +} // namespace + +void PopulateTpuPreParallelizationConversionPatterns( + mlir::ConversionTarget& target, mlir::RewritePatternSet& patterns, + const TfrtPipelineOptions& options) { + target.addIllegalOp(); + patterns.add( + patterns.getContext(), options.use_tpu_host_allocator_for_inputs); +} + +void PopulateTpuConversionPatterns(mlir::ConversionTarget& target, + mlir::RewritePatternSet& patterns, + mlir::TypeConverter& type_converter, + ExecuteOpRegistry& execute_op_registry, + const TfrtPipelineOptions& options) { + target.addIllegalOp(); + target.addLegalDialect(); + + patterns.add( + &type_converter, patterns.getContext(), &execute_op_registry); +} + +void RegisterTpuDialect(mlir::DialectRegistry& registry) { + registry.insert(); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h new file mode 100644 index 00000000000..979b4b46033 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TPU_CONVERSION_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TPU_CONVERSION_PATTERNS_H_ + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" + +namespace tensorflow { +namespace mlrt_compiler { + +inline constexpr char kTfMlrtCustomDevice[] = "tf_mlrt.custom_device"; +inline constexpr char kTpuHostDevice[] = "tpu_host_device"; + +void RegisterTpuDialect(mlir::DialectRegistry& registry); + +void PopulateTpuPreParallelizationConversionPatterns( + mlir::ConversionTarget& target, mlir::RewritePatternSet& patterns, + const TfrtPipelineOptions& options); + +void PopulateTpuConversionPatterns(mlir::ConversionTarget& target, + mlir::RewritePatternSet& patterns, + mlir::TypeConverter& type_converter, + ExecuteOpRegistry& execute_op_registry, + const TfrtPipelineOptions& options); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TPU_CONVERSION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc new file mode 100644 index 00000000000..fb110fb01f2 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.cc @@ -0,0 +1,43 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h" + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h.inc" + +namespace tensorflow { +namespace mlrt_compiler { + +bool UseFallback(mlir::Operation *op) { + if (!llvm::isa(op->getDialect())) return false; + + // TODO(b/173017701): have a centralized place to hold the information + // whether a TF op should be lowered to FallbackExecute op. + // LINT.IfChange(fallback_allow_list) + return !llvm::isa(op); + // LINT.ThenChange(tf_to_mlrt.cc:fallback_allow_list) +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h new file mode 100644 index 00000000000..c47471f67cd --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_UTIL_H_ + +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Use fallback by default for anything that does not have a native kernel +// with some exceptions. +bool UseFallback(mlir::Operation *op); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc new file mode 100644 index 00000000000..a7975c40e1f --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc @@ -0,0 +1,944 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h" + +namespace tensorflow { +namespace mlrt_compiler { +namespace { + +void RemoveIdentityOp(mlir::func::FuncOp func) { + auto &block = func.getBody().front(); + llvm::SmallVector identity_ops; + for (auto &op : block) { + if (auto identity_op = llvm::dyn_cast(&op)) { + identity_ops.push_back(identity_op); + } + } + + for (auto op : llvm::reverse(identity_ops)) { + op.getOutput().replaceAllUsesWith(op.getInput()); + } + + for (auto op : identity_ops) { + op->erase(); + } + + auto return_op = llvm::cast(block.getTerminator()); + + auto func_type = mlir::FunctionType::get( + func.getContext(), func.getArgumentTypes(), return_op.getOperandTypes()); + + func.setType(func_type); +} + +// tf.map_fn (https://www.tensorflow.org/api_docs/python/tf/map_fn) is converted +// to tf.while during lowering. tf.map_fn expects parallel execution of its body +// function but not all tf.while can guarantee parallel executions. The tf.while +// op that is converted from tf.map_fn has distinct programming patterns. This +// pass matches those patterns to convert applicable tf.while to tf_mlrt.map_fn +// for parallel execution of the body function. +// +// For example, tf.map_fn(fn, elems, ...) can be converted to the following: +// +// %tensor_list = "tf.TensorListReserve"(%per_iteration_shape, %max_iterations) +// +// %while_outputs:7 = "tf.While"(%loop_counter, +// %tensor_list_index, %other_args, %tensor_list) {body = @while_body, cond = +// @while_cond} +// +// %outputs = "tf.TensorListStack"(%while_outputs#2, %output_shape) +// +// in which +// +// while_cond: check loop_counter and tensor_list_index both smaller than +// max_iterations. +// +// while_body: loop_counter and tensor_list_index is incremented and returned; +// also gather input from elems based on un-incremented tensor_list_index, +// call fn and set output into a TensorList at tensor_list_index. +// +// This pass additionally assumes the following patterns to identify a tf.While +// that are converted from tf.map_fn: +// 1. Arguments have one loop_counter and one element_index that are initialized +// to be 0. +// 2. TensorList or TensorArray is reserved with max_iterations size. The +// max_iterations shall be a constant. +// 3. The predicate function check both loop_counter and element_index is less +// than max_iterations. +// 4. The body function increase loop_counter and element_index by 1 and use +// element_index to stores its result into tensor list or tensor array such +// that there is no overlap in write between iterations +// 5. The body function does not have side effects such that one iteration will +// impact the next iteration outside #4. +// +// After conversion, the pseudocode is +// +// %tensor_list = "tf.TensorListReserve"(%per_iteration_shape, %max_iterations) +// +// %updated_tensor_list = "tf_mlrt.map_fn" (%max_iterations, %tensor_list, +// %other_args) {body = @map_fn_body} +// +// %outputs = "tf.TensorListStack"(%updated_tensor_list, %output_shape) +// +// where +// +// tf_mlrt.map_fn leads to a blocking call and +// the argument list of tf_mlrt.map_fn is (%max_iterations, %tensor_list, +// tf.while's argument list minus loop_counter, tensor_list_index and +// tensor_list). tf_mlrt.map_fn is a block call and returns the updated tensor +// list. +// +// map_fn_body has an input signature of (%in_tensor_list_future, +// %out_tensor_list_promise, %loop_counter, %tensor_list_index, %other_args) and +// has not return values (the updated_tensor_list is delivered through +// %out_tensor_list_promise). +// +class WhileToMapFnPass + : public mlir::PassWrapper> { + public: + WhileToMapFnPass() = default; + WhileToMapFnPass &operator=(const WhileToMapFnPass &) = delete; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WhileToMapFnPass) + + private: + struct LoopInfo { + // Argument indices in while op of key loop variables. + int loop_counter = -1; + int element_index = -1; + std::vector tensor_list_or_flow_in; + // Max iteration may be passed in as an argument to while op. + std::optional max_iterations_arg_idx; + // Max itertions may be hard coded as constant inside while predicate + // function. + std::optional max_iterations_value; + // Defining Op of max_iterations. + mlir::Value max_iterations; + }; + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + llvm::StringRef getArgument() const final { + return "tf-mlrt-while-to-map-fn"; + } + + llvm::StringRef getDescription() const final { + return "Convert tf.while to tf_mlrt.map_fn when possible for parallel " + "execution."; + } + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + mlir::SymbolTable symbol_table(module); + + // Use make_early_inc_range because the processing might insert new node + // into the list + for (auto func_op : + llvm::make_early_inc_range(module.getOps())) { + MayConvertWhileToMapFn(func_op, symbol_table); + } + } + + // We match while op's predicate function and body function with known + // patterns from tf.map_fn. If matched, tf.while is converted to + // tf_mlrt.map_fn. + void MayConvertWhileToMapFn(mlir::func::FuncOp op, + mlir::SymbolTable &symbol_table) { + mlir::OpBuilder builder(op); + for (mlir::Operation &op : llvm::make_early_inc_range(op.front())) { + auto while_op = llvm::dyn_cast(&op); + if (!while_op) continue; + LoopInfo loop_info; + if (mlir::succeeded(MatchPredicate(while_op.getCondAttr(), symbol_table, + loop_info)) && + mlir::succeeded( + MatchBody(while_op.getBodyAttr(), symbol_table, loop_info)) && + mlir::succeeded(MatchInputSource(while_op, loop_info)) && + mlir::succeeded(MatchOutputUse(while_op, loop_info))) { + // Input, predicate function, body function and output are all following + // patterns, we can convert it to tf_mlrt.map_fn. + mlir::func::FuncOp while_body_func = + symbol_table.lookup(while_op.getBody()); + auto map_fn_body_func = CreateMapFnBodyFunction( + builder, while_body_func, symbol_table, loop_info); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(while_op); + std::vector invariant_arguments; + invariant_arguments.reserve(while_op->getNumOperands()); + + absl::flat_hash_set variant_arguments = {loop_info.loop_counter, + loop_info.element_index}; + variant_arguments.insert(loop_info.tensor_list_or_flow_in.begin(), + loop_info.tensor_list_or_flow_in.end()); + for (int i = 0; i < while_op->getNumOperands(); ++i) { + if (variant_arguments.contains(i)) { + continue; + } + invariant_arguments.push_back(while_op.getOperand(i)); + } + + llvm::SmallVector result_types; + llvm::SmallVector tensor_list_operands; + for (int i = 0; i < loop_info.tensor_list_or_flow_in.size(); ++i) { + tensor_list_operands.push_back( + while_op.getOperand(loop_info.tensor_list_or_flow_in[i])); + result_types.push_back( + while_op.getResult(loop_info.tensor_list_or_flow_in[i]) + .getType()); + } + + auto map_fn_op = builder.create( + while_op.getLoc(), result_types, loop_info.max_iterations, + tensor_list_operands, invariant_arguments, + map_fn_body_func.getSymName(), + loop_info.tensor_list_or_flow_in.size()); + + // MatchOutputUse already makes sure only the tensor_list or + // tensor_array output is used. + absl::flat_hash_map old_arg_indx_to_new_index; + for (int i = 0; i < loop_info.tensor_list_or_flow_in.size(); ++i) { + old_arg_indx_to_new_index.insert( + {loop_info.tensor_list_or_flow_in[i], i}); + } + for (int i = 0; i < while_op.getResults().size(); ++i) { + if (!old_arg_indx_to_new_index.contains(i)) { + while_op.getResult(i).dropAllUses(); + } else { + while_op.getResult(i).replaceAllUsesWith( + map_fn_op.getResult()[old_arg_indx_to_new_index[i]]); + } + } + + while_op.erase(); + } + } + } + + // Match that (a) the tensor list or tensor array are reserved with + // max_iterations size such that parallel operations on tensor list or tensor + // array is thread safe; (b) loop_counter and element_index starts with 0. + // Also may identify source of max_iterations. + mlir::LogicalResult MatchInputSource(mlir::TF::WhileOp while_op, + LoopInfo &loop_info) { + // Element index and loop counter should start from 0. + if (!mlir::matchPattern(while_op.getOperand(loop_info.loop_counter), + mlir::m_Zero()) || + !mlir::matchPattern(while_op.getOperand(loop_info.element_index), + mlir::m_Zero())) { + return mlir::failure(); + } + + DCHECK_GE(loop_info.tensor_list_or_flow_in.size(), 1); + // Tensor list or a tensor array are reserved + + for (auto tensor_list_index : loop_info.tensor_list_or_flow_in) { + mlir::Operation *tensor_list_or_flow_in_defining_op = + while_op.getOperand(tensor_list_index).getDefiningOp(); + mlir::Operation *max_iterations = nullptr; + if (loop_info.max_iterations_arg_idx.has_value()) { + max_iterations = + while_op.getOperand(loop_info.max_iterations_arg_idx.value()) + .getDefiningOp(); + } + if (auto tensor_list_reserve = + llvm::dyn_cast( + tensor_list_or_flow_in_defining_op)) { + // Tensor list should resever for max_iterations. + mlir::Operation *tensor_list_reserve_size = + tensor_list_reserve.getNumElements().getDefiningOp(); + + if (tensor_list_reserve_size != max_iterations) { + // if tensor list is not reserved by max_iteration variable, then + // another acceptable case is that both contain same constant values. + llvm::APInt reserved_cst; + if (!mlir::matchPattern(tensor_list_reserve_size, + mlir::m_ConstantInt(&reserved_cst)) || + !loop_info.max_iterations_value.has_value() || + reserved_cst.getZExtValue() != + loop_info.max_iterations_value.value()) { + return mlir::failure(); + } + } + // TensorListReserveOp has only one result and is already in used by + // while. + loop_info.max_iterations = tensor_list_reserve.getNumElements(); + } else if (auto tensor_array = llvm::dyn_cast( + tensor_list_or_flow_in_defining_op)) { + mlir::Operation *tensor_array_size = + tensor_array.getOperand().getDefiningOp(); + if (tensor_array_size != max_iterations) { + // if tensor array is not reserved by max_iteration variable, then + // another acceptable case is that both contain same constant values. + llvm::APInt reserved_cst; + if (!mlir::matchPattern(tensor_array_size, + mlir::m_ConstantInt(&reserved_cst)) || + !loop_info.max_iterations_value.has_value() || + reserved_cst.getZExtValue() != + loop_info.max_iterations_value.value()) { + return mlir::failure(); + } + } + + // Other than flow_in, the tensor array should be used by while as well. + if (!llvm::is_contained(while_op.getOperands(), + tensor_array.getHandle())) { + return mlir::failure(); + } + loop_info.max_iterations = tensor_array.getSize(); + } else { + return mlir::failure(); + } + } + return mlir::success(); + } + + // Match the map_attern that output of while op is subsequentially stacked. + mlir::LogicalResult MatchOutputUse(mlir::TF::WhileOp &while_op, + const LoopInfo &loop_info) { + absl::flat_hash_set used_results; + used_results.insert(loop_info.tensor_list_or_flow_in.begin(), + loop_info.tensor_list_or_flow_in.end()); + for (int i = 0; i < while_op->getResults().size(); ++i) { + if (used_results.contains(i)) { + // Tensor list or flow in should be used next. + if (!while_op->getResult(i).hasOneUse()) { + return mlir::failure(); + } + } else { + // No other result should be used. + if (!while_op->getResult(i).use_empty()) { + return mlir::failure(); + } + } + } + + for (auto result_index : loop_info.tensor_list_or_flow_in) { + mlir::Operation *use_op = + *while_op->getResult(result_index).getUsers().begin(); + + if (!llvm::isa(use_op)) { + return mlir::failure(); + } + } + return mlir::success(); + } + + // Match that the while predicate function is doing just + // loop_counter < max iterations && element_index < max_iterations. + // Through this pattern, we also update the argument index of + // loop_counter, element_index and possibly max_iterations. + mlir::LogicalResult MatchPredicate(mlir::FlatSymbolRefAttr predicate_fn, + const mlir::SymbolTable &symbol_table, + LoopInfo &loop_info) { + mlir::func::FuncOp predicate_fn_op = + symbol_table.lookup(predicate_fn.getValue()); + + // The body of the predicate function should have two LessOp and one + // LogicalAndOp. It can optionally has IdentityOp and ToBoolOp. + enum class PredicateBodyExpectingOp { + kExpectFirstLess, + kExpectSecondLess, + kExpectLogicalAnd, + kExpectTerminator + }; + std::vector less_ops; + less_ops.reserve(2); + PredicateBodyExpectingOp expecting_op = + PredicateBodyExpectingOp::kExpectFirstLess; + for (auto &body_op : predicate_fn_op.getBody().front()) { + switch (expecting_op) { + case PredicateBodyExpectingOp::kExpectFirstLess: + if (llvm::isa(body_op)) { + expecting_op = PredicateBodyExpectingOp::kExpectSecondLess; + less_ops.push_back(&body_op); + } else if (!llvm::isa( + body_op)) { + return mlir::failure(); + } + break; + case PredicateBodyExpectingOp::kExpectSecondLess: + if (llvm::isa(body_op)) { + expecting_op = PredicateBodyExpectingOp::kExpectLogicalAnd; + less_ops.push_back(&body_op); + } else if (!llvm::isa( + body_op)) { + return mlir::failure(); + } + break; + case PredicateBodyExpectingOp::kExpectLogicalAnd: + if (llvm::isa(body_op)) { + expecting_op = PredicateBodyExpectingOp::kExpectTerminator; + } else if (!llvm::isa(body_op)) { + return mlir::failure(); + } + break; + case PredicateBodyExpectingOp::kExpectTerminator: + if (!llvm::isa(body_op)) { + return mlir::failure(); + } + break; + default: + return mlir::failure(); + } + } + + // Identify loop_counter + int counter_index = -1; + auto counter_iter = + llvm::find(predicate_fn_op.getArguments(), less_ops[0]->getOperand(0)); + if (counter_iter != predicate_fn_op.getArguments().end()) { + counter_index = counter_iter->getArgNumber(); + if (!IsScalarOrUnrankedI32Tensor( + predicate_fn_op.getArgument(counter_index))) { + return mlir::failure(); + } + } + + // Find upper bound on loop_counter. + int max_iter_index_from_counter = -1; + int max_iter_value_from_counter = -1; + if (auto max_iter_iter = llvm::find(predicate_fn_op.getArguments(), + less_ops[0]->getOperand(1)); + max_iter_iter != predicate_fn_op.getArguments().end()) { + // Upper bound on loop_counter is from one argument. + max_iter_index_from_counter = max_iter_iter->getArgNumber(); + // Argument has to be int32 + if (!IsScalarOrUnrankedI32Tensor( + predicate_fn_op.getArgument(max_iter_index_from_counter))) { + return mlir::failure(); + } + } else { + // If upper bound is not passed in, it has to be a constant + llvm::APInt value; + if (!mlir::matchPattern(less_ops[0]->getOperand(1).getDefiningOp(), + mlir::m_ConstantInt(&value))) { + return mlir::failure(); + } + max_iter_value_from_counter = value.getZExtValue(); + } + + // Identify element_index + int element_index = -1; + auto element_index_iter = + llvm::find(predicate_fn_op.getArguments(), less_ops[1]->getOperand(0)); + if (element_index_iter != predicate_fn_op.getArguments().end()) { + element_index = element_index_iter->getArgNumber(); + if (!IsScalarOrUnrankedI32Tensor( + predicate_fn_op.getArgument(element_index))) { + return mlir::failure(); + } + } + + // Find upper bound on element_index. + int max_iter_index_from_element = -1; + int max_iter_value_from_element = -1; + if (auto max_iter_iter = llvm::find(predicate_fn_op.getArguments(), + less_ops[1]->getOperand(1)); + max_iter_iter != predicate_fn_op.getArguments().end()) { + // Upper bound on element_index is from one argument. + max_iter_index_from_element = max_iter_iter->getArgNumber(); + // Upper bound argument needs to be int32 + if (!IsScalarOrUnrankedI32Tensor( + predicate_fn_op.getArgument(max_iter_index_from_element))) { + return mlir::failure(); + } + } else { + // If upper bound is not passed in, it has to be a constant + llvm::APInt value; + if (!mlir::matchPattern(less_ops[1]->getOperand(1).getDefiningOp(), + mlir::m_ConstantInt(&value))) { + return mlir::failure(); + } + max_iter_value_from_element = value.getZExtValue(); + } + + // Loop_counter is always available. + if (counter_index < 0) return mlir::failure(); + // element_index can change its location, but will always be provided. + if (element_index < 0) return mlir::failure(); + + std::optional max_iter_const; + std::optional max_iter_index; + if (max_iter_index_from_counter < 0 && max_iter_index_from_element < 0) { + // If both loop counter and element index are not upper bounded by passing + // in arguments, they shall be upper bounded by constants of same value. + if (max_iter_value_from_element != max_iter_value_from_counter || + max_iter_value_from_element < 0 || max_iter_value_from_counter < 0) { + return mlir::failure(); + } else { + max_iter_const = max_iter_value_from_element; + } + } else if (max_iter_index_from_counter >= 0 && + max_iter_index_from_element >= 0) { + // Loop counter or element are upper bounded by pass-in arguments. + // They need to be upper bounded by the same argument + if (max_iter_index_from_element != max_iter_index_from_counter) { + return mlir::failure(); + } else { + max_iter_index = max_iter_index_from_counter; + } + } else { + // TODO(deqiangc): remove this clause after verifying grappler pass remove + // the case that one of them is bounded by pass-in argument and the other + // is bounded by constants. + max_iter_index = + std::max(max_iter_index_from_counter, max_iter_index_from_element); + max_iter_const = + std::max(max_iter_value_from_element, max_iter_value_from_counter); + } + + // Update hypothesis + loop_info.loop_counter = counter_index; + loop_info.element_index = element_index; + loop_info.max_iterations_arg_idx = max_iter_index; + loop_info.max_iterations_value = max_iter_const; + return mlir::success(); + } + + // Match that the current hypothesis of current loop_counter and element_index + // in the while body function based on the following simple pattern: + // %updated_loop_counter = %loop_counter + 1 + // %updated_element_index = %element_index + 1 + // %loaded_elem = tf.Gather(.., %element_index,... ) + // DoSomething + // tf.TensorListSetItem(.., %element_index) + // return %update_loop_counter, %updated_element_index, + // %tensor_array_list, %max_iterations, %other_args + mlir::LogicalResult MatchLoopCounterElementIndexInBody( + mlir::func::FuncOp while_body_func, LoopInfo &loop_info) { + mlir::Block &block = while_body_func.getBlocks().front(); + + // Verify argument loop_counter is +1 and returned at the same location. + mlir::BlockArgument loop_counter = + block.getArgument(loop_info.loop_counter); + llvm::SmallVector loop_counter_users = + GetUsersIgnoringIdentityOp(loop_counter); + if (loop_counter_users.size() != 1 || + !llvm::isa( + loop_counter_users.front()) || + !mlir::matchPattern( + loop_counter_users.front()->getOperand(1).getDefiningOp(), + mlir::m_One())) { + return mlir::failure(); + } + + // loop_counter + 1 is in ReturnOp's operand. + if (loop_counter_users.front() != + GetDefiningOpIgnoringIdentityOp( + GetReturnedOperand(while_body_func, loop_info.loop_counter))) { + return mlir::failure(); + } + + // Verify element_index's usage and also identify the argument index of + // tensor list or tensor array flow_in. + std::vector tensor_list_or_flow_in_index; + mlir::BlockArgument element_index = + block.getArgument(loop_info.element_index); + for (auto *element_index_use : GetUsersIgnoringIdentityOp(element_index)) { + if (llvm::isa(element_index_use)) { + // One use of element_index is +1 and then returned at the same + // location. + if (!mlir::matchPattern( + element_index_use->getOperand(1).getDefiningOp(), + mlir::m_One()) || + element_index_use != + GetDefiningOpIgnoringIdentityOp(GetReturnedOperand( + while_body_func, loop_info.element_index))) { + return mlir::failure(); + } + } else if (llvm::isa(element_index_use)) { + if (auto tensor_list_index = MayGetArgumentIndexIgnoringIdentityOp( + while_body_func, + llvm::dyn_cast(element_index_use) + .getInputHandle()); + !tensor_list_index.has_value()) { + return mlir::failure(); + } else { + tensor_list_or_flow_in_index.push_back(tensor_list_index.value()); + } + } else if (llvm::isa(element_index_use)) { + if (auto flow_in_index = MayGetArgumentIndexIgnoringIdentityOp( + while_body_func, llvm::dyn_cast( + element_index_use) + .getFlowIn()); + !flow_in_index.has_value()) { + return mlir::failure(); + } else { + tensor_list_or_flow_in_index.push_back(flow_in_index.value()); + } + } else if (!llvm::isa(element_index_use)) { + // The only other use is to either gather the input or set output. + return mlir::failure(); + } + } + + if (tensor_list_or_flow_in_index.empty()) { + return mlir::failure(); + } + + // Update hypothesis + loop_info.tensor_list_or_flow_in = std::move(tensor_list_or_flow_in_index); + + return mlir::success(); + } + + // Match that the while body function is the following simple pattern: + // %updated_loop_counter = %loop_counter + 1 + // %updated_element_index = %element_index + 1 + // %loaded_elem = tf.Gather(.., %element_index,... ) + // DoSomething + // tf.TensorListSetItem(.., %element_index) + // return %update_loop_counter, %updated_element_index, + // %tensor_array_list, %max_iterations, %other_args + // + // in which + // DoSomething has no side-effect on the next iteration. + // + // Also identify argument index for TensorList or TensorArray flow_in. + mlir::LogicalResult MatchBody(mlir::FlatSymbolRefAttr while_body_func_name, + const mlir::SymbolTable &symbol_table, + LoopInfo &loop_info) { + mlir::func::FuncOp while_body_func = + symbol_table.lookup( + while_body_func_name.getValue()); + + if (mlir::failed( + MatchLoopCounterElementIndexInBody(while_body_func, loop_info))) { + // Swap the order of loop_counter and element_index in the current + // hypothesis and try again + int swap = loop_info.loop_counter; + loop_info.loop_counter = loop_info.element_index; + loop_info.element_index = swap; + if (mlir::failed( + MatchLoopCounterElementIndexInBody(while_body_func, loop_info))) { + return mlir::failure(); + } + } + + // The next iteration of while_body does not depend on the previous + // iteration except loop_counter, element_index, tensor_list_or_flow_in, and + // max_iterations. + absl::flat_hash_set allowed_variable_between_iterations; + allowed_variable_between_iterations.insert(loop_info.loop_counter); + allowed_variable_between_iterations.insert(loop_info.element_index); + if (loop_info.max_iterations_arg_idx.has_value()) { + allowed_variable_between_iterations.insert( + loop_info.max_iterations_arg_idx.value()); + } + allowed_variable_between_iterations.insert( + loop_info.tensor_list_or_flow_in.begin(), + loop_info.tensor_list_or_flow_in.end()); + for (int j = 0; j < while_body_func.getNumArguments(); j++) { + if (!allowed_variable_between_iterations.contains(j)) { + if (GetReturnedOperand(while_body_func, j) != + while_body_func.getArgument(j)) { + return mlir::failure(); + } + } + } + + return mlir::success(); + } + + // The map_fn body function is a clone of the while_body_func that + // canonicalize loop_counter and tensor_list_index to be the first two + // arguments. + mlir::func::FuncOp CreateMapFnBodyFunction(mlir::OpBuilder &builder, + mlir::func::FuncOp while_body_func, + mlir::SymbolTable &symbol_table, + const LoopInfo &loop_info) { + std::string map_fn_body_name = + absl::StrCat(while_body_func.getSymName().str(), "/MapFnBody"); + + if (auto func = symbol_table.lookup(map_fn_body_name)) { + return func; + } + + RemoveIdentityOp(while_body_func); + + absl::flat_hash_set variant_arguments = {loop_info.loop_counter, + loop_info.element_index}; + variant_arguments.insert(loop_info.tensor_list_or_flow_in.begin(), + loop_info.tensor_list_or_flow_in.end()); + llvm::SmallVector remapped_input_type; + + for (int i = 0; i < loop_info.tensor_list_or_flow_in.size(); i++) { + remapped_input_type.push_back( + builder.getType()); + remapped_input_type.push_back( + builder.getType()); + } + + remapped_input_type.push_back( + while_body_func.getFunctionType().getInput(loop_info.loop_counter)); + remapped_input_type.push_back( + while_body_func.getFunctionType().getInput(loop_info.element_index)); + for (int i = 0; i < while_body_func.getFunctionType().getNumInputs(); i++) { + if (!variant_arguments.contains(i)) { + remapped_input_type.push_back( + while_body_func.getFunctionType().getInput(i)); + } + } + mlir::OpBuilder::InsertionGuard insertion_guard(builder); + builder.setInsertionPointAfter(while_body_func); + auto map_fn_body_func = builder.create( + while_body_func.getLoc(), map_fn_body_name, + mlir::FunctionType::get(while_body_func.getContext(), + remapped_input_type, {})); + + map_fn_body_func->setAttr( + "tfrt.cost_threshold", + builder.getI64IntegerAttr(std::numeric_limits::max())); + + if (while_body_func.getArgAttrs().has_value()) { + llvm::SmallVector remapped_input_attributes; + // No attributes carry over for tensor list future/promise. + for (int i = 0; i < loop_info.tensor_list_or_flow_in.size(); i++) { + remapped_input_attributes.push_back(mlir::Attribute()); + remapped_input_attributes.push_back(mlir::Attribute()); + } + auto args_attrs = while_body_func.getArgAttrs().value(); + remapped_input_attributes.push_back(args_attrs[loop_info.loop_counter]); + remapped_input_attributes.push_back(args_attrs[loop_info.element_index]); + for (int i = 0; i < args_attrs.size(); i++) { + if (!variant_arguments.contains(i)) { + remapped_input_attributes.push_back(args_attrs[i]); + } + } + map_fn_body_func.setAllArgAttrs(remapped_input_attributes); + } + auto future_index = [](int i) { return 2 * i; }; + auto promise_index = [](int i) { return 2 * i + 1; }; + + if (while_body_func.getResAttrs().has_value()) { + // The order and types of results remain the same; so does attributes. + map_fn_body_func.setAllResultAttrs(while_body_func.getResAttrs().value()); + } + map_fn_body_func.setVisibility(mlir::func::FuncOp::Visibility::Private); + + builder.setInsertionPointToEnd(map_fn_body_func.addEntryBlock()); + + mlir::IRMapping mapping; + std::vector await_ops; + for (int i = 0; i < loop_info.tensor_list_or_flow_in.size(); i++) { + await_ops.push_back(builder.create( + while_body_func.getLoc(), + while_body_func.getArgument(loop_info.tensor_list_or_flow_in.at(i)) + .getType(), + map_fn_body_func.getArgument(future_index(i)))); + + mapping.map( + while_body_func.getArgument(loop_info.tensor_list_or_flow_in.at(i)), + await_ops.at(i)); + } + // Rest of argument start after promise + int map_fn_argument_index = + promise_index(loop_info.tensor_list_or_flow_in.size() - 1); + mapping.map(while_body_func.getArgument(loop_info.loop_counter), + map_fn_body_func.getArgument(++map_fn_argument_index)); + mapping.map(while_body_func.getArgument(loop_info.element_index), + map_fn_body_func.getArgument(++map_fn_argument_index)); + for (int i = 0; i < while_body_func.getNumArguments(); i++) { + if (!variant_arguments.contains(i)) { + mapping.map(while_body_func.getArgument(i), + map_fn_body_func.getArgument(++map_fn_argument_index)); + } + } + + for (auto &op : while_body_func.getBody().front()) { + builder.clone(op, mapping); + } + + auto return_op = map_fn_body_func.getBody().front().getTerminator(); + + mlir::Operation *first_write = nullptr; + // Move tensor list write to the end of the block. + for (int index : loop_info.tensor_list_or_flow_in) { + auto *def = return_op->getOperand(index).getDefiningOp(); + CHECK(def); // Crash OK + def->moveBefore(return_op); + if (!first_write) first_write = def; + } + + // Move the await op before the first write. + for (auto tensor_list_or_flow_in : await_ops) { + tensor_list_or_flow_in->moveBefore(first_write); + } + + // Insert promise right before return + builder.setInsertionPoint(return_op); + for (int i = 0; i < await_ops.size(); i++) { + builder.create( + return_op->getLoc(), map_fn_body_func.getArgument(promise_index(i)), + return_op->getOperand(loop_info.tensor_list_or_flow_in.at(i))); + } + builder.create(return_op->getLoc()); + return_op->erase(); + + symbol_table.insert(map_fn_body_func); + + return map_fn_body_func; + } + + std::optional MayGetArgumentIndexIgnoringIdentityOp( + mlir::func::FuncOp func, mlir::Value value) const { + // Value may go through some identify chains. + while (value.getDefiningOp()) { + if (!llvm::isa(value.getDefiningOp())) { + return std::nullopt; + } + value = value.getDefiningOp()->getOperand(0); + } + + // Value is directly from argument since it has no defining op. + auto argument_iter = llvm::find(func.getArguments(), value); + if (argument_iter == func.getArguments().end()) { + return std::nullopt; + } + return argument_iter->getArgNumber(); + } + + // Given a value, find its use ignoring identify op. + // For example, given the below chains: + // + // %original_value = OriginalDefinedOp() + // %value1 = tf.IdentifyOp(original_value) + // %value2 = tf.IdentifyOp(value1) + // UseOp(%value2) + // + // GetUseIgnroningIdentifyOp(%original_value) will return UseOp + llvm::SmallVector GetUsersIgnoringIdentityOp( + mlir::Value value) { + llvm::SmallVector users; + std::vector users_stack; + + for (auto *direct_user : value.getUsers()) { + users_stack.push_back(direct_user); + } + + while (!users_stack.empty()) { + mlir::Operation *descendent_user = users_stack.back(); + users_stack.pop_back(); + + if (!llvm::isa(descendent_user)) { + users.push_back(descendent_user); + } else { + // User of identify op is considered as user. + for (auto *user : descendent_user->getResult(0).getUsers()) { + users_stack.push_back(user); + } + } + } + return users; + } + + // Given a value, find its source defined op ignoring identify op. + // For example, given the below chains: + // + // %original_value = OriginalDefinedOp() + // %value1 = tf.IdentifyOp(original_value) + // %value2 = tf.IdentifyOp(value1) + // UseOp(%value2) + // + // GetDefiningOpIgnroningIdentifyOp(%value2) will return OriginalDefinedOp + mlir::Operation *GetDefiningOpIgnoringIdentityOp(mlir::Value value) { + mlir::Operation *source_op = value.getDefiningOp(); + while (llvm::isa(source_op)) { + source_op = source_op->getOperand(0).getDefiningOp(); + } + return source_op; + } + + mlir::Value GetReturnedOperand(const mlir::func::FuncOp func, + uint32_t result_index) { + auto return_op = llvm::dyn_cast( + func->getRegion(0).front().getTerminator()); + DCHECK_NE(return_op, nullptr); + return return_op->getOperand(result_index); + } + + bool IsScalarI32Tensor(mlir::Value value) const { + if (auto value_type = llvm::dyn_cast(value.getType())) { + if (value_type.getElementType().isInteger(32) && value_type.hasRank() && + value_type.getRank() == 0) { + return true; + } + } + return false; + } + + bool IsScalarOrUnrankedI32Tensor(mlir::Value value) const { + if (auto value_type = llvm::dyn_cast(value.getType())) { + if (value_type.getElementType().isInteger(32) && + ((value_type.hasRank() && value_type.getRank() == 0) || + !value_type.hasRank())) { + return true; + } + } + return false; + } +}; +} // namespace + +std::unique_ptr> CreateWhileToMapFnPass() { + return std::make_unique(); +} + +} // namespace mlrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h new file mode 100644 index 00000000000..a45c03871c7 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_WHILE_TO_MAP_FN_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_WHILE_TO_MAP_FN_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +std::unique_ptr> CreateWhileToMapFnPass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_WHILE_TO_MAP_FN_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index 2eda5bfd0e9..bf27bac6ffb 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" @@ -207,6 +208,11 @@ void CreateTFExecutorToTFInvariantOptimizationPipelineHelper( pm.addPass(CreateSinkInInvariantOpsPass()); } + if (!options.saved_model_dir.empty()) { + pm.addPass( + mlir::tf_saved_model::CreateAssetSinkingPass(options.saved_model_dir)); + } + pm.addPass(CreateLowerTFSavedModelPass( options.hoist_invariant_ops, options.fuse_get_resource_ops_in_hoisting)); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h index 24d245b1714..a4c62f8bf20 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h @@ -24,6 +24,8 @@ namespace tensorflow { struct TfrtPipelineOptions : public mlir::PassPipelineOptions { + Option saved_model_dir{*this, "saved-model-dir", + llvm::cl::desc(""), llvm::cl::init("")}; Option default_device{ *this, "default-device", llvm::cl::desc("default device assignment"), llvm::cl::init("/job:localhost/replica:0/task:0/device:CPU:0")}; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc index d4e39782e2c..b14882d9ca8 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc @@ -39,8 +39,8 @@ void UpdateOpCostInTfrtMlir(mlir::ModuleOp op, if (!op_key_attr) return; // Set the cost attr with a new value. const int64_t op_key = op_key_attr.getInt(); - op->setAttr(kCostAttrName, builder.getI64IntegerAttr( - cost_recorder.GetCostNanosecond(op_key))); + op->setAttr(kCostAttrName, + builder.getI64IntegerAttr(cost_recorder.GetCost(op_key))); }); } diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 1573306ba2d..9e93aa345ad 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -70,7 +70,7 @@ StatusOr> ExportXlaFunctions(mlir::ModuleOp module) { const auto func_op = symbol_table.lookup(func_name); if (!func_op) { - return tensorflow::errors::Internal( + return absl::InternalError( absl::StrCat("Function ", func_name, " is not found.")); } FunctionDef func_def; @@ -92,6 +92,14 @@ StatusOr> ExportXlaFunctions(mlir::ModuleOp module) { } } }); + + // Remove the function from the module, as it will be handled by XLA. + // It is safe to remove the function, i.e., the function won't be invoked on + // CPU. This is because bridge guarantees that each function has only one + // use. We don't replace the uses of the function, because we iterate from + // the root caller and hence its uses should have been removed. + func_op->erase(); + visited.insert(func_name); } return xla_func_defs; @@ -111,9 +119,9 @@ Status ConvertFunctionToBef( tensorflow::ConvertFunctionToMlir(fbody, flib_def, &context); if (!expected_module.ok()) - return tensorflow::errors::Internal( + return absl::InternalError(absl::StrCat( "Failed to convert function to mlir for function ", function_name.str(), - ". Error: ", expected_module.status().message()); + ". Error: ", expected_module.status().message())); auto module = std::move(expected_module).value(); @@ -152,7 +160,7 @@ Status ConvertTfMlirToRuntimeExecutable( tensorflow::RunTPUBackwardCompatConversion(module, tpu_compile_options); if (mlir::failed(backward_compat_result)) { return diag_handler.Combine( - tensorflow::errors::Internal("Failed to handle legacy TPU Ops")); + absl::InternalError("Failed to handle legacy TPU Ops")); } if (VLOG_IS_ON(1)) { @@ -165,7 +173,7 @@ Status ConvertTfMlirToRuntimeExecutable( auto tpu_partitioned_call_fallback_compat_result = tensorflow::RunTPUPartitionedCallFallbackCompatConversion(module); if (mlir::failed(tpu_partitioned_call_fallback_compat_result)) { - return diag_handler.Combine(tensorflow::errors::Internal( + return diag_handler.Combine(absl::InternalError( "Failed to process TPUPartitionedCallOp for fallback execution")); } } else if (options.device_target == TfrtDeviceInfraTarget::kGpu && @@ -222,7 +230,7 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("tf_to_corert_failure", module); } - return diag_handler.Combine(tensorflow::errors::Internal( + return diag_handler.Combine(absl::InternalError( "failed to lower TF Dialect to CoreRT dialect.")); } @@ -230,7 +238,7 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, tfrt::ConvertMLIRToBEF(module, /*disable_optional_sections=*/true); if (bef_buffer->empty()) return diag_handler.Combine( - tensorflow::errors::Internal("failed to convert MLIR to BEF.")); + absl::InternalError("failed to convert MLIR to BEF.")); bef_buffer->shrink_to_fit(); return OkStatus(); @@ -241,6 +249,9 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, std::unique_ptr GetTfrtPipelineOptions( const TfrtCompileOptions& options) { auto pipeline_options = std::make_unique(); + + pipeline_options->saved_model_dir = options.saved_model_dir; + if (!options.default_device.empty()) { pipeline_options->default_device = options.default_device; } diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD new file mode 100644 index 00000000000..0d5fac70054 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD @@ -0,0 +1,71 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", + # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", + # copybara:uncomment "//smartass/brain/ops/tfrt_kernels:__subpackages__", + "//tensorflow/compiler/mlir/tfrt/transforms/mlrt:__subpackages__", + "//tensorflow/core/tfrt:__subpackages__", + ], +) + +cc_library( + name = "mlir_to_bytecode", + srcs = ["mlir_to_bytecode.cc"], + hdrs = ["mlir_to_bytecode.h"], + deps = [ + "//tensorflow/core/tfrt/mlrt/bytecode", + "//tensorflow/core/tfrt/mlrt/bytecode:executable", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + +tf_cc_test( + name = "mlir_to_bytecode_test", + srcs = ["mlir_to_bytecode_test.cc"], + data = glob(["testdata/**"]), + deps = [ + ":mlir_to_bytecode", + "//tensorflow/core/tfrt/mlrt/bytecode:executable", + "//tensorflow/core/tfrt/mlrt/interpreter:attribute_span", + "//tensorflow/tsl/platform:resource_loader", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Parser", + ], +) + +cc_library( + name = "test_utils", + testonly = 1, + srcs = ["test_utils.cc"], + hdrs = ["test_utils.h"], + deps = [ + "//learning/brain/experimental/tfrt/native_lowering/kernels:sync_context", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tfrt/mlrt/attribute", + "//tensorflow/core/tfrt/mlrt/bytecode", + "//tensorflow/core/tfrt/mlrt/bytecode:kernel", + "//tensorflow/core/tfrt/mlrt/interpreter:context", + "//tensorflow/core/tfrt/mlrt/interpreter:interpreter_testutil", + "//tensorflow/core/tfrt/mlrt/interpreter:value", + "//tensorflow/core/tfrt/utils:tensor_util", + "//tensorflow/tsl/platform:errors", + "@com_google_absl//absl/status:statusor", + "@tf_runtime//:hostcontext", + "@tf_runtime//:tensor", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc new file mode 100644 index 00000000000..895f705e6b1 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -0,0 +1,470 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" + +namespace mlrt { +namespace { + +// LINT.IfChange(mlrt_attributes) +bool CanBeInlined(mlir::Attribute attr, absl::string_view data) { + // FlatSymbolRefAttr is a special case as we are emitting it as integer. + return attr.isa() && + data.size() <= sizeof(uint32_t); +} +// LINT.ThenChange(../../../../../core/tfrt/mlrt/interpreter/attribute_span.h:mlrt_attributes) + +// Encode integer or float-point numbers as bytes. +template +std::string EncodeIntegerOrFloat(T attr) { + std::string data(sizeof(attr), '\0'); + std::memcpy(data.data(), &attr, sizeof(attr)); + return data; +} + +// Encode a list of I64 integers as bytes using bc::Vector. The bytes +// can be decoded directly using bc::Vector. If `array` is not a list +// I64 integers, a nullopt will be returned. + +template +std::optional EncodeListOfInteger(mlir::ArrayAttr array) { + bc::Buffer buffer; + bc::Allocator allocator(&buffer); + auto ctor = bc::New>(&allocator, array.size()); + + mlir::Type type; + + for (int i = 0; i < array.size(); ++i) { + if (auto integer_attr = array[i].dyn_cast()) { + if (type && integer_attr.getType() != type) return std::nullopt; + type = integer_attr.getType(); + llvm::APInt value = integer_attr.getValue(); + if (value.getBitWidth() != sizeof(T) * 8) return std::nullopt; + ctor.ConstructAt(i, value.getZExtValue()); + } else { + return std::nullopt; + } + } + + return std::string(buffer.data(), buffer.size()); +} + +std::optional EncodeListOfSymbolRef( + const ModuleEmitterContext& module_context, mlir::ArrayAttr array) { + bc::Buffer buffer; + bc::Allocator allocator(&buffer); + auto ctor = bc::New>(&allocator, array.size()); + + for (int i = 0; i < array.size(); ++i) { + if (auto symbol_ref = array[i].dyn_cast()) { + ctor.ConstructAt(i, module_context.GetFunctionId(symbol_ref.getValue())); + } else { + return std::nullopt; + } + } + return std::string(buffer.data(), buffer.size()); +} + +template +std::optional EncodeDenseArray(llvm::ArrayRef array) { + bc::Buffer buffer; + bc::Allocator allocator(&buffer); + auto ctor = bc::New>(&allocator, array.size()); + + if (!array.empty()) { + ctor.Place(reinterpret_cast(array.data()), + array.size() * sizeof(T)); + } + + return std::string(buffer.data(), buffer.size()); +} + +// Encode a list of strings as bytes using bc::Vector. The bytes +// can be decoded directly using bc::Vector. If `array` is not a +// list of strings, a nullopt will be returned. +std::optional EncodeListOfString(mlir::ArrayAttr array) { + bc::Buffer buffer; + bc::Allocator allocator(&buffer); + auto ctor = bc::New>(&allocator, array.size()); + + for (int i = 0; i < array.size(); ++i) { + if (auto string_attr = array[i].dyn_cast()) { + ctor.ConstructAt(i, string_attr.getValue().str()); + } else { + return std::nullopt; + } + } + + return std::string(buffer.data(), buffer.size()); +} + +struct FunctionEmitterContext { + explicit FunctionEmitterContext(const ModuleEmitterContext* module_context) + : module_context(*module_context) {} + + const ModuleEmitterContext& module_context; + + struct RegInfo { + int num_uses = 0; + int id = -1; + }; + + int next_reg_id = 0; + llvm::DenseMap register_table; + std::vector free_regs; + + int AssignRegId() { + if (free_regs.empty()) { + return next_reg_id++; + } + int id = free_regs.back(); + free_regs.pop_back(); + return id; + } + + void FreeRegId(int id) { free_regs.push_back(id); } +}; + +// Emit the bytecode for a kernel. It uses the information in an MLIR operation +// and populates the bytecode using bc::Kernel::Constructor. For a kernel's +// bytecode format, please refer to kernel.h. +void EmitKernel(FunctionEmitterContext& function_context, + bc::Kernel::Constructor& constructor, mlir::Operation& op, + std::vector& function_output_regs, + std::vector& function_output_last_uses) { + // Assign reg ids for results first to make sure results does not reuse reg + // ids freed from args in the same operation. + std::vector results; + results.reserve(op.getNumResults()); + for (auto result : op.getResults()) { + auto iter = function_context.register_table.find(result); + CHECK(iter != function_context.register_table.end()); // Crash Ok + CHECK_EQ(iter->second.id, -1); // Crash Ok + iter->second.id = function_context.AssignRegId(); + results.push_back(iter->second.id); + } + constructor.construct_results(results.size()) + .Assign(results.begin(), results.end()); + + std::vector arguments; + std::vector last_uses; + arguments.reserve(op.getNumOperands()); + last_uses.reserve(op.getNumOperands()); + for (auto operand : op.getOperands()) { + auto iter = function_context.register_table.find(operand); + CHECK(iter != function_context.register_table.end()); // Crash Ok + int id = iter->second.id; + CHECK_NE(id, -1); // Crash Ok + last_uses.push_back(0); + if (--iter->second.num_uses == 0) { + function_context.FreeRegId(id); + last_uses.back() = 1; + } + arguments.push_back(id); + } + + constructor.construct_arguments(arguments.size()) + .Assign(arguments.begin(), arguments.end()); + constructor.construct_last_uses(last_uses.size()) + .Assign(last_uses.begin(), last_uses.end()); + + std::vector attributes; + attributes.reserve(op.getAttrs().size()); + for (auto attr : op.getAttrs()) { + int attr_id = + function_context.module_context.GetAttributeId(attr.getValue()); + absl::string_view attr_data = + function_context.module_context.attributes().at(attr_id); + + if (CanBeInlined(attr.getValue(), attr_data)) { + uint32_t data = 0; + std::memcpy(&data, attr_data.data(), attr_data.size()); + attributes.push_back(data); + } else { + attributes.push_back(attr_id); + } + } + constructor.construct_attributes(attributes.size()) + .Assign(attributes.begin(), attributes.end()); + + if (op.hasTrait()) { + constructor.set_code(function_context.module_context.GetKernelId("return")); + + function_output_regs = std::move(arguments); + function_output_last_uses = std::move(last_uses); + + } else if (llvm::isa(&op)) { + constructor.set_code(function_context.module_context.GetKernelId("call")); + } else { + llvm::StringRef op_name = op.getName().getStringRef(); + constructor.set_code(function_context.module_context.GetKernelId(op_name)); + } +} + +// Emit the bytecode for a function. It uses information in an MLIR function or +// an MLIR region, and populates the bytecode using bc::Function::Constructor. +// For a function's bytecode format, please refer to function.h. +void EmitFunction(const ModuleEmitterContext& module_context, + bc::Function::Constructor& constructor, llvm::StringRef name, + mlir::Region& region) { + FunctionEmitterContext function_context(&module_context); + + constructor.construct_name(name.str()); + + DCHECK(llvm::hasSingleElement(region)) << "should have a single block"; + + auto& block = region.front(); + + auto& register_table = function_context.register_table; + + std::vector input_regs; + input_regs.reserve(block.getNumArguments()); + for (auto arg : block.getArguments()) { + int id = function_context.AssignRegId(); + input_regs.push_back(id); + register_table[arg] = {static_cast(std::distance(arg.getUses().begin(), + arg.getUses().end())), + id}; + } + constructor.construct_input_regs(input_regs); + + for (auto& op : block) { + for (auto result : op.getResults()) { + register_table[result] = {static_cast( + std::distance(result.getUses().begin(), result.getUses().end()))}; + } + } + + auto kernels_constructor = + constructor.construct_kernels(block.getOperations().size()); + + std::vector output_regs; + std::vector output_last_uses; + for (const auto& iter : llvm::enumerate(block.getOperations())) { + int i = iter.index(); + mlir::Operation& op = iter.value(); + auto kernel_ctor = kernels_constructor.ConstructAt(i); + EmitKernel(function_context, kernel_ctor, op, output_regs, + output_last_uses); + } + + constructor.set_num_regs(function_context.next_reg_id); + constructor.construct_output_regs(output_regs); + constructor.construct_output_last_uses(output_last_uses); +} + +// Emit the bytecode for an executable. It converts attributes, kernels, and +// functions in an MLIR module to bytecode using bc::Executable::Constructor. +// For an executable's bytecode format, please refer to executable.h. +absl::Status EmitExecutable(ModuleEmitterContext& module_context, + bc::Executable::Constructor& constructor, + mlir::ModuleOp module) { + module.walk( + [&](mlir::func::FuncOp func) { module_context.AddFunction(func); }); + + auto functions = module_context.functions(); + for (auto func : functions) { + if (!llvm::hasSingleElement(func.getRegion())) { + return absl::InvalidArgumentError("function should have a single block."); + } + auto& block = func.getRegion().front(); + + for (auto& op : block) { + if (llvm::isa(&op)) { + // Canonicalize the MLIR builtin call op's name to "call". + module_context.AddKernelName("call"); + } else if (op.hasTrait()) { + // Canonicalize the return op's name to "return". + if (op.getNumResults() != 0) { + return absl::InvalidArgumentError( + "Block terminator must be a return op."); + } + module_context.AddKernelName("return"); + } else { + module_context.AddKernelName(op.getName().getStringRef().str()); + } + + for (auto attr : op.getAttrs()) { + if (auto status = module_context.AddAttribute(&op, attr.getValue()); + !status.ok()) { + return status; + } + } + + // TODO(chky): Support inline regions. + } + } + + constructor.construct_kernel_names(module_context.kernels().size()) + .Assign(module_context.kernels().begin(), module_context.kernels().end()); + + auto functions_constructor = + constructor.construct_functions(functions.size()); + for (int i = 0; i < functions.size(); ++i) { + auto func = functions[i]; + auto function_ctor = functions_constructor.ConstructAt(i); + EmitFunction(module_context, function_ctor, func.getSymName(), + func.getRegion()); + } + + // Emit attributes after emitting functions as attributes might be large. + // Large attributes may result in large offsets that do not fit into a + // unit32_t integer. Since functions section should fit into 2GB size limit, + // so we emit functions first. + constructor.construct_attributes(module_context.attributes().size()) + .Assign(module_context.attributes().begin(), + module_context.attributes().end()); + + return absl::OkStatus(); +} + +} // namespace + +absl::Status ModuleEmitterContext::AddAttribute(mlir::Operation* op, + mlir::Attribute attr) { + absl::StatusOr attr_data; + if (auto* encoder = attribute_encoder_registry_.Get( + op->getName().getDialectNamespace())) { + attr_data = (*encoder)(*this, attr); + } else { + attr_data = DefaultEncodeAttribute(attr); + } + if (!attr_data.ok()) return std::move(attr_data).status(); + + int id = AddData(std::move(*attr_data), attributes_, attribute_data_id_map_); + attribute_id_map_[attr] = id; + + return absl::OkStatus(); +} + +int ModuleEmitterContext::AddFunction(mlir::func::FuncOp func) { + int id = functions_.size(); + functions_.push_back(func); + DCHECK(!function_name_id_map_.contains(func.getSymName())); + function_name_id_map_[func.getSymName()] = id; + return id; +} + +std::optional EncodeSimpleAttribute( + const ModuleEmitterContext& module_context, mlir::Attribute attr) { + return llvm::TypeSwitch>(attr) + .Case( + [](const auto& str_attr) { return str_attr.str(); }) + .Case( + [](const auto& integer_attr) -> std::optional { + switch (llvm::APInt value = integer_attr.getValue(); + value.getBitWidth()) { + case 1: + return EncodeIntegerOrFloat(value.getZExtValue()); + case 32: + return EncodeIntegerOrFloat(value.getZExtValue()); + case 64: + return EncodeIntegerOrFloat(value.getZExtValue()); + default: + return std::nullopt; + } + }) + .Case( + [](const auto& float_attr) -> std::optional { + llvm::APFloat value = float_attr.getValue(); + if (float_attr.getType().isF32()) { + return EncodeIntegerOrFloat(value.convertToFloat()); + } + return std::nullopt; + }) + .Case([&](const auto& array_attr) + -> std::optional { + if (auto encoded_list_i32 = EncodeListOfInteger(array_attr)) { + return std::move(*encoded_list_i32); + } else if (auto encoded_list_i64 = + EncodeListOfInteger(array_attr)) { + return std::move(*encoded_list_i64); + } else if (auto encoded_list_string = EncodeListOfString(array_attr)) { + return std::move(*encoded_list_string); + } else if (auto encoded_list_symbol_ref = + EncodeListOfSymbolRef(module_context, array_attr)) { + return std::move(*encoded_list_symbol_ref); + } else { + return std::nullopt; + } + }) + .Case( + [](const auto& dense_array_i32) -> std::optional { + return EncodeDenseArray(dense_array_i32); + }) + .Case( + [](const auto& dense_array_i64) -> std::optional { + return EncodeDenseArray(dense_array_i64); + }) + .Case([&](const auto& symbol_ref) { + return EncodeIntegerOrFloat( + module_context.GetFunctionId(symbol_ref.getValue())); + }) + .Default([](const auto& attr) { return std::nullopt; }); +} + +// Encode mlir attributes with a limited support such as I64, string and array +// of I64. Returns an error if the attribute is not supported. +absl::StatusOr ModuleEmitterContext::DefaultEncodeAttribute( + mlir::Attribute attr) { + if (auto result = EncodeSimpleAttribute(*this, attr)) { + return std::move(*result); + } + + // TODO(chky): Add a unit test for the error below. This requires we + // propagate the error all the way back to the entry point. + std ::string attr_str; + llvm::raw_string_ostream os(attr_str); + attr.print(os); + + return absl::InvalidArgumentError( + absl::StrCat("Try to encode unsupported attribute: ", attr_str)); +} + +absl::StatusOr EmitExecutable( + const AttributeEncoderRegistry& attribute_encoder_registry, + mlir::ModuleOp module) { + bc::Buffer buffer; + bc::Allocator allocator(&buffer); + + ModuleEmitterContext module_context(&attribute_encoder_registry); + + auto executable_ctor = bc::New(&allocator); + + if (auto status = EmitExecutable(module_context, executable_ctor, module); + !status.ok()) { + return status; + } + + return buffer; +} + +} // namespace mlrt diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h new file mode 100644 index 00000000000..7f5416d230c --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h @@ -0,0 +1,128 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_MLIR_TO_BYTECODE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_MLIR_TO_BYTECODE_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" + +namespace mlrt { + +class ModuleEmitterContext; + +// Defines a custom attribute encoding registry. Users can register custom +// attribute encoding for their dialects in this registry. If no custom encoder +// is registered for a dialect, the default encoding with a limited support, the +// EncodeSimpleAttribute() below, will be used. +class AttributeEncoderRegistry { + public: + using EncoderFn = std::function( + const ModuleEmitterContext&, mlir::Attribute)>; + + void Register(absl::string_view dialect, EncoderFn encoder) { + encoders_[dialect] = std::move(encoder); + } + + // Returns the encoder for the specified dialect. It can be nullptr if it is + // not registered for this dialect. The returned reference will be invalidated + // if Register() is called. + const EncoderFn* Get(absl::string_view dialect) const { + auto iter = encoders_.find(dialect); + if (iter != encoders_.end()) return &iter->second; + return nullptr; + } + + private: + absl::flat_hash_map encoders_; +}; + +class ModuleEmitterContext { + public: + explicit ModuleEmitterContext( + const AttributeEncoderRegistry* attribute_encoder_registry) + : attribute_encoder_registry_(*attribute_encoder_registry) {} + + void AddKernelName(std::string name) { + AddData(std::move(name), kernels_, kernel_id_map_); + } + + int GetKernelId(llvm::StringRef name) const { + return kernel_id_map_.at(name); + } + + absl::Status AddAttribute(mlir::Operation* op, mlir::Attribute attr); + + int GetAttributeId(mlir::Attribute attr) const { + return attribute_id_map_.lookup(attr); + } + + int AddFunction(mlir::func::FuncOp func); + + int GetFunctionId(absl::string_view name) const { + return function_name_id_map_.at(name); + } + + absl::Span kernels() const { return kernels_; } + absl::Span attributes() const { return attributes_; } + absl::Span functions() const { return functions_; } + + private: + int AddData(std::string data, std::vector& data_vector, + absl::flat_hash_map& data_map) { + auto iter = data_map.find(data); + if (iter != data_map.end()) return iter->second; + + int id = data_vector.size(); + data_map[data] = id; + data_vector.push_back(std::move(data)); + return id; + } + + absl::StatusOr DefaultEncodeAttribute(mlir::Attribute attr); + + const AttributeEncoderRegistry& attribute_encoder_registry_; + + std::vector kernels_; + absl::flat_hash_map kernel_id_map_; + + std::vector attributes_; + llvm::DenseMap attribute_id_map_; + absl::flat_hash_map attribute_data_id_map_; + + std::vector functions_; + absl::flat_hash_map function_name_id_map_; +}; + +// Encodes a few simple attributes. Users can use this function in their custom +// attribute encoder. +std::optional EncodeSimpleAttribute( + const ModuleEmitterContext& module_context, mlir::Attribute attr); + +absl::StatusOr EmitExecutable( + const AttributeEncoderRegistry& attribute_encoder_registry, + mlir::ModuleOp module); + +} // namespace mlrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_MLIR_TO_BYTECODE_H_ diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc new file mode 100644 index 00000000000..d94e8df3b2e --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc @@ -0,0 +1,365 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h" + +#include +#include +#include + +#include +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h" +#include "tensorflow/tsl/platform/resource_loader.h" + +namespace mlrt { +namespace { + +TEST(MlirToByteCodeTest, Basic) { + constexpr char kBasicMlir[] = + "tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/basic.mlir"; + + mlir::DialectRegistry registry; + registry.insert(); + mlir::MLIRContext mlir_context(registry); + mlir_context.allowUnregisteredDialects(); + auto mlir_module = mlir::parseSourceFile( + tsl::GetDataDependencyFilepath(kBasicMlir), &mlir_context); + + AttributeEncoderRegistry attribute_encoder_registry; + bc::Buffer buffer = + EmitExecutable(attribute_encoder_registry, mlir_module.get()).value(); + + bc::Executable executable(buffer.data()); + + auto kernel_names = executable.kernel_names(); + EXPECT_THAT(kernel_names, ::testing::ElementsAreArray({"test_mlbc.add.i32", + "test_mlbc.sub.i32", + "call", "return"})); + + auto functions = executable.functions(); + ASSERT_GE(functions.size(), 1); + + auto function = functions[0]; + EXPECT_EQ(function.name().str(), "add_i32_10"); + EXPECT_EQ(function.num_regs(), 5); + EXPECT_THAT(function.input_regs(), ::testing::ElementsAreArray({0})); + EXPECT_THAT(function.output_regs(), ::testing::ElementsAreArray({0, 2, 2})); + EXPECT_THAT(function.output_last_uses(), + ::testing::ElementsAreArray({true, false, true})); + + auto kernels = function.kernels(); + ASSERT_EQ(kernels.size(), 11); + + EXPECT_EQ(kernels[0].code(), 0); + EXPECT_THAT(kernels[0].arguments(), ::testing::ElementsAreArray({0, 0})); + EXPECT_THAT(kernels[0].results(), ::testing::ElementsAreArray({1})); + EXPECT_THAT(kernels[0].last_uses(), ::testing::ElementsAreArray({0, 0})); + + for (int i = 1; i < 9; i++) { + EXPECT_EQ(kernels[i].code(), i % 2); + EXPECT_THAT(kernels[i].arguments(), + ::testing::ElementsAreArray({(i - 1) % 2 + 1, 0})); + EXPECT_THAT(kernels[i].results(), ::testing::ElementsAreArray({i % 2 + 1})); + EXPECT_THAT(kernels[i].last_uses(), ::testing::ElementsAreArray({1, 0})); + } + + EXPECT_EQ(kernels[9].code(), 2); + EXPECT_THAT(kernels[9].arguments(), ::testing::ElementsAreArray({1})); + EXPECT_THAT(kernels[9].last_uses(), ::testing::ElementsAreArray({true})); + EXPECT_THAT(kernels[9].results(), ::testing::ElementsAreArray({2, 3, 4})); + + EXPECT_EQ(kernels[10].code(), 3); + EXPECT_THAT(kernels[10].arguments(), ::testing::ElementsAreArray({0, 2, 2})); + EXPECT_THAT(kernels[10].last_uses(), + ::testing::ElementsAreArray({true, false, true})); + EXPECT_TRUE(kernels[10].results().empty()); +} + +template +absl::StatusOr DecodeAttribute(absl::string_view data) { + if (data.size() < sizeof(T)) + return absl::InvalidArgumentError("Invalid data size for attribute."); + + T value; + std::memcpy(&value, data.data(), sizeof(T)); + return value; +} + +TEST(MlirToByteCodeTest, BasicAttributes) { + constexpr char kBasicAttributesMlir[] = + "tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/" + "basic_attributes.mlir"; + + mlir::DialectRegistry registry; + registry.insert(); + mlir::MLIRContext mlir_context(registry); + mlir_context.allowUnregisteredDialects(); + auto mlir_module = mlir::parseSourceFile( + tsl::GetDataDependencyFilepath(kBasicAttributesMlir), &mlir_context); + + AttributeEncoderRegistry attribute_encoder_registry; + bc::Buffer buffer = + EmitExecutable(attribute_encoder_registry, mlir_module.get()).value(); + + bc::Executable executable(buffer.data()); + + auto attributes = executable.attributes(); + + ASSERT_EQ(attributes.size(), 14); + + auto attr_iter = attributes.begin(); + + EXPECT_EQ(*attr_iter, "test string"); + ++attr_iter; + + EXPECT_EQ(*attr_iter, "ts"); + ++attr_iter; + + EXPECT_THAT(DecodeAttribute(*attr_iter), + ::testing::status::IsOkAndHolds(100)); + ++attr_iter; + + EXPECT_THAT(DecodeAttribute(*attr_iter), + ::testing::status::IsOkAndHolds(200)); + ++attr_iter; + + EXPECT_THAT(DecodeAttribute(*attr_iter), + ::testing::status::IsOkAndHolds(::testing::FloatEq(3.0))); + ++attr_iter; + + EXPECT_THAT(DecodeAttribute(*attr_iter), + ::testing::status::IsOkAndHolds(0)); + ++attr_iter; + + bc::Vector list_of_i64((*attr_iter).data()); + EXPECT_THAT(list_of_i64, ::testing::ElementsAreArray({0, 1, 2, 3, 4})); + ++attr_iter; + + bc::Vector list_of_i32((*attr_iter).data()); + EXPECT_THAT(list_of_i32, ::testing::ElementsAreArray({0, 1, 2, 3})); + ++attr_iter; + + bc::Vector list_of_str((*attr_iter).data()); + EXPECT_THAT(list_of_str, + ::testing::ElementsAreArray({"string 0", "string 1"})); + ++attr_iter; + + EXPECT_THAT(DecodeAttribute(*attr_iter), + ::testing::status::IsOkAndHolds(1)); + EXPECT_EQ(executable.functions()[1].name().Get(), "callee"); + ++attr_iter; + + bc::Vector list_of_symbol_ref((*attr_iter).data()); + EXPECT_EQ(executable.functions()[2].name().Get(), "callee0"); + EXPECT_EQ(executable.functions()[3].name().Get(), "callee1"); + EXPECT_THAT(list_of_symbol_ref, ::testing::ElementsAreArray({2, 3})); + ++attr_iter; + + bc::Vector dense_array_of_i32((*attr_iter).data()); + EXPECT_THAT(dense_array_of_i32, ::testing::ElementsAreArray({0, 1, 2})); + ++attr_iter; + + bc::Vector dense_array_of_i64((*attr_iter).data()); + EXPECT_THAT(dense_array_of_i64, ::testing::ElementsAreArray({0, 1, 2})); + ++attr_iter; + + bc::Vector empty_dense_array((*attr_iter).data()); + EXPECT_TRUE(empty_dense_array.empty()); + + auto kernels = executable.functions()[0].kernels(); + ASSERT_EQ(kernels.size(), 15); + auto kernel_iter = kernels.begin(); + + auto attribute_span = [&](auto kernel_iter) { + return mlrt::AttributeSpan((*kernel_iter).attributes(), attributes); + }; + + EXPECT_EQ(attribute_span(kernel_iter).GetAs(0).Get(), + "test string"); + ++kernel_iter; + + EXPECT_EQ(attribute_span(kernel_iter).GetAs(0).Get(), "ts"); + ++kernel_iter; + + EXPECT_EQ(attribute_span(kernel_iter).GetAs(0), 100); + ++kernel_iter; + + EXPECT_EQ(attribute_span(kernel_iter).GetAs(0), 200); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs(0), + ::testing::FloatEq(3.0)); + ++kernel_iter; + + EXPECT_EQ(attribute_span(kernel_iter).GetAs(0), false); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs>(0), + ::testing::ElementsAreArray({0, 1, 2, 3, 4})); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs>(0), + ::testing::ElementsAreArray({0, 1, 2, 3})); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs>(0), + ::testing::ElementsAreArray({"string 0", "string 1"})); + ++kernel_iter; + + EXPECT_EQ(attribute_span(kernel_iter).GetAs(0), 1); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs>(0), + ::testing::ElementsAreArray({2, 3})); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs>(0), + ::testing::ElementsAreArray({0, 1, 2})); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs>(0), + ::testing::ElementsAreArray({0, 1, 2})); + ++kernel_iter; + + EXPECT_THAT(attribute_span(kernel_iter).GetAs>(0), + ::testing::IsEmpty()); +} + +TEST(MlirToByteCodeTest, UnsupportedAttributes) { + constexpr char kUnsupportedAttributesMlir[] = + "tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/" + "unsupported_attributes.mlir"; + + mlir::DialectRegistry registry; + registry.insert(); + mlir::MLIRContext mlir_context(registry); + mlir_context.allowUnregisteredDialects(); + auto mlir_module = mlir::parseSourceFile( + tsl::GetDataDependencyFilepath(kUnsupportedAttributesMlir), + &mlir_context); + + AttributeEncoderRegistry attribute_encoder_registry; + EXPECT_THAT(EmitExecutable(attribute_encoder_registry, mlir_module.get()), + ::testing::status::CanonicalStatusIs( + absl::StatusCode::kInvalidArgument, + "Try to encode unsupported attribute: unit")); +} + +class CustomDense { + public: + struct StorageType { + using Self = StorageType; + DEFINE_BYTECODE_FIELD(bc::Vector, shape); + DEFINE_BYTECODE_FIELD(bc::Vector, data); + }; + + class Constructor { + public: + Constructor(bc::Allocator* allocator, bc::BcAddr_t address) + : allocator_(allocator), address_(address) {} + + template + auto construct_shape(Args&&... args) { + return StorageType::construct_shape(allocator_, address_, + std::forward(args)...); + } + template + auto construct_data(Args&&... args) { + return StorageType::construct_data(allocator_, address_, + std::forward(args)...); + } + + bc::BcAddr_t address() const { return address_; } + + private: + bc::Allocator* allocator_; + bc::BcAddr_t address_; + }; + using NonTrivialConstructorType = Constructor; + + explicit CustomDense(const char* p) : p_(p) {} + + bc::Vector shape() const { return StorageType::read_shape(p_); } + bc::Vector data() const { return StorageType::read_data(p_); } + + private: + const char* p_ = nullptr; +}; + +absl::StatusOr EncodeCustomDense(const ModuleEmitterContext&, + mlir::Attribute attr) { + auto dense_int_attr = attr.dyn_cast(); + if (!dense_int_attr) + return absl::InvalidArgumentError( + "The element of the custom dense attribute must be an integer."); + + if (dense_int_attr.getElementType().cast().getWidth() != + 32) { + return absl::InvalidArgumentError( + "The element of the custom dense attribute must be an i32 integer."); + } + + bc::Buffer buffer; + bc::Allocator allocator(&buffer); + auto custom_dense_ctor = bc::New(&allocator); + + auto shaped_type = dense_int_attr.getType(); + std::vector shape(shaped_type.getShape().begin(), + shaped_type.getShape().end()); + custom_dense_ctor.construct_shape(shape); + + custom_dense_ctor.construct_data(shaped_type.getNumElements()) + .Place(dense_int_attr.getRawData().data(), + dense_int_attr.getRawData().size()); + + return std::string(buffer.data(), buffer.size()); +} + +TEST(MlirToByteCodeTest, CustomDense) { + constexpr char kCustomAttributesMlir[] = + "tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/" + "custom_attributes.mlir"; + + mlir::DialectRegistry registry; + registry.insert(); + mlir::MLIRContext mlir_context(registry); + mlir_context.allowUnregisteredDialects(); + auto mlir_module = mlir::parseSourceFile( + tsl::GetDataDependencyFilepath(kCustomAttributesMlir), &mlir_context); + + AttributeEncoderRegistry attribute_encoder_registry; + attribute_encoder_registry.Register("test_custom", &EncodeCustomDense); + bc::Buffer buffer = + EmitExecutable(attribute_encoder_registry, mlir_module.get()).value(); + + bc::Executable executable(buffer.data()); + + auto attributes = executable.attributes(); + + ASSERT_EQ(attributes.size(), 10); + for (int i = 0; i < 10; ++i) { + bc::String attr_data = attributes[i]; + + CustomDense custom_dense(attr_data.data()); + EXPECT_THAT(custom_dense.shape(), ::testing::ElementsAreArray({1})); + EXPECT_THAT(custom_dense.data(), ::testing::ElementsAreArray({i})); + } +} + +} // namespace +} // namespace mlrt diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc new file mode 100644 index 00000000000..b5a3cb9550c --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc @@ -0,0 +1,174 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h" + +namespace mlrt { +namespace testing { + +absl::StatusOr EncodeAttribute(const tensorflow::AttrValue& attr) { + if (attr.has_b()) { + std::string result; + result.resize(sizeof(uint8_t)); + uint8_t v = attr.b(); + std::memcpy(result.data(), &v, sizeof(v)); + return result; + } + + if (attr.has_i()) { + std::string result; + result.resize(sizeof(int64_t)); + int64_t v = attr.i(); + std::memcpy(result.data(), &v, sizeof(v)); + return result; + } + + if (attr.has_f()) { + std::string result; + result.resize(sizeof(float)); + float v = attr.f(); + std::memcpy(result.data(), &v, sizeof(v)); + return result; + } + + if (attr.has_s()) { + return attr.s(); + } + + if (attr.has_list()) { + if (attr.list().s_size() > 0) { + mlrt::bc::Buffer buffer; + mlrt::bc::Allocator allocator(&buffer); + auto ctor = mlrt::bc::New>( + &allocator, attr.list().s_size()); + + for (int i = 0; i < attr.list().s_size(); ++i) { + ctor.ConstructAt(i, attr.list().s(i)); + } + + return std::string(buffer.data(), buffer.size()); + } + } + + if (attr.has_tensor()) { + mlrt::bc::Buffer buffer; + mlrt::bc::Allocator allocator(&buffer); + + tensorflow::Tensor tensor; + if (!tensor.FromProto(attr.tensor())) { + return absl::InvalidArgumentError("Invalid tensor proto."); + } + + auto tensor_attr_ctor = mlrt::bc::New( + &allocator, tensor.dtype()); + + auto shape = tensor.shape().dim_sizes(); + + tensor_attr_ctor.construct_shape(shape.size()) + .Assign(shape.begin(), shape.end()); + + auto tensor_data = tensor.tensor_data(); + tensor_attr_ctor.construct_data(tensor_data.size()) + .Place(tensor_data.data(), tensor_data.size()); + + return std::string(buffer.data(), buffer.size()); + } + + // TODO(chky,rohitju): Add more attribute support. + + return absl::InvalidArgumentError("Unsupported attribute."); +} + +namespace { + +bool CanBeInlined(const tensorflow::AttrValue& attr) { + return attr.has_b() || attr.has_f(); +} + +} // namespace + +absl::Status EncodeAttributes(AttributeTable& attributes, + const tensorflow::AttrValueMap& attr_map) { + std::vector> attrs( + attr_map.begin(), attr_map.end()); + std::sort(attrs.begin(), attrs.end(), + [](const auto& x, const auto& y) { return x.first < y.first; }); + + for (int i = 0; i < attrs.size(); ++i) { + const tensorflow::AttrValue& attr = attrs[i].second; + TF_ASSIGN_OR_RETURN(auto attr_str, EncodeAttribute(attr)); + if (CanBeInlined(attr)) { + attributes.AddInline(absl::StrCat(i), attr_str); + } else { + attributes.Add(absl::StrCat(i), attr_str); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr>> +CreateKernelAndAttrs(int num_inputs, int num_outputs, + mlrt::ExecutionContext& exec_ctx, mlrt::bc::Buffer* buffer, + const tensorflow::AttrValueMap& attrs) { + mlrt::bc::Allocator allocator(buffer); + auto attributes_ctor = mlrt::bc::New>( + &allocator, attrs.size()); + AttributeTable attribute_table(attributes_ctor); + TF_RETURN_IF_ERROR(EncodeAttributes(attribute_table, attrs)); + + auto kernel_ctor = mlrt::bc::New(&allocator); + kernel_ctor.set_code(0); + + std::vector input_indices(num_inputs); + std::iota(input_indices.begin(), input_indices.end(), 0); + kernel_ctor.construct_arguments(input_indices.size()) + .Assign(input_indices.begin(), input_indices.end()); + + std::vector output_indices(num_outputs); + std::iota(output_indices.begin(), output_indices.end(), num_inputs); + kernel_ctor.construct_results(output_indices.size()) + .Assign(output_indices.begin(), output_indices.end()); + + std::vector attr_indices; + attr_indices.reserve(attrs.size()); + for (int i = 0; i < attrs.size(); ++i) { + attr_indices.push_back(attribute_table.GetHandle(absl::StrCat(i))); + } + + kernel_ctor.construct_attributes(attr_indices.size()) + .Assign(attr_indices.begin(), attr_indices.end()); + + mlrt::bc::Vector attributes( + buffer->Get(attributes_ctor.address())); + mlrt::bc::Kernel kernel(buffer->Get(kernel_ctor.address())); + + return std::make_pair(kernel, attributes); +} + +} // namespace testing +} // namespace mlrt diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h new file mode 100644 index 00000000000..fd2d491923f --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h @@ -0,0 +1,115 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_TEST_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "learning/brain/experimental/tfrt/native_lowering/kernels/sync_context.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/value.h" +#include "tensorflow/core/tfrt/utils/tensor_util.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/host_allocator.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/tensor/dense_tensor_utils.h" // from @tf_runtime + +namespace mlrt { +namespace testing { + +absl::StatusOr EncodeAttribute(const tensorflow::AttrValue& attr); + +absl::Status EncodeAttributes(AttributeTable& attributes, + const tensorflow::AttrValueMap& attr_map); + +absl::StatusOr>> +CreateKernelAndAttrs(int num_inputs, int num_outputs, + mlrt::ExecutionContext& exec_ctx, mlrt::bc::Buffer* buffer, + const tensorflow::AttrValueMap& attrs = {}); + +template +absl::Status TestMlrtKernel( + absl::string_view kernel_name, absl::Span regs, + tfrt::HostContext* host, int num_inputs, int num_outputs, + absl::Span expected_outputs, + mlrt::KernelRegistry* registry, bool approx_equal = false, + const tensorflow::AttrValueMap& attrs = {}) { + mlrt::ExecutionContext execution_context(nullptr); + + mlrt::bc::Buffer buffer; + TF_ASSIGN_OR_RETURN(auto kernel_and_attrs, + CreateKernelAndAttrs(num_inputs, num_outputs, + execution_context, &buffer, attrs)); + + tfrt::ExecutionContext tfrt_execution_context( + *tfrt::RequestContextBuilder(host, nullptr).build()); + tensorflow::tfrt_stub::SyncResourceState sync_resource_state; + auto sync_context = + std::make_unique(*host, &sync_resource_state); + execution_context.AddUserContext(std::move(sync_context)); + + auto kernel_fn = registry->Get(kernel_name); + mlrt::KernelFrame::State state(regs, kernel_and_attrs.second, + &execution_context); + mlrt::KernelFrame frame(&state); + frame.set_kernel(kernel_and_attrs.first); + + kernel_fn(frame); + + TF_RETURN_IF_ERROR(execution_context.status()); + + for (int i = 0, j = num_inputs; i < expected_outputs.size(); ++i, ++j) { + const auto& expected_output = expected_outputs[i]; + auto expected_dht = tfrt::ConvertTfTensorToDHT(expected_output); + if (!expected_dht) { + return absl::InternalError(tfrt::StrCat(expected_dht.takeError())); + } + + if (!approx_equal) { + if (!tfrt::TensorEqual(regs[j].Get(), + *expected_dht)) { + return absl::InternalError( + absl::StrCat("wrong result for ", kernel_name)); + } + } else { + if (!tfrt::TensorApproxEqual(regs[j].Get(), + *expected_dht)) { + return absl::InternalError( + absl::StrCat("wrong result for ", kernel_name)); + } + } + } + + return absl::OkStatus(); +} + +} // namespace testing +} // namespace mlrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_TEST_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/basic.mlir b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/basic.mlir new file mode 100644 index 00000000000..a5d5f98332b --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/basic.mlir @@ -0,0 +1,13 @@ +func.func @add_i32_10(%c0: i32) -> (i32, i32, i32) { + %c1 = "test_mlbc.add.i32"(%c0, %c0) : (i32, i32) -> i32 + %c2 = "test_mlbc.sub.i32"(%c1, %c0) : (i32, i32) -> i32 + %c3 = "test_mlbc.add.i32"(%c2, %c0) : (i32, i32) -> i32 + %c4 = "test_mlbc.sub.i32"(%c3, %c0) : (i32, i32) -> i32 + %c5 = "test_mlbc.add.i32"(%c4, %c0) : (i32, i32) -> i32 + %c6 = "test_mlbc.sub.i32"(%c5, %c0) : (i32, i32) -> i32 + %c7 = "test_mlbc.add.i32"(%c6, %c0) : (i32, i32) -> i32 + %c8 = "test_mlbc.sub.i32"(%c7, %c0) : (i32, i32) -> i32 + %c9 = "test_mlbc.add.i32"(%c8, %c0) : (i32, i32) -> i32 + %c10, %c11, %c12 = call @add_i32_10(%c9) : (i32) -> (i32, i32, i32) + func.return %c0, %c10, %c10 : i32, i32, i32 +} diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/basic_attributes.mlir b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/basic_attributes.mlir new file mode 100644 index 00000000000..db72598bb4a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/basic_attributes.mlir @@ -0,0 +1,29 @@ +func.func @simple_attributes() { + "test_custom.attribute"() {value = "test string"} : () -> () + "test_custom.attribute"() {value = "ts"} : () -> () + "test_custom.attribute"() {value = 100 : i32} : () -> () + "test_custom.attribute"() {value = 200 : i64} : () -> () + "test_custom.attribute"() {value = 3.0 : f32} : () -> () + "test_custom.attribute"() {value = false} : () -> () + "test_custom.attribute"() {value = [0, 1, 2, 3, 4]} : () -> () + "test_custom.attribute"() {value = [0 : i32, 1 : i32, 2 : i32, 3 : i32]} : () -> () + "test_custom.attribute"() {value = ["string 0", "string 1"]} : () -> () + "test_custom.attribute"() {value = @callee} : () -> () + "test_custom.attribute"() {value = [@callee0, @callee1]} : () -> () + "test_custom.attribute"() {value = array} : () -> () + "test_custom.attribute"() {value = array} : () -> () + "test_custom.attribute"() {value = array} : () -> () + func.return +} + +func.func @callee() { + return +} + +func.func @callee0() { + return +} + +func.func @callee1() { + return +} diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/custom_attributes.mlir b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/custom_attributes.mlir new file mode 100644 index 00000000000..54f325092f3 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/custom_attributes.mlir @@ -0,0 +1,13 @@ +func.func @add_const_custom_dense_i32_10(%c0: i32) -> i32 { + %c1 = "test_custom.add.const.i32"(%c0) {value = dense<[0]> : tensor<1xi32>} : (i32) -> i32 + %c2 = "test_custom.add.const.i32"(%c1) {value = dense<[1]> : tensor<1xi32>} : (i32) -> i32 + %c3 = "test_custom.add.const.i32"(%c2) {value = dense<[2]> : tensor<1xi32>} : (i32) -> i32 + %c4 = "test_custom.add.const.i32"(%c3) {value = dense<[3]> : tensor<1xi32>} : (i32) -> i32 + %c5 = "test_custom.add.const.i32"(%c4) {value = dense<[4]> : tensor<1xi32>} : (i32) -> i32 + %c6 = "test_custom.add.const.i32"(%c5) {value = dense<[5]> : tensor<1xi32>} : (i32) -> i32 + %c7 = "test_custom.add.const.i32"(%c6) {value = dense<[6]> : tensor<1xi32>} : (i32) -> i32 + %c8 = "test_custom.add.const.i32"(%c7) {value = dense<[7]> : tensor<1xi32>} : (i32) -> i32 + %c9 = "test_custom.add.const.i32"(%c8) {value = dense<[8]> : tensor<1xi32>} : (i32) -> i32 + %c10 = "test_custom.add.const.i32"(%c9) {value = dense<[9]> : tensor<1xi32>} : (i32) -> i32 + func.return %c10 : i32 +} diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/unsupported_attributes.mlir b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/unsupported_attributes.mlir new file mode 100644 index 00000000000..4c060815c95 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata/unsupported_attributes.mlir @@ -0,0 +1,5 @@ +func.func @unsupported_attributes() { + "test_custom.attribute"() {unit} : () -> () + func.return +} + diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index e451cf737f3..7b731307531 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -38,6 +38,7 @@ enum class TfrtDeviceInfraTarget { std::ostream& operator<<(std::ostream& os, TfrtDeviceInfraTarget device_target); struct TfrtCompileOptions { + std::string saved_model_dir; // TODO(tfrt-devs): Ideally, compiler should make the decision where // to place the variable. std::string variable_device = "/job:localhost/replica:0/task:0/device:CPU:0"; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 145cce5ac6b..29d96e79c47 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -16,7 +16,6 @@ load( ) load( "//tensorflow/core/platform:build_config.bzl", - "if_llvm_aarch64_available", "if_llvm_system_z_available", "tf_proto_library", ) @@ -102,6 +101,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep "@llvm-project//llvm:Analysis", "@llvm-project//llvm:CodeGen", @@ -119,8 +119,6 @@ tf_cc_binary( "@llvm-project//mlir:ToLLVMIRTranslation", ] + if_llvm_system_z_available([ "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep - ]) + if_llvm_aarch64_available([ - "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep ]), ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 29ae6752cd2..e4892ba1d2e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -83,6 +83,7 @@ cc_library( ":tf_framework_ops_inc_gen", ":tf_status_inc_gen", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "//tensorflow/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/status", "@llvm-project//mlir:AllocationOpInterface", "@llvm-project//mlir:BufferizationDialect", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index 48e288eb48d..e6f87f387ed 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -28,6 +28,7 @@ limitations under the License. // Generated dialect definitions. #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.cc.inc" +#include "tensorflow/tsl/protobuf/error_codes.pb.h" namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD index 4d4abfd9d90..3fa42457e1d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", test_file_exts = ["mlir"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir index e2a5601fc53..6d90a339e5c 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -13,8 +13,9 @@ func.func @alloc(%ctx: !tf_framework.op_kernel_context, func.return %buf : memref } // Compute number of elements. -// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(10 : index) : i64 -// CHECK: [[NUM_ELEM_0:%.*]] = llvm.mul [[SIZE_0]], [[SIZE_1]] : i64 +// CHECK: [[SIZE_1A:%.*]] = llvm.mlir.constant(10 : index) : i64 +// CHECK: [[SIZE_1B:%.*]] = llvm.mlir.constant(10 : index) : i64 +// CHECK: [[NUM_ELEM_0:%.*]] = llvm.mul [[SIZE_0]], [[SIZE_1B]] : i64 // CHECK: [[NUM_ELEMS:%.*]] = llvm.mul [[NUM_ELEM_0]], [[SIZE_2]] : i64 // Compute the size of an individual element. @@ -48,9 +49,9 @@ func.func @alloc(%ctx: !tf_framework.op_kernel_context, // CHECK: [[DESC_4:%.*]] = llvm.insertvalue [[SIZE_2]], [[DESC_3]][3, 2] // CHECK: [[DESC_5:%.*]] = llvm.insertvalue [[STRIDE_2]], [[DESC_4]][4, 2] // CHECK: [[STRIDE_1:%.*]] = llvm.mul [[STRIDE_2]], [[SIZE_2]] : i64 -// CHECK: [[DESC_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[DESC_5]][3, 1] +// CHECK: [[DESC_6:%.*]] = llvm.insertvalue [[SIZE_1A]], [[DESC_5]][3, 1] // CHECK: [[DESC_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[DESC_6]][4, 1] -// CHECK: [[STRIDE_0:%.*]] = llvm.mul [[STRIDE_1]], [[SIZE_1]] : i64 +// CHECK: [[STRIDE_0:%.*]] = llvm.mul [[STRIDE_1]], [[SIZE_1A]] : i64 // CHECK: [[DESC_8:%.*]] = llvm.insertvalue [[SIZE_0]], [[DESC_7]][3, 0] // CHECK: [[DESC_9:%.*]] = llvm.insertvalue [[STRIDE_0]], [[DESC_8]][4, 0] // CHECK: llvm.return [[DESC_9]] : [[DESC_TY]] @@ -212,7 +213,7 @@ func.func @jit_execute(%ctx: !tf_framework.op_kernel_context, // CHECK: %[[ARG:.*]] = llvm.insertvalue %[[ARG_DESCR]], %[[T1]][1] // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) // CHECK: %[[RESULT_PTR:.*]] = llvm.alloca %[[C1]] x !llvm.struct<(i64, ptr)> - + // Copy argument(s) to stack-allocated buffer. // CHECK: %[[NUM_ARGS:.*]] = llvm.mlir.constant(1 : i64) // CHECK: %[[ARGS_PTR:.*]] = llvm.alloca %[[NUM_ARGS]] x !llvm.struct<(i64, ptr)> diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD index 38f0b297272..d0caa87983d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD @@ -6,6 +6,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], default_tags = [ # We need access to the CUDA SDK. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index fe9d26723b9..73cc324f405 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -176,8 +176,8 @@ llvm::Expected> Compile( } // Create the kernel. - mlir::OwningOpRef module; mlir::MLIRContext context; + mlir::OwningOpRef module; if (item.result_module().empty()) { // Otherwise, compile the module now. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc index 8d5d583a2dc..e6ab814216a 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -17,6 +17,7 @@ // This file implements the entry point to compile a tf op to a kernel. // //===----------------------------------------------------------------------===// +#include #include #include #include @@ -50,10 +51,15 @@ namespace { static llvm::codegen::RegisterCodeGenFlags CGF; -std::unique_ptr GetTargetMachine(llvm::Module* module) { +std::unique_ptr GetTargetMachine( + llvm::StringRef host_triple, llvm::Module* module) { llvm::Triple triple(module->getTargetTriple()); if (triple.getTriple().empty()) { - triple = llvm::Triple(llvm::sys::getDefaultTargetTriple()); + if (!host_triple.empty()) { + triple = llvm::Triple(host_triple); + } else { + triple = llvm::Triple(llvm::sys::getDefaultTargetTriple()); + } module->setTargetTriple(triple.getTriple()); } @@ -71,14 +77,15 @@ std::unique_ptr GetTargetMachine(llvm::Module* module) { } // Compiles the given MLIR module via LLVM into an executable binary format. -StatusOr EmitToBinary(mlir::ModuleOp module) { +StatusOr EmitToBinary(llvm::StringRef host_triple, + mlir::ModuleOp module) { // Translate the module. llvm::LLVMContext llvm_context; mlir::registerLLVMDialectTranslation(*module->getContext()); std::unique_ptr llvm_module = mlir::translateModuleToLLVMIR(module, llvm_context); - auto target_machine = GetTargetMachine(llvm_module.get()); + auto target_machine = GetTargetMachine(host_triple, llvm_module.get()); llvm_module->setDataLayout(target_machine->createDataLayout()); // Run LLVM's mid-level optimizer to clean up the IR. @@ -106,6 +113,7 @@ StatusOr EmitToBinary(mlir::ModuleOp module) { } Status Run(llvm::StringRef input_file, llvm::StringRef output_file, + llvm::StringRef host_triple, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, int64_t max_supported_rank, @@ -130,7 +138,7 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, /*apply_cl_options=*/true)); // Get binary. - TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module)); + TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(host_triple, *module)); // Write .a file. TF_RETURN_IF_ERROR( @@ -167,6 +175,8 @@ int main(int argc, char** argv) { llvm::cl::opt jit_compile( "jit", llvm::cl::desc("Generate only a JIT compiler invocation."), llvm::cl::init(false)); + llvm::cl::opt host_triple( + "host-triple", llvm::cl::desc("Override host triple for module")); llvm::cl::list architectures( "arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); @@ -189,16 +199,25 @@ int main(int argc, char** argv) { llvm::cl::init(false)); tensorflow::InitMlir y(&argc, &argv); - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); + + LLVMInitializeX86Target(); + LLVMInitializeX86TargetInfo(); + LLVMInitializeX86TargetMC(); + LLVMInitializeX86AsmPrinter(); + + LLVMInitializeAArch64Target(); + LLVMInitializeAArch64TargetInfo(); + LLVMInitializeAArch64TargetMC(); + LLVMInitializeAArch64AsmPrinter(); + mlir::registerPassManagerCLOptions(); mlir::registerMLIRContextCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "TF op kernel generator\n"); auto status = tensorflow::kernel_gen::Run( - input_file, output_file, architectures, tile_sizes, unroll_factors, - max_supported_rank, print_ptx, print_llvmir, enable_ftz, index_64bit, - jit_compile, jit_i64_indexed_for_large_tensors); + input_file, output_file, host_triple, architectures, tile_sizes, + unroll_factors, max_supported_rank, print_ptx, print_llvmir, enable_ftz, + index_64bit, jit_compile, jit_i64_indexed_for_large_tensors); if (!status.ok()) { LOG(ERROR) << status; return 1; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc index e2cf9051761..796f133cff5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc index e54f36f5684..af0943ded1f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc index 1a8bc8882f8..ed1138849e5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc index afba89b635d..5698d6c7025 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms/Transforms.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 0d76cd4c93c..561e87e6dda 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -13,6 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include +#include +#include + #include "llvm/Transforms/Utils/Cloning.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/parallel_loops_to_sequential.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/parallel_loops_to_sequential.cc index 0bfacccc1a1..35b11c8abf0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/parallel_loops_to_sequential.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/parallel_loops_to_sequential.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 167b370e17f..84712606e98 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_PASSES_H_ #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewrite_tf_framework_assert.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewrite_tf_framework_assert.cc index b6f3a98237c..16116cac215 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewrite_tf_framework_assert.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewrite_tf_framework_assert.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc index 246556da0ca..1e422f6ab88 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc @@ -18,6 +18,8 @@ limitations under the License. // sizes of operands with equal shapes. #include +#include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc index 3be09ac912b..b308241aede 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc @@ -16,6 +16,9 @@ limitations under the License. // This file combines patterns for lowering shape dialect to standard ops, // structured control flow and descriptors. +#include +#include + #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index 9a5b0749888..8adb4c1eebe 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include #include "mlir/Conversion/LLVMCommon/Pattern.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -125,7 +127,8 @@ class TFAllocOpConverter : public ConvertToLLVMCallOpPattern { llvm::to_vector<4>(adaptor.getDynSizes()), rewriter, sizes, strides, sizeBytes); // Get number of elements. - Value num_elements = getNumElements(loc, sizes, rewriter); + Value num_elements = + getNumElements(loc, memref_type, adaptor.getDynSizes(), rewriter); // Get element size. Value element_size = getSizeInBytes(loc, memref_type.getElementType(), rewriter); diff --git a/tensorflow/compiler/mlir/tosa/tests/BUILD b/tensorflow/compiler/mlir/tosa/tests/BUILD index e7c4a5b9a61..a523ba82942 100644 --- a/tensorflow/compiler/mlir/tosa/tests/BUILD +++ b/tensorflow/compiler/mlir/tosa/tests/BUILD @@ -7,6 +7,7 @@ package( ) glob_lit_tests( + name = "all_tests", data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", size_override = { diff --git a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir index fbadf264186..803c1415dc4 100644 --- a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir @@ -1,23 +1,25 @@ // RUN: tf-opt --split-input-file --tosa-strip-quant-types --verify-each %s | FileCheck %s -// CHECK-LABEL: @test_add_qi8 -// CHECK-SAME: %arg0: tensor) -> tensor -func.func @test_add_qi8(%arg0: tensor>) -> tensor> { - %0 = "tosa.add"(%arg0, %arg0) : (tensor>, tensor>) -> tensor> +// ----- - // CHECK: %[[VAR0:.+]] = "tosa.add"(%arg0, %arg0) : (tensor, tensor) -> tensor - // CHECK: return %[[VAR0]] : tensor - func.return %0 : tensor> +// CHECK-LABEL: @test_max_pool2d_qi8 +// CHECK-SAME: %arg0: tensor<1x4x4x4xi8>) -> tensor<1x4x4x4xi8> +func.func @test_max_pool2d_qi8(%arg0: tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> { + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> + + // CHECK: %[[VAR0:.+]] = "tosa.max_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> : (tensor<1x4x4x4xi8>) -> tensor<1x4x4x4xi8> + // CHECK: return %[[VAR0]] : tensor<1x4x4x4xi8> + func.return %0 : tensor<1x4x4x4x!quant.uniform> } -// ---- +// ----- -// CHECK-LABEL: @test_add_qu8 +// CHECK-LABEL: @test_bitwise_not_qu8 // CHECK-SAME: %arg0: tensor) -> tensor -func.func @test_add_qu8(%arg0: tensor>) -> tensor> { - %0 = "tosa.add"(%arg0, %arg0) : (tensor>, tensor>) -> tensor> +func.func @test_bitwise_not_qu8(%arg0: tensor>) -> tensor> { + %0 = "tosa.bitwise_not"(%arg0) : (tensor>) -> tensor> - // CHECK: %[[VAR0:.+]] = "tosa.add"(%arg0, %arg0) : (tensor, tensor) -> tensor + // CHECK: %[[VAR0:.+]] = "tosa.bitwise_not"(%arg0) : (tensor) -> tensor // CHECK: return %[[VAR0]] : tensor - func.return %0 : tensor> + func.return %0 : tensor> } diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 5cacdf03552..47e2571e2bb 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -565,7 +565,7 @@ func.func @test_argmax(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xi32> { // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{acc_type = f32, kernel = array, pad = array, stride = array}> func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { %2 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> func.return %2 : tensor<1x32x32x8xf32> @@ -617,8 +617,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { // CHECK-LABEL: test_strided_slice // CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array} // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} func.func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> { %2 = "tf.Const"() {value = dense<[4, 0, 1]> : tensor<3xi64>} : () -> tensor<3xi64> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index ddfc7eefe81..145b1877761 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -962,7 +962,7 @@ func.func @test_less_equal_dynamic(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x? // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{acc_type = f32, kernel = array, pad = array, stride = array}> func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -971,7 +971,7 @@ func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_avg_pool2d_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{acc_type = f32, kernel = array, pad = array, stride = array}> func.func @test_avg_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -1064,14 +1064,14 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_strided_slice_simple // CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, 3, 1]> : tensor<3xi32> - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1079,14 +1079,14 @@ func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32 // CHECK-LABEL: test_strided_slice_simple_negative // CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_strided_slice_simple_negative(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, -3, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, 3, 1]> : tensor<3xi32> - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 1 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1099,7 +1099,7 @@ func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<* %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, 1, 1]> : tensor<3xi32> - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 2 : i32} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 2 : i32, offset = false} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1107,14 +1107,14 @@ func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<* // CHECK-LABEL: test_strided_slice_shrink // CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_strided_slice_shrink(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, 3, 1]> : tensor<3xi32> - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 5 : i32} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 5 : i32, offset = false} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1127,7 +1127,7 @@ func.func @test_strided_slice_shrink_ignore_stride(%arg0: tensor<13x21x3xf32>) - %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, 3, 1]> : tensor<3xi32> - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 3 : i32} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 3 : i32, offset = false} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1142,7 +1142,7 @@ func.func @test_strided_slice_unstrided(%arg0: tensor<13x21x3xf32>) -> tensor<*x %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, 1, -1]> : tensor<3xi32> - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<13x21x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1157,7 +1157,7 @@ func.func @test_strided_slice_unstrided_shorter(%arg0: tensor<13x21x3xf32>) -> t %cst = arith.constant dense<[4, 0]> : tensor<2xi32> %cst_0 = arith.constant dense<[13, 21]> : tensor<2xi32> %cst_1 = arith.constant dense<[1, -1]> : tensor<2xi32> - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<13x21x3xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 3 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<13x21x3xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1171,7 +1171,7 @@ func.func @test_strided_slice_unstrided_shorter(%arg0: tensor<13x21x3xf32>) -> t func.func @test_strided_slice_dynamic_masked(%arg0: tensor<10x?x?xf32>, %arg1: tensor<3xi32>) -> tensor<*xf32> { %cst_0 = arith.constant dense<[13, -1, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, -1, -1]> : tensor<3xi32> - %0 = "tfl.strided_slice"(%arg0, %arg1, %cst_0, %cst_1) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 7 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %arg1, %cst_0, %cst_1) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 7 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1190,7 +1190,7 @@ func.func @test_strided_slice_dynamic_begin(%arg0: tensor<10x?x?xf32>) -> tensor // CHECK: %[[VAR0:.*]] = "tosa.reverse"(%arg0) <{axis = 1 : i64}> // CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) <{axis = 2 : i64}> // CHECK: return %[[VAR1]] - %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 7 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 7 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } // ----- @@ -1203,10 +1203,10 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<* %stride = arith.constant dense<[1, 2, -1]> : tensor<3xi32> // CHECK: %[[SLICE1:.+]] = "tosa.slice"(%arg0) <{size = array, start = array}> - // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) <{new_shape = array}> - // CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) <{size = array, start = array}> + // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) <{new_shape = array}> + // CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) <{size = array, start = array}> // CHECK: %[[RESHAPE2:.+]] = "tosa.reshape"(%[[SLICE2]]) <{new_shape = array}> - %0 = "tfl.strided_slice"(%arg0, %begin, %end, %stride) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 2 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 4 : i32} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> + %0 = "tfl.strided_slice"(%arg0, %begin, %end, %stride) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 2 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 4 : i32, offset = false} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> // CHECK: return %[[RESHAPE2]] func.return %0 : tensor<*xf32> } @@ -1882,7 +1882,7 @@ func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform, pad = array, quantization_info = #tosa.unary_quant, stride = array}> +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{acc_type = i32, kernel = array, pad = array, quantization_info = #tosa.unary_quant, stride = array}> // CHECK-SAME: -> tensor<1x32x32x8x!quant.uniform> func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -1892,7 +1892,7 @@ func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform, pad = array, stride = array}> +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{acc_type = i32, kernel = array, pad = array, stride = array}> // CHECK-SAME: -> tensor<1x32x32x8xi16> func.func @test_avg_pool2d_i16(%arg0: tensor<1x32x32x8xi16>) -> tensor<*xi16> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xi16>) -> tensor<*xi16> diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc index 62058cbe799..c3ef4e0bd67 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc @@ -27,7 +27,9 @@ limitations under the License. #include #include #include +#include #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project @@ -58,7 +60,7 @@ namespace { class ConvertUint8ToInt8 : public impl::TosaConvertTFLUint8PassBase { public: - explicit ConvertUint8ToInt8() {} + explicit ConvertUint8ToInt8() = default; void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc index 13c60e8af21..b64e4eda6d5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project @@ -36,7 +37,7 @@ namespace { class TosaDequantizeTFLSoftmax : public impl::TosaDequantizeTFLSoftmaxPassBase { public: - explicit TosaDequantizeTFLSoftmax() {} + explicit TosaDequantizeTFLSoftmax() = default; void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc index 1fc479502d2..a3521eea92b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include #include +#include #include #include +#include #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -44,7 +46,7 @@ namespace { class FuseBiasTF : public impl::TosaFusebiasTFPassBase { public: - explicit FuseBiasTF() {} + explicit FuseBiasTF() = default; void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 54539429695..e4cf77db4cd 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -23,12 +23,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" +#include #include +#include #include #include #include +#include #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -1414,6 +1418,14 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, return std::nullopt; } + // beta is not exposed from the TF API, assume only beta=1.0 is supported + // For more details: https://github.com/tensorflow/tensorflow/issues/60435 + if (beta != 1.0) { + (void)rewriter.notifyMatchFailure( + op, "beta values other than 1.0 are not supported"); + return std::nullopt; + } + // reduce_sum on last dimension int32_t input_rank = input_type.getShape().size(); ArrayRef logits_shape = output_type.getShape(); @@ -2271,9 +2283,7 @@ std::optional convertStridedSliceOp( // tensor // // 2. Reshape2: Reshape the tensor from (1) such that each dimension with - // stride is split into two dimensions of size_i/stride_i, stride_i. A naive - // implementation doubles the input tensor rank, but only dimensions being - // strided actually need to be doubled. + // abs(stride) != 1 is split into two dimensions of size_i/stride_i, stride_i. // // 3. Slice3: Slice the tensor from (2) such that we select index [0] from // each of the stride_i dimensions in (2) @@ -2316,7 +2326,6 @@ std::optional convertStridedSliceOp( int32_t strides_size = strides.size(); for (auto stride : strides) all_strides_one &= abs(stride) == 1; - // If all of the masks are set we can just bypass the entire thing. const int32_t all_masks_one = (1 << strides_size) - 1; @@ -2448,10 +2457,14 @@ std::optional convertStridedSliceOp( } // Step 2: reshape the sliced array - SmallVector a2_shape(input_rank * 2); + SmallVector a2_shape; for (int i = 0; i < input_rank; ++i) { - a2_shape[i * 2 + 0] = a1_size[i] == -1 ? -1 : a1_size[i] / abs(strides[i]); - a2_shape[i * 2 + 1] = abs(strides[i]); + int64_t abs_stride_i = abs(strides[i]); + a2_shape.push_back(a1_size[i] == -1 ? -1 : a1_size[i] / abs_stride_i); + if (abs_stride_i != 1) { + // only add a stride dimension if strides[i] != 1 + a2_shape.push_back(abs_stride_i); + } } auto a2_reshape_op = CreateOpAndInfer( @@ -2462,19 +2475,24 @@ std::optional convertStridedSliceOp( tensorflow::ConvertMlirShapeToTF(a2_shape))); // Step 3: take a slice along the strides - SmallVector a3_begin(input_rank * 2), a3_size(input_rank * 2); + SmallVector a3_begin, a3_size; for (int i = 0; i < input_rank; ++i) { - a3_begin[i * 2 + 0] = 0; - a3_begin[i * 2 + 1] = 0; + int64_t abs_stride_i = abs(strides[i]); + a3_begin.push_back(0); if (shrink_axis_mask & (1 << i)) { - a3_size[i * 2 + 0] = 1; + a3_size.push_back(1); } else { - a3_size[i * 2 + 0] = - (a1_size[i] == -1) ? -1 : (a1_size[i] / abs(strides[i])); + a3_size.push_back((a1_size[i] == -1) ? -1 : (a1_size[i] / abs_stride_i)); + } + if (abs_stride_i != 1) { + // previous reshape only adds a stride dimension if strides[i] != 1 + a3_begin.push_back(0); + a3_size.push_back(1); } - a3_size[i * 2 + 1] = 1; } + assert(a2_shape.size() == a3_begin.size()); + assert(a2_shape.size() == a3_size.size()); auto a3_slice_op = CreateOpAndInfer( rewriter, op->getLoc(), diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 5418eab622c..082cfe74018 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -632,8 +632,14 @@ LogicalResult ConvertTFAvgPoolOp::matchAndRewrite( return failure(); } - CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_avgpool_op.getValue(), kernel, stride, pad); + // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time + // FP16 is supported, the accumulator type can be selected based on trade-off + // between performance and accuracy. Set to FP32 by default. + auto acc_attr = mlir::TypeAttr::get(rewriter.getF32Type()); + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_avgpool_op.getValue(), kernel, + stride, pad, acc_attr); return success(); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf_tfl.cc index d40688d570d..72c86e40a7e 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf_tfl.cc @@ -15,6 +15,9 @@ limitations under the License. // Legalize TensorFlow and TensorFlow Lite to TOSA +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -38,7 +41,7 @@ namespace { // Performs lowering to TOSA dialect class LegalizeTFTFL : public impl::TosaLegalizeTFTFLPassBase { public: - explicit LegalizeTFTFL() {} + explicit LegalizeTFTFL() = default; void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 0162ddd4a8a..87573d30ed5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -15,6 +15,7 @@ limitations under the License. // Legalize TensorFlow Lite to TOSA +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -681,7 +683,12 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, Value output; if (output_is_qtype && input_lhs_is_qtype && input_rhs_is_qtype) { - ShapedType rescale_type = output_type.clone(rewriter.getI32Type()); + ShapedType rescale_type_output = output_type.clone(rewriter.getI32Type()); + ShapedType rescale_type_input_left = + input_lhs_type.clone(rewriter.getI32Type()); + ShapedType rescale_type_input_right = + input_rhs_type.clone(rewriter.getI32Type()); + UniformQuantizedType input_lhs_qtype = input_lhs_type.getElementType() .dyn_cast(); @@ -743,10 +750,11 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, Value op1_none_half_scale_intermediate; if (output_qtype.getStorageTypeIntegralWidth() == 16) { auto tfl_add_lhs_casted = CreateOpAndInfer( - rewriter, op->getLoc(), rescale_type, tfl_add_op.getLhs()); + rewriter, op->getLoc(), rescale_type_input_left, + tfl_add_op.getLhs()); op1_none_half_scale_intermediate = CreateOpAndInfer( - rewriter, op->getLoc(), rescale_type, + rewriter, op->getLoc(), rescale_type_input_left, tfl_add_lhs_casted.getResult(), getTosaConstTensorSingleI32(rewriter, op, input_shift)); } else { @@ -773,10 +781,11 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, Value op2_none_half_scale_intermediate; if (output_qtype.getStorageTypeIntegralWidth() == 16) { auto tfl_add_rhs_casted = CreateOpAndInfer( - rewriter, op->getLoc(), rescale_type, tfl_add_op.getRhs()); + rewriter, op->getLoc(), rescale_type_input_right, + tfl_add_op.getRhs()); op2_none_half_scale_intermediate = CreateOpAndInfer( - rewriter, op->getLoc(), rescale_type, + rewriter, op->getLoc(), rescale_type_input_right, tfl_add_rhs_casted.getResult(), getTosaConstTensorSingleI32(rewriter, op, input_shift)); } else { @@ -789,8 +798,9 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, } #endif // TFLITE_DOUBLE_ROUNDING - auto op3_add_op1_op2 = CreateOpAndInfer( - rewriter, op->getLoc(), rescale_type, op1_rescale_lhs, op2_rescale_rhs); + auto op3_add_op1_op2 = + CreateOpAndInfer(rewriter, op->getLoc(), rescale_type_output, + op1_rescale_lhs, op2_rescale_rhs); Value op4_rescale_op3 = buildRescaleFromInt32( rewriter, op, output_type, op3_add_op1_op2.getResult(), output_rescale_scale, output_qtype.getZeroPoint()); @@ -1190,6 +1200,13 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( auto average_etype = input_type.getElementType(); auto average_type = output_type.clone(average_etype); + // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time + // FP16 is supported, the accumulator type can be selected based on trade-off + // between performance and accuracy. Set to FP32 by default. + TypeAttr acc_attr = average_etype.isa() + ? mlir::TypeAttr::get(rewriter.getF32Type()) + : mlir::TypeAttr::get(rewriter.getIntegerType(32)); + Value result; if (average_etype.isa()) { // TensorFlow Lite doesn't use the zero point when calculating @@ -1200,11 +1217,11 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( /*input_zp=*/0, /*output_zp=*/0); result = CreateOpAndInfer( rewriter, op->getLoc(), average_type, tfl_avgpool_op.getInput(), - kernel_size, stride, pad, quant_attr); + kernel_size, stride, pad, acc_attr, quant_attr); } else { result = CreateOpAndInfer( rewriter, op->getLoc(), average_type, tfl_avgpool_op.getInput(), - kernel_size, stride, pad); + kernel_size, stride, pad, acc_attr); } if (average_type != output_type) { result = CreateOpAndInfer(rewriter, op->getLoc(), output_type, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index ff8616687a2..29f913edb87 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" +#include +#include +#include #include #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index b2e76197fb5..07a781c6240 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc index da989275d52..0180574f1d0 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc @@ -31,6 +31,8 @@ limitations under the License. // resulting graph is free of illegal complex tensors. #include +#include +#include #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -55,7 +57,7 @@ namespace { class LowerComplexTypes : public impl::TosaLowerComplexTypesPassBase { public: - explicit LowerComplexTypes() {} + explicit LowerComplexTypes() = default; void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc index c55908f31fc..b0ce9d0d80f 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc @@ -27,7 +27,9 @@ limitations under the License. #include #include #include +#include #include +#include #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project @@ -56,7 +58,7 @@ namespace { class StripQuantTypes : public impl::TosaStripQuantTypesPassBase { public: - explicit StripQuantTypes() {} + explicit StripQuantTypes() = default; void runOnOperation() override; }; diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc index d966d887b1f..ba50a923bc4 100644 --- a/tensorflow/compiler/mlir/utils/name_utils.cc +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/utils/name_utils.h" #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 1fb31b9db74..d3ea4b077a9 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1,16 +1,17 @@ +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_strict_test") +load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "cuda_py_test", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load( "//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", - "tf_xla_py_test", ) load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") +load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -38,7 +39,7 @@ package_group( generate_backend_suites() -py_library( +py_strict_library( name = "xla_test", testonly = 1, srcs = ["xla_test.py"], @@ -46,21 +47,24 @@ py_library( visibility = [":friends"], deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:random_seed", - "//tensorflow/python:session", - "//tensorflow/python:variables", + "//tensorflow/python/client:session", "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:random_seed", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:flags", "//tensorflow/python/platform:tf_logging", + "//tensorflow/python/tpu:tpu_py", "//third_party/py/numpy", ], ) -py_library( +py_strict_library( name = "test_utils", testonly = 1, srcs = [ @@ -74,7 +78,7 @@ py_library( ], ) -py_test( +py_strict_test( name = "xla_test_test", size = "small", srcs = ["xla_test_test.py"], @@ -88,7 +92,7 @@ py_test( ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "adadelta_test", size = "medium", srcs = ["adadelta_test.py"], @@ -99,15 +103,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "adagrad_test", size = "small", srcs = ["adagrad_test.py"], @@ -118,16 +123,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], @@ -139,15 +144,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "adam_test", size = "small", srcs = ["adam_test.py"], @@ -159,16 +166,19 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variable_scope", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "add_n_test", size = "small", srcs = ["add_n_test.py"], @@ -181,16 +191,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:list_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:list_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "argminmax_test", size = "small", srcs = ["argminmax_test.py"], @@ -202,15 +212,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "binary_ops_test", size = "medium", srcs = ["binary_ops_test.py"], @@ -224,19 +234,22 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:math_ops_gen", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", + "//tensorflow/python/ops:bitwise_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:math_ops_gen", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:nn_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "complex_div_test", size = "medium", srcs = ["complex_div_test.py"], @@ -254,16 +267,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:math_ops_gen", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops_gen", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "bucketize_op_test", size = "small", srcs = ["bucketize_op_test.py"], @@ -274,15 +286,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], @@ -294,15 +307,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:standard_ops", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:random_seed", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "cholesky_op_test", size = "medium", srcs = ["cholesky_op_test.py"], @@ -315,17 +331,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:map_fn", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:linalg_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "cond_test", size = "small", srcs = ["cond_test.py"], @@ -336,20 +353,26 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:cond", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:control_flow_switch_case", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:training", - "//tensorflow/python/eager:function", + "//tensorflow/python/client:session", + "//tensorflow/python/compiler/xla", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:cond", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/ops:control_flow_switch_case", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:tensor_array_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], @@ -361,18 +384,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:map_fn", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:linalg_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "searchsorted_op_test", size = "small", timeout = "moderate", @@ -384,12 +404,13 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:platform_test", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "svd_op_test", size = "medium", srcs = ["svd_op_test.py"], @@ -406,18 +427,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:map_fn", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:linalg_ops", + "//tensorflow/python/ops:linalg_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "matrix_inverse_op_test", size = "small", timeout = "moderate", @@ -429,15 +449,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:linalg_ops", + "//tensorflow/python/ops:math_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "matrix_solve_op_test", size = "small", timeout = "moderate", @@ -449,14 +470,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:linalg_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", + "//tensorflow/python/ops:linalg_ops", + "//tensorflow/python/ops:random_ops", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "matrix_triangular_solve_op_test", size = "small", timeout = "moderate", @@ -469,16 +491,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:linalg_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "clustering_test", size = "small", srcs = ["clustering_test.py"], @@ -489,14 +513,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "concat_ops_test", size = "medium", srcs = ["concat_ops_test.py"], @@ -507,17 +534,19 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework", - "//tensorflow/python:gradient_checker", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:math_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "conv2d_test", size = "medium", srcs = ["conv2d_test.py"], @@ -530,17 +559,38 @@ tf_xla_py_test( deps = [ ":test_utils", ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:nn", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:nn_ops_gen", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( + name = "tensor_float_32_test", + size = "medium", + srcs = ["tensor_float_32_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + use_xla_device = False, # Uses tf.function(jit_compile=True) + deps = [ + ":xla_test", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + ], +) + +tf_xla_py_strict_test( name = "conv3d_test", size = "medium", srcs = ["conv3d_test.py"], @@ -552,16 +602,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:nn", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradient_checker", + "//tensorflow/python/ops:nn_grad", + "//tensorflow/python/ops:nn_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "depthwise_conv_op_test", size = "medium", srcs = ["depthwise_conv_op_test.py"], @@ -575,17 +627,19 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:nn", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_grad", + "//tensorflow/python/ops:nn_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "dynamic_slice_ops_test", size = "small", srcs = ["dynamic_slice_ops_test.py"], @@ -595,15 +649,16 @@ tf_xla_py_test( "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ - "//tensorflow/compiler/tests:xla_test", + ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "einsum_op_test", size = "medium", srcs = ["einsum_op_test.py"], @@ -619,14 +674,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", - "//tensorflow/python:special_math_ops", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:special_math_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "reshape_op_test", size = "small", srcs = ["reshape_op_test.py"], @@ -636,16 +692,16 @@ tf_xla_py_test( "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ - "//tensorflow/compiler/tests:xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", + ":xla_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:test", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "dynamic_stitch_test", size = "small", srcs = ["dynamic_stitch_test.py"], @@ -655,14 +711,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:data_flow_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], @@ -672,14 +729,14 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "eager_test", size = "medium", srcs = ["eager_test.py"], @@ -691,19 +748,33 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:cond", - "//tensorflow/python:framework", - "//tensorflow/python:layers", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", + "//tensorflow/core:protos_all_py", "//tensorflow/python:platform_test", - "//tensorflow/python:while_loop", - "//tensorflow/python/eager:function", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:indexed_slices", + "//tensorflow/python/framework:ops", + "//tensorflow/python/layers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:cond", + "//tensorflow/python/ops:embedding_ops", + "//tensorflow/python/ops:functional_ops", + "//tensorflow/python/ops:init_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:random_ops_gen", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:while_loop", + "//tensorflow/python/training:adam", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "fifo_queue_test", size = "medium", srcs = ["fifo_queue_test.py"], @@ -714,16 +785,14 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:data_flow_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], @@ -737,15 +806,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradients_impl", "//tensorflow/python/ops/signal", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "slice_ops_test", size = "medium", srcs = ["slice_ops_test.py"], @@ -757,14 +828,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "ftrl_test", size = "medium", srcs = ["ftrl_test.py"], @@ -776,16 +848,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "ftrl_ops_test", size = "medium", srcs = ["ftrl_ops_test.py"], @@ -796,17 +868,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/tpu", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "function_test", size = "small", srcs = ["function_test.py"], @@ -817,13 +888,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/ops:array_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "image_ops_test", size = "small", timeout = "long", @@ -837,20 +911,25 @@ tf_xla_py_test( python_version = "PY3", shard_count = 10, tags = [ + "no_oss", # TODO(b/282033702): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:image_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:image_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "listdiff_op_test", size = "small", srcs = ["listdiff_op_test.py"], @@ -862,16 +941,14 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "lrn_ops_test", size = "medium", srcs = ["lrn_ops_test.py"], @@ -882,15 +959,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:nn", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_ops_gen", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "manip_ops_test", size = "small", srcs = ["manip_ops_test.py"], @@ -901,14 +981,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:manip_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:manip_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "matrix_band_part_test", size = "medium", timeout = "long", @@ -921,15 +1002,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "matrix_diag_ops_test", size = "medium", timeout = "long", @@ -942,13 +1024,14 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:platform_test", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "momentum_test", size = "small", srcs = ["momentum_test.py"], @@ -959,16 +1042,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "nary_ops_test", size = "small", srcs = ["nary_ops_test.py"], @@ -979,14 +1064,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "nullary_ops_test", size = "small", srcs = ["nullary_ops_test.py"], @@ -997,13 +1084,14 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:control_flow_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], @@ -1015,15 +1103,19 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:nn_ops_gen", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "pooling_ops_3d_test", size = "medium", srcs = ["pooling_ops_3d_test.py"], @@ -1035,16 +1127,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:nn_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "proximal_adagrad_test", size = "medium", srcs = ["proximal_adagrad_test.py"], @@ -1055,14 +1148,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "proximal_gradient_descent_test", size = "medium", srcs = ["proximal_gradient_descent_test.py"], @@ -1073,14 +1168,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "qr_op_test", size = "medium", srcs = ["qr_op_test.py"], @@ -1098,17 +1195,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:linalg_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "unstack_test", size = "medium", srcs = ["unstack_test.py"], @@ -1126,17 +1223,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "random_ops_test", size = "medium", srcs = ["random_ops_test.py"], @@ -1147,17 +1242,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow:tensorflow_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:standard_ops", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops/distributions:special_math", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "reduce_ops_test", size = "medium", srcs = ["reduce_ops_test.py"], @@ -1169,16 +1265,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:errors", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "reduce_window_test", size = "small", srcs = ["reduce_window_test.py"], @@ -1190,15 +1288,15 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:errors", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/ops:array_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "reverse_ops_test", size = "medium", srcs = ["reverse_ops_test.py"], @@ -1209,13 +1307,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:test", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "reverse_sequence_op_test", size = "medium", srcs = ["reverse_sequence_op_test.py"], @@ -1227,15 +1327,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) # copybara:uncomment_begin(google-only) -# tf_xla_py_test( +# tf_xla_py_strict_test( # name = "reverse_sequence_op_args_test", # size = "medium", # srcs = ["reverse_sequence_op_args_test.py"], @@ -1249,17 +1349,16 @@ tf_xla_py_test( # deps = [ # ":xla_test", # "//tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport -# "//tensorflow/python:array_ops", -# "//tensorflow/python:framework", -# "//tensorflow/python:platform_test", # "//tensorflow/python/compat:v2_compat", -# "//tensorflow/python/eager:function", +# "//tensorflow/python/eager:def_function", +# "//tensorflow/python/framework:errors", +# "//tensorflow/python/ops:array_ops", # "//tensorflow/python/platform:client_testlib", # ], # ) # copybara:uncomment_end -tf_xla_py_test( +tf_xla_py_strict_test( name = "rmsprop_test", size = "small", srcs = ["rmsprop_test.py"], @@ -1270,16 +1369,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], @@ -1292,15 +1391,19 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "segment_reduction_ops_test", size = "medium", srcs = ["segment_reduction_ops_test.py"], @@ -1312,15 +1415,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/client", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "spacetobatch_op_test", size = "medium", srcs = ["spacetobatch_op_test.py"], @@ -1332,17 +1436,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "sparse_to_dense_op_test", - size = "small", + size = "medium", srcs = ["sparse_to_dense_op_test.py"], enable_mlir_bridge = True, python_version = "PY3", @@ -1351,15 +1456,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:sparse_ops", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:sparse_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], @@ -1372,15 +1478,17 @@ tf_xla_py_test( use_xla_device = False, deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/compiler/xla", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:data_flow_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "stateful_random_ops_test", size = "medium", srcs = ["stateful_random_ops_test.py"], @@ -1395,17 +1503,25 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:standard_ops", - "//tensorflow/python:stateful_random_ops", + "//tensorflow/python/client", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", "//tensorflow/python/kernel_tests/random:util", + "//tensorflow/python/ops:stateful_random_ops", + "//tensorflow/python/ops:stateful_random_ops_gen", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:flags", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "stateless_random_ops_test", size = "medium", srcs = ["stateless_random_ops_test.py"], @@ -1418,16 +1534,25 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:standard_ops", - "//tensorflow/python:stateless_random_ops", + "//tensorflow/python/client", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", "//tensorflow/python/kernel_tests/random:util", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/ops:stateless_random_ops_v2_gen", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "tensor_array_ops_test", size = "medium", srcs = ["tensor_array_ops_test.py"], @@ -1443,21 +1568,27 @@ tf_xla_py_test( use_xla_device = False, deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:math_ops_gen", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", - "//tensorflow/python:platform_test", - "//tensorflow/python:tensor_array_grad", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:training", + "//tensorflow/python/compiler/xla", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_util", + "//tensorflow/python/ops:data_flow_ops_gen", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:tensor_array_grad", + "//tensorflow/python/ops:tensor_array_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "tensor_list_ops_test", size = "small", srcs = ["tensor_list_ops_test.py"], @@ -1470,16 +1601,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:list_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python/eager:function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:list_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "ternary_ops_test", size = "medium", srcs = ["ternary_ops_test.py"], @@ -1491,38 +1624,46 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:math_ops_gen", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", - shard_count = 32, + shard_count = 50, tags = [ "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:bitwise_ops", + "//tensorflow/python/ops:functional_ops_gen", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:nn_ops_gen", + "//third_party/py/numpy", + "@six_archive//:six", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "fused_batchnorm_test", size = "medium", srcs = ["fused_batchnorm_test.py"], @@ -1535,20 +1676,18 @@ tf_xla_py_test( deps = [ ":test_utils", ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradient_checker", + "//tensorflow/python/ops:nn_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "variable_ops_test", size = "small", srcs = ["variable_ops_test.py"], @@ -1560,18 +1699,26 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:math_ops_gen", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:init_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:state_ops", + "//tensorflow/python/ops:state_ops_gen", + "//tensorflow/python/ops:variable_scope", + "//tensorflow/python/ops:variables", + "//tensorflow/python/training", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "while_test", size = "small", srcs = ["while_test.py"], @@ -1583,16 +1730,22 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - "//tensorflow/python:while_loop", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:map_fn", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:while_loop", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "case_test", size = "small", srcs = ["case_test.py"], @@ -1605,18 +1758,17 @@ tf_xla_py_test( use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_switch_case", - "//tensorflow/python:framework", "//tensorflow/python:image_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_switch_case", + "//tensorflow/python/ops:io_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "cast_ops_test", size = "small", srcs = ["cast_ops_test.py"], @@ -1628,18 +1780,20 @@ tf_xla_py_test( use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:image_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/ops:io_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:random_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], @@ -1650,16 +1804,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", "//tensorflow/python/platform:flags", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "gather_nd_op_test", size = "medium", srcs = ["gather_nd_op_test.py"], @@ -1670,14 +1825,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "scatter_nd_op_test", size = "medium", srcs = ["scatter_nd_op_test.py"], @@ -1689,14 +1845,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], @@ -1709,16 +1866,22 @@ tf_xla_py_test( "optonly", ], deps = [ - "//tensorflow/compiler/tests:xla_test", + ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "data_format_ops_test", size = "small", srcs = ["data_format_ops_test.py"], @@ -1728,15 +1891,16 @@ tf_xla_py_test( "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ - "//tensorflow/compiler/tests:xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:nn_ops", + ":xla_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], @@ -1748,14 +1912,17 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -cuda_py_test( +cuda_py_strict_test( name = "xla_device_gpu_test", size = "small", srcs = ["xla_device_gpu_test.py"], @@ -1765,16 +1932,16 @@ cuda_py_test( xla_enable_strict_auto_jit = False, xla_enabled = True, deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", + "//tensorflow/python/client:session", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", ], ) -cuda_py_test( +cuda_py_strict_test( name = "jit_test", size = "medium", srcs = ["jit_test.py"], @@ -1788,21 +1955,25 @@ cuda_py_test( deps = [ ":test_utils", "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:cond", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:while_loop", + "//tensorflow/python/client:session", "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:cond", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:while_loop", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -cuda_py_test( +cuda_py_strict_test( name = "async_comp_test", size = "medium", srcs = ["async_comp_test.py"], @@ -1813,18 +1984,18 @@ cuda_py_test( xla_enable_strict_auto_jit = False, xla_enabled = True, deps = [ - ":test_utils", "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/client:session", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", ], ) -cuda_py_test( +cuda_py_strict_test( name = "dense_layer_test", size = "medium", srcs = ["dense_layer_test.py"], @@ -1836,11 +2007,13 @@ cuda_py_test( deps = [ ":test_utils", "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:layers", - "//tensorflow/python:variables", "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/framework:ops", + "//tensorflow/python/layers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) @@ -1962,22 +2135,23 @@ tf_cuda_cc_test( ], ) -py_library( +py_strict_library( name = "lstm", testonly = 1, srcs = ["lstm.py"], srcs_version = "PY3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_v1", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:variable_v1", "@six_archive//:six", ], ) -cuda_py_test( +cuda_py_strict_test( name = "lstm_test", srcs = ["lstm_test.py"], tags = [ @@ -1988,13 +2162,17 @@ cuda_py_test( deps = [ ":lstm", ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:variables", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:init_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) @@ -2021,7 +2199,7 @@ tf_library( tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "fake_quant_ops_test", size = "medium", srcs = ["fake_quant_ops_test.py"], @@ -2032,12 +2210,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "placeholder_test", size = "small", srcs = ["placeholder_test.py"], @@ -2048,13 +2229,14 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "quantized_ops_test", size = "medium", srcs = ["quantized_ops_test.py"], @@ -2066,16 +2248,18 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:bitwise_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:bitwise_ops", + "//tensorflow/python/ops:math_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "xla_ops_test", size = "medium", srcs = ["xla_ops_test.py"], @@ -2086,16 +2270,26 @@ tf_xla_py_test( ], deps = [ ":xla_test", + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:array_ops", - "//tensorflow/python:errors", - "//tensorflow/python:framework", + "//tensorflow/compiler/xla:xla_data_proto_py", "//tensorflow/python:platform_test", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:stateless_random_ops", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "xla_custom_call_ops_test", size = "small", srcs = ["xla_custom_call_ops_test.py"], @@ -2113,14 +2307,16 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/ops:random_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "runtime_shape_check_test", size = "small", srcs = ["runtime_shape_check_test.py"], @@ -2137,13 +2333,17 @@ tf_xla_py_test( use_xla_device = False, deps = [ ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "conv_node_name_test", size = "medium", srcs = ["conv_node_name_test.py"], @@ -2156,17 +2356,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:ops", + "//tensorflow/python/layers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_ops", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "tridiagonal_solve_ops_test", size = "medium", srcs = ["tridiagonal_solve_ops_test.py"], @@ -2178,15 +2377,20 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:standard_ops", + "//tensorflow/python:gradients", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops/linalg:linalg_impl", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "tridiagonal_matmul_ops_test", size = "medium", srcs = ["tridiagonal_matmul_ops_test.py"], @@ -2198,15 +2402,22 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:gradient_checker_v2", - "//tensorflow/python:linalg_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:standard_ops", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:gradient_checker_v2", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/ops/linalg:linalg_impl", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], @@ -2219,14 +2430,20 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:gradient_checker_v2", - "//tensorflow/python:math_ops", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:gradient_checker_v2", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:math_ops_gen", + "//tensorflow/python/ops:random_ops_gen", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + "@absl_py//absl/flags", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "repeat_op_test", size = "medium", srcs = ["repeat_op_test.py"], @@ -2239,13 +2456,14 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:math_ops", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", "//tensorflow/python/platform:client_testlib", - "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "image_ops_jit_compile_test", size = "medium", srcs = ["image_ops_jit_compile_test.py"], @@ -2262,13 +2480,20 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:math_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", - "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "ensure_shape_op_test", size = "medium", srcs = ["ensure_shape_op_test.py"], @@ -2280,14 +2505,15 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:check_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "where_op_test", size = "small", srcs = ["where_op_test.py"], @@ -2303,17 +2529,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:errors", - "//tensorflow/python:framework", - "//tensorflow/python/compiler/xla:compiler_py", - "//tensorflow/python/tpu", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/tpu:tpu_py", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "where_op_tpu_test", size = "small", srcs = ["where_op_test.py"], @@ -2335,17 +2560,16 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:errors", - "//tensorflow/python:framework", - "//tensorflow/python/compiler/xla:compiler_py", - "//tensorflow/python/tpu", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/tpu:tpu_py", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "risc_ops_test", size = "small", srcs = ["risc_ops_test.py"], @@ -2356,16 +2580,18 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", - "//tensorflow/python:is_mlir_bridge_test_true", "//tensorflow/python:platform_test", - "//tensorflow/python/eager:function", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:is_mlir_bridge_test_true", + "//tensorflow/python/framework:ops", "//tensorflow/python/ops/risc:risc_ops", "//tensorflow/python/platform:client_testlib", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "const_arg_test", size = "small", srcs = ["const_arg_test.py"], @@ -2376,14 +2602,14 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:is_mlir_bridge_test_false", + "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", ], ) -cuda_py_test( +cuda_py_strict_test( name = "const_test", size = "small", srcs = ["const_test.py"], @@ -2391,14 +2617,16 @@ cuda_py_test( xla_enable_strict_auto_jit = False, xla_enabled = True, deps = [ - ":xla_test", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tpu_py_test( +tpu_py_strict_test( name = "giant_const_op_test", srcs = [ "giant_const_op_test.py", @@ -2410,14 +2638,19 @@ tpu_py_test( python_version = "PY3", tags = ["no_oss"], deps = [ - "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/distribute:tpu_strategy", + "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:remote", "//tensorflow/python/eager:test", - "//tensorflow/python/tpu:tpu_lib", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/platform:flags", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "sharding_util_ops_test", srcs = ["sharding_util_ops_test.py"], disabled_backends = [ @@ -2434,40 +2667,41 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:session", - "//tensorflow/python:variables", + "//tensorflow/python/client:session", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:tpu_ops_gen", + "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/tpu:tpu_lib", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tpu_py_test( +tpu_py_strict_test( name = "approx_topk_test", srcs = ["approx_topk_test.py"], disable_experimental = False, disable_mlir_bridge = False, tags = ["no_oss"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:variables", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "xla_call_module_test", size = "small", srcs = ["xla_call_module_test.py"], @@ -2479,14 +2713,50 @@ tf_xla_py_test( use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ ":xla_test", + "//tensorflow/compiler/mlir/stablehlo", + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", - "//tensorflow/python:training", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( + name = "xla_call_module_no_platform_check_test", + size = "small", + srcs = ["xla_call_module_no_platform_check_test.py"], + enable_mlir_bridge = False, + env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=platform"}, + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + use_xla_device = False, # Uses tf.function(jit_compile=True) + deps = [ + ":xla_test", + "//tensorflow/compiler/mlir/stablehlo", + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +tf_xla_py_strict_test( name = "bincount_op_test", size = "small", srcs = ["bincount_op_test.py"], @@ -2499,10 +2769,12 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:errors", + "//tensorflow/python/ops:math_ops_gen", ], ) -tf_xla_py_test( +tf_xla_py_strict_test( name = "unique_ops_test", size = "small", srcs = ["unique_ops_test.py"], @@ -2518,9 +2790,11 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_gen", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 8bf82c34644..7343bb9b89e 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -1,6 +1,7 @@ """Build rules for Tensorflow/XLA testing.""" load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:strict.default.bzl", "py_strict_test") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -21,6 +22,7 @@ def tf_xla_py_test( disabled_backends = None, use_xla_device = True, enable_mlir_bridge = True, + test_rule = py_test, **kwargs): """Generates py_test targets, one per XLA backend. @@ -111,7 +113,7 @@ def tf_xla_py_test( extra_tag = [] updated_name = test_name - mlir_bridge_dep = "//tensorflow/python:is_mlir_bridge_test_true" + mlir_bridge_dep = "//tensorflow/python/framework:is_mlir_bridge_test_true" has_mlir_dep = (mlir_bridge_dep in deps) if mlir_option: if updated_name.endswith("_test"): @@ -130,7 +132,7 @@ def tf_xla_py_test( # version. continue - py_test( + test_rule( name = updated_name, srcs = srcs, srcs_version = "PY3", @@ -145,6 +147,9 @@ def tf_xla_py_test( test_names.append(updated_name) native.test_suite(name = name, tests = test_names) +def tf_xla_py_strict_test(**kwargs): + tf_xla_py_test(test_rule = py_strict_test, **kwargs) + def generate_backend_suites(backends = []): """Generates per-backend test_suites that run all tests for a backend.""" if not backends: diff --git a/tensorflow/compiler/tests/giant_const_op_test.py b/tensorflow/compiler/tests/giant_const_op_test.py index 014b9d5f1eb..9a73a95cb34 100644 --- a/tensorflow/compiler/tests/giant_const_op_test.py +++ b/tensorflow/compiler/tests/giant_const_op_test.py @@ -25,7 +25,6 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.platform import flags -from tensorflow.python.tpu import tpu_strategy_util FLAGS = flags.FLAGS flags.DEFINE_string("tpu", "", "Name of TPU to connect to.") @@ -45,7 +44,7 @@ def get_tpu_cluster_resolver(): def get_tpu_strategy(): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) - tpu_strategy_util.initialize_tpu_system(resolver) + tpu_cluster_resolver.initialize_tpu_system(resolver) return tpu_lib.TPUStrategyV2(resolver) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index b5d3baec7fc..d4d9afaad54 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -43,8 +43,19 @@ limitations under the License. // * StridedSliceGrad (need to use shape function to compute sensible inputs) #include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include #include +#include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" @@ -504,7 +515,7 @@ OpTest::OpTest() { << ". To reproduce the " "results of this test, pass flag --tf_xla_random_seed=" << seed; - generator_.reset(new std::mt19937(seed)); + generator_ = std::make_unique(seed); } namespace { @@ -532,7 +543,7 @@ template class TensorGenerator { public: explicit TensorGenerator(OpTest& test) : test_(test) {} - virtual ~TensorGenerator() {} + virtual ~TensorGenerator() = default; virtual DataType dtype() = 0; virtual void RandomVals(std::optional lo, std::optional hi, bool needs_unique_values, diff --git a/tensorflow/compiler/tests/sharding_util_ops_test.py b/tensorflow/compiler/tests/sharding_util_ops_test.py index 26e39ca2a2b..7d5ac5771f1 100644 --- a/tensorflow/compiler/tests/sharding_util_ops_test.py +++ b/tensorflow/compiler/tests/sharding_util_ops_test.py @@ -294,7 +294,7 @@ class XlaSplitNDOpTest(xla_test.XLATestCase, parameterized.TestCase): def testRanked(self, graph_fn, rank): num_splits = [2] * rank num_outputs = 2 << (rank - 1) - input_value = np.reshape(np.arange(np.product(num_splits)), num_splits) + input_value = np.reshape(np.arange(np.prod(num_splits)), num_splits) for dtype in self.numeric_types: with self.session() as sess, self.device_scope(): split = graph_fn( diff --git a/tensorflow/compiler/tests/tensor_float_32_test.py b/tensorflow/compiler/tests/tensor_float_32_test.py new file mode 100644 index 00000000000..f02b69948f4 --- /dev/null +++ b/tensorflow/compiler/tests/tensor_float_32_test.py @@ -0,0 +1,106 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that the PrecisionConfig is set if TF32 is disabled.""" + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import def_function +from tensorflow.python.framework import config +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class TensorFloat32ConvTest(xla_test.XLATestCase): + + def tearDown(self): + super().tearDown() + config.enable_tensor_float_32_execution(True) + + def _test_fn(self, fn, inputs): + with ops.device('device:{}:0'.format(self.device)): + # Test with TF32 disabled + config.enable_tensor_float_32_execution(False) + compiled_fn = def_function.function(fn, jit_compile=True) + hlo_text = compiled_fn.experimental_get_compiler_ir(*inputs)(stage='hlo') + self.assertIn('operand_precision={highest,highest}', hlo_text) + + # Test the output is sufficiently precise by comparing with FP64 results + out = compiled_fn(*inputs) + f64_out = compiled_fn(*[math_ops.cast(x, 'float64') for x in inputs]) + self.assertAllClose(out, f64_out, rtol=1e-5, atol=1e-5) + + # Test with TF32 enabled. Recompile fn because enabling TF32 does not + # reset function cache. + config.enable_tensor_float_32_execution(True) + compiled_fn = def_function.function(fn, jit_compile=True) + hlo_text = compiled_fn.experimental_get_compiler_ir(*inputs)(stage='hlo') + # operand_precision is not in HLO if it's the default value. + self.assertNotIn('operand_precision', hlo_text) + + def test_matmul(self): + x = array_ops.fill((1024, 1024), 1 + 2**-12) + y = array_ops.fill((1024, 1024), 1.0) + + def matmul(x, y): + return math_ops.matmul(x, y) + + self._test_fn(matmul, [x, y]) + + def test_batch_matmul(self): + x = array_ops.fill((2, 1024, 1024), 1 + 2**-12) + y = array_ops.fill((2, 1024, 1024), 1.0) + + def batch_matmul(x, y): + return math_ops.matmul(x, y) + + self._test_fn(batch_matmul, [x, y]) + + def test_conv2d(self): + x = array_ops.fill((2, 20, 20, 32), 1 + 2**-12) + y = array_ops.fill((3, 3, 32, 32), 1.0) + + def conv2d(x, y): + return nn_ops.conv2d(x, y, [1, 1, 1, 1], padding='SAME') + + self._test_fn(conv2d, [x, y]) + + def test_conv2d_backprop_input(self): + y = array_ops.fill((3, 3, 32, 32), 1 + 2**-12) + out_backprop = array_ops.fill((2, 20, 20, 32), 1.0) + + def conv2d_backprop_input(y, out_backprop): + return nn_ops.conv2d_backprop_input( + (2, 20, 20, 32), y, out_backprop, [1, 1, 1, 1], padding='SAME' + ) + + self._test_fn(conv2d_backprop_input, [y, out_backprop]) + + def test_conv2d_backprop_filter(self): + x = array_ops.fill((2, 20, 20, 32), 1 + 2**-12) + out_backprop = array_ops.fill((2, 20, 20, 32), 1.0) + + def conv2d_backprop_filter(x, out_backprop): + return nn_ops.conv2d_backprop_filter( + x, (3, 3, 32, 32), out_backprop, [1, 1, 1, 1], padding='SAME' + ) + + self._test_fn(conv2d_backprop_filter, [x, out_backprop]) + + +if __name__ == '__main__': + ops.enable_eager_execution() + googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 34aeeafe976..c944c9e22e0 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -953,6 +953,17 @@ class UnaryOpsTest(xla_test.XLATestCase): lambda x: array_ops.bitcast(x, dtypes.uint64), np.array([1, 0x100000003f800000], np.int64), expected=np.array([1, 0x100000003f800000], np.uint64)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.float64), + np.array( + [0, 0x3FF0000000000000, 0xc3af161421c8e000, 0x4032000000000007], + np.uint64, + ), + expected=np.array( + [0, 1.0, -1.12e+18, 18.000000000000024869], np.float64 + ), + atol=0 + ) def testBitcastInt8ToFloat(self): self._assertOpOutputMatchesExpected( diff --git a/tensorflow/compiler/tests/xla_call_module_no_platform_check_test.py b/tensorflow/compiler/tests/xla_call_module_no_platform_check_test.py new file mode 100644 index 00000000000..9146711bda3 --- /dev/null +++ b/tensorflow/compiler/tests/xla_call_module_no_platform_check_test.py @@ -0,0 +1,87 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA call module op wrapper with disabled platform check. + +This test runs with --tf_xla_call_module_disabled_checks=platform +""" +from typing import Tuple + +import numpy as np + +from tensorflow.compiler.mlir.stablehlo import stablehlo +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +def serialize(module_str: str) -> Tuple[str, int]: + target = stablehlo.get_minimum_version() + byte_str = stablehlo.serialize_portable_artifact(module_str, target) + return byte_str, xla.call_module_maximum_supported_version() + + +class XlaCallModuleOpTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, + op, + args, + expected, + equality_fn=None): + """Asserts op(*args) == expected.""" + with self.session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def test_platforms_errors(self): + """Error reporting for the platforms attribute.""" + x = np.float32(0.) + + module_str = """ +module @jit_f.0 { + func.func public @main(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} +""" + module, version = serialize(module_str) + def f(x): + return xla.call_module( + [x], version=version, + module=module, + Tout=[np.float32], + Sout=[()], + platforms=['RANDOM_PLATFORM'], + disabled_checks=[]) + # No error even though the `platforms` does not match the testing platform + self._assertOpOutputMatchesExpected(f, (x,), (x,)) + + +if __name__ == '__main__': + # This test is using Tensorflow sessions which are not compatible with eager + # mode. + ops.disable_eager_execution() + googletest.main() diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index 01f30718217..31abd1f700e 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -18,24 +18,24 @@ import unittest import numpy as np +from tensorflow.compiler.mlir.stablehlo import stablehlo from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.ops import gen_xla_ops from tensorflow.compiler.tf2xla.python import xla - from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest +from tensorflow.python.platform import test def serialize(module_str: str) -> Tuple[str, int]: - # TODO(b/274838200): error importing xla_extension in OSS - # target_version = '0.9.0' # TODO(gleasonk): use APIs to get this - # return xla_extension.mlir.serialize_portable_artifact( - # module_str, target_version), 4 - return module_str, 3 + target = stablehlo.get_minimum_version() + byte_str = stablehlo.serialize_portable_artifact(module_str, target) + return byte_str, xla.call_module_maximum_supported_version() class XlaCallModuleOpTest(xla_test.XLATestCase): @@ -64,7 +64,10 @@ class XlaCallModuleOpTest(xla_test.XLATestCase): if self.device in ['CPU', 'XLA_CPU']: return 'CPU' elif self.device in ['GPU', 'XLA_GPU']: - return 'CUDA' + if test.is_built_with_rocm(): + return 'ROCM' + else: + return 'CUDA' elif self.device in ['TPU', 'XLA_TPU']: return 'TPU' else: @@ -85,7 +88,34 @@ module @jit_f.0 { } """) return xla.call_module([x], version=version, - module=module, Tout=[x.dtype], Sout=[x.shape]) + module=module, Tout=[x.dtype], Sout=[x.shape], + platforms=[self.testing_platform()]) + + self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) + + def test_basic_with_token(self): + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + def f(x): + # sin(cos(x)) + module, version = serialize(""" +module @jit_f.0 { + func.func public @main(%arg0: !stablehlo.token, %arg1: tensor<3xf32>) -> (!stablehlo.token, tensor<3xf32>) { + %0 = stablehlo.cosine %arg1 : tensor<3xf32> + %1 = stablehlo.sine %0 : tensor<3xf32> + return %arg0, %1 : !stablehlo.token, tensor<3xf32> + } +} +""") + return xla.call_module( + [x], + version=version, + module=module, + Tout=[x.dtype], + Sout=[x.shape], + has_token_input_output=True, + platforms=[self.testing_platform()], + ) self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) @@ -107,7 +137,8 @@ module @jit_f_jax.0 { return xla.call_module([x], version=version, module=module, Tout=[res.dtype], - Sout=[res.shape]) + Sout=[res.shape], + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -129,17 +160,19 @@ module @jit_f.0 { return xla.call_module([x, y], version=version, module=module, Tout=[x.dtype, y.dtype], - Sout=[x.shape, y.shape]) + Sout=[x.shape, y.shape], + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x, y), (np.sin(x), np.cos(y))) + # TODO(b/283439649): remove dim_args_spec support def test_dim_var_basic(self): x = np.arange(6, dtype=np.float32).reshape((2, 3)) def f(x): # x: f32[2, b] # Module takes another argument which is the value of b # (sin(x), x.shape[1]) - module, version = serialize(""" + module, _ = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { %0 = stablehlo.sine %arg1 : tensor<2x?xf32> @@ -147,21 +180,24 @@ module @jit_f.0 { } } """) - return xla.call_module([x], version=version, - module=module, - Tout=[x.dtype, np.int32], - Sout=[(None, 3), ()], - dim_args_spec=['0.1']) + return gen_xla_ops.xla_call_module( + [x], + version=4, + module=module, + Tout=[x.dtype, np.int32], + Sout=[(None, 3), ()], + dim_args_spec=['0.1']) self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) + # TODO(b/283439649): remove dim_args_spec support def test_dim_var_basic_dim_arg_i64(self): x = np.arange(6, dtype=np.float32).reshape((2, 3)) def f(x): # x: f32[2, b] # Module takes another argument which is the value of b # (sin(x), x.shape[1]) - module, version = serialize(""" + module, _ = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { %0 = stablehlo.sine %arg1 : tensor<2x?xf32> @@ -169,11 +205,12 @@ module @jit_f.0 { } } """) - return xla.call_module([x], - module=module, version=version, - Tout=[x.dtype, np.int64], - Sout=[(None, 3), ()], - dim_args_spec=['0.1']) + return gen_xla_ops.xla_call_module( + [x], + module=module, version=4, + Tout=[x.dtype, np.int64], + Sout=[(None, 3), ()], + dim_args_spec=['0.1']) self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) @@ -199,89 +236,64 @@ module @jit_f.0 { return xla.call_module([x], module=module, version=version, Tout=[x.dtype, np.int32], - Sout=[(None, 3), ()]) + Sout=[(None, 3), ()], + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) - def test_dim_args_spec_errors(self): - # x, y: f32[2, b, c] - x = np.arange(24, dtype=np.float32).reshape((2, 3, 4)) - y = x + def test_wrong_actual_args_errors(self): + x = np.arange(6, dtype=np.float32).reshape((3, 2)) + y = np.arange(6, dtype=np.int32).reshape((2, 3)) - # Module takes two prefix arguments with the values of b and c - # return (sin(x + y), x.shape[1]) + # x: f32[a, 2], return x module, version = serialize(""" module @jit_f.0 { - func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x?x?xf32>, %arg3: tensor<2x?x?xf32>) -> (tensor<2x?x?xf32>, tensor) { - %0 = stablehlo.add %arg2, %arg3 : tensor<2x?x?xf32> - %1 = stablehlo.sine %0 : tensor<2x?x?xf32> - return %1, %arg0 : tensor<2x?x?xf32>, tensor + func.func public @main(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor { + return %arg0 : tensor } } """) - dim_args_spec = ['0.1', '0.2'] def f(x, y): - return xla.call_module([x, y], - module=module, version=version, - Tout=[x.dtype, np.int32], - Sout=[(None, 3), ()], - dim_args_spec=dim_args_spec) - self._assertOpOutputMatchesExpected(f, (x, y), (np.sin(x + y), x.shape[1])) + return xla.call_module( + [x, y], + module=module, + version=version, + Tout=[x.dtype], + Sout=[(None, 2)], + platforms=[self.testing_platform()], + ) - dim_args_spec = ['0.0', '0.0', '0.0', '0.0'] # Too many dim_args_spec + self._assertOpOutputMatchesExpected(f, (x, y), (x,)) + + x_bad_etype = x.astype(np.int32) with self.assertRaisesRegex( errors.InvalidArgumentError, - 'The module should have 0 platform index arguments and ' - '4 dimension arguments, ' - 'but it has only 4 total arguments'): - self._assertOpOutputMatchesExpected(f, (x, y), - (np.sin(x + y), x.shape[1])) + 'Element type mismatch for argument 0 passed to XlaCallModule: ' + r'expecting tensor<\?x2xf32>, got tensor<3x2xi32>', + ): + self._assertOpOutputMatchesExpected(f, (x_bad_etype, y), (x_bad_etype,)) - dim_args_spec = ['0.0', '0.0', '0.0'] # dim_args_spec refers to non-scalar + y_bad_etype = y.astype(np.float32) with self.assertRaisesRegex( errors.InvalidArgumentError, - 'Module argument at index 2 should be a 0-dimensional integer-tensor ' - 'dimension argument but has type'): - self._assertOpOutputMatchesExpected(f, (x, y), - (np.sin(x + y), x.shape[1])) + 'Element type mismatch for argument 1 passed to XlaCallModule: ' + r'expecting tensor<\*xi32>, got tensor<2x3xf32>', + ): + self._assertOpOutputMatchesExpected(f, (x, y_bad_etype), (x,)) - dim_args_spec = ['1.0'] # Too few dim_args_spec + x_bad_shape = np.arange(15, dtype=np.float32).reshape(5, 3) with self.assertRaisesRegex( errors.InvalidArgumentError, - 'Incorrect number of arguments passed to XlaCallModule: 2. ' - 'The module takes 4 arguments of which 0 platform index arguments ' - 'and 1 dimension arguments.'): - self._assertOpOutputMatchesExpected(f, (x, y), - (np.sin(x + y), x.shape[1])) - - dim_args_spec = ['0.b', '0.1'] # axis_idx not a number - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Syntax error in dim_args_spec '0.b'"): - self._assertOpOutputMatchesExpected(f, (x, y), - (np.sin(x + y), x.shape[1])) - - dim_args_spec = ['2.0', '0.1'] # arg_idx too large - with self.assertRaisesRegex( - errors.InvalidArgumentError, - 'Invalid argument index 2 when the number of non-dimension arguments ' - "is 2 in dim_arg_spec '2.0'"): - self._assertOpOutputMatchesExpected(f, (x, y), - (np.sin(x + y), x.shape[1])) - - dim_args_spec = ['0.3', '0.1'] # axis_idx too large - with self.assertRaisesRegex( - errors.InvalidArgumentError, - 'Invalid axis index 3 when the rank of non-dimension argument 0 ' - "is 3 in dim_arg_spec '0.3'"): - self._assertOpOutputMatchesExpected(f, (x, y), - (np.sin(x + y), x.shape[1])) + 'Shape mismatch for argument 0 passed to XlaCallModule: ' + r'expecting tensor<\?x2xf32>, got tensor<5x3xf32>', + ): + self._assertOpOutputMatchesExpected(f, (x_bad_shape, y), (x_bad_shape,)) def test_platforms_basic(self): x = np.float32(0.) - # returns x + 2. on CPU, x + 3. on GPU and x + 4. on TPU + # returns x + 2. on CPU, x + 3. on GPU (CUDA or ROCM) and x + 4. on TPU module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg_platform_idx: tensor, %arg0: tensor) -> tensor { @@ -301,7 +313,7 @@ module @jit_f.0 { } """) - platforms = ['CPU', 'CUDA', 'TPU'] + platforms = ['CPU', 'CUDA', 'ROCM', 'TPU'] def f(x): return xla.call_module([x], version=version, module=module, @@ -309,40 +321,11 @@ module @jit_f.0 { Sout=[()], platforms=platforms) - expected_value = x + dict(CPU=2., CUDA=3., TPU=4.)[self.testing_platform()] + expected_value = ( + x + dict(CPU=2.0, CUDA=3.0, ROCM=3.0, TPU=4.0)[self.testing_platform()] + ) self._assertOpOutputMatchesExpected(f, (x,), (expected_value,)) - def test_platforms_with_dim_vars(self): - x = np.ones((3,), dtype=np.float32) - y = np.arange(3., dtype=np.float32) - - # returns x + x on CPU and x - x on TPU - module, version = serialize(""" -module @jit_f.0 { - func.func public @main(%arg_platform_idx: tensor, %arg_dim0: tensor, %arg0: tensor, %arg1: tensor) -> tensor { - %res = "stablehlo.case"(%arg_platform_idx) ({ - %0 = stablehlo.add %arg0, %arg1 : tensor - stablehlo.return %0 : tensor - }, { - %1 = stablehlo.subtract %arg0, %arg1 : tensor - stablehlo.return %1 : tensor - }) : (tensor) -> tensor - return %res : tensor - } -} -""") - def f(x, y): - return xla.call_module([x, y], version=version, - module=module, - Tout=[np.float32], - Sout=[(None,)], - platforms=['CPU', 'TPU'], - dim_args_spec=['0.0']) - - expected_value = x + (y if self.testing_platform() == 'CPU' else -y) - if self.testing_platform() in ['CPU', 'TPU']: - self._assertOpOutputMatchesExpected(f, (x, y), (expected_value,)) - def test_platforms_errors(self): """Error reporting for the platforms attribute.""" x = np.float32(0.) @@ -353,17 +336,26 @@ module @jit_f.0 { return %arg0 : tensor } } +""" + module_str_no_platform_arg = """ +module @jit_f.0 { + func.func public @main(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} """ module, version = serialize(module_str) - platforms = [] + platforms = [self.testing_platform()] + disabled_checks = [] def f(x): return xla.call_module([x], version=version, module=module, Tout=[np.float32], Sout=[()], - platforms=platforms) + platforms=platforms, + disabled_checks=disabled_checks) - # With empty platforms, there should be no platform_index argument + # With singleton `platforms`, there should be no platform_index argument with self.assertRaisesRegex( errors.InvalidArgumentError, 'Incorrect number of arguments passed to XlaCallModule: 1. ' @@ -371,23 +363,33 @@ module @jit_f.0 { 'and 0 dimension arguments.'): self._assertOpOutputMatchesExpected(f, (x,), (x,)) - # Same with a single platform - platforms = ['CPU'] - if self.testing_platform() == 'CPU': - with self.assertRaisesRegex( - errors.InvalidArgumentError, - 'Incorrect number of arguments passed to XlaCallModule: 1. ' - 'The module takes 2 arguments of which 0 platform index arguments ' - 'and 0 dimension arguments.'): - self._assertOpOutputMatchesExpected(f, (x,), (x,)) - platforms = ['RANDOM_PLATFORM_1', 'RANDOM_PLATFORM_2'] with self.assertRaisesRegex( errors.NotFoundError, 'The current platform .* is not among the platforms'): self._assertOpOutputMatchesExpected(f, (x,), (x,)) - platforms = ['CPU', 'CUDA'] + # Disable the check but have two platforms + platforms = ['RANDOM_PLATFORM_1', 'RANDOM_PLATFORM_2'] + disabled_checks = [xla.call_module_disable_check_platform()] + # No error + self._assertOpOutputMatchesExpected(f, (x,), (x,)) + + # Disable the check but have a single platform and hence no platform arg. + platforms = ['RANDOM_PLATFORM_1'] + module, version = serialize(module_str_no_platform_arg) + # No error + self._assertOpOutputMatchesExpected(f, (x,), (x,)) + disabled_checks = [] + module, version = serialize(module_str) + + platforms = [] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'must have non-empty platforms'): + self._assertOpOutputMatchesExpected(f, (x,), (x,)) + + platforms = ['CPU', 'CUDA', 'ROCM'] if self.testing_platform() not in platforms: with self.assertRaisesRegex( errors.NotFoundError, @@ -398,7 +400,7 @@ module @jit_f.0 { # The module cannot have i64 %arg_platform_idx module, version = serialize(module_str.replace('i32', 'i64')) - platforms = ['CPU', 'CUDA', 'TPU'] + platforms = ['CPU', 'CUDA', 'ROCM', 'TPU'] with self.assertRaisesRegex( errors.InvalidArgumentError, 'Module argument at index 0 should be a 0-dimensional ' @@ -428,7 +430,12 @@ module @jit_f.0 { # return np.arange(x.shape[0], dtype=np.int32) module, version = serialize(""" module @jit_fun.1 { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg1: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> %1 = "stablehlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor return %1 : tensor @@ -439,7 +446,7 @@ module @jit_fun.1 { module=module, Tout=[res.dtype], Sout=[(None,)], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -471,7 +478,12 @@ module @jit_f.0 { def f(x): # x: f32[b, 3] module, version = serialize(""" module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg1: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<3> : tensor %1 = stablehlo.multiply %arg0, %0 : tensor %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> @@ -484,7 +496,7 @@ module @jit_fun_flat_jax { module=module, Tout=[res.dtype], Sout=[(None,)], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -495,7 +507,12 @@ module @jit_fun_flat_jax { def f(x): # x: f32[b, 4] module, version = serialize(""" module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg1: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<0> : tensor %1 = stablehlo.constant dense<0> : tensor<1xi64> %2 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> @@ -510,7 +527,7 @@ module @jit_fun_flat_jax { module=module, Tout=[res.dtype], Sout=[(None, 2)], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -521,7 +538,12 @@ module @jit_fun_flat_jax { def f(x): # x: f32[b, 4] module, version = serialize(""" module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { + func.func public @main(%arg1: tensor) -> tensor<4xf32> { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor<4xf32> + return %0 : tensor<4xf32> + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { %0 = stablehlo.constant dense<-1> : tensor %1 = stablehlo.add %arg0, %0 : tensor %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> @@ -541,7 +563,7 @@ module @jit_fun_flat_jax { module=module, Tout=[x.dtype], Sout=[(4,)], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -553,7 +575,12 @@ module @jit_fun_flat_jax { def f(x, idx): # x: f32[b, 4] idx: i32 module, version = serialize(""" module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + func.func public @main(%arg1: tensor, %arg2: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = stablehlo.constant dense<0> : tensor %1 = stablehlo.compare LT, %arg2, %0, SIGNED : (tensor, tensor) -> tensor %2 = stablehlo.add %arg2, %arg0 : tensor @@ -568,7 +595,7 @@ module @jit_fun_flat_jax { module=module, Tout=[res.dtype], Sout=[(None, 4)], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x, idx), (res,)) @@ -581,7 +608,12 @@ module @jit_fun_flat_jax { # return (np.broadcast_to(x, y.shape), x + y) module, version = serialize(""" module @jit_fun.0 { - func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { + func.func public @main(%arg1: tensor, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { + %arg0_new = "stablehlo.get_dimension_size"(%arg2) {dimension = 1 : i64} : (tensor<2x?x4xf32>) -> tensor + %0, %1 = call @dyn_main(%arg0_new, %arg1, %arg2) : (tensor, tensor, tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) + return %0, %1 : tensor<2x?x4xf32>, tensor<2x?x4xf32> + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { %0 = stablehlo.constant dense<2> : tensor<1xi32> %2 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> %3 = stablehlo.constant dense<4> : tensor<1xi32> @@ -596,7 +628,7 @@ module @jit_fun.0 { module=module, Tout=[res[0].dtype, res[1].dtype], Sout=[(2, None, 4), (2, None, 4)], - dim_args_spec=['1.1']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x, y), res) @@ -608,14 +640,19 @@ module @jit_fun.0 { def f(x): # x: i32[b] module, version = serialize(""" module @jit_fun{ - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg1: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg2) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<0> : tensor %1 = stablehlo.reduce(%arg1 init: %0) across dimensions = [0] : (tensor, tensor) -> tensor reducer(%arg2: tensor, %arg3: tensor) { - %4 = mhlo.add %arg2, %arg3 : tensor - "mhlo.return"(%4) : (tensor) -> () + %4 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%4) : (tensor) -> () } - %2 = mhlo.multiply %1, %arg0 : tensor + %2 = stablehlo.multiply %1, %arg0 : tensor return %2 : tensor } } @@ -624,7 +661,7 @@ module @jit_fun{ module=module, Tout=[res.dtype], Sout=[res.shape], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -635,7 +672,12 @@ module @jit_fun{ def f(x): # x: f32[b, 5] module, version = serialize(""" module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg1: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<0.000000e+00> : tensor %1 = stablehlo.reduce(%arg1 init: %0) across dimensions = [1] : (tensor, tensor) -> tensor reducer(%arg2: tensor, %arg3: tensor) { @@ -654,7 +696,7 @@ module @jit_fun_flat_jax { module=module, Tout=[res.dtype], Sout=[(None, 1)], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()],) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -666,7 +708,12 @@ module @jit_fun_flat_jax { def f(x): # x: f32[b] module, version = serialize(""" module @jit_fun_3 { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg1: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = call @f(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } @@ -681,7 +728,7 @@ module @jit_fun_3 { module=module, Tout=[res.dtype], Sout=[()], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()]) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -692,7 +739,12 @@ module @jit_fun_3 { def f(x): # x: f32[b] module, version = serialize(""" module @jit_fun_3 { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg1: tensor) -> tensor { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> tensor { return %arg1 : tensor } } @@ -701,7 +753,7 @@ module @jit_fun_3 { module=module, Tout=[res.dtype], Sout=[()], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()]) self._assertOpOutputMatchesExpected(f, (x,), (res,)) @@ -717,7 +769,12 @@ module @jit_fun_3 { def f(x): # x: f32[b] module, version = serialize(""" module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + func.func public @main(%arg1: tensor) -> (tensor, tensor) { + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor + %0, %1 = call @dyn_main(%arg0_new, %arg1) : (tensor, tensor) -> (tensor, tensor) + return %0, %1 : tensor, tensor + } + func.func private @dyn_main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0 = stablehlo.constant dense<0> : tensor %1:2 = "stablehlo.while"(%arg1, %0) ({ ^bb0(%arg2: tensor, %arg3: tensor): @@ -741,10 +798,301 @@ module @jit_fun_flat_jax { module=module, Tout=[res0.dtype, res1.dtype], Sout=[(None,), res1.shape], - dim_args_spec=['0.0']) + platforms=[self.testing_platform()]) self._assertOpOutputMatchesExpected(f, (x,), (res0, res1)) + def test_tf_call_function(self): + """A TensorFlow function call inside StableHLO.""" + x = np.int32(2) + y = np.int32(3) + res = x + y + + @function.Defun(dtypes.int32, dtypes.int32) + def foo(x, y): + return x + y + + def f(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) { + tf.backend_config = {called_index = 0} + } : (tensor, tensor) -> tensor + return %0 : tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + function_list=(foo,), + ) + + self._assertOpOutputMatchesExpected(f, (x, y), (res,)) + + def test_tf_call_function_multiple_funcs(self): + """Multiple TensorFlow function calls inside StableHLO.""" + x = np.int32(2) + y = np.int32(3) + res = (x + y) + (x + y) + + @function.Defun(dtypes.int32, dtypes.int32) + def foo(x, y): + return x + y + + @function.Defun(dtypes.int32, dtypes.int32) + def bar(x, y): + return foo(x, y) + + def f(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) { + tf.backend_config = {called_index = 0} + } : (tensor, tensor) -> tensor + %1 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) { + tf.backend_config = {called_index = 1} + } : (tensor, tensor) -> tensor + %2 = stablehlo.custom_call @tf.call_tf_function(%0, %1) { + tf.backend_config = {called_index = 1} + } : (tensor, tensor) -> tensor + return %2 : tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + function_list=(foo, bar), + ) + + self._assertOpOutputMatchesExpected(f, (x, y), (res,)) + + def test_shape_polymorphic_tf_call_function(self): + """A TensorFlow function call inside StableHLO.""" + x = np.full((2,), 2, dtype=np.int32) + y = np.full((2,), 3, dtype=np.int32) + res = x + y + + @function.Defun(dtypes.int32, dtypes.int32) + def foo(x, y): + return x + y + + def f(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1, %0) { + tf.backend_config = {called_index = 0}, + indices_of_shape_operands = dense<[2]> : tensor<1xi64> + } : (tensor, tensor, tensor) -> tensor + return %1 : tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + function_list=(foo,), + ) + + self._assertOpOutputMatchesExpected(f, (x, y), (res,)) + + def test_tf_call_function_with_token(self): + """A TensorFlow function call inside StableHLO.""" + x = np.int32(2) + y = np.int32(3) + res = x + y + + @function.Defun(dtypes.int32, dtypes.int32) + def foo(x, y): + return x + y + + def f(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func public @main(%arg0: !stablehlo.token, %arg1: tensor, %arg2: tensor) -> (!stablehlo.token, tensor) { + %0:2 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1, %arg2) { + tf.backend_config = {called_index = 0, has_token_input_output = true} + } : (!stablehlo.token, tensor, tensor) -> (!stablehlo.token, tensor) + return %0#0, %0#1 : !stablehlo.token, tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + function_list=(foo,), + has_token_input_output=True, + ) + + self._assertOpOutputMatchesExpected(f, (x, y), (res,)) + + def test_tf_call_function_nested(self): + """Nested XlaCallModule inside TensorFlow function calls.""" + x = np.int32(2) + y = np.int32(3) + res = x + y + + @function.Defun(dtypes.int32, dtypes.int32) + def add(x, y): + return x + y + + @function.Defun(dtypes.int32, dtypes.int32) + def nested_xla_call(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) { + tf.backend_config = {called_index = 0} + } : (tensor, tensor) -> tensor + return %0 : tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + function_list=(add,), + ) + + @function.Defun(dtypes.int32, dtypes.int32) + def call(x, y): + return nested_xla_call(x, y) + + def f(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) { + tf.backend_config = {called_index = 0} + } : (tensor, tensor) -> tensor + return %0 : tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res.dtype], + Sout=[res.shape], + platforms=[self.testing_platform()], + function_list=(call,), + ) + + self._assertOpOutputMatchesExpected(f, (x, y), (res,)) + + def test_tf_call_function_nested_func_renaming(self): + """Multiple custom calls with identically named private functions.""" + x = np.int32(2) + y = np.int32(3) + res0 = x + y + res1 = x - y + + # Verify that multiple inner TF function calls with the same private + # functions are properly renamed during MHLO import. This test case is + # carefully constructed such that one outer XlaCallModule op has two custom + # calls, each of which has the same private "@call" function with different + # body. This is to catch bugs in the func renaming logic. + + @function.Defun(dtypes.int32, dtypes.int32) + def add(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func private @call(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.add %arg0, %arg1 : tensor + return %0 : tensor + } + + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = func.call @call(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res0.dtype], + Sout=[res0.shape], + platforms=[self.testing_platform()], + ) + + @function.Defun(dtypes.int32, dtypes.int32) + def subtract(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func private @call(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.subtract %arg0, %arg1 : tensor + return %0 : tensor + } + + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = func.call @call(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res1.dtype], + Sout=[res1.shape], + platforms=[self.testing_platform()], + ) + + def f(x, y): + module, version = serialize(""" +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) { + tf.backend_config = {called_index = 0} + } : (tensor, tensor) -> tensor + %1 = stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) { + tf.backend_config = {called_index = 1} + } : (tensor, tensor) -> tensor + return %0, %1 : tensor, tensor + } +} +""") + return xla.call_module( + [x, y], + version=version, + module=module, + Tout=[res0.dtype, res1.dtype], + Sout=[res0.shape, res1.shape], + platforms=[self.testing_platform()], + function_list=(add, subtract), + ) + + self._assertOpOutputMatchesExpected(f, (x, y), (res0, res1)) + def test_op_backward_compatibility(self): """Test for ensuring XlaCallModuleOp backward compatiblity.""" x = np.array([1.0, 2.0, 3.0], dtype=np.float32) @@ -769,6 +1117,7 @@ module @jit_f.0 { module=module, Tout=[x.dtype], Sout=[x.shape], + platforms=[self.testing_platform()], ) self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index bb48f9e806b..8f4c707e901 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -443,9 +443,9 @@ tf_custom_op_py_library( deps = [ ":_pywrap_py_utils", ":trt_ops", - "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:for_generated_wrappers", + "//tensorflow/python/ops:resources", ], ) @@ -1030,6 +1030,7 @@ pybind_extension( "@local_config_rocm//:__subpackages__", "@local_config_tensorrt//:__subpackages__", "@local_execution_config_platform//:__subpackages__", + "@ml_dtypes//:__subpackages__", "@nsync//:__subpackages__", "@platforms//:__subpackages__", "@pybind11//:__subpackages__", diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.cc b/tensorflow/compiler/tf2tensorrt/common/utils.cc index 92166c2e79e..26ac37b237b 100644 --- a/tensorflow/compiler/tf2tensorrt/common/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/common/utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include + #if GOOGLE_CUDA && GOOGLE_TENSORRT #include "absl/base/call_once.h" #include "absl/strings/str_cat.h" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0bf252386bc..676281bd6a4 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -210,9 +210,9 @@ filegroup( srcs = [ "xla_compiled_cpu_function.h", "//tensorflow/compiler/xla:cpu_runtime_hdrs", + "//tensorflow/compiler/xla/runtime:aot_ffi_execution_context_hdrs", "//tensorflow/compiler/xla/service:custom_call_status_hdrs", "//tensorflow/compiler/xla/service/cpu:runtime_hdrs", - "//tensorflow/compiler/xla/service/cpu:xla_runtime_runner_hdrs", "//tensorflow/core/kernels:xla_cpu_runtime_hdrs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", "//tensorflow/tsl/framework:xla_cpu_runtime_hdrs", @@ -229,7 +229,6 @@ filegroup( "//tensorflow/compiler/xla:cpu_runtime_srcs", "//tensorflow/compiler/xla/service:custom_call_status_srcs", "//tensorflow/compiler/xla/service/cpu:runtime_srcs", - "//tensorflow/compiler/xla/service/cpu:xla_runtime_runner_srcs", "//tensorflow/core/kernels:xla_cpu_runtime_srcs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", "//tensorflow/tsl/platform:xla_cpu_runtime_srcs", @@ -377,6 +376,7 @@ cc_library( # binary produced by tfcompile. "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla/runtime:aot_ffi_execution_context", "//tensorflow/compiler/xla/service/cpu:buffer_desc", "//tensorflow/core/platform:types", ], @@ -513,6 +513,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index ac616e542a5..69adee9baab 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -11,6 +11,7 @@ load( "tf_cc_test", "tf_cuda_library", ) +load("//tensorflow/tsl/platform:build_config_root.bzl", "if_static") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -236,8 +237,10 @@ tf_kernel_library( "//tensorflow/core/kernels:stateless_random_ops_v2_header", "//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/util:overflow", + "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -250,7 +253,7 @@ tf_kernel_library( ] + if_cuda_or_rocm( if_false = [], if_true = [":light_outside_compilation"], - ), + ) + if_static(["//tensorflow/tsl/platform:tensor_float_32_utils"]), ) tf_cuda_library( @@ -341,8 +344,9 @@ cc_library( "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/core/platform:statusor", + "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "@com_google_absl//absl/types:span", - ], + ] + if_static(["//tensorflow/tsl/platform:tensor_float_32_utils"]), ) cc_library( @@ -382,13 +386,19 @@ cc_library( srcs = ["xla_call_module_loader.cc"], hdrs = ["xla_call_module_loader.h"], deps = [ + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/pjrt:mlir_to_hlo", + "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:regexp", "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -409,27 +419,33 @@ tf_kernel_library( srcs = ["xla_call_module_op.cc"], deps = [ ":xla_call_module_loader", - "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/tf2xla:side_effect_util", - "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:sharding_op_util", - "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo_proto_cc", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:status", "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 007074d8a9d..8f3081515e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" diff --git a/tensorflow/compiler/tf2xla/kernels/assert_op.cc b/tensorflow/compiler/tf2xla/kernels/assert_op.cc index c40caa8fa10..c1c14d7dcaf 100644 --- a/tensorflow/compiler/tf2xla/kernels/assert_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/assert_op.cc @@ -26,7 +26,7 @@ namespace { class AssertOp : public XlaOpKernel { public: explicit AssertOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - ~AssertOp() override {} + ~AssertOp() override = default; void Compile(XlaOpKernelContext* ctx) override { static mutex mu(tensorflow::LINKER_INITIALIZED); diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 095bedcda95..76a91179da6 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -20,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { @@ -41,10 +44,13 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = - xla::BatchDot(MaybeConjugate(ctx->Input(0), adj_x_), adj_x_, - MaybeConjugate(ctx->Input(1), adj_y_), adj_y_, - xla::PrecisionConfig::DEFAULT, preferred_element_type_); + xla::PrecisionConfig::Precision precision = + tsl::tensor_float_32_execution_enabled() + ? xla::PrecisionConfig::DEFAULT + : xla::PrecisionConfig::HIGHEST; + auto result = xla::BatchDot(MaybeConjugate(ctx->Input(0), adj_x_), adj_x_, + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_, + precision, preferred_element_type_); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index e340342b1c9..18526b68538 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/tf2xla/kernels/relu_op.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 14e9f0b5590..5864da9885e 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -21,7 +25,7 @@ limitations under the License. namespace tensorflow { namespace { -void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, +void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp input, DataType input_dtype, const TensorShape& input_tensor_shape, absl::Span block_shape, const xla::Literal& crops) { diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index a8e2755bfe9..60c3077649c 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,6 +16,8 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include + #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index a970c873695..0e2afe33de6 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -15,6 +15,9 @@ limitations under the License. // Native XLA implementations of simple binary Ops +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index eef8e940feb..7db022b280f 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index 5078f8662bd..cce0e332e68 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index 438d454cb21..e1b4ef94208 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/case_op.h" +#include +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index cac026d81b6..1aa64228591 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ #include +#include #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 534cfc58013..a89d3b5f2be 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA implementations of Categorical op. +#include + #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 20934423141..833efb34649 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -17,6 +17,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include +#include +#include +#include + #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -39,10 +44,24 @@ limitations under the License. #include "tensorflow/core/kernels/conv_grad_shape_utils.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { +xla::PrecisionConfig GetPrecisionConfig() { + xla::PrecisionConfig::Precision precision = + tsl::tensor_float_32_execution_enabled() ? xla::PrecisionConfig::DEFAULT + : xla::PrecisionConfig::HIGHEST; + xla::PrecisionConfig config; + const int num_inputs = 2; + config.mutable_operand_precision()->Reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + config.add_operand_precision(precision); + } + return config; +} + // Returns the expanded size of a filter used for depthwise convolution. // If `shape` is [H, W, ..., M, N] returns [H, W, ..., 1, M*N]. xla::Shape GroupedFilterShapeForDepthwiseConvolution( @@ -187,9 +206,10 @@ StatusOr ConvOpAttrs::Create(int num_spatial_dims, bool depthwise, return attrs; } -StatusOr MakeXlaForwardConvOp( - StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter, - const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) { +StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); auto* builder = conv_input.builder(); @@ -277,6 +297,7 @@ StatusOr MakeXlaForwardConvOp( rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, &padding[i].first, &padding[i].second)); } + xla::PrecisionConfig precision_config = GetPrecisionConfig(); if (padding_type != xla::PaddingType::PADDING_INVALID) { return xla::DynamicConvForward( @@ -284,20 +305,22 @@ StatusOr MakeXlaForwardConvOp( dims, /*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count, - /*batch_group_count=*/1, precision_config, padding_type); + /*batch_group_count=*/1, &precision_config, padding_type); } return xla::ConvGeneralDilated( conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, dims, /*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count, - /*batch_group_count=*/1, precision_config); + /*batch_group_count=*/1, &precision_config); } -StatusOr MakeXlaBackpropInputConvOp( - StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, - xla::XlaOp out_backprop, const ConvOpAttrs& attrs, - const xla::PrecisionConfig* precision_config, xla::XlaOp* input_sizes) { +StatusOr MakeXlaBackpropInputConvOp(StringPiece type_string, + const xla::Shape& input_shape, + xla::XlaOp filter, + xla::XlaOp out_backprop, + const ConvOpAttrs& attrs, + xla::XlaOp* input_sizes) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); int num_dims = attrs.num_spatial_dims + 2; @@ -367,6 +390,7 @@ StatusOr MakeXlaBackpropInputConvOp( lhs_dilation[i] = dims.spatial_dims[i].stride; rhs_dilation[i] = attrs.dilations[dim]; } + xla::PrecisionConfig precision_config = GetPrecisionConfig(); if (feature_group_count != 1 && !attrs.depthwise) { filter = TransposeFilterForGroupConvolutionBackpropInput( @@ -381,7 +405,7 @@ StatusOr MakeXlaBackpropInputConvOp( lhs_dilation, rhs_dilation, dnums, /*feature_group_count=*/ feature_group_count, - /*batch_group_count=*/1, precision_config, padding_type); + /*batch_group_count=*/1, &precision_config, padding_type); } // activation gradients // = gradients (with padding and dilation) mirrored_weights @@ -389,13 +413,14 @@ StatusOr MakeXlaBackpropInputConvOp( padding, lhs_dilation, rhs_dilation, dnums, /*feature_group_count=*/ feature_group_count, - /*batch_group_count=*/1, precision_config); + /*batch_group_count=*/1, &precision_config); } -StatusOr MakeXlaBackpropFilterConvOp( - StringPiece type_string, xla::XlaOp activations, - const xla::Shape& filter_shape, xla::XlaOp gradients, - const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) { +StatusOr MakeXlaBackpropFilterConvOp(StringPiece type_string, + xla::XlaOp activations, + const xla::Shape& filter_shape, + xla::XlaOp gradients, + const ConvOpAttrs& attrs) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); auto* builder = activations.builder(); @@ -519,6 +544,7 @@ StatusOr MakeXlaBackpropFilterConvOp( : 0; padding[i] = {pad_before, pad_total - pad_before}; } + xla::PrecisionConfig precision_config = GetPrecisionConfig(); // Besides padding the input, we will also expand output_rows to // expanded_out_rows = (output_rows - 1) * stride + 1 @@ -533,14 +559,14 @@ StatusOr MakeXlaBackpropFilterConvOp( activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums, /*feature_group_count=*/1, - /*batch_group_count=*/batch_group_count, precision_config, + /*batch_group_count=*/batch_group_count, &precision_config, padding_type); } else { filter_backprop = xla::ConvGeneralDilated( activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums, /*feature_group_count=*/1, - /*batch_group_count=*/batch_group_count, precision_config); + /*batch_group_count=*/batch_group_count, &precision_config); } if (attrs.depthwise) { diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 7922c6ba821..70c579cde73 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -58,20 +58,19 @@ struct ConvOpAttrs { // Creates a new XLA forward or backward convolution with the given inputs and // attributes. -StatusOr MakeXlaForwardConvOp( - StringPiece type_string, xla::XlaOp conv_input, xla::XlaOp filter, - const ConvOpAttrs& attrs, - const xla::PrecisionConfig* precision_config = nullptr); +StatusOr MakeXlaForwardConvOp(StringPiece type_string, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs); StatusOr MakeXlaBackpropInputConvOp( StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, xla::XlaOp out_backprop, const ConvOpAttrs& attrs, - const xla::PrecisionConfig* precision_config = nullptr, xla::XlaOp* input_sizes = nullptr); -StatusOr MakeXlaBackpropFilterConvOp( - StringPiece type_string, xla::XlaOp activations, - const xla::Shape& filter_shape, xla::XlaOp gradients, - const ConvOpAttrs& attrs, - const xla::PrecisionConfig* precision_config = nullptr); +StatusOr MakeXlaBackpropFilterConvOp(StringPiece type_string, + xla::XlaOp activations, + const xla::Shape& filter_shape, + xla::XlaOp gradients, + const ConvOpAttrs& attrs); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 1d94cf4969f..0f1b53c8a56 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -113,7 +113,7 @@ class ConvBackpropInputOp : public XlaOpKernel { xla::XlaOp input_sizes = ctx->Input(0); StatusOr in_backprop = MakeXlaBackpropInputConvOp( ctx->op_kernel().type_string(), input_shape, ctx->Input(1), - ctx->Input(2), attrs_, nullptr, &input_sizes); + ctx->Input(2), attrs_, &input_sizes); OP_REQUIRES_OK(ctx, in_backprop.status()); ctx->SetOutput(0, in_backprop.value()); } diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index a1ab899a7c1..923f578900c 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 10570c91339..14678369741 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include +#include +#include + #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 199d3514c22..748ce28777f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -18,6 +18,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ +#include +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -41,7 +44,7 @@ class XlaBinaryOp : public XlaOpKernel { OP_REQUIRES(ctx, lhs == rhs, errors::InvalidArgument("Input types of binary op must match")); } - ~XlaBinaryOp() override {} + ~XlaBinaryOp() override = default; // Implement the (tensor,tensor)->tensor lambda that should be // applied to the inputs. The desired computation should be added to diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index e2b3e3ffcf5..5833480a664 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 6ca29c5526f..a8bad158812 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc index ff058f92cd7..3c4bbe78bfe 100644 --- a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 1ad75b65c66..e11844303da 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index d6de86a4ef8..635cad36675 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index cd03b617158..5a36c175478 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -15,6 +15,9 @@ limitations under the License. // XLA-specific dynamic stitch Op. +#include +#include + #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/empty_op.cc b/tensorflow/compiler/tf2xla/kernels/empty_op.cc index 2b90a2c4d35..348cdf06ec6 100644 --- a/tensorflow/compiler/tf2xla/kernels/empty_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/empty_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific Empty Op. +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 55bce65bd8e..49e80226786 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index eb9de507fb0..d437e5476b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -47,7 +50,7 @@ void CpuNudge(const float min, const float max, const float quant_min, // An XLA version of CpuNudge(). void XlaNudge(xla::XlaBuilder* b, const DataType data_type, - const xla::XlaOp& min, const xla::XlaOp& max, + const xla::XlaOp min, const xla::XlaOp max, const float quant_min_value, const float quant_max_value, xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, xla::XlaOp* scale) { @@ -67,11 +70,10 @@ void XlaNudge(xla::XlaBuilder* b, const DataType data_type, *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale); } -xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, - const DataType data_type, - const xla::XlaOp& nudged_input_min, - const xla::XlaOp& nudged_input_max, - const xla::XlaOp& input_scale) { +xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp input, + const DataType data_type, const xla::XlaOp nudged_input_min, + const xla::XlaOp nudged_input_max, + const xla::XlaOp input_scale) { xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); xla::XlaOp inv_scale = xla::Div(one, input_scale); xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 1368d15a030..48726350c98 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -15,6 +15,9 @@ limitations under the License. // XLA-specific Ops for FFT. +#include +#include + #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index ebcbadb894e..3c5f41161ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific Fill Op. +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index 516e3aeaa88..3da7ce96bee 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -55,7 +55,7 @@ class AlwaysFailOp : public OpKernel { public: explicit AlwaysFailOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~AlwaysFailOp() override {} + ~AlwaysFailOp() override = default; void Compute(OpKernelContext* ctx) override { ctx->CtxFailure(errors::FailedPrecondition( diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index a28a0e9eb26..807e4304e8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 83ab17686e9..54d186dd12d 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -43,7 +43,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // the input instead of context->input(0) in order to allow ResourceGather to // handle obtaining the data from the ResourceVariable. Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, - const xla::XlaOp input, + xla::XlaOp input, const TensorShape& input_shape, int batch_dims, xla::XlaOp* gather_output); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc index 3162d197480..a7e47c3850a 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 7dd618aaf91..4a55c479ac0 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" +#include + #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index 42f4e9d9e6b..11b196f939e 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/framework/attr_value.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc index 15314b0434e..7d3a7c7d176 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc @@ -15,6 +15,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" +#include +#include +#include +#include + #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h index 631fedd25f7..15f30975076 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ +#include +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 4abfb149792..8e8b7134413 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 48c63d56e4c..6d034b8c6c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h" +#include +#include #include +#include #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -213,7 +216,7 @@ xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder, } xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder, - const xla::XlaOp& input, + const xla::XlaOp input, int32_t spatial_dimensions_offset, absl::Span in_size, absl::Span out_size) { @@ -235,7 +238,7 @@ xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder, } xla::XlaOp ResizeUsingDilationAndConvolution( - xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type, + xla::XlaBuilder* builder, const xla::XlaOp input, xla::PrimitiveType type, const int num_spatial_dims, absl::Span in_size, absl::Span out_size, const int64_t channels, const bool align_corners, bool is_kernel_bilinear) { @@ -381,7 +384,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( } xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( - xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type, + xla::XlaBuilder* builder, const xla::XlaOp grad, xla::PrimitiveType type, const int num_spatial_dims, absl::Span in_size, absl::Span grad_size, const int64_t channels, const bool align_corners, bool is_kernel_bilinear) { diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 8e81356ea85..bb28f1ea0aa 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index aaf6a8f89eb..fa8a5ddf8f1 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -281,7 +281,7 @@ int GetOutputBufferId(int output_num, const TfCallbackData& callback_data) { int64_t BufferSize(const TfCallbackData::BufferDescription& descr) { TensorShape shape; - CHECK(TensorShape::BuildTensorShape(descr.shape(), &shape).ok()); // Crash OK + TF_CHECK_OK(TensorShape::BuildTensorShape(descr.shape(), &shape)); // Crash OK return shape.num_elements() * DataTypeSize(descr.type()); } diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h index e4786f0142e..24675783495 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_LIGHT_OUTSIDE_COMPILATION_H_ #include +#include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index e741d6dfcff..ec95ceccfe6 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -16,8 +16,11 @@ limitations under the License. // XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64 // input. +#include #include +#include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -74,7 +77,7 @@ class ListDiffOp : public XlaOpKernel { TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input)); TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input)); - std::unordered_set y_input_set; + absl::flat_hash_set y_input_set; y_input_set.reserve(y_input.size()); for (auto y : y_input) { y_input_set.insert(y); diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 785b7bea107..86c0d97e97f 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific MatMul Op. +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/tsl/platform/tensor_float_32_utils.h" namespace tensorflow { namespace { @@ -88,7 +91,12 @@ class MatMulOp : public XlaOpKernel { b = xla::ConvertElementType(b, xla::F32); } } - ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_)); + xla::PrecisionConfig::Precision precision = + tsl::tensor_float_32_execution_enabled() + ? xla::PrecisionConfig::DEFAULT + : xla::PrecisionConfig::HIGHEST; + ctx->SetOutput(0, + xla::BatchDot(a, transpose_a_, b, transpose_b_, precision)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index 2fcbf22a5f0..e9cb7b60db9 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index b221f55655e..7edb6fbf3b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 64351c6a741..83e8697d8c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -15,7 +15,9 @@ limitations under the License. // XLA specific pooling ops. +#include #include +#include #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index ad1554312f2..f36b07cc93c 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 176c83a8375..2a66980314d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,8 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include + #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/lib/random.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 662052dac29..fef848224c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -16,6 +16,9 @@ limitations under the License. // XLA-specific reduction Ops. #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" + +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 8141dde7f2c..42631ae4b5b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -18,6 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -34,7 +36,7 @@ namespace tensorflow { class XlaReductionOp : public XlaOpKernel { public: XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); - ~XlaReductionOp() override {} + ~XlaReductionOp() override = default; // Return the base case for the reduction. virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 95a2a454210..1194a2e0c70 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific reduction Ops. +#include + #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index e6b21219894..fec4b5fea61 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific reshape Op. +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 72932ea72ec..0df51930ffc 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific reverse Op. +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/roll_op.cc b/tensorflow/compiler/tf2xla/kernels/roll_op.cc index ae0827391d8..b9b9939d0ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/roll_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/roll_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 7be090adb4a..12bdc30a950 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 1812ddab2b6..0e41300351f 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 46b0aab40dc..9b33307ac00 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 446d7f3d7aa..c1ca1fc67d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 60b1f5eea3a..39ca2030045 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,11 @@ limitations under the License. // XLA-specific Shape Ops. +#include +#include +#include + +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" @@ -312,7 +317,7 @@ class SqueezeOp : public XlaOpKernel { xla::Shape shape = input_shape.value(); int64_t rank = shape.rank(); - std::unordered_set wrapped_squeeze_dims; + absl::flat_hash_set wrapped_squeeze_dims; wrapped_squeeze_dims.reserve(squeeze_dims_.size()); std::vector new_shape; // Validate squeeze dims against the input. @@ -360,7 +365,7 @@ class SqueezeOp : public XlaOpKernel { } private: - std::unordered_set squeeze_dims_; + absl::flat_hash_set squeeze_dims_; }; REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp); diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc index fdfef3a5355..14cdb4ab1e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 2c78ad74c16..0f930046c31 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific Slice Op. +#include + #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 520c73ffbbf..44c332c5eb5 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,9 @@ limitations under the License. // XLA-specific Ops for softmax. +#include +#include + #include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index d6e38f1309f..39ce1057139 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace { -void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, +void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp input, DataType input_dtype, const TensorShape& input_tensor_shape, absl::Span block_shape, const xla::Literal& paddings) { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 378aef0205d..be3d8e01b35 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index 4f3c7b79861..5d6ccd54a2d 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index ae5150f14f9..4871a89b0af 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -15,6 +15,8 @@ limitations under the License. // XLA-specific Ops for split. +#include + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index ad33157ed78..a64c0d54a7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index b5445aa3e90..76cd46e1893 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/stateless_random_ops_v2.h" #include +#include +#include #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/kernels/rng_converter_utils.h" diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index a057b84c28b..d6e7f404fb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/util/strided_slice_op.h" +#include #include #include "absl/algorithm/container.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index a0f8f62cd57..fa70efbd906 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -118,8 +118,8 @@ Status GetTensorArrayShape(const XlaResource* resource, // Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the // relevant slice of 'operand'. -xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, - const xla::XlaOp& update, +xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp operand, + const xla::XlaOp update, absl::Span update_dims, absl::Span start_indices, DataType dtype) { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index f7e0f350a2f..544fc1d14ba 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" +#include + #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index f31cfd9eafc..d422bb63afd 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index d873396a828..228d1c5bbe3 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -18,6 +18,8 @@ limitations under the License. // handles all transposes, while Eigen needs a restricted DoTranspose // helper. +#include + #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc index 3c992ee8407..8da1b6eb6eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index c4389baf4a8..39fd56cd5ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include +#include #include #include diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 6fe7a44f32a..11b67ba9e54 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" @@ -208,7 +211,7 @@ class ResourceScatterAddOp : public ResourceScatterOp { : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Add(x, y); } @@ -221,7 +224,7 @@ class ResourceScatterSubOp : public ResourceScatterOp { : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Sub(x, y); } @@ -234,7 +237,7 @@ class ResourceScatterMulOp : public ResourceScatterOp { : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Mul(x, y); } @@ -247,7 +250,7 @@ class ResourceScatterDivOp : public ResourceScatterOp { : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Div(x, y); } @@ -260,7 +263,7 @@ class ResourceScatterMinOp : public ResourceScatterOp { : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Min(x, y); } @@ -273,7 +276,7 @@ class ResourceScatterMaxOp : public ResourceScatterOp { : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Max(x, y); } @@ -303,7 +306,7 @@ class ResourceScatterNdAddOp : public ResourceScatterOp { /*combiner=*/Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Add(x, y); } @@ -317,7 +320,7 @@ class ResourceScatterNdSubOp : public ResourceScatterOp { /*combiner=*/Combine) {} private: - static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + static xla::XlaOp Combine(const xla::XlaOp x, const xla::XlaOp y, xla::XlaBuilder* builder) { return xla::Sub(x, y); } diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index 107cf72eb5a..44ef36c063b 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 53b02fb5416..53a5e3c0525 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" +#include +#include +#include + #include "absl/strings/str_split.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" @@ -246,7 +250,7 @@ StatusOr BuildWrappedBody( xla::XlaOp BuildWhile(XlaOpKernelContext* ctx, const xla::XlaComputation& wrapped_cond, const xla::XlaComputation& wrapped_body, - const xla::XlaOp& initial_values, + const xla::XlaOp initial_values, const std::vector& input_mapping, const std::vector& compile_time_const_arg_indices, int num_compile_time_const_args, @@ -347,12 +351,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.return_updated_values_for_all_resources = true; body_options.is_entry_computation = false; body_options.add_token_input_output = has_token_input_output_; - XlaCompiler::CompilationResult body; + auto body = std::make_unique(); OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, - arguments, &body)); + arguments, body.get())); OP_REQUIRES_OK( ctx, ctx->xla_context()->RecordCollectiveInfoFromNestedCompilationResult( - body)); + *body.get())); // We must use a static shape for parameters to an XLA compilation. However, // we may not know the shape of a resource if it is first @@ -378,8 +382,8 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { << has_uninitialized_tensor_lists; // Initializes any uninitialized resource with zero values of the // shape determined by the first compilation. - for (int i = 0; i < body.resource_updates.size(); ++i) { - const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + for (int i = 0; i < body->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body->resource_updates[i]; XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); @@ -416,7 +420,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Set the shape of any uninitialized TensorLists to the shape determined by // the first compilation. Note that, unlike resources, we do not initialize // the input list with zeros here, that is done later. - xla::Shape body_output_shape = body.xla_output_shape; + xla::Shape body_output_shape = body->xla_output_shape; OP_REQUIRES(ctx, body_output_shape.IsTuple(), errors::FailedPrecondition( "xla_output_shape of while body must be a tuple.")); @@ -431,9 +435,9 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Recompile the body with the "correct" resource shapes. VLOG(1) << "Recompiling body with corrected resource shapes"; - body = {}; + *body = {}; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, - arguments, &body)); + arguments, body.get())); } VLOG(1) << "Compiling condition"; @@ -446,9 +450,9 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); - OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, + OP_REQUIRES(ctx, body->xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); - xla::Shape body_input_shape = body.xla_input_shapes[0]; + xla::Shape body_input_shape = body->xla_input_shapes[0]; OP_REQUIRES(ctx, body_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, @@ -458,7 +462,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { errors::FailedPrecondition("Expected tuple shape")); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) - << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); + << " -> " << xla::ShapeUtil::HumanString(body->xla_output_shape); VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape) << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape); @@ -473,7 +477,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // args (which are pruned from the body outputs in body_wapper) matches the // shape of the inputs. OP_REQUIRES_OK(ctx, VerifyBodyInputAndOutputShapeMatch( - ctx, compile_time_const_arg_indices, body, + ctx, compile_time_const_arg_indices, *body.get(), has_token_input_output_)); xla::Shape expected_cond_output_shape_without_side_effect = @@ -494,10 +498,10 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { "(pred[], token[]), got: ", xla::ShapeUtil::HumanString(cond.xla_output_shape))); - int num_inputs = body.input_mapping.size(); + int num_inputs = body->input_mapping.size(); std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { - int input_num = body.input_mapping[i]; + int input_num = body->input_mapping[i]; if (has_token_input_output_ && i == num_inputs - 1) { // Set token input for this "while" op. std::vector token_inputs; @@ -577,14 +581,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Remove compile time const args from the list of body outputs. StatusOr body_result = - BuildWrappedBody(ctx, body, compile_time_const_arg_indices, + BuildWrappedBody(ctx, *body.get(), compile_time_const_arg_indices, num_compile_time_const_args, has_token_input_output_); OP_REQUIRES_OK(ctx, body_result.status()); xla::XlaComputation wrapped_body = std::move(body_result.value()); // Builds the While op and pads its output with the compile time const args. xla::XlaOp while_result = - BuildWhile(ctx, wrapped_cond, wrapped_body, init, body.input_mapping, + BuildWhile(ctx, wrapped_cond, wrapped_body, init, body->input_mapping, compile_time_const_arg_indices, num_compile_time_const_args, has_token_input_output_); @@ -617,8 +621,8 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } // Updates the values of any resource variables modified by the loop. - for (int i = 0; i < body.resource_updates.size(); ++i) { - const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + for (int i = 0; i < body->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body->resource_updates[i]; XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 0e259b3bac0..2f0b6c3a7f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/core/framework/attr_value.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc index 44495e77e33..2beabc31f34 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "absl/algorithm/container.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index c8a82fbfa28..e265184ad2d 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -15,11 +15,15 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" +#include #include #include #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project @@ -44,10 +48,14 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/regexp.h" #include "tensorflow/tsl/platform/statusor.h" @@ -60,20 +68,37 @@ namespace { // version in the constructor in xla.py. // Version 1 used MHLO & CHLO, not supported anymore. // Version 2 supports StableHLO & CHLO. From 10/2022. -const int VERSION_START_STABLE_HLO = 2; +constexpr int VERSION_START_STABLE_HLO = 2; // Version 3 supports platform checking and multiple platforms. From 02/2023. -const int VERSION_START_PLATFORMS = 3; +constexpr int VERSION_START_PLATFORMS = 3; // Version 4 supports StableHLO with compatibility guarantees. -// Used from 03/2023. -const int VERSION_START_STABLE_HLO_COMPATIBILITY = 4; -// Version 5 add support to stablehlo.custom_call for host call tf graph. -// Used from 04/2023. -const int VERSION_SUPPORT_CUSTOM_CALL = 5; -const int VERSION_MINIMUM_SUPPORTED = VERSION_START_STABLE_HLO; -const int VERSION_MAXIMUM_SUPPORTED = VERSION_SUPPORT_CUSTOM_CALL; +// Used in jax2tf from March 15, 2023 (cl/516885716). Starting with +// March 28th, 2023 we stopped using dim_args_spec (cl/520033493). +// TODO(b/283439649): Remove support for dim_args_spec. +constexpr int VERSION_START_STABLE_HLO_COMPATIBILITY = 4; +// Version 5 adds support for call_tf_graph. This does not change the semantics +// of the op, but it allows the `function_list` attribute. +// Used in jax2tf from May 3rd, 2023 (cl/529106145). +constexpr int VERSION_START_SUPPORT_CALL_TF_GRAPH = 5; +// Version 6 adds support for the `disabled_checks` attribute. This version +// mandates a non-empty `platforms` attribute. +// Used in jax2tf since June 2023. +constexpr int VERSION_START_SUPPORT_DISABLED_CHECKS = 6; +constexpr int VERSION_MINIMUM_SUPPORTED = + VERSION_START_STABLE_HLO_COMPATIBILITY; + +constexpr int VERSION_MAXIMUM_SUPPORTED = VERSION_START_SUPPORT_DISABLED_CHECKS; + +constexpr absl::string_view DISABLED_CHECK_PLATFORM = "platform"; + +bool IsPlatformCheckDisabled(absl::Span disabled_checks) { + return std::find(disabled_checks.begin(), disabled_checks.end(), + DISABLED_CHECK_PLATFORM) != disabled_checks.end(); +} // Computes a dimension value from the dim_arg specification. // The specification is of the form ".". +// TODO(b/283439649): Remove support for dim_args_spec. tsl::StatusOr ComputeDimensionValue( int version, std::string dim_arg_spec, std::vector arguments, mlir::OpBuilder op_builder, mlir::Type dim_arg_type) { @@ -81,27 +106,27 @@ tsl::StatusOr ComputeDimensionValue( int arg_idx, arg_axis_idx; if (!RE2::FullMatch(dim_arg_spec, *dim_arg_spec_re, &arg_idx, &arg_axis_idx)) { - return tsl::errors::InvalidArgument("Syntax error in dim_args_spec '", - dim_arg_spec, "'"); + return absl::InvalidArgumentError( + absl::StrCat("Syntax error in dim_args_spec '", dim_arg_spec, "'")); } if (arg_idx < 0 || arg_idx >= arguments.size()) { - return tsl::errors::InvalidArgument( + return absl::InvalidArgumentError(absl::StrCat( "Invalid argument index ", arg_idx, " when the number of non-dimension arguments is ", arguments.size(), - " in dim_arg_spec '", dim_arg_spec, "'"); + " in dim_arg_spec '", dim_arg_spec, "'")); } mlir::RankedTensorType arg_type = arguments[arg_idx].getType().dyn_cast(); if (!arg_type) { - return tsl::errors::InvalidArgument( - "Argument ", arg_idx, " referenced in dim_arg_spec '", dim_arg_spec, - "' does not have a RankedTensorType"); + return absl::InvalidArgumentError( + absl::StrCat("Argument ", arg_idx, " referenced in dim_arg_spec '", + dim_arg_spec, "' does not have a RankedTensorType")); } if (arg_axis_idx < 0 || arg_axis_idx >= arg_type.getShape().size()) { - return tsl::errors::InvalidArgument( + return absl::InvalidArgumentError(absl::StrCat( "Invalid axis index ", arg_axis_idx, " when the rank of non-dimension argument ", arg_idx, " is ", - arg_type.getShape().size(), " in dim_arg_spec '", dim_arg_spec, "'"); + arg_type.getShape().size(), " in dim_arg_spec '", dim_arg_spec, "'")); } mlir::Value val; mlir::Type get_dim_type = @@ -120,27 +145,14 @@ tsl::StatusOr ComputeDimensionValue( tsl::StatusOr> XlaCallModuleLoader::Create( mlir::MLIRContext *context, int version, std::string module_str, - std::vector dim_args_spec, int platform_index) { - if (version < VERSION_MINIMUM_SUPPORTED) { - return tsl::errors::InvalidArgument( - "XlaCallModuleOp with version ", version, - " is not supported anymore. Must be >= ", VERSION_MINIMUM_SUPPORTED); - } - if (version > VERSION_MAXIMUM_SUPPORTED) { - return tsl::errors::InvalidArgument( - "XlaCallModuleOp with version ", version, - " is not supported by this build. Must be <= ", - VERSION_MAXIMUM_SUPPORTED); - } - - if (version < VERSION_START_PLATFORMS) { - platform_index = -1; - } - + std::vector dim_args_spec, + std::vector disabled_checks, + std::vector platforms, std::string loading_platform) { std::unique_ptr loader(new XlaCallModuleLoader); TF_RETURN_IF_ERROR(loader->LoadAndPreprocessModule( context, version, std::move(module_str), std::move(dim_args_spec), - platform_index)); + std::move(disabled_checks), std::move(platforms), + std::move(loading_platform))); return loader; } @@ -191,18 +203,18 @@ tsl::Status XlaCallModuleLoader::AddMainWrapper() { mlir::func::FuncOp orig_main = module_->lookupSymbol("main"); if (!orig_main) { - return tsl::errors::InvalidArgument("Cannot find 'main' in module"); + return absl::InvalidArgumentError("Cannot find 'main' in module"); } int nr_platform_args = 0; if (platform_index_ >= 0) { nr_platform_args = 1; } if (orig_main.getNumArguments() <= nr_platform_args + nr_dim_args) { - return tsl::errors::InvalidArgument( - "The module should have ", nr_platform_args, - " platform index arguments and ", nr_dim_args, - " dimension arguments, but it ", "has only ", - orig_main.getNumArguments(), " total arguments"); + return absl::InvalidArgumentError( + absl::StrCat("The module should have ", nr_platform_args, + " platform index arguments and ", nr_dim_args, + " dimension arguments, but it ", "has only ", + orig_main.getNumArguments(), " total arguments")); } mlir::Block &orig_main_body = orig_main.front(); @@ -237,18 +249,18 @@ tsl::Status XlaCallModuleLoader::AddMainWrapper() { !arg_ranked_type.getShape().empty()) { std::string argument_type = (i < nr_platform_args) ? "platform index" : "dimension"; - return tsl::errors::InvalidArgument( + return absl::InvalidArgumentError(absl::StrCat( "Module argument at index ", i, " should be a 0-dimensional integer-tensor ", argument_type, - " argument but has type ", mlir::debugString(arg_type)); + " argument but has type ", mlir::debugString(arg_type))); } if (i < nr_platform_args) { if (arg_ranked_type.getElementTypeBitWidth() != 32) { - return tsl::errors::InvalidArgument( - "Module argument at index ", i, - " should be a 0-dimensional 32-bit integer-tensor" - " platform index argument but has type ", - mlir::debugString(arg_type)); + return absl::InvalidArgumentError( + absl::StrCat("Module argument at index ", i, + " should be a 0-dimensional 32-bit integer-tensor" + " platform index argument but has type ", + mlir::debugString(arg_type))); } call_args[i] = op_builder.create( block_args[0].getLoc(), @@ -268,8 +280,10 @@ tsl::Status XlaCallModuleLoader::AddMainWrapper() { mlir::func::CallOp call_op = op_builder.create( loc, orig_main.getResultTypes(), orig_main.getSymName(), call_args); op_builder.create(loc, call_op.getResults()); - VLOG(3) << "XlaCallModule module with wrapper: " - << mlir::debugString(*module_); + + if (VLOG_IS_ON(5)) { + DumpMlirOpToFile("xla_call_module.after_add_main_wrapper", *module_); + } return tsl::OkStatus(); } @@ -283,35 +297,62 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( int nr_dim_args = dim_args_spec_.size(); int non_dimension_arguments = input_shapes.size(); if (non_dimension_arguments != main_body.getNumArguments()) { - return tsl::errors::InvalidArgument( + return absl::InvalidArgumentError(absl::StrCat( "Incorrect number of arguments passed to XlaCallModule: ", non_dimension_arguments, ". The module takes ", main_body.getNumArguments() + nr_platform_args + nr_dim_args, " arguments of which ", nr_platform_args, " platform index arguments and ", nr_dim_args, " dimension arguments. It must be called with ", - main_body.getNumArguments(), " arguments."); + main_body.getNumArguments(), " arguments.")); } mlir::Builder builder(module_->getContext()); std::vector static_array_input_types(non_dimension_arguments); for (int i = 0, end = non_dimension_arguments; i < end; ++i) { const xla::Shape &xla_shape = input_shapes[i]; - std::vector xla_dimensions(xla_shape.dimensions().begin(), - xla_shape.dimensions().end()); - TF_ASSIGN_OR_RETURN( - mlir::Type element_type, - ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder)); - mlir::Type type = mlir::RankedTensorType::get(xla_dimensions, element_type); - // TODO(burmako): This fails with an obscure compilation error. - // TF_ASSIGN_OR_RETURN( - // mlir::Type type, - // ConvertShapeToType(xla_shape, builder)); - VLOG(3) << "XlaCallModule static array input type #" << i << ": " - << mlir::debugString(type); - // TODO(b/278273480): Determine whether it's safe to override the element - // type using that from the input shape. - static_array_input_types[i] = type; + if (xla_shape.IsToken()) { + static_array_input_types[i] = mlir::stablehlo::TokenType::get(context_); + } else { + std::vector xla_dimensions(xla_shape.dimensions().begin(), + xla_shape.dimensions().end()); + TF_ASSIGN_OR_RETURN( + mlir::Type element_type, + ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder)); + mlir::RankedTensorType type = + mlir::RankedTensorType::get(xla_dimensions, element_type); + // TODO(burmako): This fails with an obscure compilation error. + // TF_ASSIGN_OR_RETURN( + // mlir::Type type, + // ConvertShapeToType(xla_shape, builder)); + VLOG(3) << "XlaCallModule static array input type #" << i << ": " + << mlir::debugString(type); + mlir::TensorType arg_type = + main_body.getArgument(i).getType().dyn_cast(); + if (arg_type == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Argument ", i, " passed to XlaCallModule is not a tensor")); + } + + if (arg_type.getElementType() != type.getElementType()) { + return absl::InvalidArgumentError(absl::StrCat( + "Element type mismatch for argument ", i, + " passed to XlaCallModule: ", "expecting ", + mlir::debugString(arg_type), ", got ", mlir::debugString(type))); + } + + if (auto ranked_arg_type = arg_type.dyn_cast()) { + if (mlir::failed(mlir::verifyCompatibleShape(ranked_arg_type.getShape(), + type.getShape()))) { + return absl::InvalidArgumentError(absl::StrCat( + "Shape mismatch for argument ", i, + " passed to XlaCallModule: ", "expecting ", + mlir::debugString(arg_type), ", got ", mlir::debugString(type))); + } + } + + static_array_input_types[i] = type; + } } // Refine 'main' argument types to use static input types instead. @@ -320,13 +361,18 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( // shape refinement as explained below. // Before refining the argument types it is useful to run the inliner to // remove calls that may be called with the input arguments. - mlir::PassManager pm_inline(module_->getContext()); - pm_inline.addPass(mlir::createInlinerPass()); - if (!mlir::succeeded(pm_inline.run(*module_))) { - return tsl::errors::InvalidArgument("Module inlining failed"); + { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); + + mlir::PassManager pm_inline(module_->getContext()); + applyTensorflowAndCLOptions(pm_inline); + pm_inline.addPass(mlir::createInlinerPass()); + + if (mlir::failed(pm_inline.run(*module_))) { + return absl::InvalidArgumentError(absl::StrCat( + "Module inlining failed: ", diag_handler.ConsumeStatus().ToString())); + } } - VLOG(3) << "XlaCallModule module after inlining: " - << mlir::debugString(*module_); auto static_array_output_types = llvm::to_vector(main_.getResultTypes()); for (auto i = 0; i < main_body.getNumArguments(); ++i) { @@ -346,42 +392,49 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( } main_.setType(builder.getFunctionType(static_array_input_types, static_array_output_types)); + if (VLOG_IS_ON(5)) { + DumpMlirOpToFile("xla_call_module.after_refined_input_types", *module_); + } // Verify the module before running passes on it. // If the module doesn't pass verification, all sorts of weirdness might // happen if we run the pass manager. - if (failed(verify(*module_))) { - VLOG(3) << "XlaCallModule module with verification failed: " - << mlir::debugString(*module_); - return tsl::errors::InvalidArgument("Module verification failed"); - } - mlir::PassManager pm(module_->getContext()); - if (VLOG_IS_ON(3)) { - auto print_before = [](mlir::Pass *, mlir::Operation *) { return true; }; - auto print_after = [](mlir::Pass *, mlir::Operation *) { return true; }; - pm.enableIRPrinting(print_before, print_after, /*printModuleScope=*/true, - /*printAfterOnlyOnChange=*/false); - } - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); - pm.addNestedPass( - mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); - if (!mlir::succeeded(pm.run(*module_))) { - return tsl::errors::InvalidArgument("Module shape refinement failed"); - } + { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - VLOG(3) << "XlaCallModule module with refined shapes: " - << mlir::debugString(*module_); + if (failed(verify(*module_))) { + return absl::InvalidArgumentError( + absl::StrCat("Module verification failed: ", + diag_handler.ConsumeStatus().ToString())); + } + + mlir::PassManager pm(module_->getContext()); + applyTensorflowAndCLOptions(pm); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); + pm.addNestedPass( + mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); + if (mlir::failed(pm.run(*module_))) { + return absl::InvalidArgumentError( + absl::StrCat("Module shape refinement failed: ", + diag_handler.ConsumeStatus().ToString())); + } + + if (VLOG_IS_ON(3)) { + DumpMlirOpToFile("xla_call_module.after_shape_refinement", *module_); + } + } return tsl::OkStatus(); } tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( mlir::MLIRContext *context, int version, std::string module_str, - std::vector dim_args_spec, int platform_index) { + std::vector dim_args_spec, + std::vector disabled_checks, + std::vector platforms, std::string loading_platform) { context_ = context; version_ = version; dim_args_spec_ = std::move(dim_args_spec); - platform_index_ = platform_index; // Load a superset of dialects; we should check at serialization time that // we only include allowable dialects. @@ -390,6 +443,13 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( context_->loadDialect(); context_->loadDialect(); context_->loadDialect(); + + if (version >= VERSION_START_SUPPORT_DISABLED_CHECKS && platforms.empty()) { + return absl::InvalidArgumentError( + absl::StrCat("XlaCallModuleOp with version ", version, + " must have non-empty platforms.")); + } + // Parses both IR text and bytecode. if (version >= VERSION_START_STABLE_HLO_COMPATIBILITY) { module_ = @@ -398,22 +458,75 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( module_ = mlir::parseSourceString(module_str, context_); } + std::vector loading_disabled_checks = disabled_checks; + loading_disabled_checks.insert( + loading_disabled_checks.end(), + GetXlaCallModuleFlags()->disabled_checks.begin(), + GetXlaCallModuleFlags()->disabled_checks.end()); if (!module_) { - return tsl::errors::InvalidArgument("Cannot deserialize computation"); + return absl::InvalidArgumentError("Cannot deserialize computation"); } - VLOG(3) << "Parsed serialized module (version " << version - << ", platform_index = " << platform_index_ << ", dim_args_spec = [" - << absl::StrJoin(dim_args_spec_, ", ") << "])\n" - << mlir::debugString(*module_); - if (failed(module_->verifyInvariants())) { - VLOG(1) << "MLIR verification failed."; - module_->dump(); - return tsl::errors::InvalidArgument("Error verifying module"); + VLOG(3) << "Parsed serialized module (version " << version + << ", platforms = [" << absl::StrJoin(platforms, ", ") + << "], loading_platform = " << loading_platform + << ", dim_args_spec = [" << absl::StrJoin(dim_args_spec_, ", ") + << "], disabled_checks = [" << absl::StrJoin(disabled_checks, ", ") + << "], loading_disabled_checks = [" + << absl::StrJoin(loading_disabled_checks, ", ") << "]), module = " + << DumpMlirOpToFile("xla_call_module.parsed", *module_); + + if (version < VERSION_MINIMUM_SUPPORTED) { + return absl::InvalidArgumentError(absl::StrCat( + "XlaCallModuleOp with version ", version, + " is not supported anymore. Must be >= ", VERSION_MINIMUM_SUPPORTED)); + } + if (version > VERSION_MAXIMUM_SUPPORTED) { + return absl::InvalidArgumentError( + absl::StrCat("XlaCallModuleOp with version ", version, + " is not supported by this build. Must be <= ", + VERSION_MAXIMUM_SUPPORTED)); + } + + platform_index_ = -1; + if (!platforms.empty()) { + auto found_platform = + std::find(platforms.begin(), platforms.end(), loading_platform); + if (found_platform == platforms.end()) { + if (!IsPlatformCheckDisabled(loading_disabled_checks)) { + return absl::NotFoundError(absl::StrCat( + "The current platform ", loading_platform, + " is not among the platforms required by the module: [", + absl::StrJoin(platforms, ", "), "]")); + } else { + if (platforms.size() > 1) { + platform_index_ = 0; + } + } + } else { + // We only use a platform index arguments if we support at least 2 + // platforms. + if (platforms.size() > 1) { + platform_index_ = found_platform - platforms.begin(); + } + } + } + + if (version >= VERSION_START_SUPPORT_CALL_TF_GRAPH && + !dim_args_spec_.empty()) { + return absl::InvalidArgumentError( + "dim_args_spec not supported in this version"); + } + { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); + if (mlir::failed(mlir::verify(*module_))) { + return absl::InvalidArgumentError(absl::StrCat( + "Error verifying module: ", diag_handler.ConsumeStatus().ToString())); + } } main_ = module_->lookupSymbol("main"); if (!main_) { - return tsl::errors::InvalidArgument("Cannot find 'main' in module"); + return absl::InvalidArgumentError("Cannot find 'main' in module"); } if (!dim_args_spec_.empty() || platform_index_ >= 0) { @@ -423,9 +536,9 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( return tsl::OkStatus(); } -tsl::Status XlaCallModuleLoader::ValidateModule() { +tsl::Status XlaCallModuleLoader::ValidateDialect() { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); bool moduleHasUnsupportedDialects = false; - bool moduleHasDynamicShapes = false; module_->walk([&](mlir::Operation *op) { // StableHLO programs created by jax2tf only contain operations @@ -434,10 +547,23 @@ tsl::Status XlaCallModuleLoader::ValidateModule() { mlir::func::FuncDialect, mlir::stablehlo::StablehloDialect>( op->getDialect())) { moduleHasUnsupportedDialects = true; - VLOG(3) << "Operation has unsupported dialects: " - << mlir::debugString(*op); + op->emitOpError() << "is an op from an unsupported dialect"; } + }); + if (moduleHasUnsupportedDialects) { + return absl::InvalidArgumentError( + absl::StrCat("Module has unsupported dialects: ", + diag_handler.ConsumeStatus().ToString())); + } + return tsl::OkStatus(); +} + +tsl::Status XlaCallModuleLoader::ValidateStaticShapes() { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); + bool moduleHasDynamicShapes = false; + + module_->walk([&](mlir::Operation *op) { // It's sufficient to only check results because operands either come from // results or from block arguments which are checked below. auto hasDynamicShape = [](mlir::Value value) { @@ -452,22 +578,53 @@ tsl::Status XlaCallModuleLoader::ValidateModule() { } if (opHasDynamicShapes) { moduleHasDynamicShapes = true; - VLOG(3) << "Operation has dynamic shapes: " << mlir::debugString(*op); + op->emitOpError() << "has dynamic shapes"; } }); - if (moduleHasUnsupportedDialects) - return tsl::errors::InvalidArgument("Module has unsupported dialects"); - if (moduleHasDynamicShapes) - return tsl::errors::InvalidArgument("Module has dynamic shapes"); + if (moduleHasDynamicShapes) { + return absl::InvalidArgumentError( + absl::StrCat("Module has dynamic shapes: ", + diag_handler.ConsumeStatus().ToString())); + } return tsl::OkStatus(); } +absl::Status XlaCallModuleLoader::LowerModuleToMhlo() { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); + + mlir::PassManager pm(module_->getContext()); + applyTensorflowAndCLOptions(pm); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass( + mlir::mhlo::createLegalizeSparseChloToLinalgPass()); + pm.addNestedPass(mlir::mhlo::createChloLegalizeToHloPass( + /*legalizeBroadcasts=*/true, /*expandCompositions=*/true)); + pm.addNestedPass(mlir::createCanonicalizerPass()); + // In order to export to XLA, we must sink constants to control flow + // regions, since XLA uses functional control flow. + pm.addNestedPass( + mlir::mhlo::createSinkConstantsToControlFlowPass()); + if (failed(pm.run(*module_))) { + return absl::InternalError( + absl::StrCat("MHLO->HLO lowering passes failed: ", + diag_handler.ConsumeStatus().ToString())); + } + + if (VLOG_IS_ON(5)) { + DumpMlirOpToFile("xla_call_module.after_mhlo_lowering", *module_); + } + + return absl::OkStatus(); +} + tsl::StatusOr XlaCallModuleLoader::ToXlaComputation() { - xla::XlaComputation xla_computation; + xla::HloProto proto; + mlir::MlirToHloConversionOptions options; TF_RETURN_IF_ERROR( - MlirToXlaComputation(*module_, xla_computation, false, false)); - return xla_computation; + mlir::ConvertMlirHloToHlo(*module_, &proto, /*use_tuple_args=*/false, + /*return_tuple=false*/ false, options)); + return xla::XlaComputation(std::move(*proto.mutable_hlo_module())); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h index 6196cfe1f20..54aaa6ae58f 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -34,7 +35,9 @@ class XlaCallModuleLoader { public: static tsl::StatusOr> Create( mlir::MLIRContext* context, int version, std::string module_str, - std::vector dim_args_spec, int platform_index); + std::vector dim_args_spec, + std::vector disabled_checks, + std::vector platforms, std::string loading_platform); int nr_outputs() { return main_.getNumResults(); } mlir::TypeRange output_types() { return main_.getResultTypes(); } @@ -52,13 +55,26 @@ class XlaCallModuleLoader { // cause lifetime issues. tsl::Status RefineDynamicShapes(llvm::ArrayRef input_shapes); - // Validate that the module represents a statically-shaped StableHLO program, + // Validates that the module only contains ops from valid dialects. + tsl::Status ValidateDialect(); + + // Validates that the module represents a statically-shaped StableHLO program, // otherwise all sorts of weirdness might happen in the HLO exporter which is // much easier to detect here. - tsl::Status ValidateModule(); + tsl::Status ValidateStaticShapes(); + // Lowers the StableHLO module to MHLO in place. + absl::Status LowerModuleToMhlo(); + + // Lowers the MHLO module to XlaComputation and returns it. + // + // REQUIRES: `LowerModuleToMhlo()` is called beforehand. tsl::StatusOr ToXlaComputation(); + // Returns the deserialized stablehlo module. + mlir::ModuleOp module() & { return *module_; } + mlir::OwningOpRef module() && { return std::move(module_); } + private: XlaCallModuleLoader() = default; @@ -66,7 +82,9 @@ class XlaCallModuleLoader { tsl::Status LoadAndPreprocessModule(mlir::MLIRContext* context, int version, std::string module_str, std::vector dim_args_spec, - int platform_index); + std::vector disabled_checks, + std::vector platforms, + std::string loading_platform); // Adds a wrapper for the "main" function to compute the platform index and // the dimension arguments. @@ -75,6 +93,8 @@ class XlaCallModuleLoader { mlir::MLIRContext* context_; int version_; mlir::OwningOpRef module_; + // Index in platforms of the current platform, or -1 if module does not take + // a platform index arg. int platform_index_; std::vector dim_args_spec_; mlir::func::FuncOp main_; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index fbb853528fc..8e6c6e9af93 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -19,22 +19,109 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace { +// Imports the given `XlaComputation` into StableHLO functions the MLIR module. +// Returns the MLIR function in the imported module that represents the entry +// function of the imported computation. +absl::StatusOr ImportXlaComputation( + mlir::SymbolTableCollection &symbol_table_collection, mlir::ModuleOp module, + const xla::XlaComputation &computation) { + mlir::MLIRContext *context = module.getContext(); + mlir::SymbolTable &symbol_table = + symbol_table_collection.getSymbolTable(module); + + mlir::OwningOpRef imported = + mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + context->loadDialect(); + context->loadDialect(); + TF_RETURN_IF_ERROR( + xla::ConvertHloToMlirHlo(*imported, &computation.proto(), + /*import_all_computations=*/true)); + if (VLOG_IS_ON(5)) { + DumpMlirOpToFile("xla_call_module.imported_tf_func", *imported); + } + + // Rename all functions beforehand in order to avoid conflicts. + mlir::StringAttr main_func_name; + for (auto func : imported->getOps()) { + mlir::StringAttr name = func.getSymNameAttr(); + mlir::StringAttr new_name = name; + for (int i = 0; symbol_table.lookup(new_name) != nullptr; ++i) { + new_name = mlir::StringAttr::get( + context, absl::StrCat(absl::string_view(name.getValue()), i)); + } + if (new_name != name) { + if (failed(mlir::SymbolTable::replaceAllSymbolUses(func, new_name, + *imported))) { + return absl::InternalError( + absl::StrCat("Failed to replace all symbol uses of function '", + absl::string_view(func.getName()), "'")); + } + func.setSymNameAttr(new_name); + } + if (name.getValue() == "main") { + main_func_name = new_name; + } + } + if (!main_func_name) { + return absl::InternalError( + "HLO module lowered from TF function is missing a main function"); + } + + mlir::func::FuncOp main_func; + for (auto func : imported->getOps()) { + auto cloned = func.clone(); + cloned.setPrivate(); + symbol_table.insert(cloned); + if (func.getSymNameAttr() == main_func_name) { + main_func = cloned; + } + } + + return main_func; +} + class XlaCallModuleOp : public XlaOpKernel { public: explicit XlaCallModuleOp(OpKernelConstruction *ctx) : XlaOpKernel(ctx) { @@ -50,73 +137,105 @@ class XlaCallModuleOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec)); OP_REQUIRES(ctx, expected_output_shapes.size() == expected_output_dtypes.size(), - errors::InvalidArgument("The size of Sout (", - expected_output_shapes.size(), - ") must match the size of Tout (", - expected_output_dtypes.size(), ")")); + absl::InvalidArgumentError(absl::StrCat( + "The size of Sout (", expected_output_shapes.size(), + ") must match the size of Tout (", + expected_output_dtypes.size(), ")"))); + std::vector disabled_checks; + OP_REQUIRES_OK(ctx, ctx->GetAttr("disabled_checks", &disabled_checks)); std::vector platforms; - // Index in platforms of the current platform, or -1 if module does not take - // a platform index arg. - int platform_index = -1; - if (ctx->HasAttr("platforms")) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("platforms", &platforms)); - if (!platforms.empty()) { - string current_device_type = ctx->device_type().type_string(); - string current_platform = ""; - if (current_device_type == DEVICE_CPU_XLA_JIT) { - current_platform = "CPU"; - } else if (current_device_type == DEVICE_GPU_XLA_JIT) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("platforms", &platforms)); + + string loading_device_type = ctx->device_type().type_string(); + string loading_platform = ""; + if (loading_device_type == DEVICE_CPU_XLA_JIT) { + loading_platform = "CPU"; + } else if (loading_device_type == DEVICE_GPU_XLA_JIT) { #if GOOGLE_CUDA - current_platform = "CUDA"; + loading_platform = "CUDA"; #elif TENSORFLOW_USE_ROCM - current_platform = "ROCM"; + loading_platform = "ROCM"; #else - OP_REQUIRES(ctx, false, - errors::Unimplemented("CUDA or ROCM build required")); + OP_REQUIRES(ctx, false, + absl::UnimplementedError("CUDA or ROCM build required")); #endif - } else if (current_device_type == DEVICE_TPU_XLA_JIT) { - current_platform = "TPU"; - } else { - OP_REQUIRES(ctx, false, - errors::Unimplemented("Unexpected device type ", - current_device_type)); - } - VLOG(3) << "Initialized XlaCallModuleOp on " << current_platform; - auto found_platform = - std::find(platforms.begin(), platforms.end(), current_platform); - OP_REQUIRES(ctx, found_platform != platforms.end(), - errors::NotFound( - "The current platform ", current_platform, - " is not among the platforms required by the module: [", - absl::StrJoin(platforms, ", "), "]")); - // We only use a platform index arguments if we support at least 2 - // platforms. - if (platforms.size() > 1) { - platform_index = found_platform - platforms.begin(); - } - } + } else if (loading_device_type == DEVICE_TPU_XLA_JIT) { + loading_platform = "TPU"; + } else { + OP_REQUIRES(ctx, false, + absl::UnimplementedError(absl::StrCat( + "Unexpected device type ", loading_device_type))); + } + VLOG(3) << "Initialized XlaCallModuleOp on " << loading_platform; + { + auto loader = XlaCallModuleLoader::Create( + &context_, version, std::move(module_str), std::move(dim_args_spec), + std::move(disabled_checks), std::move(platforms), loading_platform); + OP_REQUIRES_OK(ctx, loader.status()); + loader_ = *std::move(loader); + } + OP_REQUIRES_OK(ctx, loader_->ValidateDialect()); + + if (!ctx->GetAttr("function_list", &function_list_).ok()) { + function_list_.clear(); } - auto loader = - XlaCallModuleLoader::Create(&context_, version, std::move(module_str), - std::move(dim_args_spec), platform_index); - OP_REQUIRES_OK(ctx, loader.status()); - loader_ = *std::move(loader); + if (!ctx->GetAttr("has_token_input_output", &module_has_token_input_output_) + .ok()) { + module_has_token_input_output_ = false; + } + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + token_input_nodes_.clear(); + op_has_token_input_output_ = false; + } else { + op_has_token_input_output_ = !token_input_nodes_.empty(); + } + if (!ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName, + &original_node_name_) + .ok()) { + original_node_name_ = name(); + } } void Compile(XlaOpKernelContext *ctx) override { + XlaCompiler *const compiler = ctx->compiler(); + xla::XlaBuilder *const b = ctx->builder(); + std::vector input_shapes; + if (module_has_token_input_output_) { + input_shapes.push_back(xla::ShapeUtil::MakeTokenShape()); + } for (int i = 0; i < ctx->num_inputs(); ++i) { auto shape = ctx->InputXlaShape(i); OP_REQUIRES_OK(ctx, shape.status()); input_shapes.push_back(*std::move(shape)); } OP_REQUIRES_OK(ctx, loader_->RefineDynamicShapes(input_shapes)); - OP_REQUIRES_OK(ctx, loader_->ValidateModule()); + OP_REQUIRES_OK(ctx, loader_->ValidateStaticShapes()); + OP_REQUIRES_OK(ctx, loader_->LowerModuleToMhlo()); + if (!function_list_.empty()) { + OP_REQUIRES_OK(ctx, LowerTfFunctionCalls(ctx)); + } - std::vector inputs(ctx->num_inputs()); + std::vector inputs; + if (module_has_token_input_output_) { + // The main function expects a token input at the start. + if (!token_input_nodes_.empty()) { + std::vector token_inputs; + for (const string &node_name : token_input_nodes_) { + auto token = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token.status()); + token_inputs.push_back(token.value()); + } + inputs.push_back(xla::AfterAll(b, token_inputs)); + } else { + // Generate a dummy token if the main function expects a token but the + // XlaCallModule doesn't take one. + inputs.push_back(xla::CreateToken(b)); + } + } for (int i = 0, end = ctx->num_inputs(); i < end; ++i) { - inputs[i] = ctx->Input(i); + inputs.push_back(ctx->Input(i)); } auto xla_computation = loader_->ToXlaComputation(); @@ -132,30 +251,268 @@ class XlaCallModuleOp : public XlaOpKernel { xla_computation->proto(), module_config)); xla::HloPrintOptions options; options = xla::HloPrintOptions::ShortParsable(); - VLOG(3) << "XlaCallModule converted to HLO module " - << hlo_module->ToString(options); + XLA_VLOG_LINES(3, absl::StrCat("XlaCallModule converted to HLO module ", + hlo_module->ToString(options))); } - xla::XlaOp output = xla::Call(ctx->builder(), *xla_computation, inputs); + xla::XlaOp output = xla::Call(b, *xla_computation, inputs); // Check that the resulting computation returns the expected shape - OP_REQUIRES_VALUE(xla::Shape found_output_shape, ctx, - ctx->builder()->GetShape(output)); + OP_REQUIRES_VALUE(xla::Shape found_output_shape, ctx, b->GetShape(output)); VLOG(3) << "XlaCallModule compiled output shape : " << xla::ShapeUtil::HumanString(found_output_shape); + std::vector outputs; if (loader_->nr_outputs() == 1) { - ctx->SetOutput(0, output); + outputs.push_back(output); } else { for (int i = 0; i < loader_->nr_outputs(); ++i) { - ctx->SetOutput(i, xla::GetTupleElement(output, i)); + outputs.push_back(xla::GetTupleElement(output, i)); } } + + xla::XlaOp token_output; + if (module_has_token_input_output_) { + // The main function returns a token as the first output. + token_output = outputs.front(); + outputs.erase(outputs.begin()); + auto shape = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape.status()); + OP_REQUIRES(ctx, shape->IsToken(), + absl::FailedPreconditionError( + absl::StrCat("Token output is not token type: ", + xla::ShapeUtil::HumanString(*shape)))); + } + if (op_has_token_input_output_) { + if (token_output.IsUninitialized()) { + // The main function does not return any token, but the XlaCallModule is + // expected to return one. Create a dummy token. + token_output = xla::CreateToken(b); + } + OP_REQUIRES_OK(ctx, + compiler->SetNodeToken(original_node_name_, token_output)); + } + + for (int i = 0; i < outputs.size(); ++i) { + ctx->SetOutput(i, outputs[i]); + } } private: + // Lowers `mhlo.CustomCall` ops representing TF function calls into nested XLA + // computation. The called TF functions are lowered into MHLO and inserted as + // function calls in the main module. + // + // This is implemented here instead of in xla_call_module_loader.cc in order + // to prevent cyclic dependency with TF MLIR passes. + absl::Status LowerTfFunctionCalls(XlaOpKernelContext *ctx) { + mlir::ModuleOp module = loader_->module(); + mlir::SymbolTableCollection symbol_table_collection; + + llvm::SmallDenseSet updated_funcs; + + auto lower = [&](mlir::mhlo::CustomCallOp custom_call) -> absl::Status { + if (custom_call.getCallTargetName() != "tf.call_tf_function") { + return absl::OkStatus(); + } + + NameAttrList f; + bool custom_call_has_token_input_output = false; + { + auto backend_config = custom_call->getAttrOfType( + "tf.backend_config"); + if (!backend_config) { + return absl::InternalError( + "TF function custom call must have 'tf.backend_config' " + "attribute"); + } + + auto called_index = + backend_config.getAs("called_index"); + if (!called_index) { + return absl::InternalError( + "TF function custom call must have 'called_index' in the " + "'tf.backend_config' attribute"); + } + + int index = called_index.getInt(); + if (index < 0 || index >= function_list_.size()) { + return absl::OutOfRangeError(absl::StrCat( + "XlaCallModule has function_list of size ", function_list_.size(), + " but TF function custom call references function #", index)); + } + f = function_list_[index]; + + // Whether the custom call takes a token argument and returns another + // token. Used to model side effects. + if (auto attr = + backend_config.getAs("has_token_input_output"); + attr != nullptr) { + custom_call_has_token_input_output = attr.getValue(); + } + } + + // Lower the called TF function into an HLO module. + + std::vector arguments; + { + mlir::TypeRange input_types(custom_call->getOperandTypes()); + if (custom_call_has_token_input_output) { + if (input_types.empty() || + !input_types.front().isa()) { + return absl::InvalidArgumentError(absl::StrCat( + "stablehlo.custom_call with has_token_input_output = true is " + "expected to take !stablehlo.token as the first argument, but " + "got ", + mlir::debugString(custom_call))); + } + input_types = input_types.drop_front(); + } + for (mlir::Type input_type : input_types) { + XlaCompiler::Argument &argument = arguments.emplace_back(); + argument.kind = XlaCompiler::Argument::kParameter; + TF_RETURN_IF_ERROR(ConvertToDataType(input_type, &argument.type)); + argument.shape = xla::TypeToShape(input_type); + } + + mlir::TypeRange result_types(custom_call->getResultTypes()); + if (custom_call_has_token_input_output) { + if (result_types.empty() || + !result_types.front().isa()) { + return absl::InvalidArgumentError(absl::StrCat( + "stablehlo.custom_call with has_token_input_output = true is " + "expected to return !stablehlo.token as the first result, but " + "got ", + mlir::debugString(custom_call))); + } + } + } + + XlaCompiler::CompileOptions options; + options.use_tuple_arg = true; + options.always_return_tuple = true; + options.is_entry_computation = false; + // Propagate tokens from XlaCallModule to inner computation. + options.add_token_input_output = op_has_token_input_output_; + + XlaCompiler::CompilationResult result; + TF_RETURN_IF_ERROR( + ctx->compiler()->CompileFunction(options, f, arguments, &result)); + + // Import the lowered HLO module into StableHLO functions in `module`. The + // main function accepts tupled arguments and returns tupled results. + TF_ASSIGN_OR_RETURN(mlir::func::FuncOp main_func, + ImportXlaComputation(symbol_table_collection, module, + *result.computation)); + + // Replace the custom call with ops that call the imported main function. + mlir::OpBuilder builder(custom_call); + auto loc = custom_call.getLoc(); + + // Pack all arguments into a tuple (`options.use_tuple_arg` is true). If + // `has_tuple_input_output` is true, the first argument is a token type. + mlir::Value arg_tuple; + { + llvm::SmallVector args(custom_call->getOperands()); + if (custom_call_has_token_input_output) { + // Adjust the indexes since custom calls with `has_token_input_output` + // takes a token as the first argument, but TF2XLA'ed computation + // expects the token to be the last argument. + std::rotate(args.begin(), args.begin() + 1, args.end()); + } else if (options.add_token_input_output) { + // Add a dummy token if the inner computation takes a token but the + // custom call doesn't have a token argument. + args.push_back(builder.create(loc)); + } + + llvm::SmallVector elements; + elements.reserve(result.input_mapping.size()); + for (int index : result.input_mapping) { + elements.push_back(args[index]); + } + arg_tuple = + builder.create(loc, elements).getResult(); + } + + // Call the lowered function. + auto call = builder.create( + loc, main_func, mlir::ValueRange(arg_tuple)); + + // Unpack the result tuple (`options.always_return_tuple` is true). If + // `has_tuple_input_output` is true, the first result is a token type. + { + llvm::SmallVector results(custom_call->getResults()); + if (custom_call_has_token_input_output) { + // Adjust the indexes since custom calls with `has_token_input_output` + // returns a token as the first result, but TF2XLA'ed computation + // returns the token as the last result. + std::rotate(results.begin(), results.begin() + 1, results.end()); + + if (!options.add_token_input_output) { + // If the custom call returns a token but the inner computation + // doesn't, replace the token result with a dummy token. + mlir::Value token = results.back(); + if (!token.use_empty()) { + token.replaceAllUsesWith( + builder.create(loc)); + } + results.pop_back(); + } + } + + for (const auto &it : llvm::enumerate(results)) { + if (!it.value().use_empty()) { + auto get_tuple_element = + builder.create( + loc, call.getResults().front(), it.index()); + it.value().replaceAllUsesWith(get_tuple_element.getResult()); + } + } + } + + updated_funcs.insert(call->getParentOfType()); + custom_call->erase(); + + return absl::OkStatus(); + }; + + absl::Status status; + mlir::WalkResult result = module->walk([&](mlir::mhlo::CustomCallOp op) { + status.Update(lower(op)); + if (!status.ok()) { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return status; + } + + // If the call results are used by `func.return`, then we may need to update + // function result types. + for (auto func : updated_funcs) { + auto ret = llvm::cast( + func.getFunctionBody().front().getTerminator()); + func.setFunctionType(mlir::FunctionType::get( + &context_, func.getArgumentTypes(), ret.getOperandTypes())); + } + + if (VLOG_IS_ON(5)) { + DumpMlirOpToFile("xla_call_module.after_tf_func_call_import", module); + } + return absl::OkStatus(); + } + mlir::MLIRContext context_{mlir::MLIRContext::Threading::DISABLED}; std::unique_ptr loader_; + std::vector function_list_; + + // Whether the StableHLO module's main function has token input/output. + bool module_has_token_input_output_; + // Whether the XlaCallModule op has token input/output. + bool op_has_token_input_output_; + std::vector token_input_nodes_; + std::string original_node_name_; }; REGISTER_XLA_OP(Name("XlaCallModule"), XlaCallModuleOp); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc index 0c67ad7f8de..40770909df7 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 31f8f2840ba..63d10e399f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc index 5e297a8c80b..9871ac537c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/algorithm/container.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc index bf5244e52e6..699e9248eaa 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 748971271bf..c8c03fd2b97 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -35,87 +36,35 @@ xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, double value) { - switch (type) { - case xla::F16: - return xla::ConstantR0(builder, static_cast(value)); - break; - case xla::BF16: - return xla::ConstantR0(builder, static_cast(value)); - break; - case xla::F32: - return xla::ConstantR0(builder, static_cast(value)); - break; - case xla::F64: - return xla::ConstantR0(builder, value); - break; - case xla::C64: - return xla::ConstantR0(builder, value); - break; - case xla::C128: - return xla::ConstantR0(builder, value); - break; - default: - LOG(FATAL) << "unhandled element type " << type; - } + return xla::primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> xla::XlaOp { + if constexpr (xla::primitive_util::IsFloatingPointType( + primitive_type_constant) || + xla::primitive_util::IsComplexType( + primitive_type_constant)) { + using NativeT = + xla::primitive_util::NativeTypeOf; + return xla::ConstantR0(builder, static_cast(value)); + } + LOG(FATAL) << "unhandled element type " << type; + }, + type); } xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64_t value) { - xla::Literal literal; - switch (type) { - case xla::U8: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::U16: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::U32: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::U64: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::S8: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::S16: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::S32: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::S64: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::F32: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::F64: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::C64: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::C128: - literal = xla::LiteralUtil::CreateR0(value); - break; - case xla::PRED: - LOG(FATAL) << "pred element type is not integral"; - case xla::BF16: - literal = xla::LiteralUtil::CreateR0( - static_cast(value)); - break; - case xla::F16: - literal = - xla::LiteralUtil::CreateR0(static_cast(value)); - break; - case xla::TUPLE: - LOG(FATAL) << "tuple element type is not integral"; - case xla::OPAQUE_TYPE: - LOG(FATAL) << "opaque element type is not integral"; - default: - LOG(FATAL) << "unhandled element type " << type; - } + xla::Literal literal = xla::primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> xla::Literal { + if constexpr (xla::primitive_util::IsArrayType( + primitive_type_constant)) { + using NativeT = + xla::primitive_util::NativeTypeOf; + return xla::LiteralUtil::CreateR0( + static_cast(value)); + } + LOG(FATAL) << "unhandled element type " << type; + }, + type); return xla::ConstantLiteral(builder, literal); } diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 3959ebb5771..3ff4fa845d2 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -159,18 +159,12 @@ bool EnableNonTpuBridge(const Graph& graph) { // // The config_proto param is a required input for all TF1 graphs but it is // redundant for TF2 graphs. -MlirOptimizationPassState MlirBridgePass::GetPassState( - const DeviceSet* device_set, const ConfigProto& config_proto, - const Graph& graph, - const FunctionLibraryDefinition& function_library) const { +MlirOptimizationPassState GetPassStateImpl( + bool run_tpu_bridge, const ConfigProto& config_proto, const Graph& graph, + const FunctionLibraryDefinition& function_library) { // Skip MLIR TF/XLA Bridge if no TPU devices and no qualified CPU/GPU // graphs are found. - bool has_tpu_device = device_set ? HasTPUDevice(*device_set) : false; - // GetPassState is called once before MlirBridgePass starts, and the pass - // gets skipped if it is disabled. Log such cases in this function. The cases - // where the pass is enabled will only be logged during their execution to - // prevent them from being counted twice. - if (device_set && !has_tpu_device && !EnableNonTpuBridge(graph)) { + if (!run_tpu_bridge && !EnableNonTpuBridge(graph)) { // Only record CPU/GPU graphs that are qualified but filtered out if (HasQualifiedNonTPUOp(graph)) { metrics::UpdateTfMlirBridgeFirstPhaseCounter( @@ -184,11 +178,17 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( // We set `uses_uninitialized_resource_args` to false here because the first // phase of the bridge is not affected by uninitialized resource args. + // GetMlirBridgeRolloutPolicy will analyze a TPU graph if users have not + // explicltly requested a policy. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( - graph, &function_library, config_proto, /*is_tpu_graph*/ has_tpu_device, + graph, &function_library, config_proto, /*run_tpu_bridge*/ run_tpu_bridge, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/false, /*record_stats=*/false); - if (has_tpu_device) { + // GetPassState is called once before MlirBridgePass starts, and the pass + // gets skipped if it is disabled. Log such cases in this function. The cases + // where the pass is enabled will only be logged during their execution to + // prevent them from being counted twice. + if (run_tpu_bridge) { switch (policy) { case MlirBridgeRolloutPolicy::kEnabledByUser: return MlirOptimizationPassState::Enabled; @@ -236,6 +236,20 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( } } +MlirOptimizationPassState MlirBridgePass::GetPassState( + const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library) const { + if (!device_set) { + // This is not expected in practice. + VLOG(1) << "Device set is empty!"; + return MlirOptimizationPassState::Disabled; + } + + return GetPassStateImpl(/*run_tpu_bridge*/ HasTPUDevice(*device_set), + config_proto, graph, function_library); +} + // This runs the first phase of the "bridge", transforming the graph in a form // that can be executed with delegation of some computations to an accelerator. // This builds on the model of XLA where a subset of the graph is encapsulated @@ -252,22 +266,17 @@ Status MlirBridgePass::Run(const std::string& function_name, // Check if there are TPU devices or TPU ops. If not, then check if the // non TPU graph is qualified to run TF2XLA Bridge. // This check needs to precede GetPassState for instrumentation purposes. - bool is_qualified_for_tpu_bridge = HasTPUDevicesAndOps(module), - is_qualified_for_non_tpu_bridge = false; - if (!is_qualified_for_tpu_bridge) - is_qualified_for_non_tpu_bridge = EnableNonTpuBridge(graph); - if (!is_qualified_for_tpu_bridge && !is_qualified_for_non_tpu_bridge) { + bool run_tpu_bridge = HasTPUDevicesAndOps(module); + if (!run_tpu_bridge && !HasQualifiedNonTPUOp(graph)) { VLOG(1) << "Skipping MLIR TF2XLA Bridge, no qualified devices or ops found."; return OkStatus(); } - // Set device_set to nullptr here as the device specific checks are performed - // based on the devices in the module. // TODO(b/241853328): Add caching of pass state and call logging/metrics // related to graph analysis from here. - auto pass_state = GetPassState(/*device_set=*/nullptr, config_proto, graph, - function_library); + auto pass_state = + GetPassStateImpl(run_tpu_bridge, config_proto, graph, function_library); if (pass_state == MlirOptimizationPassState::Disabled) { // GetPassState is called before run() and run() will only be called if the @@ -278,7 +287,7 @@ Status MlirBridgePass::Run(const std::string& function_name, return OkStatus(); } - if (is_qualified_for_tpu_bridge) { + if (run_tpu_bridge) { bool fallback_enabled = false; if (pass_state == MlirOptimizationPassState::FallbackEnabled) { // We set `uses_uninitialized_resource_args` to false here because the @@ -310,7 +319,7 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( // phase of the bridge is not affected by uninitialized resource args. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( graph, /*function_library=*/&function_library, config_proto, - /*is_tpu_graph*/ true, + /*run_tpu_bridge*/ true, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/true, /*record_stats=*/false); switch (policy) { diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index e536ffa3746..2d80bf3c2a2 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -1326,6 +1326,9 @@ REGISTER_OP("XlaCallModule") .Attr("dim_args_spec: list(string) = []") .Attr("platforms: list(string) = []") .Attr("function_list: list(func) = []") + .Attr("has_token_input_output: bool = false") + .Attr("disabled_checks: list(string) = []") + .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector args_shapes; TF_RETURN_IF_ERROR(c->input("args", &args_shapes)); @@ -1361,7 +1364,8 @@ version: Tracks changes the semantics of the op, to support backwards version 3, the op also supports the `platforms` attribute. From version 4, the op carries a StableHLO module with compatibility guarantees. From version 5, XLACallModule can include `stablehlo.custom_call` op to execute tf - functions. + functions. From version 6 the op supports the `disabled_checks` attribute. + See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code. module: A serialized computation, a text or bytecode representation of an mlir.Module. The return type must be a tuple if and only if the `Sout` is a list with 0 or more than 1 elements. The length of `Tout` and @@ -1369,15 +1373,16 @@ module: A serialized computation, a text or bytecode representation of module returns a single result. Tout: List of output tensor data types. Sout: List of output tensor shapes. -platforms: the list of platforms supported by `module`. If the list is empty, - the `module` is platform independent or there should be no platform checking - or preprocessing. The list can contain the strings "CPU", "CUDA", "ROCM", - or "TPU". - If the list is not empty then it is an error to compile this op for a - platform that does not appear in the list. If the list contains more than +platforms: the list of platforms supported by `module`. The list can contain + the strings "CPU", "CUDA", "ROCM", or "TPU". It is an error to compile + this op for a platform that does not appear in the list. This check can be + disabled using `disabled_checks`. If the list contains more than one platform, then the `module` takes one additional 0-dimensional integer-tensor parameter in the first position, encoding the index in - `platforms` of the current compilation platform. + `platforms` of the current compilation platform. This parameter has value 0 + if the plaform is not among `platforms` and the check has been disabled. + The list can be empty in old versions (earlier than 6) to denote that no + platform checking must be performed at loading time. dim_args_spec: in presence of dynamic shapes, this is the specification for the dimension arguments. In absence of dynamic shapes this list is empty. The `module` takes one 0-dimensional integer tensor dimension argument for each @@ -1386,11 +1391,26 @@ dim_args_spec: in presence of dynamic shapes, this is the specification for the string of the form "." that specifies that the value of the corresponding dimension argument must be "args[arg_idx].shape[axis_idx]", where "args" are the actual array arguments. + This attribute is not used anymore in modules serialized with version 5 + after March 28th, 2023 and JAX OSS versions higher than 0.4.6. + TODO(b/283439649): remove support for dim_args_spec. function_list: This list contains the TensorFlow FunctionDefs that are used by the XLACallModule. If the XLACallModule contains `stablehlo.custom_call` operations, they can call TensorFlow graph functions outside of the XLACallModule. This `function_list` attribute registers the dependency of the XLACallModule on those functions. This attribute was added in version 5. +has_token_input_output: If true, the embedded StableHLO module's main function + must take a `!stablehlo.token` as its first argument and returns a token as + its first result. This can be used in conjunction with the TF2XLA's side + effect mechanism in order to model side effects. +disabled_checks: A list of strings describing the safety checks that were + disabled at serialization time. This attribute was added in version 6. The + following directives are recognized: "platform" (allow a compilation platform + that is not among the `platforms`); "custom_call:xxx" (allow a custom call + with target function name "xxx" even if it is not known to JAX to be stable). + This list, supplemented with a comma-separate list of directives specified + using the flag --tf_xla_call_module_disabled_checks, + is used at module loading time to skip the corresponding checks. )doc"); } // namespace diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 815fc42b44a..c5e7ca1fbf4 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -33,5 +33,6 @@ tf_custom_op_py_library( deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/python/ops/numpy_ops:np_utils", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 61d2be76ac1..620535d40c2 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -604,12 +604,54 @@ def custom_call_v2( ) -def call_module(args, *, version=4, module, Tout, Sout, - dim_args_spec=(), platforms=(), function_list=()): - # See documentation for the XlaCallModule op. - return gen_xla_ops.xla_call_module( - args, version=version, module=module, dim_args_spec=dim_args_spec, - Tout=Tout, Sout=Sout, platforms=platforms, function_list=function_list) +# pylint: disable=g-doc-args +# pylint: disable=g-doc-return-or-yield +def call_module( + args, + *, + version=4, + module, + Tout, + Sout, + platforms=(), + function_list=(), + has_token_input_output=False, + disabled_checks=(), +): + """See documentation for the XlaCallModule op. + + https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_ops.cc+xlacallmodule&type=code + """ + res = gen_xla_ops.xla_call_module( + args, + version=version, + module=module, + dim_args_spec=(), + Tout=Tout, + Sout=Sout, + platforms=platforms, + function_list=function_list, + has_token_input_output=has_token_input_output, + disabled_checks=disabled_checks, + ) + # Since XLACallModule op is stateful, zero return function will return the TF + # op under tf.function. It creates trouble for downstream codes. + # Here we force it return empty tuple to work around it. + # TODO(johnqiangzhang): Figure out a better way to handle control dependency. + if isinstance(res, ops.Operation): + res = () + return res +# pylint: enable=g-doc-args +# pylint: enable=g-doc-return-or-yield + + +def call_module_maximum_supported_version(): + return 6 + + +def call_module_disable_check_platform(): + # For use with xla_call_module.disabled_checks. + return "platform" def gather(operand, diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 1c24cffa93d..a30c29259d9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -16,29 +16,65 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include +#include #include #include "tensorflow/compiler/xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/runtime/aot_ffi_execution_context.h" namespace tensorflow { +namespace { +// MemrefDesc's are part of the XLA Runtime ABI. Redefine them here (with a +// slightly different name to avoid confusion) because we cannot depend on +// XLA Runtime's headers. +// Note: this is an internal type, to be used exclusively in this file. +struct MemrefHolder { + MemrefHolder(const XlaCompiledCpuFunction::ShapeInfo& shape_info, + void* data_ptr) + : rank(shape_info.num_dimensions), data(data_ptr), offset(0) { + sizes.resize(shape_info.num_dimensions); + strides.resize(shape_info.num_dimensions); + int64_t multiplier = 1; + for (int i = shape_info.num_dimensions - 1; i >= 0; --i) { + int64_t size = shape_info.dimensions[i]; + sizes[i] = size; + strides[i] = multiplier; + multiplier *= size; + } + } + + unsigned rank = 0; + // Note: dtype is not needed here. + void* data = nullptr; + int64_t offset = 0; + std::vector sizes; + std::vector strides; +}; +} // namespace + XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, AllocMode alloc_mode) : raw_function_(static_data.raw_function_), - run_function_(static_data.run_function_), + external_run_function_(static_data.external_run_function_), cpu_executable_(static_data.cpu_executable_), result_index_(static_data.result_index_), buffer_table_(new void*[static_data.num_buffers_]), buffer_infos_(static_data.buffer_infos_), num_buffers_(static_data.num_buffers_), + num_results_(static_data.num_results_), + result_index_table_(static_data.result_index_table_), arg_index_table_(static_data.arg_index_table_), num_args_(static_data.num_args_), num_variables_(static_data.num_variables_), + arg_shape_infos_(static_data.arg_shape_infos_), + result_shape_infos_(static_data.result_shape_infos_), arg_names_(static_data.arg_names_), variable_names_(static_data.variable_names_), result_names_(static_data.result_names_), program_shape_(static_data.program_shape_), - hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { + hlo_profile_printer_data_(static_data.hlo_profile_printer_data_), + use_xla_runtime_(static_data.use_xla_runtime_) { bool allocate_entry_params = alloc_mode == AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS; // Allocate arg and temp buffers. @@ -56,11 +92,75 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, } } +bool XlaCompiledCpuFunction::RunXlaRuntime() { + size_t num_memref_args = num_args_ + num_results_; + std::vector memref_args; + memref_args.reserve(num_memref_args); + + size_t num_ptrs = 1; // execution context. + + // Append arguments. + for (int i = 0; i < num_args_; ++i) { + const ShapeInfo& shape_info = arg_shape_infos_[i]; + memref_args.emplace_back(shape_info, buffer_table_[arg_index_table_[i]]); + num_ptrs += 3 + 2 * shape_info.num_dimensions; + } + + // Append results. + for (int i = 0; i < num_results_; ++i) { + const ShapeInfo& shape_info = result_shape_infos_[i]; + memref_args.emplace_back(shape_info, buffer_table_[result_index_table_[i]]); + num_ptrs += 3 + 2 * shape_info.num_dimensions; + + // Point to this result from the "result" entry in the buffer table. + void** results = static_cast(buffer_table_[result_index_]); + results[i] = buffer_table_[result_index_table_[i]]; + } + + std::vector call_frame; + call_frame.resize(num_ptrs); + size_t ptr_index = 1; + for (const MemrefHolder& memref : memref_args) { + auto cast = [](const void* p) { return const_cast(p); }; + call_frame[ptr_index + 0] = cast(&memref.data); // memref.basePtr + call_frame[ptr_index + 1] = cast(&memref.data); // memref.data + call_frame[ptr_index + 2] = cast(&memref.offset); + unsigned rank = memref.rank; + for (int64_t d = 0; d < rank; ++d) { + call_frame[ptr_index + 3 + d] = cast(&memref.sizes[d]); + call_frame[ptr_index + 3 + d + rank] = cast(&memref.strides[d]); + } + ptr_index += 3 + 2 * rank; + } + + assert(num_ptrs == ptr_index); + + xla::runtime::aot::ExecutionContext execution_context; + execution_context.custom_call_data = &run_options_; + xla::runtime::aot::ExecutionContext* execution_context_ptr = + &execution_context; + call_frame[0] = &execution_context_ptr; + + auto xla_runtime_func = + reinterpret_cast(raw_function_); + xla_runtime_func(call_frame.data()); + if (execution_context.error) { + // No error support in XLA; dump error message to stderr. + std::cerr << "XLA AOT error: " << execution_context.error << ".\n"; + return false; + } + return true; +} + bool XlaCompiledCpuFunction::Run() { - if (run_function_) { + if (use_xla_runtime_) { + return RunXlaRuntime(); + } + if (external_run_function_) { std::vector descriptor_table = MakeXlaRuntimeDescriptorTable(); - return run_function_(cpu_executable_, descriptor_table, &run_options_); + return external_run_function_(cpu_executable_, descriptor_table, + &run_options_); } XlaCustomCallStatus status; raw_function_(buffer_table_[result_index_], &run_options_, nullptr, diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 176f203e924..bde21d559c5 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -55,16 +55,30 @@ namespace tensorflow { // is guaranteed that no thread may call a non-const method. class XlaCompiledCpuFunction { public: - // Type of the raw function, produced by either JIT or AOT. + // Type of the raw XLA Classic function, produced by either JIT or AOT. using RawFunction = void (*)(void* result, const xla::ExecutableRunOptions* run_options, const void** args, void** temps, XlaCustomCallStatus*, int64_t* profile_counters); - using RunFunction = + + // Signature of the XLA Runtime raw function. Used only by XLA Runtime AOT. + using XlaRuntimeRawFunction = void (*)(void**); + + // Signature of an external run function. Used only by XLA Runtime JIT. + using ExternalRunFunction = bool (*)(const xla::cpu::CpuExecutable* cpu_executable, const std::vector& descriptor_table, const xla::ExecutableRunOptions* run_options); + // Simple struct to describe a tensor's shape. + // Note: this is a poor man's substitute for xla::ShapeProto, but we cannot + // depend on protobuf's in this library. + // TODO(ecg): extend ShapeInfo to support tuples, if needed. + struct ShapeInfo { + const int32_t* dimensions = nullptr; + int32_t num_dimensions = 0; + }; + // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for // AOT this is backed by data compiled into the object file. @@ -76,13 +90,20 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function_; - RunFunction run_function_ = nullptr; + ExternalRunFunction external_run_function_ = nullptr; const xla::cpu::CpuExecutable* cpu_executable_ = nullptr; // Contains information about the buffers used by the XLA computation. const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; int32_t num_buffers_ = 0; + // Result parameter i is described by + // buffer_infos[result_index_table[i]]. + const int32* result_index_table_ = nullptr; + + // There are num_results result parameters. + int64_t num_results_ = 0; + // Entry parameter i is described by // buffer_infos[arg_index_table[i]]. const int32* arg_index_table_ = nullptr; @@ -96,6 +117,9 @@ class XlaCompiledCpuFunction { // The 0-based index of the result tuple, in the temp buffers. size_t result_index_ = 0; + const ShapeInfo* arg_shape_infos_ = nullptr; + const ShapeInfo* result_shape_infos_ = nullptr; + // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names_ = nullptr; @@ -115,6 +139,8 @@ class XlaCompiledCpuFunction { // declared so we don't have access to that information here. int64_t profile_counters_size_ = 0; + bool use_xla_runtime_ = false; + // Only XlaCompiledCpuFunction is allowed to read and write the above // fields. friend class XlaCompiledCpuFunction; @@ -166,6 +192,8 @@ class XlaCompiledCpuFunction { return buffer_table_[arg_index_table_[index]]; } + int num_results() const { return num_results_; } + int num_args() const { return num_args_; } int num_variables() const { return num_variables_; } @@ -291,9 +319,9 @@ class XlaCompiledCpuFunction { static_data->raw_function_ = raw_function; } - static void set_static_data_run_function(StaticData* static_data, - RunFunction run_function) { - static_data->run_function_ = run_function; + static void set_static_data_external_run_function( + StaticData* static_data, ExternalRunFunction external_run_function) { + static_data->external_run_function_ = external_run_function; } static void set_static_data_cpu_executable( @@ -312,6 +340,16 @@ class XlaCompiledCpuFunction { static_data->num_buffers_ = num_buffers; } + static void set_static_data_result_index_table( + StaticData* static_data, const int32* result_index_table) { + static_data->result_index_table_ = result_index_table; + } + + static void set_static_data_num_results(StaticData* static_data, + int64_t num_results) { + static_data->num_results_ = num_results; + } + static void set_static_data_arg_index_table(StaticData* static_data, const int32* arg_index_table) { static_data->arg_index_table_ = arg_index_table; @@ -332,6 +370,16 @@ class XlaCompiledCpuFunction { static_data->result_index_ = result_index; } + static void set_static_data_arg_shape_infos(StaticData* static_data, + const ShapeInfo* shape_infos) { + static_data->arg_shape_infos_ = shape_infos; + } + + static void set_static_data_result_shape_infos(StaticData* static_data, + const ShapeInfo* shape_infos) { + static_data->result_shape_infos_ = shape_infos; + } + static void set_static_data_arg_names(StaticData* static_data, const char** arg_names) { static_data->arg_names_ = arg_names; @@ -368,14 +416,19 @@ class XlaCompiledCpuFunction { static_data->profile_counters_size_ = profile_counters_size; } + static void set_static_data_use_xla_runtime(StaticData* static_data, + bool use_xla_runtime) { + static_data->use_xla_runtime_ = use_xla_runtime; + } + private: const RawFunction raw_function_; - // TODO(ecg): RunFunction and CpuExecutable should go away. Instead, we should - // have a pointer or reference to a minimal wrapper around CpuExecutable's - // Execute(), without CpuExecutable's dependences. We could call this wrapper - // "XlaRuntimeRunner". - const RunFunction run_function_; + + // [Optional] External Run() function. + const ExternalRunFunction external_run_function_; + // [Maybe Optional] CpuExecutable to be passed to external_run_function_. const xla::cpu::CpuExecutable* cpu_executable_; + const size_t result_index_; // Array containing pointers to argument and temp buffers (slots corresponding @@ -386,6 +439,10 @@ class XlaCompiledCpuFunction { const xla::cpu_function_runtime::BufferInfo* const buffer_infos_; const int32 num_buffers_; + // Indices of expanded result tuple. + const int32 num_results_; + const int32* const result_index_table_; + // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. const int32* const arg_index_table_; @@ -396,6 +453,12 @@ class XlaCompiledCpuFunction { // The number of incoming variables. const int32 num_variables_; + // Shapes of the input arguments. + const ShapeInfo* const arg_shape_infos_; + + // Shapes of the results. + const ShapeInfo* const result_shape_infos_; + // Backing memory for buffer_table_ and args_, the latter depending on // AllocMode. void* alloc_buffer_table_ = nullptr; @@ -413,9 +476,13 @@ class XlaCompiledCpuFunction { const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + const bool use_xla_runtime_ = false; + // Creates a descriptor table for XLA Runtime. std::vector MakeXlaRuntimeDescriptorTable(); + bool RunXlaRuntime(); + // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the // `set_static_data_*` static methods above. friend class XlaJitCompiledCpuFunction; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ef7c45f0a4b..cc951fe375e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -16,14 +16,20 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include +#include #include #include +#include +#include #include +#include +#include #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/types/variant.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" @@ -572,7 +578,7 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, // function in flib_runtime_. auto status = GetFunctionBody(function, local_flib_runtime_, fbody); if (!status.ok()) { - if (!errors::IsNotFound(status)) { + if (!absl::IsNotFound(status)) { return status; } TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -750,8 +756,8 @@ Status XlaCompiler::CompileSingleOp( auto compile_with_old_bridge = [&]() { *result = {}; - return CompileGraph(compile_options, node_def.name(), std::move(graph), - args, result); + return ADD_SOURCE_LOCATION(CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result)); }; const ConfigProto* config = &(single_op_compile_argument.config_proto); @@ -1426,6 +1432,11 @@ Status XlaCompiler::CompileGraph( std::unique_ptr graph, absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; + if (VLOG_IS_ON(2)) { + VLOG(2) << "XlaCompiler::CompileGraph: " + << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); + } DummyStackTrace stack_trace; for (auto node : graph->nodes()) { @@ -1443,12 +1454,6 @@ Status XlaCompiler::CompileGraph( graph.get(), local_flib_def_.get(), pflr_->GetFunctionLibraryDefinition())); - if (VLOG_IS_ON(2)) { - VLOG(2) << "XlaCompiler::CompileGraph: " - << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph, - flib_runtime_->GetFunctionLibraryDefinition()); - } - // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); @@ -1456,8 +1461,8 @@ Status XlaCompiler::CompileGraph( // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, options_.device_type, name)); - xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext(this, &builder, graph.get()); + auto builder = std::make_unique(name); + XlaContext* context = new XlaContext(this, builder.get(), graph.get()); core::ScopedUnref context_unref(context); std::vector real_args(args.begin(), args.end()); @@ -1479,7 +1484,7 @@ Status XlaCompiler::CompileGraph( std::vector arg_expressions; TF_RETURN_IF_ERROR(BuildArguments( - *graph, real_args, options.use_tuple_arg, &builder, context, + *graph, real_args, options.use_tuple_arg, builder.get(), context, arg_shardings, &arg_expressions, &result->input_mapping, &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); @@ -1505,7 +1510,7 @@ Status XlaCompiler::CompileGraph( // Original token is manually created. if (HasSideEffectingNodes(*graph)) { TF_RETURN_IF_ERROR( - SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); + SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(builder.get()))); } } @@ -1523,7 +1528,8 @@ Status XlaCompiler::CompileGraph( TF_RETURN_IF_ERROR(token_or.status()); token_inputs.push_back(token_or.value()); } - token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs))); + token_output = std::make_unique( + xla::AfterAll(builder.get(), token_inputs)); } TF_RETURN_IF_ERROR(PopNodeTokenMapping()); @@ -1532,7 +1538,8 @@ Status XlaCompiler::CompileGraph( result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); std::vector retvals = context->retvals(); - ConvertConstantsToExpressions(&builder, absl::Span(retvals)); + ConvertConstantsToExpressions(builder.get(), + absl::Span(retvals)); XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns{ UseNoPreferenceLayoutFn(), IdentityShapeRepresentationFn()}; TF_RETURN_IF_ERROR(BuildComputation( @@ -1543,7 +1550,7 @@ Status XlaCompiler::CompileGraph( options.is_entry_computation, options.return_updated_values_for_all_resources, options.always_return_tuple, options.use_tuple_arg, - options.alias_resource_update, &builder, result->computation.get(), + options.alias_resource_update, builder.get(), result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, &result->resource_updates, &result->xla_output_shape, result->input_mapping)); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 5c54551707b..a8cca7befd4 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -140,7 +140,8 @@ XlaJitCompiledCpuFunction::Compile( // Compute buffer infos and the result index, needed to run the raw function. std::vector buffer_infos = - xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment); + xla::cpu::CreateBufferInfosFromBufferAssignment(cpu_executable->module(), + buffer_assignment); std::vector arg_index_table = xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, @@ -157,8 +158,8 @@ XlaJitCompiledCpuFunction::Compile( XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_, raw_function); if (cpu_executable->IsXlaRuntime()) { - XlaCompiledCpuFunction::set_static_data_run_function(&jit->static_data_, - RunXlaRuntime); + XlaCompiledCpuFunction::set_static_data_external_run_function( + &jit->static_data_, RunXlaRuntime); XlaCompiledCpuFunction::set_static_data_cpu_executable(&jit->static_data_, cpu_executable); } diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index f29391811ed..7fc7ba9f111 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -121,8 +121,6 @@ xla_cc_test( deps = [ ":bit_cast", ":test", - "//tensorflow/tsl/platform:logging", - "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:test_main", ], ) @@ -219,7 +217,6 @@ xla_cc_test( deps = [ ":test", ":types", - "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:test_main", ], ) @@ -231,7 +228,6 @@ cc_library( visibility = [":friends"], deps = [ ":status", - ":xla_cc_grpc_proto", ":xla_data_proto_cc", ":xla_proto_cc", ], @@ -263,7 +259,6 @@ xla_cc_test( ":test", ":test_helpers", "//tensorflow/tsl/platform:errors", - "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:test_main", ], ) @@ -377,7 +372,6 @@ cc_library( ":types", ":util", "//tensorflow/tsl/platform:env", - "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:path", "//tensorflow/tsl/platform:protobuf", "@com_google_absl//absl/hash", @@ -500,7 +494,6 @@ xla_cc_test( ":types", ":util", ":xla_data_proto_cc", - "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:test_main", ], ) @@ -563,6 +556,7 @@ cc_library( "//tensorflow/tsl/platform:protobuf", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/util:byte_swap_array", + "//third_party/eigen3", "@com_google_absl//absl/base", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", @@ -585,7 +579,6 @@ xla_cc_test( ":types", "//tensorflow/tsl/lib/core:status_test_util", "//tensorflow/tsl/platform:float8", - "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_benchmark", "//tensorflow/tsl/platform:test_main", "@com_google_absl//absl/base", @@ -653,7 +646,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/tsl/platform:logging", - "//tensorflow/tsl/platform:status", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -679,7 +671,6 @@ cc_library( deps = [ ":status", ":types", - "//tensorflow/tsl/platform:logging", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -880,8 +871,6 @@ xla_cc_test( ":types", "//tensorflow/tsl/lib/core:status_test_util", "//tensorflow/tsl/platform:env", - "//tensorflow/tsl/platform:logging", - "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", ], ) @@ -912,10 +901,8 @@ xla_cc_test( ":shape_util", ":test", ":xla_data_proto_cc", - "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_benchmark", "//tensorflow/tsl/platform:test_main", - "@com_google_absl//absl/memory", ], ) @@ -997,12 +984,10 @@ xla_cc_test( ":literal", ":reference_util", ":test", - ":util", ":xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/tsl/platform:test_main", - "@com_google_absl//absl/memory", ], ) @@ -1081,12 +1066,9 @@ xla_cc_test( deps = [ ":xla_proto_cc", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:test", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 3238ffdf53d..55be3df5890 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -38,6 +38,10 @@ namespace xla { namespace array_impl { +template +using overload_for_float = std::enable_if_t< + is_specialized_floating_point_v && std::is_same::value, bool>; + // A type trait that is valid when all elements in a parameter pack are of // integral type. Not using an alias template to work around MSVC 14.00 bug. template @@ -110,12 +114,7 @@ class Array { // Creates a 1D array of a floating-point type (half, bfloat16, float, // or double) from an initializer list of float values. - template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) && - std::is_same::value>::type> + template = true> Array(std::initializer_list values) : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; @@ -128,14 +127,7 @@ class Array { // Creates a 2D array of a floating-point type (float8, half, bfloat16, float, // or double) from an initializer list of float values. - template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) && - std::is_same::value>::type> + template = true> Array(std::initializer_list> values) : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; @@ -166,12 +158,7 @@ class Array { // Creates a 3D array of a floating-point type (half, bfloat16, float, // or double) from an initializer list of float values. - template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) && - std::is_same::value>::type> + template = true> Array(std::initializer_list>> values) : Array(ToInt64Array(values), no_default_init_t{}) { @@ -207,12 +194,7 @@ class Array { // Creates a 4D array of a floating-point type (half, bfloat16, float, // or double) from an initializer list of float values. - template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) && - std::is_same::value>::type> + template = true> Array(std::initializer_list< std::initializer_list>>> values) @@ -510,6 +492,17 @@ class Array { // Performs a permutation of dimensions. void TransposeDimensions(absl::Span permutation) { + return TransposeDimensionsImpl(permutation); + } + void TransposeDimensions(absl::Span permutation) { + return TransposeDimensionsImpl(permutation); + } + void TransposeDimensions(std::initializer_list permutation) { + return TransposeDimensionsImpl(permutation); + } + template >* = nullptr> + void TransposeDimensionsImpl(absl::Span permutation) { CHECK_EQ(sizes_.size, permutation.size()); OwnedBuffer permuted_dims(permutation.size()); for (int64_t i = 0; i < permutation.size(); ++i) { diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 2409fe6268b..834d602956b 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -52,15 +52,7 @@ class Array2D : public Array { // Creates an array of a floating-point type (float8, half, bfloat16, float, // or double) from the given nested initializer list of float values. - template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) && - std::is_same::value>::type> + template = true> Array2D(std::initializer_list> values) : Array(values) {} diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index 5dec480fdad..2fb39461dd9 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -57,12 +57,7 @@ class Array3D : public Array { // Creates an array of a floating-point type (half, bfloat16, float, // or double) from the given nested initializer list of float values. - template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) && - std::is_same::value>::type> + template = true> Array3D( std::initializer_list>> values) diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index 3e75023fd9c..e86f212cd2e 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -82,12 +82,7 @@ class Array4D : public Array { // Creates an array of a floating-point type (half, bfloat16, float, // or double) from the given nested initializer list of float values. - template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) && - std::is_same::value>::type> + template = true> Array4D(std::initializer_list>>> values) diff --git a/tensorflow/compiler/xla/backends/interpreter/BUILD b/tensorflow/compiler/xla/backends/interpreter/BUILD index f431901a4e8..4017caaf013 100644 --- a/tensorflow/compiler/xla/backends/interpreter/BUILD +++ b/tensorflow/compiler/xla/backends/interpreter/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:topk_rewriter", "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/stream_executor", diff --git a/tensorflow/compiler/xla/backends/interpreter/compiler.cc b/tensorflow/compiler/xla/backends/interpreter/compiler.cc index f614c6be078..d62b46f9051 100644 --- a/tensorflow/compiler/xla/backends/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/backends/interpreter/compiler.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/topk_rewriter.h" #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -81,6 +82,7 @@ StatusOr HandleEvaluatorCustomCall( Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/backends/profiler/BUILD b/tensorflow/compiler/xla/backends/profiler/BUILD index c02ea2c3d0c..5327adc4175 100644 --- a/tensorflow/compiler/xla/backends/profiler/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/tsl:tsl.bzl", "if_libtpu", "tsl_gpu_library") +load("//tensorflow/tsl:tsl.bzl", "if_with_tpu_support", "tsl_gpu_library") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -20,8 +20,10 @@ tsl_gpu_library( deps = [ "//tensorflow/compiler/xla/backends/profiler/cpu:host_tracer", "//tensorflow/compiler/xla/backends/profiler/cpu:metadata_collector", - ] + if_libtpu([ - "//tensorflow/compiler/xla/backends/profiler/tpu:tpu_tracer", - ]), + ] + if_with_tpu_support( + [ + "//tensorflow/compiler/xla/backends/profiler/tpu:tpu_tracer", + ], + ), alwayslink = True, ) diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD index 5ae87af9e0a..a44c5735b5b 100644 --- a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD @@ -25,11 +25,11 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/xla:internal"], features = [ "-layering_check", ], - licenses = ["notice"], ) tsl_gpu_library( diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index bea85c5179d..a29c8df9d8b 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -125,7 +125,6 @@ cc_library( ":xla_computation", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_tree", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:backend", @@ -142,7 +141,6 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", "//tensorflow/compiler/xla/stream_executor/host:host_platform", "@com_google_absl//absl/types:span", - "@llvm-project//llvm:TargetParser", ], ) @@ -296,7 +294,8 @@ xla_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:pattern_matcher_gmock", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/tsl/platform:test", ], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index bfd0e75d143..8f000fb3d8b 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index afcc953a9f2..fb9b1c19be1 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_H_ #include +#include +#include #include #include "absl/types/span.h" diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 034ce6e927a..1868c59bc3c 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include +#include +#include #include #include "tensorflow/compiler/xla/service/backend.h" diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index d86328320b2..af599f8eef8 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -85,7 +85,7 @@ class ClientLibrary { // created, for the given platform. static StatusOr GetOrCreateLocalClient( se::Platform* platform = nullptr, - const std::optional>& allowed_devices = std::nullopt); + const std::optional>& device_set = std::nullopt); static StatusOr GetOrCreateLocalClient( const LocalClientOptions& options); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 62d2057b5fd..2cea9024bf5 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/compile_only_client.h" #include +#include #include "llvm/ADT/Twine.h" #include "llvm/TargetParser/Triple.h" diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index 02524eaeb2a..30766ec2dd0 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ +#include +#include + #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" @@ -52,7 +55,7 @@ class CompileOnlyClient : public Client { // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( - const absl::Span computations, + absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata = nullptr); diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 30259c323a2..37b21a7d18e 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include +#include #include +#include #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/debug_options_flags.h" diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index ed4554e9c77..b859589cbf3 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include @@ -138,6 +139,7 @@ class ExecutableBuildOptions { CHECK(device_assignment_.has_value()); return device_assignment_.value(); } + void clear_device_assignment() { device_assignment_.reset(); } // Whether input and output buffers are aliased if the associated parameter is // passed-through XLA modules without being changed. diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 6785be501d5..49aea0c8566 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" +#include #include #include +#include #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 9d46f68b6fb..bc7ba74322a 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -103,6 +103,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/tsl/platform:float8", ], ) diff --git a/tensorflow/compiler/xla/client/lib/approx_topk.cc b/tensorflow/compiler/xla/client/lib/approx_topk.cc index aafbda35c34..10cf46aa95d 100644 --- a/tensorflow/compiler/xla/client/lib/approx_topk.cc +++ b/tensorflow/compiler/xla/client/lib/approx_topk.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/lib/approx_topk_shape.h" diff --git a/tensorflow/compiler/xla/client/lib/approx_topk_shape.cc b/tensorflow/compiler/xla/client/lib/approx_topk_shape.cc index 2d5db586ea0..d0b1164a065 100644 --- a/tensorflow/compiler/xla/client/lib/approx_topk_shape.cc +++ b/tensorflow/compiler/xla/client/lib/approx_topk_shape.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/approx_topk_shape.h" #include +#include +#include #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/client/lib/approx_topk_shape.h b/tensorflow/compiler/xla/client/lib/approx_topk_shape.h index 027acc9defe..d1e2cb62fd5 100644 --- a/tensorflow/compiler/xla/client/lib/approx_topk_shape.h +++ b/tensorflow/compiler/xla/client/lib/approx_topk_shape.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ +#include + #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 058e6b301dc..f4cd43a2127 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include +#include #include +#include #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/lib/constants.h" diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index cdaa4f63b0a..449d7acb516 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ +#include #include +#include #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc index cf20b7c6549..f55aa3db0ee 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include #include #include "tensorflow/compiler/xla/client/xla_builder.h" diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc index 5e628545bad..19403b287de 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.cc +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/comparators.h" #include +#include #include #include diff --git a/tensorflow/compiler/xla/client/lib/comparators.h b/tensorflow/compiler/xla/client/lib/comparators.h index 33a6a2a2ad0..81d71afa384 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.h +++ b/tensorflow/compiler/xla/client/lib/comparators.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ +#include +#include #include #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -49,7 +51,7 @@ XlaComputation CreateScalarComparisonComputation( const std::string& name, const std::vector& operand_types, const std::vector< std::optional)>>& - comparators, + generators, XlaBuilder* builder); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index aa48c54905e..0752bc99e24 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -15,8 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" +#include + #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/float8.h" namespace xla { @@ -41,24 +46,19 @@ XlaOp One(XlaBuilder* builder, PrimitiveType type) { } XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { - switch (type) { - case F16: - return ConstantR0( - builder, - static_cast(Eigen::NumTraits::epsilon())); - case BF16: - return ConstantR0( - builder, static_cast( - Eigen::NumTraits::epsilon())); - case F32: - return ConstantR0(builder, std::numeric_limits::epsilon()); - case F64: - return ConstantR0(builder, - std::numeric_limits::epsilon()); - default: - return builder->ReportError(InvalidArgument( - "Invalid type for Epsilon (%s).", PrimitiveType_Name(type))); - } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + using NativeT = typename primitive_util::PrimitiveTypeToNative< + primitive_type_constant>::type; + return ConstantR0(builder, + std::numeric_limits::epsilon()); + } + return builder->ReportError(InvalidArgument( + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type))); + }, + type); } XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) { @@ -66,39 +66,35 @@ XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) { } XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { - switch (type) { - case F16: - return ConstantR0(builder, - Eigen::NumTraits::lowest()); - case BF16: - return ConstantR0( - builder, Eigen::NumTraits::lowest()); - case F32: - return ConstantR0(builder, -std::numeric_limits::max()); - case F64: - return ConstantR0(builder, -std::numeric_limits::max()); - default: - return MinValue(builder, type); - } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + using NativeT = typename primitive_util::PrimitiveTypeToNative< + primitive_type_constant>::type; + return ConstantR0(builder, + std::numeric_limits::lowest()); + } + return MinValue(builder, type); + }, + type); } XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) { - switch (type) { - case F16: - return ConstantR0(builder, - std::numeric_limits::min()); - case BF16: - return ConstantR0( - builder, std::numeric_limits::min()); - case F32: - return ConstantR0(builder, std::numeric_limits::min()); - case F64: - return ConstantR0(builder, std::numeric_limits::min()); - default: - return builder->ReportError( - InvalidArgument("Invalid type for MinPositiveNormalValue (%s).", - PrimitiveType_Name(type))); - } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + using NativeT = typename primitive_util::PrimitiveTypeToNative< + primitive_type_constant>::type; + return ConstantR0(builder, + std::numeric_limits::min()); + } + return builder->ReportError( + InvalidArgument("Invalid type for MinPositiveNormalValue (%s).", + PrimitiveType_Name(type))); + }, + type); } XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { @@ -106,44 +102,34 @@ XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { } XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { - switch (type) { - case F16: - return ConstantR0(builder, - Eigen::NumTraits::highest()); - case BF16: - return ConstantR0( - builder, Eigen::NumTraits::highest()); - case F32: - return ConstantR0(builder, std::numeric_limits::max()); - case F64: - return ConstantR0(builder, std::numeric_limits::max()); - default: - return MaxValue(builder, type); - } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + using NativeT = typename primitive_util::PrimitiveTypeToNative< + primitive_type_constant>::type; + return ConstantR0(builder, + std::numeric_limits::max()); + } + return MaxValue(builder, type); + }, + type); } XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { - switch (type) { - case F16: - return ConstantR0( - builder, Eigen::NumTraits::quiet_NaN()); - case BF16: - return ConstantR0( - builder, Eigen::NumTraits::quiet_NaN()); - case F32: - return ConstantR0(builder, - std::numeric_limits::quiet_NaN()); - case F64: - return ConstantR0(builder, - std::numeric_limits::quiet_NaN()); - default: - return InvalidArgument( - "Operand to NanValue was %s, but must be a real-valued " - "floating-point type.", - PrimitiveType_Name(type)); - } - }); + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + using NativeT = typename primitive_util::PrimitiveTypeToNative< + primitive_type_constant>::type; + return ConstantR0(builder, + std::numeric_limits::quiet_NaN()); + } + return builder->ReportError(InvalidArgument( + "Invalid type for NanValue (%s).", PrimitiveType_Name(type))); + }, + type); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 779c40eee48..9fc69f45836 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/float8.h" namespace xla { @@ -45,42 +46,17 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { "Invalid cast from complex type to %s in ConstantR0WithType.", PrimitiveType_Name(type))); } - switch (type) { - case PRED: - return ConstantR0(builder, static_cast(value)); - case F16: - return ConstantR0(builder, static_cast(value)); - case BF16: - return ConstantR0(builder, static_cast(value)); - case F32: - return ConstantR0(builder, static_cast(value)); - case F64: - return ConstantR0(builder, static_cast(value)); - case C64: - return ConstantR0(builder, static_cast(value)); - case C128: - return ConstantR0(builder, static_cast(value)); - case U8: - return ConstantR0(builder, static_cast(value)); - case U16: - return ConstantR0(builder, static_cast(value)); - case U32: - return ConstantR0(builder, static_cast(value)); - case U64: - return ConstantR0(builder, static_cast(value)); - case S8: - return ConstantR0(builder, static_cast(value)); - case S16: - return ConstantR0(builder, static_cast(value)); - case S32: - return ConstantR0(builder, static_cast(value)); - case S64: - return ConstantR0(builder, static_cast(value)); - default: - return builder->ReportError( - InvalidArgument("Invalid type for ConstantR0WithType (%s).", - PrimitiveType_Name(type))); - } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { + using NativeT = primitive_util::NativeTypeOf; + return ConstantR0(builder, static_cast(value)); + } + return builder->ReportError( + InvalidArgument("Invalid type for ConstantR0WithType (%s).", + PrimitiveType_Name(type))); + }, + type); } // Returns a scalar containing 'value' cast to the same run-time type as diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc index 5b034dde320..051cfc898da 100644 --- a/tensorflow/compiler/xla/client/lib/constants_test.cc +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" +#include + #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc index 3dde6cdcafe..8230f848df5 100644 --- a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc +++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" +#include + #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/tsl/platform/errors.h" diff --git a/tensorflow/compiler/xla/client/lib/logdet.cc b/tensorflow/compiler/xla/client/lib/logdet.cc index 3201323f4dc..b77694f0cbe 100644 --- a/tensorflow/compiler/xla/client/lib/logdet.cc +++ b/tensorflow/compiler/xla/client/lib/logdet.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/logdet.h" +#include #include #include diff --git a/tensorflow/compiler/xla/client/lib/logdet_test.cc b/tensorflow/compiler/xla/client/lib/logdet_test.cc index b5f78aea82d..ac61cbfad27 100644 --- a/tensorflow/compiler/xla/client/lib/logdet_test.cc +++ b/tensorflow/compiler/xla/client/lib/logdet_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/logdet.h" +#include + #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" diff --git a/tensorflow/compiler/xla/client/lib/loops.cc b/tensorflow/compiler/xla/client/lib/loops.cc index 7e7426812ee..4da691f3a9d 100644 --- a/tensorflow/compiler/xla/client/lib/loops.cc +++ b/tensorflow/compiler/xla/client/lib/loops.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/loops.h" +#include +#include +#include + #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/client/lib/lu_decomposition.cc b/tensorflow/compiler/xla/client/lib/lu_decomposition.cc index dac74300215..7a52980e599 100644 --- a/tensorflow/compiler/xla/client/lib/lu_decomposition.cc +++ b/tensorflow/compiler/xla/client/lib/lu_decomposition.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/lu_decomposition.h" +#include #include #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 25179617548..95fd4621d0b 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -15,7 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" +#include +#include #include +#include +#include +#include #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" @@ -153,6 +158,9 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F8E5M2: + case F8E4M3FN: + case F8E4M3B11FNUZ: case F16: case BF16: // Not all XLA backends handle U16 well, so we convert to F32/U32. @@ -293,10 +301,11 @@ XlaOp Erfc(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { - return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), - ScalarLike(x, 1) - ErfImpl32Cephes(x)); - }); + return DoWithUpcastToF32( + x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, [](XlaOp x) { + return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), + ScalarLike(x, 1) - ErfImpl32Cephes(x)); + }); }); } @@ -338,7 +347,7 @@ XlaOp Erf(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16, F16}, + return DoWithUpcastToF32(x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, [](XlaOp x) { return ErfImpl32(x); }); }); } @@ -487,7 +496,7 @@ XlaOp ErfInv(XlaOp x) { if (shape.element_type() == F64) { return ErfInv64(x); } - return DoWithUpcastToF32(x, {BF16, F16}, + return DoWithUpcastToF32(x, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, [](XlaOp x) { return ErfInv32(x); }); }); } @@ -616,7 +625,8 @@ XlaOp Lgamma(XlaOp input) { // F16 and BF16 don't provide sufficient precision for intermediate results // here (although it's better than you might expect!), so do the // computations in F32. - return DoWithUpcastToF32(input, {BF16, F16}, do_it); + return DoWithUpcastToF32( + input, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, do_it); }); } @@ -711,7 +721,8 @@ XlaOp Digamma(XlaOp input) { auto& b = *input.builder(); return b.ReportErrorOrReturn([&]() -> StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input)); - return DoWithUpcastToF32(input, {BF16, F16}, do_it); + return DoWithUpcastToF32( + input, {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}, do_it); }); } @@ -965,8 +976,13 @@ XlaOp Igamma(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); - bool needs_upcast = - a_shape.element_type() == F16 || a_shape.element_type() == BF16; + bool needs_upcast = false; + for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}) { + if (a_shape.element_type() == type) { + needs_upcast = true; + break; + } + } if (needs_upcast) { a = ConvertElementType(a, F32); @@ -1012,8 +1028,13 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { a_shape.ToString(), x_shape.ToString()); } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); - bool needs_upcast = - a_shape.element_type() == F16 || a_shape.element_type() == BF16; + bool needs_upcast = false; + for (PrimitiveType type : {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ}) { + if (a_shape.element_type() == type) { + needs_upcast = true; + break; + } + } if (needs_upcast) { a = ConvertElementType(a, F32); diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index ccd4ee2b1cc..cd571a66978 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -15,7 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" +#include +#include +#include #include +#include +#include +#include +#include #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -120,10 +127,6 @@ class MathTypedTest : public MathTest { // // For good measure, we also check pow with an exponent other than 0.5. void TestSqrtPowInequivalence() { - // TODO(b/145798892): test fails on GPU for double values. - if (std::is_same::value) { - return; - } SetFastMathDisabled(true); // Tests disable constant folding by default, but this test needs it @@ -222,10 +225,6 @@ XLA_TEST_F(MathTest, RealFpOnlyOps) { } else { continue; } - if (ty == F8E5M2 || ty == F8E4M3FN || ty == F8E4M3B11FNUZ) { - // TODO(b/259609697): Add FP8 support to math ops - continue; - } for (const auto& test : std::vector, std::string>>({ diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index f8eb44de9aa..eb4a8a5e0b5 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 5ceda40af02..b24feca3ea8 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/strings/string_view.h" diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc index 42340910fd5..7db9f364d3a 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.cc +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/pooling.h" +#include +#include +#include + #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h index 9510193f8a6..3a26c02d0d5 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.h +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ +#include +#include + #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -57,8 +60,7 @@ XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, absl::Span stride, absl::Span> padding, - const TensorFormat& data_format, - const bool counts_include_padding); + const TensorFormat& data_format, bool counts_include_padding); // Returns the list of low and high padding elements in each spatial dimension // for the given 'padding' specification. @@ -72,8 +74,7 @@ XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, absl::Span kernel_size, absl::Span stride, absl::Span> spatial_padding, - const TensorFormat& data_format, - const bool counts_include_padding); + const TensorFormat& data_format, bool counts_include_padding); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc index 44d0091f0c9..496a9a931e1 100644 --- a/tensorflow/compiler/xla/client/lib/pooling_test.cc +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -14,6 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/pooling.h" + +#include +#include + #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index b0b66dd1b0a..4c6d0eae79d 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -15,7 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/prng.h" +#include #include +#include +#include +#include #include #include "tensorflow/compiler/xla/client/lib/constants.h" @@ -597,57 +601,59 @@ XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) { RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); - switch (type) { - case S8: - case U8: - case F16: - case U16: - case S16: - return ThreeFryRngBitNarrow(key, initial_state, shape); - case F32: - case U32: - case S32: - return ThreeFryRngBit32(key, initial_state, shape); - case F64: - case U64: - case S64: - return ThreeFryRngBit64(key, initial_state, shape); - default: - return { - key.builder()->ReportError(Unimplemented( - "Types other than F16, F32, F64, U16, S16, U32, S32, U64 and S64 " - "are not implemented by ThreeFryBitGenerator; got %s", - primitive_util::LowercasePrimitiveTypeName(type))), - initial_state}; - } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> RngOutput { + if constexpr (primitive_util::IsArrayType(primitive_type_constant) && + !primitive_util::IsComplexType(primitive_type_constant) && + primitive_type_constant != PRED) { + const int kBits = primitive_util::BitWidth(primitive_type_constant); + if (kBits < 32) { + return ThreeFryRngBitNarrow(key, initial_state, shape); + } + if (kBits == 32) { + return ThreeFryRngBit32(key, initial_state, shape); + } + if (kBits == 64) { + return ThreeFryRngBit64(key, initial_state, shape); + } + } + return { + key.builder()->ReportError(Unimplemented( + "Types other than F16, F32, F64, U16, S16, U32, S32, U64 and " + "S64 are not implemented by ThreeFryBitGenerator; got %s", + primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; + }, + type); } RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); - switch (type) { - case S8: - case U8: - case F16: - case U16: - case S16: - return PhiloxRngBitNarrow(key, initial_state, shape); - case F32: - case U32: - case S32: - return PhiloxRngBit32(key, initial_state, shape); - case F64: - case U64: - case S64: - return PhiloxRngBit64(key, initial_state, shape); - default: - return { - key.builder()->ReportError(Unimplemented( - "Types other than F16, F32, F64, U16, S16, U32, S32, U64 and S64 " - "are not implemented by PhiloxBitGenerator; got %s", - primitive_util::LowercasePrimitiveTypeName(type))), - initial_state}; - } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> RngOutput { + if constexpr (primitive_util::IsArrayType(primitive_type_constant) && + !primitive_util::IsComplexType(primitive_type_constant) && + primitive_type_constant != PRED) { + const int kBits = primitive_util::BitWidth(primitive_type_constant); + if (kBits < 32) { + return PhiloxRngBitNarrow(key, initial_state, shape); + } + if (kBits == 32) { + return PhiloxRngBit32(key, initial_state, shape); + } + if (kBits == 64) { + return PhiloxRngBit64(key, initial_state, shape); + } + } + return { + key.builder()->ReportError(Unimplemented( + "Types other than F16, F32, F64, U16, S16, U32, S32, U64 and " + "S64 are not implemented by PhiloxBitGenerator; got %s", + primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; + }, + type); } std::pair ScramblePhiloxKey(XlaOp key) { diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index ef60bd74486..35d6d05ac33 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ #include +#include +#include #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index 12ed7af8821..41ed04a7c87 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/qr.h" +#include #include #include diff --git a/tensorflow/compiler/xla/client/lib/quantize_test.cc b/tensorflow/compiler/xla/client/lib/quantize_test.cc index 2dbbd21666c..52668d27ac8 100644 --- a/tensorflow/compiler/xla/client/lib/quantize_test.cc +++ b/tensorflow/compiler/xla/client/lib/quantize_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/quantize.h" #include +#include #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc index 26f15fb3203..8caf8c1784e 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" #include +#include #include #include "tensorflow/compiler/xla/client/lib/arithmetic.h" diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc index 0f47c41975c..6e48ae35cfa 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include +#include +#include + #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index d8a36e22aa9..91c35d2cd4c 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/slicing.h" #include +#include #include #include diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 3befbd311eb..2bf9a27aee5 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index cdd1f4a542a..32a59b6025d 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/sorting.h" +#include + #include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" @@ -188,24 +190,9 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, auto iota = values_and_indices[3]; // Slice value and indices for this partition. - XlaOp start; - switch (index_type) { - case PrimitiveType::S16: - start = Mul(Add(partition, ConstantR0(builder, 1)), - ConstantR0(builder, per_partition_size)); - break; - case PrimitiveType::S32: - start = Mul(Add(partition, ConstantR0(builder, 1)), - ConstantR0(builder, per_partition_size)); - break; - case PrimitiveType::S64: - start = Mul(Add(partition, ConstantR0(builder, 1)), - ConstantR0(builder, per_partition_size)); - break; - default: - LOG(FATAL) << "Unsupported index type " - << PrimitiveType_Name(index_type); - } + XlaOp start = + Mul(Add(partition, One(builder, index_type)), + ConstantR0WithType(builder, index_type, per_partition_size)); XlaOp sliced_input = DynamicSliceInMinorDims(input, {start}, {per_partition_size}); XlaOp sliced_indices = diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index 7d5de392067..8573329b5ae 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -15,7 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/sorting.h" +#include +#include #include +#include +#include #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/client/lib/svd_test.cc b/tensorflow/compiler/xla/client/lib/svd_test.cc index 597c2d7747f..034771d2fb6 100644 --- a/tensorflow/compiler/xla/client/lib/svd_test.cc +++ b/tensorflow/compiler/xla/client/lib/svd_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/svd.h" +#include #include +#include #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index bb13a3b15c3..4f2aeb9438c 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" +#include +#include + #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 3418616fcc9..ae82cd46167 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -18,15 +18,14 @@ limitations under the License. #include #include #include +#include -#include "llvm/TargetParser/Triple.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/stream_pool.h" -#include "tensorflow/compiler/xla/status_macros.h" using xla::source_map_util::InvalidParameterArgument; @@ -167,8 +166,8 @@ LocalExecutable::RunHelper(const absl::Span argument_shapes, // ExecutableRunOptions.eigen_intra_op_thread_pool. // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). - ServiceExecutableRunOptions service_options(run_options, - backend_->StreamBorrower()); + ServiceExecutableRunOptions service_options( + run_options, backend_->StreamBorrowerWithPriority()); return std::make_pair(service_options, std::move(stream)); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 1e2cbf11c60..c79425bca90 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/types/span.h" @@ -49,7 +50,7 @@ class LocalExecutable { // Run the compiled computation with the given arguments and options and // return the result. StatusOr Run( - const absl::Span arguments, + absl::Span arguments, ExecutableRunOptions run_options); // Similar to Run(), but allows for donating argument buffers to the @@ -60,7 +61,7 @@ class LocalExecutable { // Similar to Run(), but need not block the host waiting for the computation // to complete before returning. StatusOr RunAsync( - const absl::Span arguments, + absl::Span arguments, ExecutableRunOptions run_options); // Similar to RunAsync(), but allows for donating argument buffers to the @@ -91,7 +92,7 @@ class LocalExecutable { StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); StatusOr> RunHelper( - const absl::Span argument_shapes, + absl::Span argument_shapes, ExecutableRunOptions run_options); // The ordinal of the device which this executable was compiled for. The @@ -143,7 +144,7 @@ class LocalClient : public Client { // environment variable. StatusOr>> Compile( const XlaComputation& computation, - const absl::Span argument_layouts, + absl::Span argument_layouts, const ExecutableBuildOptions& options); // Same as Compile() above, but return AotCompilationResult objects (instead @@ -151,7 +152,7 @@ class LocalClient : public Client { // LocalExecutable(s) using the Load() method below. StatusOr>> CompileAheadOfTime(const XlaComputation& computation, - const absl::Span argument_layouts, + absl::Span argument_layouts, const ExecutableBuildOptions& options); // Return a LocalExecutable object loaded from a serialized diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 7fec04e2ac5..a78cd490a4c 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include +#include +#include #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/lib/math/math_util.h" diff --git a/tensorflow/compiler/xla/client/padding_test.cc b/tensorflow/compiler/xla/client/padding_test.cc index 1b249596138..79306a40d2e 100644 --- a/tensorflow/compiler/xla/client/padding_test.cc +++ b/tensorflow/compiler/xla/client/padding_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" +#include + #include "tensorflow/tsl/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index e3290f8afd1..718b411a6a9 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" +#include + namespace xla { namespace sharding_builder { diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc index 1b211b5cb54..2ffa00e234c 100644 --- a/tensorflow/compiler/xla/client/value_inference.cc +++ b/tensorflow/compiler/xla/client/value_inference.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -609,8 +610,8 @@ StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( TF_ASSIGN_OR_RETURN( auto computation, HloComputation::CreateFromProto(*computation_proto, {})); - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .WithComputation(std::move(computation)) .WithSubshape(context.shape_index) .Evaluate(); @@ -629,8 +630,8 @@ StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( .AddVisit([](Literal operand) { return operand; }); } return result.AddVisit([root, this](absl::Span operands) { - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .Evaluate(); }); } @@ -763,8 +764,8 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( std::vector new_operands; new_operands.emplace_back(std::move(upper_bound)); new_operands.emplace_back(std::move(lower_bound)); - return HloProtoEvaluator(evaluator, *root) - .WithOperands(absl::MakeSpan(new_operands)) + return std::make_unique(evaluator, *root) + ->WithOperands(absl::MakeSpan(new_operands)) .Evaluate(); }); } @@ -796,8 +797,8 @@ StatusOr PostorderDFSVisitor::AnalyzeUpperBound( .AddDependency(root->operand_ids(1), PostorderDFSNodeType::kConstantValue, context) .AddVisit([root, this](absl::Span operands) { - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .Evaluate(); }); } @@ -874,8 +875,8 @@ StatusOr PostorderDFSVisitor::AnalyzeLowerBound( PostorderDFSNodeType::kConstantUpperBound, context) .AddVisit( [root, this](absl::Span operands) -> StatusOr { - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .Evaluate(); }); } @@ -886,8 +887,8 @@ StatusOr PostorderDFSVisitor::AnalyzeLowerBound( .AddDependency(root->operand_ids(1), PostorderDFSNodeType::kConstantValue, context) .AddVisit([root, this](absl::Span operands) { - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .Evaluate(); }); } @@ -939,8 +940,8 @@ StatusOr PostorderDFSVisitor::AnalyzeConstant( } return result.AddVisit( [root, this](absl::Span operands) -> StatusOr { - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .Evaluate(); }); } @@ -984,8 +985,8 @@ StatusOr PostorderDFSVisitor::AnalyzeConstant( TF_ASSIGN_OR_RETURN( auto computation, HloComputation::CreateFromProto(*computation_proto, {})); - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .WithComputation(std::move(computation)) .WithSubshape(context.shape_index) .Evaluate(); @@ -1149,8 +1150,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { return result.AddVisit([root, this](absl::Span operands) { - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .WithPrimitiveType(PRED) .WithOpCode(HloOpcode::kOr) .Evaluate(); @@ -1176,8 +1177,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( .AddVisit([](Literal operand) { return operand; }); } return result.AddVisit([root, this](absl::Span operands) { - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .WithPrimitiveType(PRED) .Evaluate(); }); @@ -1341,8 +1342,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( reduce_or = b.Build(); } - return HloProtoEvaluator(evaluator, *root) - .WithOperands(operands) + return std::make_unique(evaluator, *root) + ->WithOperands(operands) .WithPrimitiveType(PRED) .WithComputation(std::move(reduce_or)) // Reduce could produce tuple shape, only fetch what we need. @@ -1429,8 +1430,8 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( new_operands.emplace_back( optional_selector_literal.GetValue()->Clone()); - return HloProtoEvaluator(evaluator, *root) - .WithOperands(absl::MakeSpan(new_operands)) + return std::make_unique(evaluator, *root) + ->WithOperands(absl::MakeSpan(new_operands)) .WithPrimitiveType(PRED) .Evaluate(); }); @@ -1655,7 +1656,7 @@ StatusOr ValueInference::SimplifyOp(int64_t handle) { TF_ASSIGN_OR_RETURN(auto* inst, builder_->LookUpInstructionByHandle(handle)); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(inst->opcode())); std::vector operands; - auto output_shape = Shape(inst->shape()); + auto output_shape = std::make_unique(inst->shape()); switch (opcode) { case HloOpcode::kSlice: case HloOpcode::kConcatenate: @@ -1667,8 +1668,8 @@ StatusOr ValueInference::SimplifyOp(int64_t handle) { } // We put handles into the tensor and evaluate the results into a literal. // The literal also contain handles for each element position. - return HloProtoEvaluator(evaluator_, *inst) - .WithOperands(absl::MakeSpan(operands)) + return std::make_unique(evaluator_, *inst) + ->WithOperands(absl::MakeSpan(operands)) .WithPrimitiveType(S64) .Evaluate(); } @@ -1676,23 +1677,23 @@ StatusOr ValueInference::SimplifyOp(int64_t handle) { // Only identity kConvert can be optimized away. auto operand = builder_->LookUpInstructionByHandle(inst->operand_ids(0)).value(); - if (Shape::Equal()(output_shape, Shape(operand->shape()))) { + if (Shape::Equal()(*output_shape, Shape(operand->shape()))) { // Forward operand handle as result. return SimplifyOp(inst->operand_ids(0)); } else { - return CreateS64Literal(-1, output_shape); + return CreateS64Literal(-1, *output_shape); } } case HloOpcode::kAdd: { // a + (b - a) => b // a + b + (c - a) => b + c - if (output_shape.rank() == 0) { + if (output_shape->rank() == 0) { TF_ASSIGN_OR_RETURN(auto lhs, SimplifyOp(inst->operand_ids(0))); TF_ASSIGN_OR_RETURN(auto rhs, SimplifyOp(inst->operand_ids(1))); int64_t lhs_handle = lhs.Get({}); int64_t rhs_handle = rhs.Get({}); if (lhs_handle == -1 || rhs_handle == -1) { - return CreateS64Literal(-1, output_shape); + return CreateS64Literal(-1, *output_shape); } // Recursive lambda needs explicit signature. std::function(int64_t, int64_t)> @@ -1749,14 +1750,14 @@ StatusOr ValueInference::SimplifyOp(int64_t handle) { return LiteralUtil::CreateR0(new_sum.handle()); } else { - return CreateS64Literal(-1, output_shape); + return CreateS64Literal(-1, *output_shape); } } default: { - if (ShapeUtil::IsScalar(output_shape)) { + if (ShapeUtil::IsScalar(*output_shape)) { return LiteralUtil::CreateR0(handle); } else { - return CreateS64Literal(-1, output_shape); + return CreateS64Literal(-1, *output_shape); } } } diff --git a/tensorflow/compiler/xla/client/value_inference.h b/tensorflow/compiler/xla/client/value_inference.h index 2579f65059f..3d371eef283 100644 --- a/tensorflow/compiler/xla/client/value_inference.h +++ b/tensorflow/compiler/xla/client/value_inference.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_VALUE_INFERENCE_H_ #include +#include #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/client/xla_builder.h" diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 339ce5b2ad8..b913f2c4f00 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2603,30 +2603,15 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); Shape output_shape = shape; - switch (output_shape.element_type()) { - case PrimitiveType::S8: - case PrimitiveType::U8: - output_shape.set_element_type(PrimitiveType::U8); - break; - case PrimitiveType::BF16: - case PrimitiveType::F16: - case PrimitiveType::S16: - case PrimitiveType::U16: - output_shape.set_element_type(PrimitiveType::U16); - break; - case PrimitiveType::F32: - case PrimitiveType::S32: - case PrimitiveType::U32: - output_shape.set_element_type(PrimitiveType::U32); - break; - case PrimitiveType::F64: - case PrimitiveType::S64: - case PrimitiveType::U64: - output_shape.set_element_type(PrimitiveType::U64); - break; - default: - return InvalidArgument("Unsupported shape for RngBitGenerator: %s", - PrimitiveType_Name(output_shape.element_type())); + output_shape.set_element_type(PRIMITIVE_TYPE_INVALID); + if (primitive_util::IsArrayType(shape.element_type())) { + output_shape.set_element_type( + primitive_util::UnsignedIntegralTypeForBitWidth( + primitive_util::BitWidth(shape.element_type()))); + } + if (!primitive_util::IsUnsignedIntegralType(output_shape.element_type())) { + return InvalidArgument("Unsupported shape for RngBitGenerator: %s", + PrimitiveType_Name(shape.element_type())); } return RngBitGeneratorInternal( ShapeUtil::MakeTupleShapeWithPtrs({&state_shape, &output_shape}), @@ -4994,6 +4979,18 @@ XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation, use_global_device_ids); } +XlaOp AllReduceTuple(const absl::Span operands, + const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& shape_with_layout, + const std::optional use_global_device_ids) { + CHECK(!operands.empty()); + return operands[0].builder()->AllReduce( + operands[0].builder()->Tuple(operands), computation, replica_groups, + channel_id, shape_with_layout, use_global_device_ids); +} + XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, int64_t shard_count, absl::Span replica_groups, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index bde606e2e35..ad74675704d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -67,16 +67,16 @@ struct XlaBuilderFriend { XlaBuilder* builder, absl::Span operands, std::string execution_thread, const XlaComputation& called_computation, const Shape& shape); - static XlaOp BuildAsyncUpdate(XlaBuilder* builder, const XlaOp operands, + static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, std::string execution_thread, int64_t group_id, int64_t called_computation, const Shape& shape); - static XlaOp BuildAsyncUpdate(XlaBuilder* builder, const XlaOp operands, + static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, std::string execution_thread, int64_t called_computation, const Shape& shape); - static XlaOp BuildAsyncDone(XlaBuilder* builder, const XlaOp operands, + static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, std::string execution_thread, int64_t group_id, int64_t called_computation, const Shape& shape); - static XlaOp BuildAsyncDone(XlaBuilder* builder, const XlaOp operands, + static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, std::string execution_thread, int64_t called_computation, const Shape& shape); @@ -85,8 +85,8 @@ struct XlaBuilderFriend { int64_t shard_count, absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); - static XlaOp BuildAllGatherDone(XlaBuilder* builder, const XlaOp operands, + std::optional use_global_device_ids = std::nullopt); + static XlaOp BuildAllGatherDone(XlaBuilder* builder, XlaOp operands, const Shape& shape); static XlaOp BuildAllReduceStart( @@ -94,22 +94,21 @@ struct XlaBuilderFriend { absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); - static XlaOp BuildAllReduceDone(XlaBuilder* builder, const XlaOp operands, + std::optional use_global_device_ids = std::nullopt); + static XlaOp BuildAllReduceDone(XlaBuilder* builder, XlaOp operands, const Shape& shape); static XlaOp BuildCollectivePermuteStart( XlaBuilder* builder, XlaOp operand, const std::vector>& source_target_pairs, const std::optional& channel_id = std::nullopt); - static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, - const XlaOp operands, + static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp operands, const Shape& shape); static XlaOp BuildCopyStart( XlaBuilder* builder, XlaOp operand, std::optional cross_program_prefetch_index = std::nullopt); - static XlaOp BuildCopyDone(XlaBuilder* builder, const XlaOp operand, + static XlaOp BuildCopyDone(XlaBuilder* builder, XlaOp operand, const Shape& shape); static XlaOp BuildFusion( @@ -135,9 +134,8 @@ struct XlaBuilderFriend { const Shape& shape, const ChannelHandle& handle, bool is_host_transfer); - static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand, - const OpSharding entry, const OpSharding exit, - const Shape& shape); + static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand, OpSharding entry, + OpSharding exit, const Shape& shape); static XlaOp BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta, const Shape& shape); @@ -521,9 +519,8 @@ class XlaBuilder { XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); - XlaOp BroadcastInDim(XlaOp operand, - const absl::Span out_dim_size, - const absl::Span broadcast_dimensions); + XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, + absl::Span broadcast_dimensions); XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); @@ -810,19 +807,18 @@ class XlaBuilder { XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); - XlaOp AllGather( - XlaOp operand, int64_t all_gather_dimension, int64_t shard_count, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); + XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); - XlaOp AllReduce( - XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); + XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); XlaOp ReduceScatter( XlaOp operand, const XlaComputation& computation, @@ -830,7 +826,7 @@ class XlaBuilder { absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); + std::optional use_global_device_ids = std::nullopt); XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, @@ -937,12 +933,11 @@ class XlaBuilder { absl::Span branch_computations, absl::Span branch_operands); - XlaOp ReducePrecision(XlaOp operand, const int exponent_bits, - const int mantissa_bits); + XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); virtual StatusOr ReducePrecisionInternal(const Shape& shape, XlaOp operand, - const int exponent_bits, - const int mantissa_bits); + int exponent_bits, + int mantissa_bits); XlaOp Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, @@ -1083,7 +1078,7 @@ class XlaBuilder { // operation such as `RngNormal` or `Infeed`. The visitor walks the // computation starting at a given operation and sets is_constant to false iff // a parameter or stateful operation is encountered. - void IsConstantVisitor(const int64_t op_handle, int depth, + void IsConstantVisitor(int64_t op_handle, int depth, absl::flat_hash_set* visited, bool* is_constant) const; @@ -1176,9 +1171,9 @@ class XlaBuilder { friend XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); - friend XlaOp BroadcastInDim( - XlaOp operand, const absl::Span out_dim_size, - const absl::Span broadcast_dimensions); + friend XlaOp BroadcastInDim(XlaOp operand, + absl::Span out_dim_size, + absl::Span broadcast_dimensions); friend XlaOp Copy(XlaOp operand); @@ -1439,18 +1434,24 @@ class XlaBuilder { absl::Span replica_groups, const std::optional& channel_id, const std::optional& layout, - const std::optional use_global_device_ids); + std::optional use_global_device_ids); friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, const std::optional& shape_with_layout, - const std::optional use_global_device_ids); + std::optional use_global_device_ids); + friend XlaOp AllReduceTuple(absl::Span operand, + const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& shape_with_layout, + std::optional use_global_device_ids); friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, int64_t shard_count, absl::Span replica_groups, const std::optional& channel_id, const std::optional& layout, - const std::optional use_global_device_ids); + std::optional use_global_device_ids); friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, @@ -1546,8 +1547,8 @@ class XlaBuilder { XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands); - friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits, - const int mantissa_bits); + friend XlaOp ReducePrecision(XlaOp operand, int exponent_bits, + int mantissa_bits); friend XlaOp Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, @@ -1604,15 +1605,13 @@ class XlaBuilder { absl::Span replica_groups, const std::optional& channel_id, const std::optional& layout, - const std::optional use_global_device_ids, - bool async); + std::optional use_global_device_ids, bool async); XlaOp AllReduceImpl(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const std::optional& channel_id, const std::optional& layout, - const std::optional use_global_device_ids, - bool async); + std::optional use_global_device_ids, bool async); XlaOp CollectivePermuteImpl( XlaOp operand, @@ -1848,9 +1847,8 @@ XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); // will generate output // {{1 , 1}, // {2 , 2}} -XlaOp BroadcastInDim(XlaOp operand, - const absl::Span out_dim_size, - const absl::Span broadcast_dimensions); +XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, + absl::Span broadcast_dimensions); // Copies the input operand to the output. This operation is for internal // purpose and is only used by the compiler for optimization purposes or to @@ -2428,7 +2426,7 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); + std::optional use_global_device_ids = std::nullopt); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -2453,14 +2451,21 @@ XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, const std::optional& shape_with_layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); + std::optional use_global_device_ids = std::nullopt); + +XlaOp AllReduceTuple( + absl::Span operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); XlaOp ReduceScatter( XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, int64_t shard_count, absl::Span replica_groups = {}, const std::optional& channel_id = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional use_global_device_ids = std::nullopt); + std::optional use_global_device_ids = std::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. // An optional `layout` can be specified to force the layout of the instruction. @@ -2702,8 +2707,7 @@ XlaOp Conditional(XlaOp branch_index, absl::Span branch_operands); // Enqueues a ReducePrecision node onto the computation. -XlaOp ReducePrecision(XlaOp operand, const int exponent_bits, - const int mantissa_bits); +XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); // Enqueues a Gather node onto the computation. XlaOp Gather(XlaOp input, XlaOp start_indices, diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 1b0eb3bc073..97c7dfcfc7b 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -29,7 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" -#include "tensorflow/compiler/xla/hlo/utils/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -40,7 +41,7 @@ namespace xla { namespace { -namespace op = xla::testing::opcode_matchers; +namespace m = ::xla::match; using ::testing::HasSubstr; @@ -80,58 +81,57 @@ TEST_F(XlaBuilderTest, OnePlusTwo) { Add(ConstantR0(&b, 1.0), ConstantR0(&b, 2.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Add(m::Constant(), m::Constant()))); } TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { - auto test_unary_operator = - [&](std::function op, - ::testing::Matcher matches_pattern) { - XlaBuilder b(TestName()); - op(ConstantR0(&b, 1)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, matches_pattern); - }; - test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant())); - test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant())); + auto test_unary_operator = [&](std::function op, + auto matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unary_operator([](XlaOp x) { return -x; }, + GmockMatch(m::Negate(m::Constant()))); + test_unary_operator([](XlaOp x) { return ~x; }, + GmockMatch(m::Not(m::Constant()))); } TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { - auto test_binary_operator = - [&](std::function op, - ::testing::Matcher matches_pattern) { - XlaBuilder b(TestName()); - op(ConstantR0(&b, 1), ConstantR0(&b, 2)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, matches_pattern); - }; + auto test_binary_operator = [&](std::function op, + auto matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; test_binary_operator([](XlaOp x, XlaOp y) { return x + y; }, - op::Add(op::Constant(), op::Constant())); + GmockMatch(m::Add(m::Constant(), m::Constant()))); test_binary_operator([](XlaOp x, XlaOp y) { return x - y; }, - op::Subtract(op::Constant(), op::Constant())); + GmockMatch(m::Subtract(m::Constant(), m::Constant()))); test_binary_operator([](XlaOp x, XlaOp y) { return x * y; }, - op::Multiply(op::Constant(), op::Constant())); + GmockMatch(m::Multiply(m::Constant(), m::Constant()))); test_binary_operator([](XlaOp x, XlaOp y) { return x / y; }, - op::Divide(op::Constant(), op::Constant())); + GmockMatch(m::Divide(m::Constant(), m::Constant()))); test_binary_operator([](XlaOp x, XlaOp y) { return x & y; }, - op::And(op::Constant(), op::Constant())); + GmockMatch(m::And(m::Constant(), m::Constant()))); test_binary_operator([](XlaOp x, XlaOp y) { return x | y; }, - op::Or(op::Constant(), op::Constant())); + GmockMatch(m::Or(m::Constant(), m::Constant()))); test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; }, - op::Xor(op::Constant(), op::Constant())); + GmockMatch(m::Xor(m::Constant(), m::Constant()))); test_binary_operator([](XlaOp x, XlaOp y) { return x << y; }, - op::ShiftLeft(op::Constant(), op::Constant())); + GmockMatch(m::ShiftLeft(m::Constant(), m::Constant()))); test_binary_operator( [](XlaOp x, XlaOp y) { return x >> y; }, - op::ShiftRightArithmetic(op::Constant(), op::Constant())); + GmockMatch(m::ShiftRightArithmetic(m::Constant(), m::Constant()))); auto test_unsigned_binary_operator = - [&](std::function op, - ::testing::Matcher matches_pattern) { + [&](std::function op, auto matches_pattern) { XlaBuilder b(TestName()); op(ConstantR0(&b, 1), ConstantR0(&b, 2)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); @@ -140,7 +140,7 @@ TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { }; test_unsigned_binary_operator( [](XlaOp x, XlaOp y) { return x >> y; }, - op::ShiftRightLogical(op::Constant(), op::Constant())); + GmockMatch(m::ShiftRightLogical(m::Constant(), m::Constant()))); } TEST_F(XlaBuilderTest, VariadicAnd) { @@ -151,12 +151,12 @@ TEST_F(XlaBuilderTest, VariadicAnd) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); // Don't specify in the test whether And(x, y, z) is right- or // left-associative; accept either one. - EXPECT_THAT( - module->entry_computation()->root_instruction(), - ::testing::AnyOf(op::And(op::Parameter(0), - op::And(op::Parameter(1), op::Parameter(2))), - op::And(op::And(op::Parameter(0), op::Parameter(1)), - op::Parameter(2)))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + ::testing::AnyOf( + GmockMatch(m::And(m::Parameter(0), + m::And(m::Parameter(1), m::Parameter(2)))), + GmockMatch(m::And(m::And(m::Parameter(0), m::Parameter(1)), + m::Parameter(2))))); } TEST_F(XlaBuilderTest, VariadicOr) { @@ -167,12 +167,12 @@ TEST_F(XlaBuilderTest, VariadicOr) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); // Don't specify in the test whether Or(x, y, z) is right- or // left-associative; accept either one. - EXPECT_THAT( - module->entry_computation()->root_instruction(), - ::testing::AnyOf( - op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))), - op::Or(op::Or(op::Parameter(0), op::Parameter(1)), - op::Parameter(2)))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + ::testing::AnyOf( + GmockMatch(m::Or(m::Parameter(0), + m::Or(m::Parameter(1), m::Parameter(2)))), + GmockMatch(m::Or(m::Or(m::Parameter(0), m::Parameter(1)), + m::Parameter(2))))); } TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { @@ -191,7 +191,8 @@ TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { Add(x, ConstantR0(&b, 1.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); + EXPECT_THAT(root, + GmockMatch(m::Add(m::Parameter(), m::Broadcast(m::Constant())))); } TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { @@ -207,7 +208,8 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); + EXPECT_THAT( + root, GmockMatch(m::Add(m::Parameter(0), m::Broadcast(m::Parameter(1))))); } TEST_F(XlaBuilderTest, XPlusX) { @@ -216,7 +218,7 @@ TEST_F(XlaBuilderTest, XPlusX) { Add(x, x); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Parameter(0)))); } TEST_F(XlaBuilderTest, ShapeInferenceError) { @@ -268,8 +270,8 @@ TEST_F(XlaBuilderTest, Call) { Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), - op::Call(op::Constant(), op::Constant()))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Call(m::Parameter(), m::Parameter()), + m::Call(m::Constant(), m::Constant())))); } TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { @@ -289,8 +291,9 @@ TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { // \ / // add auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), - op::Broadcast(op::Reshape(op::Parameter(1))))); + EXPECT_THAT(root, + GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Reshape(m::Parameter(1)))))); } TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { @@ -314,8 +317,9 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { // \ / // add auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), - op::Broadcast(op::Reshape(op::Parameter(1))))); + EXPECT_THAT(root, + GmockMatch(m::Add(m::Broadcast(m::Parameter(0)), + m::Broadcast(m::Reshape(m::Parameter(1)))))); } TEST_F(XlaBuilderTest, BroadcastInDim) { @@ -325,7 +329,7 @@ TEST_F(XlaBuilderTest, BroadcastInDim) { /*broadcast_dimensions=*/{0, 2}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Broadcast()); + EXPECT_THAT(root, GmockMatch(m::Broadcast())); } TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { @@ -335,7 +339,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { /*broadcast_dimensions=*/{0, 1, 2}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Broadcast(op::Reshape(op::Broadcast()))); + GmockMatch(m::Broadcast(m::Reshape(m::Broadcast())))); } TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) { @@ -368,7 +372,7 @@ TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { Reshape(x, /*new_sizes=*/{6, 35}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Parameter())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Parameter()))); } TEST_F(XlaBuilderTest, ReshapeHasTranspose) { @@ -377,7 +381,7 @@ TEST_F(XlaBuilderTest, ReshapeHasTranspose) { Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Transpose(m::Parameter())))); } TEST_F(XlaBuilderTest, Transpose) { @@ -386,7 +390,7 @@ TEST_F(XlaBuilderTest, Transpose) { Transpose(x, /*permutation=*/{1, 0}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Transpose(op::Parameter())); + EXPECT_THAT(root, GmockMatch(m::Transpose(m::Parameter()))); } TEST_F(XlaBuilderTest, AllGatherR1) { @@ -481,14 +485,48 @@ TEST_F(XlaBuilderTest, AllToAllTuple) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - // AllToAll is converted into a single all-to-all HloInstruction. - EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll); + // Check shape and replica groups. auto expected_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, /* dimensions= */ {2, 4}, /* minor_to_major= */ {0, 1}); - EXPECT_THAT(root, op::ShapeWithLayout(ShapeUtil::MakeTupleShape( - {expected_shape, expected_shape}))); - EXPECT_THAT(root, op::ReplicaGroups({{0, 1}})); + auto tuple_shape = + ShapeUtil::MakeTupleShape({expected_shape, expected_shape}); + auto is_replica_group_pred = [](const HloInstruction* instr) { + return instr->replica_groups().size() == 1 && + absl::c_equal(instr->replica_groups()[0].replica_ids(), + std::vector{0, 1}); + }; + + // AllToAll is converted into a single all-to-all HloInstruction. + EXPECT_THAT(root, GmockMatch(m::Op() + .WithOpcode(HloOpcode::kAllToAll) + .WithShapeEqualTo(&tuple_shape) + .WithPredicate(is_replica_group_pred))); +} + +TEST_F(XlaBuilderTest, AllReduceTuple) { + XlaBuilder b(TestName()); + auto shape0 = ShapeUtil::MakeShape(F32, {}); + auto shape1 = ShapeUtil::MakeShape(F32, {1, 2}); + auto p0 = Parameter(&b, 0, shape0, "p0"); + auto p1 = Parameter(&b, 1, shape1, "p1"); + + XlaBuilder bsum(TestName()); + auto f32Scalar = ShapeUtil::MakeShape(F32, {}); + Add(Parameter(&bsum, 0, f32Scalar, "x"), Parameter(&bsum, 1, f32Scalar, "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + + AllReduceTuple({p0, p1}, sum); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // Check shape and replica groups. + auto tuple_shape = ShapeUtil::MakeTupleShape({shape0, shape1}); + + // AllToAll is converted into a single all-to-all HloInstruction. + EXPECT_THAT(root, GmockMatch(m::Op() + .WithOpcode(HloOpcode::kAllReduce) + .WithShapeEqualTo(&tuple_shape))); } TEST_F(XlaBuilderTest, CollectivePermute) { @@ -514,7 +552,7 @@ TEST_F(XlaBuilderTest, GetDimensionSizeConstant) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x"); - // Get dimension size from a contant dimension gives us a constant. + // Get dimension size from a constant dimension gives us a constant. GetDimensionSize(x, 0); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); @@ -536,7 +574,7 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Add(m::Constant(), m::Constant()))); } TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { @@ -554,7 +592,7 @@ TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { Add(constant, ConstantR0(&b, 2.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); } TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { @@ -568,7 +606,7 @@ TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { Add(x, Sub(y, z)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter()); + EXPECT_THAT(root, GmockMatch(m::Parameter())); EXPECT_EQ(module->entry_computation()->num_parameters(), 3); EXPECT_EQ(module->entry_computation()->instruction_count(), 5); } @@ -821,19 +859,19 @@ TEST_F(XlaBuilderTest, SelectIntoConditional) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, BuildHloModule(&b)); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Conditional(op::Parameter(0), op::Parameter(1), op::Parameter(2))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Conditional(m::Parameter(0), m::Parameter(1), + m::Parameter(2)))); EXPECT_THAT(module->entry_computation() ->root_instruction() ->branch_computation(0) ->root_instruction(), - op::Parameter(0)); + GmockMatch(m::Parameter(0))); EXPECT_THAT(module->entry_computation() ->root_instruction() ->branch_computation(1) ->root_instruction(), - op::Parameter(0)); + GmockMatch(m::Parameter(0))); } TEST_F(XlaBuilderTest, DynamicPad) { @@ -1420,7 +1458,7 @@ TEST_F(XlaBuilderTest, ComparisonType) { (void)Le(ConstantR0(&b, 1), ConstantR0(&b, 2)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant())); + ASSERT_THAT(root, GmockMatch(m::Compare(m::Constant(), m::Constant()))); EXPECT_EQ(Comparison::Type::kSigned, DynCast(root)->type()); } diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h index c7f8280a066..d8f2d0a4d5b 100644 --- a/tensorflow/compiler/xla/client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ +#include +#include #include #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,7 +34,7 @@ class XlaComputation { XlaComputation(HloModuleProto proto) : unique_id_(proto.id()), proto_(std::move(proto)) {} - ~XlaComputation() {} + ~XlaComputation() = default; XlaComputation(const XlaComputation&) = delete; XlaComputation& operator=(const XlaComputation&) = delete; diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc index 69f4c0b2100..7cda7d4457c 100644 --- a/tensorflow/compiler/xla/comparison_util.cc +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -31,37 +31,14 @@ namespace { // Verifies that this is a valid Comparison: (1) not a partial ordering on // integers, and (2) a valid PrimitiveType. bool IsValidComparison(xla::PrimitiveType type, Comparison::Order order) { - switch (type) { - case F16: - case F32: - case BF16: - case F64: - case F8E5M2: - case F8E4M3FN: - case F8E4M3B11FNUZ: - case C64: - case C128: - return true; - case S4: - case S8: - case S16: - case S32: - case S64: - case PRED: - case U4: - case U8: - case U16: - case U32: - case U64: - return order == Comparison::Order::kTotal; - case TUPLE: - case OPAQUE_TYPE: - case TOKEN: - case PRIMITIVE_TYPE_INVALID: - case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: - case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: - return false; + if (primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type)) { + return true; } + if (primitive_util::IsIntegralType(type) || type == PRED) { + return order == Comparison::Order::kTotal; + } + LOG(FATAL) << "Unsupported type: " << PrimitiveType_Name(type); } // Returns the X32 primitive type for each Type. @@ -91,32 +68,14 @@ Comparison::Order DefaultOrdering(Comparison::Type type) { // Returns the expected ordering for each primitive type. Comparison::Order DefaultOrdering(PrimitiveType type) { - switch (type) { - case S4: - case S8: - case S16: - case S32: - case S64: - case PRED: - case U4: - case U8: - case U16: - case U32: - case U64: - return Comparison::Order::kTotal; - case F8E5M2: - case F8E4M3FN: - case F8E4M3B11FNUZ: - case BF16: - case F16: - case F32: - case F64: - case C64: - case C128: - return Comparison::Order::kPartial; - default: - LOG(FATAL) << "Unsupported type: " << PrimitiveType_Name(type); + if (primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type)) { + return Comparison::Order::kPartial; } + if (primitive_util::IsIntegralType(type) || type == PRED) { + return Comparison::Order::kTotal; + } + LOG(FATAL) << "Unsupported type: " << PrimitiveType_Name(type); } // Returns the converse of `direction`. @@ -248,33 +207,17 @@ StatusOr StringToComparisonType( } Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) { - switch (type) { - case S4: - case S8: - case S16: - case S32: - case S64: - return Type::kSigned; - case PRED: - case U4: - case U8: - case U16: - case U32: - case U64: - return Type::kUnsigned; - case F8E5M2: - case F8E4M3FN: - case F8E4M3B11FNUZ: - case F16: - case F32: - case BF16: - case F64: - case C64: - case C128: - return Type::kFloat; - default: - LOG(FATAL) << "Unexpected: " << PrimitiveType_Name(type); + if (primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type)) { + return Type::kFloat; } + if (primitive_util::IsSignedIntegralType(type)) { + return Type::kSigned; + } + if (primitive_util::IsUnsignedIntegralType(type) || type == PRED) { + return Type::kUnsigned; + } + LOG(FATAL) << "Unexpected: " << PrimitiveType_Name(type); } Comparison::Comparison(Direction dir, PrimitiveType type, Order order) @@ -312,36 +255,10 @@ std::optional Comparison::Inverse() const { // operand is NaN. return std::nullopt; } - switch (primitive_type_) { - case F16: - case F32: - case BF16: - case F64: - case F8E5M2: - case F8E4M3FN: - case F8E4M3B11FNUZ: - case C64: - case C128: - case S4: - case S8: - case S16: - case S32: - case S64: - case PRED: - case U4: - case U8: - case U16: - case U32: - case U64: - return Comparison(xla::Inverse(dir_), primitive_type_, order_); - case TUPLE: - case OPAQUE_TYPE: - case TOKEN: - case PRIMITIVE_TYPE_INVALID: - case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: - case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: - return std::nullopt; + if (primitive_util::IsArrayType(primitive_type_)) { + return Comparison(xla::Inverse(dir_), primitive_type_, order_); } + return std::nullopt; } bool Comparison::IsReflexive() const { diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h index 1b6f349e8b9..5cd434104b2 100644 --- a/tensorflow/compiler/xla/comparison_util.h +++ b/tensorflow/compiler/xla/comparison_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ +#include +#include #include #include #include @@ -183,7 +185,8 @@ class Comparison { // Applies the comparison from this Comparison's direction and ordering for // integral types. - template ::value, int> = 0> + template ::is_integer, int> = 0> inline bool Compare(const T a, const T b) const { DCHECK(primitive_util::IsCanonicalRepresentation(primitive_type_)); return GetComparator()(a, b); @@ -192,9 +195,7 @@ class Comparison { // Applies the comparison from this Comparison's direction and ordering // for floating point types. template ::value || - std::is_same::value, - int> = 0> + absl::enable_if_t::is_integer, int> = 0> inline bool Compare(const T a, const T b) const { DCHECK(primitive_util::IsCanonicalRepresentation(primitive_type_)); if (IsTotalOrder()) { diff --git a/tensorflow/compiler/xla/cpu_function_runtime.h b/tensorflow/compiler/xla/cpu_function_runtime.h index 14c5fec2ff8..151fc90d2d0 100644 --- a/tensorflow/compiler/xla/cpu_function_runtime.h +++ b/tensorflow/compiler/xla/cpu_function_runtime.h @@ -24,17 +24,37 @@ limitations under the License. namespace xla { namespace cpu_function_runtime { + +struct EncodedBufferInfo { + uint64_t packed_kind_and_size = 0; + uint32_t entry_param_number = -1; + uint32_t result_param_number = -1; +}; + // Stores information about one buffer used by an XLA:CPU compiled function. // These buffers are used for holding inputs to the computation, outputs from // the computation and as temporary scratch space. class BufferInfo { public: // Creates a BufferInfo from a serialized encoding generated by `Encode`. - explicit BufferInfo(std::pair encoding) - : entry_param_number_(encoding.second) { + // TODO(ecg): remove once there are no users left. + explicit BufferInfo(uint64_t packed_kind_and_size, + uint32_t entry_param_number, uint32_t result_param_number) + : entry_param_number_(entry_param_number), + result_param_number_(result_param_number) { Kind kind; uint64_t size; - Unpack(encoding.first, &kind, &size); + Unpack(packed_kind_and_size, &kind, &size); + kind_ = kind; + size_ = size; + } + + explicit BufferInfo(const EncodedBufferInfo& encoded) + : entry_param_number_(encoded.entry_param_number), + result_param_number_(encoded.result_param_number) { + Kind kind; + uint64_t size; + Unpack(encoded.packed_kind_and_size, &kind, &size); kind_ = kind; size_ = size; } @@ -46,14 +66,31 @@ class BufferInfo { // Returns true if this buffer stores an entry parameter. These may or may // not need to be allocated by the runtime, depending on // XlaCompiledCpuFunction::AllocMode. - bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; } + bool is_entry_parameter() const { + return kind() == Kind::kParameter && entry_param_number_ >= 0; + } // Returns the entry parameter number of this buffer. - uint64_t entry_parameter_number() const { + uint32_t entry_parameter_number() const { assert(is_entry_parameter()); return entry_param_number_; } + void set_result_parameter_number(uint32_t param_number) { + result_param_number_ = param_number; + } + + bool is_result_parameter() const { + // Note: the kind is not unique, e.g. could be a kTempBuffer, or a + // kParameter if it is an in-out argument. + return result_param_number_ >= 0; + } + + uint32_t result_parameter_number() const { + assert(is_result_parameter()); + return result_param_number_; + } + // Returns true if this buffer is temporary scratch space required by the XLA // computations. These are always allocated by the runtime. bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; } @@ -69,11 +106,13 @@ class BufferInfo { // reconstruct the BufferInfo later using the constructor. We need this // because we use BufferInfo in places where using protocol buffers would // negatively impact binary size. - std::pair Encode() const { + EncodedBufferInfo Encode() const { static_assert(sizeof(*this) == 16, ""); - uint64_t upper = Pack(kind(), size_); - uint64_t lower = entry_param_number_; - return {upper, lower}; + EncodedBufferInfo ret; + ret.packed_kind_and_size = Pack(kind(), size_); + ret.entry_param_number = entry_param_number_; + ret.result_param_number = result_param_number_; + return ret; } bool operator==(const BufferInfo& buffer_info) const { @@ -87,20 +126,26 @@ class BufferInfo { // Factory methods: static BufferInfo MakeTempBuffer(uint64_t size) { - return BufferInfo(Kind::kTempBuffer, /*size=*/size, - /*entry_param_number=*/-1); + return BufferInfo(Kind::kTempBuffer, size); } static BufferInfo MakeConstant(uint64_t size) { - return BufferInfo(Kind::kConstant, /*size=*/size, - /*entry_param_number=*/-1); + return BufferInfo(Kind::kConstant, size); } - static BufferInfo MakeEntryParameter(uint64_t size, uint64_t param_number) { - return BufferInfo(Kind::kEntryParameter, /*size=*/size, - /*entry_param_number=*/param_number); + // Note: in-out parameters are possible by first creating an entry parameter + // and then calling set_result_parameter_number(). + static BufferInfo MakeEntryParameter(uint64_t size, + uint32_t entry_param_number) { + return BufferInfo(Kind::kParameter, size, entry_param_number); + } + // Only used in tests. Here we use kTempBuffer but it is unimportant. + static BufferInfo MakeResultParameter(uint64_t size, + uint32_t result_param_number) { + // Here we + return BufferInfo(Kind::kTempBuffer, size, /*entry_param_number=*/-1, + result_param_number); } static BufferInfo MakeOnStackBuffer(uint64_t size) { - return BufferInfo(Kind::kOnStackBuffer, /*size=*/size, - /*entry_param_number=*/-1); + return BufferInfo(Kind::kOnStackBuffer, size); } private: @@ -109,14 +154,25 @@ class BufferInfo { enum class Kind : uint64_t { kConstant, kTempBuffer, - kEntryParameter, + kParameter, kOnStackBuffer }; Kind kind() const { return static_cast(kind_); } - explicit BufferInfo(Kind kind, uint64_t size, uint64_t entry_param_number) - : kind_(kind), size_(size), entry_param_number_(entry_param_number) {} + explicit BufferInfo(Kind kind, uint64_t size) + : BufferInfo(kind, size, + /*entry_param_number=*/-1, + /*result_param_number=*/-1) {} + explicit BufferInfo(Kind kind, uint64_t size, uint32_t entry_param_number) + : BufferInfo(kind, size, entry_param_number, + /*result_param_number=*/-1) {} + explicit BufferInfo(Kind kind, uint64_t size, uint32_t entry_param_number, + uint32_t result_param_number) + : kind_(kind), + size_(size), + entry_param_number_(entry_param_number), + result_param_number_(result_param_number) {} static uint64_t Pack(Kind kind, uint64_t size) { return (static_cast(size) << 2) | static_cast(kind); @@ -129,7 +185,8 @@ class BufferInfo { Kind kind_ : 2; uint64_t size_ : 62; - int64_t entry_param_number_; + int32_t entry_param_number_ = -1; + int32_t result_param_number_ = -1; }; // Align to 64-bytes, to mimic tsl::Allocator::kAllocatorAlignment. diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index dabed025c92..783af8e2300 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -81,6 +81,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cuda_graph_instantiation_threshold(2); opts.set_xla_gpu_enable_persistent_temp_buffers(false); opts.set_xla_gpu_cuda_graph_capture_threshold(2); + opts.set_xla_gpu_cuda_graph_enable_concurrent_region(true); // Despite the name, fast min/max on GPUs does not seem to be any faster, and // adds very counter-intuitive "NaN-swallowing" behavior. @@ -90,6 +91,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); opts.set_xla_gpu_all_reduce_combine_threshold_bytes(30 * 1024 * 1024); + opts.set_xla_gpu_all_gather_combine_threshold_bytes(1024 * 1024 * 1024); + opts.set_xla_gpu_reduce_scatter_combine_threshold_bytes(30 * 1024 * 1024); opts.set_xla_gpu_enable_async_all_reduce(true); opts.set_xla_gpu_enable_reassociation_for_converted_ar(true); @@ -105,13 +108,14 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // Set 4GB space limit for redzone scratch allocator. opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12); opts.set_xla_gpu_shape_checks(DebugOptions::RUNTIME); - opts.set_xla_gpu_enable_mlir_lowering(true); opts.set_xla_gpu_normalize_layouts(true); opts.set_xla_gpu_simplify_all_fp_conversions(true); opts.set_xla_dump_latency_hiding_schedule(false); opts.set_xla_gpu_enable_latency_hiding_scheduler(false); opts.set_xla_gpu_lhs_enable_gpu_async_tracker(false); - opts.set_xla_gpu_pgle_profile_directory(""); + opts.set_xla_gpu_pgle_profile_file_or_directory_path(""); + opts.set_xla_gpu_enable_highest_priority_async_stream(false); + opts.set_xla_gpu_enable_data_parallel_collective_optimizer(false); opts.set_xla_cpu_enable_mlir_tiling_and_fusion(true); opts.set_xla_cpu_enable_custom_matmul_tiling(false); @@ -128,11 +132,14 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); - // Moving reduce-scatter out of while loops can incrase memory footprint, so + // Moving reduce-scatter out of while loops can increase memory footprint, so // turning it off by default. opts.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(false); opts.set_xla_gpu_collective_inflation_factor(1); + + opts.set_xla_gpu_enable_experimental_block_size(false); + return opts; } @@ -265,11 +272,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return true; }; - auto setter_for_xla_gpu_enable_mlir_lowering = [debug_options](bool value) { - debug_options->set_xla_gpu_enable_mlir_lowering(value); - return true; - }; - // Custom "sub-parser" lambda for xla_partitioning_algorithm. auto setter_for_xla_partitioning_algorithm = [debug_options](const std::string& value) { @@ -773,6 +775,18 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_all_reduce_combine_threshold_bytes), debug_options->xla_gpu_all_reduce_combine_threshold_bytes(), "Size threshold (in bytes) for the GPU all-reduce combiner.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_all_gather_combine_threshold_bytes", + int64_setter_for( + &DebugOptions::set_xla_gpu_all_gather_combine_threshold_bytes), + debug_options->xla_gpu_all_gather_combine_threshold_bytes(), + "Size threshold (in bytes) for the GPU all-gather combiner.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_reduce_scatter_combine_threshold_bytes", + int64_setter_for( + &DebugOptions::set_xla_gpu_reduce_scatter_combine_threshold_bytes), + debug_options->xla_gpu_reduce_scatter_combine_threshold_bytes(), + "Size threshold (in bytes) for the GPU reduce-scatter combiner.")); flag_list->push_back(tsl::Flag( "xla_gpu_all_reduce_contiguous", bool_setter_for(&DebugOptions::set_xla_gpu_all_reduce_contiguous), @@ -844,6 +858,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_cuda_graph_capture_threshold(), "Capture a region as a function to be launched as cuda graph if the " "number of moved instructions reaches this threshold.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_cuda_graph_enable_concurrent_region", + bool_setter_for( + &DebugOptions::set_xla_gpu_cuda_graph_enable_concurrent_region), + debug_options->xla_gpu_cuda_graph_enable_concurrent_region(), + "Identify concurrent regions in cuda graphs and execute them " + "concurrently.")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_persistent_temp_buffers", @@ -904,10 +925,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "xla_gpu_shape_checks", setter_for_xla_gpu_shape_checks, DebugOptions::ShapeChecks_Name(debug_options->xla_gpu_shape_checks()), "When to perform shape checks in XLA:GPU.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_mlir_lowering", setter_for_xla_gpu_enable_mlir_lowering, - debug_options->xla_gpu_enable_mlir_lowering(), - "Enable MLIR-based lowering in XLA:GPU instead of LLVM emitters.")); flag_list->push_back( tsl::Flag("xla_gpu_normalize_layouts", bool_setter_for(&DebugOptions::set_xla_gpu_normalize_layouts), @@ -969,15 +986,28 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_latency_hiding_scheduler(), "Enable latency-hiding scheduler for XLA:GPU")); flag_list->push_back(tsl::Flag( - "xla_gpu_pgle_profile_directory", - string_setter_for(&DebugOptions::set_xla_gpu_pgle_profile_directory), - debug_options->xla_gpu_pgle_profile_directory(), - "Directory for PGLE profiles in XLA:GPU")); + "xla_gpu_pgle_profile_file_or_directory_path", + string_setter_for( + &DebugOptions::set_xla_gpu_pgle_profile_file_or_directory_path), + debug_options->xla_gpu_pgle_profile_file_or_directory_path(), + "Directory or file for PGLE profiles in XLA:GPU")); flag_list->push_back(tsl::Flag( "xla_gpu_lhs_enable_gpu_async_tracker", bool_setter_for(&DebugOptions::set_xla_gpu_lhs_enable_gpu_async_tracker), debug_options->xla_gpu_lhs_enable_gpu_async_tracker(), "Enable GPU async tracker for latency-hiding scheduler in XLA:GPU")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_highest_priority_async_stream", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_highest_priority_async_stream), + debug_options->xla_gpu_enable_highest_priority_async_stream(), + "Enable async stream to have the highest priority.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_data_parallel_collective_optimizer", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_data_parallel_collective_optimizer), + debug_options->xla_gpu_enable_data_parallel_collective_optimizer(), + "Enable data parallel collective optimizer.")); flag_list->push_back(tsl::Flag( "xla_partitioning_algorithm", setter_for_xla_partitioning_algorithm, DebugOptions::PartitioningAlgorithm_Name( @@ -1001,6 +1031,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_triton_gemm_any(), "Use Triton-based matrix multiplication for any GEMM it " "supports without filtering only faster ones.")); + flag_list->push_back( + tsl::Flag("xla_gpu_enable_experimental_block_size", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_experimental_block_size), + debug_options->xla_gpu_enable_experimental_block_size(), + "Enable experimental block size.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/examples/axpy/BUILD b/tensorflow/compiler/xla/examples/axpy/BUILD index a2e266481dd..8c1442922c3 100644 --- a/tensorflow/compiler/xla/examples/axpy/BUILD +++ b/tensorflow/compiler/xla/examples/axpy/BUILD @@ -1,5 +1,7 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + xla_cc_test( name = "stablehlo_compile_test", srcs = ["stablehlo_compile_test.cc"], diff --git a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/BUILD b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/BUILD index 769e8a76bb0..c09b88d80c9 100644 --- a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/BUILD +++ b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/BUILD @@ -1,9 +1,11 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") + +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) cc_library( name = "sm_bw_utils", hdrs = ["sm_bw_utils.h"], - defines = if_cuda(["GOOGLE_CUDA=1"]), deps = [ "//tensorflow/tsl/platform:logging", ] + if_cuda([ @@ -20,13 +22,15 @@ cuda_library( ], ) -cc_test( +xla_cc_test( name = "sm_bw_test", srcs = ["sm_bw_test.cc"], - tags = ["requires-gpu-sm80-only"], + tags = ["requires-gpu-nvidia"], deps = [ ":sm_bw_kernels", ":sm_bw_utils", "@com_google_googletest//:gtest_main", - ], + ] + if_cuda([ + "//tensorflow/tsl/platform:cuda", + ]), ) diff --git a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc index d0bc62cd0de..e170e44d66d 100644 --- a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc +++ b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc @@ -77,21 +77,22 @@ template __launch_bounds__(kMaxBlockSize) __global__ void BenchmarkDeviceCopyKernel(const float* __restrict__ in, float* __restrict__ out, int64_t size) { + constexpr int kVecWidth = chunks < 4 ? 1 : 4; const int64_t lines = size / (blockDim.x * chunks); const int64_t start_line = lines * blockIdx.x / gridDim.x; const int64_t end_line = lines * (blockIdx.x + 1) / gridDim.x; const int64_t start_offset = - start_line * blockDim.x * chunks + 4 * threadIdx.x; + start_line * blockDim.x * chunks + kVecWidth * threadIdx.x; const int64_t end_offset = end_line * blockDim.x * chunks; - Vec buffer[chunks / 4]; + Vec buffer[chunks / kVecWidth]; for (int64_t i = start_offset; i < end_offset; i += blockDim.x * chunks) { #pragma unroll - for (int j = 0; j < chunks; j += 4) { - LoadNc(buffer[j / 4], in + i + blockDim.x * j, 0); + for (int j = 0; j < chunks; j += kVecWidth) { + LoadNc(buffer[j / kVecWidth], in + i + blockDim.x * j, 0); } #pragma unroll - for (int j = 0; j < chunks; j += 4) { - Store(buffer[j / 4], out + i + blockDim.x * j, 0); + for (int j = 0; j < chunks; j += kVecWidth) { + Store(buffer[j / kVecWidth], out + i + blockDim.x * j, 0); } } } diff --git a/tensorflow/compiler/xla/experiments/triton_autotuning/__init__.py b/tensorflow/compiler/xla/experiments/triton_autotuning/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/compiler/xla/experiments/triton_autotuning/check_csv.py b/tensorflow/compiler/xla/experiments/triton_autotuning/check_csv.py new file mode 100755 index 00000000000..c38fe4fef75 --- /dev/null +++ b/tensorflow/compiler/xla/experiments/triton_autotuning/check_csv.py @@ -0,0 +1,95 @@ +#!/usr/bin/python3 +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Measures timings of tilings provided in a CSV file.""" +import sys + +from absl import app +from absl import flags +from matmul_lib import benchmark_matmul +from matmul_lib import MatmulSize +from matmul_lib import MatmulTiling +from matmul_lib import QuantizedInputType +import pandas as pd +import torch +import tqdm + +_DATA = flags.DEFINE_string('data', '', 'Data to check') +_OUTPUT_FILE = flags.DEFINE_string( + 'output_file', '/tmp/checked.csv', 'File to write output data to' +) +_NUM_SAMPLES = flags.DEFINE_integer( + 'num_samples', 100, 'Number of samples to check' +) +_M = flags.DEFINE_integer('m', 64, 'Size of first matrix') +_K = flags.DEFINE_integer('k', 64, 'Size of contracting dimension') +_N = flags.DEFINE_integer('n', 64, 'Size of second matrix') +_QUANTIZED_LHS = flags.DEFINE_enum_class( + 'quantized_lhs', + QuantizedInputType.FULL, + QuantizedInputType, + 'Type to use for LHS quantization', +) + + +def get_actual_time(r, s, pbar): + dims = MatmulSize(_M.value, _N.value, _K.value, _QUANTIZED_LHS.value) + return benchmark_matmul( + dims=dims, + pbar=pbar, + shared_stream=s, + tilings=[ + MatmulTiling( + r.block_m, + r.block_n, + r.block_k, + r.split_k, + r.num_stages, + r.num_warps, + ) + ], + repetitions_ms=300, + )[0].min_time_ms + + +def main(): + df = pd.read_csv(_DATA.value).sample(_NUM_SAMPLES.value) + shared_stream = torch.cuda.Stream() + measured_times = [] + pbar = tqdm.tqdm(total=_NUM_SAMPLES.value, ncols=0) + with torch.cuda.stream(shared_stream): + for _, r in df.iterrows(): + measured_times.append(get_actual_time(r, shared_stream, pbar)) + df = df.assign(measured_min_time_ms=measured_times) + pbar.close() + + def absolute_error(r): + return abs(r.measured_min_time_ms - r.min_time_ms) + + def relative_error(r): + return absolute_error(r) / r.min_time_ms + + errors = df.assign(absolute_error=absolute_error).assign( + relative_error=relative_error + )[['absolute_error', 'relative_error']] + print(errors) + print(errors.describe()) + df.to_csv(_OUTPUT_FILE.value) + + +if __name__ == '__main__': + app.parse_flags_with_usage(sys.argv) + main() diff --git a/tensorflow/compiler/xla/experiments/triton_autotuning/check_data.py b/tensorflow/compiler/xla/experiments/triton_autotuning/check_data.py new file mode 100755 index 00000000000..8ad71b671bd --- /dev/null +++ b/tensorflow/compiler/xla/experiments/triton_autotuning/check_data.py @@ -0,0 +1,81 @@ +#!/usr/bin/python3 +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Plot actual min time vs estimated min time from Triton performance model.""" + +from collections.abc import Sequence +from absl import app +import pandas as pd +import plotext as plt +import torch +import triton + + +def main(argv: Sequence[str]) -> None: + if len(argv) != 2: + raise app.UsageError('Incorrect number of command-line arguments.') + f = argv[1] + + df = pd.read_csv( + f, + dtype={ + 'M': int, + 'N': int, + 'K': int, + 'BLOCK_M': int, + 'BLOCK_N': int, + 'BLOCK_K': int, + 'SPLIT_K': int, + 'num_stages': int, + 'num_warps': int, + 'min_time_ms': float, + }, + ) + grouped_df = df.groupby(['M', 'N', 'K']).min().sort_values('min_time_ms') + + estimated_times = [] + actual_times = [] + + matrix = torch.randn(1, 1, device='cuda', dtype=torch.float16) + for dims, r in grouped_df.iterrows(): + m, n, k = dims + estimated_time = triton.ops.matmul_perf_model.estimate_matmul_time( + num_warps=r.num_warps, + num_stages=r.num_stages, + A=matrix, + B=matrix, + C=matrix, + M=m, + N=n, + K=k, + BLOCK_M=r.BLOCK_M, + BLOCK_N=r.BLOCK_N, + BLOCK_K=r.BLOCK_K, + SPLIT_K=r.SPLIT_K, + ) + actual_times.append(r.min_time_ms) + estimated_times.append(estimated_time) + + plt.theme('dark') + plt.plot(actual_times, estimated_times) + plt.xlabel('Actual Time (ms)') + plt.ylabel('Estimated Time (ms)') + plt.title('Estimated time as a function of actual time') + plt.show() + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/compiler/xla/experiments/triton_autotuning/matmul_lib.py b/tensorflow/compiler/xla/experiments/triton_autotuning/matmul_lib.py new file mode 100755 index 00000000000..2497b2100bb --- /dev/null +++ b/tensorflow/compiler/xla/experiments/triton_autotuning/matmul_lib.py @@ -0,0 +1,451 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Library for running matmuls.""" +import enum +import itertools +import logging +import math +import typing + +import torch +import tqdm +import triton +import triton.language as tl + +LOG = logging.getLogger(__name__) + +logging.basicConfig( + format=( + '%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d]' + ' %(threadName)15s: %(message)s' + ), + datefmt='%Y-%m-%d:%H:%M:%S', + level=logging.INFO, +) + + +@enum.unique +class QuantizedInputType(enum.Enum): + """Type to use for quantized matmul inputs.""" + + FULL = 'full' + INT8 = 'int8' + FLOAT8 = 'float8' + + +class MatmulTiling(typing.NamedTuple): + """Tiling parameterization of a matmul.""" + + BLOCK_M: int + BLOCK_N: int + BLOCK_K: int + SPLIT_K: int + num_stages: int + num_warps: int + + +class MatmulSize(typing.NamedTuple): + """[M, K] @ [K, N].""" + + M: int + N: int + K: int + quantized_lhs: QuantizedInputType + + +class MatmulTiming(typing.NamedTuple): + """Timing result of a configuration.""" + + dims: MatmulSize + tiling: MatmulTiling + min_time_ms: float + + +def parse_int_list(v: str) -> typing.List[int]: + """Converts a string of comma-separated ints into a list of strings.""" + return list(map(int, v.split(','))) + + +def generate_tiling_configs( + tilings_m: typing.List[int], + tilings_n: typing.List[int], + tilings_k: typing.List[int], + split_ks: typing.List[int], + num_stages: typing.List[int], + num_warps: typing.List[int], +) -> typing.Iterator[MatmulTiling]: + """Generate a list of matmul configs to evaluate.""" + product = itertools.product( + tilings_m, + tilings_n, + tilings_k, + split_ks, + num_stages, + num_warps, + ) + return [MatmulTiling(*p) for p in product] + + +@triton.jit +def _fix_type_for_load(x): + """Bitcasts a pointer to a type that can be loaded by Triton.""" + load_dtype = x.dtype + if x.dtype == tl.pointer_type(tl.float8e5): + load_dtype = tl.pointer_type(tl.int8) + return x.to(load_dtype, bitcast=True) + + +@triton.jit +def _matmul_kernel( + lhs, + rhs, + out, + m: tl.constexpr, + n: tl.constexpr, + k: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + block_m: tl.constexpr, + block_n: tl.constexpr, + block_k: tl.constexpr, + group_m: tl.constexpr, + split_k: tl.constexpr, + acc_ty: tl.constexpr, + # Workaround for a bug in Triton cache: + # force recompilation on different num_warps/num_stages. + force_num_warps: tl.constexpr, # pylint: disable=unused-argument + force_num_stages: tl.constexpr, # pylint: disable=unused-argument +): + """Computes a block-level matmul.""" + even_k = k % (block_k * split_k) == 0 + pid0 = tl.program_id(0) + pid1 = tl.program_id(1) + pid2 = tl.program_id(2) + grid_m = (m + block_m - 1) // block_m + grid_n = (n + block_n - 1) // block_n + # re-order program ID for better L2 performance + width = group_m * grid_n + group_id = pid0 // width + group_size = min(grid_m - group_id * group_m, group_m) + pid_m = group_id * group_m + pid0 % group_size + pid_n = (pid0 % width) // group_size + rm = pid_m * block_m + tl.arange(0, block_m) + rn = pid_n * block_n + tl.arange(0, block_n) + ram = tl.max_contiguous(tl.multiple_of(rm % m, block_m), block_m) + rbn = tl.max_contiguous(tl.multiple_of(rn % n, block_n), block_n) + rk = pid1 * block_k + tl.arange(0, block_k) + lhs += ram[:, None] * stride_am + rk[None, :] * stride_ak + pid2 * m * k + rhs += rk[:, None] * stride_bk + rbn[None, :] * stride_bn + acc = tl.zeros((block_m, block_n), dtype=acc_ty) + # for ki in range(0, k, block_k * split_k): # pytype: disable=wrong-arg-types + for ki in range(k, 0, -block_k * split_k): # pytype: disable=wrong-arg-types + if even_k: + a = tl.load(_fix_type_for_load(lhs)) + b = tl.load(rhs) + else: + a = tl.load(_fix_type_for_load(lhs), mask=rk[None, :] < ki, other=0) + b = tl.load(rhs, mask=rk[:, None] < ki, other=0) + casted_a = a.to(lhs.dtype.element_ty, bitcast=True).to(out.dtype.element_ty) + casted_b = b.to(out.dtype.element_ty) + acc += tl.dot(casted_a, casted_b, allow_tf32=True) + lhs += block_k * split_k * stride_ak + rhs += block_k * split_k * stride_bk + acc = acc.to(out.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * block_m + tl.arange(0, block_m) + rn = pid_n * block_n + tl.arange(0, block_n) + out += rm[:, None] * stride_cm + rn[None, :] * stride_cn + pid2 * m * n + out += m * n * pid1 + mask = (rm < m)[:, None] & (rn < n)[None, :] + tl.store(out, acc, mask=mask) + + +@triton.jit +def _reduce_kernel( + src, + dest, + row_size: tl.constexpr, + col_size: tl.constexpr, + row_block_size: tl.constexpr, +): + """Computes a column reduction.""" + pid0 = tl.program_id(0) + idx = pid0 * row_block_size + tl.arange(0, row_block_size) + src += idx + acc = tl.zeros((row_block_size,), dtype=dest.dtype.element_ty) + for _ in range(col_size): + acc += tl.load(src, mask=idx < row_size, other=0) + src += row_size + tl.store(dest + idx, acc, mask=idx < row_size) + + +@triton.jit +def _to_f8_kernel(src, dest, size, block_size: tl.constexpr): + pid = tl.program_id(0) + offs = pid * block_size + tl.arange(0, block_size) + mask = offs < size + x = tl.load(src + offs, mask=mask) + y = x.to(tl.float8e5) + tl.store(dest + offs, y, mask=mask) + + +def to_triton_f8(x: torch.Tensor) -> triton.TensorWrapper: + """Converts torch tensors to triton.language.float8e5.""" + assert x.is_contiguous(), 'Kernel only works for contiguous tensors' + ret = triton.reinterpret( + torch.empty(x.shape, dtype=torch.int8, device=x.device, layout=x.layout), + tl.float8e5, + ) + grid = lambda META: (triton.cdiv(x.numel(), META['block_size']),) + _to_f8_kernel[grid](ret, x, x.numel(), block_size=1024) + return ret + + +def benchmark_matmul_tiling( + dims: MatmulSize, + tiling: MatmulTiling, + s: torch.cuda.Stream, + shared_stream: torch.cuda.Stream, + a: torch.Tensor | triton.TensorWrapper, + b: torch.Tensor, + c: torch.Tensor, + scratchpad: torch.Tensor, # Largest size: c * SPLIT_K + repetitions_ms: int, + debug=False, +) -> typing.Optional[MatmulTiming]: + """Benchmarks a single matmul tiling.""" + grid = lambda META: ( # pylint: disable=g-long-lambda + triton.cdiv(dims.M, tiling.BLOCK_M) * triton.cdiv(dims.N, tiling.BLOCK_N), + tiling.SPLIT_K, + 1, # batch + ) + data_a = getattr(a, 'base', a) + + def run_matmul(): + used_output = c if tiling.SPLIT_K == 1 else scratchpad + _matmul_kernel[grid]( + a, + b, + used_output, + m=int(dims.M), + n=int(dims.N), + k=int(dims.K), + stride_am=data_a.stride(0), + stride_ak=data_a.stride(1), + stride_bk=b.stride(0), + stride_bn=b.stride(1), + stride_cm=c.stride(0), + stride_cn=c.stride(1), + block_m=int(tiling.BLOCK_M), + block_n=int(tiling.BLOCK_N), + block_k=int(tiling.BLOCK_K), + group_m=8, + split_k=tiling.SPLIT_K, + num_warps=tiling.num_warps, + num_stages=tiling.num_stages, + force_num_warps=tiling.num_warps, + force_num_stages=tiling.num_stages, + acc_ty=tl.float32, + ) + if tiling.SPLIT_K != 1: + # Run reduction kernel. + _reduce_kernel[(triton.cdiv(dims.M * dims.N, 1024),)]( + scratchpad, + c, + row_size=int(dims.M), + col_size=tiling.SPLIT_K, + num_stages=1, + num_warps=1024 // 32, + row_block_size=1024, + ) + + for dim in ['M', 'N', 'K']: + next_pow2 = lambda v: 2 ** int(math.ceil(math.log2(v))) + dim_size: int = getattr(dims, dim) + if dim == 'K': + dim_size = math.ceil(dim_size / tiling.SPLIT_K) + tile_size = getattr(tiling, f'BLOCK_{dim}') + if next_pow2(dim_size) < tile_size: + if debug: + LOG.error( + 'Tile %s larger than the dimension %s (%s)', + tile_size, + dim, + dim_size, + ) + return None + + if tiling.BLOCK_M * tiling.BLOCK_N > 131072: + if debug: + LOG.error('Overly large tile') + return None + + # TODO(cheshire): Compilation time is huge for such tiles. + if tiling.BLOCK_M > 512 or tiling.BLOCK_N > 512: + if debug: + LOG.error('Overly large tile') + return None + + max_shared_memory = triton.runtime.driver.utils.get_device_properties( + torch.cuda.current_device() + )['max_shared_mem'] + + required_shared_memory = ( + (tiling.BLOCK_M + tiling.BLOCK_N) + * tiling.BLOCK_K + * tiling.num_stages + * b.element_size() + ) + if required_shared_memory > max_shared_memory: + if debug: + LOG.error('Skipping %s due to exceeding shmem bound', tiling) + return None + with torch.cuda.stream(s): + try: + run_matmul() # Warmup on our own stream. + except Exception as exc: + LOG.error('%s for %s generated %s', tiling, dims, exc, exc_info=True) + raise + + # Use shared stream to take actual measurements. + with torch.cuda.stream(shared_stream): + try: + percentiles = triton.testing.do_bench( + run_matmul, + warmup=0, + rep=repetitions_ms, + quantiles=(0.001, 0.1, 0.5, 0.9), + ) + min_ms = percentiles[0] + except Exception as exc: + LOG.error('%s for %s generated %s', tiling, dims, exc, exc_info=True) + raise + return MatmulTiming(dims, tiling, min_ms) + + +def benchmark_cublas(dims: MatmulSize) -> MatmulTiming: + """Measure cublas performance.""" + a = torch.randn(dims.M, dims.K, device='cuda', dtype=torch.bfloat16) + b = torch.randn(dims.K, dims.N, device='cuda', dtype=torch.bfloat16) + run_matmul = lambda: torch.matmul(a, b) + percentiles = triton.testing.do_bench( + run_matmul, warmup=0, rep=300, quantiles=(0.001, 0.1, 0.5, 0.9) + ) + min_ms = percentiles[0] + return min_ms + + +def benchmark_matmul( + dims: MatmulSize, + pbar: tqdm.std.tqdm, + shared_stream: torch.cuda.Stream, + tilings: typing.List[MatmulTiling], + repetitions_ms: int, + debug=False, +) -> typing.Sequence[MatmulTiming]: + """For a given matmul configuration, benchmark it. + + Args: + dims: the dimensions of the matmul + pbar: a progress bar + shared_stream: stream to execute benchmarks on + tilings: list of tilings to benchmark + repetitions_ms: how many milliseconds to spend running each configuration + debug: whether to print debug output + + Returns: + A sequence of matmul timings. + """ + out: list[MatmulTiming] = [] + largest_splitk = max(tilings, key=lambda t: t.SPLIT_K).SPLIT_K + + s = torch.cuda.Stream() + + # Use our own stream for compilation. + with torch.cuda.stream(s): + if dims.quantized_lhs == QuantizedInputType.INT8: + a = torch.randint( + 0, 128, (dims.M, dims.K), device='cuda', dtype=torch.int8 + ) + elif dims.quantized_lhs == QuantizedInputType.FLOAT8: + a = to_triton_f8( + torch.randn(dims.M, dims.K, device='cuda', dtype=torch.bfloat16) + ) + else: + a = torch.randn(dims.M, dims.K, device='cuda', dtype=torch.bfloat16) + + b = torch.randn(dims.K, dims.N, device='cuda', dtype=torch.bfloat16) + data_a = getattr(a, 'base', a) + assert data_a.shape[1] == b.shape[0], 'incompatible dimensions' + assert data_a.is_contiguous(), 'matrix A must be contiguous' + assert b.is_contiguous(), 'matrix B must be contiguous' + c = torch.empty((dims.M, dims.N), device=a.device, dtype=torch.bfloat16) + scratchpad = torch.empty( + (largest_splitk, dims.M, dims.N), device=a.device, dtype=torch.bfloat16 + ) + + LOG.info('Autotuning for %s', dims) + + for tiling in tilings: + pbar.update(1) + + timing = benchmark_matmul_tiling( + dims, + tiling, + s, + shared_stream, + a, + b, + c, + scratchpad, + repetitions_ms=repetitions_ms, + debug=debug, + ) + if not timing: + continue + + out.append(timing) + return out + + +def print_roofline_performance(dims: MatmulSize, time_ms: float): + """Print theoretical roofline model performance.""" + gbps: float = triton.testing.get_dram_gbps() + tflops: float = triton.testing.get_max_tensorcore_tflops(torch.bfloat16) + lhs_size_bytes = dims.M * dims.K + rhs_size_bytes = dims.K * dims.N * 2 + out_size_bytes = dims.M * dims.N * 2 + + size_gb = (lhs_size_bytes + rhs_size_bytes + out_size_bytes) / 1e9 + roofline_time_ms_bw = (size_gb / gbps) * 1e3 + roofline_time_ms_flops = 2 * (dims.M * dims.N * dims.K) / (tflops * 1e9) + + best_time_ms = max(roofline_time_ms_bw, roofline_time_ms_flops) + bound = ( + 'bandwidth' if roofline_time_ms_bw > roofline_time_ms_flops else 'flops' + ) + + print( + f'Percentage of roofline: {(best_time_ms * 100 / time_ms):0.4f}%' + f' ({bound} bound)' + ) + + print(f'Roofline time if bandwidth bound: {roofline_time_ms_bw:0.4f}ms') + print(f'Roofline time if flops bound: {roofline_time_ms_flops:0.4f}ms') diff --git a/tensorflow/compiler/xla/experiments/triton_autotuning/run_single_matmul.py b/tensorflow/compiler/xla/experiments/triton_autotuning/run_single_matmul.py new file mode 100755 index 00000000000..382ff29b19f --- /dev/null +++ b/tensorflow/compiler/xla/experiments/triton_autotuning/run_single_matmul.py @@ -0,0 +1,91 @@ +#!/usr/bin/python3 +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Runs a single matmul with a supplied configuration.""" +import sys + +from absl import app +from absl import flags +from matmul_lib import benchmark_cublas +from matmul_lib import benchmark_matmul +from matmul_lib import MatmulSize +from matmul_lib import MatmulTiling +from matmul_lib import print_roofline_performance +from matmul_lib import QuantizedInputType +import torch +import tqdm + + +_M = flags.DEFINE_integer('m', 64, 'Size of first matrix') +_K = flags.DEFINE_integer('k', 64, 'Size of contracting dimension') +_N = flags.DEFINE_integer('n', 64, 'Size of second matrix') +_QUANTIZED_LHS = flags.DEFINE_enum_class( + 'quantized_lhs', + QuantizedInputType.FULL, + QuantizedInputType, + 'Type to use for LHS quantization', +) + +_BLOCK_M = flags.DEFINE_integer('block_m', 16, 'Tiling in M-dimension') +_BLOCK_N = flags.DEFINE_integer('block_n', 16, 'Tiling in N-dimension') +_BLOCK_K = flags.DEFINE_integer('block_k', 16, 'Tiling in K-dimension') + +_SPLIT_K = flags.DEFINE_integer( + 'split_k', 1, 'Number of splits for contracting dimension' +) +_NUM_STAGES = flags.DEFINE_integer( + 'num_stages', 1, 'Number of pipelining stages' +) +_NUM_WARPS = flags.DEFINE_integer( + 'num_warps', 4, 'Number of warps to allocate in a given block' +) +_DEBUG = flags.DEFINE_bool('debug', False, 'Print debug information') + + +def main(): + s = torch.cuda.Stream() + pbar = tqdm.tqdm(ncols=0) + dims = MatmulSize(_M.value, _N.value, _K.value, _QUANTIZED_LHS.value) + timing = benchmark_matmul( + dims=dims, + pbar=pbar, + shared_stream=s, + tilings=[ + MatmulTiling( + _BLOCK_M.value, + _BLOCK_N.value, + _BLOCK_K.value, + _SPLIT_K.value, + _NUM_STAGES.value, + _NUM_WARPS.value, + ) + ], + repetitions_ms=300, + debug=_DEBUG.value, + ) + if len(timing) != 1: + print('Failed to find working configuration') + sys.exit(1) + t = timing[0] + print(f'Timing: {t}') + print_roofline_performance(dims, t.min_time_ms) + cublas_time = benchmark_cublas(dims) + print(f'Reference cuBLAS time (bf16xbf16->bf16): {cublas_time:0.4f}ms') + + +if __name__ == '__main__': + app.parse_flags_with_usage(sys.argv) + main() diff --git a/tensorflow/compiler/xla/experiments/triton_autotuning/search.py b/tensorflow/compiler/xla/experiments/triton_autotuning/search.py new file mode 100755 index 00000000000..4b2fb618c76 --- /dev/null +++ b/tensorflow/compiler/xla/experiments/triton_autotuning/search.py @@ -0,0 +1,223 @@ +#!/usr/bin/python3 +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Launch Triton search for good tiling sizes, save to CSV.""" + +import concurrent.futures +import csv +import itertools +import logging +import os +import random +import sys +import time +import typing + +from absl import app +from absl import flags +from matmul_lib import benchmark_matmul +from matmul_lib import generate_tiling_configs +from matmul_lib import MatmulSize +from matmul_lib import MatmulTiming +from matmul_lib import parse_int_list +from matmul_lib import QuantizedInputType +import numpy as np +import torch +import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +LOG = logging.getLogger(__name__) + +_OUTPUT_FILE = flags.DEFINE_string( + 'output_file', + 'out.csv', + """File to generate output into. + +1) Output is streamed: for each point processed, incremental output is written +out. +2) Restarts with checkpointing are supported: the script will not regenerate data +for files already present. +""", +) +_MAX_WORKERS = flags.DEFINE_integer( + 'max_workers', 64, 'Number of threads to use' +) +_REPETITIONS_MS = flags.DEFINE_integer( + 'repetitions_ms', 300, 'Number of requests' +) +_NUM_SAMPLES = flags.DEFINE_integer('num_samples', 1000, 'Number of samples ') +_TILINGS_M = flags.DEFINE_string( + 'tilings_m', '32, 64, 128, 256', 'Tilings to try for M' +) +_TILINGS_N = flags.DEFINE_string( + 'tilings_n', '32, 64, 128, 256', 'Tilings to try for N' +) +_TILINGS_K = flags.DEFINE_string( + 'tilings_k', '32, 64, 128, 256, 512', 'Tilings to try for K' +) +_NUM_STAGES = flags.DEFINE_string( + 'num_stages', '1,2,3', 'Number of stages to try' +) +_NUM_WARPS = flags.DEFINE_string('num_warps', '4,8', 'Number of warps to try') +_SPLIT_KS = flags.DEFINE_string( + 'split_ks', '1,2,3,4,5', 'Number of split_k values to try' +) + +logging.basicConfig( + format=( + '%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d]' + ' %(threadName)15s: %(message)s' + ), + datefmt='%Y-%m-%d:%H:%M:%S', + level=logging.INFO, +) + +# pylint: disable=g-long-lambda +# pylint: disable=g-complex-comprehension +# pylint: disable=cell-var-from-loop + + +def read_timings() -> typing.Set[MatmulSize]: + """Find timings already existing in the file.""" + out: typing.Set[MatmulSize] = set() + with open(_OUTPUT_FILE.value) as f: + reader = csv.reader(f) + for row in reader: + if row[0].isdigit(): + # M, N, K + quantized_lhs + out.add(MatmulSize(*map(int, row[:4]))) + return out + + +def write_csv_header() -> None: + """Write CSV file header.""" + with open(_OUTPUT_FILE.value, 'w') as f: + fieldnames = [ + 'M', + 'N', + 'K', + 'quantized_lhs', + 'BLOCK_M', + 'BLOCK_N', + 'BLOCK_K', + 'SPLIT_K', + 'num_stages', + 'num_warps', + 'min_time_ms', + ] + writer = csv.writer(f) + writer.writerow(fieldnames) + + +def write_timings(timings: typing.Sequence[MatmulTiming]) -> None: + """Write matmul timing data to CSV output.""" + with open(_OUTPUT_FILE.value, 'a') as f: + writer = csv.writer(f) + for d in timings: + writer.writerow([ + d.dims.M, + d.dims.N, + d.dims.K, + d.dims.quantized_lhs, + d.tiling.BLOCK_M, + d.tiling.BLOCK_N, + d.tiling.BLOCK_K, + d.tiling.SPLIT_K, + d.tiling.num_stages, + d.tiling.num_warps, + d.min_time_ms, + ]) + + +def generate_samples() -> typing.List[MatmulSize]: + """Generate a list of matmuls we will be benchmarking.""" + m_axis = np.unique(np.logspace(4, 13, num=200, dtype=np.int64, base=2)) + n_axis = np.unique(np.logspace(4, 13, num=200, dtype=np.int64, base=2)) + k_axis = np.unique(np.logspace(4, 13, num=200, dtype=np.int64, base=2)) + q = [QuantizedInputType.INT8] + out = [MatmulSize(*p) for p in itertools.product(m_axis, n_axis, k_axis, q)] + out = random.choices(out, k=_NUM_SAMPLES.value) + return out + + +def run_search( + existing_samples: typing.Set[MatmulSize], +) -> typing.Sequence[MatmulTiming]: + """Run search on a list of matmul configurations.""" + samples: typing.Sequence[MatmulSize] = [ + s for s in generate_samples() if s not in existing_samples + ] + t0 = time.time() + shared_stream = torch.cuda.Stream() + tilings = generate_tiling_configs( + parse_int_list(_TILINGS_M.value), + parse_int_list(_TILINGS_N.value), + parse_int_list(_TILINGS_K.value), + parse_int_list(_SPLIT_KS.value), + parse_int_list(_NUM_STAGES.value), + parse_int_list(_NUM_WARPS.value), + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=_MAX_WORKERS.value + ) as executor: + pbar = tqdm.tqdm(total=len(samples) * len(tilings), ncols=0) + results = [] + with logging_redirect_tqdm(): + if _MAX_WORKERS.value == 1: + for c in samples: + res = benchmark_matmul( + c, pbar, shared_stream, tilings, _REPETITIONS_MS.value + ) + results.extend(res) + write_timings(res) + else: + future_to_dims = { + executor.submit( + benchmark_matmul, + c, + pbar, + shared_stream, + tilings, + _REPETITIONS_MS.value, + ): c + for c in samples + } + for future in concurrent.futures.as_completed(future_to_dims): + res = future.result() + results.extend(res) + write_timings(res) + + pbar.close() + + LOG.info('%d datapoints generated in %.2fs', len(results), (time.time() - t0)) + return results + + +def main() -> None: + existing_samples: typing.Set[MatmulSize] = set() + if os.path.isfile(_OUTPUT_FILE.value): + existing_samples = read_timings() + else: + write_csv_header() + + run_search(existing_samples) + + +if __name__ == '__main__': + random.seed(42) + app.parse_flags_with_usage(sys.argv) + main() diff --git a/tensorflow/compiler/xla/experiments/triton_autotuning/tune_single_matmul.py b/tensorflow/compiler/xla/experiments/triton_autotuning/tune_single_matmul.py new file mode 100755 index 00000000000..09971e6baaa --- /dev/null +++ b/tensorflow/compiler/xla/experiments/triton_autotuning/tune_single_matmul.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Finds best tuning for a single matmul.""" +import csv +import sys + +from absl import app +from absl import flags +from matmul_lib import benchmark_cublas +from matmul_lib import benchmark_matmul +from matmul_lib import generate_tiling_configs +from matmul_lib import MatmulSize +from matmul_lib import MatmulTiming +from matmul_lib import parse_int_list +from matmul_lib import print_roofline_performance +from matmul_lib import QuantizedInputType +import torch +import tqdm + +_M = flags.DEFINE_integer('m', 64, 'Size of first matrix') +_K = flags.DEFINE_integer('k', 64, 'Size of contracting dimension') +_N = flags.DEFINE_integer('n', 64, 'Size of second matrix') +_QUANTIZED_LHS = flags.DEFINE_enum_class( + 'quantized_lhs', + QuantizedInputType.FULL, + QuantizedInputType, + 'Type to use for LHS quantization', +) + +_TILINGS_M = flags.DEFINE_string( + 'tilings_m', '32, 64, 128, 256', 'Tilings to try for M' +) +_TILINGS_N = flags.DEFINE_string( + 'tilings_n', '32, 64, 128, 256', 'Tilings to try for N' +) +_TILINGS_K = flags.DEFINE_string( + 'tilings_k', '32, 64, 128, 256, 512', 'Tilings to try for K' +) +_NUM_STAGES = flags.DEFINE_string( + 'num_stages', '1,2,3', 'Number of stages to try' +) +_NUM_WARPS = flags.DEFINE_string('num_warps', '4,8', 'Number of warps to try') +_SPLIT_KS = flags.DEFINE_string( + 'split_ks', '1,2,3,4,5', 'Number of split_k values to try' +) +_DEBUG = flags.DEFINE_bool('debug', False, 'Print debug information') +_APPEND_TO_CSV = flags.DEFINE_string( + 'append_to_csv', + None, + 'If set, appends the best tiling to the CSV file passed', +) + + +def main() -> None: + dims = MatmulSize( + M=_M.value, N=_N.value, K=_K.value, quantized_lhs=_QUANTIZED_LHS.value + ) + s = torch.cuda.Stream() + tilings = generate_tiling_configs( + parse_int_list(_TILINGS_M.value), + parse_int_list(_TILINGS_N.value), + parse_int_list(_TILINGS_K.value), + parse_int_list(_SPLIT_KS.value), + parse_int_list(_NUM_STAGES.value), + parse_int_list(_NUM_WARPS.value), + ) + pbar = tqdm.tqdm(total=len(tilings), ncols=0) + timings = sorted( + benchmark_matmul( + dims, pbar, s, tilings, repetitions_ms=300, debug=_DEBUG.value + ), + key=lambda t: t.min_time_ms, + ) + fastest: MatmulTiming = timings[0] + print(f'Fastest configuration: {fastest}') + + features_list = [ + 'BLOCK_M', + 'BLOCK_N', + 'BLOCK_K', + 'SPLIT_K', + 'num_stages', + 'num_warps', + ] + features = frozenset(features_list) + for f in features: + other_features = features - {f} + + def other_features_equal_to_best(t): + return all( + getattr(fastest.tiling, of) == getattr(t.tiling, of) + for of in other_features # pylint: disable=cell-var-from-loop + ) + + # Keep everyting but the currently evaluated feature fixed to the best + # value. + others_fixed = [t for t in timings if other_features_equal_to_best(t)] + + # TODO(cheshire): Visualize. + print( + f'Varying feature {f}:', + ', '.join( + f'{t.min_time_ms:0.4f} @ {f}={getattr(t.tiling, f)}' + for t in others_fixed + ), + ) + + print_roofline_performance(dims, fastest.min_time_ms) + cublas_time = benchmark_cublas(dims) + print(f'Reference cuBLAS time (bf16xbf16->bf16): {cublas_time:0.4f}ms') + + if _APPEND_TO_CSV.value: + fields = ( + ['M', 'N', 'K', 'quantized_lhs'] + + features_list + + ['min_time_ms', 'cublas_time_ms'] + ) + with open(_APPEND_TO_CSV.value, 'a') as f: + writer = csv.DictWriter(f, fieldnames=fields) + if f.tell() == 0: + writer.writeheader() + writer.writerow( + dict( + fastest.dims._asdict(), + **fastest.tiling._asdict(), + min_time_ms=fastest.min_time_ms, + cublas_time_ms=cublas_time, + ) + ) + + +if __name__ == '__main__': + app.parse_flags_with_usage(sys.argv) + main() diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 1f9ffa35a3d..83e21ca58de 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -2950,6 +2950,36 @@ relative order of the equal values is preserved. Two elements `e1` and `e2` are equal if and only if `comparator(e1, e2) = comparator(e2, e1) = false`. By default, `is_stable` is set to false. +## Top-K + +See also the `jax.lax.top_k` operation. + +`TopK(operand)` + +Arguments | Type | Semantics +------------ | ---------------- | --------------------------------------------- +`operand` | `XlaOp` | N-dimensional array +`k` | `int64` | Integer specifying the number of top entries. +`comparator` | `XlaComputation` | The comparator computation to use. + +Returns top `k` values and their indices as a tuple, along the last dimension of +the operand using the given `comparator` (for usual topk behavior, it should be +strict-greater-than operation). + +For example, given strict `>` operator, `k=1` and the following operand of shape +`f32[2,3]`: + +``` +[[0.1, 0.3, 0.1], [0.7, 0.2, -0.1]] +``` + +The TopK application returns the following tuple of shape `(f32[2,1], +s32[2,1])`: + +``` +([[0.3], [0.7]], [[1], [0]]) +``` + ## Transpose See also the `tf.reshape` operation. diff --git a/tensorflow/compiler/xla/glob_lit_test.bzl b/tensorflow/compiler/xla/glob_lit_test.bzl index 8863805fff7..217dd18e608 100644 --- a/tensorflow/compiler/xla/glob_lit_test.bzl +++ b/tensorflow/compiler/xla/glob_lit_test.bzl @@ -65,6 +65,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties): ) def glob_lit_tests( + name = None, exclude = [], test_file_exts = _default_test_file_exts, default_size = _default_size, @@ -79,6 +80,7 @@ def glob_lit_tests( """Creates all plausible Lit tests (and their inputs) under this directory. Args: + name: str, name of the test_suite rule to generate for running all tests. exclude: [str], paths to exclude (for tests and inputs). test_file_exts: [str], extensions for files that are tests. default_size: str, the test size for targets not in "size_override". @@ -104,7 +106,10 @@ def glob_lit_tests( # Run tests individually such that errors can be attributed to a specific # failure. + all_tests = [] for curr_test in tests: + all_tests.append(curr_test + ".test") + # Instantiate this test with updated parameters. _run_lit_test( name = curr_test + ".test", @@ -115,3 +120,11 @@ def glob_lit_tests( features = features, exec_properties = exec_properties, ) + + # TODO: remove this check after making it a required param. + if name: + native.test_suite( + name = name, + tests = all_tests, + tags = ["manual"], + ) diff --git a/tensorflow/compiler/xla/hlo/evaluator/BUILD b/tensorflow/compiler/xla/hlo/evaluator/BUILD index d9a7ef2218d..a4253c21179 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/BUILD +++ b/tensorflow/compiler/xla/hlo/evaluator/BUILD @@ -63,6 +63,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/tsl/lib/core:bitmap", "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:float8", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:protobuf", "//tensorflow/tsl/platform:status", @@ -104,11 +105,9 @@ xla_cc_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_benchmark", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc index 67311db899f..9a6f0dc8cf9 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h" #include +#include #include #include #include @@ -64,6 +65,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/lib/core/bitmap.h" #include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/float8.h" #include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/protobuf.h" #include "tensorflow/tsl/platform/status.h" @@ -74,9 +76,7 @@ namespace xla { namespace { -template -using NativeTypeOf = - typename primitive_util::PrimitiveTypeToNative::type; +using primitive_util::NativeTypeOf; template StatusOr Compare(const Shape& shape, ComparisonDirection direction, @@ -237,48 +237,17 @@ struct PopulateImpl { // native types to avoid templating the whole implementations. template