diff --git a/.bazelrc b/.bazelrc index 94841167276..391fc927c27 100644 --- a/.bazelrc +++ b/.bazelrc @@ -184,6 +184,10 @@ build:android_x86_64 --config=android build:android_x86_64 --cpu=x86_64 build:android_x86_64 --fat_apk_cpu=x86_64 +# Build everything statically for Android since all static libs are later +# bundled together into a single .so for deployment. +build:android --dynamic_mode=off + # Sets the default Apple platform to macOS. build:macos --apple_platform_type=macos @@ -202,6 +206,8 @@ build:ios_armv7 --config=ios build:ios_armv7 --cpu=ios_armv7 build:ios_arm64 --config=ios build:ios_arm64 --cpu=ios_arm64 +build:ios_arm64e --config=ios +build:ios_arm64e --cpu=ios_arm64e build:ios_sim_arm64 --config=ios build:ios_sim_arm64 --cpu=ios_sim_arm64 build:ios_i386 --config=ios @@ -219,7 +225,9 @@ build:monolithic --define framework_shared_object=false build:monolithic --define tsl_protobuf_header_only=false build:monolithic --experimental_link_static_libraries_once=false # b/229868128 -# Please note that MKL on MacOS or windows is still not supported. +build:linux --define=build_with_onednn_v2=true + +# Please note that MKL on MacOS is still not supported. # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true @@ -551,8 +559,8 @@ build:rbe_linux_py3_base --python_path="/usr/local/bin/python3.9" build:rbe_linux_py3_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9" build:rbe_win --config=rbe -build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_02232023:toolchain" -build:rbe_win --extra_toolchains="//tensorflow/tools/toolchains/win/tf_win_02232023:cc-toolchain-x64_windows" +build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_05022023:toolchain" +build:rbe_win --extra_toolchains="//tensorflow/tools/toolchains/win/tf_win_05022023:cc-toolchain-x64_windows" build:rbe_win --extra_execution_platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" build:rbe_win --host_platform="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" build:rbe_win --platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" @@ -683,10 +691,10 @@ build:ubsan --linkopt -fsanitize=undefined build:ubsan --linkopt -lubsan # Disable TFRT integration for now unless --config=tfrt is specified. -build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils +build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug # TODO(b/240450920): We are in the process of migrating JitRt backend to XLA # and while we are doing this we can't keep it buildable/testable in OSS. -build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils +build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug # TF Fuzztest config try-import fuzztest.bazelrc diff --git a/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml b/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml index 91c37cfe117..ac5643d9276 100644 --- a/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml @@ -131,7 +131,6 @@ body: description: Also tell us, what did you expect to happen? placeholder: Tell us what you see! value: "A bug happened!" - render: shell validations: required: true - type: textarea diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 45b4c58fc90..b5cf2a5a6c2 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -16,9 +16,8 @@ # A list of assignees assignees: - synandi - - tiruk007 + - SuryanarayanaY - tilakrayal - - pjpratik # A list of assignees for compiler folder compiler_assignees: - joker-eph diff --git a/.github/workflows/arm-ci-extended-cpp.yml b/.github/workflows/arm-ci-extended-cpp.yml new file mode 100644 index 00000000000..cfa3a214918 --- /dev/null +++ b/.github/workflows/arm-ci-extended-cpp.yml @@ -0,0 +1,61 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +name: ARM CI Extended C++ + +on: + push: + tags: + - v2.** + schedule: + - cron: '0 2 * * *' + +jobs: + build: + if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks + runs-on: [self-hosted, linux, ARM64] + strategy: + matrix: + pyver: ['3.10'] + steps: + - name: Stop old running containers (if any) + shell: bash + run: | + running_containers=$(docker ps -q) && \ + if [[ $running_containers == "" ]]; then + echo "No running containers"; + else + echo "Running container(s) found" && \ + docker stop $running_containers; + fi + docker container prune -f + docker image prune -af + - name: Clean repository + shell: bash + run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true + - name: Checkout repository for nightly (skipped for releases) + if: ${{ github.event_name == 'schedule' }} + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + with: + ref: 'nightly' + - name: Checkout repository + if: ${{ github.event_name == 'push' }} + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + - name: Build binary and run C++ tests + shell: bash + run: | + is_nightly=0 && tf_project_name='tf_ci_ext_c' && ${{ github.event_name == 'schedule' }} && is_nightly=1 && tf_project_name='tf_nightly_ci_ext_c' + CI_DOCKER_BUILD_EXTRA_PARAMS="--build-arg py_major_minor_version=${{ matrix.pyver }} --build-arg is_nightly=${is_nightly} --build-arg tf_project_name=${tf_project_name}" \ + ./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_cpp.sh diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 1592f4ed18a..7e32dafabe9 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -17,14 +17,10 @@ name: ARM CI Extended on: push: - branches: - - master - - r2.** - pull_request: - types: [opened, synchronize, reopened] - branches: - - master - - r2.** + tags: + - v2.** + schedule: + - cron: '0 4 * * *' jobs: build: @@ -49,10 +45,17 @@ jobs: - name: Clean repository shell: bash run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true + - name: Checkout repository for nightly (skipped for releases) + if: ${{ github.event_name == 'schedule' }} + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 + with: + ref: 'nightly' - name: Checkout repository + if: ${{ github.event_name == 'push' }} uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Build binary and run non-pip tests shell: bash run: | - CI_DOCKER_BUILD_EXTRA_PARAMS='--build-arg py_major_minor_version=${{ matrix.pyver }}' \ + is_nightly=0 && tf_project_name='tf_ci_ext' && ${{ github.event_name == 'schedule' }} && is_nightly=1 && tf_project_name='tf_nightly_ci_ext' + CI_DOCKER_BUILD_EXTRA_PARAMS="--build-arg py_major_minor_version=${{ matrix.pyver }} --build-arg is_nightly=${is_nightly} --build-arg tf_project_name=${tf_project_name}" \ ./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml index e6ddbb9eec9..b0876ba60d7 100644 --- a/.github/workflows/arm-ci.yml +++ b/.github/workflows/arm-ci.yml @@ -54,7 +54,7 @@ jobs: - name: Build and test pip wheel shell: bash run: | - CI_DOCKER_BUILD_EXTRA_PARAMS='--build-arg py_major_minor_version=${{ matrix.pyver }}' \ + CI_DOCKER_BUILD_EXTRA_PARAMS="--build-arg py_major_minor_version=${{ matrix.pyver }} --build-arg is_nightly=1 --build-arg tf_project_name=tf_nightly_ci" \ ./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh - name: Upload pip wheel to GitHub uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # v3.1.1 diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index d4fd32171b4..b4579591c91 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -28,8 +28,16 @@ jobs: pull-requests: write steps: - name: Awaiting response issues - uses: actions/stale@v5 + uses: actions/stale@v7 with: + #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale + exempt-issue-labels: 'override-stale' + #Comma separated list of labels that can be assigned to PRs to exclude them from being marked as stale + exempt-pr-labels: "override-stale" + #Limit the No. of API calls in one run default value is 30. + operations-per-run: 1000 + #Prevent to remove stale label when PRs or issues are updated. + remove-stale-when-updated: false days-before-issue-stale: 7 days-before-issue-close: 7 stale-issue-label: "stale" @@ -48,8 +56,16 @@ jobs: close-pr-message: "This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further." repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Contribution issues - uses: actions/stale@v5 + uses: actions/stale@v7 with: + #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale + exempt-issue-labels: 'override-stale' + #Comma separated list of labels that can be assigned to PRs to exclude them from being marked as stale + exempt-pr-labels: "override-stale" + #Limit the No. of API calls in one run default value is 30. + operations-per-run: 1000 + #Prevent to remove stale label when PRs or issues are updated. + remove-stale-when-updated: false days-before-issue-stale: 180 days-before-issue-close: 365 stale-issue-label: "stale" diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index ce31d59868a..d32d7affd64 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -80,6 +80,18 @@ jobs: map sigbuild-r2.12-clang-python3.9 2.12-python3.9 map sigbuild-r2.12-clang-python3.10 2.12-python3.10 map sigbuild-r2.12-clang-python3.11 2.12-python3.11 + # TF 2.13 + map sigbuild-r2.13 2.13-python3.9 + map sigbuild-r2.13-python3.8 2.13-python3.8 + map sigbuild-r2.13-python3.9 2.13-python3.9 + map sigbuild-r2.13-python3.10 2.13-python3.10 + map sigbuild-r2.13-python3.11 2.13-python3.11 + # TF 2.13 + Clang (containers are the same, but env vars in configs.bzl are different) + map sigbuild-r2.13-clang 2.13-python3.9 + map sigbuild-r2.13-clang-python3.8 2.13-python3.8 + map sigbuild-r2.13-clang-python3.9 2.13-python3.9 + map sigbuild-r2.13-clang-python3.10 2.13-python3.10 + map sigbuild-r2.13-clang-python3.11 2.13-python3.11 - name: Create Pull Request with changes uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3 with: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ccc170b5c6e..beea15f9bf0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,39 +19,58 @@ Before sending your pull requests, make sure you do the following: ### Typical Pull Request Workflow - -**1. New PR** - As a contributor, you submit a New PR on GitHub. - We inspect -every incoming PR and add certain labels to the PR such as `size:`, `comp:` etc. -At this stage we check if the PR is valid and meets certain quality -requirements. - For example - We check if the CLA is signed, PR has sufficient -description, if applicable unit tests are added, if it is a reasonable -contribution meaning it is not a single liner cosmetic PR. +**1. New PR** -**2. Valid?** - If the PR passes all the quality checks then we go ahead and -assign a reviewer. - If the PR didn't meet the validation criteria, we request -for additional changes to be made to PR to pass quality checks and send it back -or on a rare occassion we may reject it. +- As a contributor, you submit a New PR on GitHub. +- We inspect every incoming PR and add certain labels to the PR such as `size:`, + `comp:` etc. At this stage we check if the PR is valid and meets certain + quality requirements. For example, we check if the CLA is signed, PR has + sufficient description, if applicable unit tests are added, if it is a + reasonable contribution (meaning it is not a single liner cosmetic PR). -**3. Review** - For Valid PR, reviewer (person familiar with the -code/functionality) checks if the PR looks good or needs additional changes. - -If all looks good, reviewer would approve the PR. - If a change is needed, the -contributor is requested to make suggested change. - You make the change and -submit for the review again. - This cycle repeats itself till the PR gets -approved. - Note: As a friendly reminder we may reach out to you if the PR is -awaiting your response for more than 2 weeks. +**2. Valid?** -**4. Approved** - Once the PR is approved, it gets `kokoro:force-run` label -applied and it initiates CI/CD tests. - We can't move forward if these tests -fail. - In such situations, we may request you to make further changes to your -PR for the tests to pass. - Once the tests pass, we now bring all the code in -the internal code base, using a job called "copybara". +- If the PR passes all the quality checks then we go ahead and assign a + reviewer. +- If the PR didn't meet the validation criteria, we request for additional + changes to be made to PR to pass quality checks and send it back or on a rare + occassion we may reject it. -**5. Copy to G3** - Once the PR is in Google codebase, we make sure it -integrates well with its dependencies and the rest of the system. - Rarely, but -If the tests fail at this stage, we cannot merge the code. - If needed, we may -come to you to make some changes. - At times, it may not be you, it may be us -who may have hit a snag. - Please be patient while we work to fix this. - Once -the internal tests pass, we go ahead and merge the code internally as well as -externally on GitHub. +**3. Review** + +- For Valid PR, reviewer (person familiar with the code/functionality) checks if + the PR looks good or needs additional changes. +- If all looks good, reviewer would approve the PR. +- If a change is needed, the contributor is requested to make suggested change. +- You make the change and submit for the review again. +- This cycle repeats itself till the PR gets approved. +- Note: As a friendly reminder we may reach out to you if the PR is awaiting + your response for more than 2 weeks. + +**4. Approved** + +- Once the PR is approved, it gets `kokoro:force-run` label applied and it + initiates CI/CD tests. +- We can't move forward if these tests fail. +- In such situations, we may request you to make further changes to your PR for + the tests to pass. +- Once the tests pass, we now bring all the code in the internal code base, + using a job called "copybara". + +**5. Copy to Google Internal codebase and run internal CI** + +- Once the PR is in Google codebase, we make sure it integrates well with its + dependencies and the rest of the system. +- Rarely, but If the tests fail at this stage, we cannot merge the code. +- If needed, we may come to you to make some changes. At times, it may not be + you, it may be us who may have hit a snag. Please be patient while we work to + fix this. +- Once the internal tests pass, we go ahead and merge the code internally as + well as externally on GitHub. + +In a graphical form, the entire lifetime of a PR looks like + +![image](https://user-images.githubusercontent.com/323199/229561784-0a2f5509-b731-493f-ad88-bad487688c8d.png) ### Contributor License Agreements diff --git a/README.md b/README.md index 2e1f9c72183..fa7a6c45733 100644 --- a/README.md +++ b/README.md @@ -92,8 +92,8 @@ uphold this code.** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for tracking requests and bugs, please see -[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) -for general questions and discussion, and please direct specific questions to +[TensorFlow Forum](https://discuss.tensorflow.org/) for general questions and +discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** The TensorFlow project strives to abide by generally accepted best practices in diff --git a/RELEASE.md b/RELEASE.md index 15bfd428d6e..87ebf46e557 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,45 @@ +# Release 2.14.0 + + + +# Breaking Changes + +* +* + +* `tf.Tensor` + * The class hierarchy for `tf.Tensor` has changed, and there are now + explicit `EagerTensor` and `SymbolicTensor` classes for eager and + tf.function respectively. Users who relied on the exact type of Tensor + (e.g. `type(t) == tf.Tensor`) will need to update their code to use + `isinstance(t, tf.Tensor)`. The `tf.is_symbolic_tensor` helper added in + 2.13 may be used when it is necessary to determine if a value is + specifically a symbolic tensor. + +# Known Caveats + +* +* +* + +# Major Features and Improvements + +* +* + +# Bug Fixes and Other Changes +* `tf.lite` + * Strided_Slice now supports `UINT32`. +* +* +* + +# Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +, , , , , + # Release 2.13.0 ## Breaking Changes @@ -18,6 +60,26 @@ modifying H5 files saved by Keras under a `.keras` extension. If this breaks you, simply add `save_format="h5"` to your `.save()` call to revert back to the prior behavior. + * Added `keras.utils.TimedThread` utility to run a timed thread every x + seconds. It can be used to run a threaded function alongside model + training or any other snippet of code. + * In the `keras` PyPI package, accessible symbols are now restricted to + symbols that are intended to be public. + This may affect your code if you were using `import keras` and you used + `keras` functions that were not public APIs, but were accessible in + earlier versions with direct imports. In those cases, please use the + following guideline: + - The API may be available in the public Keras API under a different + name, so make sure to look for it on keras.io or TensorFlow docs + and switch to the public version. + - It could also be a simple python or TF utility that you could easily + copy over to your own codebase. In those case, just make it your own! + - If you believe it should definitely be a public Keras API, + please open a feature request in keras GitHub repo. + - As a workaround, you could import the same private symbol keras + `keras.src`, but keep in mind the `src` namespace is not stable and + those APIs may change or be removed in the future. + * The LMDB kernels have been changed to return an error. This is in preparation for completely removing them from TensorFlow. The LMDB dependency that these @@ -40,11 +102,19 @@ clustering. * Add int16x8 support for the built-in op `exp` * Add int16x8 support for the built-in op `mirror_pad` + * Add int16x8 support for the built-in ops `space_to_batch_nd` and + `batch_to_space_nd` * Add 16-bit int type support for built-in op `less`, `greater_than`, `equal` * Add 8-bit and 16-bit support for `floor_div` and `floor_mod`. + * Add 16-bit and 32-bit int support for the built-in op `bitcast`. + * Add 8-bit/16-bit/32-bit int/uint support for the built-in op `bitwise_xor` * Add int16 indices support for built-in op `gather` and `gather_nd`. + * Add 8-bit/16-bit/32-bit int/uint support for the built-in op `right_shift` * Add reference implementation for 16-bit int unquantized `add`. + * Add reference implementation for 16-bit int and 32-bit unsigned int unquantized `mul`. + * `add_op` supports broadcasting up to 6 dimensions. + * Add 16-bit support for `top_k`. * `tf.keras` @@ -57,6 +127,8 @@ libraries (like sklearn or pycocotools) into Keras as first-class Keras metrics. * Added `tf.keras.optimizers.Lion` optimizer. + * Added `tf.keras.layers.SpectralNormalization` layer wrapper to perform + spectral normalization on the weights of a target layer. * The `SidecarEvaluatorModelExport` callback has been added to Keras as `keras.callbacks.SidecarEvaluatorModelExport`. This callback allows for exporting the model the best-scoring model as evaluated by a @@ -76,6 +148,16 @@ `tf.keras.__internal__.RaggedKerasTensor` classes. You can use these classes to do instance type checking and type annotations for layer/model inputs and outputs. + * All the `tf.keras.dtensor.experimental.optimizers` classes have been + merged with `tf.keras.optimizers`. You can migrate your code to use + `tf.keras.optimizers` directly. The API namespace for + `tf.keras.dtensor.experimental.optimizers` will be removed in future + releases. + * Added support for `class_weight` for 3+ dimensional targets (e.g. + image segmentation masks) in `Model.fit`. + * Added a new loss, `keras.losses.CategoricalFocalCrossentropy`. + * Remove the `tf.keras.dtensor.experimental.layout_map_scope()`. You can + user the `tf.keras.dtensor.experimental.LayoutMap.scope()` instead. * `tf.function`: @@ -94,6 +176,22 @@ `tf.nn.safe_embedding_lookup_sparse`, which enables a simplified and typically faster lookup procedure. +* `tf.data` + + * `tf.data.Dataset.zip` now supports Python-style zipping, i.e. + `Dataset.zip(a, b, c)`. + * `tf.data.Dataset.shuffle` now supports full shuffling. To specify that + data should be fully shuffled, use + `dataset = dataset.shuffle(dataset.cardinality())`. This will load the + full dataset into memory so that it can be shuffled, so make sure to + only use this with datasets of filenames or other small datasets. + +* `tf.math` + + * `tf.nn.top_k` now supports specifying the output index type via parameter + `index_type`. Supported types are `tf.int16`, `tf.int32` + (default), and `tf.int64`. + * `tf.SavedModel` * Introduce class method @@ -109,6 +207,13 @@ * * +* `tf.Variable` + + * Changed resource variables to inherit from `tf.compat.v2.Variable` + instead of `tf.compat.v1.Variable`. Some checks for + `isinstance(v, tf.compat.v1.Variable)` that previously returned True + may now return False. + * `tf.distribute` * Opened an experimental API, @@ -124,6 +229,20 @@ * List of members of dtensor.Layout and dtensor.Mesh have slightly changed as part of efforts to consolidate the C++ and Python source code with pybind11. Most notably, Layout.serialized_string is removed. + * Minor API changes to represent Single Device Layout for non-distributed + Tensors inside DTensor functions. Runtime support will be added soon. + +* `tf.experimental.ExtensionType`: + + * `tf.experimental.ExtensionType` now supports Python `tuple` as + the type annotation of its fields. + +* `tf.nest`: + * Deprecated API `tf.nest.is_sequence` has now been deleted. + Please use `tf.nest.is_nested` instead. + +* `tf.lite`: + * Add UINT32 support to tfl.pack ## Thanks to our Contributors @@ -134,217 +253,166 @@ This release contains contributions from many people at Google, as well as: # Release 2.12.0 -# Breaking Changes - -* -* +### Breaking Changes * Build, Compilation and Packaging - * Removal of redundant packages: the `tensorflow-gpu` and `tf-nightly-gpu` - packages have been effectively removed and replaced with packages that - direct users to switch to `tensorflow` or `tf-nightly` respectively. - The naming difference was the only difference between the two sets of - packages ever since TensorFlow 2.1, so there is no loss of functionality - or GPU support. See - https://pypi.org/project/tensorflow-gpu for more details. + * Removed redundant packages `tensorflow-gpu` and `tf-nightly-gpu`. These packages were removed and replaced with packages that direct users to switch to `tensorflow` or `tf-nightly` respectively. Since TensorFlow 2.1, the only difference between these two sets of packages was their names, so there is no loss of functionality or GPU support. See https://pypi.org/project/tensorflow-gpu for more details. * `tf.function`: - * tf.function now uses the Python inspect library directly for parsing - the signature of the Python function it is decorated on. - * This can break certain cases that were previously ignored where the - signature is malformed, e.g. - * Using functools.wraps on a function with different signature - * Using functools.partial with an invalid tf.function input - * tf.function now enforces input parameter names to be valid Python - identifiers. Incompatible names are automatically sanitized similarly to - existing SavedModel signature behavior. - * Parameterless tf.functions are assumed to have an empty input_signature - instead of an undefined one even if the input_signature is unspecified. - * tf.types.experimental.TraceType now requires an additional - `placeholder_value` method to be defined. - * tf.function now traces with placeholder values generated by TraceType - instead of the value itself. + * `tf.function` now uses the Python inspect library directly for parsing the signature of the Python function it is decorated on. This change may break code where the function signature is malformed, but was ignored previously, such as: + * Using `functools.wraps` on a function with different signature + * Using `functools.partial` with an invalid `tf.function` input + * `tf.function` now enforces input parameter names to be valid Python identifiers. Incompatible names are automatically sanitized similarly to existing SavedModel signature behavior. + * Parameterless `tf.function`s are assumed to have an empty `input_signature` instead of an undefined one even if the `input_signature` is unspecified. + * `tf.types.experimental.TraceType` now requires an additional `placeholder_value` method to be defined. + * `tf.function` now traces with placeholder values generated by TraceType instead of the value itself. -* `tf.config.experimental.enable_mlir_graph_optimization`: +* Experimental APIs `tf.config.experimental.enable_mlir_graph_optimization` and `tf.config.experimental.disable_mlir_graph_optimization` were removed. - * Experimental API removed. +### Major Features and Improvements -* `tf.config.experimental.disable_mlir_graph_optimization`: - - * Experimental API removed. - -* `tf.keras` - - * Moved all saving-related utilities to a new namespace, `keras.saving`, - i.e. `keras.saving.load_model`, `keras.saving.save_model`, - `keras.saving.custom_object_scope`, `keras.saving.get_custom_objects`, - `keras.saving.register_keras_serializable`, - `keras.saving.get_registered_name` and - `keras.saving.get_registered_object`. - The previous API locations (in `keras.utils` and `keras.models`) will - stay available indefinitely, but we recommend that you update your code - to point to the new API locations. - * Improvements and fixes in Keras loss masking: - * Whether you represent a ragged tensor as a `tf.RaggedTensor` or using - [keras masking](https://www.tensorflow.org/guide/keras/masking_and_padding), - the returned loss values should be the identical to each other. - In previous versions Keras may have silently ignored the mask. - * If you use masked losses with Keras the loss values may be different - in TensorFlow `2.12` compared to previous versions. - * In cases where the mask was previously ignored, you will now get - an error if you pass a mask with an incompatible shape. - -* `tf.SavedModel` - - * Introduce new class `tf.saved_model.experimental.Fingerprint` that - contains the fingerprint of the SavedModel. See the - [SavedModel Fingerprinting RFC](https://github.com/tensorflow/community/pull/415) - for details. - * Introduce API `tf.saved_model.experimental.read_fingerprint(export_dir)` - for reading the fingerprint of a SavedModel. - - -# Known Caveats - -* -* -* - -# Major Features and Improvements +* Support for Python 3.11 has been added. +* Support for Python 3.7 has been removed. We are not releasing any more patches for Python 3.7. * `tf.lite`: * Add 16-bit float type support for built-in op `fill`. * Transpose now supports 6D tensors. - * Float LSTM now supports diagonal recurrent tensors: - https://arxiv.org/abs/1903.08023 - -* `tf.keras`: - - * The new Keras model saving format (`.keras`) is available. You can start - using it via `model.save(f"{fname}.keras", save_format="keras_v3")`. In - the future it will become the default for all files with the `.keras` - extension. This file format targets the Python runtime only and makes - it possible to reload Python objects identical to the saved originals. - The format supports non-numerical state such as vocabulary files and - lookup tables, and it is easy to customize in the case of custom layers - with exotic elements of state (e.g. a FIFOQueue). The format - does not rely on bytecode or pickling, and is safe by default. Note - that as a result, Python `lambdas` are disallowed at loading time. If - you want to use `lambdas`, you can pass `safe_mode=False` to the loading - method (only do this if you trust the source of the model). - * Added a `model.export(filepath)` API to create a lightweight SavedModel - artifact that can be used for inference (e.g. with TF-Serving). - * Added `keras.export.ExportArchive` class for low-level customization of - the process of exporting SavedModel artifacts for inference. - Both ways of exporting models are based on `tf.function` tracing - and produce a TF program composed of TF ops. They are meant primarily - for environments where the TF runtime is available, - but not the Python interpreter, as is typical - for production with TF Serving. - * Added utility `tf.keras.utils.FeatureSpace`, a one-stop shop for - structured data preprocessing and encoding. - * Added `tf.SparseTensor` input support to `tf.keras.layers.Embedding` - layer. The layer now accepts a new boolean argument `sparse`. If - `sparse` is set to True, the layer returns a SparseTensor instead of a - dense Tensor. Defaults to False. - * Added `jit_compile` as a settable property to `tf.keras.Model`. - * Added `synchronized` optional parameter to `layers.BatchNormalization`. - * Added deprecation warning to - `layers.experimental.SyncBatchNormalization` and suggested to use - `layers.BatchNormalization` with `synchronized=True` instead. - * Updated `tf.keras.layers.BatchNormalization` to support masking of the - inputs (`mask` argument) when computing the mean and variance. - * Add `tf.keras.layers.Identity`, a placeholder pass-through layer. - * Add `show_trainable` option to `tf.keras.utils.model_to_dot` to display - layer trainable status in model plots. - * Add ability to save a `tf.keras.utils.FeatureSpace` object, via - `feature_space.save("myfeaturespace.keras")`, and reload it via - `feature_space = tf.keras.models.load_model("myfeaturespace.keras")`. - * Added utility `tf.keras.utils.to_ordinal` to convert class vector to - ordinal regression / classification matrix. + * Float LSTM now supports diagonal recurrent tensors: https://arxiv.org/abs/1903.08023 * `tf.experimental.dtensor`: - * Coordination service now works with - `dtensor.initialize_accelerator_system`, and enabled by default. - * Add `tf.experimental.dtensor.is_dtensor` to check if a tensor is a - DTensor instance. + * Coordination service now works with `dtensor.initialize_accelerator_system`, and enabled by default. + * Add `tf.experimental.dtensor.is_dtensor` to check if a tensor is a DTensor instance. * `tf.data`: - * Added support for alternative checkpointing protocol which makes it - possible to checkpoint the state of the input pipeline without having to - store the contents of internal buffers. The new functionality can be - enabled through the `experimental_symbolic_checkpoint` option of - `tf.data.Options()`. - * Added a new `rerandomize_each_iteration` argument for the - `tf.data.Dataset.random()` operation, which controls whether the - sequence of generated random numbers should be re-randomized every epoch - or not (the default behavior). If `seed` is set and - `rerandomize_each_iteration=True`, the `random()` operation will produce - a different (deterministic) sequence of numbers every epoch. - * Added a new `rerandomize_each_iteration` argument for the - `tf.data.Dataset.sample_from_datasets()` operation, which controls - whether the sequence of generated random numbers used for sampling - should be re-randomized every epoch or not. If `seed` is set and - `rerandomize_each_iteration=True`, the `sample_from_datasets()` - operation will use a different (deterministic) sequence of numbers every - epoch. - * Added a new field, `warm_start`, to - `tf.data.experimental.OptimizationOptions`. If it is set to `True`, - tf.data will start background threads of asynchronous - transformations upon iterator creation (as opposed to upon first call - to `GetNext`). To enable this behavior, set `warm_start=True` in - `tf.data.experimental.OptimizationOptions`. It should be noted that this - possibly improves the latency of the initial 'GetNext' call at the - expense of requiring more memory to hold prefetched elements between - the time of iterator construction and usage. + * Added support for alternative checkpointing protocol which makes it possible to checkpoint the state of the input pipeline without having to store the contents of internal buffers. The new functionality can be enabled through the `experimental_symbolic_checkpoint` option of `tf.data.Options()`. + * Added a new `rerandomize_each_iteration` argument for the `tf.data.Dataset.random()` operation, which controls whether the sequence of generated random numbers should be re-randomized every epoch or not (the default behavior). If `seed` is set and `rerandomize_each_iteration=True`, the `random()` operation will produce a different (deterministic) sequence of numbers every epoch. + * Added a new `rerandomize_each_iteration` argument for the `tf.data.Dataset.sample_from_datasets()` operation, which controls whether the sequence of generated random numbers used for sampling should be re-randomized every epoch or not. If `seed` is set and `rerandomize_each_iteration=True`, the `sample_from_datasets()` operation will use a different (deterministic) sequence of numbers every epoch. + * `tf.test`: - * Added `tf.test.experimental.sync_devices`, which is useful for - accurately measuring performance in benchmarks. + * Added `tf.test.experimental.sync_devices`, which is useful for accurately measuring performance in benchmarks. * `tf.experimental.dtensor`: * Added experimental support to ReduceScatter fuse on GPU (NCCL). -# Bug Fixes and Other Changes - -* -* -* +### Bug Fixes and Other Changes +* `tf.SavedModel`: + * Introduced new class `tf.saved_model.experimental.Fingerprint` that contains the fingerprint of the SavedModel. See the [SavedModel Fingerprinting RFC](https://github.com/tensorflow/community/pull/415) for details. + * Introduced API `tf.saved_model.experimental.read_fingerprint(export_dir)` for reading the fingerprint of a SavedModel. * `tf.random` - * Added non-experimental aliases for `tf.random.split` and - `tf.random.fold_in`, the experimental endpoints are still available - so no code changes are necessary. + * Added non-experimental aliases for `tf.random.split` and `tf.random.fold_in`, the experimental endpoints are still available so no code changes are necessary. * `tf.experimental.ExtensionType` - * Added function `experimental.extension_type.as_dict()`, which converts an - instance of `tf.experimental.ExtensionType` to a `dict` representation. + * Added function `experimental.extension_type.as_dict()`, which converts an instance of `tf.experimental.ExtensionType` to a `dict` representation. * `stream_executor` - * Top level `stream_executor` directory has been deleted, users should use - equivalent headers and targets under `compiler/xla/stream_executor`. + * Top level `stream_executor` directory has been deleted, users should use equivalent headers and targets under `compiler/xla/stream_executor`. * `tf.nn` - * Added `tf.nn.experimental.general_dropout`, which is similar to - `tf.random.experimental.stateless_dropout` but accepts a custom sampler - function. + * Added `tf.nn.experimental.general_dropout`, which is similar to `tf.random.experimental.stateless_dropout` but accepts a custom sampler function. * `tf.types.experimental.GenericFunction` - * The `experimental_get_compiler_ir` method supports tf.TensorSpec - compilation arguments. + * The `experimental_get_compiler_ir` method supports tf.TensorSpec compilation arguments. * `tf.config.experimental.mlir_bridge_rollout` - * Removed enums `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` and - `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED` which are no longer used by - the tf2xla bridge + * Removed enums `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` and `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED` which are no longer used by the tf2xla bridge + +## Keras + + Keras is a framework built on top of the TensorFlow. See more details on the Keras [website](https://keras.io/). + +### Breaking Changes -# Thanks to our Contributors +`tf.keras`: + +* Moved all saving-related utilities to a new namespace, `keras.saving`, for example: `keras.saving.load_model`, `keras.saving.save_model`, `keras.saving.custom_object_scope`, `keras.saving.get_custom_objects`, `keras.saving.register_keras_serializable`,`keras.saving.get_registered_name` and `keras.saving.get_registered_object`. The previous API locations (in `keras.utils` and `keras.models`) will be available indefinitely, but we recommend you update your code to point to the new API locations. + * Improvements and fixes in Keras loss masking: + * Whether you represent a ragged tensor as a `tf.RaggedTensor` or using [keras masking](https://www.tensorflow.org/guide/keras/masking_and_padding), the returned loss values should be the identical to each other. In previous versions Keras may have silently ignored the mask. + * If you use masked losses with Keras the loss values may be different in TensorFlow `2.12` compared to previous versions. + * In cases where the mask was previously ignored, you will now get an error if you pass a mask with an incompatible shape. + +### Major Features and Improvements + +`tf.keras`: + + * The new Keras model saving format (`.keras`) is available. You can start using it via `model.save(f"{fname}.keras", save_format="keras_v3")`. In the future it will become the default for all files with the `.keras` extension. This file format targets the Python runtime only and makes it possible to reload Python objects identical to the saved originals. The format supports non-numerical state such as vocabulary files and lookup tables, and it is easy to customize in the case of custom layers with exotic elements of state (e.g. a FIFOQueue). The format does not rely on bytecode or pickling, and is safe by default. Note that as a result, Python `lambdas` are disallowed at loading time. If you want to use `lambdas`, you can pass `safe_mode=False` to the loading method (only do this if you trust the source of the model). +* Added a `model.export(filepath)` API to create a lightweight SavedModel artifact that can be used for inference (e.g. with TF-Serving). +* Added `keras.export.ExportArchive` class for low-level customization of the process of exporting SavedModel artifacts for inference. Both ways of exporting models are based on `tf.function` tracing and produce a TF program composed of TF ops. They are meant primarily for environments where the TF runtime is available, but not the Python interpreter, as is typical for production with TF Serving. + * Added utility `tf.keras.utils.FeatureSpace`, a one-stop shop for structured data preprocessing and encoding. + * Added `tf.SparseTensor` input support to `tf.keras.layers.Embedding` layer. The layer now accepts a new boolean argument `sparse`. If `sparse` is set to True, the layer returns a SparseTensor instead of a dense Tensor. Defaults to False. + * Added `jit_compile` as a settable property to `tf.keras.Model`. + * Added `synchronized` optional parameter to `layers.BatchNormalization`. + * Added deprecation warning to `layers.experimental.SyncBatchNormalization` and suggested to use `layers.BatchNormalization` with `synchronized=True` instead. + * Updated `tf.keras.layers.BatchNormalization` to support masking of the inputs (`mask` argument) when computing the mean and variance. + * Add `tf.keras.layers.Identity`, a placeholder pass-through layer. + * Add `show_trainable` option to `tf.keras.utils.model_to_dot` to display layer trainable status in model plots. + * Add ability to save a `tf.keras.utils.FeatureSpace` object, via `feature_space.save("myfeaturespace.keras")`, and reload it via `feature_space = tf.keras.models.load_model("myfeaturespace.keras")`. +* Added utility `tf.keras.utils.to_ordinal` to convert class vector to ordinal regression / classification matrix. + +### Bug Fixes and Other Changes + +* N/A + +## Security + +* Moving forward, TensorFlow will no longer update [TFSAs](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/security). Please refer instead to our [GitHub security advisories](https://github.com/tensorflow/tensorflow/security/advisories), which are attached to [CVEs](https://cve.mitre.org/cve/). +* Fixes an FPE in TFLite in conv kernel [CVE-2023-27579](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-27579) +* Fixes a double free in Fractional(Max/Avg)Pool [CVE-2023-25801](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25801) +* Fixes a null dereference on ParallelConcat with XLA [CVE-2023-25676](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25676) +* Fixes a segfault in Bincount with XLA [CVE-2023-25675](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25675) +* Fixes an NPE in RandomShuffle with XLA enable [CVE-2023-25674](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25674) +* Fixes an FPE in TensorListSplit with XLA [CVE-2023-25673](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25673) +* Fixes segmentation fault in tfg-translate [CVE-2023-25671](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25671) +* Fixes an NPE in QuantizedMatMulWithBiasAndDequantize [CVE-2023-25670](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25670) +* Fixes an FPE in AvgPoolGrad with XLA [CVE-2023-25669](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25669) +* Fixes a heap out-of-buffer read vulnerability in the QuantizeAndDequantize operation [CVE-2023-25668](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25668) +* Fixes a segfault when opening multiframe gif [CVE-2023-25667](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25667) +* Fixes an NPE in SparseSparseMaximum [CVE-2023-25665](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25665) +* Fixes an FPE in AudioSpectrogram [CVE-2023-25666](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25666) +* Fixes a heap-buffer-overflow in AvgPoolGrad [CVE-2023-25664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25664) +* Fixes a NPE in TensorArrayConcatV2 [CVE-2023-25663](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25663) +* Fixes a Integer overflow in EditDistance [CVE-2023-25662](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25662) +* Fixes a Seg fault in `tf.raw_ops.Print` [CVE-2023-25660](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25660) +* Fixes a OOB read in DynamicStitch [CVE-2023-25659](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25659) +* Fixes a OOB Read in GRUBlockCellGrad [CVE-2023-25658](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25658) + +## Thanks to our Contributors This release contains contributions from many people at Google, as well as: -, , , , , +103yiran, 8bitmp3, Aakar, Aakar Dwivedi, Abinash Satapathy, Aditya Kane, ag.ramesh, Alexander Grund, Andrei Pikas, andreii, Andrew Goodbody, angerson, Anthony_256, Ashay Rane, Ashiq Imran, Awsaf, Balint Cristian, Banikumar Maiti (Intel Aipg), Ben Barsdell, bhack, cfRod, Chao Chen, chenchongsong, Chris Mc, Daniil Kutz, David Rubinstein, dianjiaogit, dixr, Dongfeng Yu, dongfengy, drah, Eric Kunze, Feiyue Chen, Frederic Bastien, Gauri1 Deshpande, guozhong.zhuang, hDn248, HYChou, ingkarat, James Hilliard, Jason Furmanek, Jaya, Jens Glaser, Jerry Ge, Jiao Dian'S Power Plant, Jie Fu, Jinzhe Zeng, Jukyy, Kaixi Hou, Kanvi Khanna, Karel Ha, karllessard, Koan-Sin Tan, Konstantin Beluchenko, Kulin Seth, Kun Lu, Kyle Gerard Felker, Leopold Cambier, Lianmin Zheng, linlifan, liuyuanqiang, Lukas Geiger, Luke Hutton, Mahmoud Abuzaina, Manas Mohanty, Mateo Fidabel, Maxiwell S. Garcia, Mayank Raunak, mdfaijul, meatybobby, Meenakshi Venkataraman, Michael Holman, Nathan John Sircombe, Nathan Luehr, nitins17, Om Thakkar, Patrice Vignola, Pavani Majety, per1234, Philipp Hack, pollfly, Prianka Liz Kariat, Rahul Batra, rahulbatra85, ratnam.parikh, Rickard Hallerbäck, Roger Iyengar, Rohit Santhanam, Roman Baranchuk, Sachin Muradi, sanadani, Saoirse Stewart, seanshpark, Shawn Wang, shuw, Srinivasan Narayanamoorthy, Stewart Miles, Sunita Nadampalli, SuryanarayanaY, Takahashi Shuuji, Tatwai Chong, Thibaut Goetghebuer-Planchon, tilakrayal, Tirumalesh, TJ, Tony Sung, Trevor Morris, unda, Vertexwahn, Vinila S, William Muir, Xavier Bonaventura, xiang.zhang, Xiao-Yong Jin, yleeeee, Yong Tang, Yuriy Chernyshov, Zhang, Xiangze, zhaozheng09 + + +# Release 2.11.1 + +**Note**: TensorFlow 2.10 was the last TensorFlow release that supported GPU on native-Windows. Starting with TensorFlow 2.11, you will need to install TensorFlow in WSL2, or install tensorflow-cpu and, optionally, try the TensorFlow-DirectML-Plugin. +* Security vulnerability fixes will no longer be patched to this Tensorflow version. The latest Tensorflow version includes the security vulnerability fixes. You can update to the latest version (recommended) or patch security vulnerabilities yourself [steps](https://github.com/tensorflow/tensorflow#patching-guidelines). You can refer to the [release notes](https://github.com/tensorflow/tensorflow/releases) of the latest Tensorflow version for a list of newly fixed vulnerabilities. If you have any questions, please create a GitHub issue to let us know. + +This release also introduces several vulnerability fixes: + +* Fixes an FPE in TFLite in conv kernel [CVE-2023-27579](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-27579) +* Fixes a double free in Fractional(Max/Avg)Pool [CVE-2023-25801](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25801) +* Fixes a null dereference on ParallelConcat with XLA [CVE-2023-25676](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25676) +* Fixes a segfault in Bincount with XLA [CVE-2023-25675](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25675) +* Fixes an NPE in RandomShuffle with XLA enable [CVE-2023-25674](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25674) +* Fixes an FPE in TensorListSplit with XLA [CVE-2023-25673](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25673) +* Fixes segmentation fault in tfg-translate [CVE-2023-25671](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25671) +* Fixes an NPE in QuantizedMatMulWithBiasAndDequantize [CVE-2023-25670](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25670) +* Fixes an FPE in AvgPoolGrad with XLA [CVE-2023-25669](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25669) +* Fixes a heap out-of-buffer read vulnerability in the QuantizeAndDequantize operation [CVE-2023-25668](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25668) +* Fixes a segfault when opening multiframe gif [CVE-2023-25667](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25667) +* Fixes an NPE in SparseSparseMaximum [CVE-2023-25665](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25665) +* Fixes an FPE in AudioSpectrogram [CVE-2023-25666](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25666) +* Fixes a heap-buffer-overflow in AvgPoolGrad [CVE-2023-25664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25664) +* Fixes a NPE in TensorArrayConcatV2 [CVE-2023-25663](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25663) +* Fixes a Integer overflow in EditDistance [CVE-2023-25662](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25662) +* Fixes a Seg fault in `tf.raw_ops.Print` [CVE-2023-25660](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25660) +* Fixes a OOB read in DynamicStitch [CVE-2023-25659](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25659) +* Fixes a OOB Read in GRUBlockCellGrad [CVE-2023-25658](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2023-25658) + # Release 2.11.0 diff --git a/SECURITY.md b/SECURITY.md index 0964f7debb1..87a16f17538 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -279,9 +279,9 @@ For each vulnerability, we try to ingress it as soon as possible, given the size of the team and the number of reports. Vulnerabilities will, in general, be batched to be fixed at the same time as a quarterly release. -Past security advisories are listed +Security advisories from 2018 to March 2023 are listed [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md). -In the future, we might sunset this list and only use GitHub's Security Advisory -format, to simplify the post-vulnerability-fix process. We credit reporters for -identifying security issues, although we keep your name confidential if you -request it. +From TF 2.13 onwards, we have sunset this list and only use GitHub's Security +Advisory format, to simplify the post-vulnerability-fix process. In both +locations, we credit reporters for identifying security issues, although we keep +your name confidential if you request it. diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ec1887945c4..fce465ff1f2 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -32,6 +32,10 @@ load( "//third_party/mkl:build_defs.bzl", "if_mkl_ml", ) +load( + "//third_party/mkl_dnn:build_defs.bzl", + "if_onednn_v3", +) load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load( "//tensorflow:tensorflow.default.bzl", @@ -124,7 +128,7 @@ PACKAGE_STATIC_DEPS = [ "@flatbuffers//:__subpackages__", "@nccl_archive//:__subpackages__", "@triton//:__subpackages__", -] + tsl_async_value_deps() +] + tsl_async_value_deps() + if_onednn_v3(["@onednn_v3//:__subpackages__"]) package( # copybara:uncomment default_applicable_licenses = [":license"], @@ -1025,8 +1029,10 @@ package_group( "//third_party/cloud_tpu/inference_converter/...", "//third_party/py/cloud_ml_autoflow/...", "//third_party/py/envlogger/...", + "//third_party/py/gldm/...", "//third_party/py/keras/...", "//third_party/yggdrasil_decision_forests/...", + "//waymo/ml/cn/...", ], ) @@ -1144,6 +1150,9 @@ tf_cc_shared_library( ], "//conditions:default": [ "-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)", + # copybara:uncomment_begin(google-only) + # "-Wl,--undefined-version", + # copybara:uncomment_end(google-only) ], }), linkstatic = 1, @@ -1350,6 +1359,7 @@ tf_cc_shared_library( "//tensorflow/core/data/service:server_lib", "//tensorflow/core/debug", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/framework:full_type_util", "//tensorflow/core/function/runtime_client:runtime_client_cc", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/clusters:single_machine", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index d217e7a1f51..0e70244453f 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -22,6 +22,29 @@ package( licenses = ["notice"], ) +filegroup( + name = "safe_ptr_hdr", + srcs = ["safe_ptr.h"], + visibility = [ + "//tensorflow:internal", + ], +) + +cc_library( + name = "safe_ptr", + srcs = [ + "safe_ptr.cc", + "//tensorflow/c/eager:headers", + ], + hdrs = ["safe_ptr.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":c_api_internal", + ], +) + # ----------------------------------------------------------------------------- # Public targets @@ -62,10 +85,10 @@ filegroup( "*test*", ], ) + [ - "//tensorflow/tsl/c:srcs", - "//tensorflow/tsl/platform:ctstring", "//tensorflow/cc:srcs_no_runtime", "//tensorflow/core/distributed_runtime:server_lib.h", + "//tensorflow/tsl/c:srcs", + "//tensorflow/tsl/platform:ctstring", ], visibility = ["//visibility:public"], ) @@ -94,14 +117,17 @@ cc_library( name = "c_api_headers", hdrs = [ "c_api.h", - "c_api_macros.h", ], visibility = ["//visibility:public"], deps = [ + ":c_api_macros_hdrs", ":tf_attrtype", - ":tf_buffer", - ":tf_datatype", + ":tf_buffer_hdrs", + ":tf_datatype_hdrs", ":tf_status_headers", + ":tf_tensor_hdrs", + # TODO: Only include tf_tstring_hdrs. Don't expose the implementation of TF_TString to API + # users. ":tf_tstring", ], ) @@ -165,6 +191,14 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "c_api_macros_hdrs", + hdrs = [ + "c_api_macros.h", + ], + visibility = ["//visibility:public"], +) + cc_library( name = "c_api_macros", hdrs = [ @@ -195,8 +229,9 @@ tf_cuda_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":c_api_no_xla", ":c_api_internal", + ":c_api_macros_hdrs", + ":c_api_no_xla", ":tf_attrtype", ":tf_buffer", ":tf_file_statistics", @@ -207,8 +242,8 @@ tf_cuda_library( "//tensorflow/tsl/c:tsl_status", ] + select({ "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/jit", + "//tensorflow/compiler/tf2xla:xla_compiler", ], "//conditions:default": [], }) + if_tensorrt([ @@ -240,9 +275,9 @@ tf_cuda_library( deps = [ ":c_api_internal", ":tf_attrtype", - ":tf_datatype", ":tf_buffer", ":tf_buffer_internal", + ":tf_datatype", ":tf_status_internal", ] + select({ "//tensorflow:android": [ @@ -253,25 +288,25 @@ tf_cuda_library( ":logging", ":tf_status", ":tf_tensor", - "@com_google_absl//absl/strings", "//tensorflow/c/experimental/filesystem:modular_filesystem", - "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc:grad_ops", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", - "//tensorflow/cc:grad_ops", "//tensorflow/cc:scope_internal", "//tensorflow/cc:while_loop", + "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/compiler/mlir/tfr:graph_decompose_pass", + "//tensorflow/compiler/mlir/tfr:node_expansion_pass", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:op_gen_lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/kernels:logging_ops", - "//tensorflow/compiler/mlir/tfr:node_expansion_pass", - "//tensorflow/compiler/mlir/tfr:graph_decompose_pass", + "@com_google_absl//absl/strings", ], }), alwayslink = 1, @@ -308,9 +343,10 @@ tf_cuda_library( "//tensorflow/core/transforms:__subpackages__", ], deps = [ - "//tensorflow/tsl/platform:status", + ":c_api_macros_hdrs", "//tensorflow/tsl/c:tsl_status", "//tensorflow/tsl/c:tsl_status_internal", + "//tensorflow/tsl/platform:status", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs @@ -363,6 +399,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":c_api_macros_hdrs", ":tf_status_internal", "//tensorflow/tsl/c:tsl_status", ] + select({ @@ -380,7 +417,8 @@ cc_library( hdrs = ["tf_status.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/tsl/c:tsl_status", + ":c_api_macros_hdrs", + "//tensorflow/tsl/c:tsl_status_headers", ], ) @@ -390,15 +428,15 @@ cc_library( "tf_tstring.cc", ], hdrs = [ - "c_api_macros.h", - "tf_datatype.h", - "tf_status.h", - "tf_tensor.h", "tf_tstring.h", ], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":c_api_macros_hdrs", + ":tf_datatype_hdrs", + ":tf_status_headers", + ":tf_tensor_hdrs", "//tensorflow/core/platform:status", "//tensorflow/core/platform:tstring", "//tensorflow/tsl/c:tsl_status", @@ -426,13 +464,23 @@ cc_library( }), ) +cc_library( + name = "tf_datatype_hdrs", + hdrs = ["tf_datatype.h"], + deps = [ + ":c_api_macros_hdrs", + ], +) + cc_library( name = "tf_datatype", srcs = ["tf_datatype.cc"], hdrs = ["tf_datatype.h"], copts = tf_copts(), visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":c_api_macros_hdrs", + ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], @@ -443,6 +491,17 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tf_tensor_hdrs", + hdrs = ["tf_tensor.h"], + visibility = ["//visibility:public"], + deps = [ + ":c_api_macros_hdrs", + ":tf_datatype_hdrs", + ":tf_status_headers", + ], +) + cc_library( name = "tf_tensor", srcs = ["tf_tensor.cc"], @@ -493,6 +552,16 @@ tf_cuda_library( }), ) +cc_library( + name = "tf_buffer_hdrs", + hdrs = [ + "tf_buffer.h", + ], + deps = [ + ":c_api_macros_hdrs", + ], +) + cc_library( name = "tf_buffer", srcs = [ @@ -504,6 +573,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":c_api_macros_hdrs", ":tf_buffer_internal", ":tf_status", ":tf_tensor_internal", @@ -525,6 +595,7 @@ tf_cuda_library( "//tensorflow/c:__subpackages__", ], deps = [ + ":c_api_macros_hdrs", ":tf_status", ":tf_tensor_internal", "//tensorflow/core/platform:protobuf", @@ -545,6 +616,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + ":c_api_macros_hdrs", ":checkpoint_reader", ":tf_buffer", ":tf_buffer_internal", @@ -635,9 +707,9 @@ tf_cuda_library( ], }) + [ ":c_api_macros", + ":tf_file_statistics", ":tf_status", ":tf_status_helper", - ":tf_file_statistics", "//tensorflow/core/platform:env", "//tensorflow/core/platform:path", "//tensorflow/core/platform:types", @@ -652,10 +724,11 @@ cc_library( ], visibility = ["//tensorflow:internal"], deps = [ - ":c_api_internal", - ":tf_datatype", - ":tf_status", - ":tf_tensor", + ":c_api_headers", + ":c_api_macros_hdrs", + ":tf_datatype_hdrs", + ":tf_status_headers", + ":tf_tensor_hdrs", "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", ], ) @@ -671,6 +744,7 @@ tf_cuda_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":c_api_macros_hdrs", ":tf_buffer", ":tf_buffer_internal", ":tf_status", @@ -685,12 +759,14 @@ tf_cuda_library( "//conditions:default": [ ":c_api_internal", ":tf_tensor", - "//tensorflow/compiler/xla/stream_executor:stream_executor", + "//tensorflow/c/experimental/stream_executor", + "//tensorflow/c/experimental/stream_executor:stream_executor_internal", + "//tensorflow/compiler/xla/stream_executor", "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//tensorflow/core:protos_all_cc", - "//tensorflow/c/experimental/stream_executor:stream_executor", - "//tensorflow/c/experimental/stream_executor:stream_executor_internal", + "//tensorflow/tsl/framework:device_id_utils", + "//tensorflow/tsl/platform:statusor", ], }), ) @@ -699,7 +775,10 @@ cc_library( name = "kernels_experimental_hdrs", hdrs = ["kernels_experimental.h"], visibility = ["//tensorflow:internal"], - deps = [":kernels_hdrs"], + deps = [ + ":c_api_macros_hdrs", + ":kernels_hdrs", + ], ) tf_cuda_library( @@ -709,6 +788,7 @@ tf_cuda_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":c_api_macros_hdrs", ":kernels", ":tf_status_helper", ":tf_status_internal", @@ -739,6 +819,7 @@ tf_cuda_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":c_api_macros_hdrs", ":tf_datatype", ":tf_status", ":tf_status_helper", @@ -758,6 +839,7 @@ cc_library( hdrs = ["ops.h"], visibility = ["//tensorflow:internal"], deps = [ + ":c_api_macros_hdrs", ":tf_datatype", ":tf_status", ], diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index fb951559a0e..e4c6499506e 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/tf_attrtype.h" #include "tensorflow/c/tf_buffer.h" #include "tensorflow/c/tf_datatype.h" @@ -72,25 +73,6 @@ limitations under the License. // and the API just provides high level controls over the number of // devices of each type. -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 3a05e1e64db..45697e20d1e 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -190,7 +190,7 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { } char* TF_FunctionDebugString(TF_Function* func, size_t* len) { - const auto& debug_str = DebugString(func->fdef); + const auto& debug_str = DebugString(func->record->fdef()); *len = debug_str.size(); char* ret = static_cast(malloc(*len + 1)); memcpy(ret, debug_str.c_str(), *len + 1); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index aec1e875eaf..abae68cfe48 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/eager/c_api.h" // -------------------------------------------------------------------------- @@ -28,25 +29,6 @@ limitations under the License. // The API here is subject to changes in the future. // -------------------------------------------------------------------------- -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes.$a -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a13a1458553..2fd92bd7dc0 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -16,11 +16,13 @@ limitations under the License. #include #include #include +#include #include "absl/strings/match.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_buffer_internal.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -30,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/base64.h" #include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/util/debug_data_dumper.h" using tensorflow::errors::InvalidArgument; @@ -203,23 +206,31 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( } // Do the actual function creation. - TF_Function* tf_function = new TF_Function(); DCHECK(append_hash_to_fn_name <= 1); + tensorflow::FunctionDef fdef; status->status = tensorflow::GraphToFunctionDef( fn_body->graph, fn_name, append_hash_to_fn_name != 0, /*set_stateful_from_nodes=*/true, /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors, output_tensors, output_names_vec, control_output_nodes, - control_output_names_vec, description, &tf_function->fdef); + control_output_names_vec, description, &fdef); if (TF_GetCode(status) != TF_OK) { - TF_DeleteFunction(tf_function); return nullptr; } + // Dump the op creation stacktraces for debugging purpose. + DEBUG_DATA_DUMPER()->DumpOpCreationStackTraces( + fn_name, kDebugGroupOpStacktrace, "initial", &fn_body->graph); + + tensorflow::StackTracesMap stack_traces; for (const Node* n : fn_body->graph.nodes()) { - tf_function->stack_traces[n->name()] = n->GetStackTrace(); + stack_traces[n->name()] = n->GetStackTrace(); } + TF_Function* tf_function = new TF_Function(); + tf_function->record = new tensorflow::FunctionRecord( + std::move(fdef), std::move(stack_traces), false); + return tf_function; } @@ -238,7 +249,7 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, } const char* TF_FunctionName(TF_Function* func) { - return func->fdef.signature().name().c_str(); + return func->record->fdef().signature().name().c_str(); } void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, @@ -249,19 +260,20 @@ void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, return; } - // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph - // to avoid the extra copy here. - tensorflow::FunctionDefLibrary fdef_lib; - *fdef_lib.add_function() = func->fdef; - if (grad) { - *fdef_lib.add_function() = grad->fdef; - tensorflow::GradientDef* gdef = fdef_lib.add_gradient(); - gdef->set_function_name(func->fdef.signature().name()); - gdef->set_gradient_func(grad->fdef.signature().name()); - } - tensorflow::mutex_lock l(g->mu); - status->status = g->graph.AddFunctionLibrary(fdef_lib); + status->status = g->graph.AddFunctionDef(func->record->fdef(), + func->record->stack_traces()); + if (TF_GetCode(status) != TF_OK) return; + if (!grad) return; + + status->status = g->graph.AddFunctionDef(grad->record->fdef(), + grad->record->stack_traces()); + if (TF_GetCode(status) != TF_OK) return; + + tensorflow::GradientDef gdef; + gdef.set_function_name(func->record->fdef().signature().name()); + gdef.set_gradient_func(grad->record->fdef().signature().name()); + status->status = g->graph.AddGradientDef(std::move(gdef)); } int TF_GraphNumFunctions(TF_Graph* g) { @@ -279,7 +291,7 @@ int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func, const auto len = std::min(max_func, static_cast(lib.function_size())); for (int i = 0; i < len; ++i) { TF_Function* func = new TF_Function(); - func->fdef = lib.function(i); + func->record = new tensorflow::FunctionRecord(lib.function(i), {}, false); funcs[i] = func; } status->status = ::tensorflow::OkStatus(); @@ -288,18 +300,21 @@ int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func, void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, TF_Status* status) { - status->status = MessageToBuffer(func->fdef, output_func_def); + status->status = MessageToBuffer(func->record->fdef(), output_func_def); } TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len, TF_Status* status) { - TF_Function* func = new TF_Function(); - if (!func->fdef.ParseFromArray(proto, proto_len)) { + tensorflow::FunctionDef fdef; + bool success = fdef.ParseFromArray(proto, proto_len); + if (!success) { status->status = InvalidArgument( "Invalid FunctionDef given to TF_FunctionImportFunctionDef"); - TF_DeleteFunction(func); return nullptr; } + + TF_Function* func = new TF_Function(); + func->record = new tensorflow::FunctionRecord(std::move(fdef), {}, false); status->status = ::tensorflow::OkStatus(); return func; } @@ -314,21 +329,37 @@ void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name, "TF_FunctionSetAttrValueProto"); return; } - (*func->fdef.mutable_attr())[string(attr_name)] = attr_value; + + auto fdef_or = func->record->mutable_fdef(); + if (!fdef_or.ok()) { + status->status = fdef_or.status(); + return; + } + + (*(fdef_or.value()->mutable_attr()))[string(attr_name)] = attr_value; + status->status = ::tensorflow::OkStatus(); } void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value, TF_Status* status) { - const auto& it = func->fdef.attr().find(attr_name); - if (it == func->fdef.attr().end()) { + const auto& it = func->record->fdef().attr().find(attr_name); + if (it == func->record->fdef().attr().end()) { status->status = - InvalidArgument("Function '", func->fdef.signature().name(), + InvalidArgument("Function '", func->record->fdef().signature().name(), "' has no attr named '", attr_name, "'."); return; } status->status = MessageToBuffer(it->second, output_attr_value); } -void TF_DeleteFunction(TF_Function* func) { delete func; } +void TF_DeleteFunction(TF_Function* func) { + if (func == nullptr) { + return; + } + + func->record->Unref(); + func->record = nullptr; + delete func; +} diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index ec8cfe4a31a..0f177ed30ae 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/hash/hash.h" @@ -1210,6 +1211,25 @@ TEST_F(CApiFunctionTest, OutputOpNotInBody) { string(TF_Message(s_))); } +class TestStackTrace : public AbstractStackTrace { + absl::Span ToFrames() const override { return frames_; } + + StackFrame LastUserFrame() const override { return frames_.back(); } + + std::vector GetUserFrames(int limit) const override { + return frames_; + } + + string ToString(const TracePrintingOptions& opts) const override { + auto frame = LastUserFrame(); + return absl::StrCat(frame.file_name, ":", frame.line_number, ":", + frame.function_name); + } + + std::vector frames_{ + StackFrame({"dummy_file_name", 10, "dummy_function_name"})}; +}; + void DefineFunction(const char* name, TF_Function** func, const char* description = nullptr, bool append_hash = false) { @@ -1221,6 +1241,9 @@ void DefineFunction(const char* name, TF_Function** func, TF_Operation* feed = Placeholder(func_graph.get(), s.get()); TF_Operation* neg = Neg(feed, func_graph.get(), s.get()); + feed->node.SetStackTrace(std::make_shared()); + neg->node.SetStackTrace(std::make_shared()); + TF_Output inputs[] = {{feed, 0}}; TF_Output outputs[] = {{neg, 0}}; *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1, @@ -1270,11 +1293,11 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) { ASSERT_NE(func_, nullptr); // Verify that FunctionDef has 2 attributes, "v1" and "v2". - ASSERT_EQ(func_->fdef.signature().attr().size(), 2); - EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1"); - EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int"); - EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2"); - EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int"); + ASSERT_EQ(func_->record->fdef().signature().attr().size(), 2); + EXPECT_EQ(func_->record->fdef().signature().attr(0).name(), "v1"); + EXPECT_EQ(func_->record->fdef().signature().attr(0).type(), "int"); + EXPECT_EQ(func_->record->fdef().signature().attr(1).name(), "v2"); + EXPECT_EQ(func_->record->fdef().signature().attr(1).type(), "int"); } void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name, @@ -1308,14 +1331,65 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) { ASSERT_NE(func_, nullptr); // Verify that FunctionDef ArgDef has attributes. - ASSERT_EQ(func_->fdef.arg_attr_size(), 1); - auto arg_attrs = func_->fdef.arg_attr().find(0); - ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end()); + ASSERT_EQ(func_->record->fdef().arg_attr_size(), 1); + auto arg_attrs = func_->record->fdef().arg_attr().find(0); + ASSERT_NE(arg_attrs, func_->record->fdef().arg_attr().end()); auto iter = arg_attrs->second.attr().find("_test_attr"); ASSERT_NE(iter, arg_attrs->second.attr().end()); EXPECT_EQ(iter->second.s(), "value"); } +TEST_F(CApiFunctionTest, TFGraphToFunctionWithStackTraces) { + DefineFunction(func_name_, &func_); + auto stack_traces = func_->record->stack_traces(); + + EXPECT_EQ(stack_traces.size(), 4); + EXPECT_EQ(stack_traces["neg"]->ToString({}), + "dummy_file_name:10:dummy_function_name"); + EXPECT_EQ(stack_traces["feed"]->ToString({}), + "dummy_file_name:10:dummy_function_name"); +} + +TEST_F(CApiFunctionTest, TFGraphCopyFunctionWithStackTraces) { + // Define the function and its grad + DefineFunction(func_name_, &func_); + TF_Function* grad_func; + DefineFunction("MyGrad", &grad_func); + + // Add func and its gradient to host graph + TF_GraphCopyFunction(host_graph_, func_, grad_func, s_); + + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(grad_func); + + const StackTracesMap* func_stack_traces; + const StackTracesMap* grad_stack_traces; + + { + mutex_lock l(host_graph_->mu); + auto flib_def = host_graph_->graph.flib_def(); + func_stack_traces = flib_def.GetStackTraces(func_name_); + grad_stack_traces = flib_def.GetStackTraces("MyGrad"); + } + + // Verify that stack traces of func is copied to graph function library. + ASSERT_NE(func_stack_traces, nullptr); + EXPECT_EQ(func_stack_traces->size(), 4); + EXPECT_EQ(func_stack_traces->at("neg")->ToString({}), + "dummy_file_name:10:dummy_function_name"); + EXPECT_EQ(func_stack_traces->at("feed")->ToString({}), + "dummy_file_name:10:dummy_function_name"); + + // Verify that stack traces of grad_func is copied to graph function library. + ASSERT_NE(grad_stack_traces, nullptr); + EXPECT_EQ(grad_stack_traces->size(), 4); + EXPECT_EQ(grad_stack_traces->at("neg")->ToString({}), + "dummy_file_name:10:dummy_function_name"); + EXPECT_EQ(grad_stack_traces->at("feed")->ToString({}), + "dummy_file_name:10:dummy_function_name"); +} + TEST_F(CApiFunctionTest, SetGradientAndRun) { // Define the function and its grad DefineFunction(func_name_, &func_); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index a34e11a3e4c..92f63553ee1 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_C_C_API_INTERNAL_H_ #define TENSORFLOW_C_C_API_INTERNAL_H_ -#include "tensorflow/c/c_api.h" - #include #include #include #include #include +#include "tensorflow/c/c_api.h" + // clang-format off // Required for IS_MOBILE_PLATFORM #include "tensorflow/core/platform/platform.h" @@ -34,11 +34,12 @@ limitations under the License. #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/core/framework/op_gen_lib.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" @@ -159,8 +160,7 @@ struct TF_DeviceList { }; struct TF_Function { - tensorflow::FunctionDef fdef; - tensorflow::StackTracesMap stack_traces; + tensorflow::FunctionRecord* record; }; struct TF_ApiDefMap { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 051c81fc782..008e2d772a3 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -243,7 +243,7 @@ void TestEncodeDecode(int line, const std::vector& data) { src.flat()(i) = data[i]; } TF_Tensor* dst = TF_TensorFromTensor(src, &status); - ASSERT_TRUE(status.ok()) << status.error_message(); + ASSERT_TRUE(status.ok()) << status.message(); // Convert back to a C++ Tensor and ensure we get expected output. Tensor output; @@ -1435,7 +1435,7 @@ TEST(CAPI, SavedModel) { ASSERT_TRUE(input_op != nullptr); Status status; csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}}); - ASSERT_TRUE(status.ok()) << status.error_message(); + ASSERT_TRUE(status.ok()) << status.message(); const tensorflow::string output_op_name( tensorflow::ParseTensorName(output_name).first); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 4a613d874a2..3ed513f0caa 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -42,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename, TF_Status* status) v2_reader_.reset( new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); if (!v2_reader_->status().ok()) { - Set_TF_Status_from_Status(status, v2_reader_->status()); + tsl::Set_TF_Status_from_Status(status, v2_reader_->status()); return; } auto result = BuildV2VarMaps(); @@ -51,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename, TF_Status* status) } else { reader_.reset(new TensorSliceReader(filename)); if (!reader_->status().ok()) { - Set_TF_Status_from_Status(status, reader_->status()); + tsl::Set_TF_Status_from_Status(status, reader_->status()); return; } var_to_shape_map_.reset( @@ -102,7 +102,7 @@ void CheckpointReader::GetTensor( } } if (!status.ok()) { - Set_TF_Status_from_Status(out_status, status); + tsl::Set_TF_Status_from_Status(out_status, status); } } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 1fb1d367dfd..dd61bd26bc1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -37,20 +37,17 @@ tf_cuda_library( ], "//conditions:default": [ ":immediate_execution_context", + ":immediate_execution_distributed_manager", ":immediate_execution_operation", ":immediate_execution_tensor_handle", - ":immediate_execution_distributed_manager", - ":tfe_context_internal", ":tfe_cancellation_manager_internal", + ":tfe_context_internal", ":tfe_executor_internal", ":tfe_monitoring_internal", ":tfe_op_attrs_internal", ":tfe_op_internal", ":tfe_tensor_debug_info_internal", ":tfe_tensorhandle_internal", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_buffer", @@ -58,6 +55,12 @@ tf_cuda_library( "//tensorflow/c:tf_status_internal", "//tensorflow/c:tf_tensor_internal", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context_distributed_manager", @@ -65,34 +68,32 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:custom_device", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", - "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:placement_utils", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], }) + [ - "@com_google_absl//absl/memory", ":abstract_tensor_handle", + "//tensorflow/c:c_api_macros_hdrs", + "//tensorflow/core:gpu_runtime", "//tensorflow/core/common_runtime/eager:eager_operation", - "//tensorflow/core/distributed_runtime/eager:remote_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_interface", "//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime", "//tensorflow/core/distributed_runtime/eager:eager_client", - "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/eager:remote_mgr", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", - "//tensorflow/core/distributed_runtime:remote_device", - "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/distributed_runtime:worker_env", - "//tensorflow/core/distributed_runtime:worker_interface", - "//tensorflow/core:gpu_runtime", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", ] + internal_tfrt_deps(), alwayslink = 1, @@ -541,7 +542,9 @@ cc_library( cc_library( name = "tfe_op_attrs_internal", hdrs = ["tfe_op_attrs_internal.h"], - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow:internal", + ], deps = [ ":abstract_op_attrs", "//tensorflow/c:conversion_macros", @@ -836,64 +839,84 @@ tf_cuda_library( "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ + ":abstract_context", + ":abstract_operation", + ":abstract_tensor_handle", ":c_api", ":c_api_internal", ":graph_function", + ":immediate_execution_context", + ":immediate_execution_tensor_handle", ":tfe_context_internal", ":tfe_op_internal", ":tfe_tensorhandle_internal", - ":abstract_operation", - ":abstract_context", - ":abstract_tensor_handle", - ":immediate_execution_tensor_handle", - ":immediate_execution_context", - "//tensorflow/core/lib/llvm_rtti", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/c:conversion_macros", "//tensorflow/core:core_cpu", - "//tensorflow/core/common_runtime/eager:attr_builder", - "//tensorflow/core/common_runtime/eager:context", - "//tensorflow/core/common_runtime/eager:eager_executor", - "//tensorflow/core/common_runtime/eager:eager_operation", - "//tensorflow/core/common_runtime/eager:execute", - "//tensorflow/core/common_runtime/eager:kernel_and_device", - "//tensorflow/core/common_runtime/eager:tensor_handle", - "//tensorflow/core/common_runtime/eager:copy_to_device_node", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/lib/llvm_rtti", "@com_google_absl//absl/types:variant", - "//tensorflow/c:conversion_macros", ], }) + select({ "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/jit", "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/tf2xla:xla_compiler", ], "//conditions:default": [], }) + [ - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_status_helper", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/coordination:coordination_service_error_util", "//tensorflow/core/distributed_runtime/eager:eager_client", - "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", - "//tensorflow/core/distributed_runtime:remote_device", - "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/distributed_runtime:worker_env", - "//tensorflow/core:gpu_runtime", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], + alwayslink = 1, +) + +cc_library( + name = "c_api_experimental_reader", + testonly = True, + srcs = [ + "c_api_experimental_reader.cc", + ], + hdrs = [ + "c_api_experimental_reader.h", + "tfe_monitoring_reader_internal.h", + ], + visibility = ["//tensorflow:internal"], + deps = [ + ":c_api", + "//tensorflow/c:c_api", + "//tensorflow/core/lib/monitoring:cell_reader", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -920,6 +943,29 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "c_api_experimental_reader_test", + size = "small", + srcs = [ + "c_api_experimental_reader_test.cc", + ], + args = ["--heap_check="], + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_experimental_reader", + ":c_api_test_util", + "//tensorflow/c:c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/strings", + ], +) + tf_cuda_cc_test( name = "c_api_unified_experimental_test", size = "small", @@ -1009,6 +1055,23 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +filegroup( + name = "pywrap_headers_monitoring_reader", + srcs = [ + "c_api_experimental_reader.h", + "tfe_monitoring_reader_internal.h", + ], + visibility = ["//tensorflow:__subpackages__"], +) + +filegroup( + name = "headers_monitoring_reader", + srcs = [ + "c_api_experimental_reader.h", + ], + visibility = ["//tensorflow:__subpackages__"], +) + cc_library( name = "dlpack", srcs = ["dlpack.cc"], @@ -1046,6 +1109,9 @@ filegroup( ], exclude = [ "c_api_experimental.cc", + "c_api_experimental_reader.cc", + "c_api_experimental_reader.h", + "tfe_monitoring_reader_internal.h", "c_api_unified_experimental.cc", "c_api_unified_experimental_eager.cc", "c_api_unified_experimental_graph.cc", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index a4e11b63576..8503485f63c 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -137,14 +138,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { std::unique_ptr device_mgr( new tensorflow::DynamicDeviceMgr(std::move(devices))); - tensorflow::Rendezvous* r = - new tensorflow::IntraProcessRendezvous(device_mgr.get()); + auto r = tsl::core::RefCountPtr( + new tensorflow::IntraProcessRendezvous(device_mgr.get())); tensorflow::EagerContext* eager_context = new tensorflow::EagerContext( opts->session_options.options, static_cast( opts->device_placement_policy), opts->async, device_mgr.release(), - /*device_mgr_owned*/ true, r, + /*device_mgr_owned*/ true, std::move(r), /*cluster_flr=*/nullptr, /*collective_executor_mgr=*/nullptr, /*run_eager_op_as_function=*/opts->run_eager_op_as_function, @@ -931,9 +932,32 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - AnnotateEagerRuntimeConstructionContext(function->fdef); + auto fdef_or = function->record->mutable_fdef(); + if (!fdef_or.ok()) { + status->status = fdef_or.status(); + return; + } + + AnnotateEagerRuntimeConstructionContext(*fdef_or.value()); status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces( - function->fdef, function->stack_traces); + *fdef_or.value(), function->record->stack_traces()); +} + +TF_Function* TFE_ContextGetFunction(TFE_Context* ctx, const char* name, + TF_Status* status) { + tensorflow::core::RefCountPtr record = + tensorflow::unwrap(ctx)->FindRecord(name); + + if (record == nullptr) { + status->status = tensorflow::errors::NotFound( + "Unable to find Function with name: ", name); + return nullptr; + } + + TF_Function* result = new TF_Function(); + record->Ref(); + result->record = record.get(); + return result; } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 7ad77587d6f..7f458ac50ab 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -21,25 +21,7 @@ limitations under the License. // stable and can change without notice. #include "tensorflow/c/c_api.h" - -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes.$a -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG +#include "tensorflow/c/c_api_macros.h" #ifdef __cplusplus extern "C" { diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index 7a604950a63..c4b58c3dd73 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -150,7 +150,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { updated_server_def.set_task_index(1); tensorflow::Status s = tensorflow::GrpcServer::Create( updated_server_def, tensorflow::Env::Default(), &worker_server); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); ASSERT_TRUE(worker_server->Start().ok()); TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index efd9e8a0a35..e35bc962525 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -434,6 +434,7 @@ class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { tensorflow::Status Run(const std::string& function_name, const tensorflow::DeviceSet& device_set, const tensorflow::ConfigProto& config_proto, + absl::string_view xla_compile_device_type, std::unique_ptr* graph, tensorflow::FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 2490fc440ed..6fbcb7bb56a 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" +#include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/time/time.h" #include "tensorflow/c/c_api.h" @@ -29,6 +31,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/sampler.h" @@ -80,7 +84,7 @@ TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name, TF_Status* status, const char* description) { auto* result = new TFE_MonitoringCounter0({name, description}); - Set_TF_Status_from_Status(status, result->counter->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->counter->GetStatus()); if (!result->counter->GetStatus().ok()) { delete result; return nullptr; @@ -103,7 +107,7 @@ TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name, const char* description, const char* label1) { auto* result = new TFE_MonitoringCounter1({name, description, label1}); - Set_TF_Status_from_Status(status, result->counter->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->counter->GetStatus()); if (!result->counter->GetStatus().ok()) { delete result; return nullptr; @@ -128,7 +132,7 @@ TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name, const char* label2) { auto* result = new TFE_MonitoringCounter2({name, description, label1, label2}); - Set_TF_Status_from_Status(status, result->counter->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->counter->GetStatus()); if (!result->counter->GetStatus().ok()) { delete result; return nullptr; @@ -159,7 +163,7 @@ TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name, TF_Status* status, const char* description) { auto* result = new TFE_MonitoringIntGauge0({name, description}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -182,7 +186,7 @@ TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name, const char* description, const char* label1) { auto* result = new TFE_MonitoringIntGauge1({name, description, label1}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -207,7 +211,7 @@ TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name, const char* label2) { auto* result = new TFE_MonitoringIntGauge2({name, description, label1, label2}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -245,7 +249,7 @@ const void TFE_MonitoringStringGaugeCellValue( TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0( const char* name, TF_Status* status, const char* description) { auto* result = new TFE_MonitoringStringGauge0({name, description}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -267,7 +271,7 @@ TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1( const char* name, TF_Status* status, const char* description, const char* label1) { auto* result = new TFE_MonitoringStringGauge1({name, description, label1}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -290,7 +294,7 @@ TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2( const char* label1, const char* label2) { auto* result = new TFE_MonitoringStringGauge2({name, description, label1, label2}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -313,7 +317,7 @@ TFE_MonitoringStringGauge3* TFE_MonitoringNewStringGauge3( const char* label1, const char* label2, const char* label3) { auto* result = new TFE_MonitoringStringGauge3( {name, description, label1, label2, label3}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -338,7 +342,7 @@ TFE_MonitoringStringGauge4* TFE_MonitoringNewStringGauge4( const char* label4) { auto* result = new TFE_MonitoringStringGauge4( {name, description, label1, label2, label3, label4}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -370,7 +374,7 @@ TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name, TF_Status* status, const char* description) { auto* result = new TFE_MonitoringBoolGauge0({name, description}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -393,7 +397,7 @@ TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name, const char* description, const char* label1) { auto* result = new TFE_MonitoringBoolGauge1({name, description, label1}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -418,7 +422,7 @@ TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name, const char* label2) { auto* result = new TFE_MonitoringBoolGauge2({name, description, label1, label2}); - Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->gauge->GetStatus()); if (!result->gauge->GetStatus().ok()) { delete result; return nullptr; @@ -472,7 +476,7 @@ TFE_MonitoringSampler0* TFE_MonitoringNewSampler0( const char* description) { auto* result = new TFE_MonitoringSampler0( {name, buckets->create_buckets(), description}); - Set_TF_Status_from_Status(status, result->sampler->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->sampler->GetStatus()); if (!result->sampler->GetStatus().ok()) { delete result; return nullptr; @@ -495,7 +499,7 @@ TFE_MonitoringSampler1* TFE_MonitoringNewSampler1( const char* description, const char* label1) { auto* result = new TFE_MonitoringSampler1( {name, buckets->create_buckets(), description, label1}); - Set_TF_Status_from_Status(status, result->sampler->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->sampler->GetStatus()); if (!result->sampler->GetStatus().ok()) { delete result; return nullptr; @@ -518,7 +522,7 @@ TFE_MonitoringSampler2* TFE_MonitoringNewSampler2( const char* description, const char* label1, const char* label2) { auto* result = new TFE_MonitoringSampler2( {name, buckets->create_buckets(), description, label1, label2}); - Set_TF_Status_from_Status(status, result->sampler->GetStatus()); + tsl::Set_TF_Status_from_Status(status, result->sampler->GetStatus()); if (!result->sampler->GetStatus().ok()) { delete result; return nullptr; @@ -628,6 +632,30 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, status->status = ::tensorflow::OkStatus(); } +void TFE_ContextGetGraphDebugInfo(TFE_Context* ctx, const char* function_name, + TF_Buffer* buf, TF_Status* status) { + auto function_record = tensorflow::unwrap(ctx)->FindRecord(function_name); + if (function_record == nullptr) { + status->status = tensorflow::errors::NotFound( + "Unable to find function with name: ", function_name); + return; + } + + tensorflow::GraphDebugInfo debug_info = + tensorflow::StackTracesMapToGraphDebugInfo( + function_record->stack_traces()); + + string str = debug_info.SerializeAsString(); + void* data = tensorflow::port::Malloc(str.length()); + str.copy(static_cast(data), str.length(), 0); + buf->data = data; + buf->length = str.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + status->status = ::tensorflow::OkStatus(); +} + TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, const int64_t* dims, int num_dims, TF_Status* status) { @@ -884,7 +912,7 @@ void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states, const auto& result = (*results)[i]; TF_Status s; TF_SetStatus(&s, static_cast(result.error_code()), - result.error_message().data()); + std::string(result.error_message()).data()); if (TF_GetCode(&s) != TF_Code::TF_OK) { tensorflow::CoordinationServiceError error; *error.mutable_source_task() = result.error_payload().source_task(); diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 48c5fe70ce0..fcbced2080a 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -612,6 +612,17 @@ TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Get GraphDebugInfo containing stack traces mapping to node names +TF_CAPI_EXPORT extern void TFE_ContextGetGraphDebugInfo( + TFE_Context* ctx, const char* function_name, TF_Buffer* buf, + TF_Status* status); + +// Extracts a TF_Function from the context. +// Must call TF_DeleteFunction on the returned value. +TF_CAPI_EXPORT extern TF_Function* TFE_ContextGetFunction(TFE_Context* ctx, + const char* name, + TF_Status* status); + // Allocate and return a new Tensor on the host. // // The caller must set the Tensor values by writing them to the pointer returned diff --git a/tensorflow/c/eager/c_api_experimental_reader.cc b/tensorflow/c/eager/c_api_experimental_reader.cc new file mode 100644 index 00000000000..0959580a104 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental_reader.cc @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License");; +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental_reader.h" + +#include "tensorflow/c/eager/tfe_monitoring_reader_internal.h" + +template +int64_t TFE_MonitoringCounterReader::Read(const LabelType&... labels) { + return counter->Read(labels...); +} + +TFE_MonitoringCounterReader* TFE_MonitoringNewCounterReader(const char* name) { + auto* result = new TFE_MonitoringCounterReader(name); + + return result; +} + +int64_t TFE_MonitoringReadCounter0(TFE_MonitoringCounterReader* cell_reader) { + int64_t result = cell_reader->Read(); + + return result; +} + +int64_t TFE_MonitoringReadCounter1(TFE_MonitoringCounterReader* cell_reader, + const char* label) { + int64_t result = cell_reader->Read(label); + + return result; +} diff --git a/tensorflow/c/eager/c_api_experimental_reader.h b/tensorflow/c/eager/c_api_experimental_reader.h new file mode 100644 index 00000000000..71c2e4650f0 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental_reader.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License");; +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_READER_H_ +#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_READER_H_ + +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Test only exports of the monitoring Cell Reader API which allows tests to +// read current values from streamz counters defined in other modules. +// +// The code under test will have created streamz counters like this: +// auto* streamz = tensorflow::monitoring::Counter<1>::New("name", +// "description", "label"); +// and then incremented that counter for various values of label: +// streamz->GetCell("label-value")->IncrementBy(1); +// +// The test code can then read and test the value of that counter: +// +// auto* reader = TFE_MonitoringNewCounterReader("name"); +// test(); +// int64_t value = TFE_MonitoringReadCounter1(reader, "label-value"); + +// Opaque handle to a reader. +typedef struct TFE_MonitoringCounterReader TFE_MonitoringCounterReader; + +// Returns a handle to be used for reading values from streamz counter. The +// counter can have been created with any number of labels. +TF_CAPI_EXPORT extern TFE_MonitoringCounterReader* +TFE_MonitoringNewCounterReader(const char* name); + +// Reads the value of a counter that was created with 0 labels. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringReadCounter0( + TFE_MonitoringCounterReader*); + +// Reads the value of specific cell of a counter that was created with 1 label. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringReadCounter1( + TFE_MonitoringCounterReader*, const char* label_value); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_READER_H_ diff --git a/tensorflow/c/eager/c_api_experimental_reader_test.cc b/tensorflow/c/eager/c_api_experimental_reader_test.cc new file mode 100644 index 00000000000..3c7a09891a6 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental_reader_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental_reader.h" + +#include + +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TFE_MonitoringCounter0* CreateCounter0(const char* counter_name); +TFE_MonitoringCounter1* CreateCounter1(const char* counter_name, + const char* label); +void IncrementCounter0(TFE_MonitoringCounter0* counter, int64_t delta = 1); +void IncrementCounter1(TFE_MonitoringCounter1* counter, const char* label, + int64_t delta = 1); + +TEST(CAPI, MonitoringCellReader0) { + auto counter_name = "test/counter0"; + auto* counter = CreateCounter0(counter_name); + auto* reader = TFE_MonitoringNewCounterReader(counter_name); + IncrementCounter0(counter); + + int64_t actual = TFE_MonitoringReadCounter0(reader); + + CHECK_EQ(actual, 1); +} + +TEST(CAPI, MonitoringCellReader1) { + auto counter_name = "test/counter1"; + auto label_name = "test/label"; + auto* counter = CreateCounter1(counter_name, label_name); + auto* reader = TFE_MonitoringNewCounterReader(counter_name); + IncrementCounter1(counter, label_name); + + int64_t actual = TFE_MonitoringReadCounter1(reader, label_name); + + CHECK_EQ(actual, 1); +} + +TFE_MonitoringCounter0* CreateCounter0(const char* counter_name) { + TF_Status* status = TF_NewStatus(); + auto* counter = + TFE_MonitoringNewCounter0(counter_name, status, "description"); + TF_DeleteStatus(status); + return counter; +} + +void IncrementCounter0(TFE_MonitoringCounter0* counter, int64_t delta) { + auto* cell = TFE_MonitoringGetCellCounter0(counter); + TFE_MonitoringCounterCellIncrementBy(cell, delta); +} + +TFE_MonitoringCounter1* CreateCounter1(const char* counter_name, + const char* label) { + TF_Status* status = TF_NewStatus(); + auto* counter = + TFE_MonitoringNewCounter1(counter_name, status, "description", label); + TF_DeleteStatus(status); + return counter; +} + +void IncrementCounter1(TFE_MonitoringCounter1* counter, const char* label, + int64_t delta) { + auto* cell = TFE_MonitoringGetCellCounter1(counter, label); + TFE_MonitoringCounterCellIncrementBy(cell, delta); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 8f600c5de8f..1fb76748059 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" +#include +#include +#include + #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/tf_datatype.h" @@ -434,6 +438,8 @@ tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( {i, tensorflow::strings::StrCat("localhost:", port)}); + LOG(INFO) << "Picked test port: " << port << " for job: " << job_name + << ", task: " << i; } return server_def; } diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 543976b4a6b..53f340ee2aa 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -76,7 +76,7 @@ static TracingContext* CreateTracingExecutionContext(const char* fn_name, if (default_factory) { return default_factory(fn_name, s); } - Set_TF_Status_from_Status( + tsl::Set_TF_Status_from_Status( s, errors::FailedPrecondition("default_factory is nullptr")); return nullptr; } @@ -109,7 +109,7 @@ using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingTensorHandle; void TF_SetTracingImplementation(const char* name, TF_Status* s) { - Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name)); + tsl::Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name)); } // Creates a new TensorFlow function, it is an execution context attached to a @@ -123,12 +123,13 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, AbstractFunction* func; TracingContext* tracing_ctx = dyn_cast(unwrap(ctx)); if (!tracing_ctx) { - Set_TF_Status_from_Status( + tsl::Set_TF_Status_from_Status( s, tensorflow::errors::InvalidArgument( "Only TracingContext can be converted into a function.")); return nullptr; } - Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func)); + tsl::Set_TF_Status_from_Status(s, + tracing_ctx->Finalize(unwrap(outputs), &func)); TF_DeleteExecutionContext(ctx); return wrap(func); } @@ -140,7 +141,7 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, TracingTensorHandle* t; TracingContext* tracing_ctx = dyn_cast(unwrap(func)); if (!tracing_ctx) { - Set_TF_Status_from_Status( + tsl::Set_TF_Status_from_Status( s, tensorflow::errors::InvalidArgument( "TF_AddFunctionParameter must be called on a TracingContext.")); return nullptr; @@ -152,11 +153,11 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, reinterpret_cast(shape.dim_sizes), shape.num_dims, &partial_shape); if (!status.ok()) { - Set_TF_Status_from_Status(s, status); + tsl::Set_TF_Status_from_Status(s, status); return nullptr; } } - Set_TF_Status_from_Status( + tsl::Set_TF_Status_from_Status( s, tracing_ctx->AddParameter(static_cast(dtype), partial_shape, &t)); return wrap(t); @@ -193,20 +194,21 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, TF_Status* s) { - Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type, - /*raw_device_name=*/nullptr)); + tsl::Set_TF_Status_from_Status( + s, unwrap(op)->Reset(op_type, + /*raw_device_name=*/nullptr)); } void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, TF_Status* s) { TracingOperation* tracing_op = dyn_cast(unwrap(op)); if (!tracing_op) { - Set_TF_Status_from_Status( + tsl::Set_TF_Status_from_Status( s, tensorflow::errors::InvalidArgument( "TF_AbstractOpSetOpName must be called on a TracingOperation.")); return; } - Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name)); + tsl::Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name)); } void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, @@ -214,20 +216,20 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, Status status = unwrap(op)->SetAttrType(attr_name, static_cast(value)); TF_SetStatus(s, static_cast(status.code()), - status.error_message().c_str()); + tsl::NullTerminatedMessage(status)); } void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_AbstractTensor* const* inputs, TF_OutputList* o, TF_Status* s) { for (int i = 0; i < num_inputs; i++) { - Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i]))); + tsl::Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i]))); if (TF_GetCode(s) != TF_OK) { return; } } int num_outputs = unwrap(o)->expected_num_outputs; - Set_TF_Status_from_Status( + tsl::Set_TF_Status_from_Status( s, unwrap(op)->Execute( absl::MakeSpan(reinterpret_cast( unwrap(o)->outputs.data()), @@ -242,5 +244,6 @@ void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx, TF_AbstractFunction* func, TF_Status* s) { - Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func))); + tsl::Set_TF_Status_from_Status(s, + unwrap(ctx)->RegisterFunction(unwrap(func))); } diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 7d36cb0ad12..af8797c2932 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/strings/str_cat.h" @@ -204,7 +205,7 @@ class GraphOperation : public TracingOperation { Status SetAttrType(const char* const attr_name, DataType value) override { if (!op_) { return Status( - error::Code::FAILED_PRECONDITION, + absl::StatusCode::kFailedPrecondition, "op_type and op_name must be specified before specifying attrs."); } op_->node_builder.Attr(attr_name, value); @@ -387,7 +388,7 @@ class GraphContext : public TracingContext { inputs_.size(), inputs_.data(), graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_.data(), s); - *f = new GraphFunction(std::move(func->fdef)); + *f = new GraphFunction(std::move(func->record->fdef())); TF_DeleteFunction(func); TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); TF_DeleteStatus(s); diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 4814344e405..edaf3d8e579 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -47,7 +47,7 @@ class UnifiedCAPI TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); - CHECK_EQ(errors::OK, s.code()) << s.error_message(); + CHECK_EQ(errors::OK, s.code()) << s.message(); } }; diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index 4a688cec241..e012b29e93f 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -41,13 +41,13 @@ void CompareNumericalAndManualGradients( AbstractTensorHandle* numerical_grad_raw; s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function, &numerical_grad_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); numerical_grad.reset(numerical_grad_raw); } TF_Tensor* numerical_tensor; s = GetValue(numerical_grad.get(), &numerical_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); auto num_elem_numerical = TF_TensorElementCount(numerical_tensor); ASSERT_EQ(num_elem_numerical, num_grad); @@ -90,14 +90,14 @@ class GradientCheckerTest { Status s = StatusFromTF_Status(status.get()); - CHECK_EQ(errors::OK, s.code()) << s.error_message(); + CHECK_EQ(errors::OK, s.code()) << s.message(); } { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx_.reset(ctx_raw); } @@ -122,7 +122,7 @@ TEST_P(GradientCheckerTest, TestMatMul) { AbstractTensorHandle* A_raw; Status s = TestTensorHandleWithDims(ctx_.get(), A_vals, A_dims, 2, &A_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); A.reset(A_raw); } float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; @@ -132,7 +132,7 @@ TEST_P(GradientCheckerTest, TestMatMul) { AbstractTensorHandle* B_raw; Status s = TestTensorHandleWithDims(ctx_.get(), B_vals, B_dims, 2, &B_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); B.reset(B_raw); } @@ -148,7 +148,7 @@ TEST_P(GradientCheckerTest, TestMul) { AbstractTensorHandle* x_raw = nullptr; Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } @@ -157,7 +157,7 @@ TEST_P(GradientCheckerTest, TestMul) { AbstractTensorHandle* y_raw = nullptr; Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); y.reset(y_raw); } diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index a24e97f9981..a345240e8c3 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -53,7 +53,7 @@ class CppGradients TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); - CHECK_EQ(errors::OK, s.code()) << s.error_message(); + CHECK_EQ(errors::OK, s.code()) << s.message(); } }; @@ -70,7 +70,7 @@ TEST_P(CppGradients, TestSetAttrString) { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); } @@ -78,7 +78,7 @@ TEST_P(CppGradients, TestSetAttrString) { { AbstractTensorHandle* x_raw = nullptr; Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); t.reset(x_raw); } @@ -86,31 +86,31 @@ TEST_P(CppGradients, TestSetAttrString) { ForwardOperation forward_op; Status s = Reset(check_numerics_op.get(), "CheckNumerics", /*raw_device_name=*/nullptr, &forward_op); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); if (isa(check_numerics_op.get())) { s = dyn_cast(check_numerics_op.get()) ->SetOpName("check_numerics"); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); } s = AddInput(check_numerics_op.get(), t.get(), &forward_op); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); string message = "This is the way!"; s = SetAttrString(check_numerics_op.get(), "message", message.data(), message.length(), &forward_op); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); int num_retvals = 1; std::vector outputs(1); GradientRegistry registry; s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); auto tape = std::make_unique(/*persistent=*/false); s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), &num_retvals, &forward_op, tape.get(), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); string read_message; s = forward_op.attrs.Get("message", &read_message); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ASSERT_EQ(read_message, message); } @@ -136,7 +136,7 @@ TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); } @@ -144,7 +144,7 @@ TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) { { AbstractTensorHandle* x_raw = nullptr; Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } @@ -157,7 +157,7 @@ TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) { "Provided null gradient_function for 'Neg'.\nIf the intent is to treat " "this op as non-differentiable consider using RegisterNotDifferentiable " "or NotDifferentiableGradientFunction.", - s.error_message()); + s.message()); ASSERT_EQ(nullptr, outputs[0]); } diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 930a26bb120..f4eb7a05367 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -134,6 +134,10 @@ class ImmediateExecutionContext : public AbstractContext { // Find and return a added function by its name. virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; + // Find and return a function record added by its name. + virtual core::RefCountPtr FindRecord( + const string& name) const = 0; + // Return the ParsedName of Host CPU device. virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; virtual const string& HostCPUName() const = 0; @@ -249,6 +253,7 @@ class ImmediateExecutionContext : public AbstractContext { int64_t kernel_cache_size; int64_t device_cache_size; std::map func_kernel_cache_entries; + int64_t local_rendezvous_cache_active_size; }; virtual CacheStats GetCacheStats() = 0; diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 0de029ff449..6195d5d1e85 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -77,6 +77,7 @@ cc_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:safe_ptr", "//tensorflow/c:tf_status_internal", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index 71a5c46b7ea..8c51559d3f7 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -211,7 +211,7 @@ int ParallelTensorNumDims(void* data, TF_Status* status) { const std::vector* shape; Status s = reinterpret_cast(data)->Shape(&shape); if (!s.ok()) { - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); return -1; } return shape->size(); @@ -223,7 +223,7 @@ int64_t ParallelTensorDim(void* data, int dim_index, TF_Status* status) { const std::vector* shape; Status s = reinterpret_cast(data)->Shape(&shape); if (!s.ok()) { - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); return -1; } return (*shape)[dim_index]; @@ -234,7 +234,7 @@ TF_Buffer* ParallelTensorSummarize(void* data, TF_Status* status) { std::string summary; Status cpp_status = parallel_tensor->SummarizeValue(summary); if (!cpp_status.ok()) { - Set_TF_Status_from_Status(status, cpp_status); + tsl::Set_TF_Status_from_Status(status, cpp_status); return nullptr; } return TF_NewBufferFromString(summary.data(), summary.size()); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index d3ff26b0a74..0522ad3b730 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -368,6 +368,27 @@ void ParallelDevice::StartExecute(TFE_Context* context, } } +void ParallelDevice::StartExecute( + TFE_Context* context, + const std::vector>& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, CancellationManager& cancellation_manager, + absl::optional step_id) const { + for (int device_index = 0; device_index < underlying_devices_.size(); + ++device_index) { + DeviceThread* device_thread = device_threads_[device_index].get(); + std::vector device_inputs; + device_inputs.reserve(inputs.size()); + for (int input_index = 0; input_index < inputs.size(); ++input_index) { + // Parallel tensors are divided between operations by device. + device_inputs.push_back(inputs[input_index][device_index]); + } + device_thread->StartExecute( + context, operation_name, std::move(device_inputs), attributes, + expected_max_outputs, cancellation_manager, step_id); + } +} + void ParallelDevice::AsyncWait(TFE_Context* context, TF_Status* status) const { StatusPtr first_bad_status(nullptr); @@ -486,6 +507,11 @@ std::unique_ptr ParallelTensor::FromTensorHandles( const ParallelDevice& parallel_device, std::vector components, absl::Span shape, TF_Status* status) { + if (components.empty()) { + TF_SetStatus(status, TF_INTERNAL, + "No components are provide for creating a ParallelTensor"); + return nullptr; + } TFE_TensorHandleGetStatus(components[0].get(), status); if (!status->status.ok()) { return nullptr; @@ -513,6 +539,11 @@ std::unique_ptr ParallelTensor::FromTensorHandles( std::unique_ptr ParallelTensor::FromTensorHandles( const ParallelDevice& parallel_device, std::vector components, TF_Status* status) { + if (components.empty()) { + TF_SetStatus(status, TF_INTERNAL, + "No components are provided for creating a ParallelTensor"); + return nullptr; + } TFE_TensorHandleGetStatus(components[0].get(), status); if (!status->status.ok()) { return nullptr; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index 4b87ad4c106..b1b96d3b410 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/types/optional.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/safe_ptr.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -35,19 +37,7 @@ limitations under the License. namespace tensorflow { namespace parallel_device { -// Functor for making unique_ptrs slightly more ergonomic. Using -// decltype(delete_fn) in the unique_ptr's second template argument requires -// passing a function pointer to delete_fn when constructing the unique_ptr. -class TensorHandleDeleter { - public: - void operator()(TFE_TensorHandle* to_delete) const { - TFE_DeleteTensorHandle(to_delete); - } -}; - -// TODO(b/256016071): Replace this with `Safe_TFE_TensorHandlePtr` when -// `Safe_TFE_TensorHandlePtr` is marked to be compatible on non-prod env. -using TensorHandlePtr = std::unique_ptr; +using TensorHandlePtr = tensorflow::Safe_TFE_TensorHandlePtr; class ParallelTensor; class DeviceThread; @@ -128,6 +118,13 @@ class ParallelDevice { CancellationManager& cancellation_manager, std::optional step_id = std::nullopt) const; + void StartExecute(TFE_Context* context, + const std::vector>& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, + CancellationManager& cancellation_manager, + std::optional step_id = std::nullopt) const; + // Blocks until the previous `StartExecute` has run `TFE_Execute` on each // device. If is_async=false (constructor argument) this means the ops have // run and have results. If is_async=true it means that all of the @@ -206,6 +203,17 @@ class ParallelTensor { // component device. Status SummarizeValue(std::string& summary); + std::vector release_tensors() { return std::move(tensors_); } + + std::vector tensors() const { + std::vector result; + result.reserve(tensors_.size()); + for (const TensorHandlePtr& tensor : tensors_) { + result.emplace_back(tensor.get()); + } + return result; + } + private: ParallelTensor(const ParallelDevice& device, std::vector tensors, @@ -222,7 +230,7 @@ class ParallelTensor { dtype_(dtype) {} const ParallelDevice& device_; - const std::vector tensors_; + std::vector tensors_; // Parallel tensors are immutable but compute their shape lazily unless it is // provided on construction. The optional has a value if the lazy computation // has been completed or the shape was provided on construction. diff --git a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc index 41d6f14e068..9f157ae760e 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "tensorflow/c/c_api.h" @@ -37,6 +38,8 @@ tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( {i, tensorflow::strings::StrCat("localhost", ":", port)}); + LOG(INFO) << "Picked test port: " << port << " for job: " << job_name + << ", task: " << i; } return server_def; } diff --git a/tensorflow/c/eager/tfe_monitoring_reader_internal.h b/tensorflow/c/eager/tfe_monitoring_reader_internal.h new file mode 100644 index 00000000000..3c63e6725f1 --- /dev/null +++ b/tensorflow/c/eager/tfe_monitoring_reader_internal.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_READER_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_MONITORING_READER_INTERNAL_H_ + +#include + +#include "tensorflow/core/lib/monitoring/cell_reader.h" + +struct TFE_MonitoringCounterReader { + explicit TFE_MonitoringCounterReader(const char* name) { + counter = std::make_unique< + ::tensorflow::monitoring::testing::CellReader>(name); + } + template + int64_t Read(const LabelType&... labels); + std::unique_ptr<::tensorflow::monitoring::testing::CellReader> + counter; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_READER_INTERNAL_H_ diff --git a/tensorflow/c/eager/unified_api_test.cc b/tensorflow/c/eager/unified_api_test.cc index a9204f4462c..27e42be5bcc 100644 --- a/tensorflow/c/eager/unified_api_test.cc +++ b/tensorflow/c/eager/unified_api_test.cc @@ -30,7 +30,7 @@ class UnifiedAPI TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); - CHECK_EQ(errors::OK, s.code()) << s.error_message(); + CHECK_EQ(errors::OK, s.code()) << s.message(); } public: @@ -61,7 +61,7 @@ TEST_P(UnifiedAPI, TestTensorShapeScalar) { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); } @@ -69,7 +69,7 @@ TEST_P(UnifiedAPI, TestTensorShapeScalar) { { AbstractTensorHandle* x_raw = nullptr; Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } @@ -77,7 +77,7 @@ TEST_P(UnifiedAPI, TestTensorShapeScalar) { /*inputs=*/{x.get()}, /*outputs=*/{}, /*use_function=*/UseFunction()); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); } // Checks that inputs[0] is a matrix with shape 2x4. @@ -111,7 +111,7 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); } @@ -122,7 +122,7 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) { int64_t dim_sizes[] = {2, 4}; Status s = TestTensorHandleWithDims(ctx.get(), data, dim_sizes, 2, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } @@ -130,7 +130,7 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) { /*inputs=*/{x.get()}, /*outputs=*/{}, /*use_function=*/UseFunction()); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); } TEST_P(UnifiedAPI, TestUnknownShapeTracing) { @@ -148,13 +148,13 @@ TEST_P(UnifiedAPI, TestUnknownShapeTracing) { PartialTensorShape shape; Status s = dyn_cast(ctx.get())->AddParameter( DT_FLOAT, shape, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } PartialTensorShape shape; Status s = x->Shape(&shape); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ASSERT_TRUE(shape.unknown_rank()); } @@ -172,16 +172,16 @@ TEST_P(UnifiedAPI, TestPartialShapeTracing) { PartialTensorShape shape; int64_t dim_sizes[] = {2, -1}; Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); s = dyn_cast(ctx.get())->AddParameter( DT_FLOAT, shape, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } PartialTensorShape shape; Status s = x->Shape(&shape); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ASSERT_FALSE(shape.unknown_rank()); ASSERT_EQ(2, shape.dim_size(0)); diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 1788cbd6551..65f580deee9 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -178,9 +178,9 @@ tf_cuda_cc_test( "//tensorflow/c/eager:unified_api_testutil", "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/core/platform:tensor_float_32_utils", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:tensor_float_32_utils", ] + if_libtpu( if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], if_true = [], @@ -204,9 +204,9 @@ tf_cuda_cc_test( "//tensorflow/c/eager:unified_api_testutil", "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/core/platform:tensor_float_32_utils", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:tensor_float_32_utils", ] + if_libtpu( if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], if_true = [], @@ -222,17 +222,17 @@ tf_cuda_cc_test( args = ["--heap_check="], # TODO(b/174752220): Remove tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, deps = [ - ":grad_test_helper", ":array_grad", + ":grad_test_helper", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api_test_util", - "//tensorflow/c/experimental/gradients/tape:tape_context", - "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/core/platform:tensor_float_32_utils", - "//tensorflow/core:test", - "//tensorflow/core:test_main", "//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:tensor_float_32_utils", ] + if_libtpu( if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], if_true = [], diff --git a/tensorflow/c/experimental/gradients/array_grad_test.cc b/tensorflow/c/experimental/gradients/array_grad_test.cc index 61c0bce6664..fcaafd693e1 100644 --- a/tensorflow/c/experimental/gradients/array_grad_test.cc +++ b/tensorflow/c/experimental/gradients/array_grad_test.cc @@ -51,13 +51,13 @@ class CppGradients TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); status_ = StatusFromTF_Status(status.get()); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); { AbstractContext* ctx_raw = nullptr; status_ = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); immediate_execution_ctx_.reset(ctx_raw); } @@ -86,7 +86,7 @@ TEST_P(CppGradients, TestIdentityNGrad) { AbstractTensorHandle* x1_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 1.0f, &x1_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x1.reset(x1_raw); } @@ -95,19 +95,19 @@ TEST_P(CppGradients, TestIdentityNGrad) { AbstractTensorHandle* x2_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 1.0f, &x2_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x2.reset(x2_raw); } status_ = registry_.Register("IdentityN", IdentityNRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); auto IdentityNGradModel = BuildGradModel(IdentityNModel, registry_); std::vector outputs(2); status_ = RunModel(IdentityNGradModel, immediate_execution_ctx_.get(), {x1.get(), x2.get()}, absl::MakeSpan(outputs), UseFunction()); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); EXPECT_EQ(outputs[0], nullptr); ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {1.0f}, /*dims*/ {}, /*abs_error*/ 0)); diff --git a/tensorflow/c/experimental/gradients/custom_gradient_test.cc b/tensorflow/c/experimental/gradients/custom_gradient_test.cc index d447073b36a..cce9a051a74 100644 --- a/tensorflow/c/experimental/gradients/custom_gradient_test.cc +++ b/tensorflow/c/experimental/gradients/custom_gradient_test.cc @@ -38,7 +38,7 @@ class CustomGradientTest TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); - CHECK_EQ(errors::OK, s.code()) << s.error_message(); + CHECK_EQ(errors::OK, s.code()) << s.message(); } }; @@ -92,7 +92,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); } @@ -100,7 +100,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) { { AbstractTensorHandle* x_raw = nullptr; Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } @@ -113,11 +113,11 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) { Status s = RunModel(ExpWithPassThroughGrad, ctx.get(), {x.get()}, absl::MakeSpan(outputs), /*use_function=*/!std::get<2>(GetParam())); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); TF_Tensor* result_tensor; s = GetValue(outputs[0], &result_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); auto result_value = static_cast(TF_TensorData(result_tensor)); EXPECT_EQ(*result_value, 1.0); outputs[0]->Unref(); diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc index 1bcb72175f7..a4b71ea6d3b 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.cc +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -30,7 +30,7 @@ void CompareNumericalAndAutodiffGradients( std::vector outputs(num_inputs); auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs), /*use_function=*/use_function); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); for (int i = 0; i < num_inputs; ++i) { if (!outputs[i]) continue; @@ -41,18 +41,18 @@ void CompareNumericalAndAutodiffGradients( s = CalcNumericalGrad(ctx, model, inputs, /*input_index=*/i, use_function, &numerical_grad_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); numerical_grad.reset(numerical_grad_raw); } TF_Tensor* numerical_tensor; s = GetValue(numerical_grad.get(), &numerical_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); auto num_elem_numerical = TF_TensorElementCount(numerical_tensor); TF_Tensor* analytical_tensor; s = GetValue(outputs[i], &analytical_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); auto num_elem_analytical = TF_TensorElementCount(analytical_tensor); ASSERT_EQ(num_elem_numerical, num_elem_analytical); @@ -79,7 +79,7 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, absl::Span dims, double abs_error) { TF_Tensor* analytical_tensor; auto s = GetValue(t, &analytical_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(errors::OK, s.code()) << s.message(); int64_t num_elem_analytical = 1; auto num_dims_analytical = TF_NumDims(analytical_tensor); diff --git a/tensorflow/c/experimental/gradients/math_grad_test.cc b/tensorflow/c/experimental/gradients/math_grad_test.cc index c528fc1ae40..d0d08db8fd4 100644 --- a/tensorflow/c/experimental/gradients/math_grad_test.cc +++ b/tensorflow/c/experimental/gradients/math_grad_test.cc @@ -86,13 +86,13 @@ class CppGradients TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); status_ = StatusFromTF_Status(status.get()); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); { AbstractContext* ctx_raw = nullptr; status_ = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); immediate_execution_ctx_.reset(ctx_raw); } @@ -117,7 +117,7 @@ TEST_P(CppGradients, TestAddGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } @@ -126,14 +126,14 @@ TEST_P(CppGradients, TestAddGrad) { AbstractTensorHandle* y_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &y_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); y.reset(y_raw); } // TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to // AddV2Registerer. status_ = registry_.Register("AddV2", AddRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( AddModel, BuildGradModel(AddModel, registry_), @@ -146,12 +146,12 @@ TEST_P(CppGradients, TestExpGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } status_ = registry_.Register("Exp", ExpRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( ExpModel, BuildGradModel(ExpModel, registry_), @@ -171,7 +171,7 @@ TEST_P(CppGradients, TestMatMulGrad) { AbstractTensorHandle* A_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); A.reset(A_raw); } @@ -182,12 +182,12 @@ TEST_P(CppGradients, TestMatMulGrad) { AbstractTensorHandle* B_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); B.reset(B_raw); } status_ = registry_.Register("MatMul", MatMulRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); for (bool transpose_a : {false, true}) { for (bool transpose_b : {false, true}) { @@ -214,7 +214,7 @@ TEST_P(CppGradients, TestMatMulGradManual) { AbstractTensorHandle* A_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); A.reset(A_raw); } @@ -225,12 +225,12 @@ TEST_P(CppGradients, TestMatMulGradManual) { AbstractTensorHandle* B_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); B.reset(B_raw); } status_ = registry_.Register("MatMul", MatMulRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); bool transpose_a_vals[] = {false, false, true, true}; bool transpose_b_vals[] = {false, true, false, true}; @@ -259,7 +259,7 @@ TEST_P(CppGradients, TestMatMulGradManual) { status_ = RunModel(MatMulGradModel, immediate_execution_ctx_.get(), {A.get(), B.get()}, absl::MakeSpan(outputs), UseFunction()); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], dA_vals[i], /*dims*/ {3, 3}, /*abs_error*/ 0)); @@ -277,12 +277,12 @@ TEST_P(CppGradients, TestSqrtGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } status_ = registry_.Register("Sqrt", SqrtRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( SqrtModel, BuildGradModel(SqrtModel, registry_), @@ -295,12 +295,12 @@ TEST_P(CppGradients, TestNegGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } status_ = registry_.Register("Neg", NegRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( NegModel, BuildGradModel(NegModel, registry_), @@ -313,7 +313,7 @@ TEST_P(CppGradients, TestSubGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } @@ -322,12 +322,12 @@ TEST_P(CppGradients, TestSubGrad) { AbstractTensorHandle* y_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &y_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); y.reset(y_raw); } status_ = registry_.Register("Sub", SubRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( SubModel, BuildGradModel(SubModel, registry_), @@ -340,7 +340,7 @@ TEST_P(CppGradients, TestMulGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } @@ -349,12 +349,12 @@ TEST_P(CppGradients, TestMulGrad) { AbstractTensorHandle* y_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &y_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); y.reset(y_raw); } status_ = registry_.Register("Mul", MulRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( MulModel, BuildGradModel(MulModel, registry_), @@ -367,12 +367,12 @@ TEST_P(CppGradients, TestLog1pGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } status_ = registry_.Register("Log1p", Log1pRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( Log1pModel, BuildGradModel(Log1pModel, registry_), @@ -381,7 +381,7 @@ TEST_P(CppGradients, TestLog1pGrad) { TEST_P(CppGradients, TestDivNoNanGrad) { status_ = registry_.Register("DivNoNan", DivNoNanRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); auto DivNoNanGradModel = BuildGradModel(DivNoNanModel, registry_); @@ -390,7 +390,7 @@ TEST_P(CppGradients, TestDivNoNanGrad) { AbstractTensorHandle* x_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); x.reset(x_raw); } @@ -399,7 +399,7 @@ TEST_P(CppGradients, TestDivNoNanGrad) { AbstractTensorHandle* y_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 2.0f, &y_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); y.reset(y_raw); } @@ -413,14 +413,14 @@ TEST_P(CppGradients, TestDivNoNanGrad) { AbstractTensorHandle* z_raw = nullptr; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 0.0f, &z_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); z.reset(z_raw); } std::vector outputs(2); status_ = RunModel(DivNoNanGradModel, immediate_execution_ctx_.get(), {x.get(), z.get()}, absl::MakeSpan(outputs), UseFunction()); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {}, /*abs_error*/ 0)); ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {0.0f}, /*dims*/ {}, diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 15552eed3ca..d6d0d4dd524 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -67,13 +67,13 @@ class CppGradients TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); status_ = StatusFromTF_Status(status.get()); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); { AbstractContext* ctx_raw = nullptr; status_ = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); immediate_execution_ctx_.reset(ctx_raw); } @@ -94,7 +94,7 @@ class CppGradients TEST_P(CppGradients, TestReluGrad) { status_ = registry_.Register("Relu", ReluRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); auto ReluGradModel = BuildGradModel(ReluModel, registry_); @@ -105,7 +105,7 @@ TEST_P(CppGradients, TestReluGrad) { AbstractTensorHandle* X_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); X.reset(X_raw); } @@ -120,14 +120,14 @@ TEST_P(CppGradients, TestReluGrad) { AbstractTensorHandle* Y_raw; status_ = TestScalarTensorHandle( immediate_execution_ctx_.get(), 0.0f, &Y_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); Y.reset(Y_raw); } std::vector outputs(1); status_ = RunModel(ReluGradModel, immediate_execution_ctx_.get(), {Y.get()}, absl::MakeSpan(outputs), UseFunction()); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {}, /*abs_error*/ 0)); outputs[0]->Unref(); @@ -148,7 +148,7 @@ TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) { AbstractTensorHandle* X_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); X.reset(X_raw); } // Label @@ -159,13 +159,13 @@ TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) { AbstractTensorHandle* Y_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), Y_vals, Y_dims, 1, &Y_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); Y.reset(Y_raw); } status_ = registry_.Register("SparseSoftmaxCrossEntropyWithLogits", SparseSoftmaxCrossEntropyWithLogitsRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( SparseSoftmaxCrossEntropyWithLogitsModel, @@ -186,7 +186,7 @@ TEST_P(CppGradients, TestBiasAddGrad) { AbstractTensorHandle* A_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); A.reset(A_raw); } // Bias @@ -197,12 +197,12 @@ TEST_P(CppGradients, TestBiasAddGrad) { AbstractTensorHandle* Bias_raw; status_ = TestTensorHandleWithDims( immediate_execution_ctx_.get(), Bias_vals, Bias_dims, 1, &Bias_raw); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); Bias.reset(Bias_raw); } status_ = registry_.Register("BiasAdd", BiasAddRegisterer); - ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_EQ(errors::OK, status_.code()) << status_.message(); ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( BiasAddModel, BuildGradModel(BiasAddModel, registry_), diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc index c7f2739601f..ed4b4d92362 100644 --- a/tensorflow/c/experimental/grappler/grappler_test.cc +++ b/tensorflow/c/experimental/grappler/grappler_test.cc @@ -94,7 +94,7 @@ TEST(Grappler, DeviceTypeNotSet) { tensorflow::Status status = InitGraphPlugin(plugin_init); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ( - status.error_message(), + status.message(), "'device_type' field in TP_OptimizerRegistrationParams must be set."); } @@ -109,7 +109,7 @@ TEST(Grappler, OptimizeFuncNotSet) { tensorflow::Status status = InitGraphPlugin(plugin_init); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); - ASSERT_EQ(status.error_message(), + ASSERT_EQ(status.message(), "'optimize_func' field in TP_Optimizer must be set."); } diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 89c718ec5d8..eda00deb59c 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -12,6 +12,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/c:c_api_headers", + "//tensorflow/c:c_api_macros_hdrs", "//tensorflow/c:kernels_experimental_hdrs", "//tensorflow/c:kernels_hdrs", "//tensorflow/c:tf_buffer_internal", @@ -41,6 +42,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/c:c_api_headers", + "//tensorflow/c:c_api_macros_hdrs", "//tensorflow/c:kernels_hdrs", "//tensorflow/c:tf_buffer_internal", "//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs", diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc index 3d7150433b9..caa49be2d3f 100644 --- a/tensorflow/c/experimental/next_pluggable_device/c_api.cc +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc @@ -56,7 +56,7 @@ void TF_CreatePluginResource(TF_OpKernelContext* ctx, auto cc_status = cc_ctx->resource_manager()->Create( container_name, plugin_resource_name, cc_resource_ptr); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); } void TF_LookupOrCreatePluginResource( @@ -86,7 +86,7 @@ void TF_LookupOrCreatePluginResource( } else { *result_plugin_resource = nullptr; } - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); } // ------------------------- VariableInfo ------------------------------------ @@ -113,7 +113,7 @@ TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx, cc_status = tsl::errors::InvalidArgument( "Trying to obtain resource handle from Input[", index, "], which is not type DT_RESOURCE."); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); return nullptr; } const tensorflow::ResourceHandle& handle = @@ -141,20 +141,20 @@ void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx, tsl::Status cc_status; if (var_info == nullptr) { cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL."); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); return; } if (var_info->var_info.var() == nullptr) { cc_status = tsl::errors::InvalidArgument( "VariableInfo does not track a resource variable."); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); return; } cc_status = cc_ctx->allocate_temp(var_info->var_info.var()->tensor()->dtype(), var_info->var_info.var()->tensor()->shape(), var_info->var_info.var()->tensor()); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); } TF_Tensor* TF_GetTensorFromVariableInfo(TF_VariableInfo* var_info, @@ -162,20 +162,20 @@ TF_Tensor* TF_GetTensorFromVariableInfo(TF_VariableInfo* var_info, tsl::Status cc_status; if (var_info == nullptr) { cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL."); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); return nullptr; } if (var_info->var_info.var() == nullptr) { cc_status = tsl::errors::InvalidArgument( "VariableInfo does not track a resource variable."); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); return nullptr; } tensorflow::Tensor* tensor = var_info->var_info.var()->tensor(); TF_Tensor* result_tensor = tensorflow::TF_TensorFromTensor(*tensor, &cc_status); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); return result_tensor; } @@ -323,6 +323,13 @@ void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer, } tensorflow::AsyncValueTensor* av_tensor = tensorflow::AsyncValueTensor::FromTensor(&tensor); + if (av_tensor == nullptr) { + tensorflow::Set_TF_Status_from_Status( + status, + tsl::errors::Internal( + "The tensor to set PjRtBuffer is not an AsyncValueTensor.")); + return; + } av_tensor->SetBuffer( std::make_unique(pjrt_c_api_client, c_buffer)); TF_SetStatus(status, TF_OK, ""); diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.h b/tensorflow/c/experimental/next_pluggable_device/c_api.h index 4c476a68322..f8f3db2c737 100644 --- a/tensorflow/c/experimental/next_pluggable_device/c_api.h +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_ #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/kernels.h" #include "tensorflow/c/kernels_experimental.h" #include "tensorflow/c/tf_buffer.h" @@ -26,25 +27,6 @@ limitations under the License. // C API for device. The API is under active development and eventually // should allow registering a plugin device with TensorFlow. -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index d72cf86a7bc..f35d4c1ee04 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -16,7 +16,6 @@ package( # copybara:uncomment() "//learning/brain/tfrt/aot:__pkg__", "//tensorflow/c:__subpackages__", "//tensorflow/c/experimental/saved_model/internal:__pkg__", - "//tensorflow/cc/experimental/libtf:__pkg__", "//tensorflow/core:__subpackages__", ], licenses = ["notice"], diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc index 52a652a90ef..9f63038ac4c 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc @@ -96,7 +96,7 @@ TEST_F(RestoreOpsTest, BadCheckpointPrefixShouldFail) { Status status = internal::SingleRestore( context(), CheckpointPrefix("unknown_bad_checkpoint_prefix"), "x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle); - EXPECT_FALSE(status.ok()) << status.error_message(); + EXPECT_FALSE(status.ok()) << status.message(); } TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) { @@ -104,7 +104,7 @@ TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) { Status status = internal::SingleRestore( context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"), "bad_checkpoint_key", DT_FLOAT, &x_handle); - EXPECT_FALSE(status.ok()) << status.error_message(); + EXPECT_FALSE(status.ok()) << status.message(); } } // namespace diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc index 59f7306fedc..d6e568090f7 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -41,7 +41,7 @@ FlatTensorFunction::~FlatTensorFunction() { Status status = ctx_->RemoveFunction(name_); if (!status.ok()) { LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " - << status.error_message(); + << status.message(); } } diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc index 43b8c3ee303..3bcaee4852a 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.cc @@ -71,7 +71,7 @@ RestoredResource::~RestoredResource() { if (!status.ok()) { LOG(WARNING) << "Failed executing destroy_resource function for RestoredResource: " - << status.error_message(); + << status.message(); } } } diff --git a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc index 6947b6eb28d..7d6b50fa6b5 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc @@ -126,7 +126,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { ImmediateTensorHandlePtr expected_handle = testing::CreateTensorHandle(context(), dtype, shape_vector, 42); AbstractTensorPtr expected_tensor(expected_handle->Resolve(&status)); - TF_EXPECT_OK(status) << status.error_message(); + TF_EXPECT_OK(status) << status.message(); // Assign the tensorhandle to the variable. TF_EXPECT_OK(var->Assign(expected_handle.get())); @@ -135,7 +135,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { ImmediateTensorHandlePtr output_handle; TF_EXPECT_OK(var->ReadValue(&output_handle)); AbstractTensorPtr output_tensor(output_handle->Resolve(&status)); - TF_EXPECT_OK(status) << status.error_message(); + TF_EXPECT_OK(status) << status.message(); // Check that output_tensor == expected_tensor EXPECT_EQ(output_tensor->Type(), expected_tensor->Type()); diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index 88d44011f09..0f1f9ca4488 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -139,7 +139,7 @@ void CheckBufferDataIsEqual(DataType dtype, int64_t num_elements, void* a, AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle) { Status status; AbstractTensorPtr tensor(handle->Resolve(&status)); - CHECK(status.ok()) << status.error_message(); + CHECK(status.ok()) << status.message(); CHECK_NE(tensor.get(), nullptr); return tensor; } diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc index ae3460f7a61..92a21ae6e04 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc @@ -81,8 +81,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, TooFewInputsInSavedConcreteFunction) { std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) << status.message(); } // A SavedConcreteFunction whose canonicalized input signature length + @@ -105,8 +104,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, captures, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) << status.message(); } // A SavedConcreteFunction whose canonicalized input signature @@ -124,8 +122,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, TooManyInputsInSavedConcreteFunction) { std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) << status.message(); } // A SavedConcreteFunction whose canonicalized input signature @@ -149,8 +146,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, captures, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) << status.message(); } // A SavedConcreteFunction whose capture refers to an index not in the capture @@ -174,8 +170,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, ImproperCaptureIndex) { std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, captures, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) << status.message(); } // A SavedConcreteFunction whose outputs are fewer than its corresponding @@ -193,8 +188,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, TooFewOutputsInSavedConcreteFunction) { std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) << status.message(); } // A SavedConcreteFunction whose outputs exceed its corresponding functiondef @@ -213,8 +207,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) << status.message(); } // A SavedConcreteFunction whose (inputs + captures) = functiondef inputs, @@ -238,7 +231,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, SuccessfulLoad) { std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, captures, context(), &result); - TF_EXPECT_OK(status) << status.error_message(); + TF_EXPECT_OK(status) << status.message(); } // A TFConcreteFunction should register functiondefs on creation, and @@ -257,7 +250,7 @@ TEST_F(SavedConcreteFunctionLoadingTest, RegistersAndRemovesFunctionDefs) { std::unique_ptr result; Status status = internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - TF_EXPECT_OK(status) << status.error_message(); + TF_EXPECT_OK(status) << status.message(); // The function should be registered with context. EXPECT_TRUE(context()->FindFunctionByName(func_name)); } diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 2a7c0a8c419..85647f78b7b 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -26,7 +26,7 @@ cc_library( hdrs = ["stream_executor.h"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/c:c_api_macros", + "//tensorflow/c:c_api_macros_hdrs", "//tensorflow/c:tf_status_headers", ], ) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 2ba7d3cc953..3c984bcc15c 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -204,7 +204,7 @@ struct HostCallbackContext { void HostCallbackTrampoline(void* ctx, TF_Status* status) { HostCallbackContext* host_ctx = static_cast(ctx); tsl::Status s = std::move(host_ctx->callback)(); - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); delete host_ctx; } @@ -237,7 +237,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { stream_executor_->allocate(&device_, size, memory_space, &mem); tsl::Status status = ValidateSPDeviceMemoryBase(mem); if (!status.ok()) { - LOG(ERROR) << status.error_message(); + LOG(ERROR) << status.message(); } return DeviceMemoryBaseFromC(mem); } @@ -284,7 +284,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { } tsl::Status status = ValidateSPAllocatorStats(c_stats); if (!status.ok()) { - LOG(ERROR) << status.error_message(); + LOG(ERROR) << status.message(); return absl::nullopt; } ::stream_executor::AllocatorStats stats; diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index cf21374c48f..90b4dad5daa 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -65,7 +65,7 @@ TEST(StreamExecutor, NameNotSet) { tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); - ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); + ASSERT_EQ(status.message(), "'name' field in SP_Platform must be set."); } TEST(StreamExecutor, InvalidNameWithSemicolon) { @@ -81,7 +81,7 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) { InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); EXPECT_THAT( - status.error_message(), + status.message(), testing::ContainsRegex("Device name/type 'INVALID:NAME' must match")); } @@ -97,7 +97,7 @@ TEST(StreamExecutor, InvalidNameWithSlash) { tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); - EXPECT_THAT(status.error_message(), + EXPECT_THAT(status.message(), testing::ContainsRegex("Device name/type 'INVALID/' must match")); } @@ -113,7 +113,7 @@ TEST(StreamExecutor, CreateDeviceNotSet) { tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); - ASSERT_EQ(status.error_message(), + ASSERT_EQ(status.message(), "'create_device' field in SP_PlatformFns must be set."); } @@ -130,7 +130,7 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) { InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ( - status.error_message(), + status.message(), "'unified_memory_allocate' field in SP_StreamExecutor must be set."); } @@ -327,7 +327,7 @@ TEST_F(StreamExecutorTest, StreamStatus) { status_ok = false; auto updated_status = stream.RefreshStatus(); ASSERT_FALSE(stream.ok()); - ASSERT_EQ(updated_status.error_message(), "Test error"); + ASSERT_EQ(updated_status.message(), "Test error"); } TEST_F(StreamExecutorTest, CreateEvent) { diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc index 6722b86c0ef..41928bc469c 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc @@ -83,9 +83,6 @@ void SynchronizeAllActivity(const SP_Device* const device, TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream, SE_StatusCallbackFn const callback_fn, void* const callback_arg) { - TSL_Status* status_ignored = TSL_NewStatus(); - callback_fn(callback_arg, status_ignored); - TSL_DeleteStatus(status_ignored); return true; } diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 9f34547f9ee..59f978000e6 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -36,6 +36,9 @@ limitations under the License. #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/tsl/framework/device_id_utils.h" +#include "tensorflow/tsl/platform/statusor.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) using tensorflow::errors::InvalidArgument; @@ -660,12 +663,12 @@ TF_Buffer* TF_OpKernelConstruction_GetAttrFunction(TF_OpKernelConstruction* ctx, tensorflow::NameAttrList function; auto cc_status = cc_ctx->GetAttr(attr_name, &function); if (!cc_status.ok()) { - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); return nullptr; } TF_Buffer* buffer = TF_NewBuffer(); cc_status = tensorflow::MessageToBuffer(function, buffer); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); if (!cc_status.ok()) return nullptr; else @@ -753,10 +756,19 @@ int64_t TF_GetStepId(TF_OpKernelContext* ctx) { int TF_GetDeviceId(TF_OpKernelContext* ctx) { // TensorFlow always sets device in OpKernelContext. - auto* device = - reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->device(); - if (!device->parsed_name().has_id) return -1; - return device->parsed_name().id; + const tensorflow::DeviceBase* device_base = + reinterpret_cast(ctx)->device(); +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + if (!device_base->parsed_name().has_id) return -1; + return device_base->parsed_name().id; +#else + const auto* device = reinterpret_cast( + device_base->UnderlyingDevice()); + const tsl::StatusOr id = tsl::GetDeviceIdFromDeviceParsedName( + device->parsed_name(), tensorflow::DeviceType(device->device_type())); + if (!id.ok()) return -1; + return *id; +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx) { @@ -791,8 +803,6 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, int num_dims, size_t len, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); - static_assert(sizeof(int64_t) == sizeof(int64_t), - "64-bit int types should match in size"); tensorflow::gtl::ArraySlice dimarray( reinterpret_cast(dims), num_dims); tensorflow::Tensor* tensor; @@ -818,8 +828,6 @@ TF_Tensor* TF_ForwardInputOrAllocateOutput( TF_SetStatus(status, TF_OK, ""); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); - static_assert(sizeof(int64_t) == sizeof(int64_t), - "64-bit int types should match in size"); tensorflow::gtl::ArraySlice input_indices_array( candidate_input_indices, num_candidate_input_indices); tensorflow::gtl::ArraySlice output_dimarray( @@ -847,8 +855,6 @@ TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype, TF_Status* status) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); TF_SetStatus(status, TF_OK, ""); - static_assert(sizeof(int64_t) == sizeof(int64_t), - "64-bit int types should match in size"); tensorflow::gtl::ArraySlice dimarray( reinterpret_cast(dims), num_dims); if (attributes && !attributes->struct_size) { diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 7759c02daa2..665aff8f17a 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -19,30 +19,12 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif @@ -283,7 +265,11 @@ TF_CAPI_EXPORT extern int64_t TF_GetIterId(TF_OpKernelContext* ctx); // Returns the Step ID of the given context. TF_CAPI_EXPORT extern int64_t TF_GetStepId(TF_OpKernelContext* ctx); -// Returns the Device ID of the device that the context possesses. +// Returns the Device ID of the device that the context possesses. Returns the +// PlatformDeviceId if a mapping between between TfDeviceId and PlatformDeviceId +// is set; otherwise returns the id in the device name. Please refer to +// tensorflow/tsl/framework/device_id.h for more details. +// For mobile or slim build, returns the id in the device name. TF_CAPI_EXPORT extern int TF_GetDeviceId(TF_OpKernelContext* ctx); // Returns the graph def version of the given context. diff --git a/tensorflow/c/kernels_experimental.cc b/tensorflow/c/kernels_experimental.cc index 7590921d952..259d1cac9df 100644 --- a/tensorflow/c/kernels_experimental.cc +++ b/tensorflow/c/kernels_experimental.cc @@ -262,7 +262,7 @@ void TF_AssignUpdateVariable(TF_OpKernelContext* ctx, int input_index, Status status = LookupResource(context, HandleFromInput(context, input_index), &variable); if (!status.ok()) { - printf("Failed with error: %s\n", status.error_message().c_str()); + printf("Failed with error: %s\n", tsl::NullTerminatedMessage(status)); abort(); } const Tensor& value = context->input(value_index); @@ -475,6 +475,118 @@ static Status ValidateVariantType(const Variant& variant) { return ::tensorflow::OkStatus(); } +static Status VariantBinaryAddFunc( + ::tensorflow::OpKernelContext* cc_ctx, const Variant& a, const Variant& b, + Variant* out, + void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b, + TF_Tensor* out)); + +static Status CCBinaryAddFunc( + ::tensorflow::OpKernelContext* cc_ctx, const Tensor& cc_a, + const Tensor& cc_b, Tensor* cc_out, + void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b, + TF_Tensor* out)) { + if (cc_a.dtype() == ::tensorflow::DT_INVALID) { + *cc_out = cc_b; + return ::tensorflow::OkStatus(); + } + if (cc_b.dtype() == ::tensorflow::DT_INVALID) { + *cc_out = cc_a; + return ::tensorflow::OkStatus(); + } + + Status status; + TF_Tensor* a = TF_TensorFromTensor(cc_a, &status); + TF_RETURN_IF_ERROR(status); + + TF_Tensor* b = TF_TensorFromTensor(cc_b, &status); + if (!status.ok()) { + TF_DeleteTensor(a); + return status; + } + + ::tensorflow::AllocatorAttributes attr; + if (cc_a.dtype() == ::tensorflow::DT_VARIANT) { + attr.set_on_host(true); + } + + status = cc_ctx->allocate_temp(cc_a.dtype(), cc_a.shape(), cc_out, attr); + if (!status.ok()) { + TF_DeleteTensor(a); + TF_DeleteTensor(b); + return status; + } + + TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status); + if (!status.ok()) { + TF_DeleteTensor(a); + TF_DeleteTensor(b); + return status; + } + + auto* ctx = reinterpret_cast(cc_ctx); + if (cc_a.dtype() == ::tensorflow::DT_VARIANT) { + return VariantBinaryAddFunc( + cc_ctx, cc_a.scalar()(), cc_b.scalar()(), + cc_out->scalar().data(), binary_add_func); + } else { + binary_add_func(ctx, a, b, out); + return cc_ctx->status(); + } +}; + +static Status VariantBinaryAddFunc( + ::tensorflow::OpKernelContext* cc_ctx, const Variant& a, const Variant& b, + Variant* out, + void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b, + TF_Tensor* out)) { + auto cc_binary_add = [binary_add_func](::tensorflow::OpKernelContext* cc_ctx, + const Tensor& cc_a, const Tensor& cc_b, + Tensor* cc_out) { + return CCBinaryAddFunc(cc_ctx, cc_a, cc_b, cc_out, binary_add_func); + }; + + if (out == nullptr) { + return ::tensorflow::errors::Internal( + "The output variant hasn't been initialized"); + } + + if (a.TypeId() != b.TypeId()) { + return ::tensorflow::errors::Internal( + "BinaryOpVariants: Variants a and b have different " + "type ids. Type names: '", + a.TypeName(), "' vs. '", b.TypeName(), "'"); + } + + if (a.TypeId() == tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) { + TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(a)); + *out = ::tensorflow::TensorList(); + + return ::tensorflow::TensorListBinaryAdd( + cc_ctx, *a.get<::tensorflow::TensorList>(), + *b.get<::tensorflow::TensorList>(), + out->get<::tensorflow::TensorList>(), cc_binary_add); + } else if (a.TypeId() == tensorflow::TypeIndex::Make< + ::tensorflow::data::OptionalVariant>()) { + TF_RETURN_IF_ERROR( + ValidateVariantType<::tensorflow::data::OptionalVariant>(a)); + *out = ::tensorflow::data::OptionalVariant(); + + return ::tensorflow::data::OptionalBinaryAdd( + cc_ctx, *a.get<::tensorflow::data::OptionalVariant>(), + *b.get<::tensorflow::data::OptionalVariant>(), + out->get<::tensorflow::data::OptionalVariant>(), cc_binary_add); + } + + const std::string type_index_name = + ::tensorflow::port::MaybeAbiDemangle(a.TypeId().name()); + + return ::tensorflow::errors::Internal( + "No unary variant binary_op function found for op ADD Variant " + "type_name: ", + type_index_name, " for device type: ", cc_ctx->device()->name()); +} + void TF_AddNVariant(TF_OpKernelContext* ctx, void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b, @@ -482,97 +594,11 @@ void TF_AddNVariant(TF_OpKernelContext* ctx, TF_Status* status) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); - auto cc_binary_add_func = [binary_add_func]( - ::tensorflow::OpKernelContext* cc_ctx, - const Tensor& cc_a, const Tensor& cc_b, - Tensor* cc_out) { - if (cc_a.dtype() == ::tensorflow::DT_INVALID) { - *cc_out = cc_b; - return ::tensorflow::OkStatus(); - } - if (cc_b.dtype() == ::tensorflow::DT_INVALID) { - *cc_out = cc_a; - return ::tensorflow::OkStatus(); - } - - Status status; - TF_Tensor* a = TF_TensorFromTensor(cc_a, &status); - TF_RETURN_IF_ERROR(status); - - TF_Tensor* b = TF_TensorFromTensor(cc_b, &status); - if (!status.ok()) { - TF_DeleteTensor(a); - return status; - } - - ::tensorflow::AllocatorAttributes attr; - if (cc_a.dtype() == ::tensorflow::DT_VARIANT) { - attr.set_on_host(true); - } - - status = cc_ctx->allocate_temp(cc_a.dtype(), cc_a.shape(), cc_out, attr); - if (!status.ok()) { - TF_DeleteTensor(a); - TF_DeleteTensor(b); - return status; - } - - TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status); - if (!status.ok()) { - TF_DeleteTensor(a); - TF_DeleteTensor(b); - return status; - } - - auto* ctx = reinterpret_cast(cc_ctx); - binary_add_func(ctx, a, b, out); - return cc_ctx->status(); - }; - - auto binary_add_variant = [cc_binary_add_func]( - ::tensorflow::OpKernelContext* cc_ctx, - const Variant& a, const Variant& b, - Variant* out) { - if (out == nullptr) { - return ::tensorflow::errors::Internal( - "The output variant hasn't been initialized"); - } - - if (a.TypeId() != b.TypeId()) { - return ::tensorflow::errors::Internal( - "BinaryOpVariants: Variants a and b have different " - "type ids. Type names: '", - a.TypeName(), "' vs. '", b.TypeName(), "'"); - } - - if (a.TypeId() == tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) { - TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(a)); - *out = ::tensorflow::TensorList(); - - return ::tensorflow::TensorListBinaryAdd( - cc_ctx, *a.get<::tensorflow::TensorList>(), - *b.get<::tensorflow::TensorList>(), - out->get<::tensorflow::TensorList>(), cc_binary_add_func); - } else if (a.TypeId() == tensorflow::TypeIndex::Make< - ::tensorflow::data::OptionalVariant>()) { - TF_RETURN_IF_ERROR( - ValidateVariantType<::tensorflow::data::OptionalVariant>(a)); - *out = ::tensorflow::data::OptionalVariant(); - - return ::tensorflow::data::OptionalBinaryAdd( - cc_ctx, *a.get<::tensorflow::data::OptionalVariant>(), - *b.get<::tensorflow::data::OptionalVariant>(), - out->get<::tensorflow::data::OptionalVariant>(), cc_binary_add_func); - } - - const std::string type_index_name = - ::tensorflow::port::MaybeAbiDemangle(a.TypeId().name()); - - return ::tensorflow::errors::Internal( - "No unary variant binary_op function found for op ADD Variant " - "type_name: ", - type_index_name, " for device type: ", cc_ctx->device()->name()); - }; + auto binary_add_variant = + [binary_add_func](::tensorflow::OpKernelContext* cc_ctx, const Variant& a, + const Variant& b, Variant* out) { + return VariantBinaryAddFunc(cc_ctx, a, b, out, binary_add_func); + }; ::tensorflow::AddNVariant(cc_ctx, binary_add_variant); ::tensorflow::Set_TF_Status_from_Status(status, cc_ctx->status()); } diff --git a/tensorflow/c/kernels_experimental.h b/tensorflow/c/kernels_experimental.h index fbf0247f1c0..a36ea55e311 100644 --- a/tensorflow/c/kernels_experimental.h +++ b/tensorflow/c/kernels_experimental.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ #define TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/kernels.h" // -------------------------------------------------------------------------- @@ -24,25 +25,6 @@ limitations under the License. // The API here is subject to changes in the future. // -------------------------------------------------------------------------- -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/c/ops.h b/tensorflow/c/ops.h index 7463809e35b..5d3a1e8965d 100644 --- a/tensorflow/c/ops.h +++ b/tensorflow/c/ops.h @@ -73,23 +73,10 @@ limitations under the License. #include #include +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/python/lib/core/safe_ptr.cc b/tensorflow/c/safe_ptr.cc similarity index 95% rename from tensorflow/python/lib/core/safe_ptr.cc rename to tensorflow/c/safe_ptr.cc index ce852a4f009..fa200b0712f 100644 --- a/tensorflow/python/lib/core/safe_ptr.cc +++ b/tensorflow/c/safe_ptr.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/c/safe_ptr.h" namespace tensorflow { diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/c/safe_ptr.h similarity index 90% rename from tensorflow/python/lib/core/safe_ptr.h rename to tensorflow/c/safe_ptr.h index 00f47d7bbe6..8d8b8141b0b 100644 --- a/tensorflow/python/lib/core/safe_ptr.h +++ b/tensorflow/c/safe_ptr.h @@ -13,16 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ -#define TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ - -#include +#ifndef TENSORFLOW_C_SAFE_PTR_H_ +#define TENSORFLOW_C_SAFE_PTR_H_ #include #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" -#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" namespace tensorflow { namespace detail { @@ -68,4 +65,4 @@ Safe_TF_BufferPtr make_safe(TF_Buffer* buffer); } // namespace tensorflow -#endif // TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ +#endif // TENSORFLOW_C_SAFE_PTR_H_ diff --git a/tensorflow/c/tf_buffer.h b/tensorflow/c/tf_buffer.h index f18f2116536..71a9aef844c 100644 --- a/tensorflow/c/tf_buffer.h +++ b/tensorflow/c/tf_buffer.h @@ -18,24 +18,7 @@ limitations under the License. #include -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG +#include "tensorflow/c/c_api_macros.h" #ifdef __cplusplus extern "C" { diff --git a/tensorflow/c/tf_buffer_internal.h b/tensorflow/c/tf_buffer_internal.h index a538de7e895..805f632cf72 100644 --- a/tensorflow/c/tf_buffer_internal.h +++ b/tensorflow/c/tf_buffer_internal.h @@ -22,11 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/status.h" -namespace tsl { -class Status; -} namespace tensorflow { -using tsl::Status; Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, TF_Buffer* out); diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h index df0c1fb45b0..1f5597fe99a 100644 --- a/tensorflow/c/tf_datatype.h +++ b/tensorflow/c/tf_datatype.h @@ -18,24 +18,7 @@ limitations under the License. #include -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG +#include "tensorflow/c/c_api_macros.h" #ifdef __cplusplus extern "C" { diff --git a/tensorflow/c/tf_status.h b/tensorflow/c/tf_status.h index db1d32bf8e7..22b237e16df 100644 --- a/tensorflow/c/tf_status.h +++ b/tensorflow/c/tf_status.h @@ -16,22 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_STATUS_H_ #define TENSORFLOW_C_TF_STATUS_H_ +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/tsl/c/tsl_status.h" -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 90e19c55c89..d4efcaf50aa 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -183,7 +183,7 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type, *tensorflow::down_cast( from->tensor), static_cast(type), new_dims, num_new_dims)); - Set_TF_Status_from_Status(status, cc_status); + tsl::Set_TF_Status_from_Status(status, cc_status); } namespace tensorflow { diff --git a/tensorflow/c/tf_tensor.h b/tensorflow/c/tf_tensor.h index e8bef826599..05c74b8f342 100644 --- a/tensorflow/c/tf_tensor.h +++ b/tensorflow/c/tf_tensor.h @@ -23,25 +23,6 @@ limitations under the License. #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" -// Macro to control visibility of exported symbols in the shared library (.so, -// .dylib, .dll). -// This duplicates the TF_EXPORT macro definition in -// tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes. -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/c/tf_tstring.h b/tensorflow/c/tf_tstring.h index f9fb2fe083f..876fd5f384f 100644 --- a/tensorflow/c/tf_tstring.h +++ b/tensorflow/c/tf_tstring.h @@ -15,23 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_TSTRING_H_ #define TENSORFLOW_C_TF_TSTRING_H_ +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/platform/ctstring.h" -#ifdef SWIG -#define TF_CAPI_EXPORT -#else -#if defined(_WIN32) -#ifdef TF_COMPILE_LIBRARY -#define TF_CAPI_EXPORT __declspec(dllexport) -#else -#define TF_CAPI_EXPORT __declspec(dllimport) -#endif // TF_COMPILE_LIBRARY -#else -#define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // _WIN32 -#endif // SWIG - #ifdef __cplusplus extern "C" { #endif diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 2ea322ffcb2..a6c7be07554 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -108,10 +108,14 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, std::vector* outputs, RunMetadata* run_metadata) const { std::vector> feeds; + feeds.reserve(inputs.size()); for (auto const& feed : inputs) { TF_RETURN_IF_ERROR(feed.second.status); - feeds.emplace_back(feed.first.name(), feed.second.tensor); + feeds.emplace_back(std::piecewise_construct, + std::forward_as_tuple(feed.first.name()), + std::forward_as_tuple(feed.second.tensor)); } + std::vector output_tensor_names; output_tensor_names.reserve(fetch_outputs.size()); for (auto const& output : fetch_outputs) { diff --git a/tensorflow/cc/experimental/libtf/tests/function_test.cc b/tensorflow/cc/experimental/libtf/tests/function_test.cc index 226cbf2afa7..a9b4061f1a0 100644 --- a/tensorflow/cc/experimental/libtf/tests/function_test.cc +++ b/tensorflow/cc/experimental/libtf/tests/function_test.cc @@ -50,7 +50,7 @@ class FunctionTest impl::TaggedValueTensor CreateScalarTensor(T val) { AbstractTensorHandle* raw = nullptr; Status s = TestScalarTensorHandle(ctx_.get(), val, &raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); return impl::TaggedValueTensor(raw, /*add_ref=*/false); } @@ -64,12 +64,12 @@ class FunctionTest TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); // Set the runtime impl, Core RT vs TFRT. AbstractContext* ctx_raw = nullptr; s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); ctx_.reset(ctx_raw); } }; @@ -139,7 +139,7 @@ template void ExpectEquals(AbstractTensorHandle* t, T expected) { TF_Tensor* result_t; Status s = tensorflow::GetValue(t, &result_t); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); auto value = static_cast(TF_TensorData(result_t)); EXPECT_EQ(*value, expected); TF_DeleteTensor(result_t); @@ -156,10 +156,10 @@ TEST_P(FunctionTest, Square) { PartialTensorShape unknown_shape; TaggedValue signature(unknown_shape, DT_FLOAT); Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); TaggedValue args(std::move(x)); StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().error_message(); + ASSERT_TRUE(v.ok()) << v.status().message(); const TaggedValue& result = v.value(); AbstractTensorHandle* t = result.tensor().get(); ExpectEquals(t, 4.0f); @@ -178,12 +178,12 @@ TEST_P(FunctionTest, Add) { input_signature.tuple().emplace_back(tensor_spec); Status s = tf_function.RegisterTrace(std::move(trace), input_signature, tensor_spec); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); TaggedValue args = TaggedValue::Tuple(); args.tuple().emplace_back(TaggedValue(x)); args.tuple().emplace_back(TaggedValue(x)); StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().error_message(); + ASSERT_TRUE(v.ok()) << v.status().message(); const TaggedValue& result = v.value(); ExpectEquals(result.tensor().get(), 4.0f); } @@ -200,12 +200,12 @@ TEST_P(FunctionTest, IdentityN) { signature.tuple().emplace_back(tensor_spec); signature.tuple().emplace_back(tensor_spec); Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); TaggedValue args = TaggedValue::Tuple(); args.tuple().emplace_back(TaggedValue(x)); args.tuple().emplace_back(TaggedValue(y)); StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().error_message(); + ASSERT_TRUE(v.ok()) << v.status().message(); const TaggedValue& result = v.value(); ExpectEquals(result.tuple()[0].tensor().get(), 2.0f); ExpectEquals(result.tuple()[1].tensor().get(), 4.0f); @@ -220,13 +220,13 @@ TEST_P(FunctionTest, UnaryFuncCalledWithMultipleArgsFails) { PartialTensorShape unknown_shape; TaggedValue signature(unknown_shape, DT_FLOAT); Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); TaggedValue args = TaggedValue::Tuple(); args.tuple().emplace_back(TaggedValue(x)); args.tuple().emplace_back(TaggedValue(x)); StatusOr v = tf_function.Execute(ctx_.get(), args); ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())); - ASSERT_TRUE(absl::StrContains(v.status().error_message(), "No match")); + ASSERT_TRUE(absl::StrContains(v.status().message(), "No match")); } TEST_P(FunctionTest, IncorrectArityOfOutputSignatureFails) { @@ -248,13 +248,13 @@ TEST_P(FunctionTest, IncorrectArityOfOutputSignatureFails) { TaggedValue output_signature(unknown_shape, DT_FLOAT); Status s = tf_function.RegisterTrace(std::move(trace), input_signature, output_signature); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); TaggedValue args = TaggedValue::Tuple(); args.tuple().emplace_back(TaggedValue(x)); args.tuple().emplace_back(TaggedValue(y)); StatusOr v = tf_function.Execute(ctx_.get(), args); ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())) << v.status(); - ASSERT_TRUE(absl::StrContains(v.status().error_message(), + ASSERT_TRUE(absl::StrContains(v.status().message(), "Expecting 2 outputs, but *num_retvals is 1")); } @@ -273,15 +273,15 @@ TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) { TaggedValue output_tensor_spec(unknown_shape, tensorflow::DT_INT64); Status s = tf_function.RegisterTrace(std::move(trace), input_signature, output_tensor_spec); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); TaggedValue args = TaggedValue::Tuple(); args.tuple().emplace_back(TaggedValue(x)); args.tuple().emplace_back(TaggedValue(x)); StatusOr v = tf_function.Execute(ctx_.get(), args); ASSERT_TRUE(tensorflow::errors::IsInternal(v.status())) << v.status(); - ASSERT_TRUE(absl::StrContains(v.status().error_message(), - "Shape and dtype of tensor")); - ASSERT_TRUE(absl::StrContains(v.status().error_message(), + ASSERT_TRUE( + absl::StrContains(v.status().message(), "Shape and dtype of tensor")); + ASSERT_TRUE(absl::StrContains(v.status().message(), "does not match that in signature")); } diff --git a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc b/tensorflow/cc/experimental/libtf/tests/tensor_test.cc index 0115d0ac50f..3f4708f0f0d 100644 --- a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc +++ b/tensorflow/cc/experimental/libtf/tests/tensor_test.cc @@ -43,7 +43,7 @@ class UnifiedCAPI TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); } }; @@ -52,7 +52,7 @@ template TaggedValue MakeContext(T runtime) { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(runtime, &ctx_raw); - // ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + // ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); return TaggedValue::Capsule(static_cast(ctx_raw), [](void* p) { tensorflow::internal::AbstractContextDeleter()( static_cast(p)); @@ -67,7 +67,7 @@ TEST_P(UnifiedCAPI, HoldTensors) { AbstractContext* ctx_raw = nullptr; Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); } @@ -76,7 +76,7 @@ TEST_P(UnifiedCAPI, HoldTensors) { { AbstractTensorHandle* x_raw = nullptr; Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); - ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); x.reset(x_raw, false); } // Manually copy pointer so we can later compare the reference count. diff --git a/tensorflow/cc/experimental/libtf/tests/variable_test.cc b/tensorflow/cc/experimental/libtf/tests/variable_test.cc index 402943a58ca..8e7aca22bdc 100644 --- a/tensorflow/cc/experimental/libtf/tests/variable_test.cc +++ b/tensorflow/cc/experimental/libtf/tests/variable_test.cc @@ -48,7 +48,7 @@ class VariableTest impl::TaggedValueTensor CreateScalarTensor(T val) { AbstractTensorHandle* raw = nullptr; Status s = TestScalarTensorHandle(ctx_.get(), val, &raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); return impl::TaggedValueTensor(raw, /*add_ref=*/false); } @@ -62,12 +62,12 @@ class VariableTest TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); // Set the runtime impl, Core RT vs TFRT. AbstractContext* ctx_raw = nullptr; s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message(); + CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); ctx_.reset(ctx_raw); } }; @@ -76,7 +76,7 @@ template void ExpectEquals(AbstractTensorHandle* t, T expected) { TF_Tensor* result_t; Status s = tensorflow::GetValue(t, &result_t); - ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(s.ok()) << s.message(); auto value = static_cast(TF_TensorData(result_t)); EXPECT_EQ(*value, expected); TF_DeleteTensor(result_t); @@ -89,7 +89,7 @@ TEST_P(VariableTest, CreateAssignReadDestroy) { AbstractTensorHandle* var_ptr = nullptr; PartialTensorShape scalar_shape; TF_EXPECT_OK( - PartialTensorShape::MakePartialShape({}, 0, &scalar_shape)); + PartialTensorShape::MakePartialShape({}, 0, &scalar_shape)); TF_EXPECT_OK(tensorflow::ops::VarHandleOp(ctx_.get(), &var_ptr, DT_FLOAT, scalar_shape)); var.reset(var_ptr); diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 178b4da972a..4c978da32ea 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -241,7 +243,7 @@ TEST(CCOpTest, InvalidFinalize) { ops::ReaderReadUpTo(root, Variable(root, {}, DT_STRING), Variable(root, {}, DT_STRING), static_cast(2)); EXPECT_FALSE(root.status().ok()); - auto err_msg = root.status().error_message(); + auto err_msg = std::string(root.status().message()); EXPECT_NE(err_msg.find("'num_records' passed int32 expected int64"), string::npos); } diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 75291678177..2256d795422 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -459,7 +459,7 @@ TEST_F(GradientsTest, UnreachableInput) { Status status = AddSymbolicGradients(scope_test_, {m1}, {z}, {dm1}, &grad_outputs); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); - EXPECT_EQ(status.error_message(), + EXPECT_EQ(status.message(), "Cannot compute the partial derivative" " for node 'z' as it's unreachable from the output node(s)."); } diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc index 18b8be3794f..1e9338eb0c2 100644 --- a/tensorflow/cc/ops/while_loop_test.cc +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -42,7 +42,7 @@ class WhileLoopTest : public ::testing::Test { Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_); EXPECT_EQ(s.code(), error_code); - EXPECT_EQ(s.error_message(), error_msg); + EXPECT_EQ(s.message(), error_msg); } template diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index a3e4c20b7c9..d52db030b1b 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -26,7 +26,9 @@ package( licenses = ["notice"], ) -exports_files(["loader.h"]) +exports_files([ + "loader.h", +]) cc_library( name = "constants", @@ -58,9 +60,9 @@ cc_library( hdrs = ["reader.h"], deps = [ ":constants", - "//tensorflow/core:protos_all_cc", ":metrics", ":util", + "//tensorflow/core:protos_all_cc", ] + if_not_mobile([ # TODO(b/111634734): :lib and :protos_all contain dependencies that # cannot be built on mobile platforms. Instead, include the appropriate @@ -158,6 +160,7 @@ cc_library( "//tensorflow/core/util/tensor_bundle", "//tensorflow/core/util/tensor_bundle:byteswaptensor", "@com_google_absl//absl/container:flat_hash_set", + "@jsoncpp_git//:jsoncpp", ], ) @@ -186,6 +189,11 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@jsoncpp_git//:jsoncpp", ], ) @@ -331,7 +339,12 @@ cc_library( "//tensorflow/python:__pkg__", "//tensorflow/security/fuzzing/cc/ops:__pkg__", # TODO(b/261455394): Remove. ], - deps = if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), + deps = [ + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@jsoncpp_git//:jsoncpp", + ] + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), alwayslink = True, ) @@ -341,7 +354,11 @@ cc_library( visibility = ["//tensorflow/python/saved_model:__subpackages__"], deps = if_static([ ":metrics_impl", - ]) + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), + ]) + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]) + [ + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], ) tf_cc_test( @@ -350,9 +367,10 @@ tf_cc_test( srcs = ["metrics_test.cc"], deps = [ ":metrics", - "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_googletest//:gtest_main", + "@jsoncpp_git//:jsoncpp", ], ) @@ -392,14 +410,14 @@ cc_library( ], deps = [ ":constants", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/graph/regularization:simple_delete", "//tensorflow/core/graph/regularization:util", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/util/tensor_bundle:naming", "//tensorflow/tsl/platform:types", - "@com_google_protobuf//:protobuf_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf_headers", ] + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), alwayslink = True, ) @@ -407,7 +425,12 @@ cc_library( cc_library( name = "fingerprinting", hdrs = ["fingerprinting.h"], - visibility = ["//tensorflow/python/saved_model:__subpackages__"], + visibility = [ + "//learning/brain/contrib/hub/server/distro:__subpackages__", + "//learning/brain/contrib/tpu_modeling:__subpackages__", + "//learning/tfx/pipeline/util:__subpackages__", + "//tensorflow/python/saved_model:__subpackages__", + ], deps = if_static([ ":fingerprinting_impl", "@com_google_absl//absl/strings", diff --git a/tensorflow/cc/saved_model/bundle_v2.cc b/tensorflow/cc/saved_model/bundle_v2.cc index 90f220575c5..21692edbf40 100644 --- a/tensorflow/cc/saved_model/bundle_v2.cc +++ b/tensorflow/cc/saved_model/bundle_v2.cc @@ -73,8 +73,8 @@ Status ReadSavedModelProto(const string& export_dir, Status err; if (found_pb.code() == found_pbtxt.code()) { - err = Status(found_pb.code(), StrCat(found_pb.error_message(), "\n", - found_pbtxt.error_message())); + err = Status(found_pb.code(), + StrCat(found_pb.message(), "\n", found_pbtxt.message())); } else if (found_pb.code() == NOT_FOUND) { err = found_pbtxt; } else if (found_pbtxt.code() == NOT_FOUND) { @@ -171,11 +171,17 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir, // Read the fingerprint. auto fingerprint_proto = saved_model::fingerprinting::ReadSavedModelFingerprint(export_dir); + std::string singleprint = ""; if (fingerprint_proto.ok()) { - // Set gauge cell with saved_model_checksum. metrics::SavedModelReadFingerprint().Set( - std::to_string(fingerprint_proto->saved_model_checksum())); + metrics::MakeFingerprintJson(fingerprint_proto.value())); + + singleprint = + saved_model::fingerprinting::Singleprint(fingerprint_proto.value()); } + + metrics::SavedModelReadPathAndSingleprint().Set( + metrics::MakeSavedModelPathAndSingleprint(export_dir, singleprint)); return OkStatus(); } diff --git a/tensorflow/cc/saved_model/bundle_v2.h b/tensorflow/cc/saved_model/bundle_v2.h index 76e6ce20e70..e199bd1cc5d 100644 --- a/tensorflow/cc/saved_model/bundle_v2.h +++ b/tensorflow/cc/saved_model/bundle_v2.h @@ -25,8 +25,8 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" diff --git a/tensorflow/cc/saved_model/bundle_v2_test.cc b/tensorflow/cc/saved_model/bundle_v2_test.cc index f6434914455..6dc3be0bf56 100644 --- a/tensorflow/cc/saved_model/bundle_v2_test.cc +++ b/tensorflow/cc/saved_model/bundle_v2_test.cc @@ -17,19 +17,25 @@ limitations under the License. #include #include +#include #include +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "json/json.h" +#include "json/reader.h" +#include "json/value.h" #include "tensorflow/cc/saved_model/metrics.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/fingerprint.pb.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { constexpr char kTestData[] = "cc/saved_model/testdata"; -// This is the value in testdata/VarsAndArithmeticObjectGraph/fingerprint.pb -constexpr char kV2ModuleSavedModelChecksum[] = "15788619162413586750"; class BundleV2Test : public ::testing::Test { protected: @@ -116,10 +122,33 @@ TEST_F(BundleV2Test, UpdatesMetrics) { EXPECT_EQ(metrics::SavedModelReadCount("2").value(), read_count + 1); EXPECT_EQ(metrics::SavedModelReadApi(kCCLoadBundleV2Label).value(), api_count + 1); - // Check that the gauge contains the fingerprint. - EXPECT_EQ(metrics::SavedModelReadFingerprint().value(), - kV2ModuleSavedModelChecksum); + // Check that the gauge contains the path and fingerprint. EXPECT_EQ(metrics::SavedModelReadPath().value(), export_dir); + + Json::Value fingerprint = Json::objectValue; + Json::Reader reader = Json::Reader(); + reader.parse(metrics::SavedModelReadFingerprint().value(), fingerprint); + EXPECT_EQ(fingerprint["saved_model_checksum"].asUInt64(), + 15788619162413586750ULL); + EXPECT_EQ(fingerprint["graph_def_program_hash"].asUInt64(), + 706963557435316516ULL); + EXPECT_EQ(fingerprint["signature_def_hash"].asUInt64(), + 5693392539583495303ULL); + EXPECT_EQ(fingerprint["saved_object_graph_hash"].asUInt64(), + 12074714563970609759ULL); + EXPECT_EQ(fingerprint["checkpoint_hash"].asUInt64(), 10788359570789890102ULL); + + // TODO(adamcogdell): add ASSERT_OK_AND_ASSIGN here after migrating + // cc/saved_model code from the tsl version of StatusOr to absl::StatusOr + auto [path, singleprint] = metrics::ParseSavedModelPathAndSingleprint( + metrics::SavedModelReadPathAndSingleprint().value()); + EXPECT_TRUE(absl::StrContains( + path, absl::StrCat(kTestData, "/VarsAndArithmeticObjectGraph"))); + EXPECT_EQ(singleprint, + "706963557435316516/" // graph_def_program_hash + "5693392539583495303/" // signature_def_hash + "12074714563970609759/" // saved_object_graph_hash + "10788359570789890102"); // checkpoint_hash } } // namespace diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index 0bb064c107f..389b28bf278 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -152,16 +152,14 @@ StatusOr ReadSavedModelFingerprint( const string fingerprint_pb_path = io::JoinPath(export_dir, kFingerprintFilenamePb); Status found_pb = Env::Default()->FileExists(fingerprint_pb_path); - if (found_pb.ok()) { - FingerprintDef fingerprint_proto; - Status result = ReadBinaryProto(Env::Default(), fingerprint_pb_path, - &fingerprint_proto); - if (result.ok()) { - return fingerprint_proto; - } - return result; - } - return found_pb; + if (!found_pb.ok()) return found_pb; + + FingerprintDef fingerprint_proto; + Status result = + ReadBinaryProto(Env::Default(), fingerprint_pb_path, &fingerprint_proto); + if (!result.ok()) return result; + + return fingerprint_proto; } std::string Singleprint(uint64 graph_def_program_hash, diff --git a/tensorflow/cc/saved_model/fingerprinting_test.cc b/tensorflow/cc/saved_model/fingerprinting_test.cc index 7e298cfc844..1c1e12440d6 100644 --- a/tensorflow/cc/saved_model/fingerprinting_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_test.cc @@ -136,7 +136,8 @@ TEST(FingerprintingTest, TestReadValidFingerprint) { TEST(FingerprintingTest, TestReadNonexistentFingerprint) { const std::string export_dir = io::JoinPath( testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "AssetModule"); - EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok()); + EXPECT_EQ(ReadSavedModelFingerprint(export_dir).status().code(), + absl::StatusCode::kNotFound); } TEST(FingerprintingTest, TestSingleprint) { diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 75869afe687..b9544bc7555 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session.h" diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 9d43f4ecc76..f2d318a25b7 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/public/session.h" diff --git a/tensorflow/cc/saved_model/metrics.cc b/tensorflow/cc/saved_model/metrics.cc index 86ff72a7839..f44abe8b659 100644 --- a/tensorflow/cc/saved_model/metrics.cc +++ b/tensorflow/cc/saved_model/metrics.cc @@ -16,10 +16,15 @@ limitations under the License. #include "tensorflow/cc/saved_model/metrics.h" #include +#include +#include "json/config.h" +#include "json/json.h" +#include "json/writer.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/protobuf/fingerprint.pb.h" namespace tensorflow { namespace metrics { @@ -64,6 +69,17 @@ auto* saved_model_write_path = monitoring::Gauge::New( "/tensorflow/core/saved_model/write/path", "The path (saved_model_path) of the exported SavedModel."); +// Gauge that contains the path (saved_model_path) and the singleprint +// (concatenation of graph_def_program_hash, signature_def_hash, +// saved_object_graph_hash, and checkpoint_hash) of the newly written +// SavedModel. +auto* saved_model_write_path_and_singleprint = + monitoring::Gauge::New( + "/tensorflow/core/saved_model/write/path_and_singleprint", + "The path (saved_model_path) and singleprint (concatenation of " + "graph_def_program_hash, signature_def_hash, saved_object_graph_hash, " + "and checkpoint_hash) of the newly written SavedModel."); + // Gauge that contains the fingerprint (saved_model_checksum) of the loaded // SavedModel. auto* saved_model_read_fingerprint = monitoring::Gauge::New( @@ -75,6 +91,15 @@ auto* saved_model_read_path = monitoring::Gauge::New( "/tensorflow/core/saved_model/read/path", "The path (saved_model_path) of the loaded SavedModel."); +// Gauge that contains the path (saved_model_path) and the singleprint +// (concatenation of graph_def_program_hash, signature_def_hash, +// saved_object_graph_hash, and checkpoint_hash) of the loaded SavedModel. +auto* saved_model_read_path_and_singleprint = monitoring::Gauge::New( + "/tensorflow/core/saved_model/read/path_and_singleprint", + "The path (saved_model_path) and singleprint (concatenation of " + "graph_def_program_hash, signature_def_hash, saved_object_graph_hash, " + "and checkpoint_hash) of the loaded SavedModel."); + // Distribution of checkpoint write durations. auto* checkpoint_write_durations = monitoring::Sampler<1>::New( { @@ -153,6 +178,10 @@ monitoring::GaugeCell& SavedModelReadPath() { return *saved_model_read_path->GetCell(); } +monitoring::GaugeCell& SavedModelReadPathAndSingleprint() { + return *saved_model_read_path_and_singleprint->GetCell(); +} + monitoring::GaugeCell& SavedModelWriteFingerprint() { return *saved_model_write_fingerprint->GetCell(); } @@ -161,6 +190,41 @@ monitoring::GaugeCell& SavedModelWritePath() { return *saved_model_write_path->GetCell(); } +monitoring::GaugeCell& SavedModelWritePathAndSingleprint() { + return *saved_model_write_path_and_singleprint->GetCell(); +} + +string MakeFingerprintJson(FingerprintDef fingerprint_serialized) { + Json::Value fingerprint = Json::objectValue; + fingerprint["saved_model_checksum"] = + Json::UInt64(fingerprint_serialized.saved_model_checksum()); + fingerprint["graph_def_program_hash"] = + Json::UInt64(fingerprint_serialized.graph_def_program_hash()); + fingerprint["signature_def_hash"] = + Json::UInt64(fingerprint_serialized.signature_def_hash()); + fingerprint["saved_object_graph_hash"] = + Json::UInt64(fingerprint_serialized.saved_object_graph_hash()); + fingerprint["checkpoint_hash"] = + Json::UInt64(fingerprint_serialized.checkpoint_hash()); + + Json::StreamWriterBuilder json_factory; + return Json::writeString(json_factory, fingerprint); +} + +string MakeSavedModelPathAndSingleprint(string path, string singleprint) { + return absl::StrCat(path, ":", singleprint); +} + +std::pair ParseSavedModelPathAndSingleprint( + string path_and_singleprint) { + size_t delimiter = path_and_singleprint.rfind(':'); + if (delimiter == std::string::npos) { + return std::pair("", ""); + } + return std::pair(path_and_singleprint.substr(0, delimiter), + path_and_singleprint.substr(delimiter + 1)); +} + monitoring::SamplerCell& CheckpointReadDuration(absl::string_view api_label) { return *checkpoint_read_durations->GetCell(std::string(api_label)); } diff --git a/tensorflow/cc/saved_model/metrics.h b/tensorflow/cc/saved_model/metrics.h index f89374af0fa..c39b9c3bc8f 100644 --- a/tensorflow/cc/saved_model/metrics.h +++ b/tensorflow/cc/saved_model/metrics.h @@ -20,11 +20,13 @@ limitations under the License. #ifndef TENSORFLOW_CC_SAVED_MODEL_METRICS_H_ #define TENSORFLOW_CC_SAVED_MODEL_METRICS_H_ -#include +#include +#include "absl/status/status.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/protobuf/fingerprint.pb.h" namespace tensorflow { namespace metrics { @@ -49,6 +51,12 @@ monitoring::GaugeCell& SavedModelWriteFingerprint(); // the saved_model_path of the SM when it is exported. monitoring::GaugeCell& SavedModelWritePath(); +// Returns "/tensorflow/core/saved_model/write/path_and_fingerprint" cell, which +// contains the path (saved_model_path) and fingerprint (concatenation of +// graph_def_program_hash, signature_def_hash, saved_object_graph_hash, +// and checkpoint_hash) of the SavedModel when it is exported. +monitoring::GaugeCell& SavedModelWritePathAndSingleprint(); + // Returns "/tensorflow/core/saved_model/read/fingerprint" cell, wich contains // the saved_model_checksum of the SM's fingerprint when it is imported. monitoring::GaugeCell& SavedModelReadFingerprint(); @@ -57,6 +65,24 @@ monitoring::GaugeCell& SavedModelReadFingerprint(); // the saved_model_path of the SM when it is imported. monitoring::GaugeCell& SavedModelReadPath(); +// Returns "/tensorflow/core/saved_model/read/path_and_fingerprint" cell, which +// contains the path (saved_model_path) and singleprint (concatenation of +// graph_def_program_hash, signature_def_hash, saved_object_graph_hash, +// and checkpoint_hash) of the SavedModel when it is imported. +monitoring::GaugeCell& SavedModelReadPathAndSingleprint(); + +// Returns the fingerprint as a Json string. +string MakeFingerprintJson(FingerprintDef fingerprint_serialized); + +// Returns canonical string concatenation of path and singleprint. +string MakeSavedModelPathAndSingleprint(string path, string singleprint); + +// TODO(adamcogdell): change to StatusOr<> to account for missing delimiter +// Returns path and singleprint as a pair, parsed canonically from the string +// metric. +std::pair ParseSavedModelPathAndSingleprint( + string path_and_singleprint); + // Returns "/tensorflow/core/saved_model/write/api" cell. This metric has 1 // field "api_label" which corresponds to a SavedModel write API. The cell for // `foo` should be incremented when the write API `foo` is called. diff --git a/tensorflow/cc/saved_model/metrics_test.cc b/tensorflow/cc/saved_model/metrics_test.cc index 0f876040b44..b4901c942c1 100644 --- a/tensorflow/cc/saved_model/metrics_test.cc +++ b/tensorflow/cc/saved_model/metrics_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/cc/saved_model/metrics.h" +#include +#include +#include "json/json.h" +#include "json/reader.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -89,6 +93,17 @@ TEST(MetricsTest, TestWritePath) { EXPECT_EQ(SavedModelWritePath().value(), "bar"); } +TEST(MetricsTest, TestWritePathAndSingleprint) { + EXPECT_EQ(SavedModelWritePathAndSingleprint().value(), ""); + SavedModelWritePathAndSingleprint().Set("foo"); + EXPECT_EQ(SavedModelWritePathAndSingleprint().value(), "foo"); + SavedModelWritePathAndSingleprint().Set("bar"); + EXPECT_EQ(SavedModelWritePathAndSingleprint().value(), "bar"); + + EXPECT_EQ(MakeSavedModelPathAndSingleprint("path", "singleprint"), + "path:singleprint"); +} + TEST(MetricsTest, TestReadFingerprint) { EXPECT_EQ(SavedModelReadFingerprint().value(), ""); SavedModelReadFingerprint().Set("foo"); @@ -105,5 +120,44 @@ TEST(MetricsTest, TestReadPath) { EXPECT_EQ(SavedModelReadPath().value(), "bar"); } +TEST(MetricsTest, TestReadPathAndSingleprint) { + EXPECT_EQ(SavedModelReadPathAndSingleprint().value(), ""); + SavedModelReadPathAndSingleprint().Set("foo"); + EXPECT_EQ(SavedModelReadPathAndSingleprint().value(), "foo"); + SavedModelReadPathAndSingleprint().Set("bar"); + EXPECT_EQ(SavedModelReadPathAndSingleprint().value(), "bar"); + + auto [path, singleprint] = + ParseSavedModelPathAndSingleprint("path/model:name:singleprint"); + EXPECT_EQ(path, "path/model:name"); + EXPECT_EQ(singleprint, "singleprint"); +} + +TEST(MetricsTest, TestMakeFingerprintJson) { + FingerprintDef fingerprint; + fingerprint.set_saved_model_checksum(1); + fingerprint.set_graph_def_program_hash(2); + fingerprint.set_signature_def_hash(3); + fingerprint.set_saved_object_graph_hash(4); + fingerprint.set_checkpoint_hash(5); + + string serialized_fingerprint_json = MakeFingerprintJson(fingerprint); + + EXPECT_EQ( + serialized_fingerprint_json, + "{\n\t\"checkpoint_hash\" : 5,\n\t\"graph_def_program_hash\" : " + "2,\n\t\"saved_model_checksum\" : 1,\n\t\"saved_object_graph_hash\" : " + "4,\n\t\"signature_def_hash\" : 3\n}"); + + Json::Value fingerprint_json = Json::objectValue; + Json::Reader reader = Json::Reader(); + reader.parse(serialized_fingerprint_json, fingerprint_json); + EXPECT_EQ(fingerprint_json["saved_model_checksum"].asUInt64(), 1); + EXPECT_EQ(fingerprint_json["graph_def_program_hash"].asUInt64(), 2); + EXPECT_EQ(fingerprint_json["signature_def_hash"].asUInt64(), 3); + EXPECT_EQ(fingerprint_json["saved_object_graph_hash"].asUInt64(), 4); + EXPECT_EQ(fingerprint_json["checkpoint_hash"].asUInt64(), 5); +} + } // namespace metrics } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h index 2c2cb865b93..f51fbeb557f 100644 --- a/tensorflow/cc/saved_model/reader.h +++ b/tensorflow/cc/saved_model/reader.h @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc index 443c04efe45..4b8b5cde20d 100644 --- a/tensorflow/cc/saved_model/reader_test.cc +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -71,9 +71,9 @@ TEST_F(ReaderTest, NoTagMatch) { &meta_graph_def); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( - st.error_message(), + st.message(), "Could not find meta graph def matching supplied tags: { missing-tag }")) - << st.error_message(); + << st.message(); } TEST_F(ReaderTest, NoTagMatchMultiple) { @@ -84,9 +84,8 @@ TEST_F(ReaderTest, NoTagMatchMultiple) { export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( - st.error_message(), - "Could not find meta graph def matching supplied tags: ")) - << st.error_message(); + st.message(), "Could not find meta graph def matching supplied tags: ")) + << st.message(); } TEST_F(ReaderTest, PbtxtFormat) { diff --git a/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc index 604fc412800..5d3690ea1a5 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc @@ -159,9 +159,9 @@ TEST_F(LoaderTest, NoTagMatch) { {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( - st.error_message(), + st.message(), "Could not find meta graph def matching supplied tags: { missing-tag }")) - << st.error_message(); + << st.message(); } TEST_F(LoaderTest, NoTagMatchMultiple) { @@ -175,9 +175,8 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { {kSavedModelTagServe, "missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( - st.error_message(), - "Could not find meta graph def matching supplied tags: ")) - << st.error_message(); + st.message(), "Could not find meta graph def matching supplied tags: ")) + << st.message(); } TEST_F(LoaderTest, SessionCreationFailure) { @@ -194,8 +193,7 @@ TEST_F(LoaderTest, SessionCreationFailure) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget)) - << st.error_message(); + EXPECT_TRUE(absl::StrContains(st.message(), kInvalidTarget)) << st.message(); } TEST_F(LoaderTest, PbtxtFormat) { diff --git a/tensorflow/cc/saved_model/saved_model_bundle_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc index 7e78aee67b1..eda63fba4fe 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -189,9 +189,9 @@ TEST_F(LoaderTest, NoTagMatch) { {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( - st.error_message(), + st.message(), "Could not find meta graph def matching supplied tags: { missing-tag }")) - << st.error_message(); + << st.message(); } TEST_F(LoaderTest, NoTagMatchMultiple) { @@ -205,9 +205,8 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { {kSavedModelTagServe, "missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( - st.error_message(), - "Could not find meta graph def matching supplied tags: ")) - << st.error_message(); + st.message(), "Could not find meta graph def matching supplied tags: ")) + << st.message(); } TEST_F(LoaderTest, SessionCreationFailure) { @@ -224,8 +223,7 @@ TEST_F(LoaderTest, SessionCreationFailure) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(absl::StrContains(st.error_message(), kInvalidTarget)) - << st.error_message(); + EXPECT_TRUE(absl::StrContains(st.message(), kInvalidTarget)) << st.message(); } TEST_F(LoaderTest, PbtxtFormat) { @@ -317,9 +315,8 @@ TEST_F(LoaderTest, NegativeShapeDimension) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_NE( - st.error_message().find("initializes from a tensor with -1 elements"), - std::string::npos); + EXPECT_NE(st.message().find("initializes from a tensor with -1 elements"), + std::string::npos); } TEST_F(LoaderTest, ConstNoValue) { @@ -332,9 +329,8 @@ TEST_F(LoaderTest, ConstNoValue) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_NE( - st.error_message().find("constant tensor but no value has been provided"), - std::string::npos); + EXPECT_NE(st.message().find("constant tensor but no value has been provided"), + std::string::npos); } TEST_F(LoaderTest, BadNodeAttr) { @@ -347,9 +343,8 @@ TEST_F(LoaderTest, BadNodeAttr) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_NE( - st.error_message().find("constant tensor but no value has been provided"), - std::string::npos); + EXPECT_NE(st.message().find("constant tensor but no value has been provided"), + std::string::npos); } TEST_F(LoaderTest, UpdateMetricsV2) { diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index ca2b5f956bf..016b27101b9 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -37,7 +37,8 @@ class RunnerInterface { virtual ~RunnerInterface() {} virtual Status Join() = 0; virtual Status ExportCostGraph(CostGraphDef* cost_graph) const { - return Status(error::INVALID_ARGUMENT, "No cost model to export."); + return Status(absl::StatusCode::kInvalidArgument, + "No cost model to export."); } /// Returns true iff the runner is running, i.e. if it is trying to populate /// its queue. diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc index 75e0da6f8f0..75793297ddd 100644 --- a/tensorflow/cc/training/coordinator_test.cc +++ b/tensorflow/cc/training/coordinator_test.cc @@ -179,21 +179,24 @@ TEST(CoordinatorTest, StatusReporting) { BlockingCounter counter(3); std::unique_ptr qr1(new MockQueueRunner(&coord)); - qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start); + qr1->StartSettingStatus(Status(absl::StatusCode::kCancelled, ""), &counter, + &start); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1))); std::unique_ptr qr2(new MockQueueRunner(&coord)); - qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start); + qr2->StartSettingStatus(Status(absl::StatusCode::kInvalidArgument, ""), + &counter, &start); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2))); std::unique_ptr qr3(new MockQueueRunner(&coord)); - qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start); + qr3->StartSettingStatus(Status(absl::StatusCode::kOutOfRange, ""), &counter, + &start); TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3))); start.Notify(); counter.Wait(); TF_EXPECT_OK(coord.RequestStop()); - EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT); + EXPECT_EQ(coord.Join().code(), absl::StatusCode::kInvalidArgument); } TEST(CoordinatorTest, JoinWithoutStop) { diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 8a4414a96d7..18e3182e686 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -41,8 +41,8 @@ using ::xla::cpu_function_runtime::BufferInfo; void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(OkStatus(), status); - EXPECT_TRUE(absl::StrContains(status.error_message(), str)) - << "expected error: " << status.error_message() << " to contain: " << str; + EXPECT_TRUE(absl::StrContains(status.message(), str)) + << "expected error: " << status.message() << " to contain: " << str; } TEST(ValidateCppIdent, Simple) { diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 3217174b79e..fd3bf0bb7e9 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -68,7 +68,7 @@ Status CompileXla(xla::CompileOnlyClient* client, client->GetComputationShape(computation); if (!pshape_or.ok()) { return errors::Unknown("Couldn't get XLA program shape: ", - pshape_or.status().error_message()); + pshape_or.status().message()); } compile_result->program_shape = pshape_or.value()->ToProto(); xla::ProgramShapeProto* pshape = &compile_result->program_shape; @@ -91,7 +91,7 @@ Status CompileXla(xla::CompileOnlyClient* client, aot_or = client->CompileAheadOfTime({instance}, aot_opts); if (!aot_or.ok()) { return errors::Unknown("XLA compilation failed: ", - aot_or.status().error_message()); + aot_or.status().message()); } compile_result->aot = xla::unique_ptr_static_cast( @@ -260,7 +260,7 @@ Status Main(const MainFlags& flags) { CompileGraph(std::move(graph_def), config, flags, &compile_result); if (!status.ok()) { return errors::CreateWithUpdatedMessage( - status, InterpolateErrorMessage(status.error_message())); + status, InterpolateErrorMessage(std::string(status.message()))); } // Write output files. diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 4e4c4bad3a3..f04aa37c887 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -69,17 +69,17 @@ py_binary( srcs_version = "PY3", deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/python", # TODO(b/34059704): remove when fixed "//tensorflow/python:array_ops", "//tensorflow/python:client", + "//tensorflow/python:cond", "//tensorflow/python:control_flow_assert", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", - "//tensorflow/python:platform", "//tensorflow/python:session", "//tensorflow/python:training", + "//tensorflow/python:variable_v1", "//tensorflow/python:variables", "@absl_py//absl:app", "@six_archive//:six", @@ -325,6 +325,112 @@ tfcompile_test_dep_configs = [ for suffix, mlir_component in tfcompile_test_dep_configs ] +tfcompile_bench_tfmatmul_mkn = [ + # Intentionally empty to avoid running unnecessary tests. + # Add here your desired (M, K, N) parameters, e.g. + # (1, 1, 256), + # (1, 2, 256), +] + +tfcompile_bench_tfmatmul = [ + ( + "bench_graph_tfmatmul_%sx%sx%s" % (m, k, n), + "bench_graph_tfmatmul_%sx%sx%s.config.pbtxt" % (m, k, n), + "bench_graph_tfmatmul.template.pbtxt", + "-e \"s||%s|g\" -e \"s||%s|g\" -e \"s||%s|g\"" % (m, k, n), + ) + for (m, k, n) in tfcompile_bench_tfmatmul_mkn +] + +test_suite( + name = "all_tfmatmul_benchmarks", + tags = ["manual"], + tests = [ + (":%s_test" % bench_name) + for (bench_name, _, _, _) in tfcompile_bench_tfmatmul + ], + visibility = ["//visibility:public"], +) + +test_suite( + name = "all_tfmatmul_mlir_benchmarks", + tags = ["manual"], + tests = [ + (":%s_mlir_test" % bench_name) + for (bench_name, _, _, _) in tfcompile_bench_tfmatmul + ], + visibility = ["//visibility:public"], +) + +[[ + genrule( + name = "gen_" + config_file, + testonly = 1, + srcs = [template_file], + outs = [config_file], + cmd = ("sed " + sed_replace + " " + + "$(location " + template_file + ") " + + "> $(OUTS)"), + tags = ["manual"], + ), + tf_library( + name = bench_name, + testonly = 1, + config = config_file, + cpp_class = "foo::bar::MatMulComp", + graph = "test_graph_tfmatmul.pb", + tags = [ + "manual", + "no_mac", # TODO(b/228273415) + ], + ), +] for (bench_name, config_file, template_file, sed_replace) in tfcompile_bench_tfmatmul] + +tfcompile_bench_tfmatmul_tile_mkn = [ + # Intentionally empty to avoid running unnecessary tests. + # Add here your desired (M, K, N) parameters, e.g. + # (1, 2, 8), + # (1, 4, 4), +] + +tfcompile_bench_tfmatmul_custom_tiling = [ + ( + "bench_graph_tfmatmul_%sx%sx%s_tiled_%sx%sx%s_mlir" % (m, k, n, tm, tk, tn), + "bench_graph_tfmatmul_%sx%sx%s_tiled_%sx%sx%s_mlir.config.pbtxt" % (m, k, n, tm, tk, tn), + "bench_graph_tfmatmul.template.pbtxt", + "-e \"s||%s|g\" -e \"s||%s|g\" -e \"s||%s|g\"" % (m, k, n), + "--xla_cpu_enable_custom_matmul_tiling --xla_cpu_matmul_tiling_m_dim=%s --xla_cpu_matmul_tiling_k_dim=%s --xla_cpu_matmul_tiling_n_dim=%s" % (tm, tk, tn), + ) + for (m, k, n) in tfcompile_bench_tfmatmul_mkn + for (tm, tk, tn) in tfcompile_bench_tfmatmul_tile_mkn +] + +[[ + genrule( + name = "gen_" + config_file, + testonly = 1, + srcs = [template_file], + outs = [config_file], + cmd = ("sed " + sed_replace + " " + + "$(location " + template_file + ") " + + "> $(OUTS)"), + tags = ["manual"], + ), + tf_library( + name = bench_name, + testonly = 1, + config = config_file, + cpp_class = "foo::bar::MatMulComp", + graph = "test_graph_tfmatmul.pb", + mlir_components = "HloLowering", # XLA:CPU-Next only. + tags = [ + "manual", + "no_mac", # TODO(b/228273415) + ], + xla_flags = xla_flags, + ), +] for (bench_name, config_file, template_file, sed_replace, xla_flags) in tfcompile_bench_tfmatmul_custom_tiling] + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], diff --git a/tensorflow/compiler/aot/tests/bench_graph_tfmatmul.template.pbtxt b/tensorflow/compiler/aot/tests/bench_graph_tfmatmul.template.pbtxt new file mode 100644 index 00000000000..5f8f68c8492 --- /dev/null +++ b/tensorflow/compiler/aot/tests/bench_graph_tfmatmul.template.pbtxt @@ -0,0 +1,18 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x_hold" } + shape { + dim { size: } + dim { size: } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: } + dim { size: } + } +} +fetch { + id { node_name: "x_y_prod" } +} diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 07d715725a2..56bea7413ef 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -31,11 +31,13 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_assert from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variable_v1 from tensorflow.python.ops import variables from tensorflow.python.training import saver as saver_lib @@ -50,7 +52,7 @@ def tfadd(_): def tfadd_with_ckpt(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.VariableV1(constant_op.constant([0]), name='y_saved') + y = variable_v1.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.global_variables_initializer() @@ -65,7 +67,7 @@ def tfadd_with_ckpt(out_dir): def tfadd_with_ckpt_saver(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.VariableV1(constant_op.constant([0]), name='y_saved') + y = variable_v1.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.global_variables_initializer() @@ -94,7 +96,7 @@ def tfcond(_): p = array_ops.placeholder(dtypes.bool, name='p_hold') x = array_ops.placeholder(dtypes.int32, name='x_hold') y = array_ops.placeholder(dtypes.int32, name='y_hold') - z = control_flow_ops.cond(p, lambda: x, lambda: y) + z = cond.cond(p, lambda: x, lambda: y) array_ops.identity(z, name='result') diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 6621f46e866..c1f8fdc089a 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -55,6 +55,7 @@ def _tfcompile_model_library_rule_impl(ctx): "--xla_cpu_fast_math_honor_functions=false " + "--xla_cpu_fast_math_honor_division=false " + "--xla_cpu_enable_fast_min_max=true " + + ctx.attr.xla_flags + " " + "$${XLA_FLAGS:-}' "), "CUDA_VISIBLE_DEVICES": "", } @@ -127,6 +128,7 @@ _tfcompile_model_library = rule( "dfsan_abilists": attr.label_list(default = [], allow_files = True), "is_linux": attr.bool(), "gen_compiler_log": attr.bool(), + "xla_flags": attr.string(), }, ) @@ -151,7 +153,8 @@ def _tf_library( mlir_components = "None", deps = None, tags = [], - copts = []): + copts = [], + xla_flags = None): if not cpp_class: fail("cpp_class must be specified") @@ -268,7 +271,7 @@ def _tf_library( tfcompile_config = config, entry_point = ep, cpp_class = cpp_class, - target_cpu = tfcompile_target_cpu(), + target_cpu = tfcompile_target_cpu(name), target_triple = target_llvm_triple(), flags = flags, extra_flags = debug_info_flags + profiling_flags + mlir_flags + traceme_flags, @@ -281,6 +284,7 @@ def _tf_library( visibility = visibility, testonly = testonly, tags = tags, + xla_flags = xla_flags, ) tfcompile_gen_object_files = tfcompile_gen + "_object_files" @@ -327,6 +331,10 @@ def _tf_library( mlir_components.count("HloLowering") > 0 and [ "//tensorflow/compiler/xla/service/cpu:runtime_mlir_utils", ] or [] + ) + ( + include_standard_runtime_deps and mlir_components == "HloLowering" and [ + "//tensorflow/compiler/xla/service/cpu/runtime:retain", + ] or [] ) + (deps or []), tags = tags, copts = copts, @@ -391,6 +399,7 @@ def _tf_library( ]), tags = tags, extra_copts = copts, + visibility = visibility, ) if gen_benchmark: @@ -437,6 +446,7 @@ def _tf_library( "//tensorflow/compiler/aot:benchmark_extra_android", ]), tags = tags, + visibility = visibility, ) def tf_library( @@ -460,7 +470,8 @@ def tf_library( mlir_components = "None", deps = None, tags = [], - copts = []): + copts = [], + xla_flags = None): """Compiles a TensorFlow graph into an executable with fast math enabled. Given an invocation of tf_library(name="foo", ...), generates the following @@ -543,6 +554,7 @@ def tf_library( deps, tags, copts, + xla_flags, ) if mlir_components == "None": _tf_library( @@ -567,6 +579,7 @@ def tf_library( deps, tags + ["notap", "local", "manual"], copts, + xla_flags, ) def target_llvm_triple(): diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index c1f14d30de7..da4fa91867f 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -87,7 +87,7 @@ int main(int argc, char** argv) { "other than flags. See --help.\n\n"; tensorflow::Status status = tensorflow::tfcompile::Main(flags); if (status.code() == absl::StatusCode::kInvalidArgument) { - std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n"; + std::cerr << "INVALID ARGUMENTS: " << status.message() << "\n\n"; return 1; } else { TF_QCHECK_OK(status); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0b6869a327d..b3fd29ff259 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -78,6 +78,7 @@ cc_library( "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core/tfrt/common:pjrt_cpu_client_registration", ] + if_libtpu( if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"], if_true = [], @@ -95,6 +96,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/core/tfrt/common:pjrt_gpu_client_registration", ]), alwayslink = 1, ) @@ -120,7 +122,6 @@ cc_library( ":jit_compilation_passes", ":xla_device", ":xla_kernel_creator", # buildcleaner: keep - "@com_google_absl//absl/memory", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -128,6 +129,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ] + if_libtpu( if_false = [ "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -146,10 +148,8 @@ cc_library( ":flags", ":jit_compilation_passes", ":xla_device", - ":xla_kernel_creator", # buildcleaner: keep ":xla_device_no_jit_rewrite_registration", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", + ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -158,6 +158,8 @@ cc_library( "//tensorflow/compiler/xla/stream_executor/gpu:gpu_init", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ] + if_libtpu( if_false = [ "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep @@ -174,14 +176,21 @@ cc_library( visibility = [":friends"], deps = [ ":xla_device", + ":xla_device_context", ":xla_kernel_creator", # buildcleaner: keep - "@com_google_absl//absl/types:optional", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/xla/stream_executor/tpu:c_api_conversions", + "//tensorflow/compiler/xla/stream_executor/tpu:status_helper", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_api", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_executor_base", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_node_context", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_platform_interface", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_stream_interface", "//tensorflow/core:framework_internal", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", @@ -190,16 +199,10 @@ cc_library( "//tensorflow/core/common_runtime:device_factory", "//tensorflow/core/common_runtime:dma_helper", "//tensorflow/core/platform:status", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_api", "//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/tpu:tpu_node_device_util", "//tensorflow/core/tpu:virtual_device", - "//tensorflow/compiler/xla/stream_executor/tpu:c_api_conversions", - "//tensorflow/compiler/xla/stream_executor/tpu:status_helper", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_executor_base", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_node_context", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_platform_interface", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_stream_interface", + "@com_google_absl//absl/types:optional", ] + if_static([ "//tensorflow/core/common_runtime:copy_tensor", ":jit_compilation_passes", @@ -289,13 +292,39 @@ XLA_DEVICE_DEPS = [ "//tensorflow/compiler/xla/stream_executor/platform", ] +cc_library( + name = "xla_device_context", + srcs = ["xla_device_context.cc"], + hdrs = ["xla_device_context.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_launch_util", + ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/common_runtime:device", + "//tensorflow/core/common_runtime:dma_helper", + "//tensorflow/core/framework:allocator", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "xla_device_no_jit_rewrite_registration", srcs = [ "xla_compile_on_demand_op.cc", "xla_compiler_options_util.cc", "xla_device.cc", - "xla_device_context.cc", "xla_device_ops.cc", "xla_ops_on_regular_devices.cc", "xla_platform_info.cc", @@ -304,7 +333,6 @@ cc_library( "xla_compile_on_demand_op.h", "xla_compiler_options_util.h", "xla_device.h", - "xla_device_context.h", "xla_device_ops.h", "xla_platform_info.h", ], @@ -313,12 +341,24 @@ cc_library( deps = XLA_DEVICE_DEPS + [ ":device_compilation_cache", ":device_compilation_profiler", + ":device_compiler", ":device_compiler_client", ":device_executable_persistor", ":flags_headers", - ":device_compiler", + ":pjrt_base_device", + ":pjrt_device_compiler_client", ":xla_device_compiler_client", + ":xla_device_context", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/core/tfrt/common:create_pjrt_client_util", + "//tensorflow/core/tfrt/common:pjrt_util", "//tensorflow/core/tpu:tpu_defs", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -328,7 +368,6 @@ cc_library( hdrs = [ "xla_compile_on_demand_op.h", "xla_device.h", - "xla_device_context.h", "xla_device_ops.h", ], # Public visibility is needed for external TF/XLA backends. @@ -337,6 +376,7 @@ cc_library( ":device_compilation_profiler", ":jit_compilation_passes", ":xla_device_no_jit_rewrite_registration", + "//tensorflow/compiler/xla/pjrt:pjrt_client", ], ) @@ -364,9 +404,11 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dump_graph", "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", + "//tensorflow/core:framework_types_hdr", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -379,12 +421,11 @@ cc_library( hdrs = ["flags.h"], visibility = [":friends"], deps = [ - "//tensorflow/compiler/mlir/tensorflow:dump_graph", - "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", + "//tensorflow/core:framework_types_hdr", "//tensorflow/core:lib", "//tensorflow/core/protobuf:for_core_protos_cc", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", ], ) @@ -475,6 +516,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", "//tensorflow/core:core_cpu_internal", @@ -484,8 +526,43 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tfrt/common:async_value_tensor", + "//tensorflow/tsl/framework:device_id_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "xla_launch_util_test", + srcs = ["xla_launch_util_test.cc"], + deps = [ + ":device_compiler", + ":flags_headers", + ":pjrt_device_compiler_client", + ":variable_info", + ":variable_info_util", + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_device_no_jit_rewrite_registration", + ":xla_launch_util", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:framework", + "//tensorflow/core/framework:fake_input", + "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/tfrt/common:create_pjrt_client_util", + "//tensorflow/core/tfrt/common:pjrt_util", + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_googletest//:gtest_main", ], ) @@ -511,6 +588,7 @@ tf_cc_test( "xla_compile_util_test.cc", ], deps = [ + ":flags_headers", ":xla_compile_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:test_main", @@ -658,6 +736,8 @@ cc_library( name = "xla_kernel_creator", srcs = [ "xla_kernel_creator.cc", + ], + hdrs = [ "xla_kernel_creator.h", ], visibility = [ @@ -777,6 +857,7 @@ cc_library( ":shape_inference", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1393,6 +1474,19 @@ cc_library( ], ) +cc_library( + name = "pjrt_base_device", + srcs = ["pjrt_base_device.cc"], + hdrs = ["pjrt_base_device.h"], + # Public visibility is needed for external TF/XLA backends. + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:local_device", + ], +) + cc_library( name = "pjrt_device_context", srcs = [ @@ -1404,12 +1498,14 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core:framework", "//tensorflow/core/platform:status", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/tfrt/common:async_value_tensor", "//tensorflow/core/tfrt/common:create_pjrt_client_util", + "//tensorflow/tsl/framework:device_id_utils", ], ) @@ -1522,11 +1618,15 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "device_context_test", srcs = ["device_context_test.cc"], - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + [ + "config-cuda-only", + "no_oss", # Temporarily disable OSS. + ], deps = [ ":flags", ":xla_device", ":xla_gpu_device", + ":xla_gpu_jit", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework_internal", "//tensorflow/core:test", @@ -1541,6 +1641,7 @@ tf_cuda_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":flags", + ":test_util", ":xla_device_no_jit_rewrite_registration", ":xla_gpu_device", ":xla_gpu_jit", @@ -1554,3 +1655,29 @@ tf_cuda_cc_test( "@com_google_googletest//:gtest_main", ], ) + +tf_cuda_cc_test( + name = "xla_platform_info_test", + srcs = ["xla_platform_info_test.cc"], + tags = tf_cuda_tests_tags() + ["config-cuda-only"], + deps = [ + ":flags_headers", + ":test_util", + ":xla_device_no_jit_rewrite_registration", + ":xla_gpu_device", + ":xla_gpu_jit", + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client", + "//tensorflow/core:framework_types_hdr", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:test", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/platform:status_matchers", + "//tensorflow/core/platform:statusor", + "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "//tensorflow/core/tfrt/common:create_pjrt_client_util", + "//tensorflow/core/tfrt/common:pjrt_util", + "//tensorflow/core/tpu:tpu_defs", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 01f2e1cf24b..cd922170c3c 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -202,7 +202,7 @@ bool RecursiveCompilabilityChecker::HasXLAKernel( Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr); if (!s.ok()) { - *uncompilable_reason = s.error_message(); + *uncompilable_reason = s.message(); return false; } return true; diff --git a/tensorflow/compiler/jit/device_compiler_disable_test.cc b/tensorflow/compiler/jit/device_compiler_disable_test.cc index cf4b5461861..7853014b2ea 100644 --- a/tensorflow/compiler/jit/device_compiler_disable_test.cc +++ b/tensorflow/compiler/jit/device_compiler_disable_test.cc @@ -68,24 +68,21 @@ TEST(DeviceCompilerTest, TestDisabledXlaCompilation) { XlaCompiler::Options{}, fn, args, XlaCompiler::CompileOptions{}, DeviceCompileMode::kStrict, profiler, &compilation_result, &executable); EXPECT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), "XLA compilation disabled")); + EXPECT_TRUE(absl::StrContains(status.message(), "XLA compilation disabled")); // Check that async compilation is disallowed. status = xla_device_compiler->CompileIfNeeded( XlaCompiler::Options{}, fn, args, XlaCompiler::CompileOptions{}, DeviceCompileMode::kAsync, profiler, &compilation_result, &executable); EXPECT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), "XLA compilation disabled")); + EXPECT_TRUE(absl::StrContains(status.message(), "XLA compilation disabled")); // Check that lazy compilation is disallowed. status = xla_device_compiler->CompileIfNeeded( XlaCompiler::Options{}, fn, args, XlaCompiler::CompileOptions{}, DeviceCompileMode::kLazy, profiler, &compilation_result, &executable); EXPECT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), "XLA compilation disabled")); + EXPECT_TRUE(absl::StrContains(status.message(), "XLA compilation disabled")); } } // namespace diff --git a/tensorflow/compiler/jit/device_context_test.cc b/tensorflow/compiler/jit/device_context_test.cc index 7bad65bcf5a..2328ec42d97 100644 --- a/tensorflow/compiler/jit/device_context_test.cc +++ b/tensorflow/compiler/jit/device_context_test.cc @@ -28,14 +28,21 @@ namespace tensorflow { namespace { static bool Initialized = [] { + auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; + rollout_config.enabled_for_xla_launch_ = true; + rollout_config.enabled_for_compile_on_demand_ = true; + tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; - tensorflow::GetXlaOpsCommonFlags()->tf_xla_use_device_api = true; return true; }(); class DeviceContextTest : public ::testing::Test { public: void SetDevice(const string& device_type) { + auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; + rollout_config.AllowForDeviceInXlaLaunch(DeviceType(device_type)); + rollout_config.AllowForDeviceInXlaCompileOnDemand(DeviceType(device_type)); + auto device_factory = DeviceFactory::GetFactory(device_type); SessionOptions options; std::vector> devices; diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index dbd202bda97..6e8fceaf47d 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -213,7 +213,8 @@ void AllocateAndParseFlags() { ops_flags = new XlaOpsCommonFlags; ops_flags->tf_xla_always_defer_compilation = false; ops_flags->tf_xla_async_compilation = false; - ops_flags->tf_xla_use_device_api = false; + ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_ = false; + ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_ = false; // The `enable_mlir_bridge` flag allows the user to explicitly request that // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge. @@ -267,9 +268,15 @@ void AllocateAndParseFlags() { "When lazy compilation is enabled, asynchronous compilation starts " "the cluster compilation in the background, and the fallback path " "is executed until the compilation has finished."), - Flag("tf_xla_use_device_api", &ops_flags->tf_xla_use_device_api, - "If true, uses the Device API (PjRt) for single device compilation." - " Defaults to false."), + Flag("tf_xla_use_device_api_for_xla_launch", + &ops_flags->tf_xla_use_device_api.enabled_for_xla_launch_, + "If true, uses Device API (PjRt) for single device compilation and " + "execution of functions marked for JIT compilation i.e. " + "jit_compile=True. Defaults to false."), + Flag("tf_xla_use_device_api_for_compile_on_demand", + &ops_flags->tf_xla_use_device_api.enabled_for_compile_on_demand_, + "If true, uses Device API (PjRt) for compiling and executing ops " + "one by one in 'on-demand' mode. Defaults to false."), Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, "Enables experimental MLIR-Based TensorFlow Compiler Bridge.", diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 650a53293fa..9f151b89eb7 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -20,7 +20,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" @@ -123,9 +125,55 @@ struct XlaOpsCommonFlags { // If true, _XlaCompile compiles the cluster asynchronously with respect to // the main execution. The fallback path is taken while compilation happens. bool tf_xla_async_compilation; - // If true, uses Device API (PjRt) for single device compilation. Defaults to - // false. - bool tf_xla_use_device_api; + + class PjRtForSingleDeviceCompilationRollout { + public: + // Allow using Device API (PjRt) for `device_type` in the XlaLaunch op. + // Please note that `enabled_for_xla_launch_` needs to be true in addition + // to the `device_type` being allowed in order to use the Device API for + // single device compilation and execution in the XlaLaunch op. + void AllowForDeviceInXlaLaunch(const DeviceType& device_type) { + xla_launch_allowed_devices_.insert(device_type.type_string()); + } + + bool IsEnabledInXlaLaunchForDevice(const DeviceType& device_type) const { + return enabled_for_xla_launch_ && + xla_launch_allowed_devices_.contains(device_type.type_string()); + } + + // Allow using Device API (PjRt) for `device_type` in the XlaCompileOnDemand + // op. Please note that `enabled_for_compile_on_demand_` needs to be true in + // addition to the `device_type` being allowed in order to use the Device + // API for single device compilation and execution in the XlaCompileOnDemand + // op. + void AllowForDeviceInXlaCompileOnDemand(const DeviceType& device_type) { + xla_compile_on_demand_allowed_devices_.insert(device_type.type_string()); + } + + bool IsEnabledInXlaCompileOnDemandForDevice( + const DeviceType& device_type) const { + return enabled_for_compile_on_demand_ && + xla_compile_on_demand_allowed_devices_.contains( + device_type.type_string()); + } + + // If true, uses Device API (PjRt) for single device compilation and + // execution of functions marked for JIT compilation i.e. jit_compile=True. + // Defaults to false. + bool enabled_for_xla_launch_; + + // If true, uses Device API (PjRt) for compiling and executing ops one by + // one in "on-demand" mode. Defaults to false. + bool enabled_for_compile_on_demand_; + + private: + // Devices for which using Device API (PjRt) is allowed in the XlaLaunch op. + // This can only be modified programmatically. + absl::flat_hash_set xla_launch_allowed_devices_; + // Devices for which using Device API (PjRt) is allowed in the + // XlaCompileOnDemand op. This can only be modified programmatically. + absl::flat_hash_set xla_compile_on_demand_allowed_devices_; + } tf_xla_use_device_api; }; // Flags for the build_xla_ops pass. diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 1f7436cdc95..8046207ed54 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -57,6 +57,7 @@ cc_library( "//tensorflow/compiler/jit:tf_graph_to_hlo_compiler", "//tensorflow/compiler/jit:tf_to_hlo_compiler", "//tensorflow/compiler/jit:xla_compile_util", + "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core/platform:refcount", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index b2547aa7e09..913cca35be3 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/kernels/xla_ops.h" +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/jit/xla_compiler_options_util.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -41,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/allocator.h" @@ -51,6 +54,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -74,6 +78,8 @@ namespace tensorflow { namespace { using XlaDeviceCompiler = DeviceCompiler; +using PjRtDeviceCompiler = + DeviceCompiler; auto* xla_launch_counter = monitoring::Counter<1>::New( "/tensorflow/core/xla_launch_counter", @@ -233,21 +239,19 @@ GetXlaCompilerArgsAndSnapshotVariables( return result; } -} // namespace +XlaCompiler::CompileOptions GenerateCompileOptions( + bool has_ref_vars, bool may_alias_resource_update) { + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + compile_options.alias_resource_update = + !has_ref_vars && may_alias_resource_update; + return compile_options; +} -XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function, - bool has_ref_vars) - : AsyncOpKernel(ctx), - constants_(constants), - resources_(resources), - function_(function), - platform_info_(XlaPlatformInfoFromDevice(ctx->device())), - has_ref_vars_(has_ref_vars) {} - -static Status CompileToLocalExecutable( +Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, const std::vector& args, @@ -288,19 +292,78 @@ static Status CompileToLocalExecutable( *xla_device_compiler, *ctx->function_library(), ctx->device(), GetStream(ctx), platform_info, has_ref_vars); - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; - // Optimization: where possible, have the computation return a naked array - // rather than a one-element tuple. - compile_options.always_return_tuple = false; - compile_options.alias_resource_update = - !has_ref_vars && may_alias_resource_update; + XlaCompiler::CompileOptions compile_options = + GenerateCompileOptions(has_ref_vars, may_alias_resource_update); return xla_device_compiler->CompileIfNeeded( options, function, args, compile_options, compile_mode, profiler, compilation_result, executable); } +Status CompileToPjRtLoadedExecutable( + const OpKernelContext& ctx, const XlaPlatformInfo& platform_info, + const NameAttrList& function, + const std::vector& args, + DeviceCompileMode compile_mode, bool has_ref_vars, + bool may_alias_resource_update, + const XlaCompiler::CompilationResult** compilation_result, + xla::PjRtClient** client, xla::PjRtLoadedExecutable** executable) { + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx.resource_manager(); + if (!rm) { + return errors::Internal("No resource manager."); + } + + PjRtDeviceCompiler* pjrt_device_compiler; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "pjrt_device_compiler", &pjrt_device_compiler, + [&](PjRtDeviceCompiler** pjrt_device_compiler) { + return BuildPjRtDeviceCompiler(platform_info, ctx.function_library(), + pjrt_device_compiler); + })); + DeviceCompilationProfiler* profiler; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "pjrt_device_compilation_profiler", &profiler, + [](DeviceCompilationProfiler** profiler) { + *profiler = new DeviceCompilationProfiler(); + return OkStatus(); + })); + // Hold the reference to the PJRT device compiler and profiler during + // evaluation. (We could probably free them sooner because the ResourceMgr + // will retain references, but this is more obviously correct.) + core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + core::ScopedUnref profiler_ref(profiler); + + *client = pjrt_device_compiler->client(); + + XlaCompiler::Options options = GenerateCompilerOptionsForPjRt( + *ctx.function_library(), ctx.device(), platform_info); + + XlaCompiler::CompileOptions compile_options = + GenerateCompileOptions(has_ref_vars, may_alias_resource_update); + + return pjrt_device_compiler->CompileIfNeeded( + options, function, args, compile_options, compile_mode, profiler, + compilation_result, executable); +} + +Status GetUpdatedVariables( + const OpKernelContext* ctx, absl::Span inputs, + absl::Span variable_indices, + const XlaCompiler::CompilationResult& compilation_result, + std::vector* variable_infos) { + std::set variables_updated; + for (const auto& resource_update : compilation_result.resource_updates) { + if (resource_update.modified) { + variables_updated.insert(resource_update.input_index); + } + } + return GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), + inputs, variable_indices, + &variables_updated, variable_infos); +} + // Get-or-create thread pool for a given collective. static thread::ThreadPool* GetOrCreateThreadPoolForCollective( const XlaCompilationResult::CollectiveInfo& collective_info) { @@ -321,6 +384,33 @@ static thread::ThreadPool* GetOrCreateThreadPoolForCollective( return &it->second; } +void RunInThreadPoolIfCollectivesPresent( + const XlaCompiler::CompilationResult& compilation_result, + std::function execution_fn) { + // If we are using collectives, we need to run in a separate threadpool. + if (compilation_result.collective_info.has_value()) { + GetOrCreateThreadPoolForCollective(*compilation_result.collective_info) + ->Schedule(execution_fn); + } else { + // Otherwise, just run normally: we merely "pretend" to be asynchronous. + execution_fn(); + } +} + +} // namespace + +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function, + bool has_ref_vars) + : AsyncOpKernel(ctx), + constants_(constants), + resources_(resources), + function_(function), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())), + has_ref_vars_(has_ref_vars) {} + void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); @@ -328,10 +418,14 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { ->IncrementBy(1); std::vector inputs = InputsFromContext(ctx); - xla::LocalClient* client; - const XlaCompiler::CompilationResult* compilation_result; - xla::LocalExecutable* executable; std::vector xla_compiler_args; + const XlaCompiler::CompilationResult* compilation_result; + + xla::LocalClient* client; // Not owned. + xla::LocalExecutable* executable; // Not owned. + + xla::PjRtClient* pjrt_client; // Not owned. + xla::PjRtLoadedExecutable* pjrt_executable; // Not owned. // Note that here we assume the shape of the variables don't change between // compilation and execution. The locks on the variables are released before @@ -357,6 +451,50 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { OP_REQUIRES_OK_ASYNC(ctx, status_or_xla_compiler_args.status(), done); xla_compiler_args = std::move(status_or_xla_compiler_args.value()); } + + bool use_pjrt = GetXlaOpsCommonFlags() + ->tf_xla_use_device_api.IsEnabledInXlaLaunchForDevice( + platform_info_.device_type()); + if (use_pjrt) { + VLOG(2) << "Compiling using PJRT"; + Status status = CompileToPjRtLoadedExecutable( + *ctx, platform_info_, function_, xla_compiler_args, + DeviceCompileMode::kStrict, has_ref_vars_, + /*may_alias_resource_update=*/true, &compilation_result, &pjrt_client, + &pjrt_executable); + OP_REQUIRES_OK_ASYNC(ctx, status, done); + + VLOG(2) << "Compiled using PJRT: " << status; + VLOG(2) << "pjrt_executable != nullptr: " << (pjrt_executable != nullptr); + VLOG(2) << "compilation_result != nullptr: " + << (compilation_result != nullptr); + VLOG(2) << "Executing using PJRT."; + + auto run_pjrt_cluster = [ctx, pjrt_client, pjrt_executable, + compilation_result, done, inputs, + resources = resources_]() { + auto platform_info = XlaPlatformInfoFromDevice(ctx->device()); + std::vector variable_infos; + OP_REQUIRES_OK_ASYNC( + ctx, + GetUpdatedVariables(ctx, inputs, resources, *compilation_result, + &variable_infos), + done); + OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)), + done); + OP_REQUIRES_OK_ASYNC( + ctx, + RunPjRtExecutable(*pjrt_client, inputs, variable_infos, + *compilation_result, pjrt_executable, ctx), + done); + VLOG(2) << "Done executing with PJRT."; + done(); + }; + + RunInThreadPoolIfCollectivesPresent(*compilation_result, run_pjrt_cluster); + return; + } + Status status = CompileToLocalExecutable( ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, xla_compiler_args, DeviceCompileMode::kStrict, @@ -369,17 +507,11 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { inputs, resources = resources_]() { auto platform_info = XlaPlatformInfoFromDevice(ctx->device()); std::vector variable_infos; - std::set variables_updated; - for (const auto& resource_update : compilation_result->resource_updates) { - if (resource_update.modified) { - variables_updated.insert(resource_update.input_index); - } - } - OP_REQUIRES_OK_ASYNC(ctx, - GetVariableInfosFromInputs( - ctx->resource_manager(), ctx->device(), inputs, - resources, &variables_updated, &variable_infos), - done); + OP_REQUIRES_OK_ASYNC( + ctx, + GetUpdatedVariables(ctx, inputs, resources, *compilation_result, + &variable_infos), + done); OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)), done); std::map resource_var_ptrs; @@ -435,14 +567,7 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { done(); }; - // If we are using collectives, we need to run in a separate threadpool. - if (compilation_result->collective_info.has_value()) { - GetOrCreateThreadPoolForCollective(*compilation_result->collective_info) - ->Schedule(run_xla_cluster); - } else { - // Otherwise, just run normally: we merely "pretend" to be asynchronous. - run_xla_cluster(); - } + RunInThreadPoolIfCollectivesPresent(*compilation_result, run_xla_cluster); } namespace { diff --git a/tensorflow/compiler/jit/pjrt_base_device.cc b/tensorflow/compiler/jit/pjrt_base_device.cc new file mode 100644 index 00000000000..d7c12921c71 --- /dev/null +++ b/tensorflow/compiler/jit/pjrt_base_device.cc @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/jit/pjrt_base_device.h" + +namespace tensorflow { +namespace { + +DeviceAttributes BuildPjRtBaseDeviceAttributes(const string& name_prefix, + const string& device_name, + int device_ordinal) { + return Device::BuildDeviceAttributes( + absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), + DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), + absl::StrCat("device: ", device_name, " device")); +} + +} // namespace + +PjRtBaseDevice::PjRtBaseDevice(const SessionOptions& session_options, + const Options& options) + : LocalDevice(session_options, + BuildPjRtBaseDeviceAttributes(options.device_name_prefix, + options.device_name, + options.device_ordinal)), + metadata_(DeviceType(options.compilation_device_name), + options.shape_determination_fns) { + if (options.shape_determination_fns.empty()) { + LOG(ERROR) << "shape_representation_fns must be non-empty."; + } + VLOG(1) << "Created PJRT base device " << options.compilation_device_name + << " device_name: " << name(); +} + +/*static*/ StatusOr +PjRtBaseDevice::GetMetadataFromDevice(DeviceBase* device) { + PjRtBaseDevice* pjrt_device = + dynamic_cast(device->UnderlyingDevice()); + if (pjrt_device == nullptr) { + return errors::Internal( + "Cannot get device metadata from non-PJRT device \"", device->name(), + "\". GetMetadata must only be called on a device derived from " + "PjRtBaseDevice. Either an internal bug has been triggered, or an " + "XLA-specific op has been placed on the wrong device."); + } + return &pjrt_device->metadata_; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/pjrt_base_device.h b/tensorflow/compiler/jit/pjrt_base_device.h new file mode 100644 index 00000000000..26c8f88efab --- /dev/null +++ b/tensorflow/compiler/jit/pjrt_base_device.h @@ -0,0 +1,111 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_BASE_DEVICE_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_BASE_DEVICE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" + +namespace tensorflow { + +// tensorflow::PjRtBaseDevice replaces the deprecated tensorflow::XlaDevice. +// This accelerator agnostic device is mainly used to store metadata. +class PjRtBaseDevice : public LocalDevice { + public: + // Stores metadata about the PjRtBaseDevice. + class Metadata { + public: + Metadata(const DeviceType& jit_device_type, + std::vector + shape_determination_fns) + : jit_device_type_(jit_device_type), + shape_determination_fns_(std::move(shape_determination_fns)) {} + + // The index of the device on this host. + int device_ordinal() const; + + const DeviceType& jit_device_type() const { return jit_device_type_; } + const XlaShapeLayoutHelpers::ShapeDeterminationFns& + default_shape_determination_fns() const { + return shape_determination_fns_.at(0); + } + + const XlaShapeLayoutHelpers::ShapeDeterminationFns& + shape_determination_fns_at(int i) const { + return shape_determination_fns_[i]; + } + + private: + const DeviceType jit_device_type_; + std::vector + shape_determination_fns_; + + TF_DISALLOW_COPY_AND_ASSIGN(Metadata); + }; + + struct Options { + // The device name's prefix (e.g., "/task:7") + std::string device_name_prefix; + + // The name of the device (e.g., "TPU") + std::string device_name; + + // The index of the device. + int device_ordinal = -1; + + // The name of the compilation device, also referred to as jit_device_type. + // (e.g., "XLA_CPU_JIT"); + std::string compilation_device_name; + + // A vector of ShapeDeterminationFn (i.e., a bundle of LayoutSelectionFn, + // ShapeRepresentationFn). Each bundle describes how the on-host shapes of + // a) argument and return value, for entry computations b) variables, for + // all computations, should be represented in XLA. Parameters/return values + // will be shaped according to the function pair, and reshaped back to/from + // their declared shapes for computations. Must be non-empty. + std::vector + shape_determination_fns; + + Options(std::string device_name_prefix, std::string device_name, + int device_ordinal, std::string compilation_device_name, + std::vector + shape_determination_fns) + : device_name_prefix(device_name_prefix), + device_name(device_name), + device_ordinal(device_ordinal), + compilation_device_name(compilation_device_name), + shape_determination_fns(shape_determination_fns) {} + }; + + // Creates a new PJRT base device. + PjRtBaseDevice(const SessionOptions& session_options, const Options& options); + + static StatusOr GetMetadataFromDevice( + DeviceBase* device); + + private: + // The metadata of this PjRtBaseDevice. + const Metadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_BASE_DEVICE_H_ diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index 6c0cd4c50cb..90e12d218d7 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -16,41 +16,68 @@ limitations under the License. #include "tensorflow/compiler/jit/pjrt_device_context.h" #include +#include #include #include "tensorflow/compiler/tf2xla/literal_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/tfrt/common/async_value_tensor.h" #include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" +#include "tensorflow/tsl/framework/device_id_utils.h" namespace tensorflow { namespace { StatusOr> HostTensorToPjRtBuffer( const tensorflow::Tensor* cpu_tensor, tensorflow::Device* device, - xla::PjRtClient* pjrt_client) { - // TODO(b/262472386): Consider layout_preference_fn and - // shape_representation_fn. - xla::Shape shape; - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(cpu_tensor->dtype(), cpu_tensor->shape(), &shape)); + xla::PjRtClient* pjrt_client, + const XlaShapeLayoutHelpers::ShapeDeterminationFns + shape_determination_fns) { + XlaLayoutPreference layout_preference = + shape_determination_fns.layout_preference_fn( + cpu_tensor->shape(), cpu_tensor->dtype(), std::nullopt); + TF_ASSIGN_OR_RETURN(xla::Shape shape, + shape_determination_fns.shape_representation_fn( + cpu_tensor->shape(), cpu_tensor->dtype(), + /*fast_mem=*/false, layout_preference)); + const xla::Layout* device_layout = &(shape.layout()); + // The device id should matche the local_hardware_id in + // tensorflow/compiler/xla/pjrt/pjrt_client.h. TF_ASSIGN_OR_RETURN( - xla::PjRtDevice * pjrt_device, - pjrt_client->LookupAddressableDevice(device->parsed_name().id)); - TF_ASSIGN_OR_RETURN( - std::unique_ptr buffer, - pjrt_client->BufferFromHostBuffer( - cpu_tensor->data(), shape.element_type(), shape.dimensions(), - /*byte_strides=*/std::nullopt, - xla::PjRtClient::HostBufferSemantics::kZeroCopy, - /*on_done_with_host_buffer=*/ - [cpu_tensor = *cpu_tensor]() { /* frees tensor */ }, pjrt_device)); - return buffer; + const int pjrt_device_id, + tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name(), + DeviceType(device->device_type()))); + TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device, + pjrt_client->LookupAddressableDevice(pjrt_device_id)); + auto first_try_buffer = pjrt_client->BufferFromHostBuffer( + cpu_tensor->data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + xla::PjRtClient::HostBufferSemantics::kZeroCopy, + /*on_done_with_host_buffer=*/ + [cpu_tensor = *cpu_tensor]() { /* frees tensor */ }, pjrt_device, + device_layout); + if (first_try_buffer.ok()) { + return std::move(*first_try_buffer); + } + if (first_try_buffer.status().code() == absl::StatusCode::kUnimplemented) { + LOG_FIRST_N(WARNING, 1) + << first_try_buffer.status() + << "; fallback to BufferFromHostBuffer without device layout."; + TF_ASSIGN_OR_RETURN( + std::unique_ptr second_try_buffer, + pjrt_client->BufferFromHostBuffer( + cpu_tensor->data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + xla::PjRtClient::HostBufferSemantics::kZeroCopy, + /*on_done_with_host_buffer=*/ + [cpu_tensor = *cpu_tensor]() { /* frees tensor */ }, pjrt_device)); + return second_try_buffer; + } else { + return first_try_buffer.status(); + } } - } // namespace void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, @@ -101,18 +128,16 @@ void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, done(pjrt_client.status()); return; } - StatusOr> buffer_or = - HostTensorToPjRtBuffer(cpu_tensor, device, *pjrt_client); + StatusOr> buffer_or = HostTensorToPjRtBuffer( + cpu_tensor, device, *pjrt_client, shape_determination_fns_); if (!buffer_or.ok()) { done(buffer_or.status()); return; } - std::unique_ptr device_buffer = std::move(buffer_or.value()); + result_tensor->SetBuffer(std::move(*buffer_or)); // TODO(b/244666476): evaluate the performance impact of marking ready when - // the data in device buffer is computed. In `tpu_device_context`, it is - // marked done when the allocation finished. - device_buffer->GetReadyFuture().OnReady(std::move(done)); - result_tensor->SetBuffer(std::move(device_buffer)); + // the data in device buffer is computed. + result_tensor->GetBuffer()->GetReadyFuture().OnReady(std::move(done)); } void PjRtDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, diff --git a/tensorflow/compiler/jit/pjrt_device_context.h b/tensorflow/compiler/jit/pjrt_device_context.h index 42e72dbd9d7..519598d3fe8 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.h +++ b/tensorflow/compiler/jit/pjrt_device_context.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ #define TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ -#include +#include -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/status.h" @@ -28,6 +28,10 @@ namespace tensorflow { // devices using PjRt. class PjRtDeviceContext : public DeviceContext { public: + explicit PjRtDeviceContext( + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns) + : shape_determination_fns_(std::move(shape_determination_fns)) {} + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done, bool sync_dst_compute) const override; @@ -37,6 +41,9 @@ class PjRtDeviceContext : public DeviceContext { void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, Tensor* output_tensor, StatusCallback done) const override; + + private: + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc index 6add3dae494..8c1268fc09b 100644 --- a/tensorflow/compiler/jit/test_util.cc +++ b/tensorflow/compiler/jit/test_util.cc @@ -15,8 +15,14 @@ limitations under the License. #include "tensorflow/compiler/jit/test_util.h" +#include +#include +#include + #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -54,4 +60,39 @@ Status ShapeAnnotationsMatch( return OkStatus(); } +void DeviceSetup::AddDevicesAndSetUp( + const std::vector& device_names) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + for (const auto& device_name : device_names) { + device_count->insert({device_name, 1}); + } + + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices)); + device_mgr_ = std::make_unique(std::move(devices)); + + OptimizerOptions opts; + lib_def_ = std::make_unique(OpRegistry::Global(), + FunctionDefLibrary()); + pflr_ = std::make_unique( + device_mgr_.get(), Env::Default(), /*config=*/nullptr, + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, + /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); +} + +Device* DeviceSetup::GetDevice(const string& device_name) { + if (device_mgr_ == nullptr) { + return nullptr; + } + + string full_device_name = absl::StrCat( + "/job:localhost/replica:0/task:0/device:", device_name, ":0"); + Device* device; + TF_CHECK_OK(device_mgr_->LookupDevice(full_device_name, &device)); + return device; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h index b5982c490df..aad58daab2a 100644 --- a/tensorflow/compiler/jit/test_util.h +++ b/tensorflow/compiler/jit/test_util.h @@ -19,11 +19,15 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ #include -#include +#include +#include +#include #include #include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -62,7 +66,20 @@ struct GraphOptimizationPassWrapper { SessionOptions session_options; }; +// Helps set up devices for unit tests. +class DeviceSetup { + public: + void AddDevicesAndSetUp(const std::vector& device_names); + Device* GetDevice(const string& device_name); + FunctionLibraryRuntime* flr() { return flr_; } + + private: + FunctionLibraryRuntime* flr_; + std::unique_ptr device_mgr_; + std::unique_ptr lib_def_; + std::unique_ptr pflr_; +}; + } // namespace tensorflow - #endif // TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ diff --git a/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc b/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc index 3eaa8202261..3da7ac13eae 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc @@ -41,7 +41,7 @@ TEST_F(DeviceCompilerSerializeTest, PersistentCacheOptionsTest) { AlterPersistentCacheEntryHloModuleNames(tensorflow::testing::TmpDir()); EXPECT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains( - status.error_message(), + status.message(), "Did not find any persistent XLA compilation cache entries to alter.")); TF_ASSERT_OK(AlterPersistentCacheEntryHloModuleNames( diff --git a/tensorflow/compiler/jit/tests/device_compiler_serialize_test.cc b/tensorflow/compiler/jit/tests/device_compiler_serialize_test.cc index 984b9852535..9233d8e43e5 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_serialize_test.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_serialize_test.cc @@ -57,8 +57,8 @@ TEST_F(DeviceCompilerSerializeTest, PersistentCacheTest) { for (int b = 1; b < 4; ++b) { auto status = ExecuteWithBatch(graph, b); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(absl::StrContains(status.error_message(), - "Serialized HLO does not match.")); + EXPECT_TRUE( + absl::StrContains(status.message(), "Serialized HLO does not match.")); } } diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index ae74dccec69..f6bdaf4e0bc 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -17,33 +17,93 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" +#include #include #include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/jit/device_compilation_profiler.h" +#include "tensorflow/compiler/jit/device_compiler.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/variable_info.h" +#include "tensorflow/compiler/jit/variable_info_util.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/jit/xla_compiler_options_util.h" -#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/tsl/platform/errors.h" namespace tensorflow { namespace { using XlaDeviceCompiler = DeviceCompiler; +using PjRtDeviceCompiler = + DeviceCompiler; + +XlaCompiler::CompileOptions GetCompileOptions(bool for_pjrt = false) { + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + if (for_pjrt) { + compile_options.use_tuple_arg = false; + compile_options.always_return_tuple = true; + } + + return compile_options; +} + +// Gets `variables` from `ctx`, locks them and builds XlaCompiler::Arguments +// using them. Stores the arguments in `args`. `variables` and `args` passed in +// will be cleared before populating them. +Status GetAndLockVariablesAndBuildXlaCompilerArguments( + const OpKernelContext& ctx, const std::vector& inputs, + const std::vector& constant_indices, + const std::vector& variable_indices, + std::vector* variables, + std::vector* args) { + variables->clear(); + args->clear(); + TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(ctx.resource_manager(), + ctx.device(), inputs, + variable_indices, variables)); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variables))); + TF_ASSIGN_OR_RETURN(*args, + XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_indices, inputs, *variables, + static_cast(ctx.device()))); + return OkStatus(); +} } // namespace -Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, - XlaDeviceCompiler* xla_device_compiler, +Status XlaCompileOnDemandOp::Run(const ResourceVarsSnapshot& variable_args, const XlaCompiler::CompilationResult* result, + const XlaDeviceCompiler* xla_device_compiler, xla::LocalExecutable* executable, - const ResourceVarsSnapshot& variable_args) { + OpKernelContext* ctx) { xla::LocalClient* client = static_cast(xla_device_compiler->client()); @@ -104,14 +164,48 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, } Status XlaCompileOnDemandOp::Compile( - OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, - XlaDeviceCompiler** xla_device_compiler, - DeviceCompilationProfiler** profiler, ResourceVarsSnapshot* variable_args, - xla::LocalExecutable** executable) { - TF_ASSIGN_OR_RETURN(std::vector constant_input_indices, - GetConstantInputIndicesFromContext(ctx)); - std::vector inputs = InputsFromContext(ctx); + const std::vector& args, OpKernelContext* ctx, + PjRtDeviceCompiler** pjrt_device_compiler, + DeviceCompilationProfiler** profiler, + const XlaCompiler::CompilationResult** result, + xla::PjRtLoadedExecutable** executable) { + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + if (!rm) { + return errors::Internal("No resource manager."); + } + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "pjrt_device_compiler", pjrt_device_compiler, + [&](PjRtDeviceCompiler** pjrt_device_compiler) { + return BuildPjRtDeviceCompiler(platform_info_, ctx->function_library(), + pjrt_device_compiler); + })); + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "pjrt_device_compilation_profiler", profiler, + [](DeviceCompilationProfiler** profiler) { + *profiler = new DeviceCompilationProfiler(); + return OkStatus(); + })); + + XlaCompiler::Options options = GenerateCompilerOptionsForPjRt( + *(ctx->function_library()), ctx->device(), platform_info_); + // No detailed logging for on demand op. + options.detailed_logging = false; + XlaCompiler::CompileOptions compile_options = GetCompileOptions(true); + + return (*pjrt_device_compiler) + ->CompileSingleOpIfNeeded(options, args, compile_options, ctx, *profiler, + result, executable); +} + +Status XlaCompileOnDemandOp::Compile( + const std::vector& args, OpKernelContext* ctx, + XlaDeviceCompiler** xla_device_compiler, + DeviceCompilationProfiler** profiler, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable) { // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); @@ -137,54 +231,87 @@ Status XlaCompileOnDemandOp::Compile( platform_info_, /*has_ref_vars=*/true); // No detailed logging from on demand op. options.detailed_logging = false; - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; - // Optimization: where possible, have the computation return a naked array - // rather than a one-element tuple. - compile_options.always_return_tuple = false; - - std::vector variables_indices = - GetResourceVariableIndicesFromContext(ctx); - StatusOr> args; - { - std::vector variable_infos; - TF_RETURN_IF_ERROR( - GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), - inputs, variables_indices, &variable_infos)); - - TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); - TF_RETURN_IF_ERROR(SnapshotResourceVariables( - ctx, variables_indices, variable_infos, variable_args)); - - args = XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_input_indices, inputs, variable_infos, - static_cast(ctx->device())); - TF_RETURN_IF_ERROR(args.status()); - } + XlaCompiler::CompileOptions compile_options = GetCompileOptions(); return (*xla_device_compiler) - ->CompileSingleOpIfNeeded(options, *args, compile_options, ctx, *profiler, + ->CompileSingleOpIfNeeded(options, args, compile_options, ctx, *profiler, result, executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* result; - xla::LocalExecutable* executable; - ResourceVarsSnapshot variable_args; - XlaDeviceCompiler* xla_device_compiler; DeviceCompilationProfiler* profiler; + OP_REQUIRES(ctx, ctx->function_library(), errors::Internal("Function library missing")); - OP_REQUIRES_OK(ctx, Compile(ctx, &result, &xla_device_compiler, &profiler, - &variable_args, &executable)); - // Hold the reference to the XLA device compiler and profiler during - // evaluation. (We could probably free them sooner because the ResourceMgr - // will retain references, but this is more obviously correct.) - core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); - core::ScopedUnref profiler_ref(profiler); - OP_REQUIRES_OK( - ctx, Run(ctx, xla_device_compiler, result, executable, variable_args)); + // Get constants, inputs and variables from the OpKernelContext. + auto constant_indices_or = GetConstantInputIndicesFromContext(ctx); + OP_REQUIRES_OK(ctx, constant_indices_or.status()); + std::vector inputs = InputsFromContext(ctx); + std::vector variable_indices = + GetResourceVariableIndicesFromContext(ctx); + + bool use_pjrt = + GetXlaOpsCommonFlags() + ->tf_xla_use_device_api.IsEnabledInXlaCompileOnDemandForDevice( + platform_info_.device_type()); + if (use_pjrt) { + std::vector variables; + std::vector args; + // Lock variables for the whole duration of compile + execute. + OP_REQUIRES_OK(ctx, GetAndLockVariablesAndBuildXlaCompilerArguments( + *ctx, inputs, *constant_indices_or, + variable_indices, &variables, &args)); + + PjRtDeviceCompiler* pjrt_device_compiler; + xla::PjRtLoadedExecutable* pjrt_executable; + OP_REQUIRES_OK(ctx, Compile(args, ctx, &pjrt_device_compiler, &profiler, + &result, &pjrt_executable)); + // Hold the reference to the XLA device compiler and profiler during + // evaluation. (We could probably free them sooner because the ResourceMgr + // will retain references, but this is more obviously correct.) + core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + core::ScopedUnref profiler_ref(profiler); + + VLOG(2) << "Compiled op with PJRT: " << ctx->status(); + VLOG(2) << "result != nullptr: " << (result != nullptr); + VLOG(2) << "pjrt_executable != nullptr: " << (pjrt_executable != nullptr); + VLOG(2) << "Executing with PJRT ..."; + + OP_REQUIRES_OK(ctx, + RunPjRtExecutable(*pjrt_device_compiler->client(), inputs, + variables, *result, pjrt_executable, ctx)); + + VLOG(2) << "Completed executing with PJRT!"; + } else { + ResourceVarsSnapshot variable_args; + std::vector args; + // Lock variables only for generating XlaCompiler::Arguments and then + // release them. + { + std::vector variables; + OP_REQUIRES_OK(ctx, GetAndLockVariablesAndBuildXlaCompilerArguments( + *ctx, inputs, *constant_indices_or, + variable_indices, &variables, &args)); + OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, variable_indices, + variables, &variable_args)); + } + + XlaDeviceCompiler* xla_device_compiler; + xla::LocalExecutable* executable; + OP_REQUIRES_OK(ctx, Compile(args, ctx, &xla_device_compiler, &profiler, + &result, &executable)); + // Hold the reference to the XLA device compiler and profiler during + // evaluation. (We could probably free them sooner because the ResourceMgr + // will retain references, but this is more obviously correct.) + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); + core::ScopedUnref profiler_ref(profiler); + + // Locks are acquired again when populating the `ctx` outputs. + OP_REQUIRES_OK( + ctx, Run(variable_args, result, xla_device_compiler, executable, ctx)); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index c5e6e8a8c72..ced95edc604 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -19,14 +19,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ +#include + #include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/variable_info.h" #include "tensorflow/compiler/jit/variable_info_util.h" -#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -45,21 +47,27 @@ class XlaCompileOnDemandOp : public OpKernel { void Compute(OpKernelContext* ctx) override; private: - XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64_t i); - Status Compile(OpKernelContext* ctx, - const XlaCompiler::CompilationResult** result, + Status Compile(const std::vector& args, + OpKernelContext* ctx, DeviceCompiler** xla_device_compiler, DeviceCompilationProfiler** profiler, - ResourceVarsSnapshot* variable_args, + const XlaCompiler::CompilationResult** result, xla::LocalExecutable** executable); - Status Run(OpKernelContext* ctx, - DeviceCompiler* - xla_device_compiler, + Status Compile(const std::vector& args, + OpKernelContext* ctx, + DeviceCompiler** + pjrt_device_compiler, + DeviceCompilationProfiler** profiler, + const XlaCompiler::CompilationResult** result, + xla::PjRtLoadedExecutable** executable); + + Status Run(const ResourceVarsSnapshot& variable_args, const XlaCompiler::CompilationResult* result, - xla::LocalExecutable* executable, - const ResourceVarsSnapshot& variable_args); + const DeviceCompiler* + xla_device_compiler, + xla::LocalExecutable* executable, OpKernelContext* ctx); const XlaPlatformInfo platform_info_; }; diff --git a/tensorflow/compiler/jit/xla_compile_util.cc b/tensorflow/compiler/jit/xla_compile_util.cc index 8d72d20ba55..e5256a8b2c9 100644 --- a/tensorflow/compiler/jit/xla_compile_util.cc +++ b/tensorflow/compiler/jit/xla_compile_util.cc @@ -66,8 +66,10 @@ StatusOr> CreateSingleOpGraph( return graph; } -bool UsePjRtForSingleDeviceCompilation() { - return GetXlaOpsCommonFlags()->tf_xla_use_device_api; +bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type) { + const auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; + return rollout_config.IsEnabledInXlaLaunchForDevice(device_type) || + rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_util.h b/tensorflow/compiler/jit/xla_compile_util.h index bdc0ebafad5..345c55a86e5 100644 --- a/tensorflow/compiler/jit/xla_compile_util.h +++ b/tensorflow/compiler/jit/xla_compile_util.h @@ -44,7 +44,9 @@ StatusOr> CreateSingleOpGraph( const NodeDef& node_def, absl::Span args, absl::Span result_types); -bool UsePjRtForSingleDeviceCompilation(); +// Checks if single device compilation and execution with PJRT is enabled for +// `device_type` in either the XlaLaunch op or the XlaCompileOnDemand op. +bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compile_util_test.cc b/tensorflow/compiler/jit/xla_compile_util_test.cc index 0e971a6b4db..9fc706fb649 100644 --- a/tensorflow/compiler/jit/xla_compile_util_test.cc +++ b/tensorflow/compiler/jit/xla_compile_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/kernels/ops_testutil.h" @@ -73,5 +74,49 @@ TEST_F(OpsTestBase, CreateSingleOpGraph) { EXPECT_EQ(retval_input_node->name(), "identity_op"); } +TEST(XlaCompileUtilTest, PjRtXlaLaunchFlagTest) { + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // Flag is turned on, but no device is allowlisted. + auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; + rollout_config.enabled_for_xla_launch_ = true; + + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // Flag is turned on, some device is allowlisted, but the requested one isn't. + rollout_config.AllowForDeviceInXlaLaunch(DeviceType(DEVICE_GPU)); + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // Flag is turned on and the requested device is allowlisted. + rollout_config.AllowForDeviceInXlaLaunch(DeviceType(DEVICE_CPU)); + EXPECT_TRUE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // The requested device is allowlisted, but the flag is turned off. + rollout_config.enabled_for_xla_launch_ = false; + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); +} + +TEST(XlaCompileUtilTest, PjRtXlaCompileOnDemandFlagTest) { + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // Flag is turned on, but no device is allowlisted. + auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; + rollout_config.enabled_for_compile_on_demand_ = true; + + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // Flag is turned on, some device is allowlisted, but the requested one isn't. + rollout_config.AllowForDeviceInXlaCompileOnDemand(DeviceType(DEVICE_GPU)); + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // Flag is turned on and the requested device is allowlisted. + rollout_config.AllowForDeviceInXlaCompileOnDemand(DeviceType(DEVICE_CPU)); + EXPECT_TRUE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); + + // The requested device is allowlisted, but the flag is turned off. + rollout_config.enabled_for_compile_on_demand_ = false; + EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compiler_options_util.cc b/tensorflow/compiler/jit/xla_compiler_options_util.cc index c8ffcfa6d8b..8580bcfbeef 100644 --- a/tensorflow/compiler/jit/xla_compiler_options_util.cc +++ b/tensorflow/compiler/jit/xla_compiler_options_util.cc @@ -86,8 +86,13 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt( options.device_ordinal = device_base->parsed_name().id; options.flib_def = function_library.GetFunctionLibraryDefinition(); options.graph_def_version = function_library.graph_def_version(); - if (platform_info.xla_device_metadata()) { - auto metadata = platform_info.xla_device_metadata(); + if (const auto* metadata = platform_info.xla_device_metadata(); + metadata != nullptr) { + options.device_type = metadata->jit_device_type(); + options.shape_determination_fns = + metadata->default_shape_determination_fns(); + } else if (const auto* metadata = platform_info.pjrt_device_metadata(); + metadata != nullptr) { options.device_type = metadata->jit_device_type(); options.shape_determination_fns = metadata->default_shape_determination_fns(); diff --git a/tensorflow/compiler/jit/xla_compiler_options_util_test.cc b/tensorflow/compiler/jit/xla_compiler_options_util_test.cc index 06bcfe2facb..2a4742567e4 100644 --- a/tensorflow/compiler/jit/xla_compiler_options_util_test.cc +++ b/tensorflow/compiler/jit/xla_compiler_options_util_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/test_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -52,8 +54,8 @@ XlaDeviceCompiler* CreateXlaDeviceCompiler( std::move(compiler_client)); } -std::unique_ptr CreateXlaDeviceMetadata( - DeviceType compilation_device_type) { +std::vector +GetShapeDeterminationFns() { XlaHelpers::ShapeRepresentationFn shape_representation_fn = [](const TensorShape&, DataType, bool, XlaLayoutPreference) { return xla::Shape(); @@ -62,73 +64,83 @@ std::unique_ptr CreateXlaDeviceMetadata( [](const TensorShape&, DataType, std::optional) { return tensorflow::XlaLayoutPreference::kTpuPreferLinearLayout; }; - std::vector - shape_determination_fns = {XlaShapeLayoutHelpers::ShapeDeterminationFns{ - layout_preference_fn, shape_representation_fn}}; + return {XlaShapeLayoutHelpers::ShapeDeterminationFns{ + layout_preference_fn, shape_representation_fn}}; +} + +std::unique_ptr CreateXlaDeviceMetadata( + DeviceType compilation_device_type) { return std::make_unique( /*device_ordinal=*/0, /*platform=*/nullptr, compilation_device_type, - shape_determination_fns, XlaDevice::PaddedShapeFn(), + GetShapeDeterminationFns(), XlaDevice::PaddedShapeFn(), /*use_multiple_streams=*/false); } +std::unique_ptr CreatePjRtDeviceMetadata( + DeviceType compilation_device_type) { + return std::make_unique(compilation_device_type, + GetShapeDeterminationFns()); +} + class XlaCompilerOptionsTest : public ::testing::Test { protected: void SetUp() override { tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; } - void AddDevicesAndSetUp(const std::vector& device_names) { - SessionOptions options; - auto* device_count = options.config.mutable_device_count(); - for (const auto& device_name : device_names) { - device_count->insert({device_name, 1}); - } - - std::vector> devices; - TF_CHECK_OK(DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices)); - device_mgr_ = std::make_unique(std::move(devices)); - - OptimizerOptions opts; - lib_def_ = std::make_unique( - OpRegistry::Global(), FunctionDefLibrary()); - pflr_ = std::make_unique( - device_mgr_.get(), Env::Default(), /*config=*/nullptr, - TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, - /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); - flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); - } - - Device* GetXlaGpuDevice() { - if (device_mgr_ == nullptr) { - return nullptr; - } - - Device* device; - TF_CHECK_OK(device_mgr_->LookupDevice( - "/job:localhost/replica:0/task:0/device:XLA_GPU:0", &device)); - return device; - } - - FunctionLibraryRuntime* flr_; - std::unique_ptr device_mgr_; - std::unique_ptr lib_def_; - std::unique_ptr pflr_; + DeviceSetup device_setup_; }; TEST_F(XlaCompilerOptionsTest, PjRtOptionsXlaDevice) { - AddDevicesAndSetUp({DEVICE_XLA_GPU}); - Device* device = GetXlaGpuDevice(); + device_setup_.AddDevicesAndSetUp({DEVICE_XLA_GPU}); + Device* device = device_setup_.GetDevice(DEVICE_XLA_GPU); DeviceType compilation_device_type = DeviceType(DEVICE_GPU_XLA_JIT); se::Platform::Id platform_id = nullptr; auto xla_device_metadata = CreateXlaDeviceMetadata(compilation_device_type); std::shared_ptr custom_allocator; - XlaPlatformInfo platform_info(compilation_device_type, platform_id, - xla_device_metadata.get(), custom_allocator); + XlaPlatformInfo platform_info( + compilation_device_type, platform_id, xla_device_metadata.get(), + /*pjrt_device_metadata=*/nullptr, custom_allocator); - XlaCompiler::Options options = - GenerateCompilerOptionsForPjRt(*flr_, device, platform_info); + XlaCompiler::Options options = GenerateCompilerOptionsForPjRt( + *device_setup_.flr(), device, platform_info); + + EXPECT_EQ(options.device_type, compilation_device_type); + EXPECT_EQ(options.device_ordinal, 0); + EXPECT_NE(options.flib_def, nullptr); + EXPECT_EQ(options.graph_def_version, TF_GRAPH_DEF_VERSION); + EXPECT_FALSE(options.allow_cpu_custom_calls); + EXPECT_FALSE(options.alias_passthrough_params); + EXPECT_FALSE(options.detailed_logging); + // Check if options have the supplied shape determination functions set. + TF_ASSERT_OK_AND_ASSIGN( + auto shape, options.shape_determination_fns.shape_representation_fn( + TensorShape(), DT_FLOAT, false, + tensorflow::XlaLayoutPreference::kTpuPreferLinearLayout)); + EXPECT_EQ(shape, xla::Shape()); + EXPECT_EQ(options.shape_determination_fns.layout_preference_fn( + TensorShape(), DT_FLOAT, std::nullopt), + tensorflow::XlaLayoutPreference::kTpuPreferLinearLayout); +} + +TEST_F(XlaCompilerOptionsTest, PjRtOptionsPjRtBaseDevice) { + // Although DEVICE_CPU isn't a PjRtBaseDevice, we use it here just for testing + // purposes and to keep things simple. Creating a TpuDevice or + // NextPluggableDevice in the context of this unit test is non-trivial. + device_setup_.AddDevicesAndSetUp({DEVICE_CPU}); + Device* device = device_setup_.GetDevice(DEVICE_CPU); + DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT); + + auto pjrt_device_metadata = CreatePjRtDeviceMetadata(compilation_device_type); + XlaPlatformInfo platform_info( + compilation_device_type, /*platform_id=*/nullptr, + /*xla_device_metadata=*/nullptr, + /*pjrt_device_metadata=*/pjrt_device_metadata.get(), + /*device_allocator=*/nullptr); + + XlaCompiler::Options options = GenerateCompilerOptionsForPjRt( + *device_setup_.flr(), device, platform_info); EXPECT_EQ(options.device_type, compilation_device_type); EXPECT_EQ(options.device_ordinal, 0); @@ -149,12 +161,12 @@ TEST_F(XlaCompilerOptionsTest, PjRtOptionsXlaDevice) { } TEST_F(XlaCompilerOptionsTest, XlaOptions) { - AddDevicesAndSetUp({DEVICE_XLA_CPU}); - Device* device = device_mgr_->HostCPU(); + device_setup_.AddDevicesAndSetUp({DEVICE_XLA_GPU}); + Device* device = device_setup_.GetDevice(DEVICE_XLA_GPU); xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); - DeviceType device_type = DeviceType(DEVICE_XLA_CPU); - DeviceType compilation_device_type = DeviceType(DEVICE_CPU_XLA_JIT); + DeviceType device_type = DeviceType(DEVICE_XLA_GPU); + DeviceType compilation_device_type = DeviceType(DEVICE_GPU_XLA_JIT); auto xla_device_compiler = CreateXlaDeviceCompiler( XlaDeviceExecutablePersistor::Config(), compilation_device_type, client); @@ -163,11 +175,13 @@ TEST_F(XlaCompilerOptionsTest, XlaOptions) { se::Platform::Id platform_id = se::host::kHostPlatformId; auto xla_device_metadata = CreateXlaDeviceMetadata(compilation_device_type); std::shared_ptr custom_allocator; - XlaPlatformInfo platform_info(device_type, platform_id, - xla_device_metadata.get(), custom_allocator); + XlaPlatformInfo platform_info( + device_type, platform_id, xla_device_metadata.get(), + /*pjrt_device_metadata=*/nullptr, custom_allocator); - XlaCompiler::Options options = GenerateCompilerOptions( - *xla_device_compiler, *flr_, device, nullptr, platform_info, false); + XlaCompiler::Options options = + GenerateCompilerOptions(*xla_device_compiler, *device_setup_.flr(), + device, nullptr, platform_info, false); EXPECT_EQ(options.device_type, compilation_device_type); EXPECT_NE(options.flib_def, nullptr); @@ -187,8 +201,8 @@ TEST_F(XlaCompilerOptionsTest, XlaOptions) { } TEST_F(XlaCompilerOptionsTest, XlaOptionsHasRefVarsNoXlaDeviceMetadata) { - AddDevicesAndSetUp({DEVICE_CPU}); - Device* device = device_mgr_->HostCPU(); + device_setup_.AddDevicesAndSetUp({DEVICE_CPU}); + Device* device = device_setup_.GetDevice(DEVICE_CPU); xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); DeviceType device_type = DeviceType(DEVICE_CPU); @@ -200,11 +214,13 @@ TEST_F(XlaCompilerOptionsTest, XlaOptionsHasRefVarsNoXlaDeviceMetadata) { se::Platform::Id platform_id = se::host::kHostPlatformId; std::shared_ptr custom_allocator; - XlaPlatformInfo platform_info(device_type, platform_id, nullptr, - custom_allocator); + XlaPlatformInfo platform_info( + device_type, platform_id, /*xla_device_metadata=*/nullptr, + /*pjrt_device_metadata=*/nullptr, custom_allocator); - XlaCompiler::Options options = GenerateCompilerOptions( - *xla_device_compiler, *flr_, device, nullptr, platform_info, false); + XlaCompiler::Options options = + GenerateCompilerOptions(*xla_device_compiler, *device_setup_.flr(), + device, nullptr, platform_info, false); EXPECT_EQ(options.device_type, compilation_device_type); EXPECT_NE(options.flib_def, nullptr); @@ -227,7 +243,7 @@ TEST_F(XlaCompilerOptionsTest, XlaOptionsHasRefVarsNoXlaDeviceMetadata) { } TEST_F(XlaCompilerOptionsTest, TfRtTpuOptions) { - AddDevicesAndSetUp({DEVICE_TPU_NODE}); + device_setup_.AddDevicesAndSetUp({DEVICE_TPU_NODE}); // Just use the default local client for testing purposes. xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); @@ -237,8 +253,8 @@ TEST_F(XlaCompilerOptionsTest, TfRtTpuOptions) { XlaDeviceExecutablePersistor::Config(), compilation_device_type, client); core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); - XlaCompiler::Options options = - GenerateCompilerOptionsForTfrtTpu(*xla_device_compiler, *flr_); + XlaCompiler::Options options = GenerateCompilerOptionsForTfrtTpu( + *xla_device_compiler, *device_setup_.flr()); EXPECT_EQ(options.device_type, compilation_device_type); EXPECT_NE(options.flib_def, nullptr); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 883f0edd91b..4742f8a72ea 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/jit/xla_device_context.h" -#include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -213,6 +212,7 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, : DefaultPaddedShapeFn, options.use_multiple_streams), device_ordinal_(options.device_ordinal), + device_name_(options.device_name), jit_device_name_(options.compilation_device_name), platform_(options.platform), intra_op_parallelism_threads_( @@ -272,7 +272,7 @@ Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) { } if (xla_allocator_ == nullptr) { - if (UsePjRtForSingleDeviceCompilation()) { + if (UsePjRtForSingleDeviceCompilation(device_name_)) { VLOG(1) << "XlaDevice " << this << " uses AsyncValueAllocator"; pjrt_allocator_ = std::make_unique(); xla_allocator_ = pjrt_allocator_.get(); @@ -308,16 +308,14 @@ Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, } StatusOr> XlaDevice::GetDeviceContextLocked() { - if (UsePjRtForSingleDeviceCompilation()) { - // TODO(b/262472386) Support shape_determination_fns with PJRT. - if (shape_determination_fns_.size() > 1) { - return errors::Unimplemented( - "Use PJRT with multiple ShapeDeterminationFn is not implemented."); - } + if (UsePjRtForSingleDeviceCompilation(device_name_)) { if (device_contexts_.empty()) { - device_contexts_.emplace_back(new PjRtDeviceContext()); - VLOG(1) << "XlaDevice " << this << " new PjRtDeviceContext " - << device_contexts_[0]; + for (const auto& iter : shape_determination_fns_) { + auto device_context = new PjRtDeviceContext(iter); + VLOG(1) << "XlaDevice " << this << " new PjRtDeviceContext " + << device_context; + device_contexts_.emplace_back(device_context); + } if (use_accelerator_device_info_) { auto accelerator_device_info = std::make_unique(); diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index e902afb8425..26c7a8d9a1b 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -27,7 +27,6 @@ limitations under the License. #include #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -226,6 +225,8 @@ class XlaDevice : public LocalDevice { const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; + // The name/type of this XlaDevice. eg. "XLA_GPU". + const DeviceType device_name_; // The name of the device that is used to compile Ops for this XlaDevice. const DeviceType jit_device_name_; // The platform for this device. diff --git a/tensorflow/compiler/jit/xla_device_compiler_client.cc b/tensorflow/compiler/jit/xla_device_compiler_client.cc index 37ebffd95dc..46689a0d547 100644 --- a/tensorflow/compiler/jit/xla_device_compiler_client.cc +++ b/tensorflow/compiler/jit/xla_device_compiler_client.cc @@ -102,6 +102,8 @@ XlaDeviceCompilerClient::LoadExecutable( } void XlaDeviceCompilerClient::WaitForProgramsToFinish() { + if (client_ == nullptr) return; + for (auto* executor : client_->backend().stream_executors()) { bool ok = executor->SynchronizeAllActivity(); if (!ok) { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ee00464178f..0309086b41d 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index 82aaa368d93..9305de9e47d 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 934071d7ca4..fbf24aeda65 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_kernel_creator.h" +#include +#include + #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -25,9 +28,18 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_utils.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/tsl/platform/errors.h" namespace tensorflow { @@ -87,7 +99,7 @@ Status XlaKernelCreator::CreateKernel( return CreateXlaKernel(flr, props->node_def, kernel); } -static bool RegisterLaunchOpCreator() { +bool RegisterLaunchOpCreator() { XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator(); RegisterDefaultCustomKernelCreator(xla_kernel_creator); return true; diff --git a/tensorflow/compiler/jit/xla_kernel_creator.h b/tensorflow/compiler/jit/xla_kernel_creator.h index 856701a791d..843a21acd19 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.h +++ b/tensorflow/compiler/jit/xla_kernel_creator.h @@ -15,8 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ #define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ +#include + #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -39,6 +42,8 @@ class XlaKernelCreator : public CustomKernelCreator { std::unique_ptr* kernel) const override; }; +bool RegisterLaunchOpCreator(); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index bc348948e4f..0ae56482025 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/variable_info.h" #include "tensorflow/compiler/jit/variable_info_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -41,8 +42,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/tfrt/common/async_value_tensor.h" #include "tensorflow/core/util/stream_executor_util.h" +#include "tensorflow/tsl/framework/device_id_utils.h" #include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace { @@ -60,10 +64,17 @@ se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) { return platform_id; } +absl::flat_hash_map CreateVariableLookup( + const std::vector& variables) { + absl::flat_hash_map variable_lookup; + for (int i = 0; i < variables.size(); i++) { + variable_lookup[variables[i].index()] = i; + } + return variable_lookup; +} + } // anonymous namespace - - std::vector InputsFromContext(OpKernelContext* ctx) { std::vector inputs; inputs.reserve(ctx->num_inputs()); @@ -576,4 +587,157 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( return out; } +void PreparePjRtExecutableArguments( + const std::vector& input_mapping, + const std::vector& inputs, + const std::vector& variables, + std::vector* args, + absl::flat_hash_set* non_donatable_input_indices) { + const auto& variable_lookup = CreateVariableLookup(variables); + + for (auto arg_num : input_mapping) { + const Tensor* tensor; + if (auto it = variable_lookup.find(arg_num); it != variable_lookup.end()) { + tensor = variables[it->second].var()->tensor(); + } else { + tensor = inputs[arg_num]; + } + if (!tensor->RefCountIsOne()) { + non_donatable_input_indices->insert(arg_num); + } + + AsyncValueTensor* av_tensor = AsyncValueTensor::FromTensor(tensor); + if (av_tensor->GetBuffer() == nullptr) { + // TODO(b/260799971): verify size 0 argument is supported. + CHECK_EQ(tensor->NumElements(), 0); // Crash OK + continue; + } + args->push_back(av_tensor->GetBuffer().get()); + } +} + +Status PopulateCtxOutputsFromPjRtExecutableOutputs( + const std::vector& inputs, + const std::vector& variables, + const XlaCompiler::CompilationResult& compilation_result, + std::vector>& executable_outputs, + OpKernelContext* ctx) { + const auto& variable_lookup = CreateVariableLookup(variables); + + // Copy XLA results to the OpOutputList. + int output_num = 0; + for (int i = 0, end = ctx->num_outputs(); i < end; ++i) { + const DataType& type = compilation_result.outputs[i].type; + VLOG(2) << "Populating output for retval " << i << " type " + << DataTypeString(type); + + if (compilation_result.outputs[i].is_constant) { + bool requires_copy_to_device = GetDeviceType(ctx) != DEVICE_CPU; + TF_RETURN_IF_ERROR(SetOutputForConstant(ctx, requires_copy_to_device, + &compilation_result, i)); + } else if (type == DT_RESOURCE) { + int input_index = compilation_result.outputs[i].input_index; + TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs()) + << "Invalid input for outputs " << i << ": " << input_index; + ctx->set_output(i, *inputs[input_index]); + } else { + Tensor* output_tensor; + TF_ASSIGN_OR_RETURN( + xla::Shape device_shape, + executable_outputs[output_num]->logical_on_device_shape()); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(device_shape, &tensor_shape)); + TF_RETURN_IF_ERROR(ctx->allocate_output(i, tensor_shape, &output_tensor)); + auto output_avt = AsyncValueTensor::FromTensor(output_tensor); + output_avt->SetBuffer(std::move(executable_outputs[output_num])); + ++output_num; + } + } + + // Apply variable updates, if any. + for (int i = 0, end = compilation_result.resource_updates.size(); i < end; + ++i) { + const XlaCompiler::ResourceUpdate& write = + compilation_result.resource_updates[i]; + int actual_input_index = write.input_index; + CHECK_GE(actual_input_index, 0); // Crash OK + CHECK_LT(actual_input_index, ctx->num_inputs()); // Crash OK + auto it = variable_lookup.find(actual_input_index); + if (it == variable_lookup.end()) { + continue; + } + Var* var = variables[it->second].var(); + CHECK(var); // Crash OK + + VLOG(2) << "Updating variable #" << i + << " at input index: " << actual_input_index << " with shape " + << write.shape.DebugString() << "; variable tensor has shape: " + << var->tensor()->shape().DebugString(); + + if (var->is_initialized && var->tensor()->dtype() != write.type) { + return errors::Internal("Mismatched type in variable write"); + } + + TF_RETURN_IF_ERROR(ctx->allocate_temp( + var->tensor()->dtype(), var->tensor()->shape(), var->tensor())); + AsyncValueTensor::FromTensor(var->tensor()) + ->SetBuffer(std::move(executable_outputs[output_num])); + var->is_initialized |= write.modified; + ++output_num; + } + return OkStatus(); +} + +xla::ExecuteOptions GetPjRtExecuteOptions( + absl::flat_hash_set non_donatable_input_indices) { + xla::ExecuteOptions options; + options.arguments_are_tupled = false; + options.untuple_result = true; + // Note: TF does not use PJRT host callbacks as of today. Setting this option + // to true to workaround an ExecuteOptions check: [1]. + // + // [1]: + // tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc;l=923-927;rcl=519286815 + options.use_major_to_minor_data_layout_for_callbacks = true; + options.non_donatable_input_indices = std::move(non_donatable_input_indices); + return options; +} + +DeviceType GetDeviceType(OpKernelContext* ctx) { + auto* device = + tensorflow::down_cast(ctx->device()->UnderlyingDevice()); + return DeviceType(device->device_type()); +} + +Status RunPjRtExecutable( + const xla::PjRtClient& pjrt_client, + const std::vector& inputs, + const std::vector& variables, + const XlaCompiler::CompilationResult& compilation_result, + xla::PjRtLoadedExecutable* executable, OpKernelContext* ctx) { + TF_ASSIGN_OR_RETURN(const int pjrt_device_id, + tsl::GetDeviceIdFromDeviceParsedName( + ctx->device()->parsed_name(), GetDeviceType(ctx))); + TF_ASSIGN_OR_RETURN(xla::PjRtDevice * device, + pjrt_client.LookupAddressableDevice(pjrt_device_id)); + + std::vector executable_args; + executable_args.reserve(compilation_result.input_mapping.size()); + absl::flat_hash_set non_donatable_input_indices; + PreparePjRtExecutableArguments(compilation_result.input_mapping, inputs, + variables, &executable_args, + &non_donatable_input_indices); + // TODO(b/257548614): currently PJRT is compiled as portable (num_replica = 1 + // and num_partition = 1). Support multiple partitions case. + TF_ASSIGN_OR_RETURN( + std::vector> execute_outputs, + executable->ExecutePortable( + executable_args, device, + GetPjRtExecuteOptions(std::move(non_donatable_input_indices)))); + + TF_RETURN_IF_ERROR(PopulateCtxOutputsFromPjRtExecutableOutputs( + inputs, variables, compilation_result, execute_outputs, ctx)); + return OkStatus(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 0e7a806d79c..1a9771068fc 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ #include +#include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" #include "tensorflow/core/framework/allocation_description.pb.h" @@ -51,6 +53,53 @@ Status SetOutputForConstant( OpKernelContext* ctx, bool requires_copy_to_device, const XlaCompiler::CompilationResult* compilation_result, int output_num); +// Converts input tensors and variables which are parameters of the +// XlaComputation into PjRtBuffers to be fed as input to the +// PjRtLoadedExecutable. `input_mapping` is a vector that maps from the +// parameters of the XlaComputation to their original argument positions. This +// can be sourced from `XlaCompiler::CompilationResult::input_mapping`. +// +// The obtained PjRtBuffers are populated to `args` vector. +// `non_donatable_input_indices` will also be set, which contains the indices of +// the input that should not be donated to output. +void PreparePjRtExecutableArguments( + const std::vector& input_mapping, + const std::vector& inputs, + const std::vector& variables, + std::vector* args, + absl::flat_hash_set* non_donatable_input_indices); + +// Populates the OpKernelContext outputs with the outputs of the +// PjRtLoadedExecutable. Requires the `compilation_result` used to build the +// PjRtLoadedExecutable. +Status PopulateCtxOutputsFromPjRtExecutableOutputs( + const std::vector& inputs, + const std::vector& variables, + const XlaCompiler::CompilationResult& compilation_result, + std::vector>& executable_outputs, + OpKernelContext* ctx); + +// Returns the options used for executing a PjRtLoadedExecutable. +xla::ExecuteOptions GetPjRtExecuteOptions( + absl::flat_hash_set non_donatable_input_indices); + +// Returns the device ordinal from the parsed name of the device. +int GetDeviceOrdinal(const DeviceBase* device); + +// Returns the device type from the OpKernelContext. +DeviceType GetDeviceType(OpKernelContext* ctx); + +// Runs `executable` and populates the outputs in `ctx`. `inputs` and +// `variables` are the input arguments to the computation, usually read from the +// OpKernelContext, `ctx`. Requires the device-appropriate `pjrt_client` and the +// `compilation_result` used to build the `executable`. +Status RunPjRtExecutable( + const xla::PjRtClient& pjrt_client, + const std::vector& inputs, + const std::vector& variables, + const XlaCompiler::CompilationResult& compilation_result, + xla::PjRtLoadedExecutable* executable, OpKernelContext* ctx); + // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. class XlaComputationLaunchContext { diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc new file mode 100644 index 00000000000..789daa318e1 --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -0,0 +1,542 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_launch_util.h" + +#include +#include +#include +#include + +#include +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/jit/device_compiler.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h" +#include "tensorflow/compiler/jit/variable_info.h" +#include "tensorflow/compiler/jit/variable_info_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" +#include "tensorflow/core/tfrt/common/pjrt_util.h" +#include "tensorflow/tsl/framework/allocator.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace { +using PjRtDeviceCompiler = + DeviceCompiler; +using PjRtDeviceExecutablePersistor = + DeviceExecutablePersistor; + +class PjRtExecutionUtilTest : public OpsTestBase { + public: + PjRtExecutionUtilTest() { + // Set flag to use PJRT for device compilation and execution. + auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; + rollout_config.enabled_for_xla_launch_ = true; + rollout_config.enabled_for_compile_on_demand_ = true; + + // Set flag to enable using XLA devices. PJRT currently is only supported + // for XLA devices. + GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; + + // Add and setup the XLA_CPU device. + auto device_type = DeviceType(DEVICE_XLA_CPU); + rollout_config.AllowForDeviceInXlaLaunch(device_type); + rollout_config.AllowForDeviceInXlaCompileOnDemand(device_type); + + auto jit_device_type = DeviceType(DEVICE_CPU_XLA_JIT); + auto device = + DeviceFactory::NewDevice(device_type.type_string(), SessionOptions(), + "/job:localhost/replica:0/task:0"); + device_ = device.get(); + SetDevice(device_type, std::move(device)); + + // Create PjRtClient for XLA_CPU. + TF_CHECK_OK(SetPjRtClientInTFGlobalResourceManager( + device_type, + xla::GetTfrtCpuClient(/*asynchronous=*/true, /*cpu_device_count=*/1) + .value())); + + // device_context_ should be a PjRtDeviceContext. + TF_CHECK_OK(device_->TryGetDeviceContext(&device_context_)); + + // Get the host allocator. + AllocatorAttributes host_alloc_attr; + host_alloc_attr.set_on_host(true); + host_allocator_ = device_->GetAllocator(host_alloc_attr); + + // Get the device allocator. This should give us an AsyncValueAllocator. + AllocatorAttributes device_alloc_attr; + device_alloc_attr.set_on_host(false); + device_allocator_ = device_->GetAllocator(device_alloc_attr); + + // Create the DeviceCompiler to help with compiling executables. + auto pjrt_client_or = GetOrCreatePjRtClient(device_type_); + TF_CHECK_OK(pjrt_client_or.status()); + pjrt_client_ = pjrt_client_or.value(); + device_compiler_ = new PjRtDeviceCompiler( + std::make_unique( + PjRtDeviceExecutablePersistor::Config(), jit_device_type), + std::make_unique(pjrt_client_)); + profiler_ = new DeviceCompilationProfiler(); + + compiler_options_.device_type = jit_device_type; + compiler_options_.client = nullptr; + compiler_options_.flib_def = flib_def_.get(); + } + + ~PjRtExecutionUtilTest() override { + for (const auto& tensor : tensors_) { + delete tensor; + } + tensors_.clear(); + device_context_->Unref(); + core::ScopedUnref device_compiler_ref(device_compiler_); + core::ScopedUnref profiler_ref(profiler_); + } + + // Creates a Tensor on host using the host_allocator_ + template + Tensor* CreateHostTensor(const TensorShape& shape, + const gtl::ArraySlice data) { + Tensor* host_tensor = + new Tensor(host_allocator_, DataTypeToEnum::v(), shape); + test::FillValues(host_tensor, data); + tensors_.push_back(host_tensor); + return host_tensor; + } + + // Creates a Tensor on device using the device_allocator_ + template + Tensor* CreateDeviceTensor(const TensorShape& shape, + const gtl::ArraySlice data) { + Tensor* host_tensor = CreateHostTensor(shape, data); + Tensor* device_tensor = + new Tensor(device_allocator_, DataTypeToEnum::v(), shape); + TF_EXPECT_OK(device_context_->CopyCPUTensorToDeviceSync( + host_tensor, device_, device_tensor)); + + tensors_.push_back(device_tensor); + return device_tensor; + } + + // Gets the `output_index`-th output set in the context_ + Tensor* GetOutput(int output_index) { + CHECK_LT(output_index, context_->num_outputs()); + Tensor* device_tensor = context_->mutable_output(output_index); + managed_outputs_.resize(context_->num_outputs()); + if (managed_outputs_[output_index]) { + return managed_outputs_[output_index]; + } + + Tensor* host_tensor = new Tensor(host_allocator_, device_tensor->dtype(), + device_tensor->shape()); + TF_EXPECT_OK(device_context_->CopyDeviceTensorToCPUSync( + device_tensor, "", device_, host_tensor)); + managed_outputs_[output_index] = host_tensor; + return host_tensor; + } + + // Compiles the op set in the context_ to a PjRtLoadedExecutable + void CompileToExecutable(const std::vector& args, + const XlaCompiler::CompilationResult** result, + xla::PjRtLoadedExecutable** executable, + XlaCompiler::CompileOptions compile_options = {}) { + TF_EXPECT_OK(device_compiler_->CompileSingleOpIfNeeded( + compiler_options_, args, compile_options, context_.get(), profiler_, + result, executable)); + } + + // Runs a PjRtLoadedExecutable with the given inputs, variables. Requires the + // XlaCompiler::CompilationResult that was used to build the executable. + StatusOr>> RunExecutable( + const std::vector& inputs, + const std::vector& variables, + const XlaCompiler::CompilationResult* result, + xla::PjRtLoadedExecutable* executable) { + TF_ASSIGN_OR_RETURN(auto pjrt_device, pjrt_client_->LookupAddressableDevice( + device_->parsed_name().id)); + + std::vector executable_args; + executable_args.reserve(result->input_mapping.size()); + absl::flat_hash_set non_donatable_input_indices; + PreparePjRtExecutableArguments(result->input_mapping, inputs, variables, + &executable_args, + &non_donatable_input_indices); + + xla::ExecuteOptions exe_options; + exe_options.arguments_are_tupled = false; + exe_options.untuple_result = true; + + // TODO(b/257548614): currently PJRT is compiled as portable (num_replica = + // 1 and num_partition = 1). Support multiple partitions case. + return executable->ExecutePortable(executable_args, pjrt_device, + exe_options); + } + + // Creates a Variable. Doesn't add it to the resource manager. + template + Var* CreateVariable(const string& name, const TensorShape& shape, + const gtl::ArraySlice data) { + Tensor* init_var_value = CreateDeviceTensor(shape, data); + Var* var = new Var(DataTypeToEnum::v()); + *var->tensor() = *init_var_value; + var->is_initialized = true; + + return var; + } + + // Creates a Variable, adds it to the resource manager and also adds it as one + // of the inputs in the context_ + template + void AddVariableInput(const string& name, const TensorShape& shape, + const gtl::ArraySlice data) { + Var* var = CreateVariable(name, shape, data); + ResourceMgr* rm = device_->resource_manager(); + TF_ASSERT_OK(rm->Create(rm->default_container(), name, var)); + + ResourceHandle handle; + handle.set_device(device_->name()); + handle.set_container(rm->default_container()); + handle.set_name(name); + TypeIndex type_index = TypeIndex::Make(); + handle.set_hash_code(type_index.hash_code()); + handle.set_maybe_type_name(type_index.name()); + + Tensor* input = new Tensor(host_allocator_, DT_RESOURCE, TensorShape({})); + input->scalar()() = handle; + tensors_.push_back(input); + inputs_.push_back({nullptr, input}); + } + + protected: + DeviceContext* device_context_; + Allocator* host_allocator_; + Allocator* device_allocator_; + + XlaCompiler::Options compiler_options_; + xla::PjRtClient* pjrt_client_; + PjRtDeviceCompiler* device_compiler_; + DeviceCompilationProfiler* profiler_; +}; + +TEST_F(PjRtExecutionUtilTest, PreparePjRtExecutableArguments) { + std::vector inputs; + inputs.push_back(CreateDeviceTensor(TensorShape({1, 3}), {0, 0, 0})); + inputs.push_back(CreateDeviceTensor(TensorShape({1, 3}), {1, 2, 3})); + inputs.push_back(CreateDeviceTensor(TensorShape({1, 3}), {4, 5, 6})); + std::vector input_mapping{1, 2}; + + std::vector exec_args; + exec_args.reserve(input_mapping.size()); + absl::flat_hash_set non_donatable_input_indices; + PreparePjRtExecutableArguments(input_mapping, inputs, {}, &exec_args, + &non_donatable_input_indices); + + EXPECT_EQ(exec_args.size(), 2); + + std::shared_ptr literal1 = *exec_args[0]->ToLiteralSync(); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + *literal1, xla::LiteralUtil::CreateR2({{1, 2, 3}}))); + + std::shared_ptr literal2 = *exec_args[1]->ToLiteralSync(); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + *literal2, xla::LiteralUtil::CreateR2({{4, 5, 6}}))); +} + +TEST_F(PjRtExecutionUtilTest, PreparePjRtExecutableArgumentsVariableInputs) { + std::vector variables; + Var* var1 = CreateVariable("v1", TensorShape({1, 2}), {1, 2}); + variables.emplace_back(1, "v1", var1); + Var* var2 = CreateVariable("v2", TensorShape({1, 2}), {3, 4}); + variables.emplace_back(2, "v2", var2); + + std::vector inputs; + inputs.push_back(CreateDeviceTensor(TensorShape({1, 3}), {0, 0, 0})); + std::vector input_mapping{1, 2}; + + std::vector exec_args; + exec_args.reserve(input_mapping.size()); + absl::flat_hash_set non_donatable_input_indices; + PreparePjRtExecutableArguments(input_mapping, inputs, variables, &exec_args, + &non_donatable_input_indices); + + EXPECT_EQ(exec_args.size(), 2); + + std::shared_ptr literal1 = *exec_args[0]->ToLiteralSync(); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + *literal1, xla::LiteralUtil::CreateR2({{1, 2}}))); + + std::shared_ptr literal2 = *exec_args[1]->ToLiteralSync(); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + *literal2, xla::LiteralUtil::CreateR2({{3, 4}}))); +} + +TEST_F(PjRtExecutionUtilTest, PopulateCtxOutputs) { + XlaOpRegistry::RegisterCompilationKernels(); + TF_EXPECT_OK(NodeDefBuilder("AddV2", "AddV2") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Attr("T", DT_INT32) + .Device("/job:localhost/replica:0/task:0/device:XLA_CPU:0") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + // Add inputs. + Tensor* a = CreateDeviceTensor(TensorShape({1, 3}), {1, 2, 3}); + Tensor* b = CreateDeviceTensor(TensorShape({1, 3}), {4, 5, 6}); + inputs_.push_back({nullptr, a}); + inputs_.push_back({nullptr, b}); + + CreateContext(); + + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1, 3}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({1, 3}); + + const XlaCompiler::CompilationResult* result; + xla::PjRtLoadedExecutable* executable; + CompileToExecutable(args, &result, &executable); + + std::vector inputs; + inputs.push_back(a); + inputs.push_back(b); + TF_ASSERT_OK_AND_ASSIGN(auto execute_outputs, + RunExecutable(inputs, {}, result, executable)); + + TF_EXPECT_OK(PopulateCtxOutputsFromPjRtExecutableOutputs( + inputs, {}, *result, execute_outputs, context_.get())); + + Tensor* expected = CreateHostTensor(TensorShape({1, 3}), {5, 7, 9}); + test::ExpectTensorEqual(*expected, *GetOutput(0)); +} + +TEST_F(PjRtExecutionUtilTest, PopulateCtxOutputsDynamicShape) { + XlaOpRegistry::RegisterCompilationKernels(); + TF_EXPECT_OK(NodeDefBuilder("testWhere", "Where") + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Device("/job:localhost/replica:0/task:0/device:XLA_CPU:0") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + // Add inputs. + Tensor* a = + CreateDeviceTensor(TensorShape({2, 3}), {0., 1., 1., 0., 0., 0.}); + inputs_.push_back({nullptr, a}); + + CreateContext(); + + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_FLOAT; + args[0].shape = TensorShape({2, 3}); + + const XlaCompiler::CompilationResult* result; + xla::PjRtLoadedExecutable* executable; + CompileToExecutable(args, &result, &executable); + + std::vector inputs; + inputs.push_back(a); + TF_ASSERT_OK_AND_ASSIGN(auto execute_outputs, + RunExecutable(inputs, {}, result, executable)); + + TF_EXPECT_OK(PopulateCtxOutputsFromPjRtExecutableOutputs( + inputs, {}, *result, execute_outputs, context_.get())); + // The expected output is indices of non-zero inputs. + Tensor* expected = CreateHostTensor(TensorShape({2, 2}), {0, 1, 0, 2}); + test::ExpectTensorEqual(*expected, *GetOutput(0)); +} + +TEST_F(PjRtExecutionUtilTest, PopulateCtxOutputsVariableInputs) { + XlaOpRegistry::RegisterCompilationKernels(); + TF_EXPECT_OK(NodeDefBuilder("AddV2", "AddV2") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Attr("T", DT_INT32) + .Device("/job:localhost/replica:0/task:0/device:XLA_CPU:0") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddVariableInput("var1", TensorShape({1, 2}), {1, 2}); + AddVariableInput("var2", TensorShape({1, 2}), {3, 4}); + + CreateContext(); + + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1, 2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({1, 2}); + + const XlaCompiler::CompilationResult* result; + xla::PjRtLoadedExecutable* executable; + CompileToExecutable(args, &result, &executable); + + std::vector inputs = InputsFromContext(context_.get()); + std::vector variables_indices = + GetResourceVariableIndicesFromContext(context_.get()); + std::vector variables; + variables.reserve(variables_indices.size()); + TF_ASSERT_OK(GetVariableInfosFromInputs(context_->resource_manager(), + context_->device(), inputs, + variables_indices, &variables)); + TF_ASSERT_OK_AND_ASSIGN(auto execute_outputs, + RunExecutable(inputs, variables, result, executable)); + TF_EXPECT_OK(PopulateCtxOutputsFromPjRtExecutableOutputs( + inputs, variables, *result, execute_outputs, context_.get())); + + Tensor* expected = CreateHostTensor(TensorShape({1, 2}), {4, 6}); + test::ExpectTensorEqual(*expected, *GetOutput(0)); +} + +TEST_F(PjRtExecutionUtilTest, PopulateCtxOutputsResourceUpdates) { + XlaOpRegistry::RegisterCompilationKernels(); + TF_EXPECT_OK(NodeDefBuilder("AssignAddVariableOp", "AssignAddVariableOp") + .Input(FakeInput(DT_RESOURCE)) + .Input(FakeInput(DT_INT32)) + .Attr("dtype", DT_INT32) + .Device("/job:localhost/replica:0/task:0/device:XLA_CPU:0") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddVariableInput("var", TensorShape({1, 3}), {1, 2, 3}); + Tensor* a = CreateDeviceTensor(TensorShape({1, 3}), {2, 2, 2}); + inputs_.push_back({nullptr, a}); + + CreateContext(); + + std::vector inputs = InputsFromContext(context_.get()); + std::vector variables_indices = + GetResourceVariableIndicesFromContext(context_.get()); + std::vector variables; + variables.reserve(variables_indices.size()); + TF_ASSERT_OK(GetVariableInfosFromInputs(context_->resource_manager(), + context_->device(), inputs, + variables_indices, &variables)); + TF_ASSERT_OK_AND_ASSIGN(std::vector constant_input_indices, + GetConstantInputIndicesFromContext(context_.get())); + TF_ASSERT_OK(LockVariables(absl::MakeSpan(variables))); + TF_ASSERT_OK_AND_ASSIGN( + std::vector args, + XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_input_indices, inputs, variables, + static_cast(context_->device()))); + + const XlaCompiler::CompilationResult* result; + xla::PjRtLoadedExecutable* executable; + CompileToExecutable(args, &result, &executable); + TF_ASSERT_OK_AND_ASSIGN(auto execute_outputs, + RunExecutable(inputs, variables, result, executable)); + + TF_EXPECT_OK(PopulateCtxOutputsFromPjRtExecutableOutputs( + inputs, variables, *result, execute_outputs, context_.get())); + + // Verify that there are no outputs. + EXPECT_EQ(context_->num_outputs(), 0); + + // Verify that the original variable was updated. + ResourceMgr* rm = device_->resource_manager(); + Var* var = nullptr; + TF_ASSERT_OK(rm->Lookup(rm->default_container(), "var", &var)); + core::ScopedUnref var_ref(var); + + Tensor* device_tensor = var->tensor(); + Tensor* host_tensor = new Tensor(host_allocator_, device_tensor->dtype(), + device_tensor->shape()); + tensors_.push_back(host_tensor); + TF_ASSERT_OK(device_context_->CopyDeviceTensorToCPUSync( + device_tensor, "", device_, host_tensor)); + + Tensor* expected = CreateHostTensor(TensorShape({1, 3}), {3, 4, 5}); + test::ExpectTensorEqual(*expected, *host_tensor); +} + +TEST(XlaLaunchUtilTest, GetPjRtExecuteOptions) { + xla::ExecuteOptions options = GetPjRtExecuteOptions({}); + EXPECT_FALSE(options.arguments_are_tupled); + EXPECT_TRUE(options.untuple_result); + EXPECT_TRUE(options.use_major_to_minor_data_layout_for_callbacks); +} + +TEST_F(PjRtExecutionUtilTest, RunPjRtExecutable) { + XlaOpRegistry::RegisterCompilationKernels(); + TF_EXPECT_OK(NodeDefBuilder("AddV2", "AddV2") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Attr("T", DT_INT32) + .Device("/job:localhost/replica:0/task:0/device:XLA_CPU:0") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddVariableInput("var1", TensorShape({1, 2}), {1, 2}); + AddVariableInput("var2", TensorShape({1, 2}), {3, 4}); + + CreateContext(); + + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1, 2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({1, 2}); + + const XlaCompiler::CompilationResult* result; + xla::PjRtLoadedExecutable* executable; + CompileToExecutable(args, &result, &executable); + + std::vector inputs = InputsFromContext(context_.get()); + std::vector variables_indices = + GetResourceVariableIndicesFromContext(context_.get()); + std::vector variables; + variables.reserve(variables_indices.size()); + TF_ASSERT_OK(GetVariableInfosFromInputs(context_->resource_manager(), + context_->device(), inputs, + variables_indices, &variables)); + TF_ASSERT_OK(RunPjRtExecutable(*pjrt_client_, inputs, variables, *result, + executable, context_.get())); + + Tensor* expected = CreateHostTensor(TensorShape({1, 2}), {4, 6}); + test::ExpectTensorEqual(*expected, *GetOutput(0)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index df4ab4460f5..b311faa13ab 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -17,20 +17,62 @@ limitations under the License. #include #include +#include #include #include #include "tensorflow/compiler/jit/device_executable_persistor.h" #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h" #include "tensorflow/compiler/jit/xla_device_compiler_client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" +#include "tensorflow/core/tfrt/common/pjrt_util.h" #include "tensorflow/core/tpu/tpu_defs.h" namespace tensorflow { namespace { using XlaDeviceCompiler = DeviceCompiler; +using PjRtDeviceCompiler = + DeviceCompiler; +using XlaDeviceExecutablePersistor = + DeviceExecutablePersistor; +using PjRtDeviceExecutablePersistor = + DeviceExecutablePersistor; + +XlaDeviceCompiler* CreateXlaDeviceCompiler( + const XlaDeviceExecutablePersistor::Config& persistor_config, + DeviceType device_type, xla::LocalClient* local_client) { + return new XlaDeviceCompiler( + std::make_unique( + std::move(persistor_config), device_type), + std::make_unique(local_client)); +} + +PjRtDeviceCompiler* CreatePjRtDeviceCompiler( + const PjRtDeviceExecutablePersistor::Config& persistor_config, + DeviceType device_type, xla::PjRtClient* pjrt_client) { + return new PjRtDeviceCompiler( + std::make_unique( + std::move(persistor_config), device_type), + std::make_unique(pjrt_client)); +} + +StatusOr>> GetAllowedGpus( + FunctionLibraryRuntime* flr) { + std::optional> gpu_ids = std::nullopt; + + if (flr->config_proto()) { + string allowed_gpus = + flr->config_proto()->gpu_options().visible_device_list(); + TF_ASSIGN_OR_RETURN(gpu_ids, ParseVisibleDeviceList(allowed_gpus)); + } + + return gpu_ids; +} } // namespace xla::StatusOr>> ParseVisibleDeviceList( @@ -57,32 +99,27 @@ xla::StatusOr>> ParseVisibleDeviceList( Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, const XlaPlatformInfo& platform_info, XlaDeviceCompiler** xla_device_compiler) { - using XlaDeviceExecutablePersistor = - DeviceExecutablePersistor; XlaDeviceExecutablePersistor::Config persistor_config( GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory, GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks, GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix); if (platform_info.xla_device_metadata()) { - auto persistor = std::make_unique( - std::move(persistor_config), - platform_info.xla_device_metadata()->jit_device_type()); - auto compiler_client = std::make_unique( + *xla_device_compiler = CreateXlaDeviceCompiler( + persistor_config, + platform_info.xla_device_metadata()->jit_device_type(), platform_info.xla_device_metadata()->client()); - *xla_device_compiler = - new XlaDeviceCompiler(std::move(persistor), std::move(compiler_client)); return OkStatus(); } // TFRT-TPU is used if device type is `DEVICE_TPU` and platform_info does not - // have `xla_device_metadata`. + // have `xla_device_metadata`. This is used for TFRT-TPU when + // BuildXlaDeviceCompiler() is called in GetCompilerIr(). Currently only + // lowering to HLO is needed there and xla::LocalClient doesn't support + // building the executable for TFRT-TPU and hence, is set to nullptr here. if (platform_info.device_type() == DEVICE_TPU) { - auto persistor = std::make_unique( - std::move(persistor_config), DeviceType(DEVICE_TPU_XLA_JIT)); - auto compiler_client = std::make_unique(nullptr); - *xla_device_compiler = - new XlaDeviceCompiler(std::move(persistor), std::move(compiler_client)); + *xla_device_compiler = CreateXlaDeviceCompiler( + persistor_config, DeviceType(DEVICE_TPU_XLA_JIT), nullptr); return OkStatus(); } @@ -118,13 +155,8 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, client_options.set_intra_op_parallelism_threads( device->tensorflow_cpu_worker_threads()->num_threads); - if (flr->config_proto()) { - string allowed_gpus = - flr->config_proto()->gpu_options().visible_device_list(); - TF_ASSIGN_OR_RETURN(std::optional> gpu_ids, - ParseVisibleDeviceList(allowed_gpus)); - client_options.set_allowed_devices(gpu_ids); - } + TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr)); + client_options.set_allowed_devices(allowed_gpus); auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); if (!client.ok()) { @@ -137,13 +169,77 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, platform_info.device_type().type()); } - auto persistor = std::make_unique( - std::move(persistor_config), - DeviceType(registration->compilation_device_name)); - auto compiler_client = - std::make_unique(client.value()); - *xla_device_compiler = - new XlaDeviceCompiler(std::move(persistor), std::move(compiler_client)); + *xla_device_compiler = CreateXlaDeviceCompiler( + persistor_config, DeviceType(registration->compilation_device_name), + client.value()); + return OkStatus(); +} + +Status BuildPjRtDeviceCompiler(const XlaPlatformInfo& platform_info, + FunctionLibraryRuntime* flr, + PjRtDeviceCompiler** pjrt_device_compiler) { + PjRtDeviceExecutablePersistor::Config persistor_config( + GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory, + GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks, + GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix); + + DeviceType device_type = platform_info.device_type(); + + if (platform_info.xla_device_metadata()) { + VLOG(2) << "Building PjRtDeviceCompiler using " + "platform_info.xla_device_metadata()."; + + DeviceType compilation_device_type = + platform_info.xla_device_metadata()->jit_device_type(); + TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); + + *pjrt_device_compiler = CreatePjRtDeviceCompiler( + persistor_config, compilation_device_type, pjrt_client); + return OkStatus(); + } + if (platform_info.pjrt_device_metadata()) { + VLOG(2) << "Building PjRtDeviceCompiler using " + "platform_info.pjrt_device_metadata()."; + + DeviceType compilation_device_type = + platform_info.pjrt_device_metadata()->jit_device_type(); + TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); + + *pjrt_device_compiler = CreatePjRtDeviceCompiler( + persistor_config, compilation_device_type, pjrt_client); + return OkStatus(); + } + + // TFRT-TPU is used if device_type is `DEVICE_TPU` and platform_info does not + // have `xla_device_metadata`. + if (device_type == DEVICE_TPU) { + TF_ASSIGN_OR_RETURN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); + *pjrt_device_compiler = CreatePjRtDeviceCompiler( + persistor_config, DeviceType(DEVICE_TPU_XLA_JIT), pjrt_client); + return OkStatus(); + } + + VLOG(2) << "platform_info.xla_device_metadata not found and " + "platform_info.device_type() != DEVICE_TPU. Building " + "PjRtDeviceCompiler for non-XLA device."; + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + device_type.type()); + } + auto compilation_device_type = + DeviceType(registration->compilation_device_name); + + TF_ASSIGN_OR_RETURN(auto allowed_gpus, GetAllowedGpus(flr)); + // TODO(b/255826209): Set platform, intra op parallelism threads if required + // and when supported by GetOrCreatePjRtClient(). + // The `allowed_gpus` argument is used only if the `device_type` is GPU. + TF_ASSIGN_OR_RETURN(auto pjrt_client, + GetOrCreatePjRtClient(device_type, allowed_gpus)); + + *pjrt_device_compiler = CreatePjRtDeviceCompiler( + persistor_config, compilation_device_type, pjrt_client); return OkStatus(); } @@ -151,6 +247,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { auto device = static_cast(device_base); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; + const PjRtBaseDevice::Metadata* pjrt_device_metadata = nullptr; std::shared_ptr custom_allocator; if (device->device_type() == DEVICE_CPU) { @@ -174,10 +271,14 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { platform_id = xla_device_metadata->platform()->id(); custom_allocator = xla_device_metadata->client()->backend().shared_memory_allocator(); + } else if (auto metadata = PjRtBaseDevice::GetMetadataFromDevice(device); + metadata.ok()) { + pjrt_device_metadata = *metadata; } return XlaPlatformInfo(DeviceType(device->device_type()), platform_id, - xla_device_metadata, custom_allocator); + xla_device_metadata, pjrt_device_metadata, + custom_allocator); } std::shared_ptr GetAllocator( diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index 97ed0b4a9db..725a876904d 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/device_compiler.h" +#include "tensorflow/compiler/jit/pjrt_base_device.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" @@ -27,7 +28,8 @@ namespace tensorflow { // Holds some information about the platform on which an // XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of -// abstraction for normal and XLA devices. +// abstraction for normal, XLA devices and devices inheriting from +// PjRtBaseDevice. class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} @@ -35,10 +37,12 @@ class XlaPlatformInfo { explicit XlaPlatformInfo( const DeviceType device_type, se::Platform::Id platform_id, const XlaDevice::Metadata* xla_device_metadata, + const PjRtBaseDevice::Metadata* pjrt_device_metadata, std::shared_ptr device_allocator) : device_type_(device_type), platform_id_(platform_id), xla_device_metadata_(xla_device_metadata), + pjrt_device_metadata_(pjrt_device_metadata), device_allocator_(device_allocator) {} XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; @@ -65,6 +69,10 @@ class XlaPlatformInfo { } bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + const PjRtBaseDevice::Metadata* pjrt_device_metadata() const { + return pjrt_device_metadata_; + } + private: DeviceType device_type_; se::Platform::Id platform_id_; @@ -74,6 +82,11 @@ class XlaPlatformInfo { // XlaLaunch/_XlaCompile/_XlaRun OpKernel. const XlaDevice::Metadata* xla_device_metadata_; + // pjrt_device_metadata_ lives in tensorflow::PjRtBaseDevice in which the + // XlaLaunch/XlaCompileOnDemand op is placed and thus does not die before the + // op kernel. + const PjRtBaseDevice::Metadata* pjrt_device_metadata_; + // If the op associated with this XlaPlatformInfo is placed on an XLA device // then device_allocator_ is the xla::Backend's memory allocator. If the op // is placed on a regular CPU or GPU device then device_allocator_ is null. @@ -90,13 +103,28 @@ class XlaPlatformInfo { StatusOr>> ParseVisibleDeviceList( absl::string_view visible_device_list); -// Returns created XLA compilation cache. +// Builds a DeviceCompiler that uses xla::LocalClient using `platform_info` and +// sets *xla_device_compiler to point to it. Uses flags from +// `MarkForCompilationPassFlags` for configuring the persistor used in the +// DeviceCompiler. Status BuildXlaDeviceCompiler( DeviceBase* dev, FunctionLibraryRuntime* flr, const XlaPlatformInfo& platform_info, DeviceCompiler** xla_device_compiler); +// Builds a DeviceCompiler that uses xla::PjRtClient using an appropriate +// PjRtClient for `platform_info.device_type()` and sets *pjrt_device_compiler +// to point to it. Uses flags from `MarkForCompilationPassFlags` for configuring +// the persistor used in the DeviceCompiler. Please note that non-XLA devices +// aren't supported yet. This is because: +// 1. PjRtClient doesn't support data transfer for non-XLA devices yet +// 2. Fetching the PjRtClient for non-XLA devices is also not supported yet +Status BuildPjRtDeviceCompiler( + const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr, + DeviceCompiler** + pjrt_device_compiler); + // Returns information about the platform from kernel context. XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); diff --git a/tensorflow/compiler/jit/xla_platform_info_test.cc b/tensorflow/compiler/jit/xla_platform_info_test.cc new file mode 100644 index 00000000000..0dedbb39bb9 --- /dev/null +++ b/tensorflow/compiler/jit/xla_platform_info_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_platform_info.h" + +#include +#include + +#include +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/test_util.h" +#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status_matchers.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" +#include "tensorflow/core/tfrt/common/pjrt_util.h" +#include "tensorflow/core/tpu/tpu_defs.h" + +namespace tensorflow { +namespace { +using XlaDeviceCompiler = + DeviceCompiler; +using PjRtDeviceCompiler = + DeviceCompiler; + +class XlaPlatformInfoTest : public ::testing::Test { + protected: + void SetUp() override { + tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; + } + + DeviceSetup device_setup_; +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerXlaDeviceMetadata) { + device_setup_.AddDevicesAndSetUp({DEVICE_XLA_GPU}); + + Device* device = device_setup_.GetDevice(DEVICE_XLA_GPU); + const XlaDevice::Metadata* metadata = nullptr; + TF_CHECK_OK(XlaDevice::GetMetadataFromDevice(device, &metadata)); + XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device); + + XlaDeviceCompiler* xla_device_compiler = nullptr; + TF_EXPECT_OK(BuildXlaDeviceCompiler(device, device_setup_.flr(), + platform_info, &xla_device_compiler)); + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); + + EXPECT_EQ(xla_device_compiler->device_type(), metadata->jit_device_type()); + EXPECT_EQ(xla_device_compiler->client(), metadata->client()); +} + +TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerNonXlaDevice) { + device_setup_.AddDevicesAndSetUp({DEVICE_GPU}); + Device* device = device_setup_.GetDevice(DEVICE_GPU); + + XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device); + XlaDeviceCompiler* xla_device_compiler = nullptr; + TF_EXPECT_OK(BuildXlaDeviceCompiler(device, device_setup_.flr(), + platform_info, &xla_device_compiler)); + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); + + EXPECT_EQ(xla_device_compiler->device_type(), DeviceType(DEVICE_GPU_XLA_JIT)); + EXPECT_TRUE(xla_device_compiler->client() != nullptr); +} + +TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestXlaDevice) { + DeviceType device_type = DeviceType(DEVICE_XLA_GPU); + device_setup_.AddDevicesAndSetUp({device_type.type()}); + + Device* device = device_setup_.GetDevice(device_type.type()); + const XlaDevice::Metadata* metadata = nullptr; + TF_CHECK_OK(XlaDevice::GetMetadataFromDevice(device, &metadata)); + XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device); + + PjRtDeviceCompiler* pjrt_device_compiler = nullptr; + TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(), + &pjrt_device_compiler)); + core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + + TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); + EXPECT_EQ(pjrt_device_compiler->device_type(), metadata->jit_device_type()); + EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client); +} + +TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTestGpuDevice) { + device_setup_.AddDevicesAndSetUp({DEVICE_GPU}); + Device* device = device_setup_.GetDevice(DEVICE_GPU); + XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(device); + PjRtDeviceCompiler* pjrt_device_compiler = nullptr; + TF_EXPECT_OK(BuildPjRtDeviceCompiler(platform_info, device_setup_.flr(), + &pjrt_device_compiler)); + core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); +} +#endif + +TEST_F(XlaPlatformInfoTest, BuildXlaDeviceCompilerTpuDevice) { + DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT); + + // Instead of creating/initializing a TPU device, create a dummy platform_info + // and use a nullptr for Device for testing purposes. Only + // XlaPlatformInfo::device_type() is needed to build the appropriate + // XlaDeviceCompiler. + Device* device = nullptr; + XlaPlatformInfo platform_info(DeviceType(DEVICE_TPU), /*platform_id=*/nullptr, + /*xla_device_metadata=*/nullptr, + /*pjrt_device_metadata=*/nullptr, + /*device_allocator=*/nullptr); + + XlaDeviceCompiler* xla_device_compiler = nullptr; + TF_EXPECT_OK(BuildXlaDeviceCompiler(device, nullptr, platform_info, + &xla_device_compiler)); + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); + + EXPECT_EQ(xla_device_compiler->device_type(), compilation_device_type); + // TFRT-TPU is used if device type is `DEVICE_TPU` and `platform_info` does + // not have `xla_device_metadata`. XlaDeviceCompiler/xla::LocalClient is not + // used in this case. + EXPECT_EQ(xla_device_compiler->client(), nullptr); +} + +// TODO(b/255826209): Look into using an actual TPU device for the unit test, +// and move this out of OSS. +TEST_F(XlaPlatformInfoTest, BuildPjRtDeviceCompilerTpuDevice) { + DeviceType device_type = DeviceType(DEVICE_TPU); + DeviceType compilation_device_type = DeviceType(DEVICE_TPU_XLA_JIT); + // Use a CPU PjRtClient instead of a TPU one just for testing whether + // GetOrCreatePjRtClient() is being called with the correct arguments. + TF_CHECK_OK(SetPjRtClientInTFGlobalResourceManager( + device_type, + xla::GetTfrtCpuClient(/*asynchronous=*/true, /*cpu_device_count=*/1) + .value())); + TF_ASSERT_OK_AND_ASSIGN(auto pjrt_client, GetOrCreatePjRtClient(device_type)); + + // Instead of creating/initializing a TPU device, create a dummy platform_info + // for testing purposes. Only XlaPlatformInfo::device_type() is needed to + // build the appropriate PjRtDeviceCompiler. + XlaPlatformInfo platform_info(device_type, /*platform_id=*/nullptr, + /*xla_device_metadata=*/nullptr, + /*pjrt_device_metadata=*/nullptr, + /*device_allocator=*/nullptr); + + PjRtDeviceCompiler* pjrt_device_compiler = nullptr; + TF_EXPECT_OK( + BuildPjRtDeviceCompiler(platform_info, nullptr, &pjrt_device_compiler)); + core::ScopedUnref pjrt_device_compiler_ref(pjrt_device_compiler); + + EXPECT_EQ(pjrt_device_compiler->device_type(), compilation_device_type); + EXPECT_EQ(pjrt_device_compiler->client(), pjrt_client); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tpu_device.cc b/tensorflow/compiler/jit/xla_tpu_device.cc index e6047e68bde..1f4db51e417 100644 --- a/tensorflow/compiler/jit/xla_tpu_device.cc +++ b/tensorflow/compiler/jit/xla_tpu_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 923b203ab38..b359e05aac7 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -60,8 +60,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", # buildcleaner:keep "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", - "//tensorflow/compiler/mlir/tf2xla:tf_xla_passes", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tosa:tf_passes", "//tensorflow/compiler/mlir/tosa:tf_tfl_passes", "//tensorflow/compiler/mlir/tosa:tfl_passes", @@ -249,6 +249,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/lib/monitoring:cell_reader", "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md index 17996380f68..9405aa417df 100644 --- a/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md +++ b/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md @@ -178,6 +178,11 @@ func @_func(%arg0: tensor) -> tensor { return %identity : tensor } ``` + +#### Options +``` +-globally-unique-func-names : If true, the pass adds extra identifiers to make function names globally unique within a process, not just within a module. +``` ### `-tf-device-constant-sinking`: Sinks constants implicitly captured in a tf_device.cluster region. This pass sinks implicitly captured constants (`tf.Const` ops) used by and into a `tf_device.cluster` region. Performing this prior to outlining will reduce the @@ -244,6 +249,11 @@ func @_func(%arg0: tensor) -> tensor { return %identity : tensor } ``` + +#### Options +``` +-globally-unique-func-names : If true, the pass adds extra identifiers to make function names globally unique within a process, not just within a module. +``` ### `-tf-device-mark-input-output-aliases`: Marks device cluster inputs-output pairs that read/write to the same variable as aliases This pass analyzes the inputs and outputs to device cluster and marks those input-output pairs as aliases (using `tf.aliasing_output` attribute) which read @@ -259,6 +269,9 @@ inside device cluster. This would allow shape inference pass to further refine operand/result shapes of these ops. This is only safe to do when compiling to XLA. ### `-tf-einsum`: Transform Einsum to other TF Ops for the supported variants +### `-tf-embedding-pipelining`: Rewrite graph for embedding pipelining +For architectures that support accelerated embedding lookups, this pass will +rewrite the graph to use pipelining for better device utilization. ### `-tf-executor-break-up-islands`: Transform from TF control dialect to TF executor dialect. ### `-tf-executor-check-control-dependencies`: Checks control dependencies This pass analyzes control dependencies between islands and warns about @@ -726,6 +739,11 @@ func @outside_compilation() -> tensor { return %0 : tensor } ``` +### `-tf-extract-tpu-copy-with-dynamic-shape-op`: Extract the TPUCopyWithDynamicShapeOp out of the host launch and place it on device launch +This pass looks for TPUCopyWithDynamicShapeOp which wraps in a +`tf_device.launch` with host device attribute. It extracts the ops and wrap +them in `tf_device.launch` with tpu device attribute so that ops can be +run on TPU instead of CPU while still being compiled on host. ### `-tf-functional-control-flow-to-cfg`: Transform functional control flow Ops to MLIR Control Form Graph (CFG) form ### `-tf-functional-control-flow-to-regions`: Transforms functional control flow operations to their region-based counterparts This pass transforms functional control flow operations in the TensorFlow @@ -1007,7 +1025,7 @@ Would become the following ops (unimportant attribute, type are omitted): "tf_device.launch"() { "tf.B"() {_xla_outside_compilation = "cluster1"} tf_device.return - } {device = "TPU_REPLICATED_HOST"} : () -> () + } {device = "TPU_REPLICATED_HOST_0"} : () -> () "tf.C"() tf_device.return }) {num_cores_per_replica = 1, topology = "", device_assignment = []} @@ -1161,6 +1179,12 @@ region and hoists them out. It also makes `tf.Shape` ops replicate invariant if possible. This currently updates or replaces `tf.Shape` ops of replicated arguments, either tensors or resources. +The primary benefit of the pass is to hoist `num_replicas` `_TPUCompile`s +into a single `_TPUCompile`. + +This pass assumes that when a `tf.Shape` directly inputs from `replicate` +params, then it is the same shape across replicas. + For example, the following ```mlir @@ -1409,6 +1433,10 @@ func @main(%arg0: tensor<8x4xf32>) { return } ``` +### `-tf-tpu-annotate-dynamic-shape-inputs`: Annotate the inputs returned by TPUCopyWithDynamicShapeOp with dynamic shape +This pass looks for the usage of the result of TPUCopyWithDynamicShapeOp +and sets the shape of these inputs to be dynamic shaped. This will ensure +that the generated HLO program is correctly reflecting the dynamic shape. ### `-tf-tpu-cleanup-cluster-attributes`: Eliminate _replication_info and other attributes from ops in a cluster This pass eliminate `_replication_info` and `device` attribute on operations that are contained in a tf_device.cluster op. @@ -1508,6 +1536,25 @@ Then said `ReadVariableOp` is going to get replaced by: tf_device.return %2 : tensor<4xf32> }) {...} : () -> tensor<4xf32> ``` +### `-tf-tpu-colocate-splits`: Colocates each Split op with its predecessor +It is beneficial for performance to assign a `Split` op to the same device +as its predecessor. This is because the weight of cut edges is always +minimized when the `Split` is with its predecessor. This colocation +constraint will be used by the placer graph optimization to assign a device +to the op. + +This pass should run in the export pipeline after tf-replicate-to-island so +each replica has its own distinct (predecessor, Split) pair. + +The colocation class (`_class`) of the `Split` is set to the same class as +its predecessor: + +```mlir +%outputs1:2, %control1 = tf_executor.island wraps "tf.IteratorGetNext"(%arg) + {_class = ["loc:@dataset_iterator_1"]} +%outputs2:2, %control2 = tf_executor.island wraps "tf.Split"(%outputs0, %outputs1#1) + {_class = ["loc:@dataset_iterator_1", num_split = 2 : i32} +``` ### `-tf-tpu-device-propagation`: Propagates TPU devices from ops to users ### `-tf-tpu-dynamic-layout-pass`: Inserts TPU layout ops to determine layout at run time. A pass that allows TPU input layout to be determined after JIT compilation. @@ -1740,7 +1787,7 @@ will be rewritten as: ```mlir func @tf_tpu_rewrite(%arg0: tensor, %arg1: tensor) { - %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} { + %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST_0 = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} { %1:2 = "tf_device.launch"() ( { %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = ""} : () -> (tensor, tensor<3x!tf_type.string>) tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index f749b3e1221..27c2706622a 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -309,6 +309,7 @@ cc_library( "ir/tfl_ops.h", "transforms/passes.h", "utils/attribute_utils.h", + "utils/utils.h", "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], deps = [ @@ -665,8 +666,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass", "//tensorflow/compiler/mlir/tensorflow:verification_utils", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/mlir_hlo", @@ -966,29 +967,35 @@ cc_library( ":tensorflow_lite", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:status", + "//tensorflow/lite:graph_info", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", + "//tensorflow/lite/core:framework", + "//tensorflow/lite/core/c:private_common", "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", "//tensorflow/lite/experimental/remat:metadata_util", - "//tensorflow/lite/kernels/internal:kernel_utils", + "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "//tensorflow/lite/schema:schema_conversion_utils", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/versioning", "//tensorflow/lite/tools/versioning:gpu_compatibility", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:tstring", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@flatbuffers", @@ -998,7 +1005,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", ], ) @@ -1268,6 +1274,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_tfl", "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/stablehlo/serializer:flatbuffer_export", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", @@ -1282,6 +1289,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:reduced_precision_support", diff --git a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc index bfae3e96202..5f77797b9aa 100644 --- a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc +++ b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/cluster_util.h" namespace mlir { @@ -44,7 +45,7 @@ namespace common { bool IsConstantOrNone(Operation* op) { return (op->getNumResults() == 1 && op->getResult(0).getType().isa()) || - matchPattern(op, m_Constant()); + matchPattern(op, m_Constant()) || isa(op); } // Pre-order traverse, adding results and BlockArgs to `been_defined` and diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h index 99928fcf4d8..38286ed3cfe 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h @@ -94,6 +94,8 @@ class TargetHardware { // Usually should be something like mlir::TypeID::get() virtual mlir::TypeID GetTypeId() const = 0; + virtual void GetDependentDialects(mlir::DialectRegistry& registry) const {} + protected: // All registered hardware ops. std::vector> hardware_ops_; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD index 84e3df38b08..d363334fb5f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD @@ -1,5 +1,9 @@ load("//tensorflow:tensorflow.default.bzl", "pybind_extension") load("//tensorflow:tensorflow.bzl", "VERSION") +load( + "//third_party/mkl_dnn:build_defs.bzl", + "if_onednn_v3", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -96,7 +100,7 @@ pybind_extension( "@upb//:__subpackages__", "@XNNPACK//:__subpackages__", "@zlib//:__subpackages__", - ], + ] + if_onednn_v3(["@onednn_v3//:__subpackages__"]), deps = [ ":tac_wrapper_lib", "//tensorflow/python:pybind11_lib", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir index 9e14f1eae7e..18b9e0fd605 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir @@ -1,4 +1,6 @@ // RUN: tac-opt-all-backends -tfl-raise-target-subgraphs %s -split-input-file | FileCheck %s +// RUN: tac-opt-all-backends -tfl-raise-target-subgraphs="skip-raise-cpu-ops=true" %s -split-input-file | FileCheck %s --check-prefixes=CHECK-SKIP-CPU +// RUN: tac-opt-all-backends -tfl-raise-target-subgraphs="ignore-inference-type=true" %s -split-input-file | FileCheck %s --check-prefixes=CHECK-IGNORE-INFERENCE-TYPE module { func.func @simpleWhile(%arg0: tensor) -> tensor { @@ -502,3 +504,69 @@ func.func @cond_false_72730(%arg0: tensor, %arg1: tensor // CHECK: return %1, %5, %7, %11, %13, %15, %16, %18, %20, %21 : tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor // CHECK: } + +// ----- + +// CHECK-SKIP-CPU-LABEL: testSkipCpuOps +func.func @testSkipCpuOps(%arg0: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { + %0 = "tfl.add"(%arg0, %arg0) {tac.device = "GPU", fused_activation_function = "RELU6", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tfl.add"(%arg0, %0) {tac.device = "CPU", fused_activation_function = "RELU6", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0, %1 : tensor<1xf32>, tensor<1xf32> +} + +// CHECK-SKIP-CPU: %[[RES0:.*]] = call @func_0_GPU_FLOAT(%arg0) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1xf32>) -> tensor<1xf32> +// CHECK-SKIP-CPU: %[[RES1:.*]] = tfl.add %arg0, %[[RES0]] {fused_activation_function = "RELU6", tac.device = "CPU", tac.inference_type = "FLOAT"} : tensor<1xf32> +// CHECK-SKIP-CPU: return %[[RES0]], %[[RES1]] : tensor<1xf32>, tensor<1xf32> +// CHECK-SKIP-CPU: } +// CHECK-SKIP-CPU: func.func private @func_0_GPU_FLOAT(%arg0: tensor<1xf32>) -> tensor<1xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK-SKIP-CPU: %[[RES2:.*]] = tfl.add %arg0, %arg0 {fused_activation_function = "RELU6", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1xf32> +// CHECK-SKIP-CPU: return %[[RES2]] : tensor<1xf32> +// CHECK-SKIP-CPU: } + +// ----- + +// CHECK-SKIP-CPU-LABEL: testSkipCpuOpsWithinLoop +func.func @testSkipCpuOpsWithinLoop(%arg0: tensor) -> tensor { + %0 = "tfl.while"(%arg0) ({ + ^bb0(%block: tensor): + "tfl.yield"(%block) : (tensor) -> () + },{ + ^bb0(%block: tensor): + %0 = "tfl.add"(%arg0, %block) {tac.device = "GPU", fused_activation_function = "RELU6", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + "tfl.yield"(%0) : (tensor) -> () + }) {tac.device = "CPU", fused_activation_function = "RELU6", tac.inference_type = "FLOAT"} : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-SKIP-CPU: "tfl.while" +// CHECK-SKIP-CPU: ^bb0(%[[ARG0:.*]]: tensor): +// CHECK-SKIP-CPU: "tfl.yield"(%[[ARG0]]) : (tensor) -> () +// CHECK-SKIP-CPU: }, { +// CHECK-SKIP-CPU: ^bb0(%[[ARG1:.*]]: tensor): +// CHECK-SKIP-CPU: %[[RES0:.*]] = func.call @func_0_GPU_FLOAT(%{{.*}}, %[[ARG1]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor, tensor) -> tensor +// CHECK-SKIP-CPU: "tfl.yield"(%[[RES0]]) : (tensor) -> () +// CHECK-SKIP-CPU: }) {fused_activation_function = "RELU6", tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor) -> tensor + +// CHECK-SKIP-CPU: func.func private @func_0_GPU_FLOAT(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor) -> tensor attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK-SKIP-CPU: %[[RES1:.*]] = tfl.add %[[ARG2]], %[[ARG3]] {fused_activation_function = "RELU6", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor +// CHECK-SKIP-CPU: return %[[RES1]] : tensor +// CHECK-SKIP-CPU: } + +// ----- + +// CHECK-IGNORE-INFERENCE-TYPE-LABEL: testIgnoreInferenceType +func.func @testIgnoreInferenceType(%arg0: tensor<1x384x384xf32>, %arg1: tensor<1x1x384x!quant.uniform>) -> (tensor<1x384x384xf32>, tensor<1x1x384x!quant.uniform>) { + // These 2 ops are clustered together when `ignore-inference-type` sets to true. + %0 = "tfl.add"(%arg0, %arg0) {tac.device = "GPU", tac.inference_type = "FLOAT", fused_activation_function = "NONE"} : (tensor<1x384x384xf32>, tensor<1x384x384xf32>) -> tensor<1x384x384xf32> + %1 = "tfl.mul"(%arg1, %arg1) {tac.device = "GPU", tac.inference_type = "QUANTIZED_INT8", fused_activation_function = "NONE"} : (tensor<1x1x384x!quant.uniform>, tensor<1x1x384x!quant.uniform>) -> tensor<1x1x384x!quant.uniform> + func.return %0, %1: tensor<1x384x384xf32>, tensor<1x1x384x!quant.uniform> +} + +// CHECK-IGNORE-INFERENCE-TYPE: %[[RES0:.*]]:2 = call @[[FUNC_NAME:.*]](%arg0, %arg1) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1x384x384xf32>, tensor<1x1x384x!quant.uniform>) -> (tensor<1x384x384xf32>, tensor<1x1x384x!quant.uniform>) +// CHECK-IGNORE-INFERENCE-TYPE: return %[[RES0]]#0, %[[RES0]]#1 : tensor<1x384x384xf32>, tensor<1x1x384x!quant.uniform> +// CHECK-IGNORE-INFERENCE-TYPE: } +// CHECK-IGNORE-INFERENCE-TYPE: func.func private @[[FUNC_NAME]](%arg0: tensor<1x384x384xf32>, %arg1: tensor<1x1x384x!quant.uniform>) -> (tensor<1x384x384xf32>, tensor<1x1x384x!quant.uniform>) attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK-IGNORE-INFERENCE-TYPE: %[[RES1:.*]] = tfl.add %arg0, %arg0 {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1x384x384xf32> +// CHECK-IGNORE-INFERENCE-TYPE: %[[RES2:.*]] = tfl.mul %arg1, %arg1 {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "QUANTIZED_INT8"} : tensor<1x1x384x!quant.uniform> +// CHECK-IGNORE-INFERENCE-TYPE: return %[[RES1]], %[[RES2]] : tensor<1x384x384xf32>, tensor<1x1x384x!quant.uniform> +// CHECK-IGNORE-INFERENCE-TYPE: } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc index 15fb7e66477..4efdd053eec 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc @@ -65,7 +65,9 @@ int64_t GetTransferredTensorBytes(func::CallOp from_graph, if (IsQUI8Type(input_type) || IsQI8Type(input_type)) { total_size_transferred += input_type.getNumElements() * 8; } else { - total_size_transferred += input_type.cast().getSizeInBits(); + auto s_type = input_type.cast(); + total_size_transferred += + s_type.getNumElements() * s_type.getElementTypeBitWidth(); } } } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h index 201ce1690d3..f738b2e7a60 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h @@ -37,8 +37,12 @@ std::unique_ptr> CreateTargetAnnotationPass( std::unique_ptr> CreateTargetAnnotationPass( const TacModule* module); -// Create an instance of the RaiseTargetSubgraphsPass. -std::unique_ptr> CreateRaiseTargetSubgraphsPass(); +// Create an instance of the RaiseTargetSubgraphsPass. If `skip_raise_cpu_ops`, +// we skip clustering for CPU ops for better clustering of ops running on other +// ML accelerators. When `ignore_inference_type` is set to true, the inference +// types are set to "NOT_CARE" for better clustering. +std::unique_ptr> CreateRaiseTargetSubgraphsPass( + bool skip_raise_cpu_ops = false, bool ignore_inference_type = false); // Create an instance of the AlternativeSubgraphPass. std::unique_ptr> CreateAlternativeSubgraphPass( diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc index 92ac79aef63..4fd9f945764 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" @@ -65,7 +66,28 @@ class RaiseTargetSubgraphsPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseTargetSubgraphsPass) + RaiseTargetSubgraphsPass() = default; + RaiseTargetSubgraphsPass(const RaiseTargetSubgraphsPass& other) { + this->skip_raise_cpu_ops_ = other.skip_raise_cpu_ops_; + this->ignore_inference_type_ = other.ignore_inference_type_; + } + explicit RaiseTargetSubgraphsPass(bool skip_raise_cpu_ops, + bool ignore_inference_type) { + skip_raise_cpu_ops_ = skip_raise_cpu_ops; + ignore_inference_type_ = ignore_inference_type; + } + private: + Option skip_raise_cpu_ops_{ + *this, "skip-raise-cpu-ops", + llvm::cl::desc("Whether to cluster and raise CPU ops."), + llvm::cl::init(false)}; + + Option ignore_inference_type_{ + *this, "ignore-inference-type", + llvm::cl::desc("Whether to ignore the inference type in clustering."), + llvm::cl::init(false)}; + llvm::StringRef getArgument() const final { return "tfl-raise-target-subgraphs"; } @@ -189,8 +211,11 @@ void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock( return std::string(""); } std::string concat_inference_device_type_string = - absl::StrCat(device_type.value().hardware, "_", - GetInferenceString(device_type.value().inference_type)); + ignore_inference_type_ + ? device_type.value().hardware + : absl::StrCat( + device_type.value().hardware, "_", + GetInferenceString(device_type.value().inference_type)); return concat_inference_device_type_string; }; @@ -208,6 +233,20 @@ void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock( extract(cluster.ops); } } + if (skip_cpu) { + for (auto& op : block) { + auto op_device = GetInferenceDeviceTypeForOp(&op); + if (op_device_is(op, kCpuDeviceName)) + // The recently raised func is device type cpu & `op` is a "CPU". + // Recursivley call again to raise any non-"CPU" subgraphs contained + // within nested region of `op`. + for (auto& region : op.getRegions()) + for (auto& block : region.getBlocks()) + RaiseTargetSubgraphsForBlock(block, builder, module, + /*skip_cpu=*/true, func_count, + side_effect_info); + } + } } void RaiseTargetSubgraphsPass::runOnOperation() { @@ -220,15 +259,18 @@ void RaiseTargetSubgraphsPass::runOnOperation() { for (auto& block : func) { OpBuilder builder = OpBuilder::atBlockBegin(&block); RaiseTargetSubgraphsForBlock(block, builder, module, - /*skip_cpu=*/false, func_count, info); + /*skip_cpu=*/skip_raise_cpu_ops_, func_count, + info); } } } } // namespace -std::unique_ptr> CreateRaiseTargetSubgraphsPass() { - return std::make_unique(); +std::unique_ptr> CreateRaiseTargetSubgraphsPass( + bool skip_raise_cpu_ops, bool ignore_inference_type) { + return std::make_unique(skip_raise_cpu_ops, + ignore_inference_type); } static PassRegistration pass; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h index 78e47dcdee9..392a2713e95 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h @@ -41,7 +41,8 @@ class TacPass : public OperationPass { ~TacPass() override {} - const TargetHardware* GetTargetHardware(const std::string& hardware_name) { + const TargetHardware* GetTargetHardware( + const std::string& hardware_name) const { return module_ != nullptr ? module_->GetTargetHardware(hardware_name) : mlir::TFL::tac::GetTargetHardware(hardware_name); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc index 009f27d936e..2dddad4e9a8 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc @@ -69,6 +69,16 @@ class TargetAnnotationPass : public TacFunctionPass { llvm::cl::desc( "comma separated list of device specs, like CPU, GPU, Hexagon."), llvm::cl::ZeroOrMore}; + + void getDependentDialects(mlir::DialectRegistry& registry) const override { + if (!module_) { + for (const auto& device : device_specs_flag_) { + auto* hardware = this->GetTargetHardware(device); + if (hardware == nullptr) continue; + hardware->GetDependentDialects(registry); + } + } + } }; void SetAnnotation(Operation* op, std::string attribute, std::string annotation, diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index c7ca7accb66..1e4475bd4b3 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -19,54 +19,73 @@ limitations under the License. #include #include +#include #include +#include +#include +#include +#include #include #include +#include #include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/strings/match.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -74,28 +93,33 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/lite/core/c/builtin_op_data.h" +#include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/experimental/remat/metadata_util.h" -#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/graph_info.h" +#include "tensorflow/lite/python/metrics/converter_error_data.pb.h" #include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/tools/versioning/gpu_compatibility.h" #include "tensorflow/lite/tools/versioning/op_version.h" #include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tensorflow/lite/version.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/tstring.h" using llvm::dyn_cast; using llvm::formatv; using llvm::isa; using llvm::StringRef; -using llvm::Twine; using mlir::Dialect; using mlir::ElementsAttr; using mlir::MLIRContext; @@ -105,6 +129,7 @@ using mlir::Operation; using mlir::Region; using mlir::StringAttr; using mlir::TensorType; +using mlir::Twine; using mlir::Type; using mlir::UnknownLoc; using mlir::Value; @@ -124,7 +149,6 @@ using VectorBufferOffset = flatbuffers::Offset>; using CustomOptionsOffset = VectorBufferOffset; -namespace error = tensorflow::error; namespace tfl = mlir::TFL; ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; @@ -142,7 +166,7 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_UINT8; } if (!is_signed) { - return Status(error::INVALID_ARGUMENT, + return Status(absl::StatusCode::kInvalidArgument, "'isSigned' can only be set for 8-bits integer type"); } @@ -164,14 +188,14 @@ static StatusOr GetTFLiteType(Type type, if (ftype.isF64()) { return tflite::TensorType_COMPLEX128; } - return Status(error::INVALID_ARGUMENT, "Unsupported type"); + return Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } else if (auto itype = type.dyn_cast()) { switch (itype.getWidth()) { case 1: return tflite::TensorType_BOOL; case 4: if (itype.isUnsigned()) { - return Status(error::INVALID_ARGUMENT, + return Status(absl::StatusCode::kInvalidArgument, "Unsupported 4bit unsigned int type"); } else { return tflite::TensorType_INT4; @@ -207,7 +231,7 @@ static StatusOr GetTFLiteType(Type type, } // TFLite export fills FLOAT32 for unknown data types. Returning an error // for now for safety and this could be revisited when required. - return Status(error::INVALID_ARGUMENT, "Unsupported type"); + return Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } static bool IsConst(Operation* op) { @@ -335,7 +359,7 @@ static bool HasValidTFLiteType(Value value, T& error_handler) { if (!status.ok()) { return error_handler.emitError( formatv("Failed to convert element type '{0}': {1}", - element_type, status.status().error_message())), + element_type, status.status().message())), false; } return true; @@ -1553,7 +1577,7 @@ std::optional> Translator::BuildSubGraph( } bool failed_once = false; - for (auto& item : llvm::enumerate(bb)) { + for (const auto& item : llvm::enumerate(bb)) { Operation& inst = item.value(); const int operation_index = item.index(); if (inst.hasTrait()) break; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 6421598e76a..487b3edd60a 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -647,7 +647,7 @@ static StatusOr BuildSparseConstOp( } std::vector dense_buffer( value_type.getElementType().getIntOrFloatBitWidth() / CHAR_BIT); - mlir::Attribute dummy_value = + mlir::TypedAttr dummy_value = mlir::DenseIntOrFPElementsAttr::getFromRawBuffer(value_type, dense_buffer); @@ -1376,7 +1376,7 @@ StatusOr ConvertSubgraph( } // Construct MLIR operators from TFLite operators - for (auto& it : llvm::enumerate(subgraph.operators)) { + for (const auto& it : llvm::enumerate(subgraph.operators)) { auto& op = it.value(); if (experimental_prune_unreachable_nodes_unconditionally && @@ -1612,8 +1612,7 @@ OwningOpRef tflite::FlatBufferToMlir( model_control_dependencies[subgraph_index]); if (!func_or_error.ok()) { return emitError(base_loc, "could not translate function ") - << subgraph->name << ": " - << func_or_error.status().error_message(), + << subgraph->name << ": " << func_or_error.status().message(), nullptr; } module.push_back(std::move(func_or_error).value()); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 62f9e220665..f23dfd96e88 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -282,6 +283,11 @@ bool IsI32Type(Type element_type) { return element_type.isInteger(32) && !element_type.isUnsignedInteger(); } +// Return true when the given element_type is UI32. +bool IsUI32Type(Type element_type) { + return element_type.isInteger(32) && element_type.isUnsignedInteger(); +} + // Return true when the given element_type is I64. bool IsI64Type(Type element_type) { return element_type.isInteger(64) && !element_type.isUnsignedInteger(); @@ -335,7 +341,7 @@ bool VerifyAddOpShapeConstraints(AddOp op) { IsI32Type(element_type) || IsI64Type(element_type)) { return VerifyOperandsHaveSameShapesOrBroadcastableShape( /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, - /*max_bcast_rank=*/4); + /*max_bcast_rank=*/6); } // Allows QI16 output when operands have the same shape. @@ -389,8 +395,9 @@ bool VerifyMulOpShapeConstraints(MulOp op) { // Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes, // which are broadcastable shapes up to four dimension or have same shapes. - if (IsI32Type(element_type) || IsI64Type(element_type) || - IsQI16Type(element_type) || element_type.isa() || + if (IsI32Type(element_type) || IsUI32Type(element_type) || + IsI64Type(element_type) || IsQI16Type(element_type) || + IsI16Type(element_type) || element_type.isa() || element_type.isF32()) { return VerifyOperandsHaveSameShapesOrBroadcastableShape( /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, @@ -961,9 +968,9 @@ mlir::LogicalResult CustomOp::verify() { LogicalResult CustomTfOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attr, RegionRange ranges, + DictionaryAttr attr, OpaqueProperties, RegionRange ranges, SmallVectorImpl& inferredReturnTypes) { - CustomTfOpAdaptor op(operands, attr, ranges); + CustomTfOpAdaptor op(operands, attr, {}, ranges); if (op.getRegions().empty()) return success(); auto* real_op = &op.getBody().front().front(); @@ -1226,7 +1233,7 @@ static LogicalResult ComputeConvWindowedOutputSize( LogicalResult Conv2DOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, - DictionaryAttr attr, RegionRange, + DictionaryAttr attr, OpaqueProperties, RegionRange, SmallVectorImpl& inferredReturnTypes) { Conv2DOpAdaptor op(operands, attr); @@ -1711,7 +1718,7 @@ struct ConvertShapeTo1D : public OpRewritePattern { return failure(); } // It is already a 1-D constant, no change. - auto old_shape = shape.getType().getShape(); + auto old_shape = shape.getShapedType().getShape(); if (old_shape.size() == 1) { return failure(); } @@ -1724,7 +1731,7 @@ struct ConvertShapeTo1D : public OpRewritePattern { } } auto new_shape = shape.reshape(tensorflow::GetTypeFromTFTensorShape( - {*old_shape.rbegin()}, shape.getType().getElementType())); + {*old_shape.rbegin()}, shape.getShapedType().getElementType())); rewriter.replaceOpWithNewOp( reshape.getShape().getDefiningOp(), new_shape); return success(); @@ -1907,7 +1914,7 @@ mlir::LogicalResult ReshapeOp::verify() { LogicalResult ReshapeOp::inferReturnTypes( MLIRContext* context, std::optional location, ValueRange operands, - DictionaryAttr attr, RegionRange, + DictionaryAttr attr, OpaqueProperties, RegionRange, SmallVectorImpl& inferredReturnTypes) { ReshapeOpAdaptor op(operands, attr); const Value input = op.getInput(); @@ -2222,7 +2229,7 @@ static void BuildTopKOp(OpBuilder* builder, OperationState& result, Value input, if (!val_type.hasRank()) return TFL::TopKV2Op::build( *builder, result, UnrankedTensorType::get(val_type.getElementType()), - UnrankedTensorType::get(builder->getIntegerType(32)), input, k); + UnrankedTensorType::get(k.getType()), input, k); // Resultant shape is value.shape[:-1] + [k] std::vector shape(val_type.getShape()); @@ -2230,8 +2237,7 @@ static void BuildTopKOp(OpBuilder* builder, OperationState& result, Value input, TFL::TopKV2Op::build( *builder, result, tensorflow::GetTypeFromTFTensorShape(shape, val_type.getElementType()), - tensorflow::GetTypeFromTFTensorShape(shape, builder->getIntegerType(32)), - input, k); + tensorflow::GetTypeFromTFTensorShape(shape, k.getType()), input, k); } //===----------------------------------------------------------------------===// @@ -2285,7 +2291,7 @@ void FakeQuantOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult UnpackOp::inferReturnTypes( MLIRContext* context, std::optional loc, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { UnpackOpAdaptor op(operands, attributes); // TODO(jpienaar): Refactor verify @@ -2646,7 +2652,7 @@ mlir::LogicalResult UnidirectionalSequenceLSTMOp::verify() { LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes( MLIRContext*, std::optional, ValueRange operands, - DictionaryAttr attr, RegionRange, + DictionaryAttr attr, OpaqueProperties, RegionRange, SmallVectorImpl& inferredReturnTypes) { Value input = operands[0]; auto input_type = input.getType().dyn_cast_or_null(); @@ -2922,7 +2928,7 @@ OpFoldResult RankOp::fold(FoldAdaptor adaptor) { assert(operands.size() == 1); auto result_type = getType().cast(); if (auto elements_attr = operands[0].dyn_cast_or_null()) { - auto rank = static_cast(elements_attr.getType().getRank()); + auto rank = static_cast(elements_attr.getShapedType().getRank()); return DenseElementsAttr::get(result_type, {rank}); } @@ -3145,9 +3151,9 @@ OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { auto delta_tensor = operands[2].dyn_cast_or_null(); if (start_tensor && limit_tensor && delta_tensor) { // Operands should all be scalars - assert(start_tensor.getType().getRank() == 0 && - limit_tensor.getType().getRank() == 0 && - delta_tensor.getType().getRank() == 0); + assert(start_tensor.getShapedType().getRank() == 0 && + limit_tensor.getShapedType().getRank() == 0 && + delta_tensor.getShapedType().getRank() == 0); Type elem_type = getType().cast().getElementType(); if (elem_type.isSignlessInteger()) { auto start_attr = start_tensor.getValues()[0]; @@ -3328,9 +3334,12 @@ namespace { // The function recursively traverses the dimensions of the output tensor in // a row-major order and writes the value in the output tensor into // `new_values`. -void ComputePermutation(ElementsAttr input_tensor, ArrayRef perm, - ArrayRef output_shape, int num_dimensions, - int output_axis, std::vector* input_indices, +void ComputePermutation(mlir::detail::ElementsAttrRange< + mlir::detail::ElementsAttrIterator> + input_tensor_values, + ArrayRef perm, ArrayRef output_shape, + const int num_dimensions, const int output_axis, + std::vector* input_indices, std::vector* new_values) { // Refer to the implementation of `Transpose` function in // tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -3343,11 +3352,11 @@ void ComputePermutation(ElementsAttr input_tensor, ArrayRef perm, // recurse into the next axis. const bool is_last_axis = output_axis == num_dimensions - 1; if (is_last_axis) { - new_values->push_back( - input_tensor.getValues()[*input_indices]); + new_values->push_back(input_tensor_values[*input_indices]); } else { - ComputePermutation(input_tensor, perm, output_shape, num_dimensions, - output_axis + 1, input_indices, new_values); + ComputePermutation(input_tensor_values, perm, output_shape, + num_dimensions, output_axis + 1, input_indices, + new_values); } } } @@ -3366,11 +3375,11 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { if (!getType().cast().getElementType().isSignlessIntOrFloat()) return nullptr; - assert(perm_tensor.getType().getRank() == 1); - const int num_dimensions = input_tensor.getType().getRank(); - assert(perm_tensor.getType().getNumElements() == num_dimensions); + assert(perm_tensor.getShapedType().getRank() == 1); + const int num_dimensions = input_tensor.getShapedType().getRank(); + assert(perm_tensor.getShapedType().getNumElements() == num_dimensions); - ArrayRef input_shape = input_tensor.getType().getShape(); + ArrayRef input_shape = input_tensor.getShapedType().getShape(); auto output_type = getType().cast(); SmallVector perm; @@ -3385,9 +3394,10 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { } std::vector new_values; - new_values.reserve(input_tensor.getType().getNumElements()); + new_values.reserve(input_tensor.getShapedType().getNumElements()); std::vector input_indices(num_dimensions); - ComputePermutation(input_tensor, perm, output_shape, num_dimensions, + auto input_tensor_values = input_tensor.getValues(); + ComputePermutation(input_tensor_values, perm, output_shape, num_dimensions, /*output_axis=*/0, &input_indices, &new_values); auto result_type = tensorflow::GetTypeFromTFTensorShape( output_shape, output_type.getElementType()); @@ -3542,7 +3552,7 @@ void IfOp::getSuccessorRegions(std::optional index, // Otherwise, the successor is dependent on the condition. bool condition; if (auto cond_attr = operands.front().dyn_cast_or_null()) { - condition = cond_attr.getValue().isOneValue(); + condition = cond_attr.getValue().isOne(); } else { // If the condition isn't constant, both regions may be executed. regions.push_back(RegionSuccessor(&getThenRegion())); @@ -3703,9 +3713,9 @@ struct WhileResultOperandsMatchAndImplicitCapture // Replace with new While with matching operands and results. Operation* op = while_op.getOperation(); - Operation* new_op = rewriter.insert( - Operation::create(op->getLoc(), op->getName(), types, new_operands, - op->getAttrs(), {}, /*numRegions=*/2)); + Operation* new_op = rewriter.insert(Operation::create( + op->getLoc(), op->getName(), types, new_operands, op->getAttrs(), + op->getPropertiesStorage(), {}, /*numRegions=*/2)); for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i)); int new_index = 0; @@ -4057,7 +4067,7 @@ OpFoldResult EmbeddingLookupOp::fold(FoldAdaptor adaptor) { std::vector new_shape = value_attr.getType().getShape().vec(); new_shape[0] = lookup_attr.getType().getShape()[0]; - Type new_type = value_attr.getType().clone(new_shape); + auto new_type = value_attr.getType().clone(new_shape); return DenseElementsAttr::get(new_type, new_values); } @@ -4086,7 +4096,96 @@ Attribute ConstBytesAttr::parse(AsmParser& parser, Type type) { void ConstBytesAttr::print(mlir::AsmPrinter& printer) const { StringRef bytes_str = getValue(); - printer << " : \"0x" << llvm::toHex(bytes_str) << "\""; + // Elide the attribute if flag is set. + std::optional limit = OpPrintingFlags().getLargeElementsAttrLimit(); + printer << " : \""; + if (limit && limit.value() < bytes_str.size()) { + printer << "__elided__"; + } else { + printer << "0x" << llvm::toHex(bytes_str); + } + printer << "\""; +} + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +int64_t GetTypeBitWidth(mlir::Type type) { + if (auto quant_type = type.dyn_cast()) { + return quant_type.getStorageTypeIntegralWidth(); + } + if (type.isIntOrFloat()) { + return std::max(type.getIntOrFloatBitWidth(), + static_cast(CHAR_BIT)); + } + return -1; +} + +LogicalResult BitcastOp::verify() { + BitcastOp op = *this; + auto input_type = op.getInput().getType().cast(); + auto output_type = op.getOutput().getType().cast(); + + auto input_element_type = input_type.getElementType(); + auto output_element_type = output_type.getElementType(); + + if (input_type.hasStaticShape()) { + const int input_element_type_bitwidth = GetTypeBitWidth(input_element_type); + const int output_element_type_bitwidth = + GetTypeBitWidth(output_element_type); + + if (input_element_type_bitwidth < 0 || output_element_type_bitwidth < 0) { + // Only supports quantized type, int and float types. + return op.emitOpError("Unsupported element type."); + } + + if (input_element_type_bitwidth < output_element_type_bitwidth) { + if (output_element_type_bitwidth % input_element_type_bitwidth != 0) { + return op.emitOpError( + "output element bitwidth is not multiple of input element " + "bitwidth"); + } + if (input_type.getShape().empty() || + input_type.getShape().back() % (output_element_type_bitwidth / + input_element_type_bitwidth) != + 0) { + return op.emitOpError( + "input rightmost dimension size is not multiple of the divisor"); + } + } else if (input_element_type_bitwidth > output_element_type_bitwidth) { + if (input_element_type_bitwidth % output_element_type_bitwidth != 0) { + return op.emitOpError( + "input element bitwidth is not multiple of output element " + "bitwidth"); + } + } + } + return success(); +} + +OpFoldResult BitcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getInput().getType()) return getInput(); + return {}; +} + +//===----------------------------------------------------------------------===// +// DynamicUpdateSliceOp +//===----------------------------------------------------------------------===// + +OpFoldResult DynamicUpdateSliceOp::fold(FoldAdaptor) { + // Check if update replaces the whole tensor, meaning operand and update has + // the same shape and all start indices are zero. + DenseIntElementsAttr indices_attr; + if (matchPattern(getStartIndices(), m_Constant(&indices_attr)) && + indices_attr.isSplat() && indices_attr.getSplatValue() == 0 && + getOperand().getType().hasStaticShape() && + getUpdate().getType().hasStaticShape() && + getOperand().getType() == getUpdate().getType()) { + return getUpdate(); + } + + return {}; } //===----------------------------------------------------------------------===// @@ -4133,7 +4232,7 @@ Operation* TFLDialect::materializeConstant(OpBuilder& builder, Attribute value, value.cast().getType() != type)) return builder.create(loc, type, value.cast()); if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); + return builder.create(loc, type, cast(value)); if (NoValueOp::isBuildableWith(value, type)) return builder.create(loc, type, value.cast()); return nullptr; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index e1cb8de2c57..73740be2310 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_enums.h.inc" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/lite/schema/schema_generated.h" #define GET_ATTRDEF_CLASSES diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index c7991a28ba7..8266fc605c0 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -2362,13 +2362,13 @@ equivalent to setting: }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, + TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$size ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output + TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output ); let hasVerifier = 1; @@ -2528,11 +2528,11 @@ def TFL_MulOp : TFL_Op<"mul", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16, Complex>]>:$lhs, - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16, Complex>]>:$rhs, + ins TFL_TensorOf<[F32, I32, UI32, I64, QI8, QUI8, QI16, I16, Complex>]>:$lhs, + TFL_TensorOf<[F32, I32, UI32, I64, QI8, QUI8, QI16, I16, Complex>]>:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, UI32, I64, QI8, QUI8, QI16, I16, Complex>]>:$output); let hasFolder = 1; @@ -2612,14 +2612,14 @@ def TFL_PackOp : TFL_Op<"pack", [ }]; let arguments = (ins - TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values, + TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$values, ConfinedAttr:$values_count, I32Attr:$axis ); let results = (outs - TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$output ); let hasVerifier = 1; @@ -3128,11 +3128,11 @@ def TFL_SelectOp : TFL_Op<"select", [ let arguments = (ins TFL_BoolTensor:$condition, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$y); let results = (outs - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$output); // TODO(jpienaar): autogenerate this. let builders = [ @@ -3167,11 +3167,11 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [ let arguments = (ins TFL_BoolTensor:$condition, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$y); let results = (outs - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$output); let builders = [ OpBuilder<(ins "Value":$cond, "Value":$x, "Value":$y), @@ -3235,12 +3235,13 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ // FixedOutputRangeInterface: quant::UniformQuantizedType GetFixedOutputRange( bool is_signed, int bit_width) { + if (bit_width != 8 && bit_width != 16) { return nullptr; } auto result_type = getOutput().getType(); // zero_point = 0 // scale = 1. / (max_value + 1) return quant::GetFixedOutputRange(is_signed, bit_width, result_type, - /*scale=*/1.0 / (1<<(bit_width)), - /*zero_point=*/-(1<<(bit_width-1))); + /*scale=*/1.0 / (bit_width == 8 ? (1<<(bit_width)) : (1<<(bit_width-1))), + /*zero_point=*/bit_width == 8 ? -(1<<(bit_width-1)): 0); } }]; } @@ -3457,12 +3458,12 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$input, - TFL_I32Tensor:$k); + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8]>:$input, + TFL_TensorOf<[I16, I32]>:$k); let results = (outs - TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$values, - TFL_I32Tensor:$indices); + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8]>:$values, + TFL_TensorOf<[I16, I32]>:$indices); let builders = [ OpBuilder<(ins "Value":$input, "Value":$k), @@ -3587,13 +3588,13 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8, QI16]>:$input, TFL_TensorOf<[I32]>:$block_shape, TFL_TensorOf<[I32]>:$indices ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, UI8, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I16, I32, I64, UI8, QI8, QUI8, QI16]>:$output ); } @@ -3612,13 +3613,13 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32Tensor:$block_shape, TFL_I32Tensor:$paddings ); let results = (outs - TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output ); } @@ -3863,7 +3864,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input, + TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input, TFL_I32Tensor:$begin, TFL_I32Tensor:$end, TFL_I32Tensor:$strides, @@ -3876,7 +3877,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output + TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output ); // TFLite kernel only supports up to 5D input including added axis. @@ -4028,6 +4029,67 @@ def TFL_DynamicUpdateSliceOp: TFL_Op<"dynamic_update_slice", [ let results = ( outs TFL_TensorOf<[I1, I8, I32, I64, F32]>:$output); + + let hasFolder = 1; +} + +def TFL_BitcastOp : TFL_Op<"bitcast", [Pure]> { + let summary = "Bitcast operator"; + + let description = [{ + Bitcasts a tensor from one type to another. + }]; + + let arguments = (ins AnyTensor:$input); + + let results = (outs AnyTensor:$output); + + // TFLite's bitcast bitop does not utilize options, instead derives types + // from the TfLiteTensors. + let hasOptions = 0; + + let hasFolder = 1; + + let hasVerifier = 1; +} + +def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [ + Commutative, + SameOperandsAndResultElementType, + Pure]> { + let summary = "Bitwise Xor operator"; + + let description = [{ + Elementwise computes the bitwise XOR of `lhs` and `rhs`. + }]; + + let arguments = (ins + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$lhs, + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$rhs + ); + + let results = (outs + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$output + ); +} + +def TFL_RightShiftOp : TFL_Op<"right_shift", [ + SameOperandsAndResultElementType, + Pure]> { + let summary = "Right Shift operator"; + + let description = [{ + Elementwise computes the bitwise right-shift of `lhs` by `rhs`. + }]; + + let arguments = (ins + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$lhs, + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$rhs + ); + + let results = (outs + TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$output + ); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 2249909aa0e..85c87fd66ad 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -35,10 +35,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h index 6d90c2d08f4..e69d3c718d9 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index 998734c8d2a..1b0f22c7cd1 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -44,11 +44,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 9c3ab396b5d..74c09b3e9e6 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -39,10 +39,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index ed339ca64b9..362e9e39ae5 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index bb36ebe81f4..5cfbc0c937a 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -35,10 +35,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc index 89eec9c7349..ae9b67e9e60 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -96,7 +97,7 @@ LogicalResult QuantizedConstRewrite::matchAndRewrite( auto fusedLoc = rewriter.getFusedLoc( {qbarrier.getArg().getDefiningOp()->getLoc(), qbarrier.getLoc()}); auto newConstOp = rewriter.create( - fusedLoc, newConstValueType, newConstValue); + fusedLoc, newConstValueType, cast(newConstValue)); rewriter.replaceOpWithNewOp(qbarrier, qbarrier.getType(), newConstOp); return success(); diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc index 3bd80ad4a7b..d111141958c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc @@ -104,7 +104,7 @@ LogicalResult StatisticsOp::verify() { // Verify layerStats attribute. { - auto layerStatsType = getLayerStats().getType(); + auto layerStatsType = getLayerStats().getShapedType(); if (!layerStatsType.getElementType().isa()) { return emitOpError("layerStats must have a floating point element type"); } @@ -121,7 +121,7 @@ LogicalResult StatisticsOp::verify() { std::accumulate(std::next(shape.begin(), *getAxis()), shape.end(), 1, std::multiplies()); - auto axisStatsType = getAxisStats()->getType(); + auto axisStatsType = getAxisStats()->getShapedType(); if (!axisStatsType.getElementType().isa()) { return emitOpError("axisStats must have a floating point element type"); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index e21105fc5c4..29216f3be16 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -124,7 +124,7 @@ TfLiteStatus QuantizeModel( // If the first or final ops are not quantized, remove QDQ. pm.addPass(TFL::CreatePostQuantizeRemoveQDQPass()); if (failed(pm.run(module.get()))) { - const std::string& err = statusHandler.ConsumeStatus().error_message(); + const std::string err(statusHandler.ConsumeStatus().message()); error_reporter->Report("Failed to quantize: %s", err.c_str()); return kTfLiteError; } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc index ce87e5d8f92..e784cf7a2eb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc @@ -147,7 +147,7 @@ TfLiteStatus QuantizeWeights( tensorflow::AddDynamicRangeQuantizationPasses(quant_specs, pm); if (failed(pm.run(module.get()))) { - absl::string_view err = statusHandler.ConsumeStatus().error_message(); + absl::string_view err = statusHandler.ConsumeStatus().message(); error_reporter->Report("Failed to quantize: %s", err); return kTfLiteError; } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 91bde7ede70..57a5c93556c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -335,7 +335,7 @@ class QuantizationDriver { fn_.walk([&](Operation *op) { std::unique_ptr scale_spec = GetQuantScaleSpec(op); if (op->hasTrait() || - (IsOpNotQuantizable(op) && !scale_spec->has_same_scale_requirement) || + (!IsOpQuantizable(op) && !scale_spec->has_same_scale_requirement) || llvm::isa(op)) { return; @@ -841,7 +841,7 @@ void QuantizationDriver::SetupAllStates() { fn_.walk([&](Operation *op) { std::unique_ptr scale_spec = GetQuantScaleSpec(op); - if (IsOpNotQuantizable(op) && !scale_spec->has_same_scale_requirement) { + if (!IsOpQuantizable(op) && !scale_spec->has_same_scale_requirement) { return; } work_list_.push_back(op); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index c559dd2403f..9a151a80e8f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -16,12 +16,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include +#include #include +#include #include #include #include #include #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -187,27 +190,26 @@ quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( } // namespace -bool IsOpNotQuantizable(Operation* op) { - // If it is terminator or not quantizable or any ops form the mlir quant - // ops dialect, we shouldn't rewrite. - bool attr_enforced_quantizable = +bool IsOpQuantizable(Operation* op) { + if (isa(op)) { + // Constant ops do not have QuantizableResult attribute but they can deal + // with quantized tensors. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + const bool attr_enforced_quantizable = op->hasAttrOfType(kQuantTraitAttrName) && op->getAttrOfType(kQuantTraitAttrName).getValue().str() == QuantTraitValues[QuantizationTrait::FullyQuantizable]; - // Constant ops do not have QuantizableResult attribute but they can deal with - // quantized tensors. - if (llvm::isa( - op)) - return false; - - bool prop_enforced_quantizable = + const bool trait_enforced_quantizable = op->hasTrait(); - return op->hasTrait() || - llvm::isa( - op) || - (!attr_enforced_quantizable && !prop_enforced_quantizable); + return attr_enforced_quantizable || trait_enforced_quantizable; } // Returns the quantized type for the diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index b91db7965b0..1113bb868fa 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -20,31 +20,38 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_ #include +#include +#include #include #include #include #include #include +#include +#include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h" @@ -183,7 +190,7 @@ quant::QuantizedType DownCastScale(quant::QuantizedType type, quant::QuantizedType DownCastScale(quant::QuantizedType type, double min, double max, Location loc); -bool IsOpNotQuantizable(Operation* op); +bool IsOpQuantizable(Operation* op); // Specialized version of location to string for flatbuffer exported locations. inline std::string GetTensorNameFromLoc(Location loc) { @@ -439,7 +446,7 @@ class QuantizationPattern : public RewritePattern { return failure(); } - if (IsOpNotQuantizable(quantizing_op) && + if (!IsOpQuantizable(quantizing_op) && !static_cast(this)->IsQuantizableCustomOp( quantizing_op, custom_map)) { if (!(enable_verify && enable_whole_model_verify)) { @@ -646,7 +653,7 @@ class QuantizationPattern : public RewritePattern { // compared against in parallel. // N.B. the return op will use this floating-point result. Value result; - if (IsOpNotQuantizable(float_op)) { + if (!IsOpQuantizable(float_op)) { // For not quantizable ops, search for dequantize attached to the // quantized op of the output. if (Operation* quantize_op = dyn_cast_or_null( diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index a9614c0e62c..8c9035f2184 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -65,7 +65,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model, pm.addPass(TFL::CreateDenseToSparsePass()); if (failed(pm.run(module.get()))) { - const std::string& err = statusHandler.ConsumeStatus().error_message(); + const std::string err(statusHandler.ConsumeStatus().message()); error_reporter->Report("Failed to sparsify: %s", err.c_str()); return kTfLiteError; } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 27ca62b52cb..258da6bc55c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -132,7 +132,7 @@ cc_library( ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", @@ -199,7 +199,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", - "//tensorflow/compiler/mlir/tf2xla:tf_xla_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", @@ -317,6 +317,33 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_tf_xla_call_module_to_stablehlo_pass", + srcs = [ + "transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc", + ], + hdrs = [ + "transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Transforms", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_serialization", + "@stablehlo//:vhlo_ops", + ], + alwayslink = 1, +) + cc_library( name = "optimize", srcs = [ @@ -362,7 +389,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", - "//tensorflow/compiler/mlir/tf2xla:legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", @@ -387,6 +414,7 @@ tf_cc_binary( deps = [ ":fold_broadcast_pass", ":fuse_convolution_pass", + ":legalize_tf_xla_call_module_to_stablehlo_pass", ":optimize", ":stablehlo_tfl", ":tf_stablehlo", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index 88330c7356b..525d73c1b79 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -115,6 +115,11 @@ opt elide_large_elements_attrs( "e", llvm::cl::desc("Elide large elements attrs."), llvm::cl::Optional, llvm::cl::init(false)); +// NOLINTNEXTLINE +opt debug_info( + "debug-info", llvm::cl::desc("Inclide MLIR debug location info in output."), + llvm::cl::Optional, llvm::cl::init(false)); + // NOLINTNEXTLINE opt allow_tf("allow-tf", llvm::cl::desc("Allow TF dialect."), llvm::cl::Optional, llvm::cl::init(false)); @@ -143,6 +148,11 @@ opt freeze_tf_graph( llvm::cl::desc("Freeze TF graph to remove tf.ResourceVariable, etc."), llvm::cl::Optional, llvm::cl::init(false)); +// NOLINTNEXTLINE +opt exported_model_signatures( + "exported_model_signatures", llvm::cl::desc("model signature names"), + llvm::cl::Optional, llvm::cl::init("serving_default")); + namespace mlir { namespace odml { @@ -165,7 +175,8 @@ tensorflow::StatusOr> ImportSavedModelOrMLIR( // TODO(pulkitb): Remove hard-coded tag. std::unordered_set tags({"serve"}); - auto exported_names_in_vector = std::vector({}); + std::vector exported_names_in_vector = + absl::StrSplit(exported_model_signatures, ','); absl::Span exported_names(exported_names_in_vector); std::vector custom_opdefs; @@ -217,6 +228,9 @@ tensorflow::Status ExportModule(mlir::ModuleOp module, std::string result; llvm::raw_string_ostream os(result); OpPrintingFlags printing_flags; + if (debug_info) { + printing_flags.enableDebugInfo(); + } if (elide_large_elements_attrs) { printing_flags.elideLargeElementsAttrs(); } @@ -232,7 +246,10 @@ tensorflow::Status ExportModule(mlir::ModuleOp module, tensorflow::Status ConvertTFToStableHLO( ModuleOp tf_module, const PassPipelineCLParser& pass_pipeline) { PassManager pm(tf_module.getContext()); - applyPassManagerCLOptions(pm); + if (failed(applyPassManagerCLOptions(pm))) { + return tensorflow::errors::Aborted( + "Failed to apply MLIR pass manager CL options."); + } auto error_handler = [&](const Twine& msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; @@ -330,14 +347,14 @@ tensorflow::Status RunConverter(const PassPipelineCLParser& pass_pipeline) { ExportModule(*module, output_path, elide_large_elements_attrs); if (!conversion_status.ok()) { LOG(ERROR) << "TF to StableHLO conversion failed: " - << conversion_status.error_message(); + << conversion_status.message(); auto debug_export_status = ExportModule( *module, absl::StrCat(verbose_dir, "/debug_stablehlo.mlir"), elide_large_elements_attrs); if (!debug_export_status.ok()) { LOG(ERROR) << "Failed to export debug_stablehlo.mlir: " - << debug_export_status.error_message(); + << debug_export_status.message(); } return conversion_status; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.cc b/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.cc index 74708abebe8..29f4977b1bf 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/serializer/flatbuffer_translator.cc @@ -759,7 +759,7 @@ Translator::BuildSubGraph(const std::string& name, Region* region, int index) { } bool failed_once = false; - for (auto& item : llvm::enumerate(bb)) { + for (const auto& item : llvm::enumerate(bb)) { Operation& inst = item.value(); const int operation_index = item.index(); if (inst.hasTrait()) break; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc index 45c8edc1ec5..9918d044c5b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc @@ -84,7 +84,7 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern { // Only fuses multiplier if all dimensions other than the out channel // dimension are equal to 1. if (!TFL::IsDimensionsDegenerateExceptLastOne( - mul_value.getType().getShape())) { + mul_value.getShapedType().getShape())) { return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { diag << "entities 'mul_value' failed to satisfy constraint: " "unsupported dimensions"; @@ -97,9 +97,10 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern { } // Rewrite - broadcast_dims = broadcast_op.getBroadcastDimensions(); + broadcast_dims = + broadcast_op ? broadcast_op.getBroadcastDimensions() : nullptr; if (broadcast_dims == nullptr) { - const auto filter_rank = filter_value.getType().getRank(); + const auto filter_rank = filter_value.getShapedType().getRank(); auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64)); broadcast_dims = DenseIntElementsAttr::get(dimsType, {filter_rank - 1}); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc new file mode 100644 index 00000000000..b7277ae0415 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc @@ -0,0 +1,176 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" + +#include +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Serialization.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace odml { + +static constexpr std::string_view kStablehloModuleDefaultEntryFuncName = "main"; +static constexpr std::string_view kStablehloFuncNamePrefix = "XlaCallModule"; + +class ConvertTFXlaCallModuleOp + : public mlir::OpRewritePattern { + public: + explicit ConvertTFXlaCallModuleOp(MLIRContext *context, ModuleOp module_op) + : OpRewritePattern(context), + module_op_(module_op) {} + using OpRewritePattern::OpRewritePattern; + + private: + ModuleOp module_op_; + mlir::LogicalResult matchAndRewrite( + mlir::TF::XlaCallModuleOp op, PatternRewriter &rewriter) const override { + mlir::OwningOpRef stablehlo_module_op = + mlir::stablehlo::deserializePortableArtifact(op.getModuleAttr(), + getContext()); + if (stablehlo_module_op.get() == nullptr) { + return mlir::failure(); + } + SymbolTable parent_module_symbol_table(module_op_); + SymbolTable stablehlo_module_symbol_table(stablehlo_module_op.get()); + if (stablehlo_module_symbol_table.lookup( + kStablehloModuleDefaultEntryFuncName) == nullptr) { + return rewriter.notifyMatchFailure( + op, "could not find main function in XlaCallModuleOp"); + } + mlir::Builder stablehlo_builder(stablehlo_module_op.get().getContext()); + // Rename XlaCallModuleOp's functions to avoid naming conflicts. + for (auto func_op : + stablehlo_module_op.get().getOps()) { + const std::string new_func_name = + CreateNewFuncName(func_op.getSymName(), parent_module_symbol_table); + if (failed(stablehlo_module_symbol_table.replaceAllSymbolUses( + func_op, stablehlo_builder.getStringAttr(new_func_name), + stablehlo_module_op.get()))) { + return mlir::failure(); + } + mlir::SymbolTable::setSymbolName(func_op, new_func_name); + } + // Move all functions from XlaCallModuleOp's stablehlo module, to parent + // module. Also marks the stablehlo module entry function as private. + mlir::func::FuncOp main_fn; + for (auto func_op : + stablehlo_module_op.get().getOps()) { + mlir::func::FuncOp cloned_func_op = func_op.clone(); + if (cloned_func_op.getSymName().contains( + kStablehloModuleDefaultEntryFuncName)) { + main_fn = cloned_func_op; + main_fn.setSymVisibility(stablehlo_builder.getStringAttr("private")); + } + parent_module_symbol_table.insert(cloned_func_op); + } + + // The stablehlo module main function's input tensor types might be + // different from the XlaCallModuleOp's input tensor types. For example, + // The XlaCallModuleOp's input is tensor<*xf32> while the function's + // argument type is tensor<1x2f32>. + llvm::SmallVector casted_operands; + casted_operands.reserve(main_fn.getNumArguments()); + for (const auto &operand_and_type : + zip(op.getOperands(), main_fn.getFunctionType().getInputs())) { + Value operand = std::get<0>(operand_and_type); + Type expected_type = std::get<1>(operand_and_type); + if (operand.getType() != expected_type) { + operand = rewriter.create( + op.getLoc(), expected_type, operand, + /*Truncate=*/rewriter.getBoolAttr(false)); + } + casted_operands.push_back(operand); + } + + auto call = rewriter.create( + op->getLoc(), main_fn.getSymName(), main_fn.getResultTypes(), + casted_operands); + rewriter.replaceOp(op, call->getResults()); + + return mlir::success(); + } + + // Creates a new function name to avoid collision. The naming scheme is + // XlaCallModule_%s_%d where %s is the original function name and %d is the + // counter. + std::string CreateNewFuncName(const StringRef func_name, + SymbolTable &symbol_table) const { + int suffix_id = 0; + std::string new_func_name = absl::StrCat(kStablehloFuncNamePrefix, "_", + func_name.str(), "_", suffix_id); + while (symbol_table.lookup(new_func_name)) { + suffix_id++; + new_func_name = absl::StrCat(kStablehloFuncNamePrefix, "_", + func_name.str(), "_", suffix_id); + } + return new_func_name; + } +}; + +class TFXlaCallModuleOpToStablehloPass + : public PassWrapper> { + public: + StringRef getArgument() const final { + return "tf-xla-call-module-op-to-stablehlo-pass"; + } + StringRef getDescription() const final { + return "Legalize TF_XlaCallModule Op to stablehlo"; + } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module_op = getOperation(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), module_op); + if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr> +CreateLegalizeTFXlaCallModuleToStablehloPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h new file mode 100644 index 00000000000..9bcee095f27 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Adds passes which transform TF_XlaCallModule Op to StableHLO Ops. +// Note that this pass only supports static shape tensors for now. +std::unique_ptr> +CreateLegalizeTFXlaCallModuleToStablehloPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index 14cbe5963e3..476b02e0bd8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -65,7 +65,6 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, pm.addNestedPass(CreateSmuggleDisallowedOpsPass()); pm.addPass(mlir::createCanonicalizerPass()); } - pm.addPass(CreateDropSavedModelSemanticsPass()); } void AddStablehloOptimizationPasses(OpPassManager& pm) { diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index 082c0627b09..17b724051cd 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -326,3 +326,21 @@ func.func @broadcast_to_to_reshape_i64_const(%arg0: tensor<4x4x4xf32>) -> tensor // CHECK-SAME: (tensor<4x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x4xf32> func.return %0 : tensor<1x4x4x4xf32> } + +// ----- + +func.func @trivial_dynamic_update_slice(%arg0: tensor<2x7x14xf32>, %arg1: tensor<2x7x14xf32>) -> tensor<2x7x14xf32> { + %0 = arith.constant dense<0> : tensor<3xi32> + %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<2x7x14xf32>, tensor<2x7x14xf32>, tensor<3xi32>) -> tensor<2x7x14xf32> + // CHECK: return %arg1 + func.return %1 : tensor<2x7x14xf32> +} + +// ----- + +func.func @trivial_dynamic_update_slice_wrong_update_shape(%arg0: tensor<2x7x14xf32>, %arg1: tensor<2x7x7xf32>) -> tensor<2x7x14xf32> { + %0 = arith.constant dense<0> : tensor<3xi32> + %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<2x7x14xf32>, tensor<2x7x7xf32>, tensor<3xi32>) -> tensor<2x7x14xf32> + // CHECK: "tfl.dynamic_update_slice" + func.return %1 : tensor<2x7x14xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/disallow_stateful_partitioned_call.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/disallow_stateful_partitioned_call.pbtxt new file mode 100644 index 00000000000..db6998bd7d9 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/end2end/disallow_stateful_partitioned_call.pbtxt @@ -0,0 +1,195 @@ +# RUN: not tf_tfl_translate -tf-input-arrays=input0 -tf-input-shapes=-1 -tf-input-data-types=DT_FLOAT -tf-output-arrays=add %s 2>&1 | FileCheck %s +# CHECK: error: The Graph contains unsupported `StatefulPartionedCallOp`(s) + +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} +node { + name: "args_0" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "spc1" + op: "StatefulPartitionedCall" + input: "input0" + input: "args_0" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_RESOURCE + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + type: DT_RESOURCE + } + } + } + attr { + key: "config" + value { + s: "" + } + } + attr { + key: "config_proto" + value { + s: "" + } + } + attr { + key: "executor_type" + value { + s: "" + } + } + attr { + key: "f" + value { + func { + name: "function" + } + } + } +} +node { + name: "spc2" + op: "StatefulPartitionedCall" + input: "input0" + input: "args_0" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_RESOURCE + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + type: DT_RESOURCE + } + } + } + attr { + key: "config" + value { + s: "" + } + } + attr { + key: "config_proto" + value { + s: "" + } + } + attr { + key: "executor_type" + value { + s: "" + } + } + attr { + key: "f" + value { + func { + name: "function" + } + } + } +} +node { + name: "add" + op: "Add" + input: "spc1" + input: "spc2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +library { + function { + signature { + name: "function" + input_arg { + name: "inputs" + type: DT_FLOAT + } + input_arg { + name: "statefulpartitionedcall_args_1" + type: DT_RESOURCE + } + output_arg { + name: "identity" + type: DT_FLOAT + } + is_stateful: true + } + node_def { + name: "Identity" + op: "Identity" + input: "inputs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_user_specified_name" + value { + s: "inputs" + } + } + } + } + arg_attr { + key: 1 + value { + } + } + } +} +versions { + producer: 121 +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 109ab804748..4f58b7af868 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1865,26 +1865,29 @@ func.func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: // ----- -func.func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> { +func.func @test5DAddWithImplicitBroadcast(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> { %0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> func.return %0 : tensor<1x1x1x3x4xi32> -// CHECK-LABEL: add_with_int32_5d_inputs -// CHECK: [[CST:%.*]] = arith.constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64> -// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) -// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) -// CHECK: tfl.add [[BCT]], [[BCT_0]] +// CHECK-LABEL: test5DAddWithImplicitBroadcast +// CHECK: %0 = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> } -// CHECK-LABEL: testAddWithBroadcastToOps -func.func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { - // CHECK: [[CST:%.*]] = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> - // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) - // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) - // CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32> +func.func @test6DAddWithImplicitBroadcast(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { +// CHECK-LABEL: test6DAddWithImplicitBroadcast +// CHECK: %0 = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> func.return %0 : tensor<1x2x3x4x5x6xi32> } +func.func @add_with_int32_7d_inputs(%arg0: tensor<1x1x1x1x1x3x1xi32>, %arg1 : tensor<1x1x1x1x1x1x4xi32>) -> tensor<1x1x1x1x1x3x4xi32> { + %0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x1x1x3x1xi32>, tensor<1x1x1x1x1x1x4xi32>) -> tensor<1x1x1x1x1x3x4xi32> + func.return %0 : tensor<1x1x1x1x1x3x4xi32> +// CHECK-LABEL: add_with_int32_7d_inputs +// CHECK: %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x1x1x1x1x3x1xi32>, tensor<7xi64>) -> tensor<1x1x1x1x1x3x4xi32> +// CHECK: %1 = "tfl.broadcast_to"(%arg1, %cst) : (tensor<1x1x1x1x1x1x4xi32>, tensor<7xi64>) -> tensor<1x1x1x1x1x3x4xi32> +// CHECK: %2 = tfl.add %0, %1 {fused_activation_function = "NONE"} : tensor<1x1x1x1x1x3x4xi32> +} + // CHECK-LABEL: testSubWithBroadcastToOps func.func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { // CHECK: [[CST:%.*]] = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> @@ -2338,6 +2341,24 @@ func.func @mul_i64(%arg0: tensor<14xi64>, %arg1: tensor<14xi64>) -> tensor<14xi6 // CHECK: return } +func.func @mul_i16(%arg0: tensor<14xi16>, %arg1: tensor<14xi16>) -> tensor<14xi16> { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xi16>, tensor<14xi16>) -> tensor<14xi16> + func.return %0: tensor<14xi16> + +// CHECK-LABEL: mul_i16 +// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<14xi16> +// CHECK: return +} + +func.func @mul_ui32(%arg0: tensor<14xui32>, %arg1: tensor<14xui32>) -> tensor<14xui32> { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xui32>, tensor<14xui32>) -> tensor<14xui32> + func.return %0: tensor<14xui32> + +// CHECK-LABEL: mul_ui32 +// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<14xui32> +// CHECK: return +} + func.func @mul_complex32(%arg0: tensor<14xcomplex>, %arg1: tensor<14xcomplex>) -> tensor<14xcomplex> { %0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xcomplex>, tensor<14xcomplex>) -> tensor<14xcomplex> func.return %0: tensor<14xcomplex> @@ -2515,6 +2536,69 @@ func.func @sign(%arg0: tensor<8xf32>) -> tensor<8xf32> { // CHECK: return %[[RES0]] : tensor<8xf32> } +func.func @bitcast(%arg0: tensor<8xi32>) -> tensor<8xui32> { + %0 = "tf.Bitcast"(%arg0) : (tensor<8xi32>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + +// CHECK-LABEL: bitcast +// CHECK: %[[RES0:.*]] = "tfl.bitcast"(%arg0) : (tensor<8xi32>) -> tensor<8xui32> +// CHECK: return %[[RES0]] : tensor<8xui32> +} + +func.func @bitcastI32ToI16(%arg0: tensor<8xi32>) -> tensor<8x2xi16> { + %0 = "tf.Bitcast"(%arg0) : (tensor<8xi32>) -> tensor<8x2xi16> + func.return %0 : tensor<8x2xi16> + +// CHECK-LABEL: bitcastI32ToI16 +// CHECK: %[[RES0:.*]] = "tfl.bitcast"(%arg0) : (tensor<8xi32>) -> tensor<8x2xi16> +// CHECK: return %[[RES0]] : tensor<8x2xi16> +} + +func.func @bitcastI16ToUI32(%arg0: tensor<8x2xi16>) -> tensor<8xui32> { + %0 = "tf.Bitcast"(%arg0) : (tensor<8x2xi16>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + +// CHECK-LABEL: bitcastI16ToUI32 +// CHECK: %[[RES0:.*]] = "tfl.bitcast"(%arg0) : (tensor<8x2xi16>) -> tensor<8xui32> +// CHECK: return %[[RES0]] : tensor<8xui32> +} + +func.func @bitcastFloatToI16(%arg0: tensor<8xf32>) -> tensor<8x2xi16> { + %0 = "tf.Bitcast"(%arg0) : (tensor<8xf32>) -> tensor<8x2xi16> + func.return %0 : tensor<8x2xi16> + +// CHECK-LABEL: bitcastFloatToI16 +// CHECK: %[[RES0:.*]] = "tfl.bitcast"(%arg0) : (tensor<8xf32>) -> tensor<8x2xi16> +// CHECK: return %[[RES0]] : tensor<8x2xi16> +} + +func.func @bitcastI16ToFloat(%arg0: tensor<8x2xi16>) -> tensor<8xf32> { + %0 = "tf.Bitcast"(%arg0) : (tensor<8x2xi16>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> + +// CHECK-LABEL: bitcastI16ToFloat +// CHECK: %[[RES0:.*]] = "tfl.bitcast"(%arg0) : (tensor<8x2xi16>) -> tensor<8xf32> +// CHECK: return %[[RES0]] : tensor<8xf32> +} + +func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> { + %0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + + // CHECK-LABEL: testBitwiseXor + // CHECK: %[[RES0:.*]] = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + // CHECK: return %[[RES0]] : tensor<8xui32> +} + +func.func @testRightShift(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> { + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + + // CHECK-LABEL: testRightShift + // CHECK: %[[RES0:.*]] = "tfl.right_shift"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + // CHECK: return %[[RES0]] : tensor<8xui32> +} + // ============================================================================= // Training OPs // ============================================================================= diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index b424a26964e..628e523c488 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -352,6 +352,22 @@ func.func @testMul(tensor, tensor) -> tensor { func.return %0#0 : tensor } +// CHECK-LABEL: testMul32BitUInt +func.func @testMul32BitUInt(tensor, tensor) -> tensor { +^bb0(%arg0: tensor, %arg1: tensor): + // CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor + func.return %0#0 : tensor +} + +// CHECK-LABEL: testMul16BitInt +func.func @testMul16BitInt(tensor, tensor) -> tensor { +^bb0(%arg0: tensor, %arg1: tensor): + // CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor + func.return %0#0 : tensor +} + // CHECK-LABEL: testMulComplex func.func @testMulComplex(tensor>, tensor>) -> tensor> { ^bb0(%arg0: tensor>, %arg1: tensor>): @@ -397,6 +413,14 @@ func.func @mul_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi3 // ----- +func.func @add_with_i32_five_dim_broadcasting(tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> { +^bb0(%arg0: tensor<1x1x1x1x1xi32>, %arg1: tensor<1xi32>): + %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6"} : (tensor<1x1x1x1x1xi32>, tensor<1xi32>) -> tensor<1x1x1x1x1xi32> + func.return %0#0 : tensor<1x1x1x1x1xi32> +} + +// ----- + func.func @mul_with_quantized_i16_five_dim_broadcasting(tensor<1x1x1x1x1x!quant.any>, tensor<1x!quant.any>) -> tensor<1x1x1x1x1x!quant.any> { ^bb0(%arg0: tensor<1x1x1x1x1x!quant.any>, %arg1: tensor<1x!quant.any>): // expected-error @+1 {{Operands do not have valid shapes}} @@ -1429,6 +1453,7 @@ func.func @unpackQuantized(%arg0: tensor<2x3x!quant.uniform>) -> t // ----- func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+2 {{failed to infer returned types}} // expected-error @+1 {{output count should match 'num' attribute}} %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) func.return %0#0 : tensor<2xi32> @@ -1437,6 +1462,7 @@ func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+2 {{failed to infer returned types}} // expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = 2, and rank = 2}} %0:3 = "tfl.unpack"(%arg0) {axis = 2 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) func.return %0#0 : tensor<2xi32> @@ -1445,6 +1471,7 @@ func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+2 {{failed to infer returned types}} // expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = -3, and rank = 2}} %0:3 = "tfl.unpack"(%arg0) {axis = -3 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) func.return %0#0 : tensor<2xi32> @@ -1453,6 +1480,7 @@ func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- func.func @unpack(%arg0: tensor) -> tensor<2xi32> { + // expected-error @+2 {{failed to infer returned types}} // expected-error @+1 {{input should be of rank larger than 0}} %0:3 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 3 : i32} : (tensor) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) func.return %0#0 : tensor<2xi32> @@ -1461,6 +1489,7 @@ func.func @unpack(%arg0: tensor) -> tensor<2xi32> { // ----- func.func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+2 {{failed to infer returned types}} // expected-error @+1 {{op inferred type(s) 'tensor<2xi32>', 'tensor<2xi32>', 'tensor<2xi32>' are incompatible with return type(s) of operation 'tensor<2xi32>', 'tensor<2x1xi32>', 'tensor<2xi32>'}} %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2x1xi32>, tensor<2xi32>) func.return %0#0 : tensor<2xi32> @@ -3121,3 +3150,34 @@ func.func @testUnsortedSegmentMin(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>, % func.return %0 : tensor<8xf32> // CHECK: return %0 : tensor<8xf32> } + + +// ----- + +// CHECK-LABEL: testBitcast +func.func @testBitcast(%arg0: tensor<8xui32>) -> tensor<8xi32> { + // CHECK: "tfl.bitcast"(%arg0) + %0 = "tfl.bitcast"(%arg0) : (tensor<8xui32>) -> tensor<8xi32> + func.return %0 : tensor<8xi32> + // CHECK: return %0 : tensor<8xi32> +} + +// ----- + +// CHECK-LABEL: testBitwiseXor +func.func @testBitwiseXor(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> { + // CHECK: "tfl.bitwise_xor"(%arg0, %arg1) + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + // CHECK: return %0 : tensor<8xui32> +} + +// ----- + +// CHECK-LABEL: testRightShift +func.func @testRightShift(%arg0: tensor<8xui32>, %arg1: tensor<8xui32>) -> tensor<8xui32> { + // CHECK: "tfl.right_shift"(%arg0, %arg1) + %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<8xui32>, tensor<8xui32>) -> tensor<8xui32> + func.return %0 : tensor<8xui32> + // CHECK: return %0 : tensor<8xui32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 05880e8ef43..8d57178a47f 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -13,7 +13,7 @@ func.func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x32x32x16xf32>) -> tensor<256x32x32x16xf32> func.return %1 : tensor<256x32x32x16xf32> - + // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x32x32x16xf32> // CHECK: return %0 } @@ -60,6 +60,25 @@ func.func @fuseAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) } +// CHECK-LABEL: fuse4DAddIntoConv2d +func.func @fuse4DAddIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<2x3x3x3xf32>) -> tensor<256x32x32x2xf32> { + %cst = arith.constant dense<[[[[1.0, 2.0]]]]> : tensor<1x1x1x2xf32> + %cst_0 = arith.constant dense<[1.0, 2.0]> : tensor<2xf32> + %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) { + dilation_h_factor = 1 : i32, + dilation_w_factor = 1 : i32, + fused_activation_function = "NONE", + padding = "SAME", + stride_h = 1 : i32, + stride_w = 1 : i32 + } : (tensor<256x32x32x3xf32>, tensor<2x3x3x3xf32>, tensor<2xf32>) -> tensor<256x32x32x2xf32> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x32x32x2xf32>, tensor<1x1x1x2xf32>) -> tensor<256x32x32x2xf32> + func.return %1 : tensor<256x32x32x2xf32> + + // CHECK-DAG: %cst = arith.constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> + // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) +} + // CHECK-LABEL: fuseSubIntoConv2d func.func @fuseSubIntoConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> tensor<256x32x32x16xf32> { %cst = arith.constant dense<0.5> : tensor<16xf32> @@ -217,12 +236,20 @@ func.func @fuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: ten } // CHECK-LABEL: dontFuseSubIntoDepthwiseConv2d -func.func @dontFuseSubIntoDepthwiseConv2d(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> { - %cst = arith.constant dense<0.5> : tensor<1x16xf32> - %cst_0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> - %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) {depth_multiplier = 4 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> - %1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<1x16xf32>) -> tensor<256x30x30x16xf32> - func.return %1 : tensor<256x30x30x16xf32> +func.func @dontFuseSubIntoDepthwiseConv2d(%arg0: tensor<256x3x3x3xf32>, %arg1: tensor<3x3x3x5xf32>) -> tensor<256x2x2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [-1.0, -2.0, -3.0, -4.0]]> : tensor<2x4xf32> + %cst_0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst_0) { + depth_multiplier = 4 : i32, + dilation_h_factor = 2 : i32, + dilation_w_factor = 3 : i32, + fused_activation_function = "NONE", + padding = "SAME", + stride_h = 4 : i32, + stride_w = 5 : i32 + } : (tensor<256x3x3x3xf32>, tensor<3x3x3x5xf32>, tensor<4xf32>) -> tensor<256x2x2x4xf32> + %1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x2x2x4xf32>, tensor<2x4xf32>) -> tensor<256x2x2x4xf32> + func.return %1 : tensor<256x2x2x4xf32> // CHECK: "tfl.depthwise_conv_2d" // CHECK: tfl.sub @@ -432,6 +459,23 @@ func.func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor< // CHECK: return %0 } +// CHECK-LABEL: @fuse4DMulIntoDepthwiseConv2d +func.func @fuse4DMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> { + %cst0 = arith.constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32> + %cst1 = arith.constant dense<2.0> : tensor<2xf32> + %cst2 = arith.constant dense<[[[[1.0, 2.0]]]]> : tensor<1x1x1x2xf32> + + %0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> + %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x112x112x2xf32> + + func.return %1 : tensor<1x112x112x2xf32> + +// CHECK-DAG: %cst = arith.constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00], [5.000000e+00, 1.200000e+01]], {{\[\[}}7.000000e+00, 1.600000e+01], [9.000000e+00, 2.000000e+01], [1.100000e+01, 2.400000e+01]], {{\[\[}}1.300000e+01, 2.800000e+01], [1.500000e+01, 3.200000e+01], [1.700000e+01, 3.600000e+01]]]]> : tensor<1x3x3x2xf32> +// CHECK-DAG: %cst_0 = arith.constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> +// CHECK: return %0 +} + // CHECK-LABEL: @notFuseMulIntoDepthwiseConv2d func.func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x4x4x2xf32>) -> tensor<1x4x4x2xf32> { %cst0 = arith.constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32> @@ -464,6 +508,21 @@ func.func @FuseFullyConnectedAddWithNoBias(%arg0: tensor<40x37xf32>, %arg1: tens // CHECK: return %[[fc]] } +// CHECK-LABEL: @FuseFullyConnectedReducedAddWithNoBias +func.func @FuseFullyConnectedReducedAddWithNoBias(%arg0: tensor<1024x1x126xf32>, %arg1: tensor<128x126xf32>) -> tensor<1024x1x128xf32> { + %cst = "tfl.no_value"() {value} : () -> none + %cst2 = arith.constant dense<2.0> : tensor<1x1x128xf32> + + %0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1024x1x126xf32>, tensor<128x126xf32>, none) -> (tensor<1024x1x128xf32>) + %1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1024x1x128xf32>, tensor<1x1x128xf32>) -> tensor<1024x1x128xf32> + + func.return %1 : tensor<1024x1x128xf32> + + // CHECK-DAG: %cst = arith.constant dense<2.000000e+00> : tensor<128xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %cst) + // CHECK: return %[[fc]] +} + // CHECK-LABEL: @FuseFullyConnectedAddWithExistingBias func.func @FuseFullyConnectedAddWithExistingBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { %cst = arith.constant dense<3.0> : tensor<40xf32> @@ -552,6 +611,38 @@ func.func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: te // FOLD: return %[[fc]] } +// CHECK-LABEL: @RemoveRedundantReshapeUsedAsInputToBinaryOp +func.func @RemoveRedundantReshapeUsedAsInputToBinaryOp(%arg0: tensor<128xf32>, %arg1: tensor<1x512x512x128xf32>, %arg2: tensor<1x512x512x128xf32>) -> (tensor<1x512x512x128xf32>, tensor<1x512x512x128xf32>) { + %cst_10 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> + + %894 = "tfl.reshape"(%arg0, %cst_10) : (tensor<128xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> + %895 = "tfl.mul"(%894, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x128xf32>, tensor<1x512x512x128xf32>) -> tensor<1x512x512x128xf32> + %896 = "tfl.mul"(%arg2, %894) {fused_activation_function = "NONE"} : (tensor<1x512x512x128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x512x512x128xf32> + + return %895, %896 : tensor<1x512x512x128xf32>, tensor<1x512x512x128xf32> + + // CHECK: %0 = tfl.mul(%arg0, %arg1) + // CHECK: %1 = tfl.mul(%arg2, %arg0) + // CHECK: return %0, %1 +} + +// CHECK-LABEL: @RetainRedundantReshapeUseInNonBinaryOp +func.func @RetainRedundantReshapeUseInNonBinaryOp(%arg0: tensor<128xf32>, %arg1: tensor<1x512x512x128xf32>, %arg2: tensor<1x512x512x128xf32>) -> (tensor<1x512x512x128xf32>, tensor<128xf32>) { + %cst = arith.constant dense<0> : tensor<1xi32> + %cst_10 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> + %894 = "tfl.reshape"(%arg0, %cst_10) : (tensor<128xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> + %895 = "tfl.mul"(%894, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x128xf32>, tensor<1x512x512x128xf32>) -> tensor<1x512x512x128xf32> + %896 = "tfl.reduce_max"(%894, %cst) {keep_dims = false} : (tensor<1x1x1x128xf32>, tensor<1xi32>) -> tensor<128xf32> + return %895, %896 : tensor<1x512x512x128xf32>, tensor<128xf32> + + // CHECK-DAG: %cst = arith.constant dense<0> : tensor<1xi32> + // CHECK-DAG: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<128xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> + // CHECK: %1 = tfl.mul(%0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x128xf32>, tensor<1x512x512x128xf32>) -> tensor<1x512x512x128xf32> + // CHECK: %2 = "tfl.reduce_max"(%0, %cst) {keep_dims = false} : (tensor<1x1x1x128xf32>, tensor<1xi32>) -> tensor<128xf32> + // CHECK: return %1, %2 +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAddConstWithOptionalAttribute // FOLD-LABEL: @FuseFullyConnectedReshapeAddConstWithOptionalAttribute func.func @FuseFullyConnectedReshapeAddConstWithOptionalAttribute(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { @@ -618,6 +709,24 @@ func.func @FuseFullyConnectedReshapeAdd2DConst(%arg0: tensor<40x37xf32>, %arg1: // CHECK: return %[[rs]] } +// CHECK-LABEL: @FuseFCReshapeAdd2DConst2 +func.func @FuseFCReshapeAdd2DConst2(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> { + %cst = "tfl.no_value"() {value} : () -> none + %cst2 = arith.constant dense<2.0> : tensor<1x1x4x10xf32> + %shape = arith.constant dense<[1, 40, 4, 10]> : tensor<4xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>) + %1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x4x10xf32>, tensor<1x1x4x10xf32>) -> tensor<1x40x4x10xf32> + + func.return %2 : tensor<1x40x4x10xf32> + + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<2.000000e+00> : tensor<40xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] + // CHECK: return %[[rs]] +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithActivation func.func @FuseFullyConnectedReshapeAdd2DConstWithActivation(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> { %cst = "tfl.no_value"() {value} : () -> none @@ -636,6 +745,24 @@ func.func @FuseFullyConnectedReshapeAdd2DConstWithActivation(%arg0: tensor<40x37 // CHECK: return %[[rs]] } +// CHECK-LABEL: @FuseFCReshapeAdd2DConstWithActvtn2 +func.func @FuseFCReshapeAdd2DConstWithActvtn2(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> { + %cst = "tfl.no_value"() {value} : () -> none + %cst2 = arith.constant dense<2.0> : tensor<1x1x4x10xf32> + %shape = arith.constant dense<[1, 40, 4, 10]> : tensor<4xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>) + %1 = "tfl.reshape"(%0, %shape) : (tensor<40x40xf32>, tensor<4xi32>) -> tensor<1x40x4x10xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x40x4x10xf32>, tensor<1x1x4x10xf32>) -> tensor<1x40x4x10xf32> + + func.return %2 : tensor<1x40x4x10xf32> + + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<2.000000e+00> : tensor<40xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[rs:.*]] = "tfl.reshape"(%[[fc]] + // CHECK: return %[[rs]] +} + // CHECK-LABEL: @FuseFullyConnectedReshapeAdd2DConstWithExistingBias func.func @FuseFullyConnectedReshapeAdd2DConstWithExistingBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<1x40x4x10xf32> { %cst = arith.constant dense<3.0> : tensor<40xf32> @@ -775,6 +902,17 @@ func.func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> ten // CHECK: return %[[rs2]] } +// CHECK-LABEL: @MinimumOfReluAnd6ToRelu6 +func.func @MinimumOfReluAnd6ToRelu6(%arg0: tensor<40x40xf32>) -> tensor<40x40xf32> { + %cst = arith.constant dense<6.0> : tensor + %2 = "tfl.relu"(%arg0) : (tensor<40x40xf32>) -> tensor<40x40xf32> + %3 = "tfl.minimum"(%2, %cst) : (tensor<40x40xf32>, tensor) -> tensor<40x40xf32> + func.return %3 : tensor<40x40xf32> + + // CHECK: %[[rs1:.*]] = "tfl.relu6"(%arg0 + // CHECK: return %[[rs1]] +} + // CHECK-LABEL: @NotReorderElementwiseValueOpAndMoveOp func.func @NotReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> (tensor<40x40xf32>, tensor<40x40xf32>) { %shape = arith.constant dense<[40, 40]> : tensor<2xi32> @@ -1809,6 +1947,32 @@ func.func @DontConvertConstSelectMixed(%arg0: tensor<2xf32>, %arg1: tensor<2xf32 // CHECK: return %0, %1 } +// CHECK-LABEL: FuseBroadcastToIntoSelect +func.func @FuseBroadcastToIntoSelect(%arg0: tensor<1x8x1024x2048xf32>, %arg1: tensor<1x8x1024x2048xf32>, %arg2: tensor<1x1x1x2048xi1>) -> (tensor<1x8x1024x2048xf32>, tensor<1x8x1024x2048xf32>) { + %cst_0 = arith.constant dense<[1, 8, 1024, 2048]> : tensor<4xi32> + %0 = "tfl.broadcast_to"(%arg2, %cst_0) : (tensor<1x1x1x2048xi1>, tensor<4xi32>) -> tensor<1x8x1024x2048xi1> + %1 = "tfl.select"(%0, %arg0, %arg1) : (tensor<1x8x1024x2048xi1>, tensor<1x8x1024x2048xf32>, tensor<1x8x1024x2048xf32>) -> tensor<1x8x1024x2048xf32> + %2 = "tfl.select_v2"(%0, %arg0, %arg1) : (tensor<1x8x1024x2048xi1>, tensor<1x8x1024x2048xf32>, tensor<1x8x1024x2048xf32>) -> tensor<1x8x1024x2048xf32> + func.return %1, %2 : tensor<1x8x1024x2048xf32>, tensor<1x8x1024x2048xf32> + // CHECK: %0 = "tfl.select_v2"(%arg2, %arg0, %arg1) : (tensor<1x1x1x2048xi1>, tensor<1x8x1024x2048xf32>, tensor<1x8x1024x2048xf32>) -> tensor<1x8x1024x2048xf32> + // CHECK: %1 = "tfl.select_v2"(%arg2, %arg0, %arg1) : (tensor<1x1x1x2048xi1>, tensor<1x8x1024x2048xf32>, tensor<1x8x1024x2048xf32>) -> tensor<1x8x1024x2048xf32> + // CHECK: return %0, %1 +} + +// CHECK-LABEL: FuseBroadcastToIntoSelect1 +func.func @FuseBroadcastToIntoSelect1(%arg0: tensor<1x1x8x1024x2048xf32>, %arg1: tensor<1x1x8x1024x2048xf32>, %arg2: tensor<1x1x1x1x2048xi1>) -> tensor<1x1x8x1024x2048xf32> { + %cst_0 = arith.constant dense<[1, 1, 8, 1024, 2048]> : tensor<5xi32> + %0 = "tfl.broadcast_to"(%arg2, %cst_0) : (tensor<1x1x1x1x2048xi1>, tensor<5xi32>) -> tensor<1x1x8x1024x2048xi1> + %1 = "tfl.select"(%0, %arg0, %arg1) : (tensor<1x1x8x1024x2048xi1>, tensor<1x1x8x1024x2048xf32>, tensor<1x1x8x1024x2048xf32>) -> tensor<1x1x8x1024x2048xf32> + + func.return %1 : tensor<1x1x8x1024x2048xf32> + // CHECK-DAG: %cst = arith.constant dense<[1, 1, 8, 1024, 2048]> : tensor<5xi32> + // CHECK: %0 = "tfl.broadcast_to"(%arg2, %cst) : (tensor<1x1x1x1x2048xi1>, tensor<5xi32>) -> tensor<1x1x8x1024x2048xi1> + // CHECK: %1 = "tfl.select"(%0, %arg0, %arg1) : (tensor<1x1x8x1024x2048xi1>, tensor<1x1x8x1024x2048xf32>, tensor<1x1x8x1024x2048xf32>) -> tensor<1x1x8x1024x2048xf32> + + // CHECK: return %1 +} + // CHECK-LABEL: CheckSelectNegated func.func @CheckSelectNegated(%arg0: tensor<1x2x3x4xi1>, %arg1: tensor<1x2x3x4xf32>, %arg2: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) { %not = "tfl.logical_not"(%arg0) : (tensor<1x2x3x4xi1>) -> tensor<1x2x3x4xi1> @@ -2381,51 +2545,42 @@ func.func @fuseUnpackAndConcatToReshape(%arg0: tensor<1x3x2xf32>) -> tensor<1x6x // CHECK: return %[[RES]] } -// CHECK-LABEL: replaceReshapeEqualWithOneHot -func.func @replaceReshapeEqualWithOneHot(%arg: tensor<2xi32>) -> tensor<2x3xi1> { - // Good match: Replace with one_hot - %shape = arith.constant dense<[2, 1]> : tensor<2xi32> +// CHECK-LABEL: replaceReshapeEqualWithOneHotSingleDim +func.func @replaceReshapeEqualWithOneHotSingleDim(%arg: tensor<1xi32>) -> tensor<3xi1> { %cst = arith.constant dense<[0, 1, 2]> : tensor<3xi32> - %tmp = "tfl.reshape"(%arg, %shape) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x1xi32> - %result = "tfl.equal"(%tmp, %cst) : (tensor<2x1xi32>, tensor<3xi32>) -> tensor<2x3xi1> + %result = "tfl.equal"(%arg, %cst) : (tensor<1xi32>, tensor<3xi32>) -> tensor<3xi1> + func.return %result : tensor<3xi1> + + // CHECK-NOT: tfl.one_hot +} + +// CHECK-LABEL: replaceReshapeEqualWithOneHot +func.func @replaceReshapeEqualWithOneHot(%arg: tensor<2x1xi32>) -> tensor<2x3xi1> { + // Good match: Replace with one_hot + %cst = arith.constant dense<[0, 1, 2]> : tensor<3xi32> + %result = "tfl.equal"(%arg, %cst) : (tensor<2x1xi32>, tensor<3xi32>) -> tensor<2x3xi1> func.return %result : tensor<2x3xi1> // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<3> : tensor // CHECK-DAG: %[[CST2:.*]] = arith.constant dense : tensor // CHECK-DAG: %[[CST3:.*]] = arith.constant dense : tensor - // CHECK: %[[RES:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xi1> + // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<2> : tensor<1xi32> + // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST4]]) : (tensor<2x1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK: %[[RES:.*]] = "tfl.one_hot"(%[[RESHAPE]], %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2xi32>, tensor, tensor, tensor) -> tensor<2x3xi1> } -// CHECK-LABEL: replaceReshapeEqualWithOneHotWithNonTrivialReshape -func.func @replaceReshapeEqualWithOneHotWithNonTrivialReshape(%arg: tensor<4x4xi32>) -> tensor<16x3xi1> { - // Good match: Replace with one_hot - %shape = arith.constant dense<[16, 1]> : tensor<2xi32> +// CHECK-LABEL: ReplaceReshapeEqualWithOneHotWithBatchingDim +func.func @ReplaceReshapeEqualWithOneHotWithBatchingDim(%arg: tensor<2x2x1xi32>) -> tensor<2x2x3xi1> { %cst = arith.constant dense<[0, 1, 2]> : tensor<3xi32> - %tmp = "tfl.reshape"(%arg, %shape) : (tensor<4x4xi32>, tensor<2xi32>) -> tensor<16x1xi32> - %result = "tfl.equal"(%tmp, %cst) : (tensor<16x1xi32>, tensor<3xi32>) -> tensor<16x3xi1> - func.return %result : tensor<16x3xi1> + %result = "tfl.equal"(%arg, %cst) : (tensor<2x2x1xi32>, tensor<3xi32>) -> tensor<2x2x3xi1> + func.return %result : tensor<2x2x3xi1> // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<3> : tensor // CHECK-DAG: %[[CST2:.*]] = arith.constant dense : tensor // CHECK-DAG: %[[CST3:.*]] = arith.constant dense : tensor - // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<16> : tensor<1xi32> - // CHECK-DAG: %[[TMP:.*]] = "tfl.reshape"(%arg0, %[[CST4]]) : (tensor<4x4xi32>, tensor<1xi32>) -> tensor<16xi32> - // CHECK: %[[RES:.*]] = "tfl.one_hot"(%[[TMP]], %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<16xi32>, tensor, tensor, tensor) -> tensor<16x3xi1> -} - -// CHECK-LABEL: noReplaceReshapeEqualWithOneHotWithBatchingDim -func.func @noReplaceReshapeEqualWithOneHotWithBatchingDim(%arg: tensor<2xi32>) -> tensor<1x2x3xi1> { - // Do not replace: shape length longer than 2 - %shape = arith.constant dense<[1, 2, 1]> : tensor<3xi32> - %cst = arith.constant dense<[0, 1, 2]> : tensor<3xi32> - %tmp = "tfl.reshape"(%arg, %shape) : (tensor<2xi32>, tensor<3xi32>) -> tensor<1x2x1xi32> - %result = "tfl.equal"(%tmp, %cst) : (tensor<1x2x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi1> - func.return %result : tensor<1x2x3xi1> - - // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xi32> - // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<[0, 1, 2]> : tensor<3xi32> - // CHECK: %[[TMP:.*]] = "tfl.reshape"(%arg0, %[[CST1]]) : (tensor<2xi32>, tensor<3xi32>) -> tensor<1x2x1xi32> - // CHECK: %[[RES:.*]] = "tfl.equal"(%[[TMP]], %[[CST2]]) : (tensor<1x2x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi1> + // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<2> : tensor<2xi32> + // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST4]]) : (tensor<2x2x1xi32>, tensor<2xi32>) -> tensor<2x2xi32> + // CHECK: %[[RES:.*]] = "tfl.one_hot"(%[[RESHAPE]], %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<2x2xi32>, tensor, tensor, tensor) -> tensor<2x2x3xi1> } // CHECK-LABEL: noReplaceReshapeEqualWithOneHotBadShape @@ -2549,8 +2704,8 @@ func.func @dontReplaceOneHotFullyConnectedWithLookupBadIndexTypeWithOptionalAttr // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) {asymmetric_quantize_inputs = true, } -// CHECK-LABEL: dontReplaceOneHotFullyConnectedWithLookupBadIndexRank -func.func @dontReplaceOneHotFullyConnectedWithLookupBadIndexRank(%arg: tensor<11x2xi32>) -> tensor<11x2x5xf32> { +// CHECK-LABEL: ReplaceOneHotFullyConnectedWithLookup2DRank +func.func @ReplaceOneHotFullyConnectedWithLookup2DRank(%arg: tensor<11x2xi32>) -> tensor<11x2x5xf32> { %depth = arith.constant dense<3> : tensor %on = arith.constant dense<1.0> : tensor %off = arith.constant dense<0.0> : tensor @@ -2558,18 +2713,16 @@ func.func @dontReplaceOneHotFullyConnectedWithLookupBadIndexRank(%arg: tensor<11 %bias = "tfl.no_value"() {value} : () -> none %tmp = "tfl.one_hot"(%arg, %depth, %on, %off) {axis = -1 : i32} : (tensor<11x2xi32>, tensor, tensor, tensor) -> tensor<11x2x3xf32> - %result = "tfl.fully_connected"(%tmp, %filter, %bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<11x2x3xf32>, tensor<5x3xf32>, none) -> tensor<11x2x5xf32> + %result = "tfl.fully_connected"(%tmp, %filter, %bias) {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<11x2x3xf32>, tensor<5x3xf32>, none) -> tensor<11x2x5xf32> func.return %result : tensor<11x2x5xf32> - // CHECK-NOT: "tfl.embedding_lookup" - // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<3> : tensor - // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<1.000000e+00> : tensor - // CHECK-DAG: %[[CST3:.*]] = arith.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[CST4:.*]] = arith.constant dense<7.000000e+00> : tensor<5x3xf32> - // CHECK-DAG: %[[CST5:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[TMP:.*]] = "tfl.one_hot"(%arg0, %[[CST1]], %[[CST2]], %[[CST3]]) {axis = -1 : i32} : (tensor<11x2xi32>, tensor, tensor, tensor) -> tensor<11x2x3xf32> - // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%[[TMP]], %[[CST4]], %[[CST5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<11x2x3xf32>, tensor<5x3xf32>, none) -> tensor<11x2x5xf32> + // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<22> : tensor<1xi32> + // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<7.000000e+00> : tensor<3x5xf32> + // CHECK-DAG: %[[CST2:.*]] = arith.constant dense<[11, 2, 5]> : tensor<3xi32> + // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST0]]) : (tensor<11x2xi32>, tensor<1xi32>) -> tensor<22xi32> + // CHECK: %[[TMP:.*]] = "tfl.embedding_lookup"(%[[RESHAPE]], %[[CST1]]) : (tensor<22xi32>, tensor<3x5xf32>) -> tensor<22x5xf32> + // CHECK: %[[RES:.*]] = "tfl.reshape"(%[[TMP]], %[[CST2]]) : (tensor<22x5xf32>, tensor<3xi32>) -> tensor<11x2x5xf32> // CHECK: return %[[RES]] : tensor<11x2x5xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir index 17c793cd19f..d2e04734e0e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir @@ -222,8 +222,8 @@ func.func @QuantizeFixedOutputRangeInterfaceOpSoftmax(%arg0: tensor<1x1xf32>) -> // CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> // CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> // CHECK-NEXT: %[[sm:.*]] = "tfl.softmax"(%[[dq1]]) {{{.*}}} : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[sm]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[sm]]) {qtype = tensor<1x1x!quant.uniform>, volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> } // CHECK-LABEL: QuantizeFixedOutputRangeInterfaceOpL2Normalization diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 9836ec1ba15..a668475a9e2 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -193,6 +193,18 @@ func.func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor< // CHECK: return %arg0, %arg1, %arg2, %0 } +func.func @sharding(%arg0: tensor<10x10xi32>) -> (tensor<10x10xi32>) { + %0 = "tf.MatMul"(%arg0, %arg0) {device = "", transpose_a = false, transpose_b = false} : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<10x10xi32> + %1 = "tf.MatMul"(%arg0, %arg0) {device = "", transpose_a = false, transpose_b = false} : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<10x10xi32> + %2 = "tf.XlaSharding"(%0) {_XlaSharding = "\08\03\1A\02\01\01\22\01\00", device = "", sharding = "\08\03\1A\02\01\01\22\01\00", unspecified_dims = []} : (tensor<10x10xi32>) -> tensor<10x10xi32> + %3 = "tf.XlaSharding"(%1) {_XlaSharding = "\08\03\1A\02\01\01\22\01\00", device = "", sharding = "\08\03\1A\02\01\01\22\01\00", unspecified_dims = []} : (tensor<10x10xi32>) -> tensor<10x10xi32> + %4 = "tf.AddV2"(%2, %3) {device = ""} : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<10x10xi32> + func.return %4 : tensor<10x10xi32> + +// CHECK-LABEL: sharding +// CHECK-NOT: %2 = "tf.XlaSharding"(%0) {_XlaSharding = "\08\03\1A\02\01\01\22\01\00", device = "", sharding = "\08\03\1A\02\01\01\22\01\00", unspecified_dims = []} : (tensor<10x10xi32>) -> tensor<10x10xi32> +// CHECK-NOT: %3 = "tf.XlaSharding"(%1) {_XlaSharding = "\08\03\1A\02\01\01\22\01\00", device = "", sharding = "\08\03\1A\02\01\01\22\01\00", unspecified_dims = []} : (tensor<10x10xi32>) -> tensor<10x10xi32> +} func.func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> { %166 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", _output_shapes = ["tfshape$dim { size = 1} dim { size = 1000}"], device = "", name = "matmul", transpose_a = false, transpose_b = false} : (tensor<1x1280xf32>, tensor<1280x1000xf32>) -> tensor<1x1000xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir index 34b9a54bc91..5baa9811229 100644 --- a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir +++ b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir @@ -68,6 +68,7 @@ func.func @testConv2dShapeInferenceDynamic(%arg0: tensor<1x?x?x128xf32>, %arg1: module attributes {tf.versions = {producer = 888 : i32}} { func.func @testConv2dShapeInvalidRanks(%arg0: tensor<1x112x80xf32>, %arg1: tensor<128x3x3x128xf32>, %arg2: tensor<128xf32>) -> tensor<1x?x?x128xf32> { + // expected-error @+2 {{'tfl.conv_2d' op failed to infer returned types}} // expected-error @+1 {{Invalid ranks}} %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x80xf32>, tensor<128x3x3x128xf32>, tensor<128xf32>) -> tensor<1x?x?x128xf32> func.return %0 : tensor<1x?x?x128xf32> diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index e7a613b3184..86dbe9c513e 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -44,10 +45,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" @@ -56,6 +59,8 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/lite/python/metrics/converter_error_data.pb.h" #include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" #include "tensorflow/tsl/platform/statusor.h" @@ -93,6 +98,25 @@ mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) { return mlir::success(); } +mlir::LogicalResult GraphContainsStatefulPartitionedOp(mlir::ModuleOp module) { + auto result = module.walk([&](Operation* op) { + return llvm::isa_and_nonnull(op) + ? mlir::WalkResult::interrupt() + : mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + // StatefulPartitionedCall ops are not supported by the tflite runtime. + mlir::TFL::AttachErrorCode( + module.emitError( + "The Graph contains unsupported `StatefulPartionedCallOp`(s), will " + "retry with `guarantee_all_funcs_used_once`"), + tflite::metrics::ConverterErrorData:: + ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR); + return mlir::failure(); + } + return mlir::success(); +} + // Util that registers 'extra_tf_opdefs' to the TF global registry. // Return OK on success, failure if registering failed. Status RegisterExtraTfOpDefs(absl::Span extra_tf_opdefs) { @@ -143,17 +167,19 @@ StatusOr> LoadFromGraphdefOrMlirSource( if (use_splatted_constant) { return tensorflow::GraphdefToSplattedMlirTranslateFunction( - file->getBuffer(), debug_info_file, input_arrays, input_dtypes, - input_shapes, output_arrays, control_output_arrays, - specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, + file->getBuffer(), debug_info_file, /*xla_compile_device_type=*/"", + input_arrays, input_dtypes, input_shapes, output_arrays, + control_output_arrays, specs.prune_unused_nodes, + /*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false, specs.upgrade_legacy, /*enable_shape_inference=*/false, /*unconditionally_use_set_output_shapes=*/true, context); } return tensorflow::GraphdefToMlirTranslateFunction( - file->getBuffer(), debug_info_file, input_arrays, input_dtypes, - input_shapes, output_arrays, control_output_arrays, - specs.prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, + file->getBuffer(), debug_info_file, /*xla_compile_device_type=*/"", + input_arrays, input_dtypes, input_shapes, output_arrays, + control_output_arrays, specs.prune_unused_nodes, + /*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false, specs.upgrade_legacy, /*enable_shape_inference=*/false, /*unconditionally_use_set_output_shapes=*/true, context); @@ -212,20 +238,25 @@ Status ConvertTFExecutorToStablehloFlatbuffer( return errors::Aborted("Failed to preprocess & freeze TF graph"); } - // The default minimum number of elements a weights array must have to be - // quantized by this transformation. - const int kWeightsMinNumElementsDefault = 1024; + // TODO(b/264218457): Refactor the component below once StableHLO Quantizer + // can run DRQ. Temporarily using TF Quantization for StableHLO DRQ. + if (!toco_flags.has_quantization_options()) { + // The default minimum number of elements a weights array must have to be + // quantized by this transformation. + const int kWeightsMinNumElementsDefault = 1024; - tensorflow::quantization::QuantizationOptions quantization_options; + tensorflow::quantization::QuantizationOptions quantization_options; - quantization_options.mutable_quantization_method()->set_experimental_method( - tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE); - quantization_options.set_op_set( - tensorflow::quantization::UNIFORM_QUANTIZED); - quantization_options.set_min_num_elements_for_weights( - kWeightsMinNumElementsDefault); - tensorflow::quantization::AddQuantizePtqDynamicRangePasses( - pass_manager, quantization_options); + quantization_options.mutable_quantization_method() + ->set_experimental_method( + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE); + quantization_options.set_op_set( + tensorflow::quantization::UNIFORM_QUANTIZED); + quantization_options.set_min_num_elements_for_weights( + kWeightsMinNumElementsDefault); + tensorflow::quantization::AddQuantizePtqDynamicRangePasses( + pass_manager, quantization_options); + } if (failed(pass_manager.run(module))) { return statusHandler.ConsumeStatus(); } @@ -237,6 +268,10 @@ Status ConvertTFExecutorToStablehloFlatbuffer( // Print out a detailed report of non-converted stats. pass_manager.addPass(mlir::odml::createPrintOpStatsPass()); mlir::odml::AddStablehloOptimizationPasses(pass_manager); + if (toco_flags.has_quantization_options()) { + stablehlo::quantization::AddQuantizationPasses( + pass_manager, toco_flags.quantization_options()); + } if (failed(pass_manager.run(module))) { return statusHandler.ConsumeStatus(); } @@ -285,7 +320,10 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::PassManager pass_manager(module.getContext()); mlir::registerPassManagerCLOptions(); - mlir::applyPassManagerCLOptions(pass_manager); + if (mlir::failed(mlir::applyPassManagerCLOptions(pass_manager))) { + return tensorflow::FromAbslStatus( + absl::UnknownError("failed to apply MLIR pass manager CL options")); + } pass_manager.addInstrumentation( std::make_unique( pass_manager.getContext())); @@ -345,6 +383,10 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( return status; } + if (failed(GraphContainsStatefulPartitionedOp(module))) { + return statusHandler.ConsumeStatus(); + } + if (export_to_mlir) { llvm::raw_string_ostream os(*result); module.print(os); diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 76a8de86dc3..bf4224c7631 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -116,8 +116,7 @@ void DefaultQuantParamsPass::runOnOperation() { } func.walk([&](Operation *op) { - if (quant::IsOpNotQuantizable(op) || - op->getParentOfType()) { + if (!quant::IsOpQuantizable(op) || op->getParentOfType()) { return; } diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index 27c4763f179..51068fcf4ac 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -368,7 +368,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( "SpaceToBatchND op's padding doesn't have same shape/type with " "BatchToSpaceND op's crops"); } - int64_t m = stb_paddings_attr.getType().getDimSize(0); + int64_t m = stb_paddings_attr.getShapedType().getDimSize(0); // padding - crop. for (uint64_t i = 0; i < m; ++i) { for (uint64_t j = 0; j < 2; ++j) { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index a8f0ea36135..a020a4be43a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -21,6 +21,7 @@ include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/lite/utils/utils.td" def CreateEmptyBoolAttr : NativeCodeCall<"::mlir::BoolAttr()">; @@ -29,10 +30,10 @@ def DenseElementsAttr : ElementsAttrBase< "non-opaque constant tensor">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; + CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; def Int64ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getType().getElementType().isInteger(64)">, "Int 64 constant tensor">; + CPred<"$_self.cast().getShapedType().getElementType().isInteger(64)">, "Int 64 constant tensor">; // Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr // with builder. @@ -65,11 +66,6 @@ def ExtractSingleElementAsInt32 : NativeCodeCall< def CreateTFCastToInt32Op : NativeCodeCall< "CreateCastToInt32($0, $_loc, $_builder)">; -// Checks whether the given operation has static shapes and same shapes of all inputs. -def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">; -def HasSameStaticShapes : Constraint; -def HasNotSameStaticShapes : Constraint, "op must have not static same input shapes">; - def CreateNoneValue : NativeCodeCall< "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; @@ -234,11 +230,11 @@ def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), - [(HasSameStaticShapes $src_op)]>; + [(OpHasSameStaticShapes $src_op)]>; def LegalizeSelectV2NotSameStaticShape : Pat< (TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), - [(HasNotSameStaticShapes $src_op)]>; + [(OpHasNotSameStaticShapes $src_op)]>; def LegalizeShape : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>; def LegalizeSigmoid : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>; def LegalizeSin : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>; @@ -577,6 +573,14 @@ def LegalizeAtan2 : Pat<(TF_Atan2Op $y, $x), (TFL_Atan2Op $y, $x)>; def LegalizeSign : Pat<(TF_SignOp $x), (TFL_SignOp $x)>; +def LegalizeBitcast : Pat<(TF_BitcastOp $x), (TFL_BitcastOp $x)>; + +def LegalizeBitwiseXor : Pat<(TF_BitwiseXorOp $l, $r), + (TFL_BitwiseXorOp $l, $r)>; + +def LegalizeRightShift : Pat<(TF_RightShiftOp $l, $r), + (TFL_RightShiftOp $l, $r)>; + // ============================================================================= // Training OPs // ============================================================================= diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 555a879c956..2a80ef14cc4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -92,28 +92,6 @@ class LegalizeTFPass : public impl::LegalizeTFPassBase { void runOnOperation() override; }; -// Returns true if all tensor value in `values` has static shape and same shape. -bool HasSameStaticShapes(Operation* op) { - auto values = op->getOperands(); - int index = 0; - ArrayRef shape; - for (Value value : values) { - auto shaped_type = value.getType().dyn_cast(); - if (!shaped_type || !shaped_type.hasStaticShape()) { - return false; - } - if (index == 0) { - shape = shaped_type.getShape(); - } else { - if (shape != shaped_type.getShape()) { - return false; - } - } - ++index; - } - return true; -} - // Util that casts 'val' to Int32 by adding a cast Op. Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { IntegerType new_ele_type = rewriter.getIntegerType(32); diff --git a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc index d939d74c5dd..91abd715cdf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc @@ -230,7 +230,7 @@ class LiftFlexCustomOp : public OpRewritePattern { StatusOr mlir_attr = tensorflow::ConvertAttributeValue(attr_value, &builder); if (!mlir_attr.ok()) { - return emitError(loc, mlir_attr.status().error_message()); + return emitError(loc, mlir_attr.status().message()); } attributes.push_back(builder.getNamedAttr(attr_name, *mlir_attr)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 5be3dcac0e0..ae09b8faded 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -632,7 +632,7 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { // as specified by element_dtype. RankedTensorType zero_type = tensorflow::GetTypeFromTFTensorShape({}, element_dtype); - Attribute zero_attr = rewriter.getZeroAttr(zero_type); + auto zero_attr = rewriter.getZeroAttr(zero_type); auto zero = rewriter.create(loc, zero_type, zero_attr); rewriter.replaceOpWithNewOp(op, result_type, list_shape, zero); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index a7e404233ce..9b74c6bf606 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -17,14 +17,17 @@ limitations under the License. // optimizes them to resulting operations in TensorFlowLite dialect. #include +#include #include #include #include #include #include +#include #include #include #include +#include #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" @@ -40,6 +43,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project @@ -70,6 +74,14 @@ constexpr char kRelu[] = "RELU"; constexpr char kRelu6[] = "RELU6"; constexpr char kRelu1[] = "RELU_N1_TO_1"; +ElementsAttr FlattenTo1D(Attribute a) { + auto elements = a.cast(); + const std::array flattened_shape = {elements.getNumElements()}; + auto new_type = RankedTensorType::get(flattened_shape, + elements.getType().getElementType()); + return elements.reshape(new_type); +} + bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { if (axis.getNumElements() == 0) { return false; @@ -138,14 +150,49 @@ bool IsTailOfShape(Type type1, Type type2) { return std::equal(i1, e1, i2); } +// This function removes explicit broadcasting on type1 and returns whether if +// the reduced `type1` dimensions are the same as the ending dimensions +// of `type2`. +bool IsReducedTailOfShape(Type type1, Type type2) { + auto tail_type = type1.dyn_cast(); + auto full_type = type2.dyn_cast(); + if (!tail_type || !full_type || !tail_type.hasRank() || !full_type.hasRank()) + return false; + + auto i1 = tail_type.getShape().rbegin(); + auto reduced_e1 = tail_type.getShape().rend(); + auto i2 = full_type.getShape().rbegin(); + + while ((std::distance(i1, reduced_e1) > 0) && (*(reduced_e1 - 1) == 1)) { + reduced_e1--; + } + + return (std::distance(i1, reduced_e1) > 0) && + (std::distance(i1, reduced_e1) <= full_type.getRank()) && + (std::equal(i1, reduced_e1, i2)); +} + +// Check if the value of the last dimension of type1 is equal to the number of +// elements in type2. This is a required condition to flatten type2 to form a +// 1D array and allow the binaryOp handle the broadcasting implicitly. +bool IsLastDimEqualToNumElements(Type type1, Type type2) { + return (type1.cast().getRank() >= 1 && + type1.cast().getDimSize( + type1.cast().getRank() - 1) == + type2.cast().getNumElements()); +} + bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef filter_shape, const ArrayRef elements_shape, bool is_depthwise) { - // Also, val tensor must be of rank 1 or 0 (scalar). - const auto elements_rank = elements_shape.size(); - if (elements_rank != 1 && elements_rank != 0) { - return false; + // Val tensor must be a scalar or of a shape [1, ... , 1, elements_depth]. + const int elements_rank = elements_shape.size(); + for (int i = 0; i < elements_rank - 1; ++i) { + if (elements_shape[i] != 1) { + return false; + } } + auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back(); // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we // can let binary op to broadcast elements. @@ -313,16 +360,6 @@ DenseElementsAttr GetShape(Value output_val) { llvm::ArrayRef(shape)); } -static Type GetShapeStrippedType(TypeAttr type_attr) { - auto type = type_attr.getValue(); - auto shaped_type = type.dyn_cast(); - if (shaped_type) { - return shaped_type.getElementType(); - } else { - return type; - } -} - // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in // the specified `shape` and `false` otherwise. static bool ShapeMatchesReduceWithKeepAxes(Value input, @@ -396,65 +433,63 @@ static bool FloatValueEquals(const Attribute &attr, double value) { }); } +// Returns true if `value` is compile-time constant and its splat value equals +// to `raw_value`. +template +bool IsConstantValueOf(mlir::TypedAttr value, T raw_value) { + auto element_type = value.getType().cast().getElementType(); + + if (element_type.isa()) { + return FloatValueEquals(value, raw_value); + } else if (element_type.isa()) { + auto int_attr = value.dyn_cast_or_null(); + if (!int_attr) return false; + + if (int_attr.isSplat()) { + return int_attr.getSplatValue() == raw_value; + } + return llvm::all_of(int_attr.getValues(), + [raw_value](const APInt &f) { return f == raw_value; }); + } + + return false; +} + // Returns true if the value's element type is F32. bool IsF32Value(Value value) { return value.getType().cast().getElementType().isF32(); } -// Returns the number of elements in attr if it is a DenseElementsAttr, 1 -// otherwise, as an unranked int32 Attribute. -Attribute GetNumElementsOrOne(Attribute attr) { - const auto dense_attr = attr.dyn_cast_or_null(); - int32_t num_elements = dense_attr ? dense_attr.getNumElements() : 1; +// Returns the number of elements in attr if it is a static shape, 1 otherwise, +// as an unranked int32 Attribute. +TypedAttr GetNumElementsOrOne(Type type) { + auto shaped_type = type.cast(); + int32_t num_elements = + shaped_type.hasStaticShape() ? shaped_type.getNumElements() : 1; - OpBuilder builder(attr.getContext()); + OpBuilder builder(type.getContext()); return DenseIntElementsAttr::get( RankedTensorType::get({}, builder.getI32Type()), {llvm::APInt(32, num_elements, true)}); } -bool HasExactlyTwoElements(Attribute attr) { - const auto values = attr.dyn_cast_or_null(); - if (!values) return false; - return values.getNumElements() == 2; -} - -// Returns true if attr is a DenseIntElementsAttr with the last element equal 1. -bool IsLastElementEqualsOne(Attribute attr) { - const auto ints = attr.dyn_cast_or_null(); - if (!ints) return false; - if (ints.empty()) return false; - const auto last_element_index = ints.getNumElements() - 1; - const auto iterator = ints.value_begin(); - const int last_element = iterator[last_element_index]; - return last_element == 1; -} - // Reshapes value to a given shape. -Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value, - Attribute shape) { - // This function is always guarded with IsLastElementEqualsOne(), so we could - // cast safely here. - const auto old_shape = shape.cast(); - auto iterator = old_shape.value_begin(); - SmallVector new_shape; - SmallVector new_shape_i64; - for (int i = 0; i < old_shape.size() - 1; ++i) { - new_shape.push_back(*iterator); - new_shape_i64.push_back(*iterator); - ++iterator; +Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value) { + // This function is always guarded with HasTrivialShapeExceptSecondLastDim(), + // so we could cast safely here. + auto type = value.getType().cast(); + SmallVector new_shape; + for (int64_t dim : type.getShape().drop_back()) { + new_shape.push_back(dim); } return builder.create( - value.getLoc(), - RankedTensorType::get( - new_shape_i64, value.getType().cast().getElementType()), - value, + value.getLoc(), value, builder.create( - value.getLoc(), DenseIntElementsAttr::get( - RankedTensorType::get({old_shape.size() - 1}, - builder.getI32Type()), - new_shape))); + value.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get(type.getRank() - 1, builder.getI32Type()), + new_shape))); } // Returns true if val has a static shape and the last dimension equals 1. @@ -467,6 +502,27 @@ bool IsLastDimensionEqualOne(Value val) { return last_element == 1; } +// Returns true if the supplied value- +// 1) Has only one use or +// 2) Is only used by binary op like AddOp, SubOp, MulOp or DivOp. +bool HasOneUseOrUsedByOnlyBinaryOps(Value out_value) { + if (out_value.hasOneUse()) { + return true; + } + + for (auto &use : out_value.getUses()) { + mlir::Operation *owner = use.getOwner(); + if (!llvm::isa(owner) && + !llvm::isa(owner) && + !llvm::isa(owner) && + !llvm::isa(owner)) { + return false; + } + } + + return true; +} + // Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or an // incrementing sequence from 0 to N-1. // @@ -481,7 +537,10 @@ bool IsOneHotIndexAttribute(Attribute attr) { if (index_elem_bits != 32 && index_elem_bits != 64) { return false; } - if (index_type.getRank() != 1) { + // Checks that the index has shape of [1, 1, 1, ..., 1, N]. + if (index_type.getRank() < 1 || + llvm::any_of(index_type.getShape().drop_back(), + [](int64_t dim) { return dim != 1; })) { return false; } const auto elems = dense_attr.value_begin(); @@ -493,6 +552,32 @@ bool IsOneHotIndexAttribute(Attribute attr) { return true; } +Value Get1DShapeValue(OpBuilder &builder, Value value) { + auto type = value.getType().cast(); + if (!type.hasStaticShape()) { + return nullptr; + } + auto output_type = RankedTensorType::get({1}, builder.getI32Type()); + const int num_elements = type.getNumElements(); + return builder.create( + value.getLoc(), output_type, + DenseIntElementsAttr::get(output_type, num_elements)); +} + +Type GetEmbeddingLookupShape(Value lookup, Value value) { + auto lookup_type = lookup.getType().cast(); + if (!lookup_type.hasStaticShape()) { + return nullptr; + } + auto value_type = value.getType().cast(); + if (!value_type.hasStaticShape() || value_type.getRank() != 2) { + return nullptr; + } + SmallVector new_shape = {lookup_type.getNumElements(), + value_type.getDimSize(0)}; + return value_type.clone(new_shape); +} + // Creates FullyConnected op from params and returns the output. mlir::Value GetFcOutput(OpBuilder *builder, ::mlir::Operation::result_range result, Value input, @@ -521,7 +606,7 @@ bool AllValuesAreZero(mlir::Value value) { // Converts an Attribute with a single value of float or integral type to an // Attribute holding a single value of float type. If attr has no elements, the // result is 0.0f. -Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) { +TypedAttr ConvertSingleElementAttrToFloatAttr(Attribute attr) { const auto dense_fp_attr = attr.dyn_cast_or_null(); if (dense_fp_attr) { // Already float => return @@ -578,8 +663,6 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { bool is_scalar_rhs = false; if (constant_val_type.getRank() == 0) { is_scalar_rhs = true; - } else if (constant_val_type.getRank() != 1) { - return failure(); } Value filter = fc_op.getFilter(); @@ -593,7 +676,18 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { // Rewrite if (is_none_bias) { - if (is_scalar_rhs) { + if (constant_val_type.getRank() == 1) { + // If there no pre-existing bias and the `constant_val` is 1D, simply + // use `constant_val` as bias. + bias = constant_val; + } else { + if (!is_scalar_rhs && + !(IsReducedTailOfShape(constant_val.getType(), filter.getType()) && + IsLastDimEqualToNumElements(filter.getType(), + constant_val.getType()))) { + return failure(); + } + // If the `constant_val` is scalar, we must the shape of filter // to properly broadcast the scalar to `{num_channels}` shape. @@ -606,29 +700,48 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { } int num_channels = filter_type.getShape()[0]; - // Create a zero tensor with shape {num_channels}, and the type need to - // be the same as constant_val. - // This is a way to gracefully handle scalar tensor. The Add will always - // be constant-folded away regardless if `constant_val` is a scalar or - // not. + // Create a zero tensor with shape {num_channels}, and the type need + // to be the same as constant_val. This is a way to gracefully handle + // scalar tensor. The Add will always be constant-folded away + // regardless if `constant_val` is a scalar or not. RankedTensorType type = RankedTensorType::get( {num_channels}, constant_val_type.getElementType()); auto attr = rewriter.getZeroAttr(type); bias = rewriter.create(add_op.getLoc(), type, attr); auto none_af = rewriter.getStringAttr("NONE"); - bias = - rewriter.create(add_op.getLoc(), bias, constant_val, none_af) - .getOutput(); - } else { - // If there no pre-existing bias and the `constant_val` is 1D, simply - // use `constant_val` as bias. - bias = constant_val; + if (is_scalar_rhs) { + bias = + rewriter + .create(add_op.getLoc(), bias, constant_val, none_af) + .getOutput(); + } else { + // If the RHS is neither a scalar constant nor a 1d constant, look + // if there is opportunity to reduce the dimentionality and allow + // implicit broadcasting + + auto new_added_value = added_value.reshape(RankedTensorType::get( + {added_value.getType().cast().getNumElements()}, + added_value.getType().cast().getElementType())); + + ::mlir::arith::ConstantOp new_constant_val = + rewriter.create<::mlir::arith::ConstantOp>( + add_op.getLoc(), + /*value=*/new_added_value); + + bias = rewriter + .create<::mlir::TFL::AddOp>( + add_op.getLoc(), + /*lhs=*/bias, + /*rhs=*/new_constant_val.getResult(), + /*fused_activation_function=*/none_af) + .getOutput(); + } } } else { - auto none_af = rewriter.getStringAttr("NONE"); - bias = - rewriter.create(add_op.getLoc(), bias, constant_val, none_af) - .getOutput(); + bias = rewriter + .create(add_op.getLoc(), bias, constant_val, + rewriter.getStringAttr("NONE")) + .getOutput(); } auto fc = rewriter.create( @@ -641,7 +754,8 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { rewriter.getStringAttr(add_op.getFusedActivationFunction()), /*weights_format=*/rewriter.getStringAttr(fc_op.getWeightsFormat()), /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.getKeepNumDims()), - /*asymmetric_quantize_inputs=*/fc_op.getAsymmetricQuantizeInputsAttr()); + /*asymmetric_quantize_inputs=*/ + fc_op.getAsymmetricQuantizeInputsAttr()); rewriter.replaceOp(add_op, fc.getOutput()); return success(); @@ -1488,7 +1602,7 @@ struct FuseUnpackAndConcatToReshape if (!unpack_op || unpack_op.getNumResults() != concat_op.getNumOperands()) { return failure(); } - for (auto &index_and_value : llvm::enumerate(concat_op.getValues())) { + for (const auto &index_and_value : llvm::enumerate(concat_op.getValues())) { if (index_and_value.value() != unpack_op.getResult(index_and_value.index())) { return failure(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index e772f9cb88d..216ac15c034 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -25,12 +25,12 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getType().getElementType().isF32()">, + CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isF32()">, "32 bit float constant tensor">; // Checks if the param passed is a float ElementsAttr. def FloatElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getType().getElementType().isa()">, + CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isa()">, "float constant tensor">; // Checks if the param passed is of NoneType. @@ -44,11 +44,28 @@ class HasRankAtMost : Constraint< CPred<"$0.getType().cast().hasRank() && " "$0.getType().cast().getRank() <= " # n>>; +// Checks if the value has rank at most 'n'. +class HasRankAtLeast : Constraint< + CPred<"$0.getType().cast().hasRank() && " + "$0.getType().cast().getRank() >= " # n>>; + // Checks if the value has rank 'n'. class HasRank : Constraint< CPred<"$0.getType().cast().hasRank() && " "$0.getType().cast().getRank() == " # n>>; +// Flattens a constant tensor to 1D. +def FlattenTo1D : NativeCodeCall<"FlattenTo1D($0)">; + +def HasOneUse : Constraint>; + +def HasSameStaticShapes : Constraint< + CPred<"$0.getType().cast().hasStaticShape() && " + "$1.getType().cast().hasStaticShape() && " + "$0.getType().cast().getShape() ==" + "$1.getType().cast().getShape()">, + "have the same static shape">; + //===----------------------------------------------------------------------===// // Ternary ops patterns. //===----------------------------------------------------------------------===// @@ -111,7 +128,7 @@ multiclass FuseBinaryOpToPrecedingAffine { (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_Conv2DOp $input, $filter, (binaryOp (Arith_ConstantOp $bias), - (Arith_ConstantOp $value), TFL_AF_None), + (Arith_ConstantOp (FlattenTo1D $value)), TFL_AF_None), $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), (HasOneUse $output)]>; @@ -122,11 +139,11 @@ multiclass FuseBinaryOpToPrecedingAffine { $stride_w, $multiplier), (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_DepthwiseConv2DOp $input, $filter, - (binaryOp (Arith_ConstantOp $bias), (Arith_ConstantOp $value), TFL_AF_None), + (binaryOp (Arith_ConstantOp $bias), + (Arith_ConstantOp (FlattenTo1D $value)), TFL_AF_None), $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, $multiplier), [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), - (HasRank<1> $value), (HasOneUse $output)]>; def FuseBinaryOpWithTransposeConv#binaryOp : Pat< (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $input, @@ -181,12 +198,11 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { TFL_AF_None), (BinaryOp (Arith_ConstantOp $bias), - (Arith_ConstantOp $value), + (Arith_ConstantOp (FlattenTo1D $value)), TFL_AF_None), $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, $multiplier), [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), - (HasRank<1> $value), (HasOneUse $output)]>; def FuseMulOrDivWithConv#BinaryOp : Pat< (BinaryOp (TFL_Conv2DOp:$conv_output $input, @@ -200,7 +216,7 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { (Arith_ConstantOp (ExpandTo4DForConv $value)), TFL_AF_None), (BinaryOp (Arith_ConstantOp $bias), - (Arith_ConstantOp $value), + (Arith_ConstantOp (FlattenTo1D $value)), TFL_AF_None), $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), @@ -395,19 +411,30 @@ def OperandsBroadcastToOutputType : Constraint>; +def IsReducedTailOfShape : Constraint>; + +def IsRankLessThanEqualTo : Constraint().getRank() <= " + "$1.getType().cast().getRank()">>; + def Flatten : NativeCodeCall< "$0.cast()" ".reshape(RankedTensorType::get({$0.getType().cast().getNumElements()}, " "$0.getType().cast().getElementType()))">; def IsLastDimEqualToNumElements : Constraint().getRank() >= 1 && " - "$0.getType().cast().getDimSize($0.getType().cast().getRank() - 1) == " - "$1.getType().cast().getNumElements()">>; + "TFL::IsLastDimEqualToNumElements($0.getType(), $1.getType())">>; def IsDefinedByFullyConnectedOp : Constraint() != nullptr">>; +// Returns true if the supplied value- +// 1) Has only one use or +// 2) Is only used by binary op like AddOp, SubOp, MulOp or DivOp. +def HasOneUseOrUsedByOnlyBinaryOps : Constraint>; + // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. multiclass FuseTileBroadcastIntoFollowingBinary { @@ -475,43 +502,94 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { (HasRankAtMost<4> $rhs), (SameElementType $input, $rhs)]>; - // Move binary op before reshape: - // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs)) - // This is valid only when both side of the binary operand is reshaped, and - // the sizes are the same both before and after the reshape. - def MoveBinaryOpBeforeReshape#BinaryOp : Pat< - (BinaryOp (TFL_ReshapeOp:$lhs $input1, (Arith_ConstantOp:$shape1 $s1)), - (TFL_ReshapeOp:$rhs $input2, (Arith_ConstantOp:$shape2 $s2)), - $act_fn), - (TFL_ReshapeOp (BinaryOp $input1, $input2, $act_fn), $shape1), - [(IsTailOfShape $rhs, $lhs), - (IsTailOfShape $lhs, $rhs), - (IsTailOfShape $input1, $input2), - (IsTailOfShape $input2, $input1), - (SameElementType $input1, $input2)]>; + // Move binary op before reshape: + // binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs)) + // This is valid only when both side of the binary operand is reshaped, and + // the sizes are the same both before and after the reshape. + def MoveBinaryOpBeforeReshape#BinaryOp : Pat< + (BinaryOp (TFL_ReshapeOp:$lhs $input1, (Arith_ConstantOp:$shape1 $s1)), + (TFL_ReshapeOp:$rhs $input2, (Arith_ConstantOp:$shape2 $s2)), + $act_fn), + (TFL_ReshapeOp (BinaryOp $input1, $input2, $act_fn), $shape1), + [(IsTailOfShape $rhs, $lhs), + (IsTailOfShape $lhs, $rhs), + (IsTailOfShape $input1, $input2), + (IsTailOfShape $input2, $input1), + (SameElementType $input1, $input2)]>; - // Move binary op before reshape: - // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) - // This is valid only when the last dimension of lhs is equal to the - // number of elements in constant rhs. - // Therefore, after transformation broadcast of binary op is always - // applied to the last dimension of $input. - def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< - (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)), - (Arith_ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn), - (TFL_ReshapeOp (BinaryOp $input, (Arith_ConstantOp (Flatten $rhs_attr)), - $act_fn), - $shape), - [(AnyStaticShapeTensor $input), - (IsTailOfShape $rhs, $lhs), - (IsLastDimEqualToNumElements $input, $rhs), - (HasOneUse $lhs), - // Restrict operands to have at most rank 4 because TFLite binary - // kernel supports up to 4D broadcast. - (HasRankAtMost<4> $input), - (HasRankAtMost<4> $lhs), - (HasRankAtMost<4> $rhs), - (IsDefinedByFullyConnectedOp $input)]>; + // Move binary op batched RHS before reshape: + // binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs))) + // Pattern targetted here is as follows- + // [input, lhr, rhs] == [<1x1024x128>, <1x1024x8x16>, <1x1x8x16xf32>] + // This is valid only when the- + // 1.last dimension of lhs is equal to the number of elements in constant rhs. + // 2.Reduded shape of rhs, here <8x16> is equal to last dimensions of lhs. + // Therefore, after transformation broadcast of binary op is always + // applied to the last dimension of $input. + def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat< + (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)), + (Arith_ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn), + (TFL_ReshapeOp (BinaryOp $input, (Arith_ConstantOp (Flatten $rhs_attr)), + $act_fn), + $shape), + [(AnyStaticShapeTensor $input), + (IsReducedTailOfShape $rhs, $lhs), + (IsLastDimEqualToNumElements $input, $rhs), + (HasOneUse $lhs), + // Restrict operands to have at most rank 4 because TFLite binary + // kernel supports up to 4D broadcast. + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $lhs), + (HasRankAtMost<4> $rhs), + (IsDefinedByFullyConnectedOp $input)]>; + + // Pattern to remove redundant reshape op used as LHS to binary ops + // Binary(Reshape(input, shape), rhs) -> Binary(input, rhs) + // This pattern is valid only if- + // 1. The shape is only adding broadcasting that can otherwise be implicitly + // handled by the binary op. Ex- shape == [1, 1, 1, 128] + // 2. The rank of the input to reshape is <= reshape output. + // 3. The rank of the output to reshape is <= binary rhs. + // The conditions 2 and 3 will make sure any required increase in + // dimentionality dure to reshape op is not lost. + def RemoveRedundantReshapeUsedAsLhsTo#BinaryOp : Pat< + (BinaryOp (TFL_ReshapeOp:$lhs $input, (Arith_ConstantOp:$shape $s)), + $rhs, $act_fn), + (BinaryOp $input, $rhs, $act_fn), + [(AnyStaticShapeTensor $input), + (AnyStaticShapeTensor $rhs), + (IsRankLessThanEqualTo $input, $lhs), + (IsRankLessThanEqualTo $lhs, $rhs), + (IsReducedTailOfShape $lhs, $input), + (HasOneUseOrUsedByOnlyBinaryOps $lhs), + // Restrict operands to have at most rank 4 because TFLite binary + // kernel supports up to 4D broadcast. + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $rhs)]>; + + // Pattern to remove redundant reshape op used as RHS to binary ops + // Binary(lhs, Reshape(input, shape)) -> Binary(lhs, input) + // This pattern is valid only if- + // 1. The shape is only adding broadcasting that can otherwise be implicitly + // handled by the binary op. Ex- shape == [1, 1, 1, 128] + // 2. The rank of the input to reshape is <= reshape output. + // 3. The rank of the output to reshape is <= binary lhs. + // The conditions 2 and 3 will make sure any required increase in + // dimentionality dure to reshape op is not lost. + def RemoveRedundantReshapeUsedAsRhsTo#BinaryOp : Pat< + (BinaryOp $lhs, (TFL_ReshapeOp:$rhs $input, (Arith_ConstantOp:$shape $s)), + $act_fn), + (BinaryOp $lhs, $input, $act_fn), + [(AnyStaticShapeTensor $input), + (AnyStaticShapeTensor $lhs), + (IsRankLessThanEqualTo $input, $rhs), + (IsRankLessThanEqualTo $rhs, $lhs), + (IsReducedTailOfShape $rhs, $input), + (HasOneUseOrUsedByOnlyBinaryOps $rhs), + // Restrict operands to have at most rank 4 because TFLite binary + // kernel supports up to 4D broadcast. + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $lhs)]>; } foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, @@ -620,12 +698,22 @@ def ConvertExpandDimsToReshape : Pat< class FloatValueEquals : Constraint>; +// Here, the element type can be any integer or float type. +class IsConstantValueOf : Constraint>; + // ReLU patterns def MatchReluPattern : Pat< (TFL_MaximumOp $input, (Arith_ConstantOp $Zero)), (TFL_ReluOp $input), [(FloatValueEquals<"0"> $Zero)]>; +// Optimize Minimum of tf.Relu and constant six to tf.Relu6 +def MinimumOfReluAnd6ToRelu6 : + Pat<(TFL_MinimumOp (TFL_ReluOp $x), (Arith_ConstantOp $y)), + (TFL_Relu6Op $x), + [(IsConstantValueOf<6> $y)]>; + def MatchRelu1Pattern1 : Pat< (TFL_MinimumOp (TFL_MaximumOp $input, (Arith_ConstantOp $NegOne)), (Arith_ConstantOp $One)), @@ -855,6 +943,23 @@ foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in { def Optimize#SelectOp#Not : Pat< (SelectOp (TFL_LogicalNotOp $condition), $input1, $input2), (SelectOp $condition, $input2, $input1)>; + + // Fuse select(broadcast_to(input, shape), x, y) -> selectV2(input, x, y) + // Also, fuse selectv2(broadcast_to(input, shape), x, y) -> selectV2(input, x, y) + // It is safe to perform this transform here because- + // the shapes of `pre_broadcast` and `dim` must be broadcast + // compatible for the `broadcast_to` op to be valid. + // And considering, `shape(post_broadcast)` == `shape(%input1)`, + // `post_broadcast` is broadcast compatible with `input1`. + def FuseBroadcastInto#SelectOp : Pat< + (SelectOp + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), + AnyStaticShapeTensor:$input1, AnyStaticShapeTensor:$input2), + (TFL_SelectV2Op $pre_broadcast, $input1, $input2), + [(HasSameStaticShapes $post_broadcast, $input1), + (HasRankAtMost<4> $post_broadcast), + (HasRankAtMost<4> $input1), + (HasRankAtMost<4> $input2)]>; } def EliminateLogicalAndTrue : Pat< @@ -914,38 +1019,33 @@ def OptimizeSliceOp : Pat< (replaceWithValue $input), [(CanOptimizeIdentitySliceOp $input, $begin, $size)]>; -def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0)">; +def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0.getType())">; def ReshapeValueDroppingLastDim : NativeCodeCall< - "ReshapeValueDroppingLastDim($_builder, $0, $1)">; - -def HasExactlyTwoElements : Constraint>; - -def IsLastElementEqualsOne : Constraint>; + "ReshapeValueDroppingLastDim($_builder, $0)">; def IsOneHotIndexAttribute : Constraint>; +// Checks if the shape has shape with last dimension equals 1. +def IsLastDimensionEqualOne : Constraint>; + // Replace -// Equal(Reshape(X, shape), indices) +// Equal(X, indices) // With -// OneHot(Reshape(X, shape[:-1]), N, true, false, -1) +// OneHot(Reshape(X), N, true, false, -1) // where -// - shape has length 2 (unnecessary, just to be conservative) -// - last value in shape is 1 +// - last dimension of the LHS of the equal is 1, and the rank is at least 2. // - indices is a incrementing series from 0 to N-1. (N elements total.) def ReshapeEqualOpToOneHotOp : Pat< - (TFL_EqualOp (TFL_ReshapeOp $x, (Arith_ConstantOp $shape)), - (Arith_ConstantOp $series)), - (TFL_OneHotOp (ReshapeValueDroppingLastDim $x, $shape), + (TFL_EqualOp $x, (Arith_ConstantOp $series)), + (TFL_OneHotOp (ReshapeValueDroppingLastDim $x), (Arith_ConstantOp (GetNumElementsOrOne $series)), (Arith_ConstantOp ConstantAttr, "true">), (Arith_ConstantOp ConstantAttr, "false">), ConstantAttr), - [(HasExactlyTwoElements $shape), - (IsLastElementEqualsOne $shape), + [(IsLastDimensionEqualOne $x), + (HasRankAtLeast<2> $x), (IsOneHotIndexAttribute $series)]>; def F32ElementsVal : Constraint; +def Get1DShapeValue: NativeCodeCall<"Get1DShapeValue($_builder, $0)">; + +class GetIthValue : NativeCodeCall<"$0[" # index # "]">; + +def GetEmbeddingLookupShape: NativeCodeCall<"GetEmbeddingLookupShape($0, $1)">; + // Replace // OneHot(index, depth, on=1.0f, off=0.0f, axis=-1) * filter // With @@ -996,26 +1102,29 @@ def FuseOneHotAndCastToFloat : Pat< // This is exactly what the EmbeddedLookup operator is doing, on the transposed // matrix, without doing any arithmetic but only memcpy. def ReplaceOneHotFullyConnectedWithLookup : Pat< - (TFL_FullyConnectedOp + (TFL_FullyConnectedOp:$outputs (TFL_OneHotOp - $indices, + AnyStaticShapeTensor:$indices, (Arith_ConstantOp $depth), (Arith_ConstantOp ConstantAttr, "1.0f">), (Arith_ConstantOp ConstantAttr, "0.0f">), ConstantAttr), - $filter, + StaticShapeTensorOf<[F32, I8, UI8]>:$filter, $bias, TFL_AF_None, TFL_FCWO_Default, - ConstBoolAttrFalse, + $keep_num_dims, $asymmetric_quantize_inputs), + (TFL_ReshapeOp (TFL_EmbeddingLookupOp - $indices, + (TFL_ReshapeOp $indices, (Get1DShapeValue $indices)), (TFL_TransposeOp $filter, - (Arith_ConstantOp ConstantAttr, "{1,0}"> ))), + (Arith_ConstantOp ConstantAttr, "{1,0}">)), + (returnType (GetEmbeddingLookupShape $indices, $filter)) + ), + (Arith_ConstantOp (GetShape (GetIthValue<0> $outputs)))), [(I32ElementsVal $indices), // lookup is not implemented for i64 - (HasRank<1> $indices), // lookup isn't implemented for any other rank (IsNoneType $bias)]>; // Maybe folded into the lookup matrix later def AreInputDimensionsOneInAxes : Constraint; - -// Checks if the shape has shape with last dimension equals 1. -def IsLastDimensionEqualOne : Constraint>; - // Fetches the output of FC op, from the provided arguments. def GetFcOutput : NativeCodeCall< "GetFcOutput(&$_builder, $0, $1, $2, $3, $4, $5, $6, $7)">; diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 21313752165..efd3506a8aa 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -240,10 +240,10 @@ struct FoldTransposeOp : public OpRewritePattern { ElementsAttr input_tensor = qconst_op.getValue(); assert(perm_tensor.getType().getRank() == 1); - const int num_dimensions = input_tensor.getType().getRank(); + const int num_dimensions = input_tensor.getShapedType().getRank(); assert(perm_tensor.getType().getNumElements() == num_dimensions); - ArrayRef input_shape = input_tensor.getType().getShape(); + ArrayRef input_shape = input_tensor.getShapedType().getShape(); auto output_type = op.getOutput().getType().cast(); SmallVector perm; @@ -258,7 +258,7 @@ struct FoldTransposeOp : public OpRewritePattern { } std::vector new_values; - new_values.reserve(input_tensor.getType().getNumElements()); + new_values.reserve(input_tensor.getShapedType().getNumElements()); std::vector input_indices(num_dimensions); ComputePermutation(input_tensor, perm, output_shape, num_dimensions, /*output_axis=*/0, &input_indices, &new_values); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 7803940f65d..9064d6c7f50 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -76,6 +76,7 @@ def ConvertPlaceholderWithDefault : Pat<(TF_PlaceholderWithDefaultOp $arg), (TF_ //===----------------------------------------------------------------------===// // Op removal patterns. //===----------------------------------------------------------------------===// +def RemoveXlaSharding : Pat<(TF_XlaShardingOp $a, $b, $c), (replaceWithValue $a)>; def RemoveIdentityN : Pat<(TF_IdentityNOp $arg), (replaceWithValue $arg)>; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index 04f7fb84011..a19c29a666f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -108,7 +108,7 @@ class PrepareDynamicRangeQuantizableOp return failure(); } - // 2. Quantize collected ops. It is immediatly quantized by inserting Q-DQ + // 2. Quantize collected ops. It is immediately quantized by inserting Q-DQ // pair for int8 while it is lazily applied for float16 by inserting CastOp. if (!(quantizeOps(rewriter, op, quantizable_ops))) { return failure(); @@ -160,7 +160,7 @@ class PrepareDynamicRangeQuantizableOp // Insert CastOp which is used to for converting float32 ConstantOp into // float16 quantization. If there is an existing CastOp connected to the // ConstantOp, the quantize_op will be rewired to the existing CastOp. This - // guarentees at most one CastOp is created for float32 to float16 conversion. + // guarantees at most one CastOp is created for float32 to float16 conversion. void quantizeOpAsFloat16(PatternRewriter& rewriter, arith::ConstantOp op, std::pair quant_op) const { Operation* quantize_op = quant_op.first; diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index ae2d501643c..4a0c7c42e90 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -24,7 +24,7 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; + CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; // Squash tfl.dequantize and tfl.quantize pairs. // TODO(fengliuai): Compare the scale of input and output. This can also be diff --git a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc index 2a85b5d54aa..ce2d51a66e5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc @@ -90,7 +90,7 @@ void RaiseCustomOpsPass::runOnOperation() { new_block->addArguments(op->getOperandTypes(), SmallVector(op->getNumOperands(), loc)); - for (auto &idx_args : llvm::enumerate(new_block->getArguments())) { + for (const auto &idx_args : llvm::enumerate(new_block->getArguments())) { inner_op->setOperand(idx_args.index(), idx_args.value()); } custom_op->setAttrs(inner_op->getAttrs()); diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index 60a81091be8..20336080cc2 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -20,8 +20,8 @@ namespace mlir { namespace TFL { FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr) { - if (attr.getType().getNumElements() != 1 || - !attr.getType().getElementType().isa()) { + if (attr.getShapedType().getNumElements() != 1 || + !attr.getShapedType().getElementType().isa()) { return {}; } return attr.getSplatValue(); @@ -36,8 +36,8 @@ FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) { } IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) { - if (attr.getType().getNumElements() != 1 || - !attr.getType().getElementType().isSignlessInteger()) { + if (attr.getShapedType().getNumElements() != 1 || + !attr.getShapedType().getElementType().isSignlessInteger()) { return {}; } return attr.getSplatValue(); diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index b49f2a10bc5..9f2301d4803 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -33,7 +33,7 @@ tsl::StatusOr CreateConstOpWithSingleValue( int value) { Type element_type = shaped_type.getElementType(); ShapedType scalar_type = RankedTensorType::get({}, element_type); - Attribute attr; + TypedAttr attr; if (element_type.isF16()) { auto floatType = mlir::FloatType::getF16(element_type.getContext()); auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); @@ -118,7 +118,8 @@ tsl::StatusOr CreateConstOpWithSingleValue( return tensorflow::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } - return rewriter->create(loc, scalar_type, attr); + return rewriter->create(loc, scalar_type, + cast(attr)); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc b/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc index 3311a75e387..aa2e9697595 100644 --- a/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/low_bit_utils.cc @@ -22,14 +22,18 @@ limitations under the License. namespace tflite { std::vector PackInt4ValuesDensely(std::vector src_buffer) { - std::vector packed_buffer((src_buffer.size() + 1) / 2); + auto num_elements = src_buffer.size(); + auto packed_size = (num_elements + 1) / 2; + std::vector packed_buffer((num_elements + 1) / 2); - for (int i = 0; i < src_buffer.size(); ++i) { - if (i % 2 == 0) { - packed_buffer.at(i / 2) = src_buffer[i]; - } else { - packed_buffer.at(i / 2) |= src_buffer[i] << 4; - } + for (int i = 0; i < num_elements - 1; i += 2) { + packed_buffer[i / 2] = src_buffer[i] & 0x0F; + packed_buffer[i / 2] |= src_buffer[i + 1] << 4; + } + + // Copy the final nibble if the buffer is odd-lengthed + if (num_elements % 2 != 0) { + packed_buffer[packed_size - 1] = src_buffer[num_elements - 1] & 0x0F; } return packed_buffer; diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h new file mode 100644 index 00000000000..7878b675895 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +using llvm::ArrayRef; +using mlir::Operation; +using mlir::ShapedType; +using mlir::Value; + +// Returns true if all tensor value in `values` has static shape and same shape. +inline bool OpHasSameStaticShapes(Operation* op) { + auto values = op->getOperands(); + int operand_num = 0; + ArrayRef shape; + for (Value value : values) { + auto shaped_type = value.getType().dyn_cast(); + if (!shaped_type || !shaped_type.hasStaticShape()) { + return false; + } + if (operand_num == 0) { + shape = shaped_type.getShape(); + } else { + if (shape != shaped_type.getShape()) { + return false; + } + } + ++operand_num; + } + return true; +} +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index bd832527ce5..4c8485c3551 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,5 +27,11 @@ def NotFromQuantOpOrSameQuantType : Constraint< def SameElementType : Constraint< CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; -// Checks if the value has only one user. -def HasOneUse : Constraint>; +// Checks if all of an ops inputs are the same static shape. +// BUILD NOTE: "OpHasSameStaticShapes" here refers to the C++ function defined +// in `utils/utils.h`. The `utils.h` header is included in `tfl_ops.h` so all +// of our files will have access to `OpHasSameStaticShapes` when including files +// generated from table-gen. +def OpHasSameStaticShapesPred : CPred<"OpHasSameStaticShapes($0.getDefiningOp())">; +def OpHasSameStaticShapes : Constraint; +def OpHasNotSameStaticShapes : Constraint, "op must have not static same input shapes">; \ No newline at end of file diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 90b9d8cc854..dc69f3d64bb 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -57,6 +57,14 @@ auto* mlir_graph_optimization_pass_fallback_count = monitoring::Counter<1>::New( "used", /* metric field */ "status"); +auto* mlir_function_pass_graph_conversion_count = monitoring::Counter<1>::New( + /* metric name */ + "/tensorflow/core/mlir_function_pass_graph_conversion_count", + /* metric description */ + "Track success/failure of Graph to MLIR conversions in function " + "optimization pass", + /* metric field */ "status"); + // The status metric field is used to record success/failure of mlir // function/graph optimization passes. constexpr char kSuccess[] = "kSuccess"; @@ -76,8 +84,8 @@ static void DumpModule(mlir::ModuleOp module, std::string file_prefix) { auto* env = tensorflow::Env::Default(); auto status = env->RecursivelyCreateDir(prefix); if (!status.ok()) { - LOG(WARNING) << "cannot create directory '" + prefix + - "': " + status.error_message(); + LOG(WARNING) << "cannot create directory '" << prefix + << "': " << status.message(); return; } @@ -90,8 +98,7 @@ static void DumpModule(mlir::ModuleOp module, std::string file_prefix) { std::unique_ptr file_writer; status = env->NewWritableFile(prefix, &file_writer); if (!status.ok()) { - LOG(WARNING) << "cannot open file '" + prefix + - "': " + status.error_message(); + LOG(WARNING) << "cannot open file '" << prefix << "': " << status.message(); return; } @@ -104,21 +111,14 @@ static void DumpModule(mlir::ModuleOp module, std::string file_prefix) { status = file_writer->Append(txt_module); if (!status.ok()) { - LOG(WARNING) << "error writing to file '" + prefix + - "': " + status.error_message(); + LOG(WARNING) << "error writing to file '" << prefix + << "': " << status.message(); return; } (void)file_writer->Close(); VLOG(1) << "Dumped MLIR module to " << prefix; } -static std::string GetModuleText(mlir::ModuleOp module) { - std::string module_txt; - llvm::raw_string_ostream os(module_txt); - module.print(os); - return module_txt; -} - MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { static auto* global = new MlirOptimizationPassRegistry(); return *global; @@ -137,8 +137,8 @@ static void RegisterDialects(mlir::DialectRegistry& registry) { Status MlirFunctionOptimizationPass::Run( const std::string& function_name, const DeviceSet& device_set, - const ConfigProto& config_proto, std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, + const ConfigProto& config_proto, absl::string_view xla_compile_device_type, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, bool* control_rets_updated) { // overall_state equals to: @@ -208,6 +208,7 @@ Status MlirFunctionOptimizationPass::Run( // the shape inference pass is run early in the pass pipeline, shape inference // during import is not necessary. import_config.enable_shape_inference = false; + import_config.xla_compile_device_type = xla_compile_device_type; static const char* kTfMlirCategory = "TfMlir"; tensorflow::metrics::ScopedCounter<2> timings( @@ -216,6 +217,9 @@ Status MlirFunctionOptimizationPass::Run( auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def, import_config, &context); + mlir_function_pass_graph_conversion_count + ->GetCell(absl::StatusCodeToString(module_ref_status.status().code())) + ->IncrementBy(1); timings.ReportAndStop(); if (!module_ref_status.ok()) { @@ -237,8 +241,14 @@ Status MlirFunctionOptimizationPass::Run( for (auto& pass_registration : registry_->passes()) { llvm::StringRef name = pass_registration.pass->name(); - DUMP_MLIR_MODULE(function_name, llvm::formatv("mlir_{0}_before", name), - GetModuleText(*module_ref), VLOG_IS_ON(1)); + if (DEBUG_DATA_DUMPER()->ShouldDump(function_name, kDebugGroupMain) || + VLOG_IS_ON(1)) { + ::tensorflow::DumpMlirOpToFile( + DEBUG_DATA_DUMPER()->GetDumpFilename( + function_name, kDebugGroupMain, + llvm::formatv("mlir_{0}_before", name)), + *module_ref, llvm::StringRef(), nullptr); + } Status pass_status = OkStatus(); auto pass_state = per_pass_state[per_pass_state_index++]; @@ -247,8 +257,8 @@ Status MlirFunctionOptimizationPass::Run( VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " << (*graph)->num_edges(); timings.Reset({kTfMlirCategory, name.str()}); - pass_status = pass_registration.pass->Run(config_proto, *module_ref, - **graph, *flib_def); + pass_status = pass_registration.pass->Run( + function_name, config_proto, *module_ref, **graph, *flib_def); timings.ReportAndStop(); if (pass_status.ok()) { VLOG(2) << "Finished MLIR graph optimization pass: " @@ -266,8 +276,8 @@ Status MlirFunctionOptimizationPass::Run( // module in case of no failures. auto module_ref_clone = module_ref->clone(); timings.Reset({kTfMlirCategory, name.str() + "_fallback"}); - pass_status = pass_registration.pass->Run(config_proto, module_ref_clone, - **graph, *flib_def); + pass_status = pass_registration.pass->Run( + function_name, config_proto, module_ref_clone, **graph, *flib_def); timings.ReportAndStop(); if (pass_status.ok()) { @@ -304,8 +314,13 @@ Status MlirFunctionOptimizationPass::Run( } } - DUMP_MLIR_MODULE(function_name, llvm::formatv("mlir_{0}_after", name), - GetModuleText(*module_ref), VLOG_IS_ON(1)); + if (DEBUG_DATA_DUMPER()->ShouldDump(function_name, kDebugGroupMain) || + VLOG_IS_ON(1)) { + ::tensorflow::DumpMlirOpToFile(DEBUG_DATA_DUMPER()->GetDumpFilename( + function_name, kDebugGroupMain, + llvm::formatv("mlir_{0}_after", name)), + *module_ref, llvm::StringRef(), nullptr); + } } if (!is_module_updated) { diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index 8fe0ccbd00e..d3a8420af94 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -65,7 +65,8 @@ class MlirOptimizationPass { const Graph& graph, const FunctionLibraryDefinition& function_library) const = 0; - virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, + virtual Status Run(const std::string& function_name, + const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph, const FunctionLibraryDefinition& function_library) = 0; }; @@ -118,8 +119,9 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass { // Executes all of the underlying registered MlirOptimizationPasses. Status Run(const std::string& function_name, const DeviceSet& device_set, - const ConfigProto& config_proto, std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, + const ConfigProto& config_proto, + absl::string_view xla_compile_device_type, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, std::vector* control_ret_node_names, bool* control_rets_updated) override; diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index 36ba9160f59..4e7d1449946 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -15,11 +15,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include #include +#include +#include #include #include "mlir/IR/Builders.h" // from @llvm-project #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -29,6 +33,11 @@ using ::testing::NiceMock; using ::testing::Return; using ::testing::Test; +constexpr char kOk[] = "OK"; +constexpr char kInvalidArgument[] = "INVALID_ARGUMENT"; +constexpr char kSuccess[] = "kSuccess"; +constexpr char kFailure[] = "kFailure"; + class MockMlirOptimizationPass : public MlirOptimizationPass { public: // MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX @@ -39,7 +48,8 @@ class MockMlirOptimizationPass : public MlirOptimizationPass { const DeviceSet* device_set, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library)); - MOCK_METHOD4(Run, Status(const ConfigProto& config_proto, + MOCK_METHOD5(Run, Status(const std::string& function_name, + const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph, const FunctionLibraryDefinition& function_library)); }; @@ -72,8 +82,8 @@ class ModifyMlirModulePass : public MlirOptimizationPass { // Just modify MLIR module so that we can check whether original TF graph // has changed or not. - Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, - const Graph& graph, + Status Run(const std::string& function_name, const ConfigProto& config_proto, + mlir::ModuleOp module, const Graph& graph, const FunctionLibraryDefinition& function_library) override { mlir::Builder b(module.getContext()); auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0)); @@ -123,10 +133,11 @@ class MlirGraphOptimizationPassTest : public Test { ON_CALL(*optimization_pass, GetPassState(_, _, _, _)) .WillByDefault(Return(pass_state)); - ON_CALL(*optimization_pass, Run(_, _, _, _)) + ON_CALL(*optimization_pass, Run(_, _, _, _, _)) .WillByDefault(Return(pass_run_result)); MlirOptimizationPassRegistry::Global().Add(pass_priority++, std::move(optimization_pass)); + pass_result_expected_[pass_state][pass_run_result.ok()]++; } flib_ = std::make_unique(graph_->flib_def()); @@ -141,6 +152,7 @@ class MlirGraphOptimizationPassTest : public Test { .WillByDefault(Return(pass_state)); MlirOptimizationPassRegistry::Global().Add(10, std::move(optimization_pass)); + pass_result_expected_[pass_state][run_status.ok()]++; } void TearDown() override { @@ -164,31 +176,60 @@ class MlirGraphOptimizationPassTest : public Test { #endif } + void verifyCounters() { + EXPECT_EQ(mlir_function_pass_fallback_count_.Read(kSuccess), + pass_result_expected_[MlirOptimizationPassState::FallbackEnabled] + [true]); + EXPECT_EQ(mlir_function_pass_fallback_count_.Read(kFailure), + pass_result_expected_[MlirOptimizationPassState::FallbackEnabled] + [false]); + EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kOk), 1); + } + ConfigProto config_proto_; MlirFunctionOptimizationPass function_optimization_pass_; DeviceSet device_set_; std::unique_ptr graph_; std::unique_ptr flib_; std::vector control_ret_node_names_; + std::string xla_compile_device_type_; bool control_rets_updated_{false}; + monitoring::testing::CellReader mlir_function_pass_fallback_count_ = + monitoring::testing::CellReader( + /* metric name */ + "/tensorflow/core/mlir_function_pass_fallback_count"); + monitoring::testing::CellReader + mlir_graph_optimization_pass_fallback_count_ = + monitoring::testing::CellReader( + /* metric name */ + "/tensorflow/core/mlir_graph_optimization_pass_fallback_count"); + monitoring::testing::CellReader + mlir_function_pass_graph_conversion_count_ = + monitoring::testing::CellReader( + /* metric name */ + "/tensorflow/core/mlir_function_pass_graph_conversion_count"); + std::map> + pass_result_expected_; }; TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoFallback) { - Init(Status(error::Code::ABORTED, "aborted"), + Init(Status(absl::StatusCode::kAborted, "aborted"), {MlirOptimizationPassState::Enabled}); GraphDef original_graph_def; graph_->ToGraphDef(&original_graph_def); EXPECT_EQ(function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, &graph_, flib_.get(), + "test_func", device_set_, config_proto_, + xla_compile_device_type_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), - Status(error::Code::ABORTED, "aborted")); + Status(absl::StatusCode::kAborted, "aborted")); verifyGraph(original_graph_def); + verifyCounters(); } TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) { - Init(Status(error::Code::ABORTED, "aborted"), + Init(Status(absl::StatusCode::kAborted, "aborted"), {MlirOptimizationPassState::Disabled, MlirOptimizationPassState::FallbackEnabled}); @@ -203,13 +244,15 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) { GraphDef original_graph_def; graph_->ToGraphDef(&original_graph_def); AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, - Status(error::Code::ABORTED, "aborted")); + Status(absl::StatusCode::kAborted, "aborted")); EXPECT_EQ(function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, &graph_, flib_.get(), + "test_func", device_set_, config_proto_, + xla_compile_device_type_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), OkStatus()); verifyGraph(original_graph_def); + verifyCounters(); } TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailFallback) { @@ -221,11 +264,32 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailFallback) { AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, OkStatus()); EXPECT_EQ(function_optimization_pass_.Run( - "test_func", device_set_, config_proto_, &graph_, flib_.get(), + "test_func", device_set_, config_proto_, + xla_compile_device_type_, &graph_, flib_.get(), &control_ret_node_names_, &control_rets_updated_), OkStatus()); verifyGraph(original_graph_def, true); + verifyCounters(); +} + +TEST_F(MlirGraphOptimizationPassTest, GraphDoesntConvertUpdatesCounter) { + Init(OkStatus(), {MlirOptimizationPassState::FallbackEnabled}); + + graph_ = std::make_unique(OpRegistry::Global()); + control_ret_node_names_.push_back("foo"); + + AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, + OkStatus()); + EXPECT_EQ(function_optimization_pass_.Run( + "test_func", device_set_, config_proto_, + xla_compile_device_type_, &graph_, flib_.get(), + &control_ret_node_names_, &control_rets_updated_), + OkStatus()); + + EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kOk), 0); + EXPECT_EQ(mlir_function_pass_graph_conversion_count_.Read(kInvalidArgument), + 1); } TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) { diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index 0f15052fa32..cbd03639c02 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -42,7 +42,8 @@ namespace tensorflow { OpOrArgNameMapper::~OpOrArgNameMapper() {} -llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) { +llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix, + int hash_value) { // Insert/find if prefix is unique. auto prefix_it = name_to_count_.try_emplace(prefix, 0); if (prefix_it.second && IsUnique(prefix)) { @@ -55,8 +56,11 @@ llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) { // Add increasing number (count) to end of prefix until it is determined // to be unique. auto& val = prefix_it.first->second; - llvm::SmallString<64> probe_name(prefix); - probe_name.append(GetSuffixSeparator()); + auto prefix_name = hash_value == 0 ? prefix.str() + GetSuffixSeparator().str() + : prefix.str() + GetDashSeparator().str() + + std::to_string(hash_value) + + GetDashSeparator().str(); + llvm::SmallString<64> probe_name(prefix_name); const int probe_prefix_size = probe_name.size(); while (true) { probe_name.resize(probe_prefix_size); @@ -75,11 +79,12 @@ llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) { } } -llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrVal op_or_val) { +llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrVal op_or_val, + int hash_value) { auto& name = op_or_val_to_name_[op_or_val]; if (!name.empty()) return StringViewToRef(name); // Update the value in the map with unique name. - llvm::StringRef ref = GetUniqueName(GetName(op_or_val)); + llvm::StringRef ref = GetUniqueName(GetName(op_or_val), hash_value); name = StringRefToView(ref); return ref; } diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h index f4aa8626f43..d8ff9cebc01 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h @@ -36,10 +36,10 @@ using OpOrVal = llvm::PointerUnion; class OpOrArgNameMapper { public: // Returns unique name for the given prefix. - llvm::StringRef GetUniqueName(llvm::StringRef prefix); + llvm::StringRef GetUniqueName(llvm::StringRef prefix, int hash_value = 0); // Returns unique name for the operation or value. - llvm::StringRef GetUniqueName(OpOrVal op_or_val); + llvm::StringRef GetUniqueName(OpOrVal op_or_val, int hash_value = 0); // Returns unique name as a string_view for the operation or value. absl::string_view GetUniqueNameView(OpOrVal op_or_val); @@ -67,6 +67,8 @@ class OpOrArgNameMapper { // Returns the separator used before uniqueing suffix. virtual llvm::StringRef GetSuffixSeparator() { return ""; } + virtual llvm::StringRef GetDashSeparator() { return "_"; } + private: // Returns name from the location of the operation or value. virtual std::string GetName(OpOrVal op_or_val) = 0; diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 1596b976e6d..0afe50ac2e7 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -20,16 +20,20 @@ cc_library( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TosaDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@stablehlo//:register", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:tfe_context_internal", "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", + "//tensorflow/compiler/mlir/lite:flatbuffer_import", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", @@ -40,7 +44,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/compiler/mlir/tosa:passes_header", "//tensorflow/compiler/mlir/tosa:tf_passes", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 39593e2ded0..7cc1d25355e 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -29,6 +29,8 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project @@ -36,10 +38,12 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -119,7 +123,7 @@ std::string RunPassPipelineOnModule(mlir::ModuleOp module, mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext()); if (failed(pm.run(module))) { - Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); + tsl::Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); return "// error"; } } @@ -137,13 +141,13 @@ static std::string ImportGraphDefImpl(const std::string& proto, GraphDef graphdef; auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef); if (!s.ok()) { - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); return "// error"; } mlir::MLIRContext context; auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); if (!module.ok()) { - Set_TF_Status_from_Status(status, module.status()); + tsl::Set_TF_Status_from_Status(status, module.status()); return "// error"; } @@ -158,7 +162,7 @@ std::string ImportFunction(const std::string& functiondef_proto, FunctionDef functiondef; auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef); if (!s.ok()) { - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); return "// error"; } @@ -168,7 +172,7 @@ std::string ImportFunction(const std::string& functiondef_proto, const tensorflow::FunctionDef* fdef = flib_def.Find(function_name); if (fdef == nullptr) { s = tensorflow::errors::NotFound("Cannot find function ", function_name); - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); return "// error"; } @@ -176,14 +180,14 @@ std::string ImportFunction(const std::string& functiondef_proto, s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def, &fbody); if (!s.ok()) { - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); return "// error"; } mlir::MLIRContext context; auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context); if (!module.ok()) { - Set_TF_Status_from_Status(status, module.status()); + tsl::Set_TF_Status_from_Status(status, module.status()); return "// error"; } @@ -211,7 +215,7 @@ std::string ImportGraphDef(const std::string& proto, auto s = ParseInputArrayInfo(input_names, input_data_types, input_data_shapes, &specs.inputs); if (!s.ok()) { - Set_TF_Status_from_Status(status, s); + tsl::Set_TF_Status_from_Status(status, s); return "// error"; } if (!output_names.empty()) { @@ -230,7 +234,7 @@ std::string ExperimentalConvertSavedModelToMlir( auto load_status = tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle); if (!load_status.ok()) { - Set_TF_Status_from_Status(status, load_status); + tsl::Set_TF_Status_from_Status(status, load_status); return "// error"; } @@ -242,7 +246,7 @@ std::string ExperimentalConvertSavedModelToMlir( auto module_or = ConvertSavedModelToMlir( &bundle, &context, absl::Span(exported_names)); if (!module_or.status().ok()) { - Set_TF_Status_from_Status(status, module_or.status()); + tsl::Set_TF_Status_from_Status(status, module_or.status()); return "// error"; } @@ -266,7 +270,7 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( saved_model_path, tag_set, absl::Span(exported_names), &context, import_options); if (!module_or.status().ok()) { - Set_TF_Status_from_Status(status, module_or.status()); + tsl::Set_TF_Status_from_Status(status, module_or.status()); return "// error"; } @@ -275,7 +279,8 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( std::string ExperimentalConvertSavedModelV1ToMlir( const std::string& saved_model_path, const std::string& exported_names_str, - const std::string& tags, bool lift_variables, bool upgrade_legacy, + const std::string& tags, bool lift_variables, + bool include_variables_in_initializers, bool upgrade_legacy, bool show_debug_info, TF_Status* status) { // Load the saved model into a SavedModelBundle. @@ -286,7 +291,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( auto load_status = tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle); if (!load_status.ok()) { - Set_TF_Status_from_Status(status, load_status); + tsl::Set_TF_Status_from_Status(status, load_status); return "// error"; } @@ -297,11 +302,13 @@ std::string ExperimentalConvertSavedModelV1ToMlir( tensorflow::MLIRImportOptions import_options; import_options.upgrade_legacy = upgrade_legacy; import_options.lift_variables = lift_variables; + import_options.include_variables_in_initializers = + include_variables_in_initializers; auto module_or = ConvertSavedModelV1ToMlir(bundle, absl::Span(exported_names), &context, import_options); if (!module_or.status().ok()) { - Set_TF_Status_from_Status(status, module_or.status()); + tsl::Set_TF_Status_from_Status(status, module_or.status()); return "// error"; } @@ -317,7 +324,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); if (failed(pm.run(*module))) { - Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + tsl::Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); return "// error"; } return MlirModuleToString(*module, show_debug_info); @@ -330,13 +337,16 @@ std::string ExperimentalRunPassPipeline(const std::string& mlir_txt, RegisterPasses(); mlir::DialectRegistry registry; mlir::RegisterAllTensorFlowDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + registry.insert(); mlir::MLIRContext context(registry); mlir::OwningOpRef module; { mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); module = mlir::parseSourceString(mlir_txt, &context); if (!module) { - Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + tsl::Set_TF_Status_from_Status(status, + diagnostic_handler.ConsumeStatus()); return "// error"; } } @@ -353,7 +363,7 @@ std::string ExperimentalRunPassPipeline(const std::string& mlir_txt, mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); if (failed(pm.run(*module))) { - Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + tsl::Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); return "// error"; } return MlirModuleToString(*module, show_debug_info); @@ -363,13 +373,16 @@ void ExperimentalWriteBytecode(const std::string& filename, const std::string& mlir_txt, TF_Status* status) { mlir::DialectRegistry registry; mlir::RegisterAllTensorFlowDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + registry.insert(); mlir::MLIRContext context(registry); mlir::OwningOpRef module; + mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); { - mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); module = mlir::parseSourceString(mlir_txt, &context); if (!module) { - Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + tsl::Set_TF_Status_from_Status(status, + diagnostic_handler.ConsumeStatus()); return; } } @@ -378,13 +391,74 @@ void ExperimentalWriteBytecode(const std::string& filename, std::string error; std::unique_ptr outputFile = mlir::openOutputFile(filename, &error); + if (!error.empty()) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + ("Unable to create output file " + error).c_str()); + return; + } + outputFile->keep(); + if (failed(mlir::writeBytecodeToFile(*module, outputFile->os(), + writer_config))) { + tsl::Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + } +} + +void ExperimentalTFLiteToTosaBytecode( + const std::string& flatbuffer_file, const std::string& tosa_bytecode_file, + bool use_external_constant, + const std::vector& ordered_input_arrays, + const std::vector& ordered_output_arrays, TF_Status* status) { + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + registry.insert(); + mlir::MLIRContext context(registry); + mlir::OwningOpRef module; + mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); + { + mlir::Location loc = mlir::UnknownLoc::get(&context); + std::string error; + std::unique_ptr buffer = + mlir::openInputFile(flatbuffer_file, &error); + if (buffer == nullptr) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + ("Unable to load input file " + error).c_str()); + return; + } + + auto buffer_view = + std::string_view(buffer->getBufferStart(), buffer->getBufferSize()); + module = tflite::FlatBufferToMlir( + buffer_view, &context, loc, use_external_constant, ordered_input_arrays, + ordered_output_arrays); + mlir::PassManager pm(&context, module.get()->getName().getStringRef(), + mlir::PassManager::Nesting::Implicit); + mlir::tosa::TOSATFLLegalizationPipelineOptions opts; + // This flow is specific to compilation backend, so set to true. + opts.target_compilation_backend = true; + // Temporary work-around for https://github.com/openxla/iree/issues/8974 + opts.dequantize_tfl_softmax = true; + createTFLtoTOSALegalizationPipeline(pm, opts); + if (failed(pm.run(*module))) { + tsl::Set_TF_Status_from_Status(status, + diagnostic_handler.ConsumeStatus()); + return; + } + } + mlir::FallbackAsmResourceMap fallback_resource_map; + mlir::BytecodeWriterConfig writer_config(fallback_resource_map); + std::string error; + std::unique_ptr outputFile = + mlir::openOutputFile(tosa_bytecode_file, &error); if (!error.empty()) { TF_SetStatus(status, TF_INVALID_ARGUMENT, ("Unable to create output file" + error).c_str()); return; } outputFile->keep(); - mlir::writeBytecodeToFile(*module, outputFile->os(), writer_config); + if (failed(mlir::writeBytecodeToFile(*module, outputFile->os(), + writer_config))) { + tsl::Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index 740971d4fb8..a17f4f2843e 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_ #include +#include #include "absl/strings/string_view.h" #include "tensorflow/c/eager/c_api.h" @@ -95,7 +96,8 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( // A string of textual MLIR representing the raw imported SavedModel. std::string ExperimentalConvertSavedModelV1ToMlir( const std::string &saved_model_path, const std::string &exported_names_str, - const std::string &tags, bool lift_variables, bool upgrade_legacy, + const std::string &tags, bool lift_variables, + bool include_variables_in_initializers, bool upgrade_legacy, bool show_debug_info, TF_Status *status); std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, @@ -107,6 +109,16 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, void ExperimentalWriteBytecode(const std::string &filename, const std::string &mlir_txt, TF_Status *status); +// Loads a TFLite flatbuffer, convert to TOSA for backend compilation and +// produce an MLIR bytecode file as output. +// TODO(jpienaar): Refactor this when we use more implicit module passing +// between calls to avoid serialization overhead. +void ExperimentalTFLiteToTosaBytecode( + const std::string &flatbuffer_file, const std::string &tosa_bytecode_file, + bool use_external_constant, + const std::vector &ordered_input_arrays, + const std::vector &ordered_output_arrays, TF_Status *status); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 0ae09a43dd3..f85f8f13882 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -1,21 +1,97 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +# TODO(b/264218457): Create stablehlo-quantization-opt and register passes to actually test. package_group( name = "internal_visibility_allowlist_package", packages = [ "//tensorflow/compiler/mlir/lite/...", "//tensorflow/compiler/mlir/quantization/...", + "//tensorflow/lite/...", "//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1 ] + internal_visibility_allowlist(), ) +# TODO(b/264218457): Add quantize and post_quantize passes. +cc_library( + name = "passes", + srcs = [ + "passes/quantize_weight.cc", + ], + hdrs = [ + "passes/passes.h", + ], + compatible_with = get_compatible_with_cloud(), + deps = [ + ":quantization_options_proto_cc", + ":stablehlo_passes_inc_gen", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/core/platform:path", + "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], + # Alwayslink is required for registering the MLIR passes. + # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. + alwayslink = True, +) + +cc_library( + name = "quantize_passes", + srcs = [ + "quantize_passes.cc", + ], + hdrs = [ + "quantize_passes.h", + ], + compatible_with = get_compatible_with_cloud(), + visibility = [":internal_visibility_allowlist_package"], + deps = [ + ":passes", + ":quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/core/platform:path", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + ], +) + +gentbl_cc_library( + name = "stablehlo_passes_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + ], + "passes/passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + tf_proto_library( name = "quantization_options_proto", srcs = ["quantization_options.proto"], cc_api_version = 2, - make_default_target_header_only = True, - visibility = [":internal_visibility_allowlist_package"], + visibility = ["//visibility:public"], ) # copybara:uncomment_begin(google-only) @@ -26,3 +102,7 @@ tf_proto_library( # deps = [":quantization_options_proto"], # ) # copybara:uncomment_end + +exports_files([ + "run_lit.sh", +]) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl b/tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl index 0e302a08fd5..310b10e5d0f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl +++ b/tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl @@ -1,10 +1,5 @@ """Internal visibility rules.""" def internal_visibility_allowlist(): - """Returns a list of g3 packages that can depend on internal targets.""" - return [ - "//learning/brain/experimental/mlir/quantization/...", - "//learning/brain/mlir/quantization/tensorflow/...", - "//learning/brain/mobile/programmability/...", - "//learning/brain/experimental/tfq/...", - ] + """Returns a list of the packages that can depend on internal targets.""" + return [] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h new file mode 100644 index 00000000000..788a00f349c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_PASSES_H_ + +#include +#include + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +#define GEN_PASS_DECL +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace mlir { +namespace stablehlo { + +// Creates a pass that quantizes weight component of StableHLO graph. +std::unique_ptr> CreateQuantizeWeightPass( + ::stablehlo::quantization::QuantizationOptions quantization_options); + +} // namespace stablehlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td new file mode 100644 index 00000000000..959121888b6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -0,0 +1,22 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def QuantizeWeightPass : Pass<"stablehlo-quantize-weight", "mlir::func::FuncOp"> { + let summary = "Quantizes the weight component of StableHLO graph."; + let constructor = "CreateQuantizeWeightPass()"; + let dependentDialects = ["stablehlo::StablehloDialect"]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc new file mode 100644 index 00000000000..9d5d0cc8e91 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc @@ -0,0 +1,244 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +// NOLINTNEXTLINE +//===----------------------------------------------------------------------===// +// The Quantization Pass for Weight. +//===----------------------------------------------------------------------===// +namespace mlir { +namespace stablehlo { + +namespace { +#define GEN_PASS_DEF_QUANTIZEWEIGHTPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +using QuantizationUnits = llvm::SetVector>; + +// Min/Max values used for creating ConstantOp. +constexpr float kMaxFloat16Value = 65504.f; +constexpr float kMinFloat16Value = -65504.f; + +class QuantizeWeightPass + : public impl::QuantizeWeightPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeWeightPass) + + explicit QuantizeWeightPass( + ::stablehlo::quantization::QuantizationOptions quantization_options) {} + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "stablehlo-quantize-weight"; + } + + StringRef getDescription() const final { + return "Apply the specified quantization methods to weights."; + } + + private: + void runOnOperation() override; +}; + +// Collects quantizable target ops, then insert Q-DQ quantization patterns. +class QuantizeWeight : public OpRewritePattern { + public: + explicit QuantizeWeight(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(ConstantOp op, + PatternRewriter& rewriter) const override { + // 1. Collect quantizable ops. + QuantizationUnits quantizable_ops = GetQuantizableOps(op); + if (quantizable_ops.empty()) { + return failure(); + } + + // 2. Quantize collected ops. + if (!QuantizeOps(rewriter, op, quantizable_ops)) { + return failure(); + } + + // 3. Complete the Q-DQ pair for each inference type. + if (!ConvertToFloat16Constant(rewriter, op)) { + return failure(); + } + return success(); + } + + private: + // Marks users that are applicable for quantization where the criteria for + // determining quantizable ops differs by the inference type. + QuantizationUnits GetQuantizableOps(ConstantOp op) const { + // Non-float tensors do not need quantization. + QuantizationUnits quantizable_ops; + ShapedType type = op.getType().dyn_cast(); + if (!type || !type.getElementType().isF32()) return quantizable_ops; + + Value value = op.getResult(); + + for (OpOperand& use : value.getUses()) { + Operation* user = use.getOwner(); + int operand_num = use.getOperandNumber(); + quantizable_ops.insert({user, operand_num}); + } + return quantizable_ops; + } + + // Returns whether quantization is applied to filtered users. + bool QuantizeOps(PatternRewriter& rewriter, ConstantOp op, + const QuantizationUnits& quantizable_ops) const { + // TODO(b/212514817): refactor mode checking to improve code quality. + for (const std::pair& quant_op : quantizable_ops) { + // For f16 quantization, quantize all constant ops as float16. + QuantizeOpAsFloat16(rewriter, op, quant_op); + } + // TODO(b/264218457): Return a value that accurately captures result status. + return true; + } + + // Inserts ConvertOp which is used for converting float32 ConstantOp into + // float16 quantization. If there is an existing ConvertOp connected to the + // ConstantOp, the quantizable_op will be rewired to the existing ConvertOp. + // This guarantees at most one ConvertOp is created for float32 to float16 + // conversion. + void QuantizeOpAsFloat16(PatternRewriter& rewriter, ConstantOp op, + const std::pair quant_op) const { + auto [quantizable_op, quantize_operand_num] = quant_op; + // If the constant is an output tensor, do nothing. + if (isa(quantizable_op)) { + return; + } + + TensorType old_result_type = + op.getResult().getType().dyn_cast(); + FloatType quantized_type = FloatType::getF16(op.getContext()); + ShapedType new_result_type = old_result_type.clone(quantized_type); + + // Insert ConvertOp if it does not exist yet. Otherwise, just rewire without + // creating a ConvertOp. + for (OpOperand& connected_op : op.getResult().getUses()) { + ConvertOp convert_op = + dyn_cast_or_null(connected_op.getOwner()); + // ConvertOp already exists. Rewire the existing convert op into f16. + if (convert_op && convert_op.getType() == new_result_type) { + quantizable_op->setOperand(quantize_operand_num, convert_op); + return; + } + } + rewriter.setInsertionPointAfter(op); + ConvertOp new_convert_op = rewriter.create( + op->getLoc(), new_result_type, op.getResult()); + quantizable_op->setOperand(quantize_operand_num, + new_convert_op.getResult()); + } + + // Returns whether a ConvertOp-Operation sequence can be converted into new + // ConstantOp-Convert-Operation. The new ConstantOp has float16 data type. + bool ConvertToFloat16Constant(PatternRewriter& rewriter, + ConstantOp op) const { + for (Operation* connected_op : op.getResult().getUsers()) { + ConvertOp convert_op = dyn_cast_or_null(connected_op); + // Skip if no convert op exists. + if (!convert_op || convert_op.getResult().use_empty()) continue; + + // Get types. + Type old_result_type = op.getResult().getType(); + ShapedType new_result_type = convert_op.getType().dyn_cast(); + + // Proceeds only if the converting is to float16. + if (!new_result_type.getElementType().isF16()) continue; + + // Convert values. + std::vector new_values; + DenseFPElementsAttr value_attr = + op.getValue().cast(); + new_values.reserve(value_attr.getNumElements()); + + for (float value : value_attr.getValues()) { + new_values.push_back(Eigen::half( + std::min(std::max(value, kMinFloat16Value), kMaxFloat16Value))); + } + DenseElementsAttr new_value_attr = DenseFPElementsAttr::get( + new_result_type, ArrayRef(new_values)); + // Create new ConstantOp-ConvertOp-Operation sequences. At this moment, + // old ConstantOp is guaranteed to have one F32->F16 convert op regardless + // of its number of users. + rewriter.setInsertionPointAfter(op); + // create new F16 constant op in that location + ConstantOp new_const = rewriter.create( + op->getLoc(), new_result_type, new_value_attr); + ConvertOp dcast = + rewriter.create(op->getLoc(), old_result_type, new_const); + // replace all convert ops with dq op. + convert_op->replaceAllUsesWith(dcast); + // Return without scanning for the next ConvertOp as only one ConvertOp is + // connected to all quantizable ops. + return true; + } + return false; + } +}; + +// TODO(b/264218457): Refactors the current file to parse preset quantization +// options and allow modular control of quantization specs. +void QuantizeWeightPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the StableHLO dialect Quantize Weight pass. +std::unique_ptr> CreateQuantizeWeightPass( + ::stablehlo::quantization::QuantizationOptions quantization_options) { + return std::make_unique(quantization_options); +} +} // namespace stablehlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto index 22163b54a6d..41834f95fd8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto @@ -29,7 +29,7 @@ message QuantizationMethod { // NEXT ID: 2 message PresetQuantizationMethod { // Preset quantization methods that are supported as a stable API. - // NEXT ID: 3 + // NEXT ID: 5 enum PresetMethod { // TODO(b/266173150): Update preset methods after redefining quantization // pattern matching in DarwiNN. @@ -37,14 +37,24 @@ message PresetQuantizationMethod { METHOD_UNSPECIFIED = 0; // go/do-include-enum-unspecified // Apply default weight-only quantization. Weights are quantized during - // conversion, then dequantized during inference. Data type is as follows: - // Weight: i8, Bias: f32, Activation: f32, Input/output: f32 + // conversion, then dequantized during inference. + // Activation: f32, Weight: qi8, Bias: f32 WEIGHT_ONLY = 1; // Apply default dynamic range quantization. Quantized tensor value's - // ranges are determined during graph runtime. Data type is as follows: - // Weight: i8, Bias: f32, Activation: f32, Input/output: f32 - DYNAMIC_RANGE = 2; + // ranges are determined during graph runtime. + // Activation: f32, Weight: qi8, Bias: f32 + POST_TRAINING_QUANTIZATION_DYNAMIC_RANGE = 2; + + // Apply float16 quantization to all the weights. Quantized weights will be + // dequantized before running inference. + // Activation: f32, Weight: f16, Bias: f16 + FLOAT16 = 3; + + // Apply static range quantization. The quantization range is determined + // via calibration phase and quantized during conversion. + // Activation: qi8, Weight: qi8, Bias: qi32 + POST_TRAINING_QUANTIZATION_STATIC_RANGE = 4; } PresetMethod preset_method = 1; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.cc new file mode 100644 index 00000000000..05290bcb126 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.cc @@ -0,0 +1,32 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace stablehlo { +namespace quantization { + +void AddQuantizationPasses(mlir::PassManager& pass_manager, + const QuantizationOptions& quantization_options) { + pass_manager.addNestedPass( + mlir::stablehlo::CreateQuantizeWeightPass(quantization_options)); +} + +} // namespace quantization +} // namespace stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h new file mode 100644 index 00000000000..d754be94fc6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_QUANTIZE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_QUANTIZE_PASSES_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace stablehlo { +namespace quantization { +// Adds passes for quantization of individual quantizable components. +// (i.e. activation, weight, bias) +void AddQuantizationPasses(mlir::PassManager& pass_manager, + const QuantizationOptions& quantization_options); + +} // namespace quantization +} // namespace stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_QUANTIZE_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD new file mode 100644 index 00000000000..00c76a029e9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/BUILD @@ -0,0 +1,29 @@ +load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +glob_lit_tests( + data = [":test_utilities"], + driver = "//tensorflow/compiler/mlir/quantization/stablehlo:run_lit.sh", + size_override = { + }, + tags_override = { + }, + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + "@llvm-project//mlir:run_lit.sh", + # TODO(b/254144841): Add tests in this directory with the proper stablehlo-opt. + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index a8045b28116..2d42d137f9b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -8,8 +8,8 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package_group( name = "internal_visibility_allowlist_package", packages = [ - "//tensorflow/compiler/mlir/quantization/...", "//tensorflow/compiler/mlir/lite/...", + "//tensorflow/compiler/mlir/quantization/...", ] + internal_visibility_allowlist(), ) @@ -430,7 +430,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/utils:name_utils", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/mlir_hlo", @@ -442,6 +442,7 @@ cc_library( "//tensorflow/core/platform:env", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:path", + "//tensorflow/core/tpu:tpu_defs", "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/tsl/platform:str_util", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index f5d0d9a3542..17e12765f6f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -48,7 +48,8 @@ std::unique_ptr GetTFOpQuantSpec(Operation* op) { } } else if (function_name.contains("matmul")) { spec->coeff_op_quant_dim[1] = -1; - if (function_name.contains("with_bias")) { + if (function_name.contains("with_bias") || + function_name.contains("and_bias")) { spec->biases_params[2] = {{0, 1}, quant::GetUniformQuantizedTypeForBias}; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc index a7839292fee..f8d46612814 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc @@ -47,6 +47,66 @@ class CastBf16OpsToF32Pass void runOnOperation() override; }; +class CastBf16OpsToF32 : public RewritePattern { + public: + explicit CastBf16OpsToF32(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + private: + LogicalResult match(Operation* op) const override { + if (isa(op) || + op->getName().hasTrait()) { + return failure(); + } + for (Value input : op->getOperands()) { + if (getElementTypeOrSelf(input).isBF16()) { + return success(); + } + } + for (Value value : op->getResults()) { + if (getElementTypeOrSelf(value).isBF16()) { + return success(); + } + } + return failure(); + } + + void rewrite(Operation* op, PatternRewriter& rewriter) const override { + // Casts inputs of the operation. + for (int i = 0; i < op->getNumOperands(); i++) { + Value input = op->getOperand(i); + if (getElementTypeOrSelf(input).isBF16()) { + Value f32_cast = rewriter.create( + op->getLoc(), + CloneTypeWithNewElementType(input.getType(), rewriter.getF32Type()), + input); + op->setOperand(i, f32_cast); + } + } + + // Casts BF16 outputs of the operation. + for (Value value : op->getResults()) { + if (getElementTypeOrSelf(value).isBF16()) { + value.setType(CloneTypeWithNewElementType(value.getType(), + rewriter.getF32Type())); + rewriter.setInsertionPointAfterValue(value); + for (Operation* user : op->getUsers()) { + for (int i = 0; i < user->getNumOperands(); i++) { + if (user->getOperand(i) == value) { + Value bf16_cast = rewriter.create( + user->getLoc(), + CloneTypeWithNewElementType(value.getType(), + rewriter.getBF16Type()), + value); + user->setOperand(i, bf16_cast); + } + } + } + } + } + } +}; + #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.inc" void CastBf16OpsToF32Pass::runOnOperation() { @@ -54,6 +114,7 @@ void CastBf16OpsToF32Pass::runOnOperation() { RewritePatternSet patterns(ctx); auto module_op = getOperation(); + patterns.add(ctx); populateWithGenerated(patterns); if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td index 5e38c6b1681..ace1a77e6f3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.td @@ -32,148 +32,3 @@ def RemoveUnneededCastOps : Pat< (replaceWithValue $input), [(AreTheSameElementType $input, $output)]>; -// Cast BF16 Conv2D ops to FP32 Conv2D ops. Inputs and -// filters will be casted to fp32 as well, and unused -// BF16 constant values will be removed by the compiler. -def CastBFloat16ConvToFloat32 : Pat< - (TF_Conv2DOp:$res - $input, $filter, $strides, $use_cudnn_on_gpu, $padding, - $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), - (TF_CastOp - (TF_Conv2DOp - (TF_CastOp - $input, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $input))), - (TF_CastOp - $filter, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $filter))), - $strides, $use_cudnn_on_gpu, $padding, - $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations, - (returnType (CloneTypeWithF32ElementType $res))), - /*truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $input), - (IsBF16ElementType $filter)], - (addBenefit 1)>; - -// Casts BF16 BiasAdd ops to F32 to optimize quantizable ops followed by -// BiasAdd ops. This cast will cover Conv + BiasAdd, MatMul + BiasAdd, -// etc. -def CastBFloat16BiasAddToFloat32 : Pat< - (TF_BiasAddOp:$res - $input, $bias, IsDataFormatNHWC:$bias_data_format), - (TF_CastOp - (TF_BiasAddOp - (TF_CastOp - $input, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $input))), - (TF_CastOp - $bias, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $bias))), - IsDataFormatNHWC:$bias_data_format, - (returnType (CloneTypeWithF32ElementType $res))), - /*truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $input), - (IsBF16ElementType $bias)], - (addBenefit 1)>; - -def CastBFloat16AvgPoolToFloat32 : Pat< - (TF_AvgPoolOp:$res - $input, $ksize, $strides, $padding, - IsDataFormatNHWC:$data_format), - (TF_CastOp - (TF_AvgPoolOp - (TF_CastOp - $input, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $input))), - $ksize, $strides, $padding, - IsDataFormatNHWC:$data_format, - (returnType (CloneTypeWithF32ElementType $res))), - /*truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $input)], - (addBenefit 1)>; - -def CastBFloat16MatMulToFloat32 : Pat< - (TF_MatMulOp:$res - $input, $filter, $transpose_a, $transpose_b), - (TF_CastOp - (TF_MatMulOp - (TF_CastOp - $input, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $input))), - (TF_CastOp - $filter, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $filter))), - $transpose_a, $transpose_b, - (returnType (CloneTypeWithF32ElementType $res))), - /*truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $input), - (IsBF16ElementType $filter)], - (addBenefit 1)>; - -def CastBFloat16BatchMatMulV2ToFloat32 : Pat< - (TF_BatchMatMulV2Op:$res - $input, $filter, $adj_x, $adj_y), - (TF_CastOp - (TF_BatchMatMulV2Op - (TF_CastOp - $input, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $input))), - (TF_CastOp - $filter, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $filter))), - $adj_x, $adj_y, - (returnType (CloneTypeWithF32ElementType $res))), - /*truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $input), - (IsBF16ElementType $filter)], - (addBenefit 1)>; - -def CastBFloat16DepthwiseConvToFloat32 : Pat< - (TF_DepthwiseConv2dNativeOp:$res - $input, $filter, $strides, $padding, - $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), - (TF_CastOp - (TF_DepthwiseConv2dNativeOp - (TF_CastOp - $input, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $input))), - (TF_CastOp - $filter, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $filter))), - $strides, $padding, $explicit_paddings, - IsDataFormatNHWC:$data_format, $dilations, - (returnType (CloneTypeWithF32ElementType $res))), - /*truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $input), - (IsBF16ElementType $filter)], - (addBenefit 1)>; - -def CastBFloat16GatherToFloat32 : Pat< - (TF_GatherV2Op:$res - $params, $indices, $axis, $batch_dims), - (TF_CastOp - (TF_GatherV2Op - (TF_CastOp - $params, /*truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $params))), - $indices, $axis, $batch_dims, - (returnType (CloneTypeWithF32ElementType $res))), - /*truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $params), - (IsConstTensor $params)], - (addBenefit 1)>; - -// Converts an AddV2 op accepting two bfloat16 operands into the one taking two -// float32 operands. -def CastBFloat16AddV2ToFloat32 : Pat< - (TF_AddV2Op:$res $x, $y), - (TF_CastOp - (TF_AddV2Op - (TF_CastOp $x, /*Truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $x))), - (TF_CastOp $y, /*Truncate=*/ConstBoolAttrFalse, - (returnType (CloneTypeWithF32ElementType $y)))), - /*Truncate=*/ConstBoolAttrFalse), - [(IsBF16ElementType $x), - (IsBF16ElementType $y), - (IsBF16ElementType $res)]>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc index 586fe870808..bb606714023 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc @@ -23,7 +23,9 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/tpu/tpu_defs.h" namespace mlir { namespace quant { @@ -56,17 +58,24 @@ class RemoveTpuOp : public RewritePattern { : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} private: - LogicalResult matchAndRewrite(Operation* call_op, + LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override { + // Remove `_tpu_replicate` attributes on each operation first. + if (op->hasAttr(tensorflow::kTPUReplicateAttr)) { + op->removeAttr(tensorflow::kTPUReplicateAttr); + return success(); + } + + // Remove TPU operations. if (isa(call_op)) { - call_op->erase(); + TF::TPUOrdinalSelectorOp>(op)) { + op->erase(); } else if (auto replicated_input_op = - dyn_cast_or_null(call_op)) { + dyn_cast_or_null(op)) { // TODO(b/267700110): Handle multiple input/output cases. rewriter.replaceOp(replicated_input_op, replicated_input_op.getInputs()); } else if (auto replicated_output_op = - dyn_cast_or_null(call_op)) { + dyn_cast_or_null(op)) { // TODO(b/267700110): Handle multiple input/output cases. rewriter.replaceOp(replicated_output_op, replicated_output_op.getInput()); } else { @@ -115,6 +124,7 @@ void ConvertTpuModelToCpuPass::runOnOperation() { patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { module_op.emitError() << "quant-convert-tpu-model-to-cpu pattern " diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc index 84e6f0781fb..96d42ebedf3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc @@ -171,9 +171,8 @@ void InsertQuantizedFunctionsPass::runOnOperation() { StatusScopedDiagnosticHandler diagnostic_handler(context); if (failed(pm.run(*module_ref))) { - emitError(module.getLoc()) - << "failed to apply the optimization: " - << diagnostic_handler.ConsumeStatus().error_message(); + emitError(module.getLoc()) << "failed to apply the optimization: " + << diagnostic_handler.ConsumeStatus().message(); signalPassFailure(); return; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td index 04934a479de..1eaaecf5b61 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td @@ -152,6 +152,22 @@ def LiftMatmulWithBias : Pat< (NamedAttr<"transpose_b"> $transpose_b))), [(IsNotInLiftedFunc $res)], (addBenefit 5)>; +// TODO(b/278493977): Create generic implementation of lifting any fused op +// with any reshaping op +def LiftMatmulWithReshapeAndBias : Pat< + (TF_BiasAddOp:$res + (TF_ReshapeOp:$out + (TF_MatMulOp $a, $b, $transpose_a, $transpose_b), + $shape), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsFunctionCall<"composite_matmul_with_reshape_and_bias_fn"> + (ArgumentList $a, $b, $bias, $shape), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"transpose_a"> $transpose_a), + (NamedAttr<"transpose_b"> $transpose_b))), + [(IsNotInLiftedFunc $res)], (addBenefit 5)>; + def LiftConv3dWithBias : Pat< (TF_BiasAddOp:$res (TF_Conv3DOp $input, $filter, $strides, $padding, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index 6c73b266837..1d4db2b7067 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -13,28 +13,50 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include +#include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace mlir { namespace quant { @@ -47,7 +69,11 @@ class PrepareLiftingPass PrepareLiftingPass() = default; - explicit PrepareLiftingPass(const OpSet op_set) : op_set_(op_set) {} + explicit PrepareLiftingPass(OpSet op_set) { op_set_ = op_set; } + + PrepareLiftingPass(const PrepareLiftingPass& other) { + op_set_ = other.op_set_; + } StringRef getArgument() const final { // This is the argument used to refer to the pass in @@ -68,7 +94,15 @@ class PrepareLiftingPass void runOnOperation() override; private: - OpSet op_set_; + Option op_set_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; }; // Check if given indices in `val1` has same number of elements as given @@ -116,10 +150,14 @@ LogicalResult MatchSupportedAffineOp(Operation* op, Value& binding_output, is_supported_affine_op = data_format.getValue().equals("NHWC") || data_format.getValue().equals("NDHWC"); } - } else if (llvm::isa(op)) { + } else if (llvm::isa(op)) { if (const auto adj_y = op->getAttrOfType("adj_y")) { is_supported_affine_op = !adj_y.getValue(); } + } else if (llvm::isa(op)) { + if (const auto adj_y = op->getAttrOfType("transpose_b")) { + is_supported_affine_op = !adj_y.getValue(); + } } if (!is_supported_affine_op) return failure(); @@ -141,7 +179,7 @@ Value MakeOneDimValueBroadcastable(OpBuilder& builder, Location loc, } int64_t num_elements = value_shape.getNumElements(); - llvm::SmallVector new_shape; + SmallVector new_shape; for (auto idx : llvm::reverse(llvm::seq(0, rhs_shape.getRank()))) { const int64_t rhs_dim = rhs_shape.getDimSize(idx); if (num_elements % rhs_dim != 0) { @@ -260,6 +298,243 @@ Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, return ConstantFoldOpIfPossible(dequantize).front(); } +// Generate an einsum equation from the given DotDimensionNumber. +std::string CreateEinsumEquation( + const xla::DotDimensionNumbers& dot_dimension_numbers, const int lhs_rank, + const int rhs_rank) { + // Prepare necessary indices. + absl::flat_hash_set lhs_batch_idx, rhs_batch_idx; + absl::flat_hash_set lhs_contract_idx, rhs_contract_idx; + lhs_batch_idx.insert(dot_dimension_numbers.lhs_batch_dimensions().begin(), + dot_dimension_numbers.lhs_batch_dimensions().end()); + lhs_contract_idx.insert( + dot_dimension_numbers.lhs_contracting_dimensions().begin(), + dot_dimension_numbers.lhs_contracting_dimensions().end()); + rhs_batch_idx.insert(dot_dimension_numbers.rhs_batch_dimensions().begin(), + dot_dimension_numbers.rhs_batch_dimensions().end()); + rhs_contract_idx.insert( + dot_dimension_numbers.rhs_contracting_dimensions().begin(), + dot_dimension_numbers.rhs_contracting_dimensions().end()); + + // Generate equation. + std::string lhs_eq = ""; + std::string rhs_eq = ""; + std::string out_eq = ""; + char c = 'a'; + std::vector lhs_batch_dims; + std::vector lhs_contract_dims; + for (int i = 0; i < lhs_rank; i++) { + absl::StrAppend(&lhs_eq, std::string(1, c)); + if (lhs_batch_idx.contains(i)) { + lhs_batch_dims.push_back(c); + } else if (lhs_contract_idx.contains(i)) { + lhs_contract_dims.push_back(c); + } + c++; + } + + int batch_trace_idx = 0; + int contract_trace_idx = 0; + const bool rhs_only_batch = lhs_batch_dims.empty(); + for (int i = 0; i < rhs_rank; i++) { + if (rhs_batch_idx.contains(i)) { + if (rhs_only_batch) { + rhs_eq.push_back(c); + lhs_batch_dims.push_back(c); + c++; + } else { + rhs_eq.push_back(lhs_batch_dims[batch_trace_idx]); + batch_trace_idx++; + } + } else if (rhs_contract_idx.contains(i)) { + absl::StrAppend(&rhs_eq, + std::string(1, lhs_contract_dims[contract_trace_idx])); + contract_trace_idx++; + } else { + rhs_eq += c; + c++; + } + } + + // Create out_eq by merging lhs and rhs. + // In XlaDotv2 style - batch dim - leftover from lhs - leftover from rhs. + for (const char c : lhs_batch_dims) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + for (const char c : lhs_eq) { + if (!absl::StrContains(out_eq, c) && !absl::StrContains(rhs_eq, c)) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + } + for (const char c : rhs_eq) { + if (!absl::StrContains(out_eq, c) && !absl::StrContains(lhs_eq, c)) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + } + + return absl::StrCat(lhs_eq, ",", rhs_eq, "->", out_eq); +} + +Value CreateEinsumOpFromXlaDotV2Op(OpBuilder& builder, const Location loc, + Value lhs, Value rhs, Value output, + StringAttr dot_dimension_numbers_str) { + xla::DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.ParseFromString(dot_dimension_numbers_str.str()); + SmallVector input_arguments = {lhs, rhs}; + const int lhs_rank = + lhs.getType().template cast().getShape().size(); + const int rhs_rank = + rhs.getType().template cast().getShape().size(); + + const std::string einsum_equation = + CreateEinsumEquation(dot_dimension_numbers, lhs_rank, rhs_rank); + + return builder.create(loc, output.getType(), input_arguments, + builder.getStringAttr(einsum_equation)); +} + +// Restores the collapsed dimensions to the `tensor_type`. `collapsed_dims` +// designate the dimension indices that were collapsed to produce `tensor_type`. +// The restored dimensions' sizes are 1, according to the semantics of +// `XlaGatherOp (https://www.tensorflow.org/xla/operation_semantics#gather). The +// resulting type's shape has `tensor_type.size() + collapsed_dims.size()` +// dimensions. +RankedTensorType RestoreCollapsedDimensions( + const RankedTensorType tensor_type, + const absl::flat_hash_set& collapsed_dims) { + ArrayRef original_tensor_shape = tensor_type.getShape(); + const int output_tensor_rank = + original_tensor_shape.size() + collapsed_dims.size(); + auto shape_itr = tensor_type.getShape().begin(); + + // Populate the dimensions of the output shape, including the restored + // dimensions. + SmallVector output_shape(output_tensor_rank); + for (int i = 0; i < output_tensor_rank; i++) { + if (collapsed_dims.contains(i)) { + // The collapsed dimension's size should have been 1, so it restores the + // dimension with size 1. + output_shape[i] = 1; + } else { + output_shape[i] = *shape_itr; + shape_itr++; + } + } + + return RankedTensorType::get(output_shape, tensor_type.getElementType()); +} + +// Determines the output type of the `SliceOp` when it is being inserted in +// place of a `XlaGatherOp`. When the dimensions of `xla_gather_op_output_type` +// is known, the `collapsed_dims` are restored. `xla_gather_op_output_type` is +// the result of collapsing the `collapsed_dims`, but the `SliceOp`'s output +// should not have the dimensions collapsed already. Returns +// `xla_gather_op_output_type` unchanged if the rank is unknown. +// +// Examples: +// * If `xla_gather_op_output_type` == tensor<*xf32>, then it returns: +// tensor<*xf32>. +// * If `xla_gather_op_output_type` == tensor<3x5xi32> and `collapsed_dims` == +// {0}, then it returns: tensor<1x3x5xi32>. +// * If `xla_gather_op_output_type` == tensor<3x5xf32> and `collapsed_dims` == +// {1, 3}, then it returns: tensor<3x1x5x1xf32>. +Type GetSliceOpOutputType(Type xla_gather_op_output_type, + const absl::flat_hash_set& collapsed_dims) { + if (auto ranked_output_type = + xla_gather_op_output_type.dyn_cast(); + ranked_output_type) { + return RestoreCollapsedDimensions(ranked_output_type, collapsed_dims); + } + + return xla_gather_op_output_type; +} + +// TODO (b/275225582): Supports Xla Gather op in general case. +bool IsXlaGatherWithoutBatch(Value operand, Value start_indices) { + auto operand_type = operand.getType().dyn_cast_or_null(); + auto start_indices_type = + start_indices.getType().dyn_cast_or_null(); + if (start_indices_type == nullptr || operand_type == nullptr) return false; + return start_indices_type.getShape().size() == 1; +} + +Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( + OpBuilder& builder, const Location loc, Value operand, Value start_indices, + Value slice_sizes, Value output, StringAttr dimension_numbers_str) { + // Reads dimension numbers. + xla::GatherDimensionNumbers dimension_numbers; + dimension_numbers.ParseFromString(dimension_numbers_str.str()); + + // Construct full start_indices with given start_indices and + // start_index_map. + const ArrayRef operand_shape = + operand.getType().cast().getShape(); + const int64_t operand_rank = operand_shape.size(); + + // Fills zeros if start_index is not given in start_indices. + Value empty_start_indices = builder.create( + loc, RankedTensorType::get({operand_rank}, builder.getI64Type()), + /*shape=*/Create1DConstValue(builder, loc, {operand_rank}), + /*value=*/CreateScalarConstValue(builder, loc, 0)); + + // Converts start_index_map proto to tensor. + const int64_t index_map_size = dimension_numbers.start_index_map().size(); + SmallVector indices(index_map_size); + for (int64_t i = 0; i < index_map_size; i++) { + indices[i] = dimension_numbers.start_index_map()[i]; + } + + // Fill elements from start_indices with start_index_map + Value scattered_start_indices = builder.create( + loc, empty_start_indices, + /*indices=*/ + builder.create( + loc, RankedTensorType::get({index_map_size, 1}, builder.getI64Type()), + Create1DConstValue(builder, loc, indices), + Create1DConstValue(builder, loc, {index_map_size, 1})), + /*value=*/ + builder.create( + loc, + RankedTensorType::get( + start_indices.getType().template cast().getShape(), + builder.getI64Type()), + start_indices)); + + absl::flat_hash_set collapsed_dims; + collapsed_dims.insert(dimension_numbers.collapsed_slice_dims().begin(), + dimension_numbers.collapsed_slice_dims().end()); + + // Slice operand by constructed start_indices and slice_sizes. + auto slice_op = builder.create( + loc, GetSliceOpOutputType(output.getType(), collapsed_dims), operand, + /*start_indices=*/scattered_start_indices, + /*slice_sizes=*/ + builder.create( + loc, + RankedTensorType::get( + slice_sizes.getType().template cast().getShape(), + builder.getI64Type()), + slice_sizes)); + + // Collapses dimensions by reshaping. + SmallVector new_shape(operand_rank - collapsed_dims.size()); + for (int64_t i = 0, j = 0; i < operand_rank; i++) { + if (!collapsed_dims.contains(i)) { + new_shape[j++] = operand_shape[i]; + } + } + if (!new_shape.empty()) new_shape[0] = -1; + return builder.create( + loc, output.getType(), slice_op, + Create1DConstValue(builder, loc, new_shape)); +} + +bool IsPrecisionEmpty(StringAttr prec_str) { + xla::PrecisionConfig prec; + prec.ParseFromString(prec_str.str()); + return !prec.operand_precision_size(); +} + #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.inc" void PrepareLiftingPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index 52d4505781c..6f6e6d89da6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -21,12 +21,48 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" include "mlir/Dialect/Arith/IR/ArithOps.td" +// Creates Einsum Op from XlaDotV2 Op by generating equation. +def CreateEinsumOpFromXlaDotV2Op : NativeCodeCall< + "CreateEinsumOpFromXlaDotV2Op($_builder, $_loc, $0...)">; + +// Only handles the case where precision config is default. +def IsPrecisionEmpty : + Constraint>; + +// Convert XlaDotV2 Op to Einsum Op with above two functions. +def ConvertXlaDotV2OpToEinsumOp : Pat< + (TF_XlaDotV2Op:$dot $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (CreateEinsumOpFromXlaDotV2Op $lhs, $rhs, $dot, $dot_dimension_numbers), + [(IsPrecisionEmpty $precision_config)]>; + // Converts arith.constant ops from freezing passes back to tf.Const ops. def ConvertArithConstToTfConst : Pat< (Arith_ConstantOp:$res DenseElementsAttr:$value), (TF_ConstOp $value), [(AnyStaticShapeTensor $res)]>; +// Converts CheckNumerics op to Identity +def ConvertCheckNumerics : Pat< + (TF_CheckNumericsOp $arg, $msg), + (TF_IdentityOp $arg)>; + +// Only handles the case where batch_dimension is empty. +def IsXlaGatherWithoutBatch : + Constraint>; + +// Create Slice op from XlaGather op without batch dimension. +def CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch : NativeCodeCall< + "CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch($_builder, $_loc, $0...)">; + +// Convert XlaGather op without batch to Slice op with above two functions. +def ConvertXlaGatherOpWithoutBatch : Pat< + (TF_XlaGatherOp:$gather $operand, + $start_indices, $slice_sizes, $dimension_numbers, $indices_are_sorted), + (CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch $operand, + $start_indices, $slice_sizes, $gather, $dimension_numbers), + [(IsXlaGatherWithoutBatch $operand, $start_indices)]>; + + // Converts tf.FusedBatchNormV3 into a sequence of more primitive arithmetic // operations. Specifically, performs the following calculation: // @@ -104,6 +140,19 @@ def ConvertAddToBiasAdd : Pat< [(HasRankOf<1> $add_rhs_value), (HasEqualElementSize<[-1], [0]> $conv_out, $add_rhs)]>; +// TODO(b/278493977): Create generic implementation of lifting any fused op +// with any reshaping op +def ConvertAddWithReshapeToBiasAddWithReshape : Pat< + (TF_AddV2Op + (TF_ReshapeOp:$reshape_out + (SupportedAffineOpMatcher $_, $_, $_), + $_ + ), + (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), + (TF_BiasAddOp $reshape_out, $add_rhs, (CreateStringAttr<"NHWC">)), + [(HasRankOf<1> $add_rhs_value), + (HasEqualElementSize<[-1], [0]> $reshape_out, $add_rhs)]>; + // Fuse consecutive BiasAddOp and an AddV2Op. def FuseBiasAndAddV2 : Pat< (TF_AddV2Op diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index ce8aac4b8fa..bf93774e67f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -15,38 +15,42 @@ limitations under the License. // Copied and modified from // //third_party/tensorflow/compiler/mlir/lite/transforms/quantize.cc // This transformation pass applies quantization on TF dialect. +#include #include #include #include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/framework/types.pb.h" namespace mlir { namespace quant { @@ -81,14 +85,28 @@ struct TFQuantizationBase // range quantization. static bool AllowDynamicRangeQuantizedOperand( Operation* quantized_op, const CustomMap& custom_op_map) { - return quantization_trait == kDynamicRangeQuantization; + auto call_op = cast(quantized_op); + StringRef function_name = + call_op.getFAttr().cast().getValue(); + // The below can be generalized as there are more read-only ops added such + // as slice. + const bool is_gather = function_name.contains("gather"); + return quantization_trait != kFullQuantization || is_gather; } // All the quantized ops are supported if the quantization method is dynamic // range quantization. static bool AllowDynamicRangeQuantizedResult(Operation* quantized_op, const CustomMap& custom_op_map) { - return quantization_trait == kDynamicRangeQuantization; + auto call_op = cast(quantized_op); + StringRef function_name = + call_op.getFAttr().cast().getValue(); + // The below can be generalized as there are more read-only ops added such + // as slice. + bool is_gather = false; + if (function_name.contains("gather")) is_gather = true; + return quantization_trait != kFullQuantization || + (quantization_trait == kFullQuantization && is_gather); } // If weight_only_quantization is true, the legacy weight-only quantization is @@ -164,7 +182,7 @@ class QuantizeSameScaleOpsPattern LogicalResult matchAndRewrite(quantfork::DequantizeCastOp op, PatternRewriter& rewriter) const override { - llvm::SmallVector quantizing_ops; + SmallVector quantizing_ops; auto users = op.getResult().getUsers(); quantizing_ops.append(users.begin(), users.end()); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index 31f0ade7ef8..6c374141025 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -452,24 +452,47 @@ LogicalResult TransferTFAttributesToTFUniformAttributes( // Set the attributes for ops with the attr_map attribute. for (Operation& inner_op : quantized_func.getBody().front().getOperations()) { if (auto uniform_op = - llvm::dyn_cast(inner_op)) { + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { if (failed(FillAttributesForUniformQuantizedConvolutionOp( rewriter, uniform_op, identifier_to_attr, quantization_method, enable_per_channel_quantization))) return failure(); } else if (auto uniform_op = - llvm::dyn_cast( - inner_op)) { + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { if (failed(FillAttributesForUniformQuantizedConvolutionOp( rewriter, uniform_op, identifier_to_attr, quantization_method, enable_per_channel_quantization))) return failure(); } else if (auto uniform_op = - llvm::dyn_cast(inner_op)) { + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { if (failed(FillAttributesForUniformQuantizedDotOp( rewriter, uniform_op, identifier_to_attr, quantization_method, enable_per_channel_quantization))) return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizedAddOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizedClipByValueOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformRequantizeOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); } } return success(); @@ -535,15 +558,28 @@ LogicalResult TransferAttributes(func::FuncOp float_func, } // Get the corresponding quantized function name from the given function name. -std::string GetQuantizedFunctionName(StringRef func_name) { +std::string GetQuantizedFunctionName(StringRef func_name, + const bool is_hybrid) { if (func_name.startswith(kQuantizedFuncPrefix)) return func_name.str(); if (!func_name.startswith(kCompositeFuncPrefix)) return ""; - return llvm::Twine(kQuantizedFuncPrefix) - .concat(llvm::Twine( - func_name.substr(kCompositeFuncPrefix.size()).rsplit("_fn").first)) - .concat("_fn") - .str(); + auto base_function_name = + llvm::Twine(kQuantizedFuncPrefix) + .concat(llvm::Twine(func_name.substr(kCompositeFuncPrefix.size()) + .rsplit("_fn") + .first)); + + return is_hybrid + ? base_function_name.concat("_float_output").concat("_fn").str() + : base_function_name.concat("_fn").str(); +} + +bool ContainsQuantizedReusltType(ArrayRef result_types) { + for (auto current_type : result_types) { + if (!current_type.dyn_cast().getElementType().isF32()) + return true; + } + return false; } // Unwraps quantization parameters of PartitionedCall ops with quantized @@ -554,20 +590,17 @@ class QuantizeFunctionPattern explicit QuantizeFunctionPattern(MLIRContext* context, const QuantMethod quantization_method, const OpSet target_opset, - const bool enable_per_channel_quantization, - const bool enable_legacy_weight_only) + const bool enable_per_channel_quantization) : OpRewritePattern(context), quantization_method_(quantization_method), target_opset_(target_opset), - enable_per_channel_quantization_(enable_per_channel_quantization), - enable_legacy_weight_only_(enable_legacy_weight_only) {} + enable_per_channel_quantization_(enable_per_channel_quantization) {} private: QuantMethod quantization_method_ = tensorflow::quantization::QuantizationMethod::STATIC_RANGE; OpSet target_opset_ = OpSet::TF; bool enable_per_channel_quantization_; - bool enable_legacy_weight_only_; LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { @@ -579,24 +612,20 @@ class QuantizeFunctionPattern if (!f_attr.getValue().startswith(kCompositeFuncPrefix)) { return failure(); } - // Determines if all required float input/outputs are now quantized. - bool has_quantized_types = true; - switch (quantization_method_) { - case tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE: - has_quantized_types &= IsQuantizedCallforDynamicRange(call_op); - break; - case tensorflow::quantization::QuantizationMethod::STATIC_RANGE: - has_quantized_types &= IsQuantizedCallforStaticRange(call_op); - break; - case tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY: - // Skipping input type check for weight-only quantization as it can be - // dequantized beforehand for the legacy scheme. - has_quantized_types &= !enable_legacy_weight_only_; - break; - default: - call_op->emitError("The quantization method is not supported."); - return failure(); + + bool has_quantized_types = false; + if (quantization_method_ == + tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY) { + // Skipping input type check for weight-only quantization as it can be + // dequantized beforehand for the legacy scheme. + has_quantized_types = true; + } else { + // Determines if all required float input/outputs are now quantized. + // Either one of the criteria needs to meet. + has_quantized_types |= IsQuantizedCallforDynamicRange(call_op); + has_quantized_types |= IsQuantizedCallforStaticRange(call_op); } + if (!has_quantized_types) return failure(); SmallVector args; @@ -703,7 +732,6 @@ class QuantizeFunctionPattern result_types.push_back(result_type); continue; } - if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_result_type = ConvertIntToQint( result_type.cast(), rewriter.getContext()); @@ -730,8 +758,14 @@ class QuantizeFunctionPattern dyn_cast(symbol_table.lookup(f_attr.getValue())); rewriter.setInsertionPointAfter(float_func); + // Applies only for hybrid ops in SRQ. + const bool is_hybrid = + !ContainsQuantizedReusltType(result_types) && + (quantization_method_ == + tensorflow::quantization::QuantizationMethod::STATIC_RANGE); const std::string quantized_function_name = - GetQuantizedFunctionName(f_attr.getValue()); + GetQuantizedFunctionName(f_attr.getValue(), is_hybrid); + const mlir::func::FuncOp quantized_func = dyn_cast(symbol_table.lookup(quantized_function_name)); mlir::func::FuncOp new_quantized_func = @@ -816,7 +850,7 @@ class QuantizeFunctionPattern // the length of the "_fn" suffix. const size_t fn_suffix_length = 3; std::string quantized_function_name = - GetQuantizedFunctionName(f_attr.getValue()); + GetQuantizedFunctionName(f_attr.getValue(), /*is_hybrid=*/false); quantized_function_name.replace( quantized_function_name.size() - fn_suffix_length, fn_suffix_length, kFloatOutputFuncPrefix); @@ -905,7 +939,8 @@ class QuantizeConstPattern // TODO(b/225793355): It adds TensorProtoAttr to the constant as a // workaround. tensorflow::TensorProto tensor_proto; - if (!mlir::tfg::ConvertToTensorProto(tensor_proto_attr, &tensor_proto) + if (!mlir::tfg::ConvertToTensorProto( + tensor_proto_attr.cast(), &tensor_proto) .ok()) { return failure(); } @@ -1047,7 +1082,8 @@ class QuantizationSummary { // Get the representative name attribute value of a composite function. FailureOr GetRepresentativeName(StringRef func_name) { - std::string quantized_func_name = GetQuantizedFunctionName(func_name); + std::string quantized_func_name = + GetQuantizedFunctionName(func_name, /*is_hybrid=*/false); auto quantized_func = dyn_cast_or_null( symbol_table_.lookup(quantized_func_name)); // Quantized function does not exist for weight-only case. @@ -1125,13 +1161,16 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { signalPassFailure(); } - RewritePatternSet patterns(ctx); - patterns.add( - ctx, quantization_method_, target_opset_, - enable_per_channel_quantization_, enable_legacy_weight_only_); + // Legacy weight-only does not require quantized ops. + if (!enable_legacy_weight_only_) { + RewritePatternSet patterns(ctx); + patterns.add(ctx, quantization_method_, + target_opset_, + enable_per_channel_quantization_); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { - signalPassFailure(); + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } } // Constant quantization is a lossy transformation, so they are applied only diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir index 07883e99afd..7ccff9d7091 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir @@ -393,6 +393,36 @@ module { } } // end for + // TODO(b/278493977): Create generic implementation of lifting any fused op + // with any reshaping op + for main_op in ["MatMul"] { + parameters[ + {"quantized_ops": ["${main_op}", "Reshape", "BiasAdd"], "act_func": "internal_requantize_no_activation_fn", "output_type": "i8"}, + {"quantized_ops": ["${main_op}", "Reshape", "BiasAdd"], "act_func": "internal_dequantize_no_activation_fn", "output_type": "f32"}, + ] + func.func @GenerateQuantizedFunctionName(${quantized_ops}, "${output_type}")(%input : tensor<*xi8>, + %filter : tensor<*xi8>, %bias : tensor<*xi32>, %shape : tensor<*xi32>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, + %bias_scale : tensor<*xf32>, %bias_zp : tensor<*xi32>, + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> + attributes {tf_quant.quantized_ops = ${quantized_ops}} { + %0 = "tf.PartitionedCall"(%input, %filter, %input_scale, %input_zp, + %filter_scale, %filter_zp) { + config = "", config_proto = "", executor_type = "", f=@GenerateImplFunctionName(${main_op}) + } : (tensor<*xi8>, tensor<*xi8>, tensor<*xf32>, tensor<*xi32>, + tensor<*xf32>, tensor<*xi32>) -> tensor<*xi32> + %1 = "tf.Reshape"(%0, %shape) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %2 = "tf.AddV2"(%1, %bias) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %3 = "tf.PartitionedCall"(%2, %input_scale, %input_zp, %filter_scale, %filter_zp, + %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@${act_func} + } : (tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, + tensor<*xf32>, tensor<*xi32>) -> tensor<*x${output_type}> + func.return %3 : tensor<*x${output_type}> + } + } // end for + func.func @quantize_i8(%input : tensor<*xf32>, %scale : tensor<*xf32>, %zp : tensor<*xi32>) -> tensor<*xi8> { %float_zp = "tf.Cast"(%zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> %div = "tf.Div"(%input, %scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> @@ -421,20 +451,49 @@ module { // Weight-only functions. //===----------------------------------------------------------------------===// + func.func private @internal_dequantize_i8_in_f32_fn( + %input : tensor<*xi8>, %weight_scale : tensor<*xf32>) -> tensor<*xf32> { + %input_f32 = "tf.Cast"(%input) : (tensor<*xi8>) -> tensor<*xf32> + %mul = "tf.Mul"(%input_f32, %weight_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %mul : tensor<*xf32> + } + // Note that input i64 type is also supported by this. + // As the output is quantized type, output scale/zp is required for the arguments. parameters[ - {"quantized_ops": ["Gather"], "output_type": "i8"} + {"quantized_ops": ["Gather"], "act_func": "internal_identity_fn", "output_type": "i8"} ] - func.func @GenerateQuantizedFunctionName(${quantized_ops})( + func.func @GenerateQuantizedFunctionName(${quantized_ops}, "${output_type}")( %weight : tensor<*xi8>, %input : tensor<*xi32>, %axis : tensor, %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>, %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> - attributes {tf_quant.quantized_ops = ${quantized_ops}} - { + attributes {tf_quant.quantized_ops = ${quantized_ops}} { + + %out = "tf.GatherV2"(%weight, %input, %axis) { + batch_dims = 0 : i64, attr_map = "batch_dims:0"} : (tensor<*xi8>, tensor<*xi32>, tensor) -> tensor<*xi8> + + func.return %out : tensor<*x${output_type}> + } + + // Note that input i64 type is also supported by this. + // The dequantization is merged to the quantized function. + // As the output type is specified to f32, the quantized function has "_float_output_fn" tag at the end. + parameters[ + {"quantized_ops": ["Gather"], "act_func": "internal_dequantize_i8_in_f32_fn", "output_type": "f32"} + ] + func.func @GenerateQuantizedFunctionName(${quantized_ops}, "${output_type}")( + %weight : tensor<*xi8>, %input : tensor<*xi32>, %axis : tensor, + %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*x${output_type}> + attributes {tf_quant.quantized_ops = ${quantized_ops}} { + %accum_out = "tf.GatherV2"(%weight, %input, %axis) { batch_dims = 0 : i64, attr_map = "batch_dims:0"} : (tensor<*xi8>, tensor<*xi32>, tensor) -> tensor<*xi8> - func.return %accum_out : tensor<*x${output_type}> + %out = "tf.PartitionedCall"(%accum_out, %weight_scale) { + config = "", config_proto = "", executor_type = "", f=@${act_func} + } : (tensor<*xi8>, tensor<*xf32>) -> tensor<*x${output_type}> + + func.return %out : tensor<*x${output_type}> } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir index 2225e588e39..0d95b8eda87 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir @@ -51,7 +51,13 @@ module { %filter_scale, %filter_zp, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@GenerateImplFunctionName(${main_op}) } : (tensor<*x!tf_type.qint8>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> - %add = "tf.UniformQuantizedAdd"(%main_out, %bias, %input_scale, %input_zp, %bias_scale, %bias_zp, %out_scale, %out_zp) { + // Extract channel shape from filter, and ensure input/output scale/zp's have the same channel size. + %filter_shape = "tf.Shape" (%filter_scale) : (tensor<*xf32>) -> tensor<*xi32> + %input_scale_filled = "tf.Fill" (%filter_shape, %input_scale) : (tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> + %input_zp_filled = "tf.Fill" (%filter_shape, %input_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %out_scale_filled = "tf.Fill" (%filter_shape, %out_scale) : (tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> + %out_zp_filled = "tf.Fill" (%filter_shape, %out_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %add = "tf.UniformQuantizedAdd"(%main_out, %bias, %input_scale_filled, %input_zp_filled, %bias_scale, %bias_zp, %out_scale_filled, %out_zp_filled) { lhs_quantization_axis = -1, lhs_quantization_min_val = -128, lhs_quantization_max_val = 127, @@ -64,9 +70,9 @@ module { T = "tfdtype$DT_QINT32", attr_map = "" } : (tensor<*x!tf_type.qint32>, tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> - %act = "tf.PartitionedCall"(%add, %input_scale, %input_zp, %out_scale, %out_zp) { + %act = "tf.PartitionedCall"(%add, %input_scale_filled, %input_zp_filled, %out_scale_filled, %out_zp_filled, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@${act_func} - } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x${output_type}> + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x${output_type}> func.return %act : tensor<*x${output_type}> } @@ -85,9 +91,14 @@ module { %filter_scale, %filter_zp, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@GenerateImplFunctionName(${main_op}) } : (tensor<*x!tf_type.qint8>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> - %act = "tf.PartitionedCall"(%main_out, %input_scale, %input_zp, %out_scale, %out_zp) { + %filter_shape = "tf.Shape" (%filter_scale) : (tensor<*xf32>) -> tensor<*xi32> + %input_scale_filled = "tf.Fill" (%filter_shape, %input_scale) : (tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> + %input_zp_filled = "tf.Fill" (%filter_shape, %input_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %out_scale_filled = "tf.Fill" (%filter_shape, %out_scale) : (tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32> + %out_zp_filled = "tf.Fill" (%filter_shape, %out_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %act = "tf.PartitionedCall"(%main_out, %input_scale_filled, %input_zp_filled, %out_scale_filled, %out_zp_filled, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@${act_func} - } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x${output_type}> + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x${output_type}> func.return %act : tensor<*x${output_type}> } } // end for @@ -198,7 +209,7 @@ module { // Requantizes and applies quantized Relu by clipping. func.func private @internal_requantize_no_activation_fn(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, - %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>, %out_scale_single : tensor<*xf32>, %out_zp_single : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { %q_out = "tf.PartitionedCall"(%input, %input_scale, %input_zp, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@internal_requantize_qi8_fn } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> @@ -207,20 +218,23 @@ module { // Requantizes and applies quantized Relu6 by clipping. func.func private @internal_requantize_and_relu_fn(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, - %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>, %out_scale_single : tensor<*xf32>, %out_zp_single : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %filter_shape = "tf.Shape" (%input_scale) : (tensor<*xf32>) -> tensor<*xi32> %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor + %i8_min_filled = "tf.Fill" (%filter_shape, %i8_min) : (tensor<*xi32>, tensor) -> tensor<*xf32> + %i8_max_filled = "tf.Fill" (%filter_shape, %i8_max) : (tensor<*xi32>, tensor) -> tensor<*xf32> %float_out_zp = "tf.Cast"(%out_zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> - %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor - %qclip_min = "tf.Cast"(%i8_min) {Truncate = false} : (tensor) -> tensor - %qi8_max = "tf.Cast"(%i8_max) {Truncate = false} : (tensor) -> tensor - %relu = "tf.UniformQuantizedClipByValue"(%input, %qclip_min, %qi8_max, %out_scale, %out_zp) { + %clip_min = "tf.Maximum"(%i8_min_filled, %float_out_zp) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %qclip_min = "tf.Cast"(%clip_min) {Truncate = false} : (tensor<*xf32>) -> tensor<*x!tf_type.qint32> + %qclip_max = "tf.Cast"(%i8_max_filled) {Truncate = false} : (tensor<*xf32>) -> tensor<*x!tf_type.qint32> + %relu = "tf.UniformQuantizedClipByValue"(%input, %qclip_min, %qclip_max, %out_scale, %out_zp) { T = "tfdtype$DT_QINT32", quantization_axis = -1, quantization_min_val = -128, quantization_max_val = 127, attr_map = "" - } : (tensor<*x!tf_type.qint32>, tensor, tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + } : (tensor<*x!tf_type.qint32>, tensor<*x!tf_type.qint32>, tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> %requantize = "tf.PartitionedCall"(%relu, %input_scale, %input_zp, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@internal_requantize_qi8_fn } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> @@ -229,30 +243,34 @@ module { // Apply requantization and relu6. func.func private @internal_requantize_and_relu6_fn(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, - %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>, %out_scale_single : tensor<*xf32>, %out_zp_single : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %filter_shape = "tf.Shape" (%input_scale) : (tensor<*xf32>) -> tensor<*xi32> %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor %act_max = "tf.Const"() {value = dense<6.0> : tensor} : () -> tensor - %i8_act_max_0 = "tf.PartitionedCall"(%act_max, %input_scale, %input_zp) { + // Singular scale/zp is needed to ensure quantization is per-tensor for this variable. + %i8_act_max_0 = "tf.PartitionedCall"(%act_max, %out_scale_single, %out_zp_single) { config = "", config_proto = "", executor_type = "", f=@quantize_i8 } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> %i8_act_max_1 = "tf.Cast"(%i8_act_max_0) {Truncate = false} : (tensor<*x!tf_type.qint8>) -> tensor %float_out_zp = "tf.Cast"(%out_zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> - %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor - %clip_max = "tf.Minimum"(%i8_max, %i8_act_max_1) : (tensor, tensor) -> tensor - %qclip_min = "tf.Cast"(%i8_min) {Truncate = false} : (tensor) -> tensor - %qclip_max = "tf.Cast"(%i8_max) {Truncate = false} : (tensor) -> tensor + %i8_min_filled = "tf.Fill" (%filter_shape, %i8_min) : (tensor<*xi32>, tensor) -> tensor<*xf32> + %i8_max_filled = "tf.Fill" (%filter_shape, %i8_max) : (tensor<*xi32>, tensor) -> tensor<*xf32> + %i8_act_max_1_filled = "tf.Fill" (%filter_shape, %i8_act_max_1) : (tensor<*xi32>, tensor) -> tensor<*xf32> + %clip_min = "tf.Maximum"(%i8_min_filled, %float_out_zp) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %clip_max = "tf.Minimum"(%i8_max_filled, %i8_act_max_1_filled) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %qclip_min = "tf.Cast"(%clip_min) {Truncate = false} : (tensor<*xf32>) -> tensor<*x!tf_type.qint32> + %qclip_max = "tf.Cast"(%clip_max) {Truncate = false} : (tensor<*xf32>) -> tensor<*x!tf_type.qint32> %relu = "tf.UniformQuantizedClipByValue"(%input, %qclip_min, %qclip_max, %out_scale, %out_zp) { T = "tfdtype$DT_QINT32", quantization_axis = -1, quantization_min_val = -128, quantization_max_val = 127, attr_map = "" - } : (tensor<*x!tf_type.qint32>, tensor, tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + } : (tensor<*x!tf_type.qint32>, tensor<*x!tf_type.qint32>, tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> %requantize = "tf.PartitionedCall"(%relu, %input_scale, %input_zp, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@internal_requantize_qi8_fn } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> func.return %requantize : tensor<*x!tf_type.qint8> } } - diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc index c65f7ac7906..1491ccc049f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc @@ -25,7 +25,7 @@ namespace mlir { namespace quant { bool HasQuantizedTensors(Operation* op) { - if (IsOpNotQuantizable(op)) return false; + if (!IsOpQuantizable(op)) return false; for (Type operand_type : op->getOperandTypes()) { auto tensor_type = operand_type.dyn_cast(); if (tensor_type && tensor_type.getElementType().isa()) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 9372a6ca393..3b5a9d55f5f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -85,10 +85,12 @@ cc_library( hdrs = ["quantize_model.h"], compatible_with = get_compatible_with_cloud(), deps = if_static([":quantize_model_cc_impl"]) + [ + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", - "//tensorflow/core:protos_all_cc", ], ) @@ -124,7 +126,7 @@ tf_py_test( deps = [ ":pywrap_quantize_model", "//tensorflow:tensorflow_py", - "//tensorflow/python/platform", + "//tensorflow/python/platform:client_testlib", ], ) @@ -147,7 +149,6 @@ pytype_strict_library( "//tensorflow/python/saved_model:tag_constants", "//tensorflow/python/training:saver", "//tensorflow/python/training:training_lib", - "//tensorflow/python/types", "@absl_py//absl/logging", ], ) @@ -171,14 +172,15 @@ pytype_strict_library( "//tensorflow/python/client:session", "//tensorflow/python/eager:context", "//tensorflow/python/eager:wrap_function", + "//tensorflow/python/framework:tensor_conversion", "//tensorflow/python/lib/io:lib", - "//tensorflow/python/platform", + "//tensorflow/python/platform:tf_logging", "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:signature_constants", "//tensorflow/python/saved_model:tag_constants", "//tensorflow/python/trackable:autotrackable", - "//tensorflow/python/types", + "//tensorflow/python/types:core", "//third_party/py/numpy", "@absl_py//absl/logging", ], @@ -198,6 +200,7 @@ tf_py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/platform:tf_logging", "//tensorflow/python/saved_model:tag_constants", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -236,7 +239,7 @@ pytype_library( "//tensorflow/python/saved_model:signature_def_utils", "//tensorflow/python/trackable:asset", "//tensorflow/python/trackable:autotrackable", - "//tensorflow/python/types", + "//tensorflow/python/types:core", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -264,8 +267,8 @@ pytype_strict_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/python/client:session", - "//tensorflow/python/platform", - "//tensorflow/python/types", + "//tensorflow/python/platform:tf_logging", + "//tensorflow/python/types:core", ], ) @@ -279,7 +282,7 @@ tf_py_test( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_lib", "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/types", + "//tensorflow/python/types:core", "//third_party/py/numpy", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index 6f8df273bd3..892dfde7c9a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -16,7 +16,7 @@ # TODO(b/264234648): Refactor and cleanup this file. import itertools import os -from typing import List, Mapping, Optional, Sequence, Tuple, Union +from typing import Mapping, Optional, Sequence, Tuple, Union from absl.testing import parameterized import numpy as np @@ -135,6 +135,14 @@ class MultipleSignatureModel(module.Module): Used to test where the quantizer has to handle multiple signatures. """ + def __init__(self): + self.matmul_filters = random_ops.random_uniform( + shape=(4, 3), minval=-1.0, maxval=1.0 + ) + self.conv_filters = np.random.uniform( + low=-10, high=10, size=(2, 3, 3, 2) + ).astype('f4') + @def_function.function( input_signature=[ tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) @@ -149,8 +157,7 @@ class MultipleSignatureModel(module.Module): Returns: A map of: output key -> output result. """ - filters = random_ops.random_uniform(shape=(4, 3), minval=-1.0, maxval=1.0) - out = math_ops.matmul(matmul_input, filters) + out = math_ops.matmul(matmul_input, self.matmul_filters) return {'output': out} @@ -168,12 +175,9 @@ class MultipleSignatureModel(module.Module): Returns: A map of: output key -> output result. """ - filters = np.random.uniform(low=-10, high=10, size=(2, 3, 3, 2)).astype( - 'f4' - ) out = nn_ops.conv2d( conv_input, - filters, + self.conv_filters, strides=[1, 1, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', @@ -183,6 +187,8 @@ class MultipleSignatureModel(module.Module): return {'output': out} +# TODO(b/280208261): Add unit tests for comparing unquantized and +# quantized results @test_util.run_all_in_graph_and_eager_modes class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): """Test cases regarding the use of QuantizationOptions proto. @@ -192,6 +198,10 @@ class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): """ class SimpleModel(module.Module): + def __init__(self): + self.filters = np.random.uniform(low=-1.0, high=1.0, size=(4, 3)).astype( + 'f4' + ) @def_function.function( input_signature=[ @@ -207,9 +217,8 @@ class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): Returns: A map of: output key -> output result. """ - filters = np.random.uniform(low=-1.0, high=1.0, size=(4, 3)).astype('f4') - out = math_ops.matmul(input_tensor, filters) + out = math_ops.matmul(input_tensor, self.filters) return {'output': out} def _simple_model_data_gen(self) -> repr_dataset.RepresentativeDataset: @@ -352,6 +361,56 @@ class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): threshold=0.3, ) + @test_util.run_in_graph_and_eager_modes + def test_force_graph_mode_calibration(self): + input_type = dtypes.int32 + input_placeholder = self._create_and_save_tf1_gather_model( + self._input_saved_model_path, + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + tags={tag_constants.SERVING}, + input_key='x', + output_key='output', + input_type=input_type, + ) + + data_gen = self._create_data_generator( + input_key='x', + shape=input_placeholder.shape, + minval=0, + maxval=10, + dtype=input_type, + ) + + options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + force_graph_mode_calibration=True, + ) + + with self.assertLogs(level='INFO') as info_logs: + # Save the logger verbosity. + prev_log_level = logging.get_verbosity() + logging.set_verbosity(logging.INFO) + + try: + quantize_model.quantize( + self._input_saved_model_path, + quantization_options=options, + representative_dataset=data_gen, + ) + finally: + # Restore the logger verbosity. + logging.set_verbosity(prev_log_level) + + self.assertNotEmpty(info_logs.records) + self.assertTrue( + self._any_log_contains( + 'Calibration step is executed in graph mode.', + info_logs.records, + ) + ) + class TensorNamePreservationTest(quantize_model_test_base.QuantizedModelTest): @@ -495,24 +554,6 @@ class TensorNamePreservationTest(quantize_model_test_base.QuantizedModelTest): class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): - def _any_warning_contains( - self, substring: str, warnings_list: List['LogRecord'] - ) -> bool: - """Returns True if any of the warnings contains a given substring. - - Args: - substring: A piece of string to check whether it exists in the warning - message. - warnings_list: A list of `absl.logging.LogRecord`s. - - Returns: - True if and only if the substring exists in any of the warnings in - `warnings_list`. - """ - return any( - map(lambda warning: substring in str(warning.message), warnings_list) - ) - @parameterized.parameters( parameter_combinations([{ 'shapes': [ @@ -775,7 +816,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): ) return {'output': q_out} - np.random.seed(1234) model = ConvModel() saved_model_save.save(model, self._input_saved_model_path) @@ -1112,7 +1152,9 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): ): input_type = dtypes.int32 model = self._create_simple_gather_and_conv_model( - input_type, filter_shape=(2, 3, 3, 1024), is_qat_model=True + input_type, + filter_shape=(2, 3, 3, 1024), + is_qat_model=True, ) saved_model_save.save(model, self._input_saved_model_path) @@ -1126,25 +1168,16 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): op_set=quant_opts_pb2.XLA, ) - data_gen = self._create_data_generator( - input_key='input_tensor', - shape=(6), - minval=0, - maxval=10, - dtype=input_type, - ) - converted_model = quantize_model.quantize( self._input_saved_model_path, ['serving_default'], tags, self._output_saved_model_path, quantization_options, - representative_dataset=data_gen, ) self.assertIsNotNone(converted_model) self.assertSizeRatioLessThan( - self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + self._output_saved_model_path, self._input_saved_model_path, 0.5 ) # TODO(b/244276332): Allow table initialization in TF2 eager mode. @@ -1220,6 +1253,77 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) + # TODO(b/244276332): Allow table initialization in TF2 eager mode. + @test_util.deprecated_graph_mode_only + def test_qat_file_init_hash_table_lookup_model_tf1(self): + tags = {tag_constants.SERVING} + signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + # Create and save a simple model that involves a hash table. + inputs, outputs = self._create_and_save_file_init_hash_table_qat_model_tf1( + self._input_saved_model_path, tags, signature_def_key + ) + + # Make sure that the desired input key and output key is present. + self.assertIn('input_vocabs', inputs.keys()) + self.assertIn('lookup', outputs.keys()) + + # Representative dataset is composed of a set of vocabs for table lookup. + repr_ds = [ + {'input_vocabs': np.array([b'static', b'range', b'quantization'])} + for _ in range(4) + ] + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) + signature_def_keys = [signature_def_key] + + quantize_model.quantize( + self._input_saved_model_path, + signature_def_keys, + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=repr_ds, + ) + + # Tests table lookup to make sure the table has been initialized + # successfully. + with session.Session(graph=ops.Graph()) as sess: + output_meta_graph_def = saved_model_loader.load( + sess, tags=tags, export_dir=self._output_saved_model_path + ) + + # The graph should contain a quantized function call (it contains a + # single f32 matmul node). + self.assertTrue( + self._contains_quantized_function_call( + output_meta_graph_def.graph_def + ) + ) + self.assertCountEqual( + output_meta_graph_def.signature_def.keys(), signature_def_keys + ) + + signature_def = output_meta_graph_def.signature_def[signature_def_key] + input_tensor_name = signature_def.inputs['input_vocabs'].name + input_tensor = sess.graph.get_tensor_by_name(input_tensor_name) + lookup_tensor_name = signature_def.outputs['lookup'].name + lookup_tensor = sess.graph.get_tensor_by_name(lookup_tensor_name) + + lookup_val = sess.run( + lookup_tensor, + feed_dict={ + input_tensor: np.array([b'dynamic', b'quantization', b'range']) + }, + ) + + # "dynamic" is not in the table: -1 (default value) + self.assertAllClose(lookup_val, [-1.0, 2.0, 1.0]) + # Run this test only with the eager mode. @test_util.run_v2_only def test_ptq_model_with_variable(self): @@ -1309,140 +1413,143 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def self.assertTrue(self._contains_quantized_function_call(output_graphdef)) - # TODO(b/263830952): Use dictionaries instead of tuples for parameters. + # Check only the most simple case and the most complicated cases. @parameterized.named_parameters( - ('none', None, False, False, quant_opts_pb2.TF, False, False), - ('relu', nn_ops.relu, False, False, quant_opts_pb2.TF, False, False), - ('relu6', nn_ops.relu6, False, False, quant_opts_pb2.TF, False, False), - ('bn', None, False, True, quant_opts_pb2.TF, False, False), - ( - 'bn_and_relu', - nn_ops.relu, - False, - True, - quant_opts_pb2.TF, - False, - False, - ), - ('with_bias', None, True, False, quant_opts_pb2.TF, False, False), - ('with_bias_and_bn', None, True, True, quant_opts_pb2.TF, False, False), - ( - 'with_bias_and_bn_and_relu', - nn_ops.relu, - True, - True, - quant_opts_pb2.TF, - False, - False, - ), - ( - 'with_bias_and_relu', - nn_ops.relu, - True, - False, - quant_opts_pb2.TF, - False, - False, - ), - ( - 'with_bias_and_relu6', - nn_ops.relu6, - True, - False, - quant_opts_pb2.TF, - False, - False, - ), - ( - 'with_bias_and_bn_to_xla', - None, - True, - True, - quant_opts_pb2.XLA, - False, - False, - ), - ( - 'with_bias_and_relu6_to_xla', - nn_ops.relu6, - True, - False, - quant_opts_pb2.XLA, - False, - False, - ), - ( - 'with_bias_and_bn_to_xla_dynamic', - None, - True, - True, - quant_opts_pb2.XLA, - True, - False, - ), - ( - 'with_bias_and_relu6_to_xla_dynamic', - nn_ops.relu6, - True, - False, - quant_opts_pb2.XLA, - True, - False, - ), - ( - 'none_to_uq', - None, - False, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'none_to_uq_per_channel', - None, - False, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - True, - ), - ( - 'relu_to_uq', - nn_ops.relu, - False, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'with_bias_to_uq', - None, - True, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'with_bias_and_relu_to_uq', - nn_ops.relu, - True, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'with_bias_and_relu6_to_uq', - nn_ops.relu6, - True, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), + { + 'testcase_name': 'none', + 'activation_fn': None, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'relu', + 'activation_fn': nn_ops.relu, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'relu6', + 'activation_fn': nn_ops.relu6, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'bn', + 'activation_fn': None, + 'has_bias': False, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias', + 'activation_fn': None, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_xla', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_xla', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_xla_dynamic', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': True, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_xla_dynamic', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': True, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_uq', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_uq', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_uq_per_channel', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': True, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_uq_per_channel', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': True, + }, ) @test_util.run_in_graph_and_eager_modes def test_conv_ptq_model( @@ -1457,7 +1564,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): input_shape = [None, None, None, 3] if input_shape_dynamic else [1, 3, 4, 3] filter_shape = [2, 3, 3, 2] - np.random.seed(1234) model = self._create_conv2d_model( input_shape, filter_shape, has_bias, has_batch_norm, activation_fn ) @@ -1614,7 +1720,7 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: self.assertSizeRatioGreaterThan( - self._output_saved_model_path, self._input_saved_model_path, 0.7 + self._output_saved_model_path, self._input_saved_model_path, 0.68 ) self.assertTrue( self._contains_op(output_graphdef, 'UniformQuantizedConvolution') @@ -1628,140 +1734,143 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): else: self.assertTrue(self._contains_quantized_function_call(output_graphdef)) - # TODO(b/263830952): Use dictionaries instead of tuples for parameters. + # Check only the most simple case and the most complicated cases. @parameterized.named_parameters( - ('none', None, False, False, quant_opts_pb2.TF, False, False), - ('relu', nn_ops.relu, False, False, quant_opts_pb2.TF, False, False), - ('relu6', nn_ops.relu6, False, False, quant_opts_pb2.TF, False, False), - ('bn', None, False, True, quant_opts_pb2.TF, False, False), - ( - 'bn_and_relu', - nn_ops.relu, - False, - True, - quant_opts_pb2.TF, - False, - False, - ), - ('with_bias', None, True, False, quant_opts_pb2.TF, False, False), - ('with_bias_and_bn', None, True, True, quant_opts_pb2.TF, False, False), - ( - 'with_bias_and_bn_and_relu', - nn_ops.relu, - True, - True, - quant_opts_pb2.TF, - False, - False, - ), - ( - 'with_bias_and_relu', - nn_ops.relu, - True, - False, - quant_opts_pb2.TF, - False, - False, - ), - ( - 'with_bias_and_relu6', - nn_ops.relu6, - True, - False, - quant_opts_pb2.TF, - False, - False, - ), - ( - 'with_bias_and_bn_to_xla', - None, - True, - True, - quant_opts_pb2.XLA, - False, - False, - ), - ( - 'with_bias_and_relu6_to_xla', - nn_ops.relu6, - True, - False, - quant_opts_pb2.XLA, - False, - False, - ), - ( - 'with_bias_and_bn_to_xla_dynamic', - None, - True, - True, - quant_opts_pb2.XLA, - True, - False, - ), - ( - 'with_bias_and_relu6_to_xla_dynamic', - nn_ops.relu6, - True, - False, - quant_opts_pb2.XLA, - True, - False, - ), - ( - 'none_to_uq', - None, - False, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'none_to_uq_per_channel', - None, - False, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - True, - ), - ( - 'relu_to_uq', - nn_ops.relu, - False, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'with_bias_to_uq', - None, - True, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'with_bias_and_relu_to_uq', - nn_ops.relu, - True, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), - ( - 'with_bias_and_relu6_to_uq', - nn_ops.relu6, - True, - False, - quant_opts_pb2.UNIFORM_QUANTIZED, - False, - False, - ), + { + 'testcase_name': 'none', + 'activation_fn': None, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'relu', + 'activation_fn': nn_ops.relu, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'relu6', + 'activation_fn': nn_ops.relu6, + 'has_bias': False, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'bn', + 'activation_fn': None, + 'has_bias': False, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias', + 'activation_fn': None, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.TF, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_xla', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_xla', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_xla_dynamic', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': True, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_xla_dynamic', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.XLA, + 'input_shape_dynamic': True, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_uq', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_uq', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': False, + }, + { + 'testcase_name': 'with_bias_and_relu6_to_uq_per_channel', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': False, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': True, + }, + { + 'testcase_name': 'with_bias_and_bn_and_relu6_to_uq_per_channel', + 'activation_fn': nn_ops.relu6, + 'has_bias': True, + 'has_batch_norm': True, + 'target_opset': quant_opts_pb2.UNIFORM_QUANTIZED, + 'input_shape_dynamic': False, + 'enable_per_channel_quantization': True, + }, ) @test_util.run_in_graph_and_eager_modes def test_depthwise_conv_ptq_model( @@ -1778,7 +1887,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): model = self._create_depthwise_conv2d_model( input_shape, filter_shape, has_bias, has_batch_norm, activation_fn ) - np.random.seed(1234) saved_model_save.save(model, self._input_saved_model_path) def data_gen() -> repr_dataset.RepresentativeDataset: @@ -1910,7 +2018,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): batch_sizes: Sequence[int], target_opset: quant_opts_pb2.OpSet, ): - np.random.seed(1234) lhs_batch_size, rhs_batch_size = batch_sizes input_shape = (*lhs_batch_size, 1, 1024) filter_shape = (*rhs_batch_size, 1024, 3) @@ -1922,15 +2029,14 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): has_bias, activation_fn, ) + rng = np.random.default_rng(seed=1234) def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(500): yield { - 'input_tensor': ops.convert_to_tensor( - np.random.uniform( - low=0.0, high=1.0, size=static_input_shape - ).astype('f4') - ), + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype(np.float32) } tags = {tag_constants.SERVING} @@ -1961,15 +2067,16 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertTrue(self._contains_quantized_function_call(output_graphdef)) input_data = ops.convert_to_tensor( - np.random.uniform(low=0.0, high=1.0, size=static_input_shape).astype( - 'f4' + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 ) ) expected_outputs = model.matmul(input_data) got_outputs = converted_model.signatures['serving_default']( input_tensor=ops.convert_to_tensor(input_data) ) - self.assertAllClose(expected_outputs, got_outputs, atol=0.1674) + # The atol value is arbitrary. + self.assertAllClose(expected_outputs, got_outputs, atol=0.22) # Check the converted model in the target opset. quantization_options = quant_opts_pb2.QuantizationOptions( @@ -2003,8 +2110,82 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): input_tensor=ops.convert_to_tensor(input_data) ) # The difference between TF and target path is expected to be small. - self.assertAllClose(new_outputs, got_outputs, atol=0.1202) - self.assertAllClose(new_outputs, expected_outputs, atol=0.1023) + # The atol value is arbitrary. + self.assertAllClose(new_outputs, got_outputs, atol=0.13) + self.assertAllClose(new_outputs, expected_outputs, atol=0.13) + + @parameterized.named_parameters( + { + 'testcase_name': 'with_biasadd', + 'input_shape': (32, 16), + 'filter_shape': (16, 8), + 'bias_size': 4, + 'use_biasadd': True, + 'activation_fn': nn_ops.relu, + }, + { + 'testcase_name': 'with_addv2', + 'input_shape': (32, 16), + 'filter_shape': (16, 8), + 'bias_size': 4, + 'use_biasadd': False, + 'activation_fn': nn_ops.relu, + }, + ) + def test_matmul_with_reshape_and_bias_ptq_model( + self, input_shape, filter_shape, bias_size, activation_fn, use_biasadd + ): + + model = self._create_matmul_model( + input_shape, + filter_shape, + self._input_saved_model_path, + True, + activation_fn, + bias_size, + use_biasadd, + ) + + rng = np.random.default_rng(seed=1234) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(5): + yield { + 'input_tensor': rng.uniform( + low=0.0, high=1.0, size=input_shape + ).astype(np.float32) + } + + tags = {tag_constants.SERVING} + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.OpSet.XLA, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen(), + ) + + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=input_shape).astype( + np.float32 + ) + ) + expected_outputs = model.matmul(input_data) + + got_outputs = converted_model.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + + self.assertAllClose(expected_outputs, got_outputs, atol=0.05) @parameterized.parameters( ('abc,cde->abde', (2, 2, 64), (64, 3, 3), (3, 3), quant_opts_pb2.XLA), @@ -2177,6 +2358,68 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): func.node_def, op_name='XlaConvV2', attr_name='', attr_val=None ) + @test_util.run_in_graph_and_eager_modes + def test_function_alias_preserved_in_qat(self): + _, y_shape, _, x_signature, y_signature = ( + self._prepare_sample_einsum_datashapes('ab,bc->ac') + ) + model = self._create_einsum_model_with_fake_quant( + 'ab,bc->ac', y_shape, x_signature, y_signature + ) + + signatures = { + 'serving_default': model.einsum_with_kernel.get_concrete_function(), + } + save_opts = save_options.SaveOptions( + function_aliases={'einsum_with_kernel': model.einsum_with_kernel} + ) + + saved_model_save.save( + model, self._input_saved_model_path, signatures, save_opts + ) + + tags = {tag_constants.SERVING} + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.OpSet.XLA, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + # Test whether the aliased function exists. + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + + # Confirm that the function alias is preserved. + meta_graph_def = output_loader.get_meta_graph_def_from_tags(tags) + function_aliases = meta_graph_def.meta_info_def.function_aliases + self.assertNotEmpty(function_aliases) + self.assertCountEqual(function_aliases.values(), {'einsum_with_kernel'}) + + # Test that the aliased function contains a quantized op. + for func_name, alias in function_aliases.items(): + if alias == 'einsum_with_kernel': + for func in meta_graph_def.graph_def.library.function: + if func.signature.name == func_name: + self._contains_op_with_name_and_attribute( + func.node_def, op_name='XlaDotV2', attr_name='', attr_val=None + ) + @test_util.deprecated_graph_mode_only def test_matmul_ptq_model_with_unfreeze_constants(self): # Uses large weight to exceed the constant size threshold of 64KiB @@ -2574,11 +2817,9 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertNotEmpty(warning_logs.records) # Warning message should contain the function name. + self.assertTrue(self._any_log_contains('matmul', warning_logs.records)) self.assertTrue( - self._any_warning_contains('matmul', warning_logs.records) - ) - self.assertTrue( - self._any_warning_contains( + self._any_log_contains( 'does not have min or max values', warning_logs.records ) ) @@ -2599,6 +2840,21 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): class IfModel(module.Module): """A model that contains a branching op.""" + def __init__(self): + self.filters_0 = np.random.uniform( + low=-1.0, high=1.0, size=(4, 3) + ).astype('f4') + self.bias_0 = np.random.uniform(low=-1.0, high=1.0, size=(3,)).astype( + 'f4' + ) + + self.filters_1 = np.random.uniform( + low=-1.0, high=1.0, size=(4, 3) + ).astype('f4') + self.bias_1 = np.random.uniform(low=-1.0, high=1.0, size=(3,)).astype( + 'f4' + ) + @def_function.function( input_signature=[ tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) @@ -2617,20 +2873,12 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): A map of: output key -> output result. """ if math_ops.reduce_sum(x) > 10.0: - filters = np.random.uniform(low=-1.0, high=1.0, size=(4, 3)).astype( - 'f4' - ) - bias = np.random.uniform(low=-1.0, high=1.0, size=(3,)).astype('f4') - out = math_ops.matmul(x, filters) - out = nn_ops.bias_add(out, bias) + out = math_ops.matmul(x, self.filters_0) + out = nn_ops.bias_add(out, self.bias_0) return {'output': out} - filters = np.random.uniform(low=-1.0, high=1.0, size=(4, 3)).astype( - 'f4' - ) - bias = np.random.uniform(low=-1.0, high=1.0, size=(3,)).astype('f4') - out = math_ops.matmul(x, filters) - out = nn_ops.bias_add(out, bias) + out = math_ops.matmul(x, self.filters_1) + out = nn_ops.bias_add(out, self.bias_1) return {'output': out} model = IfModel() @@ -2675,14 +2923,12 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): # Warning message should contain the function name. The uncalibrated path # is when the condition is true, so 'cond_true' function must be part of # the warning message. - self.assertTrue( - self._any_warning_contains('cond_true', warning_logs.records) - ) + self.assertTrue(self._any_log_contains('cond_true', warning_logs.records)) self.assertFalse( - self._any_warning_contains('cond_false', warning_logs.records) + self._any_log_contains('cond_false', warning_logs.records) ) self.assertTrue( - self._any_warning_contains( + self._any_log_contains( 'does not have min or max values', warning_logs.records ) ) @@ -3515,7 +3761,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): out = activation_fn(out) return {'output': out} - np.random.seed(1234) model = ConvModel() saved_model_save.save(model, self._input_saved_model_path) @@ -4083,7 +4328,7 @@ class DynamicRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: self.assertSizeRatioGreaterThan( - self._output_saved_model_path, self._input_saved_model_path, 0.7 + self._output_saved_model_path, self._input_saved_model_path, 0.65 ) self.assertTrue( self._contains_op( @@ -4392,6 +4637,61 @@ class DynamicRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) + @test_util.deprecated_graph_mode_only + def test_file_init_hash_table_lookup_model(self): + tags = {tag_constants.SERVING} + signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + # Create and save a simple model that involves a hash table. + inputs, outputs = self._create_and_save_file_init_hash_table_model_tf1( + self._input_saved_model_path, tags, signature_def_key + ) + # Make sure that the desired input key and output key is present. + self.assertIn('input_vocabs', inputs.keys()) + self.assertIn('lookup', outputs.keys()) + + signature_def_keys = [signature_def_key] + quantize_model.quantize( + self._input_saved_model_path, + signature_def_keys, + tags, + self._output_saved_model_path, + quantization_options=quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + ), + ) + + # Tests table lookup to make sure the table has been initialized + # successfully. + with session.Session(graph=ops.Graph()) as sess: + output_meta_graph_def = saved_model_loader.load( + sess, tags=tags, export_dir=self._output_saved_model_path + ) + + self.assertCountEqual( + output_meta_graph_def.signature_def.keys(), signature_def_keys + ) + + signature_def = output_meta_graph_def.signature_def[signature_def_key] + + input_tensor_name = signature_def.inputs['input_vocabs'].name + input_tensor = sess.graph.get_tensor_by_name(input_tensor_name) + + lookup_tensor_name = signature_def.outputs['lookup'].name + lookup_tensor = sess.graph.get_tensor_by_name(lookup_tensor_name) + + lookup_val = sess.run( + lookup_tensor, + feed_dict={ + input_tensor: np.array([b'dynamic', b'quantization', b'range']) + }, + ) + + # "dynamic" is not in the table: -1 (default value) + self.assertAllClose(lookup_val, [-1.0, 2.0, 1.0]) + class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): """Test cases for weight-only quantization. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py index 3341e1d84c4..f2593d336f7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_string_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import save as saved_model_save from tensorflow.python.saved_model import signature_def_utils_impl @@ -84,6 +85,27 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): total += os.path.getsize(os.path.join(root, filename)) return total + def _any_log_contains( + self, substring: str, log_record_list: List['logging.LogRecord'] + ) -> bool: + """Returns True if any of the log contains a given substring. + + Args: + substring: A piece of string to check whether it exists in the log + message. + log_record_list: A list of `absl.logging.LogRecord`s. + + Returns: + True if and only if the substring exists in any of the log in + `log_record_list`. + """ + return any( + map( + lambda log_record: substring in str(log_record.message), + log_record_list, + ) + ) + def assertSizeRatioGreaterThan( self, path_a: str, path_b: str, threshold: float ): @@ -530,7 +552,9 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): shape=array_ops.shape(input_vocabs_placeholder), dtype=dtypes.float32 ) # shape: (?, 2) - weight = array_ops.transpose_v2(array_ops.stack([weight_row, weight_row])) + weight = array_ops.transpose_v2( + array_ops_stack.stack([weight_row, weight_row]) + ) # shape: (2, 2) output_tensor = math_ops.matmul(matmul_input, weight) @@ -725,6 +749,126 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): return input_vocabs_placeholder, lookup_vals, output_tensor + def _create_table_init_from_file_qat_model_tf1( + self, sess: session.Session + ) -> Tuple[core.Tensor, core.Tensor, core.Tensor]: + """Creates a simple QAT model that initializes a table from an asset file. + + This model creates an asset file at "vocab_file.txt" containing + comma-separated vocabularies and uses it to initialize a + `StaticVocabularyTable`. For inference, the model performs a lookup with a + 1D string tensor input vocabs. + + Args: + sess: Tensorflow Session to create the model in. + + Returns: + (input_vocabs_placeholder, lookup_vals, output_tensor), where + * input_vocabs_placeholder is a placeholder tensor of 1D strings + * lookup_vals is an output tensor that is a direct result of table lookup + * output_tensor is a float 2x2 matrix + """ + # Creates and populates an asset file. + asset_dir = self.create_tempdir('assets').full_path + asset_file = os.path.join(asset_dir, 'vocab_file.txt') + content = '\n'.join(['static', 'range', 'quantization']) + file_io.write_string_to_file(filename=asset_file, file_content=content) + + # The resulting table looks like: + # "static" -> 0 + # "range" -> 1 + # "quantization" -> 2 + # default -> -1 + init = lookup_ops.TextFileInitializer( + filename=asset_file, + key_dtype=dtypes.string, + key_index=lookup_ops.TextFileIndex.WHOLE_LINE, + value_dtype=dtypes.int64, + value_index=lookup_ops.TextFileIndex.LINE_NUMBER, + ) + table = lookup_ops.StaticHashTable(init, default_value=-1) + + input_vocabs_placeholder = array_ops.placeholder( + dtypes.string, shape=(None,), name='input_vocabs' + ) + + # Introduce a matmul op that takes the lookup values to observe the + # effects of quantization. + lookup_vals = math_ops.cast( + table.lookup(input_vocabs_placeholder), dtypes.float32 + ) + # shape: (2, ?) + matmul_input = array_ops_stack.stack([lookup_vals, lookup_vals]) + matmul_input = array_ops.fake_quant_with_min_max_args( + matmul_input, min=-0.3, max=0.3, num_bits=8, narrow_range=False + ) + + # Create a dummy weight matrix filled with ones. + weight_row = array_ops.ones( + shape=array_ops.shape(input_vocabs_placeholder), dtype=dtypes.float32 + ) + # shape: (?, 2) + weight = array_ops.transpose_v2( + array_ops_stack.stack([weight_row, weight_row]) + ) + weight = array_ops.fake_quant_with_min_max_args( + weight, min=-0.1, max=0.2, num_bits=8, narrow_range=False + ) + + # shape: (2, 2) + output_tensor = math_ops.matmul(matmul_input, weight) + output_tensor = array_ops.fake_quant_with_min_max_args( + output_tensor, min=-0.2, max=0.2, num_bits=8, narrow_range=False + ) + + return input_vocabs_placeholder, lookup_vals, output_tensor + + def _create_and_save_file_init_hash_table_qat_model_tf1( + self, + output_path: str, + tags: Collection[str], + signature_def_key: str, + ) -> Tuple[Mapping[str, core.Tensor], Mapping[str, core.Tensor]]: + """Creates and saves a QAT model that uses a file-initialized table. + + The asset file "vocab_file.txt" is used to initialize a hash table. + + Args: + output_path: Path to the directory to save the created model. + tags: Set of strings that identifies the saved meta graph. + signature_def_key: Name of the SignatureDef. Used to identify the + SignatureDef within the meta graph. + + Returns: + inputs: A mapping of input_key -> input_tensor (placeholder). The input + key is "input_vocabs". + outputs: A mapping of output_key -> output_tensor. The output keys are + "lookup" and "output". + """ + with session.Session(graph=ops.Graph()) as sess: + input_vocabs_placeholder, lookup_tensor, output_tensor = ( + self._create_table_init_from_file_qat_model_tf1(sess) + ) + + inputs = {'input_vocabs': input_vocabs_placeholder} + outputs = { + 'lookup': lookup_tensor, + 'output': output_tensor, + } + + self._save_tf1_model( + sess, + output_path, + signature_def_key, + tags, + inputs=inputs, + outputs=outputs, + init_op=lookup_ops.tables_initializer(), + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), + ) + + return inputs, outputs + def _create_data_generator( self, input_key: str, @@ -804,8 +948,16 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): def __init__(self): """Initializes a SimpleGatherAndConvModel.""" - embedding_w_val = np.random.randn(1024, 3, 4, 3).astype('f4') - self.embedding_w = embedding_w_val + self.embedding_w = np.random.randn(1024, 3, 4, 3).astype('f4') + + self.conv_filters = np.random.uniform( + low=-10, high=10, size=filter_shape + ).astype('f4') + + second_conv_filter_shape = (3, 3, filter_shape[-1], 1) + self.second_conv_filters = np.random.uniform( + low=-10, high=10, size=second_conv_filter_shape + ).astype('f4') @def_function.function( input_signature=[ @@ -823,23 +975,39 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): Returns: A map of: output key -> output result. """ - conv_filters = np.random.uniform( - low=-10, high=10, size=filter_shape - ).astype('f4') out = array_ops.gather_v2(self.embedding_w, input_tensor) + + # One pure conv + out = nn_ops.conv2d( + out, + self.conv_filters, + strides=(1, 1, 2, 1), + dilations=(1, 1, 1, 1), + padding='SAME', + data_format='NHWC', + ) + + # One fakequant attached conv if is_qat_model: out = array_ops.fake_quant_with_min_max_args( out, min=-0.1, max=0.2, num_bits=8, narrow_range=False ) - conv_filters = array_ops.fake_quant_with_min_max_args( - conv_filters, min=-0.1, max=0.2, num_bits=8, narrow_range=True + second_conv_filters = array_ops.fake_quant_with_min_max_args( + self.second_conv_filters, + min=-0.1, + max=0.2, + num_bits=8, + narrow_range=True, ) + else: + second_conv_filters = self.second_conv_filters + out = nn_ops.conv2d( out, - conv_filters, - strides=[1, 1, 2, 1], - dilations=[1, 1, 1, 1], + second_conv_filters, + strides=(1, 1, 2, 1), + dilations=(1, 1, 1, 1), padding='SAME', data_format='NHWC', ) @@ -945,6 +1113,16 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): class DepthwiseConvModel(module.Module): """A simple model with a single depthwise conv2d, bias and relu.""" + def __init__(self): + self.filters = np.random.uniform( + low=-10, high=10, size=filter_shape + ).astype('f4') + + self.out_channel_size = filter_shape[2] * filter_shape[3] + self.bias = np.random.uniform( + low=0, high=10, size=(self.out_channel_size) + ).astype('f4') + @def_function.function( input_signature=[ tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) @@ -961,25 +1139,19 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): Returns: A map of: output key -> output result. """ - filters = np.random.uniform(low=-10, high=10, size=filter_shape).astype( - 'f4' - ) - out_channel_size = filter_shape[2] * filter_shape[3] - bias = np.random.uniform( - low=0, high=10, size=(out_channel_size) - ).astype('f4') - scale, offset = [1.0] * out_channel_size, [0.5] * out_channel_size + scale = [1.0] * self.out_channel_size + offset = [0.5] * self.out_channel_size mean, variance = scale, offset out = nn_ops.depthwise_conv2d_native( input_tensor, - filters, + self.filters, strides=[1, 2, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', data_format='NHWC', ) if has_bias: - out = nn_ops.bias_add(out, bias) + out = nn_ops.bias_add(out, self.bias) if has_batch_norm: # Fusing is supported for non-training case. out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( @@ -1005,6 +1177,16 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): class ConvModel(module.Module): """A simple model with a single conv2d, bias and relu.""" + def __init__(self): + self.filters = np.random.uniform( + low=-10, high=10, size=filter_shape + ).astype('f4') + + self.out_channel_size = filter_shape[-1] + self.bias = np.random.uniform( + low=0, high=10, size=(self.out_channel_size) + ).astype('f4') + @def_function.function( input_signature=[ tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) @@ -1019,25 +1201,19 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): Returns: A map of: output key -> output result. """ - filters = np.random.uniform(low=-10, high=10, size=filter_shape).astype( - 'f4' - ) - out_channel_size = filter_shape[-1] - bias = np.random.uniform( - low=0, high=10, size=(out_channel_size) - ).astype('f4') - scale, offset = [1.0] * out_channel_size, [0.5] * out_channel_size + scale = [1.0] * self.out_channel_size + offset = [0.5] * self.out_channel_size mean, variance = scale, offset out = nn_ops.conv2d( input_tensor, - filters, + self.filters, strides=[1, 1, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', data_format='NHWC', ) if has_bias: - out = nn_ops.bias_add(out, bias, data_format='NHWC') + out = nn_ops.bias_add(out, self.bias, data_format='NHWC') if has_batch_norm: # Fusing is supported for non-training case. out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( @@ -1056,6 +1232,8 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): saved_model_path: str, has_bias: bool = False, activation_fn: Optional[ops.Operation] = None, + bias_size: Optional[int] = None, + use_biasadd: bool = True, ) -> module.Module: class MatmulModel(module.Module): """A simple model with a single matmul. @@ -1066,21 +1244,32 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): def __init__( self, weight_shape: Sequence[int], - has_bias: bool = False, + bias_size: Optional[int] = None, activation_fn: Optional[ops.Operation] = None, + use_biasadd: bool = True, ) -> None: """Initializes a MatmulModel. Args: weight_shape: Shape of the weight tensor. - has_bias: If True, creates and adds a bias term. + bias_size: If None, do not use bias. Else, use given size as bias. activation_fn: The activation function to be used. No activation function if None. + use_biasadd: If True, use BiasAdd for adding bias, else use AddV2. """ - self.has_bias = has_bias + self.bias_size = bias_size self.activation_fn = activation_fn + self.use_biasadd = use_biasadd self.filters = np.random.uniform(low=-1.0, high=1.0, size=weight_shape) - self.bias = np.random.uniform(low=-1.0, high=1.0, size=weight_shape[-1]) + + if bias_size is not None: + self.bias = np.random.uniform(low=-1.0, high=1.0, size=bias_size) + + def has_bias(self) -> bool: + return self.bias_size is not None + + def has_reshape(self) -> bool: + return self.has_bias() and self.bias_size != self.filters.shape[-1] @def_function.function def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: @@ -1098,15 +1287,40 @@ class QuantizedModelTest(test.TestCase, parameterized.TestCase): """ out = math_ops.matmul(input_tensor, self.filters) - if self.has_bias: - out = nn_ops.bias_add(out, self.bias) + if self.has_reshape(): + input_shape = input_tensor.shape + if len(input_shape) == 3: + reshape_shape = (input_shape[0], -1, self.bias_size) + else: + reshape_shape = (-1, self.bias_size) + + out = array_ops.reshape(out, reshape_shape) + + if self.has_bias(): + if self.use_biasadd: + out = nn_ops.bias_add(out, self.bias) + else: + out = math_ops.add_v2(out, self.bias) if self.activation_fn is not None: out = self.activation_fn(out) return {'output': out} - model = MatmulModel(weight_shape, has_bias, activation_fn) + # If bias_size is not explictly given, it should default to width of weight. + if bias_size is None and has_bias: + bias_size = weight_shape[-1] + + # Verify that when bias_size is not None, has_bias should be True. + # And if bias_size is None, has_bias should be False using XNOR + assert (not ((bias_size is not None) ^ has_bias)) + + # Verify that bias size is correct + if bias_size: + input_height = input_shape[0] if len(input_shape) == 2 else input_shape[1] + assert input_height * weight_shape[-1] % bias_size == 0 + + model = MatmulModel(weight_shape, bias_size, activation_fn) saved_model_save.save( model, saved_model_path, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 7aeadeb212f..4b083d6f96c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -161,10 +161,11 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { [](const absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quant_opts) + const QuantizationOptions& quant_opts, + const absl::flat_hash_map& function_aliases) -> absl::StatusOr { return QuantizeQatModel(saved_model_path, signature_keys, tags, - quant_opts); + quant_opts, function_aliases); }, R"pbdoc( Returns serialized ExportedModel that contains the quantized model's diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 95feea40ee1..0a3f2e95c36 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" @@ -267,8 +268,8 @@ absl::StatusOr ConvertMlirModuleToExportedModel( if (const auto status = ConvertMlirToGraph(module_op, config, &graph, &flib_def, &control_ret_nodes); !status.ok()) { - return absl::InternalError("Failed to convert MLIR to GraphDef. " + - status.error_message()); + return absl::InternalError( + absl::StrCat("Failed to convert MLIR to GraphDef. ", status.message())); } GraphDef graph_def{}; @@ -285,6 +286,39 @@ absl::StatusOr ConvertMlirModuleToExportedModel( function_aliases, asset_file_defs); } +// Returns the updated function aliases. `module_op` may have different function +// names from the original model, so it re-associates the aliases with the new +// function names. Both the input `function_aliases` and the returned value +// are function name -> alias mappings. `function_aliases` is the function alias +// mapping of the original function. +absl::flat_hash_map UpdateFunctionAliases( + const absl::flat_hash_map function_aliases, + mlir::ModuleOp module_op) { + absl::flat_hash_map updated_function_aliases; + + module_op->walk([&](mlir::func::FuncOp func_op) { + // We may retrieve the original function's name from the attribute. + // Functions without this attribute are ignored. + auto original_func_name = + func_op->getAttrOfType("tf._original_func_name"); + if (original_func_name) { + if (auto alias_itr = function_aliases.find(original_func_name.str()); + alias_itr != function_aliases.end()) { + const std::string alias = alias_itr->second; + const std::string new_func_name = func_op.getSymName().str(); + + updated_function_aliases[new_func_name] = alias; + + VLOG(1) << "Updated function alias. Alias: " << alias + << ", New function name: " << new_func_name + << ", Old function name: " << original_func_name.str(); + } + } + }); + + return updated_function_aliases; +} + // Runs MLIR passes with `module_op`. The passes are added by calling // `add_passes_func`, which is a callable receiving mlir::PassManager& as its // only argument. `name` identifies the set of passes added by `add_passes_func` @@ -310,7 +344,7 @@ absl::Status RunPasses(const absl::string_view name, FuncT add_passes_func, if (failed(pm.run(module_op))) { return absl::InternalError( absl::StrFormat("Failed to run pass: %s. %s", name, - diagnostic_handler.ConsumeStatus().error_message())); + diagnostic_handler.ConsumeStatus().message())); } return absl::OkStatus(); @@ -421,13 +455,15 @@ absl::StatusOr QuantizeQatModel( const absl::string_view saved_model_path, const std::vector &signature_keys, const std::unordered_set &tags, - const QuantizationOptions &quantization_options) { + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context = CreateMlirContextForTfQuantization(); MLIRImportOptions import_options; import_options.upgrade_legacy = true; import_options.lift_variables = false; + import_options.include_variables_in_initializers = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input @@ -437,14 +473,33 @@ absl::StatusOr QuantizeQatModel( absl::MakeSpan(exported_names), &context, import_options, &bundle); if (!module.status().ok()) { - return absl::InternalError("Failed to import SavedModel: " + - module.status().error_message()); + return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", + module.status().message())); } mlir::OwningOpRef module_ref = std::move(module).value(); - TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( - module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr)); + const absl::flat_hash_map updated_function_aliases = + UpdateFunctionAliases(function_aliases, *module_ref); + + // Collect the names of the functions that have aliases so that they may not + // be inlined. + absl::flat_hash_set aliased_function_names; + absl::c_for_each(updated_function_aliases, [&](const auto &aliases) { + return aliased_function_names.insert(aliases.first); + }); + + // TODO(b/274858158): Removing this triggers an error on unit test. + if (aliased_function_names.empty()) { + TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr)); + } else { + TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, + /*is_inliner_run=*/false, + /*noinline_functions=*/aliased_function_names, module_ref.get(), + &context, bundle ? bundle->GetSession() : nullptr)); + } TF_QUANT_RETURN_IF_ERROR( RunPasses(/*name=*/kTfQuantQatStepName, @@ -468,44 +523,10 @@ absl::StatusOr QuantizeQatModel( RunExportPasses(export_opts, context, *module_ref)); return ConvertMlirModuleToExportedModel( - *module_ref, checkpoint_dir, - /*function_aliases=*/{}, + *module_ref, checkpoint_dir, updated_function_aliases, {asset_file_defs.begin(), asset_file_defs.end()}); } -// Returns the updated function aliases. `module_op` may have different function -// names from the original model, so it re-associates the aliases with the new -// function names. Both the input `function_aliases` and the returned value -// are function name -> alias mappings. `function_aliases` is the function alias -// mapping of the original function. -absl::flat_hash_map UpdateFunctionAliases( - const absl::flat_hash_map function_aliases, - mlir::ModuleOp module_op) { - absl::flat_hash_map updated_function_aliases; - - module_op->walk([&](mlir::func::FuncOp func_op) { - // We may retrieve the original function's name from the attribute. - // Functions without this attribute are ignored. - auto original_func_name = - func_op->getAttrOfType("tf._original_func_name"); - if (original_func_name) { - if (auto alias_itr = function_aliases.find(original_func_name.str()); - alias_itr != function_aliases.end()) { - const std::string alias = alias_itr->second; - const std::string new_func_name = func_op.getSymName().str(); - - updated_function_aliases[new_func_name] = alias; - - VLOG(1) << "Updated function alias. Alias: " << alias - << ", New function name: " << new_func_name - << ", Old function name: " << original_func_name.str(); - } - } - }); - - return updated_function_aliases; -} - absl::StatusOr QuantizePtqModelPreCalibration( const absl::string_view saved_model_path, const std::vector &signature_keys, @@ -518,6 +539,7 @@ absl::StatusOr QuantizePtqModelPreCalibration( MLIRImportOptions import_options; import_options.upgrade_legacy = true; import_options.lift_variables = false; + import_options.include_variables_in_initializers = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input @@ -528,8 +550,8 @@ absl::StatusOr QuantizePtqModelPreCalibration( &context, import_options, &bundle); if (!module.status().ok()) { - return absl::InternalError("Failed to import SavedModel: " + - module.status().error_message()); + return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", + module.status().message())); } mlir::OwningOpRef module_ref = std::move(module).value(); @@ -589,6 +611,7 @@ absl::StatusOr QuantizePtqModelPostCalibration( MLIRImportOptions import_options; import_options.upgrade_legacy = true; import_options.lift_variables = false; + import_options.include_variables_in_initializers = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input @@ -599,8 +622,8 @@ absl::StatusOr QuantizePtqModelPostCalibration( &context, import_options, &bundle); if (!module.status().ok()) { - return absl::InternalError("Failed to import SavedModel: " + - module.status().error_message()); + return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", + module.status().message())); } mlir::OwningOpRef module_ref = std::move(module).value(); @@ -661,6 +684,7 @@ absl::StatusOr QuantizePtqDynamicRange( MLIRImportOptions import_options; import_options.upgrade_legacy = true; import_options.lift_variables = false; + import_options.include_variables_in_initializers = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input @@ -671,8 +695,8 @@ absl::StatusOr QuantizePtqDynamicRange( &context, import_options, &bundle); if (!module.status().ok()) { - return absl::InternalError("Failed to import SavedModel: " + - module.status().error_message()); + return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", + module.status().message())); } mlir::OwningOpRef module_ref = std::move(module).value(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index c3747fee523..f17f20df4b6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -47,7 +47,8 @@ absl::StatusOr QuantizeQatModel( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quant_opts); + const QuantizationOptions& quant_opts, + const absl::flat_hash_map& function_aliases); // Apply post-training dynamic range quantization to the model. absl::StatusOr QuantizePtqDynamicRange( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 56057c222cd..758f99b62a0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -33,6 +33,7 @@ from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import wrap_function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_conversion from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader_impl as saved_model_loader @@ -192,7 +193,9 @@ def _convert_values_to_tf_tensors( if isinstance(tensorlike_value, core.Tensor): tensor_value = tensorlike_value else: - tensor_value = ops.convert_to_tensor_v2_with_dispatch(tensorlike_value) + tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch( + tensorlike_value + ) tensor_mapping[name] = tensor_value @@ -461,6 +464,7 @@ def _run_graph_for_calibration( signature_keys: Sequence[str], tags: Collection[str], representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, + force_graph_mode_calibration: bool, ) -> None: """Runs the graph for calibration using representative datasets. @@ -475,6 +479,8 @@ def _run_graph_for_calibration( `signature_keys` contains more than one signature key, `representative_datsaet` should be a mapping that maps each signature keys to the corresponding representative dataset. + force_graph_mode_calibration: If set to true, it forces calibration in graph + model instead of eager mode when the context is in eager mode. Raises: ValueError iff: @@ -495,11 +501,13 @@ def _run_graph_for_calibration( representative_dataset_map = {signature_keys[0]: representative_dataset} try: - if context.executing_eagerly(): + if context.executing_eagerly() and not force_graph_mode_calibration: + logging.info('Calibration step is executed in eager mode.') _run_graph_for_calibration_eager_mode( float_model_dir, tags, representative_dataset_map ) else: + logging.info('Calibration step is executed in graph mode.') _run_graph_for_calibration_graph_mode( float_model_dir, tags, representative_dataset_map ) @@ -511,85 +519,6 @@ def _run_graph_for_calibration( logging.info('Calibration step complete.') -def _run_static_range_qat( - src_saved_model_path: str, - dst_saved_model_path: str, - signature_def_keys: Sequence[str], - tags: Collection[str], - quant_opts: quant_opts_pb2.QuantizationOptions, - signature_def_map: _SignatureDefMap, -) -> None: - """Runs static-range quantization for a Quantization-Aware Trained model. - - Runs the quantization for a model trained using QAT. - - Args: - src_saved_model_path: Path to the source SavedModel directory. - dst_saved_model_path: Path to the destination SavedModel directory. - signature_def_keys: Keys of the signatures of the functions that are the - target for quantization. - tags: Tags identifying the MetaGraphDef. - quant_opts: Quantization options. - signature_def_map: Signature def key -> SignatureDef mapping. - """ - logging.info('Running static-range quantization for QAT model.') - exported_model_serialized = pywrap_quantize_model.quantize_qat_model( - src_saved_model_path, - list(signature_def_keys), - set(tags), - quant_opts.SerializeToString(), - ) - - exported_model = exported_model_pb2.ExportedModel.FromString( - exported_model_serialized - ) - - save_model.save_model_v1( - exported_model.graph_def, - dst_saved_model_path, - signature_def_map, - tags, - init_op_name=exported_model.init_node_name, - saver_def=_get_saver_def_or_none(exported_model), - checkpoint_dir=exported_model.checkpoint_dir, - function_aliases=exported_model.function_aliases, - asset_file_defs=exported_model.asset_file_defs, - ) - - -def _add_calibration_statistics(graph_def: graph_pb2.GraphDef) -> None: - """Adds calibration statistics to the graph def. - - This function must be run after running the graph with a representative - dataset. Retrieves calibration statistics from the global calibrator and adds - them to the corresponding nodes as attributes. - - Args: - graph_def: GraphDef to add calibration statistics to. - """ - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op != 'CustomAggregator': - continue - - node_id = node_def.attr['id'].s - try: - min_val = pywrap_quantize_model.get_min_from_calibrator(node_id) - max_val = pywrap_quantize_model.get_max_from_calibrator(node_id) - pywrap_quantize_model.clear_data_from_calibrator(node_id) - node_def.attr['min'].f = float(min_val) - node_def.attr['max'].f = float(max_val) - except ValueError: - logging.warn( - ( - 'CustomAggregator id "%s" from FunctionDef "%s" does not have ' - 'min or max values. Parts of this function are not quantized.' - ), - node_id.decode('utf-8'), - function_def.signature.name, - ) - - def _copy_assets(src_path: str, dst_path: str) -> None: """Copies the assets directory of the saved model. @@ -623,6 +552,94 @@ def _copy_assets(src_path: str, dst_path: str) -> None: ) +def _run_static_range_qat( + src_saved_model_path: str, + dst_saved_model_path: str, + signature_def_keys: Sequence[str], + tags: Collection[str], + quant_opts: quant_opts_pb2.QuantizationOptions, + signature_def_map: _SignatureDefMap, +) -> None: + """Runs static-range quantization for a Quantization-Aware Trained model. + + Runs the quantization for a model trained using QAT. + + Args: + src_saved_model_path: Path to the source SavedModel directory. + dst_saved_model_path: Path to the destination SavedModel directory. + signature_def_keys: Keys of the signatures of the functions that are the + target for quantization. + tags: Tags identifying the MetaGraphDef. + quant_opts: Quantization options. + signature_def_map: Signature def key -> SignatureDef mapping. + """ + logging.info('Running static-range quantization for QAT model.') + + loader = saved_model_loader.SavedModelLoader(src_saved_model_path) + function_aliases = loader.get_meta_graph_def_from_tags( + tags + ).meta_info_def.function_aliases + + exported_model_serialized = pywrap_quantize_model.quantize_qat_model( + src_saved_model_path, + list(signature_def_keys), + set(tags), + quant_opts.SerializeToString(), + dict(function_aliases), + ) + + exported_model = exported_model_pb2.ExportedModel.FromString( + exported_model_serialized + ) + + save_model.save_model_v1( + exported_model.graph_def, + dst_saved_model_path, + signature_def_map, + tags, + init_op_name=exported_model.init_node_name, + saver_def=_get_saver_def_or_none(exported_model), + checkpoint_dir=exported_model.checkpoint_dir, + function_aliases=exported_model.function_aliases, + asset_file_defs=exported_model.asset_file_defs, + ) + + _copy_assets(src_saved_model_path, dst_saved_model_path) + + +def _add_calibration_statistics(graph_def: graph_pb2.GraphDef) -> None: + """Adds calibration statistics to the graph def. + + This function must be run after running the graph with a representative + dataset. Retrieves calibration statistics from the global calibrator and adds + them to the corresponding nodes as attributes. + + Args: + graph_def: GraphDef to add calibration statistics to. + """ + for function_def in graph_def.library.function: + for node_def in function_def.node_def: + if node_def.op != 'CustomAggregator': + continue + + node_id = node_def.attr['id'].s + try: + min_val = pywrap_quantize_model.get_min_from_calibrator(node_id) + max_val = pywrap_quantize_model.get_max_from_calibrator(node_id) + pywrap_quantize_model.clear_data_from_calibrator(node_id) + node_def.attr['min'].f = float(min_val) + node_def.attr['max'].f = float(max_val) + except ValueError: + logging.warn( + ( + 'CustomAggregator id "%s" from FunctionDef "%s" does not have ' + 'min or max values. Parts of this function are not quantized.' + ), + node_id.decode('utf-8'), + function_def.signature.name, + ) + + def _get_saver_def_or_none( exported_model: exported_model_pb2.ExportedModel, ) -> Optional[saver_pb2.SaverDef]: @@ -721,6 +738,7 @@ def _run_static_range_ptq( signature_def_keys, tags, representative_dataset, + quant_opts.force_graph_mode_calibration, ) _add_calibration_statistics(graph_def) @@ -911,7 +929,7 @@ def _dynamic_range_quantize( # please also update default value in tflite converter: # tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc;l=201 if quantization_options.min_num_elements_for_weights == 0: - (quantization_options.min_num_elements_for_weights) = ( + quantization_options.min_num_elements_for_weights = ( _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS ) logging.warn( @@ -941,9 +959,14 @@ def _dynamic_range_quantize( exported_model.graph_def, output_directory, signature_def_map, - tags=tags, + tags, init_op_name=exported_model.init_node_name, + saver_def=_get_saver_def_or_none(exported_model), + checkpoint_dir=exported_model.checkpoint_dir, + function_aliases=exported_model.function_aliases, + asset_file_defs=exported_model.asset_file_defs, ) + _copy_assets(saved_model_path, output_directory) return saved_model_load(output_directory) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto index 5e3f3bba9a5..6c5c520bffc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto @@ -132,7 +132,7 @@ message FreezeAllVariables { // 2) A set of supported operations. // 3) Unit wise quantization precision. // 4) Target hardware name. -// NEXT ID: 11 +// NEXT ID: 12 message QuantizationOptions { // The default quantization configuration for the model. If the below // unit-wise configuration does not exist, we use this default quantization @@ -181,4 +181,8 @@ message QuantizationOptions { // Produces legacy weight-only graph where the qconst op(containing quantized // values) is followed by a dequantization op. bool enable_legacy_weight_only = 10; + + // If set to true, it forces calibration in graph model instead of eager mode + // when the context is in eager mode. + bool force_graph_mode_calibration = 11; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 362971c5d42..a3b43e62e5e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -93,6 +93,9 @@ void AddQuantizePtqDynamicRangePasses( pm.addNestedPass( mlir::TF::CreateUnrollBatchMatMulPassPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + if (quantization_options.experimental_enable_tpu_model_support()) { + pm.addPass(mlir::quant::CreateConvertTpuModelToCpuPass()); + } pm.addNestedPass( mlir::quant::CreatePrepareLiftingPass(quantization_options.op_set())); pm.addPass(mlir::quant::CreateLiftQuantizableSpotsAsFunctionsDRQPass( @@ -134,11 +137,11 @@ void AddQuantizePtqPreCalibrationPasses( mlir::TF::CreateUnrollBatchMatMulPassPass()); } pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - pm.addNestedPass( - mlir::quant::CreatePrepareLiftingPass(quantization_options.op_set())); if (quantization_options.experimental_enable_tpu_model_support()) { pm.addPass(mlir::quant::CreateConvertTpuModelToCpuPass()); } + pm.addNestedPass( + mlir::quant::CreatePrepareLiftingPass(quantization_options.op_set())); pm.addPass(mlir::quant::CreateLiftQuantizableSpotsAsFunctionsPass( quantization_options.op_set(), quantization_options.enable_two_input_tensors())); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir index 61d9288a5a8..4fc6cbf3f97 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/cast_bf16_ops_to_f32.mlir @@ -47,8 +47,8 @@ func.func @cast_bf16_avg_pool_to_fp32(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3 // CHECK: func @cast_bf16_avg_pool_to_fp32 // CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> -// CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> -// CHECK: %[[avg_pool:.*]] = "tf.AvgPool"(%[[conv]]) {data_format = "NHWC", ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]]) +// CHECK: %[[avg_pool:.*]] = "tf.AvgPool"(%[[conv]]) // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[avg_pool]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> // CHECK: return %[[identity]] : tensor<1x3x2x2xf32> @@ -63,7 +63,7 @@ func.func @cast_bf16_matmul_to_fp32(%arg0: tensor<1x10xf32>) -> (tensor<1x2xf32> // CHECK: func @cast_bf16_matmul_to_fp32 // CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<10x2xf32>} : () -> tensor<10x2xf32> -// CHECK: %[[matmul:.*]] = "tf.MatMul"(%arg0, %[[cst]]) {transpose_a = false, transpose_b = false} : (tensor<1x10xf32>, tensor<10x2xf32>) -> tensor<1x2xf32> +// CHECK: %[[matmul:.*]] = "tf.MatMul"(%arg0, %[[cst]]) // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[matmul]]) // CHECK: return %[[identity]] : tensor<1x2xf32> @@ -78,7 +78,7 @@ func.func @cast_bf16_depthwise_conv_to_fp32(%arg0: tensor<1x3x4x3xf32>) -> (tens // CHECK: func @cast_bf16_depthwise_conv_to_fp32 // CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> -// CHECK: %[[depthwise_conv:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[cst]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x6xf32> +// CHECK: %[[depthwise_conv:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[cst]]) // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[depthwise_conv]]) {device = ""} : (tensor<1x2x2x6xf32>) -> tensor<1x2x2x6xf32> // CHECK: return %[[identity]] : tensor<1x2x2x6xf32> @@ -97,35 +97,18 @@ func.func @cast_bf16_batch_matmul_v2_to_fp32(%arg0: tensor<1x1x10xf32>) -> (tens // CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[batch_matmul]]) {device = ""} : (tensor<1x1x2xf32>) -> tensor<1x1x2xf32> // CHECK: return %[[identity]] : tensor<1x1x2xf32> -func.func @cast_bf16_gather_v2_to_fp32(%arg0: tensor<1xi64>) -> (tensor<1x3x4x3xf32>) { - %cst = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor - %cst_0 = "tf.Const"() {device = "", value = dense<1.000000e+01> : tensor<1024x3x4x3xbf16>} : () -> tensor<1024x3x4x3xbf16> - %0 = "tf.GatherV2"(%cst_0, %arg0, %cst) {batch_dims = 0 : i64, device = ""} : (tensor<1024x3x4x3xbf16>, tensor<1xi64>, tensor) -> tensor<1x3x4x3xbf16> - %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<1x3x4x3xbf16>) -> tensor<1x3x4x3xf32> - %2 = "tf.IdentityN"(%1) {device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> - return %2 : tensor<1x3x4x3xf32> -} - -// CHECK: func @cast_bf16_gather_v2_to_fp32 -// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1024x3x4x3xf32>} : () -> tensor<1024x3x4x3xf32> -// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor -// CHECK: %[[gather:.*]] = "tf.GatherV2"(%[[cst]], %arg0, %[[cst_0]]) {batch_dims = 0 : i64} : (tensor<1024x3x4x3xf32>, tensor<1xi64>, tensor) -> tensor<1x3x4x3xf32> -// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[gather]]) -// CHECK: return %[[identity]] : tensor<1x3x4x3xf32> - // Tests that an AddV2 op accepting two bf16 operands is transformed into // an AddV2 op that accepts two fp32 operands. -func.func @cast_bf16_add_v2_to_fp32(%arg0: tensor<2xbf16>, %arg1: tensor<2xbf16>) -> tensor<2xbf16> { +func.func @cast_bf16_add_v2_to_fp32(%arg0: tensor<2xbf16>, %arg1: tensor<2xbf16>) -> tensor<2xf32> { %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> - return %0 : tensor<2xbf16> + %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<2xbf16>) -> tensor<2xf32> + return %1 : tensor<2xf32> } // The signature of the function is not changed. -// CHECK: func @cast_bf16_add_v2_to_fp32(%[[ARG_0:.*]]: tensor<2xbf16>, %[[ARG_1:.*]]: tensor<2xbf16>) -> tensor<2xbf16> +// CHECK: func @cast_bf16_add_v2_to_fp32(%[[ARG_0:.*]]: tensor<2xbf16>, %[[ARG_1:.*]]: tensor<2xbf16>) -> tensor<2xf32> // bfloat16 operands are cast to f32 operands. // CHECK-DAG: %[[CAST_0:.*]] = "tf.Cast"(%[[ARG_0]]) {Truncate = false} : (tensor<2xbf16>) -> tensor<2xf32> // CHECK-DAG: %[[CAST_1:.*]] = "tf.Cast"(%[[ARG_1]]) {Truncate = false} : (tensor<2xbf16>) -> tensor<2xf32> // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CAST_0]], %[[CAST_1]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// f32 outputs are cast back to bfloat16. -// CHECK: %[[CAST_2:.*]] = "tf.Cast"(%[[ADD]]) {Truncate = false} : (tensor<2xf32>) -> tensor<2xbf16> -// CHECK: return %[[CAST_2]] : tensor<2xbf16> +// CHECK: return %[[ADD]] : tensor<2xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir index cebbba385e5..f7c8c6aaabb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/convert_tpu_model_to_cpu.mlir @@ -19,7 +19,7 @@ func.func private @tpu_func_0_optim0(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x2 %2 = "tf.Transpose"(%0, %cst_0) {device = ""} : (tensor<1x3x4x3xbf16>, tensor<4xi32>) -> tensor<1x3x3x4xbf16> %3 = "tf.TPUReplicatedInput"(%2) {device = "", index = -1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor<1x3x3x4xbf16>) -> tensor<1x3x3x4xbf16> %4 = "tf.Transpose"(%3, %cst_1) {_tpu_replicate = "cluster", device = ""} : (tensor<1x3x3x4xbf16>, tensor<4xi32>) -> tensor<1x3x4x3xbf16> - %5 = "tf.Conv2D"(%4, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> + %5 = "tf.Conv2D"(%4, %cst) {_tpu_replicate = "cluster", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> %6 = "tf.TPUReplicatedOutput"(%5) {device = ""} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xbf16> %7 = "tf.Cast"(%6) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> func.return %7 : tensor<1x3x2x2xf32> @@ -43,9 +43,7 @@ func.func @serving_default(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor // The contents of `@serving_default` should have been inlined to `@batch_func`. // CHECK: func.func @serving_default(%[[ARG0:.*]]: tensor<1xf32>, %[[ARG1:.*]]: tensor<1xf32>) -> tensor<1xf32> // CHECK-NOT: tf.BatchFunction -// CHECK: %[[IDENTITY0:.*]] = "tf.Identity"(%[[ARG0]]) -// CHECK: %[[IDENTITY1:.*]] = "tf.Identity"(%[[ARG1]]) -// CHECK: %[[ADD0:.*]] = "tf.AddV2"(%[[IDENTITY0]], %[[IDENTITY1]]) +// CHECK: %[[ADD0:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]]) // CHECK: return %[[ADD0]] : tensor<1xf32> func.func private @batched_func(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir index debff9b9e26..f61a9fbe9fe 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions.mlir @@ -195,6 +195,35 @@ func.func @float_matmul( // ----- +func.func @float_matmul_with_reshape(%arg0: tensor<1x10xf32>, %arg1: tensor<10x10xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>} : () -> tensor<10xf32> + %cst_0 = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = false, transpose_b = true + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %2 = "tf.Reshape"(%1, %cst_0) : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %3 = "tf.BiasAdd"(%2, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<10xf32>) -> tensor<*xf32> + + func.return %3 : tensor<*xf32> + + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>} +// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi32>} +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]], %[[SHAPE]]) +// CHECK-SAME: f = @composite_matmul_with_reshape_and_bias_fn_1} +// CHECK: return %[[PARTITIONEDCALL_0]] +// CHECK: } + +// CHECK-LABEL: private @composite_matmul_with_reshape_and_bias_fn_1 +// CHECK-NEXT: tf.MatMul"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:transpose_a,1:transpose_b" +// CHECK-NEXT: tf.Reshape +// CHECK-NEXT: tf.BiasAdd +// CHECK-NEXT: return +} + +// ----- + // CHECK-LABEL: float_conv_no_bias func.func @float_conv_no_bias(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { %0 = "tf.Conv2D"(%arg0, %arg1) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir index 9d0a807aa52..ee97c375ba9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir @@ -1,4 +1,5 @@ -// RUN: tf-quant-opt %s -quant-prepare-lifting | FileCheck %s +// RUN: tf-quant-opt %s -quant-prepare-lifting -split-input-file | FileCheck %s +// RUN: tf-quant-opt %s -quant-prepare-lifting='target-opset=XLA' | FileCheck --check-prefix=XLA-CHECK %s func.func @decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -13,6 +14,8 @@ func.func @decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { // CHECK: %[[add:.*]] = "tf.AddV2"(%[[mul]], %[[CONST]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> // CHECK-NEXT: return %[[add]] : tensor<*xf32> +// ----- + func.func @not_decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -25,6 +28,8 @@ func.func @not_decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { // CHECK: %[[bn:.*]], %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %[[CONST]], %[[CONST_0]], %[[CONST_0]], %[[CONST]]) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) // CHECK-NEXT: return %[[bn]] : tensor<*xf32> +// ----- + func.func @convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -39,6 +44,8 @@ func.func @convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2 // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> +// ----- + func.func @not_convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x3xf32>) { %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x3xf32>} : () -> tensor<2x3x3x3xf32> %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<1x3x2x3xf32>} : () -> tensor<1x3x2x3xf32> @@ -53,6 +60,8 @@ func.func @not_convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3 // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x3xf32> +// ----- + func.func @fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -65,6 +74,8 @@ func.func @fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf3 // CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[CONV2D]] : tensor<1x3x2x2xf32> +// ----- + func.func @not_fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2x2xf32>} : () -> tensor<2x2xf32> @@ -79,6 +90,8 @@ func.func @not_fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x // CHECK-NEXT: %[[ADD:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x2xf32>, tensor<2x2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32> +// ----- + func.func @fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -95,6 +108,8 @@ func.func @fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor< // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> +// ----- + func.func @not_fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -113,6 +128,8 @@ func.func @not_fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (ten // CHECK-NEXT: %[[MUL:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_1]]) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[BIASADD]], %[[MUL]] : tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32> +// ----- + func.func @fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -129,6 +146,8 @@ func.func @fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>) -> (tensor< // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> +// ----- + func.func @not_fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2xf32>) -> (tensor<1x3x2x2xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> @@ -145,6 +164,8 @@ func.func @not_fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>, %arg1: // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[BIASADD]], %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32> +// ----- + func.func @match_depthwise_conv2d_and_add(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> @@ -159,6 +180,8 @@ func.func @match_depthwise_conv2d_and_add(%arg0: tensor<*xf32>) -> (tensor<*xf32 // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor<*xf32> // CHECK-NEXT: return %[[BIASADD]] : tensor<*xf32> +// ----- + func.func @match_depthwise_conv2d_and_mul(%arg0: tensor<*xf32>) -> (tensor) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> @@ -171,6 +194,8 @@ func.func @match_depthwise_conv2d_and_mul(%arg0: tensor<*xf32>) -> (tensor, tensor<2x3x3x1xf32>) -> tensor // CHECK-NEXT: return %[[DEPTHWISE_CONV2D]] : tensor +// ----- + func.func @match_depthwise_conv2d_with_bias_and_add(%arg0: tensor<*xf32>) -> (tensor) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> @@ -187,6 +212,8 @@ func.func @match_depthwise_conv2d_with_bias_and_add(%arg0: tensor<*xf32>) -> (te // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor // CHECK-NEXT: return %[[BIASADD]] : tensor +// ----- + func.func @match_depthwise_conv2d_with_bias_and_mul(%arg0: tensor<*xf32>) -> (tensor) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> @@ -203,6 +230,8 @@ func.func @match_depthwise_conv2d_with_bias_and_mul(%arg0: tensor<*xf32>) -> (te // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor // CHECK-NEXT: return %[[BIASADD]] : tensor +// ----- + func.func @lower_einsum(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> func.return %0 : tensor<3x4x6xf32> @@ -210,6 +239,7 @@ func.func @lower_einsum(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> t // CHECK-LABEL: lower_einsum // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> +// ----- func.func @removing_identity_after_const(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> @@ -225,6 +255,8 @@ func.func @removing_identity_after_const(%arg0: tensor<*xf32>) -> (tensor<*xf32> // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> // CHECK: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) +// ----- + func.func @not_removing_identity_of_returning_value(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> @@ -239,22 +271,24 @@ func.func @not_removing_identity_of_returning_value(%arg0: tensor<*xf32>) -> (te // CHECK: %[[identity:.*]] = "tf.Identity" // CHECK: return %[[identity]] : tensor<*xf32> +// ----- + func.func @batch_norm_with_q_dq(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { - %cst = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> - %cst_0 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> - %cst_1 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %0 = "quantfork.qcast"(%cst_1) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>> - %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>>) -> tensor<2x3x3x2xf32> - %2 = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> - %3 = "quantfork.dcast"(%2) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> - %4 = "tf.Conv2D"(%3, %1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> - %y, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%4, %cst, %cst_0, %cst, %cst_0) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>) - %5 = "tf.Relu6"(%y) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> - %6 = "quantfork.qcast"(%5) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>> - %7 = "quantfork.dcast"(%6) : (tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>>) -> tensor<1x3x2x2xf32> - %8 = "tf.Identity"(%7) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> - %9 = "tf.Identity"(%8) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> - return %9 : tensor<1x3x2x2xf32> + %cst = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.qcast"(%cst_1) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>>) -> tensor<2x3x3x2xf32> + %2 = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %3 = "quantfork.dcast"(%2) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %4 = "tf.Conv2D"(%3, %1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %y, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%4, %cst, %cst_0, %cst, %cst_0) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>) + %5 = "tf.Relu6"(%y) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %6 = "quantfork.qcast"(%5) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>> + %7 = "quantfork.dcast"(%6) : (tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>>) -> tensor<1x3x2x2xf32> + %8 = "tf.Identity"(%7) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %9 = "tf.Identity"(%8) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %9 : tensor<1x3x2x2xf32> } // CHECK: func @batch_norm_with_q_dq @@ -267,3 +301,80 @@ func.func @batch_norm_with_q_dq(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf // CHECK: %[[conv:.*]] = "tf.Conv2D"(%[[dq_input]], %[[dq_weight]]) // CHECK: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) {data_format = "NHWC"} // CHECK: %[[relu6:.*]] = "tf.Relu6"(%[[bias]]) + +// ----- + +func.func @xla_dot_v2(%arg0: tensor, %arg1: tensor<3x4x5xf32>) -> (tensor) { + %0 = "tf.XlaDotV2"(%arg0, %arg1) {device = "", dimension_numbers = "\0A\01\02\12\01\00", precision_config = ""} : (tensor, tensor<3x4x5xf32>) -> tensor + func.return %0 : tensor +} + +// CHECK: func @xla_dot_v2 +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 20]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<[-1, 2, 4, 5]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[reshape:.*]] = "tf.Reshape"(%arg1, %[[cst]]) : (tensor<3x4x5xf32>, tensor<2xi64>) -> tensor<3x20xf32> +// CHECK: %[[batch_matmul:.*]] = "tf.BatchMatMulV2"(%arg0, %[[reshape]]) {adj_x = false, adj_y = false} : (tensor, tensor<3x20xf32>) -> tensor +// CHECK: %[[reshape_0:.*]] = "tf.Reshape"(%[[batch_matmul]], %[[cst_0]]) : (tensor, tensor<4xi64>) -> tensor +// CHECK: return %[[reshape_0]] : tensor + +// XLA-CHECK: func @xla_dot_v2 +// XLA-CHECK: %[[einsum:.*]] = "tf.Einsum"(%arg0, %arg1) {equation = "abc,cde->abde"} : (tensor, tensor<3x4x5xf32>) -> tensor +// XLA-CHECK: return %[[einsum]] : tensor + +// ----- + +// dimension_numbers: { +// offset_dims: 0 +// collapsed_slice_dims: 1 +// start_index_map: 1 +// } +func.func @xla_gather(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor<2xi32>) -> tensor<*xf32> { + %0 = "tf.XlaGather"(%arg0, %arg1, %arg2) {device = "", dimension_numbers = "\0A\01\00\12\01\01\1A\01\01", indices_are_sorted = true} : (tensor, tensor<1xi32>, tensor<2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @xla_gather +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<0> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<1> : tensor<1x1xi64>} : () -> tensor<1x1xi64> +// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<2xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<2xi32>) -> tensor<2xi64> +// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> +// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> +// CHECK: return %[[reshape]] : tensor<*xf32> + +// ----- + +// Tests that the converted `tf.Slice` has the correct number of dimensions +// when the output shape is known (`tensor` instead of `tensor<*xi32>`). + +func.func @xla_gather_known_output_shape(%arg0: tensor<5xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor { + // dimension_numbers: { + // collapsed_slice_dims: 0 + // start_index_map: 0 + // } + %0 = "tf.XlaGather"(%arg0, %arg1, %arg2) {device = "", dimension_numbers = "\12\01\00\1A\01\00", indices_are_sorted = true} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK: func @xla_gather_known_output_shape +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<0> : tensor<1x1xi64>} : () -> tensor<1x1xi64> +// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> +// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<1xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) {Truncate = false} : (tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<5xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32> +// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<1xi32>, tensor<0xi64>) -> tensor +// CHECK: return %[[reshape]] : tensor + +// ----- + +func.func @replace_checknumerics_to_identity(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.CheckNumerics"(%arg0) {device = "", message = "transformer"} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @replace_checknumerics_to_identity +// CHECK: %[[out:.*]] = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir index 10bedcff581..d04ec262f6f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir @@ -19,8 +19,8 @@ func.func private @conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = " func.return %dq_res : tensor<*xf32> } -// CHECK-DAG: [[bias:%.+]] = "arith.constant"() {value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> -// CHECK-DAG: [[weight:%.+]] = "arith.constant"() {value = dense_resource<__elided__> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2x!quant.uniform> +// CHECK-DAG: [[bias:%.+]] = "arith.constant"() <{value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: [[weight:%.+]] = "arith.constant"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2x!quant.uniform> // CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform> // CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<*x!quant.uniform> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir index 5ba40e0eb1d..663b2efd580 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir @@ -139,3 +139,59 @@ module { // CHECK: Number of quantize layers added: 1 // CHECK: Number of dequantize layers added: 1 } + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1219 : i32}, tf_saved_model.semantics} { + func.func @embedding_with_one_float_conv_and_one_quantized_conv(%arg0: tensor<1xi32> {tf_saved_model.index_path = ["input"]}) -> (tensor<1x3x1x1xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x1024x1xf32>} : () -> tensor<3x3x1024x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<1024x3x4x3xf32>} : () -> tensor<1024x3x4x3xf32> + %cst_1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x3x3x1024xf32>} : () -> tensor<2x3x3x1024xf32> + + %0 = "tf.PartitionedCall"(%cst_0, %arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_gather_fn_1} : (tensor<1024x3x4x3xf32>, tensor<1xi32>, tensor) -> tensor<1x3x4x3xf32> + %1 = "tf.PartitionedCall"(%0, %cst_2) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_2} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1024xf32>) -> tensor<1x3x2x1024xf32> + %2 = "quantfork.qcast"(%1) : (tensor<1x3x2x1024xf32>) -> tensor<1x3x2x1024x!quant.uniform> + %3 = "quantfork.dcast"(%2) : (tensor<1x3x2x1024x!quant.uniform>) -> tensor<1x3x2x1024xf32> + %4 = "tf.PartitionedCall"(%3, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x2x1024xf32>, tensor<3x3x1024x1xf32>) -> tensor<1x3x1x1xf32> + %5 = "quantfork.qcast"(%4) : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1x!quant.uniform> + %6 = "quantfork.dcast"(%5) : (tensor<1x3x1x1x!quant.uniform>) -> tensor<1x3x1x1xf32> + return %6 : tensor<1x3x1x1xf32> + } + func.func private @composite_gather_fn_1(%arg0: tensor<1024x3x4x3xf32>, %arg1: tensor<1xi32>, %arg2: tensor) -> tensor<1x3x4x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.GatherV2"(%arg0, %arg1, %arg2) {attr_map = "0:batch_dims", batch_dims = 0 : i64, device = ""} : (tensor<1024x3x4x3xf32>, tensor<1xi32>, tensor) -> tensor<1x3x4x3xf32> + return %0 : tensor<1x3x4x3xf32> + } + func.func private @composite_conv2d_fn_2(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x1024xf32>) -> tensor<1x3x2x1024xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1024xf32>) -> tensor<1x3x2x1024xf32> + return %0 : tensor<1x3x2x1024xf32> + } + func.func private @composite_conv2d_fn_1(%arg0: tensor<1x3x2x1024xf32>, %arg1: tensor<3x3x1024x1xf32>) -> tensor<1x3x1x1xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x2x1024xf32>, tensor<3x3x1024x1xf32>) -> tensor<1x3x1x1xf32> + return %0 : tensor<1x3x1x1xf32> + } + +// CHECK-LABEL: func @embedding_with_one_float_conv_and_one_quantized_conv + +// CHECK: %[[quantized_gather:.*]] = "tf.PartitionedCall"( +// CHECK-SAME: f = @quantized_gather_float_output_fn_0 +// CHECK: %[[float_conv:.*]] = "tf.PartitionedCall"(%[[quantized_gather]] +// CHECK-SAME: f = @composite_conv2d_fn_2 +// CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%[[float_conv]] +// CHECK-SAME: f = @quantize_i8 +// CHECK: %[[quantized_conv:.*]] = "tf.PartitionedCall"(%[[quantize]] +// CHECK-SAME: f = @quantized_conv2d_float_output_fn_0 + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Gather 1/1 +// CHECK: Conv2D 1/2 + +// CHECK: Number of quantized layers with quantized outputs: 0/2 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 0 +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir index abe0c997195..c500b3c72e8 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir @@ -13,7 +13,7 @@ module { return %0 : tensor<*xf32> } -// CHECK: %[[cst:.*]] = "arith.constant"() {value = dense<0.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> +// CHECK: %[[cst:.*]] = "arith.constant"() <{value = dense<0.000000e+00> : tensor<2x1024xf32>}> : () -> tensor<2x1024xf32> // CHECK: %[[q_cst:.*]] = "quantfork.qcast"(%[[cst]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>> // CHECK: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_cst]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<*xf32> // CHECK: "func.return"(%[[out]]) : (tensor<*xf32>) -> () diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir index 9123e41967e..4356d084a56 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir @@ -19,8 +19,8 @@ func.func private @conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = " func.return %dq_res : tensor<*xf32> } -// CHECK-DAG: [[bias:%.+]] = "arith.constant"() {value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> -// CHECK-DAG: [[weight:%.+]] = "arith.constant"() {value = dense_resource<__elided__> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2x!quant.uniform> +// CHECK-DAG: [[bias:%.+]] = "arith.constant"() <{value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: [[weight:%.+]] = "arith.constant"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2x!quant.uniform> // CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> // CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform> // CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<*x!quant.uniform> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index f2d89b2df75..9f02d680300 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -63,6 +63,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:uniform_op_quant_spec", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc index e6f74a654aa..cb301ec8276 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h" +#include #include #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "llvm/ADT/StringMap.h" @@ -35,6 +37,26 @@ namespace mlir::quant { using QuantMethod = tensorflow::quantization::QuantizationMethod::ExperimentalMethod; +enum class OpType { + kDynamicRangeOp, // Dynamic Range kernels only have rhs attr. + kUnaryOp, // Unary ops have one min/max attr. + kBinaryOp, // Binary ops have lhs/rhs attr. + kQuantizationOp, // Quantization ops have input/output attr. +}; + +// For each op type, the following axis carries axis information: +// kDynamicRangeOp: rhs_quantization_axis will carry axis information. +// kUnaryOp: quantization_axis will carry axis information. +// kBinaryOp: Among {lhs, rhs, output}_quantization_axis, only check rhs. +// kQuantizationOp: Among {input, output}_quantization_axis, only check input. +// We therefore check exemplary 3 axes {rhs_, input_, }quantization_axis from +// previous accumulations. +constexpr std::array kQuantizationAxisAttrs = { + "input_quantization_axis", "quantization_axis", "rhs_quantization_axis"}; + +// Common suffixes for attributes used in FillQuantizationAttributes. +constexpr std::array kSuffixes = {"_min_val", "_max_val"}; + Attribute GetWindowStridesValue( PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { ArrayAttr stride = identifier_to_attr["strides"].dyn_cast(); @@ -103,50 +125,73 @@ Attribute GetBatchGroupCountValue( return rewriter.getI64IntegerAttr(1); } +Attribute GetQuantizationAxis(PatternRewriter& rewriter, Operation* op, + const int operand_index) { + auto* defining_op = op->getOperand(operand_index).getDefiningOp(); + for (auto attr : kQuantizationAxisAttrs) { + if (defining_op->hasAttr(attr)) { + return defining_op->getAttr(attr); + } + } + // Not found. + return rewriter.getI64IntegerAttr(-1); +} + void FillQuantizationAttributes(PatternRewriter& rewriter, Operation* op, NamedAttrList& attrs, llvm::StringMap& identifier_to_attr, - QuantMethod quantization_method) { + OpType op_type) { // TODO(b/259374419): Support broader quantization schemes absl::flat_hash_map min_max_scheme_for_8bit_narrow; min_max_scheme_for_8bit_narrow = {{"min", -127}, {"max", 127}}; - std::set quantization_attributes; - if (quantization_method == - tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { - quantization_attributes = { - "rhs_quantization_min_val", - "rhs_quantization_max_val", - }; - } else { - quantization_attributes = { - "lhs_quantization_min_val", "lhs_quantization_max_val", - "rhs_quantization_min_val", "rhs_quantization_max_val", - "output_quantization_min_val", "output_quantization_max_val", - }; + std::vector quantization_attributes; + switch (op_type) { + case OpType::kDynamicRangeOp: + quantization_attributes = {"rhs_quantization"}; + break; + case OpType::kUnaryOp: + quantization_attributes = {"quantization"}; + break; + case OpType::kBinaryOp: + quantization_attributes = {"lhs_quantization", "rhs_quantization", + "output_quantization"}; + break; + case OpType::kQuantizationOp: + quantization_attributes = {"input_quantization", "output_quantization"}; + break; + default: + quantization_attributes = {}; + break; } for (const auto& attr : quantization_attributes) { - auto quant_val = absl::StrContains(attr, "min") - ? min_max_scheme_for_8bit_narrow["min"] - : min_max_scheme_for_8bit_narrow["max"]; - auto quant_val_attr = rewriter.getI64IntegerAttr(quant_val); - attrs.push_back(rewriter.getNamedAttr(attr, quant_val_attr)); + for (int i = 0; i < kSuffixes.size(); i++) { + auto quant_val = i == 0 ? min_max_scheme_for_8bit_narrow["min"] + : min_max_scheme_for_8bit_narrow["max"]; + std::string attr_minmax = absl::StrCat(attr, kSuffixes[i]); + attrs.push_back(rewriter.getNamedAttr( + attr_minmax, rewriter.getI64IntegerAttr(quant_val))); + } } } +// This LogicalResult covers both the hybrid and fully quantized op cases. LogicalResult FillAttributesForUniformQuantizedDotOp( PatternRewriter& rewriter, Operation* op, llvm::StringMap& identifier_to_attr, QuantMethod quantization_method, bool enable_per_channel_quantization) { NamedAttrList attrs; - // Fill quantization related attributes. - FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, - quantization_method); - - if (!(quantization_method == - tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE)) { + if (quantization_method == + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { + // Fill quantization related attributes for Hybrid op. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kDynamicRangeOp); + } else { + // Fill quantization related attributes for fully quantized op. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kBinaryOp); // Per-channel activation is not supported attrs.push_back(rewriter.getNamedAttr("lhs_quantization_axis", rewriter.getI64IntegerAttr(-1))); @@ -158,7 +203,7 @@ LogicalResult FillAttributesForUniformQuantizedDotOp( absl::flat_hash_set operands = spec->quantizable_operands; int quant_dim = -1; if (enable_per_channel_quantization && operands.size() == 1) { - quant_dim = spec->coeff_op_quant_dim[*(spec->quantizable_operands.begin())]; + quant_dim = spec->coeff_op_quant_dim[*(operands.begin())]; } attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", rewriter.getI64IntegerAttr(quant_dim))); @@ -168,6 +213,7 @@ LogicalResult FillAttributesForUniformQuantizedDotOp( return success(); } +// This LogicalResult covers both the hybrid and fully quantized op cases. LogicalResult FillAttributesForUniformQuantizedConvolutionOp( PatternRewriter& rewriter, Operation* op, llvm::StringMap& identifier_to_attr, @@ -211,9 +257,16 @@ LogicalResult FillAttributesForUniformQuantizedConvolutionOp( attrs.push_back(rewriter.getNamedAttr( feature_group_cnt_attr, rewriter.getI64IntegerAttr(feature_group_cnt))); - // Fill quantization related attributes. - FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, - quantization_method); + if (quantization_method == + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { + // Fill quantization related attributes for Hybrid op. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kDynamicRangeOp); + } else { + // Fill quantization related attributes for fully quantized op. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kBinaryOp); + } if (quantization_method != tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { @@ -228,7 +281,7 @@ LogicalResult FillAttributesForUniformQuantizedConvolutionOp( absl::flat_hash_set operands = spec->quantizable_operands; int quant_dim = -1; if (enable_per_channel_quantization && operands.size() == 1) { - quant_dim = spec->coeff_op_quant_dim[*(spec->quantizable_operands.begin())]; + quant_dim = spec->coeff_op_quant_dim[*(operands.begin())]; } attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", rewriter.getI64IntegerAttr(quant_dim))); @@ -238,4 +291,84 @@ LogicalResult FillAttributesForUniformQuantizedConvolutionOp( return success(); } +LogicalResult FillAttributesForUniformQuantizedAddOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + const QuantMethod quantization_method, + const bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kBinaryOp); + Attribute activation_quantization_axis = rewriter.getI64IntegerAttr(-1); + if (enable_per_channel_quantization) { + // If either of lhs or rhs is per-channel quantized, the quantization axis + // must match for lhs, rhs, and output. + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/0); + if (activation_quantization_axis == rewriter.getI64IntegerAttr(-1)) { + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/1); + } + } + attrs.push_back(rewriter.getNamedAttr("lhs_quantization_axis", + activation_quantization_axis)); + attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", + activation_quantization_axis)); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + activation_quantization_axis)); + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +LogicalResult FillAttributesForUniformQuantizedClipByValueOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kUnaryOp); + + Attribute activation_quantization_axis = rewriter.getI64IntegerAttr(-1); + if (enable_per_channel_quantization) { + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/0); + } + attrs.push_back( + rewriter.getNamedAttr("quantization_axis", activation_quantization_axis)); + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +LogicalResult FillAttributesForUniformRequantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kQuantizationOp); + + Attribute activation_quantization_axis = rewriter.getI64IntegerAttr(-1); + if (enable_per_channel_quantization) { + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/0); + } + // For per-axis -> per-axis requantization, input and output quantization axis + // must be equal. + attrs.push_back(rewriter.getNamedAttr("input_quantization_axis", + activation_quantization_axis)); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + activation_quantization_axis)); + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + } // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h index 547473f3d90..b8e2a8bcd4f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h @@ -39,6 +39,27 @@ LogicalResult FillAttributesForUniformQuantizedConvolutionOp( quantization_method, bool enable_per_channel_quantization); +LogicalResult FillAttributesForUniformQuantizedAddOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedClipByValueOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformRequantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + bool enable_per_channel_quantization); + } // namespace mlir::quant #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index db721642303..74a804ceb97 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -71,6 +71,7 @@ tool_names = [ 'flatbuffer_to_string', 'flatbuffer_translate', 'hlo_to_llvm_ir', + 'ifrt-opt', 'json_to_flatbuffer', 'kernel-gen-opt', 'lhlo-tfrt-opt', diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 293118cda2b..0fa778bcc3a 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -53,6 +53,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/xla/mlir/tools/mlir_bisect', 'tensorflow/compiler/xla/mlir_hlo', 'tensorflow/compiler/xla/mlir_hlo/tosa', + 'tensorflow/compiler/xla/python/ifrt/ir/tests', 'tensorflow/compiler/xla/service/gpu/tests', 'tensorflow/compiler/xla/service/mlir_gpu', 'tensorflow/compiler/xla/translate', diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index b5b6746d645..d49bca20c10 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -407,7 +407,7 @@ cc_library( ":tensorflow_passes", ":tf_saved_model_passes", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", - "//tensorflow/compiler/mlir/tf2xla:legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -459,8 +459,8 @@ cc_library( srcs = [ "ir/tf_dialect.h", "ir/tf_ops.h", - "ir/tfrt_ops.h", "ir/tf_remaining_ops.h", + "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".cc" for target in tf_ops_category_list] + ["ir/tf_" + target["name"] + ".cc.inc" for target in tf_ops_category_list] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], @@ -475,8 +475,8 @@ cc_library( deps = [ ":attribute_utils", ":dynamic_shape_utils", - ":tensorflow_attributes", ":rewrite_util", + ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", @@ -484,24 +484,24 @@ cc_library( ":tensorflow_structs", ":tensorflow_traits", ":tensorflow_types", + ":tf_arith_ops_folder", ":tf_ops_canonicalization_helper", ":tf_ops_device_helper", ":tf_ops_layout_helper", ":tf_ops_tensor_helper", - ":tf_arith_ops_folder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Support", - "//tensorflow/core:framework", - "//tensorflow/core:lib", ] + [":tensorflow_" + target["name"] + "_inc_gen" for target in tf_ops_category_list], ) @@ -510,8 +510,8 @@ cc_library( srcs = [ "ir/tf_dialect.h", "ir/tf_ops.h", - "ir/tf_remaining_ops.h", "ir/tf_remaining_ops.cc", + "ir/tf_remaining_ops.h", "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], hdrs = [ @@ -555,9 +555,9 @@ cc_library( srcs = [ "ir/tf_dialect.h", "ir/tf_ops.h", - "ir/tfrt_ops.h", - "ir/tfrt_ops.cc", "ir/tf_remaining_ops.h", + "ir/tfrt_ops.cc", + "ir/tfrt_ops.h", ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], hdrs = [ ], @@ -1014,7 +1014,6 @@ cc_library( ":session_utils", ":tensorflow", ":tensorflow_ops", - ":tensorflow_passes", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework_internal", "@llvm-project//llvm:Support", @@ -1209,12 +1208,14 @@ cc_library( "transforms/device_index_selector.cc", "transforms/drop_while_shape_invariant.cc", "transforms/einsum.cc", + "transforms/embedding_pipelining.cc", "transforms/executor_island_coarsening.cc", "transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_island_coarsening.cc", "transforms/executor_tpuv1_outline_tpu_island.cc", "transforms/extract_head_tail_outside_compilation.cc", "transforms/extract_outside_compilation.cc", + "transforms/extract_tpu_copy_with_dynamic_shape_op.cc", "transforms/fold_broadcast.cc", "transforms/functional_control_flow_to_cfg.cc", "transforms/functional_control_flow_to_regions.cc", @@ -1267,9 +1268,11 @@ cc_library( "transforms/test_resource_alias_analysis.cc", "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", + "transforms/tpu_annotate_dynamic_shape_inputs.cc", "transforms/tpu_cluster_cleanup_attributes.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_colocate_composite_resource_ops.cc", + "transforms/tpu_colocate_splits.cc", "transforms/tpu_device_propagation.cc", "transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_host_computation_expansion.cc", @@ -1354,8 +1357,8 @@ cc_library( "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:validators", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1416,10 +1419,10 @@ cc_library( ":serialize_mlir_module_utils", ":shape_inference_utils", ":tensorflow", - ":tensorflow_types", ":tf_device_pass_inc_gen", ":tf_pass_inc_gen", ":translate_utils", + "//tensorflow/compiler/tf2xla/kernels:xla_call_module_loader", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1582,6 +1585,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/graph/regularization:util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -1996,19 +2000,22 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":constant_fold_utils", ":convert_tensor", - ":eval_util", + ":export_graphdef", + ":export_tf_dialect_op", ":tensorflow", ":tensorflow_traits", ":tensorflow_types", - "//tensorflow/c:tf_status", "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/core:all_kernels", + "//tensorflow/core:direct_session", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/ops", + "//tensorflow/core/tfrt/fallback:fallback_state", + "//tensorflow/core/tfrt/fallback:op_kernel_runner", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], alwayslink = 1, @@ -2084,7 +2091,7 @@ cc_library( ":mlir_import_options", ":mlir_roundtrip_flags", "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc/saved_model:reader", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -2349,7 +2356,9 @@ cc_library( srcs = ["utils/tpu_rewrite_device_util.cc"], hdrs = ["utils/tpu_rewrite_device_util.h"], deps = [ + ":device_util", ":tensorflow", + ":tensorflow_types", "//tensorflow/compiler/mlir/utils:string_container_utils", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -2370,6 +2379,8 @@ tf_cc_test( srcs = ["utils/tpu_rewrite_device_util_test.cc"], deps = [ ":device_util", + ":serialize_mlir_module_utils", + ":tensorflow", ":tpu_rewrite_device_util", "//tensorflow/core:framework", "//tensorflow/core:test", @@ -2513,8 +2524,14 @@ tf_cc_test( cc_library( name = "bridge_logger", - srcs = ["utils/bridge_logger.cc"], - hdrs = ["utils/bridge_logger.h"], + srcs = [ + "utils/bridge_logger.cc", + "utils/data_dumper_logger_config.cc", + ], + hdrs = [ + "utils/bridge_logger.h", + "utils/data_dumper_logger_config.h", + ], deps = [ ":dump_mlir_util", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index c7374f6fa72..b6d0ff71211 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -147,7 +147,7 @@ LogicalResult ParallelExecuteOp::verify() { } int output_index = 0; - for (auto& region_and_index : llvm::enumerate(regions)) { + for (const auto& region_and_index : llvm::enumerate(regions)) { auto& region = region_and_index.value(); auto* region_terminator = region.front().getTerminator(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index c11c6edd591..14ff8f37ae8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -671,7 +671,7 @@ array([b'3.14', b'2.72'], dtype=object) }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$input, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$input, DefaultValuedOptionalAttr:$precision, DefaultValuedOptionalAttr:$scientific, @@ -19882,7 +19882,7 @@ If two elements are equal, the lower-index element appears first. let arguments = (ins Arg:$input, - Arg, [{0-D. Number of top elements to look for along the last dimension (along each row for matrices).}]>:$k, DefaultValuedOptionalAttr:$sorted @@ -19890,10 +19890,12 @@ row for matrices).}]>:$k, let results = (outs Res:$values, - Res:$indices + Res, [{The indices of `values` within the last dimension of `input`.}]>:$indices ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tk = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr index_type = TF_DerivedResultTypeAttr<1>; let hasVerifier = 1; } @@ -20216,6 +20218,8 @@ Must have same shape with `output_scales`.}]>:$output_zero_points, ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasVerifier = 1; } def TF_UniformQuantizedClipByValueOp : TF_Op<"UniformQuantizedClipByValue", [Pure]> { @@ -20248,6 +20252,8 @@ Same shape condition as scales.}]>:$zero_points, ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasVerifier = 1; } def TF_UniformQuantizedConvolutionOp : TF_Op<"UniformQuantizedConvolution", [Pure]> { @@ -21048,6 +21054,33 @@ where(input) ==> [[0, 0, 0], TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_WriteTrainingPredictionsOp : TF_Op<"WriteTrainingPredictions", [DeclareOpInterfaceMethods]> { + let summary = [{ +Writes the given predictions into a RecordIO file using a previously + }]; + + let description = [{ +initialized global TrainingPredictionWriter. The predictions are transformed +into a PredictionData proto before they are written to the file. + }]; + + let arguments = (ins + Arg:$keys, + Arg, [{A list of float tensors containing prediction values.}]>:$predictions_list, + Arg:$step, + Arg:$timestamp_usec, + + StrArrayAttr:$prediction_names, + BoolAttr:$training, + StrAttr:$file_path + ); + + let results = (outs); + + TF_DerivedOperandSizeAttr num_predictions = TF_DerivedOperandSizeAttr<1>; +} + def TF_XdivyOp : TF_Op<"Xdivy", [Pure, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise."; @@ -21123,8 +21156,8 @@ def TF_XlaCallModuleOp : TF_Op<"XlaCallModule", [Pure]> { let summary = "Invokes a StableHLO module."; let description = [{ -This op is experimental and is intended for use with JAX native serialization -in a TensorFlow context. +This op is used with JAX native serialization in a TensorFlow context with +stability guarantees. }]; let arguments = (ins @@ -21137,7 +21170,8 @@ platform argument (see `platforms`) nor the dimension arguments (see StrAttr:$module, TF_ShapeAttrArray:$Sout, DefaultValuedOptionalAttr:$dim_args_spec, - DefaultValuedOptionalAttr:$platforms + DefaultValuedOptionalAttr:$platforms, + DefaultValuedOptionalAttr:$function_list ); let results = (outs @@ -22609,7 +22643,8 @@ A pseudo-op to represent host-side computation in an XLA program. StrAttr:$send_key, StrAttr:$recv_key, - DefaultValuedOptionalAttr:$host_mlir_module + DefaultValuedOptionalAttr:$host_mlir_module, + DefaultValuedOptionalAttr:$manual_sharding ); let results = (outs @@ -22636,7 +22671,8 @@ A placeholder op to receive values from a running XLA computation. execution the transfer corresponds to.}]>:$dynamic_key, StrAttr:$key, - I64Attr:$device_ordinal + I64Attr:$device_ordinal, + DefaultValuedOptionalAttr:$device_type ); let results = (outs @@ -22656,7 +22692,8 @@ A placeholder op to receive values from a running XLA computation with support f execution the transfer corresponds to.}]>:$dynamic_key, Arg:$device_ordinal, - StrAttr:$key + StrAttr:$key, + DefaultValuedOptionalAttr:$device_type ); let results = (outs @@ -22675,7 +22712,8 @@ def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", [DeclareOpInterfaceMethods execution the transfer corresponds to.}]>:$dynamic_key, StrAttr:$key, - I64Attr:$device_ordinal + I64Attr:$device_ordinal, + DefaultValuedOptionalAttr:$device_type ); let results = (outs); @@ -22694,7 +22732,8 @@ A placeholder op to send values to a running XLA computation with support for a execution the transfer corresponds to.}]>:$dynamic_key, Arg:$device_ordinal, - StrAttr:$key + StrAttr:$key, + DefaultValuedOptionalAttr:$device_type ); let results = (outs); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 8a15e3b3e85..83525c93047 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -188,6 +188,7 @@ def TF_TPUExecuteResource : TF_ResourceBase<"TPUExecute">; def TF_RandomGeneratorResource : TF_ResourceBase<"RandomGenerator">; def TF_XlaHostComputeResource : TF_ResourceBase<"XlaHostCompute">; def TF_XlaLaunchResource : TF_ResourceBase<"XlaLaunch">; +def TF_WriteTrainingPredictionsResource : TF_ResourceBase<"WriteTrainingPredictions">; def TF_CollectiveReduceOrderingResource : TF_ResourceBase<"CollectiveReduceOrdering">; def TF_NcclAllReduceOrderingResource : TF_ResourceBase<"NcclAllReduceOrdering">; @@ -252,6 +253,7 @@ def TF_XlaHostComputeSideEffect : MemoryEffects<[MemWrite]>; +def TF_WriteTrainingPredictions : MemoryEffects<[MemWrite]>; def TF_RandomGeneratorSideEffect : MemoryEffects<[MemWrite]>; // Special effect for keeping `CollectiveReduce` ops in order. @@ -294,6 +296,13 @@ def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { def TF_ShapeAttrArray : TypedArrayAttrBase; +// An array of FlatSymbolRef attributes that can be used as a default valued +// attribute. +def TF_SymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)"; +} + //===----------------------------------------------------------------------===// // TensorFlow type definitions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h index 88cbf879d56..db8208893e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h @@ -123,15 +123,15 @@ ResourceHandleValueAndId GetResourceHandleValueAndIdBase( // and have at least one operand, result type can be inferred using the first // operand's type. -#define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \ - LogicalResult Op::inferReturnTypeComponents( \ - MLIRContext* context, std::optional location, \ - ValueShapeRange operands, DictionaryAttr attributes, \ - RegionRange regions, \ - SmallVectorImpl& inferredReturnShapes) { \ - return inferReturnTypeComponentsFromOperands(context, location, operands, \ - attributes, regions, \ - inferredReturnShapes); \ +#define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \ + LogicalResult Op::inferReturnTypeComponents( \ + MLIRContext* context, std::optional location, \ + ValueShapeRange operands, DictionaryAttr attributes, \ + OpaqueProperties properties, RegionRange regions, \ + SmallVectorImpl& inferredReturnShapes) { \ + return inferReturnTypeComponentsFromOperands( \ + context, location, operands, attributes, properties, regions, \ + inferredReturnShapes); \ } #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 80857c23765..c9a890778f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -1868,4 +1868,101 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [Pure, ResultsBroadcastableShape]> { ]; } +def TF_TPUCopyWithDynamicShapeOp : TF_Op<"TPUCopyWithDynamicShape", [Pure, AttrSizedOperandSegments]> { + let summary = [{ +Op that copies host tensors to device with bounded dynamic shape support. + }]; + + let description = [{ +This op copies the padded tensor on cpu to TPU without the padded data. `tensors` +is a list of cpu tensors with padded data. `unpadded_sizes` is a list of shape +tensors which describes unpadded size of each dimension for each cpu tensor. +The size of the `unpadded_sizes` should be the same as `tensors`. They are both +on host. `tpu_tensors` are list of tpu device tensors without the padded data. +`tpu_tensors` also has the same size of the `tensors` and the shapes of +`tpu_tensors` are determined by the `unpadded_sizes`. + }]; + + let arguments = (ins + Variadic:$tensors, + Variadic:$unpadded_sizes + ); + + let results = (outs + Variadic:$tpu_tensors + ); + + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; +} + +def TF_TPUAnnotateTensorsWithDynamicShapeOp : TF_Op<"TPUAnnotateTensorsWithDynamicShape", [Pure]> { + let summary = [{ +Placeholder op which takes the output of TPUCopyWithDynamicShapeOp and pass +them to the following tpu ops. + }]; + + let description = [{ +This op serves as an annotation for the dynamic shaped tensor and will be +removed during the bridge rewrite. + }]; + + let arguments = (ins + Variadic:$tensors + ); + + let results = (outs + Variadic:$tpu_tensors + ); + + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; +} + +def TF_ConvertToCooTensorOp : TF_Op<"ConvertToCooTensor", [Pure]> { + let summary = [{ +Op that converts tensors into coo format. + }]; + + let description = [{ +This op coverts the dense, sparse and ragged tensor into standard coo tensor +format which contains three 1D tensors. + }]; + + let arguments = (ins + TF_Int32Tensor:$indices_or_row_splits, + TF_Int32Tensor:$values, + TF_Float32Tensor:$weights, + + ConfinedAttr]>:$sample_count, + StrAttr:$combiner + ); + + let results = (outs + TF_Int32Tensor:$row_ids, + TF_Int32Tensor:$col_ids, + TF_Float32Tensor:$gains + ); +} + +def TF_ResourceGatherNdOp : TF_Op<"ResourceGatherNd", []> { + let summary = "GatherNd on a resource."; + + let description = [{ +This op reads the variable referenced by the first argument, and +then performs a GatherNd operation on it. + }]; + + let arguments = (ins + Arg:$resource, + TF_I32OrI64Tensor:$indices + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index e69b91c198e..dfa46846aa1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -1693,7 +1693,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result, Type type, LogicalResult ConstOp::inferReturnTypes( MLIRContext* context, std::optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { auto value = attributes.get("value"); if (!value) return emitOptionalError(location, "missing attribute 'value'"); @@ -1936,7 +1936,8 @@ static LogicalResult inferConvReturnTypeComponents( LogicalResult Conv2DOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, + RegionRange regions, SmallVectorImpl& inferredReturnShapes) { Conv2DOpAdaptor op(operands.getValues(), attributes); ArrayRef explicit_padding; @@ -2134,7 +2135,8 @@ StringRef Conv2DBackpropInputOp::GetOptimalLayout( LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext* context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, + RegionRange regions, SmallVectorImpl& inferredReturnShapes) { Conv3DOpAdaptor op(operands.getValues(), attributes); ArrayRef explicit_padding; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index e4ec395b1cb..36b9d6c6e20 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -1,3 +1,4 @@ + /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -717,9 +718,9 @@ OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr; // Operands should all be scalars - assert(start_tensor.getType().getRank() == 0 && - limit_tensor.getType().getRank() == 0 && - delta_tensor.getType().getRank() == 0); + assert(start_tensor.getShapedType().getRank() == 0 && + limit_tensor.getShapedType().getRank() == 0 && + delta_tensor.getShapedType().getRank() == 0); Type elem_type = getType().cast().getElementType(); if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) { auto start_attr = start_tensor.getValues()[0]; @@ -2320,6 +2321,18 @@ void TPUExecuteOp::getEffects( } } +//===----------------------------------------------------------------------===// +// WriteTrainingPredictions +//===----------------------------------------------------------------------===// + +void WriteTrainingPredictionsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.reserve(1); + effects.emplace_back(MemoryEffects::Write::get(), + ResourceEffects::WriteTrainingPredictions::get()); +} + //===----------------------------------------------------------------------===// // TPUExecuteAndUpdateVariablesOp //===----------------------------------------------------------------------===// @@ -2372,7 +2385,7 @@ void TPUExecuteAndUpdateVariablesOp::getEffects( .isa(); }); - for (auto &entry : llvm::enumerate(resource_handles)) { + for (const auto &entry : llvm::enumerate(resource_handles)) { Value value = entry.value(); effects.emplace_back(MemoryEffects::Read::get(), value, ResourceEffects::Variable::get()); @@ -2660,7 +2673,7 @@ void ToBoolOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult ToBoolOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back( tensorflow::GetTypeFromTFTensorShape({}, IntegerType::get(context, 1))); @@ -2734,11 +2747,13 @@ LogicalResult TransposeOp::verify() { const int64_t y_idx = e.index(); const int64_t y_dim = y_type.getDimSize(y_idx); int64_t x_idx = e.value().getSExtValue(); - if (x_idx < 0) x_idx += x_type.getRank(); - if (x_idx < 0) { + int64_t x_rank = x_type.getRank(); + if (x_idx < -x_rank || x_idx >= x_rank) { return op.emitOpError( - llvm::formatv("perm[{0}] must be in [-rank, rank)", x_idx)); + llvm::formatv("perm[{0}]={1} must be in range [-{2}, {2})", y_idx, + x_idx, x_rank)); } + if (x_idx < 0) x_idx += x_rank; const int64_t x_dim = x_type.getDimSize(x_idx); if (!ShapedType::isDynamic(y_dim) && !ShapedType::isDynamic(x_dim) && y_dim != x_dim) { @@ -3467,7 +3482,8 @@ void XdivyOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( MLIRContext *context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, + RegionRange regions, SmallVectorImpl &inferredReturnShapes) { XlaBroadcastHelperOpAdaptor op(operands.getValues(), attributes); Value lhs = op.getLhs(); @@ -3490,7 +3506,7 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( return set_unranked_results(); } - if (dims.size() == 0) { + if (dims.empty()) { if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) { return emitOptionalError( location, @@ -3605,7 +3621,8 @@ LogicalResult XlaConvV2Op::verify() { LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( MLIRContext *context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, + RegionRange regions, SmallVectorImpl &inferredReturnShapes) { XlaSetDynamicDimensionSizeOpAdaptor op(operands.getValues(), attributes); @@ -3675,9 +3692,10 @@ LogicalResult XlaReduceWindowOp::verify() { auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult { ElementsAttr attr; if (matchPattern(val, m_Constant(&attr))) { - if (attr.getType().getRank() != 1) { - return op.emitOpError() << "expects the rank of " << attr_name - << "to be 1, got " << attr.getType().getRank(); + if (attr.getShapedType().getRank() != 1) { + return op.emitOpError() + << "expects the rank of " << attr_name << "to be 1, got " + << attr.getShapedType().getRank(); } if (input_ty.hasRank()) { int64_t input_rank = input_ty.getRank(); @@ -3705,11 +3723,11 @@ LogicalResult XlaReduceWindowOp::verify() { ElementsAttr padding; if (matchPattern(op.getPadding(), m_Constant(&padding))) { - const ShapedType &padding_ty = padding.getType(); + const ShapedType &padding_ty = cast(padding.getType()); if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) { return op.emitOpError() << "expects padding to be a matrix with minor dimension 2, got " - << padding.getType().getShape(); + << padding.getShapedType().getShape(); } } @@ -3762,11 +3780,11 @@ LogicalResult XlaSelectAndScatterOp::verify() { ElementsAttr padding; if (matchPattern(op.getPadding(), m_Constant(&padding))) { - const ShapedType &padding_ty = padding.getType(); + const ShapedType &padding_ty = cast(padding.getType()); if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) { return op.emitOpError() << "expects padding to be a matrix with minor dimension 2, got " - << padding.getType().getShape(); + << padding.getShapedType().getShape(); } } @@ -3922,8 +3940,8 @@ LogicalResult XlaVariadicSortOp::verify() { ElementsAttr dimension; if (matchPattern(op.getDimension(), m_Constant(&dimension))) { - if (dimension.getType().getRank() != 0 || - dimension.getType().getNumElements() != 1) + if (dimension.getShapedType().getRank() != 0 || + dimension.getShapedType().getNumElements() != 1) return op.emitOpError() << "dimension must be a scalar"; } @@ -4108,6 +4126,26 @@ LogicalResult UniformQuantizedConvolutionOp::verify() { return VerifyLhsRhsBothUniformQuantizedOp(*this); } +//===----------------------------------------------------------------------===// +// UniformQuantizedAddOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformQuantizedAddOp::verify() { + return VerifyLhsRhsBothUniformQuantizedOp(*this); +} + +//===----------------------------------------------------------------------===// +// UniformQuantizedClipByValueOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformQuantizedClipByValueOp::verify() { + UniformQuantizedClipByValueOp op = *this; + return VerifyScalesAndZeroPoints(op, op.getScales(), op.getZeroPoints(), + op.getQuantizationAxis()); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc index d872f8ecd04..247a85804d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -95,7 +95,7 @@ LogicalResult _XlaHostComputeMlirOp::verify() { if (!status.ok()) { return op.emitError() << "attribute 'host_mlir_module' can not be deserialized. " - << status.error_message(); + << status.message(); } func::FuncOp func = module_for_func->lookupSymbol("host_func"); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index e1c02b8d9c9..2fe672f1477 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -108,6 +108,11 @@ struct XlaLaunch : public ::mlir::SideEffects::Resource::Base { StringRef getName() final { return "XlaLaunch"; } }; +struct WriteTrainingPredictions + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "WriteTrainingPredictions"; } +}; + // Returns true iff resource type with given ID is only self-dependent, i.e., // there are no dependencies to other resource types (including unknown resource // type). diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 87def73cd5b..62f6192c1f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -151,7 +151,8 @@ class SameOperandsAndResultTypeResolveRef static LogicalResult inferReturnTypeComponentsFromOperands( MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, + RegionRange regions, SmallVectorImpl& inferredReturnShapes) { if (operands.empty()) return emitOptionalError( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir index 67a3e54979c..e1071f3e899 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -1,5 +1,43 @@ // RUN: tf-opt -split-input-file -verify-diagnostics -tf-einsum %s | FileCheck %s +func.func @unary_einsum_reduce_sum_transpose(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x5x4xf32> { + %0 = "tf.Einsum"(%arg0) {T = "tfdtype$DT_FLOAT", equation = "...gse->...sg"}: (tensor<3x4x5x6xf32>) -> tensor<3x5x4xf32> + func.return %0 : tensor<3x5x4xf32> + // CHECK-LABEL: unary_einsum_reduce_sum_transpose + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<3> : tensor<1xi32> + // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<[0, 2, 1]> : tensor<3xi32> + // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) {keep_dims = false} : (tensor<3x4x5x6xf32>, tensor<1xi32>) -> tensor<3x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%[[v0]], %[[cst_1]]) : (tensor<3x4x5xf32>, tensor<3xi32>) -> tensor<3x5x4xf32> + // CHECK: return %[[v1]] : tensor<3x5x4xf32> +} + +func.func @unary_einsum_reduce_sum_transpose1(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5xf32> { + %0 = "tf.Einsum"(%arg0) {T = "tfdtype$DT_FLOAT", equation = "...gse->...gs"}: (tensor<3x4x5x6xf32>) -> tensor<3x4x5xf32> + func.return %0 : tensor<3x4x5xf32> + // CHECK-LABEL: unary_einsum_reduce_sum_transpose1 + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<3> : tensor<1xi32> + // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) {keep_dims = false} : (tensor<3x4x5x6xf32>, tensor<1xi32>) -> tensor<3x4x5xf32> + // CHECK: return %[[v0]] : tensor<3x4x5xf32> +} + +func.func @unary_einsum_transpose(%arg0: tensor<3x4x5xf32>) -> tensor<3x5x4xf32> { + %0 = "tf.Einsum"(%arg0) {T = "tfdtype$DT_FLOAT", equation = "ijk->ikj"}: (tensor<3x4x5xf32>) -> tensor<3x5x4xf32> + func.return %0 : tensor<3x5x4xf32> + // CHECK-LABEL: unary_einsum_transpose + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[0, 2, 1]> : tensor<3xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi32>) -> tensor<3x5x4xf32> + // CHECK: return %[[v0]] : tensor<3x5x4xf32> +} + +func.func @unary_einsum_reduce_sum(%arg0: tensor<4x5x6xf32>) -> tensor<4xf32> { + %0 = "tf.Einsum"(%arg0) {T = "tfdtype$DT_FLOAT", equation = "ijk->i"}: (tensor<4x5x6xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> + // CHECK-LABEL: unary_einsum_reduce_sum + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK: %[[v0:.*]] = "tf.Sum"(%arg0, %[[cst]]) {keep_dims = false} : (tensor<4x5x6xf32>, tensor<2xi32>) -> tensor<4xf32> + // CHECK: return %[[v0]] +} + func.func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> func.return %0 : tensor<3x4x6xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir new file mode 100644 index 00000000000..f6bd3d4d586 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir @@ -0,0 +1,319 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-embedding-pipelining | FILECHECK_OPTS="" FileCheck %s + +// This test verifies the handling of TPU replicated inputs and outputs as well as the extraction of the four main functions. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + // Verify that everything is extracted into one of the four functions. + // The order of these functions is also significant. + // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_backward.*}} + // CHECK-NEXT: return + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } + // Generated functions + // non_tpu should have to TPU ops - just identity and return (in this test). + // CHECK: func.func private @_func_non_tpu + // CHECK-NEXT: tf.Identity + // CHECK-NEXT: return + + // sc_forward should have TPU ops including replicated outputs but not inputs + // CHECK: func.func private @_func_sc_forward + // CHECK-NOT: TPUReplicatedInput + // CHECK-DAG: TPUReplicateMetadata + // CHECK-DAG: TPUCompilationResult + // CHECK-DAG: TPUReplicatedOutput + // CHECK: return + + // core_tput should have TPU ops including both replicated inputs and outputs + // CHECK: func.func private @_func_core_tpu + // CHECK-DAG: TPUReplicatedInput + // CHECK-DAG: TPUReplicateMetadata + // CHECK-DAG: TPUCompilationResult + // CHECK-DAG: TPUReplicatedOutput + // CHECK: return + + // sc_backward should have TPU ops including replicted inputs but not outputs + // CHECK: func.func private @_func_sc_backward + // CHECK-NOT: TPUReplicatedOutput + // CHECK-DAG: TPUReplicateMetadata + // CHECK-DAG: TPUCompilationResult + // CHECK-DAG: TPUReplicatedInput + // CHECK: return +} + +// ----- +// This test verifies that the extraction works correctly for evaluation-only models. +module { + func.func @main() { + %cst = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %0 = "tf.While"(%cst) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Identity"(%arg0) {_embedding_pipelining = "forward", _replication_info = "repl_info"} : (tensor) -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } + // Only verify sc_backward. The previous test case verifies everything else. + // CHECK: func.func private @_func_sc_backward + // CHECK-NEXT: return +} + +// ----- +// A test verifying too many TPUReplicateMetadataOp ops. Same logic tests too many TPUCompilationResultOp ops. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*x!tf_type.resource>>) { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + // expected-error @+1 {{number of tf.TPUReplicateMetadata in loop body is not 1}} + func.func private @while_body(%arg0: tensor) -> (tensor) { + // metadata ops + %embedding_pass_trigger = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + return %arg0 : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- +// A test verifying the replication region of TPUReplicateMetadataOp ops. Same logic tests too many TPUCompilationResultOp ops. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>, %arg2: tensor<*x!tf_type.resource>>) { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + // metadata ops + %embedding_pass_trigger = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 1 : i64} : () -> () + // expected-error @+1 {{'tf.TPUCompilationResult' op is not part of the replication region "repl_info" vs "wrong_repl_info"}} + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "wrong_repl_info"} : () -> tensor + return %arg0 : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +} + +// ----- +// A test verifying TPUReplicatedOutput in the input graph doesn't trigger +// any additional TPUReplicatedInput or TPUReplicatedOutput ops. +module { + func.func @main() { + %cst_1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %0:2 = "tf.While"(%cst_1, %cst_2) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + return + } + func.func private @while_body(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + %2 = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<3> : tensor} : () -> tensor + %3:2 = "tf.TPUReplicatedOutput"(%2) {device = ""} : (tensor) -> (tensor, tensor) + + // core_tpu ops: + %res_t = "tf.Const"() {_replication_info = "repl_info", value = dense<4> : tensor} : () -> tensor + + // non_tpu_ops + %res_n = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + + return %res_n, %3#1 : tensor, tensor + } + func.func private @while_cond(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.Less"(%arg1, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } + // CHECK-DAG: TPUReplicatedOutput + // CHECK-NOT: TPUReplicatedoutput + // CHECK-NOT: TPUReplicatedInput +} + +// ----- +// Verify error for backward pass with no forward pass. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + // expected-error @+1 {{'tf.Identity' op embedding backwards pass op with no forwards pass ops}} + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// Verify error for unknown _embedding_pipelining attribute value. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + // expected-error @+1 {{'tf.Identity' op embedding op has unknown _embedding_pipelining attribute value garbage.}} + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "garbage", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// Verify error for multiple WhileOp use of while_body function. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + // expected-error @+1 {{'tf.While' op multiple users of function.}} + %1 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} + +// ----- +// Verify error for non-WhileOp use of while_body function. +module { + func.func @main() { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.While"(%cst_main) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor) -> (tensor) + // expected-error @+1 {{'tf.StatefulPartitionedCall' op non while use of function.}} + %38 = "tf.StatefulPartitionedCall"(%cst_main) {config = "", config_proto = "", executor_type = "", f = @while_body} : (tensor) -> tensor + return + } + func.func private @while_body(%arg0: tensor) -> (tensor) { + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + return %res_n : tensor + } + func.func private @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir index ffcd2a25923..baf243c9b5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_head_tail_outside_compilation.mlir @@ -152,7 +152,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) // CHECK-NOT: _xla_outside_compilation // CHECK-NEXT: tf_device.return %[[A_OUT]] - // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // CHECK-NEXT: device = "TPU_REPLICATED_HOST_0" // // CHECK: "tf_device.cluster" // CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]]) @@ -370,7 +370,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]]) // CHECK-NOT: _xla_outside_compilation // CHECK-NEXT: tf_device.return - // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // CHECK-NEXT: device = "TPU_REPLICATED_HOST_0" tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { "tf_device.cluster"() ({ %a = "tf.A"(%ri) : (tensor) -> tensor @@ -439,7 +439,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) // CHECK-NOT: _xla_outside_compilation // CHECK-NEXT: tf_device.return %[[A_OUT]] - // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // CHECK-NEXT: device = "TPU_REPLICATED_HOST_0" // // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" @@ -456,7 +456,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]]) // CHECK-NOT: _xla_outside_compilation // CHECK-NEXT: tf_device.return - // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // CHECK-NEXT: device = "TPU_REPLICATED_HOST_0" tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { "tf_device.cluster"() ({ %a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir index 0f2941bf317..f9a097d8fef 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir @@ -104,7 +104,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: "tf.B" // CHECK-NEXT: tf_device.return - // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // CHECK-NEXT: device = "TPU_REPLICATED_HOST_0" // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster" // CHECK: tf_device.return // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] @@ -215,6 +215,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"() // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: device_type = "TPU" // CHECK-SAME: key = "host_compute_channel_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" @@ -227,7 +228,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %2 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<2xi32>) %3 = "tf.C"(%2) : (tensor<2xi32>) -> tensor<2xi32> tf_device.return %3 : tensor<2xi32> - }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32> + }) {_xla_compile_device_type = "TPU", num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32> func.return %0 : tensor<2xi32> } @@ -2067,3 +2068,33 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func.return %0 : tensor<2xi32> } } + +// ----- +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0"]} { + // CHECK-LABEL: func @single_outside_compiled_output_device_type + func.func @single_outside_compiled_output_device_type(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf.Const"() {value = dense<""> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + // CHECK-NOT: "tf._TPUDeviceOrdinalPlaceholder" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"() + // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: device_type = "CPU" + // CHECK-SAME: key = "host_compute_channel_0_retvals" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"() + // CHECK-SAME: recv_key = "host_compute_channel_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_0_args" + // CHECK: "tf.C"(%[[HOST_OUTPUT]]) + %0 = "tf_device.cluster"() ({ + %1 = "tf.A"() : () -> (tensor<2xi32>) + %2 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<2xi32>) + %3 = "tf.C"(%2) : (tensor<2xi32>) -> tensor<2xi32> + tf_device.return %3 : tensor<2xi32> + }) {_xla_compile_device_type = "CPU"} : () -> tensor<2xi32> + + func.return %0 : tensor<2xi32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_tpu_copy_with_dynamic_shape_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_tpu_copy_with_dynamic_shape_op.mlir new file mode 100644 index 00000000000..fed754ea3c1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_tpu_copy_with_dynamic_shape_op.mlir @@ -0,0 +1,43 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-extract-tpu-copy-with-dynamic-shape-op %s | FileCheck %s + +// Test that extract TPUCopyWithDynamicShape from host launch to device launch + +// CHECK-LABEL: func @valid_copy_op_in_replicated_host + +// CHECK: "tf_device.launch" +// CHECK: "TPU_REPLICATED_HOST_0" +// CHECK: "tf_device.launch" +// CHECK: "tf.TPUCopyWithDynamicShape" +// CHECK: "TPU_REPLICATED_CORE_0" +func.func @valid_copy_op_in_replicated_host( + %arg0: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, + %arg1: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<2048xi32>, tensor<2048xi32>) { + %cst = "tf.Const"() {value = dense<1024> : tensor} : () -> tensor + %0:2 = "tf_device.launch"() ({ + %1 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2048xi64>) -> tensor<2048xi32> + %2 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2048xi64>) -> tensor<2048xi32> + %3:2 = "tf.TPUCopyWithDynamicShape"(%1, %2, %cst, %cst) {operand_segment_sizes = array} : (tensor<2048xi32>, tensor<2048xi32>, tensor, tensor) -> (tensor<2048xi32>, tensor<2048xi32>) + tf_device.return %3#0, %3#1 : tensor<2048xi32>, tensor<2048xi32> + }) {device = "TPU_REPLICATED_HOST_0"} : () -> (tensor<2048xi32>, tensor<2048xi32>) + return %0#0, %0#1: tensor<2048xi32>, tensor<2048xi32> +} + +// CHECK-LABEL: func @valid_copy_op_in_non_replicated_host + +// CHECK: "tf_device.launch" +// CHECK: "/job:localhost/replica:0/task:0/device:CPU:0" +// CHECK: "tf_device.launch" +// CHECK: "tf.TPUCopyWithDynamicShape" +// CHECK: "/job:localhost/replica:0/task:0/device:TPU:0" +func.func @valid_copy_op_in_non_replicated_host( + %arg0: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, + %arg1: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<2048xi32>, tensor<2048xi32>) { + %cst = "tf.Const"() {value = dense<1024> : tensor} : () -> tensor + %0:2 = "tf_device.launch"() ({ + %1 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2048xi64>) -> tensor<2048xi32> + %2 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2048xi64>) -> tensor<2048xi32> + %3:2 = "tf.TPUCopyWithDynamicShape"(%1, %2, %cst, %cst) {operand_segment_sizes = array} : (tensor<2048xi32>, tensor<2048xi32>, tensor, tensor) -> (tensor<2048xi32>, tensor<2048xi32>) + tf_device.return %3#0, %3#1 : tensor<2048xi32>, tensor<2048xi32> + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<2048xi32>, tensor<2048xi32>) + return %0#0, %0#1: tensor<2048xi32>, tensor<2048xi32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt index b7e47126779..d4d5b8e3c52 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt @@ -1,10 +1,11 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-xla-compile-device-type="GPU" -o - | FileCheck %s # Verify main graph was converted to a function, args/rets are mapped correctly, # and ops in the main graph are retained. In addition, check if subsequent # functions are converted. # CHECK: func @main(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>>, %arg2: tensor<*xf32>, %arg3: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<*xf32>) +# CHECK-SAME: _xla_compile_device_type = "GPU" # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "args_0,args_1,args_2,args_3" # CHECK-SAME: outputs = "rets_0,rets_1" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir index e6c041168d5..446af8cfb3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir @@ -2,10 +2,10 @@ // CHECK: name: "tf.ParseExample" // CHECK-NEXT: op: "ParseExample" -// CHECK-NEXT: input: "tf.Const3" +// CHECK-NEXT: input: "tf.Const{{_.*_3}}" // CHECK-NEXT: input: "tf.Const" -// CHECK-NEXT: input: "tf.Const1" -// CHECK-NEXT: input: "tf.Const2" +// CHECK-NEXT: input: "tf.Const{{_.*_1}}" +// CHECK-NEXT: input: "tf.Const{{_.*_2}}" // CHECK-NEXT: attr { // CHECK-NEXT: key: "Ndense" // CHECK-NEXT: value { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example_v2.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example_v2.mlir index a79a6c772d6..bf69559780c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example_v2.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example_v2.mlir @@ -15,12 +15,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK: name: "ParseExample" // CHECK-NEXT: op: "ParseExampleV2" // CHECK-NEXT: input: "input0" - // CHECK-NEXT: input: "tf.Const3" - // CHECK-NEXT: input: "tf.Const5" - // CHECK-NEXT: input: "tf.Const2" - // CHECK-NEXT: input: "tf.Const4" + // CHECK-NEXT: input: "tf.Const{{_.*_3}}" + // CHECK-NEXT: input: "tf.Const{{_.*_5}}" + // CHECK-NEXT: input: "tf.Const{{_.*_2}}" + // CHECK-NEXT: input: "tf.Const{{_.*_4}}" // CHECK-NEXT: input: "tf.Const" - // CHECK-NEXT: input: "tf.Const1" + // CHECK-NEXT: input: "tf.Const{{_.*_1}}" // CHECK-NEXT: attr { // CHECK-NEXT: key: "Tdense" // CHECK-NEXT: value { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir index 9570d2cdb94..fdbfc839e55 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir @@ -25,4 +25,4 @@ attributes {tf.entry_function = {inputs = "foo,bar", outputs = "Add"}} { // CHECK-NEXT: input: "[[BAR_ID_0]]" // CHECK: name: "Add" // CHECK-NEXT: op: "_Retval" -// CHECK-NEXT: input: "Add1" +// CHECK-NEXT: input: "Add{{_.*_1}}" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir index 0cc07f8816c..d608c8550c6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir @@ -16,14 +16,14 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} { // CHECK-NEXT: name: "input1" // CHECK-NEXT: op: "_Arg" // CHECK: node { -// CHECK-NEXT: name: "Add1" +// CHECK-NEXT: name: "Add{{_.*_1}}" // CHECK-NEXT: op: "Add" // CHECK-NEXT: input: "input0" // CHECK-NEXT: input: "input1" // CHECK: node { // CHECK-NEXT: name: "Add" // CHECK-NEXT: op: "_Retval" -// CHECK-NEXT: input: "Add1" +// CHECK-NEXT: input: "Add{{_.*_1}}" // CHECK-NEXT: attr { // CHECK-NEXT: key: "T" // CHECK-NEXT: value { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/unique_name.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/unique_name.mlir index 72b445341ea..dc569a9e94f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/unique_name.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/unique_name.mlir @@ -4,11 +4,11 @@ func.func @main() { tf_executor.graph { // CHECK: name: "foo" %0:2 = tf_executor.island wraps "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<0> : tensor} : () -> (tensor) loc("foo") - // CHECK: name: "foo1" + // CHECK: name: "foo{{_.*_1}}" %1:2 = tf_executor.island wraps "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<1> : tensor} : () -> (tensor) loc("foo") - // CHECK: name: "foo11" + // CHECK: name: "foo1" %2:2 = tf_executor.island wraps "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor} : () -> (tensor) loc("foo1") - // CHECK: name: "foo2" + // CHECK: name: "foo{{_.*_2}}" %3:2 = tf_executor.island wraps "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor} : () -> (tensor) loc("foo") // CHECK: name: "2" %4:2 = tf_executor.island wraps "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<3> : tensor} : () -> (tensor) loc("2") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir index e2d94c9c6e7..2f744534abd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/outside_compiled_to_host_launch.mlir @@ -48,7 +48,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.C" // CHECK-NOT: _xla_outside_compilation // CHECK: tf_device.return - // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // CHECK-NEXT: device = "TPU_REPLICATED_HOST_0" // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = "" %0 = "tf.A"(%arg0) : (tensor) -> tensor tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir index 66b0395f00a..45ee57ad75d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir @@ -163,3 +163,13 @@ func.func @UnsupportedOp(%arg0: tensor) -> tensor { func.return %0 : tensor } +// ----- + +// _XlaHostComputeMlir with manual_sharding should not fall back to +// XlaHostCompute, because XlaHostCompute does not support manual_sharding. + +func.func @HostComputeManualNoFallback(%arg0: tensor) -> () { + // expected-error @+1 {{manual_sharding not supported with fallback}} + %1 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv1", send_key = "host_compute_channel_send1", host_mlir_module = "", manual_sharding = true} : (tensor) -> (tensor) + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir index 2892a011923..b34a26431c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir @@ -32,6 +32,18 @@ func.func @invariant_shape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // CHECK: tf_device.return %[[SHAPE]] +// CHECK-LABEL: func @not_invariant_ordinal_placeholder +func.func @not_invariant_ordinal_placeholder(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { + // CHECK: tf_device.replicate + // CHECK: tf._TPUDeviceOrdinalPlaceholder + %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<*xf32>) {n = 2: i32} { + %1 = "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + tf_device.return %1 : tensor + } + func.return +} + + // CHECK-LABEL: func @replicate_resource_var_arg_shape // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*x!tf_type.resource>, %[[ARG_1:[a-z0-9]*]]: tensor<*x!tf_type.resource>) func.func @replicate_resource_var_arg_shape(%arg0: tensor<*x!tf_type.resource>, %arg1: tensor<*x!tf_type.resource>) { @@ -190,3 +202,21 @@ func.func @do_not_hoist_ops_with_virtual_device(%arg0: tensor<*xf32>, %arg1: ten // CHECK: tf_device.return [[OP_C]] : tensor<*xi32> // CHECK: }) {device = "c"} : () -> tensor<*xi32> // CHECK: tf_device.return [[SHAPE]], [[OP_A]], [[LAUNCH_B]], [[LAUNCH_C]] + + +// Checks that the argument to a Shape that has a virtual device is not changed. + +// CHECK-LABEL: func @do_not_mutate_shape_op_with_virtual_device +// CHECK: tf_device.replicate +// CHECK-SAME: as [[RI:%.*]]: tensor<*xf32> +// CHECK: "tf.Shape"([[RI]]) +func.func @do_not_mutate_shape_op_with_virtual_device(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { + tf_device.replicate([%arg0, %arg1] as %ri: tensor<*xf32>) {devices = {TPU_REPLICATED_HOST_0 = ["/device:CPU:0", "/device:CPU:1"]}, n = 2: i32} { + "tf_device.launch"() ({ + %1 = "tf.Shape"(%ri) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor + tf_device.return + }) {device = "TPU_REPLICATED_HOST_0"} : () -> () + tf_device.return + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 3a1e3316c26..71fbf7cca9e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1282,6 +1282,21 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %1#1 : tensor<*x!quant.uniform> } + // CHECK-LABEL: func @xla_call_module + // CHECK-SAME: (%arg0: tensor) -> tensor + func.func @xla_call_module(%arg0: tensor) -> tensor<*xf32> { + // Equivalent to the following: + // + // module @jit_sin { + // func.func public @main(%arg0: tensor) -> tensor { + // %0 = stablehlo.sine %arg0 : tensor + // return %0 : tensor + // } + // } + %0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = [], version = 4 : i64} : (tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> + } + // CHECK-LABEL: func @xla_host_compute_mlir_empty_module func.func @xla_host_compute_mlir_empty_module(%arg0: tensor<2xf32>) -> tensor<*xf32> { // CHECK: "tf._XlaHostComputeMlir" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 17259f35fc2..74363ecc967 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -701,6 +701,7 @@ func.func @testConv2D(%arg0: tensor<256x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) // ----- func.func @testConv3D(%arg0: tensor<256x32x32x32x3xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x32x32x16xf32> { + // expected-error @+2 {{'tf.Conv3D' op failed to infer returned types}} // expected-error @+1 {{op inferred type(s) 'tensor<256x32x32x32x16xf32>' are incompatible with return type(s) of operation 'tensor<256x32x32x16xf32>'}} %0 = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<256x32x32x32x3xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x32x32x16xf32> func.return %0 : tensor<256x32x32x16xf32> @@ -757,6 +758,7 @@ func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32 // ----- func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> { + // expected-error @+2 {{'tf.Conv2D' op failed to infer returned types}} // expected-error @+1 {{op inferred type(s) 'tensor<256x16x11x16xf32>' are incompatible with return type(s) of operation 'tensor<256x30x30x16xf32>'}} %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 2, 3, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> func.return %0 : tensor<256x30x30x16xf32> @@ -765,6 +767,7 @@ func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32 // ----- func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x16x30x16xf32> { + // expected-error @+2 {{'tf.Conv2D' op failed to infer returned types}} // expected-error @+1 {{op inferred type(s) 'tensor<256x16x11x16xf32>' are incompatible with return type(s) of operation 'tensor<256x16x30x16xf32>'}} %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "SAME", strides = [1, 2, 3, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x30x16xf32> func.return %0 : tensor<256x16x30x16xf32> @@ -773,6 +776,7 @@ func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32 // ----- func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> { + // expected-error @+2 {{'tf.Conv2D' op failed to infer returned types}} // expected-error @+1 {{op inferred type(s) 'tensor<256x6x6x16xf32>' are incompatible with return type(s) of operation 'tensor<256x32x32x16xf32>'}} %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "EXPLICIT", dilations = [1, 2, 3, 4], explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8], strides = [5, 6, 7, 8]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> func.return %0 : tensor<256x32x32x16xf32> @@ -781,6 +785,7 @@ func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32 // ----- func.func @testConv2D(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> { + // expected-error @+2 {{'tf.Conv2D' op failed to infer returned types}} // expected-error @+1 {{op inferred type(s) 'tensor<256x30x30x16xf32>' are incompatible with return type(s) of operation 'tensor<256x32x32x16xf32>'}} %0 = "tf.Conv2D"(%arg0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> func.return %0 : tensor<256x32x32x16xf32> @@ -2554,6 +2559,7 @@ func.func @testConst() -> tensor { // Test invalid tf.ToBool func.func @testInvalidToBool(%arg0: tensor) -> tensor<1xi1> { + // expected-error @+2 {{'tf.ToBool' op failed to infer returned types}} // expected-error @+1 {{op inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor<1xi1>'}} %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor<1xi1> func.return %0 : tensor<1xi1> @@ -2639,7 +2645,7 @@ func.func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> { func.func @testTranspose(tensor<2x2xf32>) -> tensor<2x2xf32> { ^bb0(%arg0: tensor<2x2xf32>): %cst = arith.constant dense<[1, -3]> : tensor<2xi32> - // expected-error @+1 {{perm[-1] must be in [-rank, rank)}} + // expected-error @+1 {{'tf.Transpose' op perm[1]=-3 must be in range [-2, 2)}} %0 = "tf.Transpose"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", Tperm = "tfdtype$DT_INT32"} : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> } @@ -4341,6 +4347,7 @@ func.func @testVarHandleOp() -> tensor<*x!tf_type.resource> { func.func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () { %0 = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> + // expected-error @+2 {{'tf.XlaBroadcastHelper' op failed to infer returned types}} // expected-error @+1 {{broadcast_dims must have size equal to the smaller argument rank}} %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<1xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>) func.return @@ -4350,6 +4357,7 @@ func.func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi3 func.func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () { %0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> + // expected-error @+2 {{'tf.XlaBroadcastHelper' op failed to infer returned types}} // expected-error @+1 {{if broadcast_dims is empty, both arguments must have equal rank or at least one argument must be a scalar}} %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<0xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>) func.return @@ -4359,6 +4367,7 @@ func.func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi3 func.func @testXlaBroadcastHelper(%arg0: tensor<5x2xi32>, %arg1: tensor<2x3x5xi32>) -> () { %0 = "tf.Const"() {value = dense<0> : tensor<2xi64>} : () -> tensor<2xi64> + // expected-error @+2 {{'tf.XlaBroadcastHelper' op failed to infer returned types}} // expected-error @+1 {{broadcast_dims has duplicates}} %lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<5x2xi32>, tensor<2x3x5xi32>, tensor<2xi64>) -> (tensor<2x1x5xi32>, tensor<2x3x5xi32>) func.return @@ -4674,6 +4683,7 @@ func.func @testReluStaticShapeInputAndDynamicShapeOutput(%arg0: tensor<8x16xf32> func.func @set_dynamic_dimension_size(%input: tensor<4xf32>, %size: tensor) -> tensor { %dimension = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + // expected-error @+2 {{'tf.XlaSetDynamicDimensionSize' op failed to infer returned types}} // expected-error @+1 {{dim_index (1) is out of range [0, 1)}} %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor func.return %0 : tensor @@ -5023,6 +5033,52 @@ func.func @testUniformQuantizedConvolution( func.return } +// ----- + +func.func @testUniformQuantizedAdd( + %input: tensor<2x2x!tf_type.qint32>, %bias: tensor<2x!tf_type.qint32>, + %input_scales: tensor, %input_zps: tensor, + %bias_scales: tensor, %bias_zps: tensor, + %output_scales: tensor<2xf32>, %output_zps: tensor) -> () { + // expected-error @below {{'tf.UniformQuantizedAdd' op quantization_axis is -1, scales must have 0 rank.}} + %1 = "tf.UniformQuantizedAdd"( + %input, %bias, + %input_scales, %input_zps, + %bias_scales, %bias_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -2147483648 : i64, + lhs_quantization_max_val = 2147483647 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -2147483648 : i64, + rhs_quantization_max_val = 2147483647 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<2x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, + tensor, tensor, + tensor, tensor, + tensor<2xf32>, tensor) -> tensor<2x2x!tf_type.qint32> + func.return +} + +// ----- + +func.func @testUniformQuantizedClipByValue( + %operand: tensor<*x!tf_type.qint32>, %min: tensor, %max: tensor, + %scales: tensor<2xf32>, %zps: tensor) -> () { + // expected-error @below {{'tf.UniformQuantizedClipByValue' op quantization_axis is -1, scales must have 0 rank.}} + %0 = "tf.UniformQuantizedClipByValue"(%operand, %min, %max, %scales, %zps) { + quantization_axis = -1 : i64, + quantization_min_val = -2147483648 : i64, + quantization_max_val = 2147483647 : i64 + } : ( + tensor<*x!tf_type.qint32>, tensor, tensor, + tensor<2xf32>, tensor + ) -> tensor<*x!tf_type.qint32> + func.return +} + // Following tests are for LegacyCall symbol use verifier. // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_location_roundtrip.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_location_roundtrip.mlir index d5e1a637b7f..fd69ecd1436 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_location_roundtrip.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_location_roundtrip.mlir @@ -18,6 +18,7 @@ // tf_executor.island, tf.Identity, and tf_executor.yield). // CHECK-LABEL: "func.func" +// CHECK: sym_name = "island_one_op_all_locs_same" // CHECK: "tf_executor.graph"() ({ // CHECK-NEXT: "tf_executor.island"() ({ // CHECK-NEXT: "tf.Identity"(%{{.*}}) : (tensor) -> tensor loc("identity@some_function") @@ -26,7 +27,6 @@ // CHECK-NEXT: "tf_executor.fetch"(%{{.*}}) : (tensor) -> () loc(unknown) // CHECK-NEXT: }) : () -> tensor loc(unknown) // CHECK-NEXT: "func.return"(%{{.*}}) : (tensor) -> () loc(unknown) -// CHECK-NEXT: sym_name = "island_one_op_all_locs_same" func.func @island_one_op_all_locs_same(%arg0: tensor) -> tensor { %0 = "tf_executor.graph"() ({ @@ -45,6 +45,7 @@ func.func @island_one_op_all_locs_same(%arg0: tensor) -> tensor { // don't have identical locations. // CHECK-LABEL: "func.func" +// CHECK: sym_name = "island_one_op_all_locs_NOT_same" // CHECK: "tf_executor.graph"() ({ // CHECK-NEXT: "tf_executor.island"() ({ // CHECK-NEXT: "tf.Identity"(%{{.*}}) : (tensor) -> tensor loc("identity@some_function") @@ -53,7 +54,6 @@ func.func @island_one_op_all_locs_same(%arg0: tensor) -> tensor { // CHECK-NEXT: "tf_executor.fetch"(%{{.*}}) : (tensor) -> () loc(unknown) // CHECK-NEXT: }) : () -> tensor loc(unknown) // CHECK-NEXT: "func.return"(%{{.*}}) : (tensor) -> () loc(unknown) -// CHECK-NEXT: sym_name = "island_one_op_all_locs_NOT_same" func.func @island_one_op_all_locs_NOT_same(%arg0: tensor) -> tensor { %0 = "tf_executor.graph"() ({ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py index ef0a95d756f..b8e7715c593 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -48,6 +48,7 @@ def do_test( show_debug_info=False, use_lite=False, lift_variables=True, + include_variables_in_initializers=False, ): """Runs test. @@ -70,6 +71,9 @@ def do_test( use_lite: If true, importer will not do any graph transformation such as lift variables. lift_variables: If false, no variable lifting will be done on the graph. + include_variables_in_initializers: If false, removes variables in + initializer functions before lifting variables or adding new variable + initialization patterns in the initializer function. """ # Make LOG(ERROR) in C++ code show up on the console. @@ -124,6 +128,7 @@ def do_test( exported_names, ','.join([tf.saved_model.tag_constants.SERVING]), lift_variables, + include_variables_in_initializers, upgrade_legacy, show_debug_info, ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/include_variables_in_init_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/include_variables_in_init_v1.py new file mode 100644 index 00000000000..2f99fae8d8d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/include_variables_in_init_v1.py @@ -0,0 +1,87 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/include_variables_in_init_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# Verify that the tf.versions attribute exists. It is difficult to enforce +# contents, since the version numbers change over time. The conversion logic +# itself is verified in the common graphdef converter, so here just assert +# it is being invoked. +# CHECK: module +# CHECK-SAME: tf.versions +# CHECK-SAME: bad_consumers +# CHECK-SAME: min_consumer +# CHECK-SAME: producer + +# CHECK: "tf_saved_model.global_tensor"() +# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[INIT_FUNC:[a-zA-Z_0-9]+]]]} : () -> () + +# Initializer function. This should contain the initialization sequence for the +# variable. +# CHECK: func @[[INIT_FUNC]](%[[ARG_0:.*]]: tensor>> {tf_saved_model.bound_input = @y}) attributes { +# CHECK-SAME: tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_init"] +# CHECK-SAME: tf_saved_model.initializer_type = "init_op" +# CHECK-SAME: } +# CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*dense<.*> : tensor<2xi32>.*}}} : () -> tensor<2xi32> +# CHECK: %[[RAND_STD_NORMAL:.*]] = "tf.RandomStandardNormal"(%[[CST_0]]) +# CHECK: "tf.AssignVariableOp"(%[[ARG_0]], %[[RAND_STD_NORMAL]]){{.*}}: (tensor>>, tensor<1x3xf32>) -> () +# CHECK: return + +# The function for the signature "key". +# CHECK: func {{@[a-zA-Z_0-9]+}}( +# CHECK-SAME: %[[ARG_1:.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["x"]} +# CHECK-SAME: %[[ARG_2:.*]]: tensor>> {tf_saved_model.bound_input = @y} +# CHECK-SAME: -> (tensor<3x3xf32> {tf_saved_model.index_path = ["r"]}) +# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"] +# CHECK-NEXT: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[ARG_2]]) {{{.*}}} : (tensor>>) -> tensor<1x3xf32> +# CHECK-NEXT: %[[MATMUL_0:.*]] = "tf.MatMul"(%[[ARG_1]], %[[READ_VAR_0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> +# CHECK-NEXT: return %[[MATMUL_0]] : tensor<3x3xf32> + + +def Test(): + x = tf.constant([[1.0], [1.0], [1.0]]) + y = tf.compat.v1.get_variable( + name='y', + shape=(1, 3), + initializer=tf.random_normal_initializer(), + trainable=True, + ) + r = tf.matmul(x, y) + + tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) + tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r) + + return ( + { + 'key': ( + tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={'x': tensor_info_x}, + outputs={'r': tensor_info_r}, + method_name='some_function', + ) + ) + }, + tf.initializers.global_variables(), + None, + ) + + +if __name__ == '__main__': + common_v1.set_tf_options() + common_v1.do_test(Test, include_variables_in_initializers=True) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-annotate-dynamic-shape-inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-annotate-dynamic-shape-inputs.mlir new file mode 100644 index 00000000000..accdbdacca8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-annotate-dynamic-shape-inputs.mlir @@ -0,0 +1,28 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-tpu-annotate-dynamic-shape-inputs %s | FileCheck %s + +// Test that annotate the inputs of the cluster func to be dynamic shaped. + +module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func.func @main( + %arg0: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, + %arg1: tensor<2048xi64> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> tensor<2048xi32> { + %cst = "tf.Const"() {value = dense<1024> : tensor} : () -> tensor + %0:2 = "tf_device.launch"() ({ + %1 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2048xi64>) -> tensor<2048xi32> + %2 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2048xi64>) -> tensor<2048xi32> + %3:2 = "tf.TPUCopyWithDynamicShape"(%1, %2, %cst, %cst) {operand_segment_sizes = array} : (tensor<2048xi32>, tensor<2048xi32>, tensor, tensor) -> (tensor<2048xi32>, tensor<2048xi32>) + // CHECK-NOT: tf.TPUAnnotateTensorsWithDynamicShape + %4:2 = "tf.TPUAnnotateTensorsWithDynamicShape"(%3#0, %3#1) : (tensor<2048xi32>, tensor<2048xi32>) -> (tensor<2048xi32>, tensor<2048xi32>) + tf_device.return %4#0, %4#1 : tensor<2048xi32>, tensor<2048xi32> + }) {device = "TPU_REPLICATED_HOST_0"} : () -> (tensor<2048xi32>, tensor<2048xi32>) + %1 = "tf_device.cluster_func"(%0#0, %0#1) {_replication_info = "cluster_test_fn", func = @tpu_func} : (tensor<2048xi32>, tensor<2048xi32>) -> tensor<2048xi32> + return %1: tensor<2048xi32> + } + // CHECK-LABEL: func @tpu_func + // CHECK: mhlo.type_extensions + func.func @tpu_func ( + %arg0: tensor<2048xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2048xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<2048xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2048xi32>, tensor<2048xi32>) -> tensor<2048xi32> + return %0 : tensor<2048xi32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 355085be8b4..db266ed4afe 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -540,6 +540,90 @@ func.func @replicated_non_replicated_output() { // ----- +// TF produces Identity ops between TPUReplicatedOutput and +// TPUPartitionedOutputV2 ops. This test ensures that they are erased +// and not considered within the clustered computation. It also ensures that +// the expected interleaving pattern is present in the output. + +func.func @partitioned_outputs(%arg0: tensor) -> (tensor, tensor, tensor, tensor) { + %pi0 = "tf.TPUPartitionedInputV2"(%arg0) {N = 2, partition_dims = [], _XlaSharding = "", is_packed = true} : (tensor) -> (tensor) + %pi1 = "tf.TPUPartitionedInputV2"(%arg0) {N = 2, partition_dims = [], _XlaSharding = "", is_packed = true} : (tensor) -> (tensor) + %1 = "tf.TPUReplicatedInput"(%pi0, %pi1) {is_mirrored_variable = true, is_packed = false} : (tensor, tensor) -> (tensor) + %2 = "tf.opA"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", is_stateless = true} : (tensor) -> (tensor) + %3:2 = "tf.TPUReplicatedOutput"(%2) : (tensor) -> (tensor, tensor) + %4 = "tf.Identity"(%3#0) : (tensor) -> (tensor) + %5:2 = "tf.TPUPartitionedOutputV2"(%4) {_XlaSharding = "", partition_dims = []} : (tensor) -> (tensor, tensor) + %6 = "tf.Identity"(%3#1) : (tensor) -> (tensor) + %7:2 = "tf.TPUPartitionedOutputV2"(%6) {_XlaSharding = "", partition_dims = []} : (tensor) -> (tensor, tensor) + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 2, num_cores_per_replica = 2, topology = "topology"} : () -> () + func.return %5#0, %5#1, %7#0, %7#1 : tensor, tensor, tensor, tensor +} + +// CHECK: [[REPLICATE:%.+]]:4 = tf_device.replicate +// CHECK: return [[REPLICATE]]#0, [[REPLICATE]]#2, [[REPLICATE]]#1, [[REPLICATE]]#3 + +// ----- + +// Ensures that mixed partitioned and replicated outputs +// works in the multi-replica case. +func.func @mixed_partitioned_outputs(%arg0: tensor) -> (tensor, tensor) { + %pi0 = "tf.TPUPartitionedInputV2"(%arg0) {N = 2, partition_dims = [], _XlaSharding = "", is_packed = true} : (tensor) -> (tensor) + %pi1 = "tf.TPUPartitionedInputV2"(%arg0) {N = 2, partition_dims = [], _XlaSharding = "", is_packed = true} : (tensor) -> (tensor) + %1 = "tf.TPUReplicatedInput"(%pi0, %pi1) {is_mirrored_variable = true, is_packed = false} : (tensor, tensor) -> (tensor) + %2:2 = "tf.opA"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", is_stateless = true} : (tensor) -> (tensor, tensor) + %3:2 = "tf.TPUReplicatedOutput"(%2#0) : (tensor) -> (tensor, tensor) + %5:2 = "tf.TPUPartitionedOutputV2"(%3#0) {_XlaSharding = "", partition_dims = []} : (tensor) -> (tensor, tensor) + %7:2 = "tf.TPUPartitionedOutputV2"(%3#1) {_XlaSharding = "", partition_dims = []} : (tensor) -> (tensor, tensor) + %8:2 = "tf.TPUReplicatedOutput"(%2#1) : (tensor) -> (tensor, tensor) + %9 = "tf.opB"(%5#0, %5#1, %7#0, %7#1) : (tensor, tensor, tensor, tensor) -> (tensor) + %10 = "tf.opC"(%8#0, %8#1) : (tensor, tensor) -> (tensor) + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 2, num_cores_per_replica = 2, topology = "topology"} : () -> () + func.return %9, %10 : tensor, tensor +} + +// CHECK: [[REPLICATE:%.+]]:6 = tf_device.replicate +// CHECK: [[OP_B:%.+]] = "tf.opB"([[REPLICATE]]#0, [[REPLICATE]]#2, [[REPLICATE]]#1, [[REPLICATE]]#3) +// CHECK: [[OP_C:%.+]] = "tf.opC"([[REPLICATE]]#4, [[REPLICATE]]#5) + +// ----- + +// For the single replica case: +// - Ensures that Identity ops are ignored. +// - Checks that mixing TPUPartitionedOutputV2 and TPUReplicatedOutput works. + +func.func @single_replica_mixed_partitioned_outputs(%arg0: tensor) -> (tensor, tensor, tensor) { + %0 = "tf.TPUPartitionedInputV2"(%arg0) {N = 2, partition_dims = [], _XlaSharding = "", is_packed = true} : (tensor) -> (tensor) + %1 = "tf.TPUReplicatedInput"(%0) {is_mirrored_variable = true, is_packed = false} : (tensor) -> (tensor) + %2:2 = "tf.opA"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", is_stateless = true} : (tensor) -> (tensor, tensor) + %3 = "tf.TPUReplicatedOutput"(%2#0) : (tensor) -> (tensor) + %4 = "tf.Identity"(%3) : (tensor) -> (tensor) + %5:2 = "tf.TPUPartitionedOutputV2"(%4) {_XlaSharding = "", partition_dims = []} : (tensor) -> (tensor, tensor) + %6 = "tf.TPUReplicatedOutput"(%2#1) : (tensor) -> (tensor) + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 1, num_cores_per_replica = 2, topology = "topology"} : () -> () + func.return %5#0, %5#1, %6 : tensor, tensor, tensor +} + +// CHECK: [[CLUSTER:%.+]]:2 = "tf_device.cluster" +// CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutputV2"([[CLUSTER]]#0) +// CHECK: return [[OUTPUT]]#0, [[OUTPUT]]#1, [[CLUSTER]]#1 + +// ----- + +func.func @replica_mismatch(%arg0: tensor) { + %pi0 = "tf.TPUPartitionedInputV2"(%arg0) {N = 2, partition_dims = [], _XlaSharding = "", is_packed = true} : (tensor) -> (tensor) + %pi1 = "tf.TPUPartitionedInputV2"(%arg0) {N = 2, partition_dims = [], _XlaSharding = "", is_packed = true} : (tensor) -> (tensor) + %1 = "tf.TPUReplicatedInput"(%pi0, %pi1) {is_mirrored_variable = true, is_packed = false} : (tensor, tensor) -> (tensor) + %2 = "tf.opA"(%1) {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", is_stateless = true} : (tensor) -> (tensor) + %3:2 = "tf.TPUReplicatedOutput"(%2) : (tensor) -> (tensor, tensor) + %4 = "tf.Identity"(%3#0) : (tensor) -> (tensor) + // expected-error@+1 {{expected zero or 2 'TPUPartitionedOutput' op(s), instead got 1}} + %5:2 = "tf.TPUPartitionedOutputV2"(%4) {_XlaSharding = "", partition_dims = []} : (tensor) -> (tensor, tensor) + "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_replicas = 2, num_cores_per_replica = 2, topology = "topology"} : () -> () + func.return +} + +// ----- + // Test cluster with missing `num_replicas` attribute. func.func @missing_num_replicas() { @@ -707,6 +791,19 @@ func.func @valid_compilation_cluster_no_replication_op_device() { // ----- +// Check conflicting device names +// CHECK: "tf_device.cluster"() +// CHECK: "tf.opA"() +// CHECK: "tf.opB"() +// CHECK-NOT: device = +func.func @do_nothing_if_short_names_conflict() { + "tf.opA"() { _xla_compile_device_type = "TPU", device = "/replica:1/task:2/device:TPU:1"} : () -> () + "tf.opB"() { _xla_compile_device_type = "TPU", device = "/replica:3/task:4/device:TPU:1"} : () -> () + func.return +} + +// ----- + // Check non-replicated case, including expected device attr in cluster. // CHECK: "tf_device.cluster"() // CHECK: "tf.opA"() @@ -924,4 +1021,25 @@ func.func @gpu_device() { func.return } +// ----- +// CHECK-LABEL: func @gather_nd +func.func @gather_nd(%arg0: tensor<*x!tf_type.resource>>, + %arg1: tensor<3xf32>) { + // CHECK: ResourceGatherNd + // CHECK: tf_device.cluster + // CHECK: Add + // CHECK: ResourceGatherNd + %0 = "tf.Const"() {value = dense<32> : tensor} : () -> tensor + %1 = "tf.ResourceGatherNd"(%arg0, %0) { + Tindices = i32 + } : (tensor<*x!tf_type.resource>>, tensor) -> tensor<1x80xf32> + %2 = "tf.Add"(%1, %1) { + _xla_compile_device_type = "TPU", + device = "/task:0/device:TPU:0", dtype = f32 + } : (tensor<1x80xf32>, tensor<1x80xf32>) -> tensor<1x80xf32> + %3 = "tf.ResourceGatherNd"(%arg0, %0) { + Tindices = i32 + } : (tensor<*x!tf_type.resource>>, tensor) -> tensor<1x80xf32> + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir index 05a3f483767..b2896fa543f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_composite_resource_ops.mlir @@ -101,7 +101,7 @@ func.func @testNonTPUDeviceReplicationIgnored(%arg0: tensor<*x!tf_type.resource< // CHECK-SAME: (%[[ARG0]] as %[[RI_0:[a-z0-9]*]]: tensor<*x!tf_type.resource>>) tf_device.replicate(%arg0 as %arg1: tensor<*x!tf_type.resource>>) { _mirrored_variable_indices = [0], - devices = {TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:1"]}, + devices = {TPU_REPLICATED_HOST_0 = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:1"]}, n = 2 : i32} { // CHECK: %[[VAL_OUT:.*]] = "tf.A"() : () -> tensor<4xf32> // CHECK-NEXT: "tf.AssignVariableOp"(%[[RI_0]], %[[VAL_OUT]]) @@ -111,7 +111,7 @@ func.func @testNonTPUDeviceReplicationIgnored(%arg0: tensor<*x!tf_type.resource< "tf_device.launch"() ({ "tf.TPUExecuteAndUpdateVariables"(%arg1, %2) {device_var_reads_indices = [0], device_var_updates_indices = [-1]} : (tensor<*x!tf_type.resource>>, tensor<2x!tf_type.string>) -> () tf_device.return - }) {device = "TPU_REPLICATED_HOST"} : () -> () + }) {device = "TPU_REPLICATED_HOST_0"} : () -> () tf_device.return } func.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_splits.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_splits.mlir new file mode 100644 index 00000000000..7c97e85c081 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_colocate_splits.mlir @@ -0,0 +1,44 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-colocate-splits | FileCheck %s + +// CHECK-LABEL: func @colocate_split_with_pred +func.func @colocate_split_with_pred() { + // CHECK: Split + // CHECK-SAME: _class = ["loc:@class"] + tf_executor.graph { + %c, %control0 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %a, %control1 = tf_executor.island wraps "tf.A"() {_class = ["loc:@class"]} : () -> (tensor<2xf32>) + %s:2, %control2 = tf_executor.island wraps "tf.Split"(%c, %a) {num_split = 2 : i32} : (tensor, tensor<2xf32>) -> (tensor<1xf32>, tensor<1xf32>) + tf_executor.fetch + } + func.return +} + +// ----- + +// CHECK-LABEL: func @colocate_split_with_pred_results +func.func @colocate_split_with_pred_results() { + // CHECK: Split + // CHECK-SAME: _class = ["loc:@class"] + tf_executor.graph { + %c, %control0 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %a:2, %control1 = tf_executor.island wraps "tf.A"() {_class = ["loc:@class"]} : () -> (tensor<2xf32>, tensor<2xf32>) + %s:2, %control2 = tf_executor.island wraps "tf.Split"(%c, %a#1) {num_split = 2 : i32} : (tensor, tensor<2xf32>) -> (tensor<1xf32>, tensor<1xf32>) + tf_executor.fetch + } + func.return +} + +// ----- + +// CHECK-LABEL: func @no_colocate_split_has_device +func.func @no_colocate_split_has_device() { + // CHECK: Split + // CHECK-NOT: _class = ["loc:@class"] + tf_executor.graph { + %c, %control0 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %a, %control1 = tf_executor.island wraps "tf.A"() {_class = ["loc:@class"]} : () -> tensor<2xf32> + %s:2, %control2 = tf_executor.island wraps "tf.Split"(%c, %a) {num_split = 2 : i32, device = "device"} : (tensor, tensor<2xf32>) -> (tensor<1xf32>, tensor<1xf32>) + tf_executor.fetch + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir index 0e1c5c79e22..91e4ff2b714 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir @@ -190,7 +190,7 @@ func.func @with_host_process(%arg0: tensor>>, %arg "tf_device.launch"() ({ "tf.OpA"(%1) : (tensor) -> () tf_device.return - }) {device = "TPU_REPLICATED_HOST"} : () -> () + }) {device = "TPU_REPLICATED_HOST_0"} : () -> () tf_device.return }, { %3 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor @@ -231,7 +231,7 @@ func.func @non_replicated_sharding(%arg0: tensor>> "tf_device.launch"() ({ "tf.OpA"(%1) : (tensor) -> () tf_device.return - }) {device = "TPU_REPLICATED_HOST"} : () -> () + }) {device = "TPU_REPLICATED_HOST_0"} : () -> () tf_device.return }, { %3 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor @@ -251,7 +251,7 @@ func.func @packed_replicated(%arg0: tensor>> {tf.d "tf_device.launch"() ({ "tf.OpA"(%1) : (tensor) -> () tf_device.return - }) {device = "TPU_REPLICATED_HOST"} : () -> () + }) {device = "TPU_REPLICATED_HOST_0"} : () -> () tf_device.return }, { %3 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index c81a69f791f..5896c243e2c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -752,7 +752,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor) - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]} + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST_0 = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]} // CHECK-SAME: n = 2 %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]]) @@ -1585,7 +1585,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-LABEL: func @replicated_parallel_execute func.func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) { // CHECK: tf_device.replicate - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"], TPU_REPLICATED_HOST = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:1/device:CPU:0"]} + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"], TPU_REPLICATED_HOST_0 = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:1/device:CPU:0"], TPU_REPLICATED_HOST_1 = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:1/device:CPU:0"]} %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() @@ -2643,3 +2643,90 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", return %1 : tensor } } + +// ----- + +// The following xla.OpSharding is used: +// Proto debug string: +// type : OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// last_tile_dims: REPLICATED +// Serialized string: +// "\08\03\1A\06\01\01\01\01\01\01\22\01\00B\01\00" + +// Test that an input sharding with last_tile_dims REPLICATED won't generate SplitOp. +//CHECK-NOT: tf.Split +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { + func.func @cluster_to_single_core(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster1", func = @_func, num_replica = 1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", device_assignment = [], input_sharding_configuration = ["\08\03\1A\06\01\01\01\01\01\01\22\01\00B\01\00"], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = false, use_tpu = true} : (tensor<128xf32>) -> tensor<128xf32> + func.return %0 : tensor<128xf32> + } + func.func @_func(%arg0: tensor<128xf32>) -> tensor<128xf32> { + func.return %arg0 : tensor<128xf32> + } +} + +// ----- + +// CHECK-LABEL: func @annotate_dynamic_shape_tensor +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1437 : i32}} { + func.func @annotate_dynamic_shape_tensor(%arg0: tensor<512xi64> {tf._user_specified_name = "190", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> (tensor<512xi32>) { + %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster_test_fn", device = ""} : () -> tensor + %cst = "tf.Const"() {value = dense<512> : tensor} : () -> tensor + %2:4 = "tf_device.launch"() ({ + %4 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<512xi64>) -> tensor<512xi32> + %5 = "tf.TPUCopyWithDynamicShape"(%4, %cst) {operand_segment_sizes = array} : (tensor<512xi32>, tensor) -> tensor<512xi32> + tf_device.return %5 : tensor<512xi32> + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<512xi32>, tensor<1024xi32>, tensor<1024xi32>, tensor<1024xf32>) + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:4 = "tf_device.launch" + // CHECK: "tf._TPUCompileMlir"() + // CHECK: is_bounded_dynamic_dim: true + %3 = "tf_device.cluster_func"(%2#0) {_dynamic_arg_index = [0 : i32], _has_manual_control_dependencies = true, _replication_info = "cluster_test_fn", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : (tensor<512xi32>) -> tensor<512xi32> + return %3: tensor<512xi32> + } +func.func private @_func(%arg0: tensor> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<512xi32>) { + %0 = "tf.A"(%arg0) {} : (tensor>) -> tensor<512xi32> + return %0 : tensor<512xi32> + } +} + +// ----- + +// The following xla.OpSharding is used: +// Proto debug string: +// type : OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 4 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// last_tile_dims: REPLICATED +// Serialized string: +// "\08\03\1A\03\01\01\04\22\04\00\01\02\03B\01\00" + +// Test that SplitOp is not generated when an input sharding has +// last_tile_dims REPLICATED and more tile_assignment_dimensions +// than tensor dimenstions, even when the SPMD sharding is enabled and +// num_cores_per_replica is more than 1. +// Test that ConcatV2 Op is not generated when a output sharding has +// last_tile_dims REPLICATED and more tile_assignment_dimensions +// than tensor dimenstions, even when the SPMD sharding is enabled and +// num_cores_per_replica is more than 1. +//CHECK-NOT: tf.Split +// CHECK-NOT: tf.ConcatV2 +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2", "/job:worker/replica:0/task:0/device:TPU:3"]} { + func.func @cluster_to_single_core(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + %0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster1", func = @_func, num_replica = 1, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "\0A\04\02\02\01\01\10\01\18\04\22\10\00\00\00\00\01\00\00\00\00\01\00\00\01\01\00\00", device_assignment = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\03\01\01\04\22\04\00\01\02\03B\01\00"], output_sharding_configuration = ["\08\03\1A\03\01\01\04\22\04\00\01\02\03B\01\00"], use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<4x128xf32>) -> tensor<4x128xf32> + func.return %0 : tensor<4x128xf32> + } + func.func @_func(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> { + func.return %arg0 : tensor<4x128xf32> + } +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 921248cf473..468e3495439 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -364,6 +364,86 @@ func.func @cluster_func(%arg0: tensor<*xf32>) { // ----- +// Tests TPIv2 with a "partially tiled" XLA annotation where: +// type: OTHER +// tile_assignment_dimensions: [4, 1, 1, 1, 2] +// tile_assignment_devices: [0, 1, 2, 3, 4, 5, 6, 7] +// replicate_on_last_tile_dim: true +// Serialized string: +// "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\070\01" + +// CHECK-LABEL: func @partial_tile_partitioned_variable +func.func @partial_tile_partitioned_variable(%arg0: tensor>>) { + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\070\01", partition_dims = [4, 1, 1, 1, 2], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<4x4x4x4xf32> + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\070\01"] + // CHECK-SAME: output_sharding_configuration = [] + // CHECK-SAME: use_spmd_for_xla_partitioning = true + "tf_device.cluster_func"(%1) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 8 : i64} : (tensor<4x4x4x4xf32>) -> () + func.return +} + +// CHECK-LABEL: func @cluster_func +// CHECK-SAME: ({{.+}}: tensor<4x4x4x4xf32> {mhlo.sharding = "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\070\01"}) +func.func @cluster_func(%arg0: tensor<4x4x4x4xf32>) { + func.return +} + +// ----- + +// Tests TPIv2 with a "subgroup tiled" XLA annotation where: +// type: OTHER +// tile_assignment_dimensions: [4, 1, 1, 1, 2] +// tile_assignment_devices: [0, 1, 2, 3, 4, 5, 6, 7] +// last_tile_dims: [REPLICATED] +// Serialized string: +// "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\07B\01\00" + +// CHECK-LABEL: func @subgroup_tile_partitioned_variable +func.func @subgroup_tile_partitioned_variable(%arg0: tensor>>) { + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\07B\01\00", partition_dims = [4, 1, 1, 1, 2], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<4x4x4x4xf32> + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\07B\01\00"] + // CHECK-SAME: output_sharding_configuration = [] + // CHECK-SAME: use_spmd_for_xla_partitioning = true + "tf_device.cluster_func"(%1) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 8 : i64} : (tensor<4x4x4x4xf32>) -> () + func.return +} + +// CHECK-LABEL: func @cluster_func +// CHECK-SAME: ({{.+}}: tensor<4x4x4x4xf32> {mhlo.sharding = "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\07B\01\00"}) +func.func @cluster_func(%arg0: tensor<4x4x4x4xf32>) { + func.return +} + +// ----- + +// Tests TPIv2 with a "partially tiled" XLA annotation where: +// type: OTHER +// tile_assignment_dimensions: [4, 1, 1, 1, 2] +// tile_assignment_devices: [0, 1, 2, 3, 4, 5, 6, 7] +// replicate_on_last_tile_dim: true +// Serialized string: +// "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\070\01" + +// This sharding has an extra dimension than the TPIv2's rank, causing an error. + +func.func @partitioned_input_rank_mismatch(%arg0: tensor>>) { + // expected-error @+1 {{rank}} + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\08\03\1A\05\04\01\01\01\02\22\08\00\01\02\03\04\05\06\070\01", partition_dims = [4, 1, 1, 2], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<4x4x4xf32> + "tf_device.cluster_func"(%1) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 8 : i64} : (tensor<4x4x4xf32>) -> () + func.return +} + +func.func @cluster_func(%arg0: tensor<4x4x4xf32>) { + func.return +} + +// ----- + // Tests partitioned inputs/outputs with no sharding (via XLA SPMD) defaults to // replicate sharding (""). @@ -484,6 +564,43 @@ func.func @func(%arg0: tensor<*xi32> {tf.aliasing_output = 1 : i64}, // ----- +// Partial tiled inputs using XlaSharding ops identified as REPLICATED should keep the sharding configuration. +// The following xla.OpSharding is used: +// Proto debug string: +// type : OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// last_tile_dims: REPLICATED +// Serialized string: +// "\08\03\1A\03\01\01\02\22\02\00\01B\01\00" + +// CHECK-LABEL: func @check_partial_tile_mpmd_fallback +func.func @check_partial_tile_mpmd_fallback(%arg0: tensor<2x7xi64>) -> tensor<2x7xi32> { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["\08\03\1A\03\01\01\02\22\02\00\01B\01\00"] + // CHECK-SAME: output_sharding_configuration = [""] + // CHECK-SAME: use_spmd_for_xla_partitioning = true + %0 = "tf_device.cluster_func"(%arg0) { + func = @func, + use_spmd_for_xla_partitioning = true, num_cores_per_replica = 2 : i64 + } : (tensor<2x7xi64>) -> (tensor<2x7xi32>) + %1 = "tf.Identity"(%0) : (tensor<2x7xi32>) -> tensor<2x7xi32> + func.return %1 : tensor<2x7xi32> +} + +// CHECK-LABEL: func @func +// CHECK-SAME: %arg0: tensor<2x7xi64> {mhlo.sharding = "\08\03\1A\03\01\01\02\22\02\00\01B\01\00" +func.func @func(%arg0: tensor<2x7xi64>) -> (tensor<2x7xi32>) { + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2x7xi64>) -> tensor<2x7xi32> + %1 = "tf.XlaSharding"(%0) {_XlaSharding = "\08\03\1A\03\01\01\02\22\02\00\01B\01\00", sharding = "\08\03\1A\03\01\01\02\22\02\00\01B\01\00", unspecified_dims = []} : (tensor<2x7xi32>) -> tensor<2x7xi32> + func.return %0 : tensor<2x7xi32> +} + +// ----- + // CHECK-LABEL: func @check_arg_sharding_errors func.func @check_arg_sharding_errors(%arg0: tensor<1x2x3xi32>) { // CHECK: tf_device.cluster_func diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir index d61f4cfeadf..7edb50ada79 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir @@ -2,59 +2,142 @@ // CHECK-LABEL: func @num_replicas_replicated func.func @num_replicas_replicated(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () - %ri = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor - %out = "tf.opA"(%ri) : (tensor) -> tensor - %ro:2 = "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) - func.return %ro#0, %ro#1 : tensor, tensor + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor } // ----- func.func @num_replicas_replicated_input(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () - // expected-error @+1 {{'tf.TPUReplicatedInput' op TF/XLA TPU bridge input check: number of inputs inconsistent. num_replicas=2 no. of inputs=3}} - %ri = "tf.TPUReplicatedInput"(%arg0, %arg1, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor, tensor) -> tensor - %ro:2 = "tf.TPUReplicatedOutput"(%ri) : (tensor) -> (tensor, tensor) - func.return %ro#0, %ro#1 : tensor, tensor + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + // expected-error @+1 {{'tf.TPUReplicatedInput' op TF2XLA TPU bridge input check: number of inputs inconsistent. num_replicas=2 no. of inputs=3}} + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor } // ----- func.func @num_replicas_replicated_input_packed(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () - // expected-error @+1 {{'tf.TPUReplicatedInput' op TF/XLA TPU bridge input check: packed with number of inputs not 1. num_replicas=2 no. of inputs=2}} - %ri = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = true} : (tensor, tensor) -> tensor - %ro:2 = "tf.TPUReplicatedOutput"(%ri) : (tensor) -> (tensor, tensor) - func.return %ro#0, %ro#1 : tensor, tensor + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + // expected-error @+1 {{'tf.TPUReplicatedInput' op TF2XLA TPU bridge input check: packed with number of inputs not 1. num_replicas=2 no. of inputs=2}} + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = true} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor } // ----- func.func @num_replicas_replicated_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () - %ri = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor - // expected-error @+1 {{'tf.TPUReplicatedOutput' op TF/XLA TPU bridge input check: number of outputs inconsistent. num_replicas=2 no. of outputs=3}} - %ro:3 = "tf.TPUReplicatedOutput"(%ri) : (tensor) -> (tensor, tensor, tensor) - func.return %ro#0, %ro#1 : tensor, tensor + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + // expected-error @+1 {{'tf.TPUReplicatedOutput' op TF2XLA TPU bridge input check: number of outputs inconsistent. num_replicas=2 no. of outputs=3}} + %ro:3, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor } // ----- func.func @num_core_per_replica_partitioned_input(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () - // expected-error @+1 {{'tf.TPUPartitionedInput' op TF/XLA TPU bridge input check: number of inputs inconsistent. num_cores_per_replica=2 no. of inputs=3}} - %pi = "tf.TPUPartitionedInput"(%arg0, %arg1, %arg1) {index = 1 : i64} : (tensor, tensor, tensor) -> tensor - %po:2 = "tf.TPUPartitionedOutput"(%pi) : (tensor) -> (tensor, tensor) - func.return %po#0, %po#1 : tensor, tensor + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + // expected-error @+1 {{'tf.TPUPartitionedInput' op TF2XLA TPU bridge input check: number of inputs inconsistent. num_cores_per_replica=2 no. of inputs=3}} + %pi, %c0 = tf_executor.island wraps "tf.TPUPartitionedInput"(%arg0, %arg1, %arg1) {index = 1 : i64} : (tensor, tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%pi) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %po:2, %c2 = tf_executor.island wraps "tf.TPUPartitionedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %po#0, %po#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor } // ----- func.func @num_core_per_replica_partitioned_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () - %pi = "tf.TPUPartitionedInput"(%arg0, %arg1) {index = 1 : i64} : (tensor, tensor) -> tensor - // expected-error @+1 {{'tf.TPUPartitionedOutput' op TF/XLA TPU bridge input check: number of outputs inconsistent. num_cores_per_replica=2 no. of outputs=3}} - %po:3 = "tf.TPUPartitionedOutput"(%pi) : (tensor) -> (tensor, tensor, tensor) - func.return %po#0, %po#1 : tensor, tensor -} \ No newline at end of file + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_cores_per_replica = 2 : i64, num_replicas = 1 : i64, topology = "topology"} : () -> () + %pi, %c0 = tf_executor.island wraps "tf.TPUPartitionedInput"(%arg0, %arg1) {index = 1 : i64} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%pi) {_tpu_replicate = "cluster"} : (tensor) -> tensor + // expected-error @+1 {{'tf.TPUPartitionedOutput' op TF2XLA TPU bridge input check: number of outputs inconsistent. num_cores_per_replica=2 no. of outputs=3}} + %po:3, %c2 = tf_executor.island wraps "tf.TPUPartitionedOutput"(%out) : (tensor) -> (tensor, tensor, tensor) + tf_executor.fetch %po#0, %po#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// ----- + +func.func @validate_tpu_replicate_no_attr(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate="cluster"}: (tensor) -> tensor + // expected-warning @+1 {{TF2XLA TPU bridge input check: cluster op = tf.opA with cluster = cluster has successor as non cluster op tf.opB}} + %out2, %c2 = tf_executor.island wraps "tf.opB"(%out) : (tensor) -> tensor + // expected-error @+1 {{tf.TPUReplicatedOutput' op TF2XLA TPU bridge input check: non-cluster op = tf.opB has invalid successor op = tf.TPUReplicatedOutput}} + %ro:2, %c4 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out2) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// ----- + +func.func @validate_tpu_replicate_wrong_attr(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster_wrong"}: (tensor) -> tensor + // expected-error @+1 {{'tf.opB' op TF2XLA TPU bridge input check: mismatch clusters tpu_replicate attr. Parent op tf.opA with cluster = cluster_wrong has successor cluster op tf.opB with cluster = cluster}} + %out2, %c2 = tf_executor.island wraps "tf.opB"(%out) {_tpu_replicate = "cluster"}: (tensor) -> tensor + %ro:2, %c3 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out2) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// ----- + +func.func @valid_xla_nonxla(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster", device = "TPU"} : (tensor) -> tensor + %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// ----- + +func.func @valid_xla_nonxla_warning(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor<*x!tf_type.string>, tensor<*x!tf_type.string>) { + %0:2 = tf_executor.graph { + %control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor<*x!tf_type.string> + // expected-warning @+1 {{TF/XLA TPU bridge input check: found invalid op. tf.Identity can't be both xla and non-xla}} + %out, %c1 = tf_executor.island(%c0) wraps "tf.Identity"(%ri) {_tpu_replicate = "cluster", device = ""} : (tensor<*x!tf_type.string>) -> tensor<*x!tf_type.string> + %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor<*x!tf_type.string>) -> (tensor<*x!tf_type.string>, tensor<*x!tf_type.string>) + tf_executor.fetch %ro#0, %ro#1 : tensor<*x!tf_type.string>, tensor<*x!tf_type.string> + } + return %0#0, %0#1 : tensor<*x!tf_type.string>, tensor<*x!tf_type.string> +} + +// ----- \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/transpose-op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/transpose-op.mlir new file mode 100644 index 00000000000..d719977dc36 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/transpose-op.mlir @@ -0,0 +1,10 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics + +func.func @out_of_bounds_check(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[0, 0x4141, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + // expected-error @+1 {{'tf.Transpose' op perm[1]=16705 must be in range [-4, 4)}} + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + func.return %3 : tensor<1x4x4x8xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir index 973fe031d75..675ae224f6b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir @@ -1,72 +1,50 @@ // RUN: tf-opt %s -split-input-file -tf-xla-rewrite | FileCheck %s -// CHECK-LABEL: func.func @convert_partitioned_call -func.func @convert_partitioned_call(%arg0: tensor) -> tensor { - %0 = "tf_device.cluster"() ({ - // CHECK: "tf.XlaLaunch"(%arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @pcall_func, operand_segment_sizes = array} : (tensor) -> tensor - %1 = "tf.PartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @pcall_func} : (tensor) -> (tensor) - tf_device.return %1 : tensor - }) : () -> tensor - func.return %0 : tensor -} -func.func @pcall_func(%arg0: tensor) -> tensor { - func.return %arg0 : tensor +module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} { + // CHECK-LABEL: func.func @convert_cluster_func + func.func @convert_cluster_func(%arg0: tensor) -> tensor { + // CHECK: "tf.XlaLaunch"(%arg0) {function = @func, operand_segment_sizes = array} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {func = @func} : (tensor) -> tensor + func.return %0 : tensor + } + + func.func @func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor + } } // ----- -// CHECK-LABEL: func.func @convert_stateful_partitioned_call -func.func @convert_stateful_partitioned_call(%arg0: tensor) -> tensor { - %0 = "tf_device.cluster"() ({ - // CHECK: "tf.XlaLaunch"(%arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func, operand_segment_sizes = array} : (tensor) -> tensor - %1 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func} : (tensor) -> (tensor) - tf_device.return %1 : tensor - }) : () -> tensor +module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} { + // CHECK-LABEL: func.func @convert_cluster_func_with_resources_in_order + func.func @convert_cluster_func_with_resources_in_order(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {function = @func_with_resources_in_order, operand_segment_sizes = array} : (tensor, tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg1, %arg0) {func = @func_with_resources_in_order} : (tensor, tensor) -> (tensor) + func.return %0 : tensor + } - func.return %0 : tensor -} - -func.func @stateful_pcall_func(%arg0: tensor) -> tensor { - func.return %arg0 : tensor + func.func @func_with_resources_in_order(%arg0 : tensor, %arg1 : tensor) -> tensor { + func.return %arg0 : tensor + } } // ----- -// CHECK-LABEL: func.func @convert_stateful_partitioned_call_with_resources_in_order -func.func @convert_stateful_partitioned_call_with_resources_in_order(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf_device.cluster"() ({ - // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources_in_order, operand_segment_sizes = array} : (tensor, tensor) -> tensor - %1 = "tf.StatefulPartitionedCall"(%arg1, %arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources_in_order} : (tensor, tensor) -> (tensor) - tf_device.return %1 : tensor - }) : () -> tensor - func.return %0 : tensor -} +module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:GPU:0"]} { + // CHECK-LABEL: func.func @convert_cluster_func_with_resources + func.func @convert_cluster_func_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {function = @func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources} : (tensor, tensor) -> tensor + // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {function = @func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor + %1 = "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_resources} : (tensor, tensor) -> tensor + return %0 : tensor + } -func.func @stateful_pcall_func_with_resources_in_order(%arg0 : tensor, %arg1 : tensor) -> tensor { - func.return %arg0 : tensor -} - -// ----- - -// CHECK-LABEL: func.func @convert_stateful_partitioned_call_with_resources -func.func @convert_stateful_partitioned_call_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf_device.cluster"() ({ - // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor - %2 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources} : (tensor, tensor) -> tensor - tf_device.return %2 : tensor - }) : () -> tensor - %1 = "tf_device.cluster"() ({ - // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor - %2 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources} : (tensor, tensor) -> tensor - tf_device.return %2 : tensor - }) : () -> tensor - return %0 : tensor -} - -// CHECK-LABEL: func.func @stateful_pcall_func_with_resources -// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor -// CHECK: return %arg0 : tensor -func.func @stateful_pcall_func_with_resources(%arg0 : tensor, %arg1: tensor) -> tensor { - func.return %arg1 : tensor + // CHECK-LABEL: func.func @func_with_resources + // CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor + // CHECK: return %arg0 : tensor + func.func @func_with_resources(%arg0 : tensor, %arg1: tensor) -> tensor { + func.return %arg1 : tensor + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index d604cb247b7..74f5a458b0b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -26,22 +26,29 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/platform/error_payloads.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" +#include "tensorflow/core/util/debug_data_dumper.h" namespace mlir { namespace { - // Add logger to bridge passmanager. // Enable timing statistics per pass for the bridge passmanager. -void EnableDetailedLogging(PassManager *pm) { +void EnableDetailedLogging(PassManager *pm, + llvm::StringRef module_name = llvm::StringRef()) { // Print the whole module after each pass, which requires disabling // multi-threading as well. pm->getContext()->disableMultithreading(); - pm->enableIRPrinting(std::make_unique( + pm->enableIRPrinting(std::make_unique<::tensorflow::DataDumperLoggerConfig>( + [module_name](const std::string &pass_tag_name) { + return DEBUG_DATA_DUMPER()->GetDumpFilename( + module_name.str(), kDebugGroupBridgePhase1, pass_tag_name); + }, + "", /*print_module_scope=*/true)); pm->enableTiming(); } @@ -50,11 +57,24 @@ void EnableDetailedLogging(PassManager *pm) { namespace TFTPU { namespace { +std::string GetMLIRModuleText(mlir::Operation *op, + const mlir::PassManager *pass_manager) { + std::string module_txt; + llvm::raw_string_ostream os(module_txt); + + if (pass_manager) ::tensorflow::PrintPassPipeline(*pass_manager, op, os); + + op->print(os, mlir::OpPrintingFlags().useLocalScope()); + + return os.str(); +} + // Run the TF XLA Bridge based on the input pipeline, which can be either TPU // bridge pipeline or non TPU bridge pipeline. tensorflow::Status RunTFXLABridge( - ModuleOp module, bool enable_logging, - llvm::function_ref pipeline_builder) { + ModuleOp module, + llvm::function_ref pipeline_builder, + llvm::StringRef module_name = llvm::StringRef()) { // Explicitly check that the TensorFlow dialect can constant fold ops. // Constant folding is essential for the bridge. Without this check, the // bridge may fail with an error that is difficult to understand and not @@ -76,18 +96,35 @@ tensorflow::Status RunTFXLABridge( module.getContext(), /*propagate=*/false, /*filter_stack=*/!VLOG_IS_ON(1)); - if (enable_logging || VLOG_IS_ON(1)) { - tensorflow::DumpMlirOpToFile("tf_xla_bridge_before", module, "", &bridge); - if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge); + if (VLOG_IS_ON(1) || + DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(), kDebugGroupMain)) { + ::tensorflow::DumpMlirOpToFile( + DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain, + "tf_xla_bridge_before"), + module, llvm::StringRef(), &bridge); } + + if (VLOG_IS_ON(2) || DEBUG_DATA_DUMPER()->ShouldDump( + module_name.str(), kDebugGroupBridgePhase1)) { + EnableDetailedLogging(&bridge, module_name); + } + LogicalResult result = bridge.run(module); (void)result; - if (enable_logging || VLOG_IS_ON(1)) - tensorflow::DumpMlirOpToFile("tf_xla_bridge_after", module, "", &bridge); + + if (VLOG_IS_ON(1) || + DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(), kDebugGroupMain)) { + ::tensorflow::DumpMlirOpToFile( + DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain, + "tf_xla_bridge_after"), + module, llvm::StringRef(), &bridge); + } + return diag_handler.ConsumeStatus(); } -void CreateTPUBridgePipelineImpl(OpPassManager &pm) { +void CreateTPUBridgePipelineImpl( + OpPassManager &pm, llvm::StringRef module_name = llvm::StringRef()) { // The following ops must be preserved regardless of reachability. Ideally, // all graphs should have control dependencies to enforce this but this is // currently not the case (see b/177478741). @@ -111,6 +148,7 @@ void CreateTPUBridgePipelineImpl(OpPassManager &pm) { pm.addNestedPass( CreateTPUReorderReplicateAndPartitionedInputsPass()); pm.addNestedPass(TF::CreateDecomposeReduceDatasetPass()); + pm.addPass(TFDevice::CreateEmbeddingPipeliningPass()); pm.addPass(CreateTPUClusterFormationPass()); // Run TPU cluster cleanup attributes so ops with no outside compiled // attribute have no host device attribute. @@ -192,11 +230,13 @@ void CreateTPUBridgePipelineImpl(OpPassManager &pm) { pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass()); pm.addNestedPass( mlir::TF::CreateRewriteTPUEmbeddingOpsPass()); - pm.addPass(CreateTPURewritePass()); + pm.addPass(CreateTPUAnnotateDynamicShapeInputsPass()); + pm.addPass(CreateTPURewritePass(module_name)); pm.addPass(createSymbolDCEPass()); pm.addNestedPass( TFDevice::CreateReplicateInvariantOpHoistingPass()); pm.addPass(CreateTPUMergeVariablesWithExecutePass()); + pm.addNestedPass(CreateExtractTPUCopyWithDynamicShapeOpPass()); pm.addNestedPass( TF::CreateHoistReplicateInvariantResourceWritesPass()); pm.addNestedPass(CreateTPUColocateCompositeResourceOps()); @@ -205,11 +245,11 @@ void CreateTPUBridgePipelineImpl(OpPassManager &pm) { } } // namespace -void CreateTPUBridgePipeline(OpPassManager &pm) { +void CreateTPUBridgePipeline(OpPassManager &pm, llvm::StringRef module_name) { pm.addPass(CreateTPUValidateInputsPass()); pm.addNestedPass( TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); - CreateTPUBridgePipelineImpl(pm); + CreateTPUBridgePipelineImpl(pm, module_name); } void CreateTPUBridgePipelineV1(OpPassManager &pm) { @@ -238,27 +278,32 @@ void CreateTPUBridgePipelineV1(OpPassManager &pm) { CreateConvertToLegacyCompileAndReplicateAttributesPass()); } -tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging, - bool fallback_enabled) { - Status status = RunTFXLABridge(module, enable_logging, [](OpPassManager &pm) { - CreateTPUBridgePipeline(pm); - // Add set of passes to lower back to graph (from tf_executor). - // Use graph export pipline V2 in TPU Bridge. - // TODO(hanxiong): Completely replace AddGraphExportLoweringPasses with - // AddGraphExortLoweringPassessV2 in all the code paths (V1 compat pipeline, - // CPU/GPU bridge, etc.) - TF::AddGraphExportLoweringPassesV2(pm); - }); +tensorflow::Status TPUBridge(ModuleOp module, bool fallback_enabled, + llvm::StringRef module_name) { + Status status = RunTFXLABridge( + module, + [module_name](OpPassManager &pm) { + CreateTPUBridgePipeline(pm, module_name); + // Add set of passes to lower back to graph + // (from tf_executor). Use graph export + // pipline V2 in TPU Bridge. + // TODO(hanxiong): Completely replace + // AddGraphExportLoweringPasses with + // AddGraphExortLoweringPassessV2 in all the + // code paths (V1 compat pipeline, CPU/GPU + // bridge, etc.) + TF::AddGraphExportLoweringPassesV2(pm); + }, + module_name); tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( "tpu", "v2", fallback_enabled, status.ok() ? "success" : "failure"); - OkOrSetErrorCounterPayload( + tsl::OkOrSetErrorCounterPayload( tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1, status); return status; } -tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging, - bool fallback_enabled) { - Status status = RunTFXLABridge(module, enable_logging, [](OpPassManager &pm) { +tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool fallback_enabled) { + Status status = RunTFXLABridge(module, [](OpPassManager &pm) { CreateTPUBridgePipelineV1(pm); // Add set of passes to lower back to graph (from tf_executor). TF::AddGraphExportLoweringPasses(pm); @@ -272,6 +317,8 @@ tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging, namespace TF { +void NoCanonicalization(OpPassManager &pm) {} + void AddGraphExportLoweringPasses(OpPassManager &pm) { auto add_pass = [&](std::unique_ptr pass) { pm.addNestedPass(std::move(pass)); @@ -286,6 +333,7 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) { add_pass(TFDevice::CreateLaunchToDeviceAttributePass( /*legacy_graph_export=*/true)); pm.addNestedPass(TFTPU::CreateTPUDevicePropagationPass()); + pm.addNestedPass(TFTPU::CreateTPUColocateSplitsPass()); pm.addPass(createSymbolDCEPass()); if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_convert_control_to_data_outputs_pass) { @@ -317,6 +365,7 @@ void AddGraphExportLoweringPassesV2(OpPassManager &pm) { pm.addPass(tf_executor::CreateTFExecutorUpdateControlDependenciesPass()); pm.addNestedPass(TFTPU::CreateTPUDevicePropagationPass()); + pm.addNestedPass(TFTPU::CreateTPUColocateSplitsPass()); pm.addPass(createSymbolDCEPass()); if (tensorflow::GetMlirCommonFlags() ->tf_mlir_enable_convert_control_to_data_outputs_pass) { @@ -382,6 +431,13 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { // shapes. pm.addPass(TF::CreateTFShapeInferencePass()); pm.addNestedPass(createCanonicalizerPass()); + // Inline all the function calls. Do not call canonicalizer to prevent it from + // moving the definition of any constant operand of ops within a cluster to + // its outside. This may cause the op to fail to verify after the cluster is + // outlined, as the constant operand is replaced by an argument. + pm.addPass(mlir::createInlinerPass({}, NoCanonicalization)); + // Lift resource operations out of device computation. This step needs to be + // done after inlining. pm.addPass(TFDevice::CreateResourceOpLiftingPass()); // TODO(b/267193636): Remove this flag when outside compilation // for generic pipeline is landed. @@ -391,10 +447,10 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { pm.addPass(TFDevice::CreateExtractHeadTailOutsideCompilationPass()); pm.addPass(TFDevice::CreateExtractOutsideCompilationPass()); } + // Outline clusters into cluster functions. + pm.addPass(TFDevice::CreateClusterOutliningPass()); // Rewrite cluster functions into XLA launch ops. pm.addPass(TFDevice::CreateXlaRewritePass()); - // Inline the cluster ops. - pm.addPass(TFDevice::CreateXlaInlineDeviceOpsPass()); // Re-run the canonicalizer pass as some cleanup during resource op lifting // pass opens up some opportunities for canonicalization of cluster ops. // Specifically, we want to eliminate pass through results from the cluster @@ -406,13 +462,16 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); } -tensorflow::Status RunTFXLABridge(ModuleOp module, bool enable_logging) { +tensorflow::Status RunTFXLABridge(ModuleOp module, + llvm::StringRef module_name) { Status status = mlir::TFTPU::RunTFXLABridge( - module, enable_logging, [](OpPassManager &pm) { + module, + [](OpPassManager &pm) { CreateTFXLABridgePipeline(pm); // Add set of passes to lower back to graph (from tf_executor). TF::AddGraphExportLoweringPasses(pm); - }); + }, + module_name); tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( /*device type*/ "cpu/gpu", /*bridge version*/ "tfxla", /*fallback_enabled*/ false, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h index 925149dd843..b0125cc592a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/core/lib/core/status.h" @@ -24,20 +26,18 @@ namespace mlir { namespace TFTPU { // Run all the passes involved in transforming the graph before execution so -// that it is suitable for targeting TPUs. When enable_logging is true, enables -// tensorflow::BridgeLogger. When fallback_enabled is true, it means if the -// bridge fails the old bridge will run. This is used for logging and doesn't -// affect any logic. -tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging, - bool fallback_enabled = false); +// that it is suitable for targeting TPUs. When fallback_enabled is true, it +// means if the bridge fails the old bridge will run. This is used for logging +// and doesn't affect any logic. +tensorflow::Status TPUBridge(ModuleOp module, bool fallback_enabled = false, + llvm::StringRef module_name = llvm::StringRef()); // Run all the passes involved in transforming the graph before execution so -// that it is suitable for targeting TPUs. When enable_logging is true, enables -// tensorflow::BridgeLogger. When fallback_enabled is true, it means if the -// bridge fails the old bridge will run. This is used for logging and doesn't -// affect any logic. +// that it is suitable for targeting TPUs. When fallback_enabled is true, it +// means if the bridge fails the old bridge will run. This is used for logging +// and doesn't affect any logic. // This variant of `TPUBridge` is intended for TensorFlow V1 compatibility. -tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging, +tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool fallback_enabled = false); } // namespace TFTPU @@ -56,7 +56,8 @@ tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module, bool enable_inliner); // Runs all passes for non TPU (GPU and CPU) graph. -tensorflow::Status RunTFXLABridge(ModuleOp module, bool enable_logging); +tensorflow::Status RunTFXLABridge( + ModuleOp module, llvm::StringRef module_name = llvm::StringRef()); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc index c05581fd202..19850ddc3aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge_pass.cc @@ -33,7 +33,9 @@ mlir::PassPipelineRegistration<> tpu_pipeline( "tf-tpu-bridge", "Run all the passes involved in transforming the graph before execution so " "that it is suitable for targeting TPUs.", - mlir::TFTPU::CreateTPUBridgePipeline); + [](mlir::OpPassManager& pm) { + return mlir::TFTPU::CreateTPUBridgePipeline(pm); + }); // Registers a pipeline builder function for TF TPU V1 bridge. mlir::PassPipelineRegistration<> tpu_pipeline_v1( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h index e04d1323352..6d27780316f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h @@ -47,7 +47,7 @@ LogicalResult GetOutermostOpsOfType( auto v = symtab.lookup(sym.getRootReference()); if (!v) { // This is not expected to happen in practice. - v.emitError() << "Cannot find function " << sym.getRootReference(); + op->emitError() << "Cannot find function " << sym.getRootReference(); return WalkResult::interrupt(); } worklist.push(v); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc index 25118033f65..4b409ffe1f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc @@ -428,7 +428,7 @@ static ClusteringState InitializeClusteringState( } // Initialize mapping from the member operation (block argument) to the id. - for (auto &tuple : llvm::enumerate(state.members)) { + for (const auto &tuple : llvm::enumerate(state.members)) { state.member_ids.try_emplace(tuple.value().source, tuple.index()); } @@ -471,7 +471,7 @@ static bool RunClusteringPass(ClusteringState &state, const ClusteringPolicySet &policies) { bool clustered = false; - for (auto &tuple : llvm::enumerate(state.members)) { + for (const auto &tuple : llvm::enumerate(state.members)) { size_t member_id = tuple.index(); Member &member = tuple.value(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 23986108112..7bb1ef5a10b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -20,6 +24,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project @@ -40,6 +45,10 @@ constexpr char kFuncAttr[] = "func"; struct ClusterOutliningPass : public impl::ClusterOutliningPassBase { + explicit ClusterOutliningPass(bool globally_unique_func_names) { + globally_unique_func_names_ = globally_unique_func_names; + } + void runOnOperation() override; }; @@ -48,6 +57,10 @@ struct ClusterOutliningPass struct LaunchOutliningPass : public impl::LaunchOutliningPassBase { + explicit LaunchOutliningPass(bool globally_unique_func_names) { + globally_unique_func_names_ = globally_unique_func_names; + } + void runOnOperation() override; }; @@ -62,17 +75,29 @@ void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op, // and inserts built function into given module. template func::FuncOp BuildFunction(llvm::ArrayRef live_ins, ClusterOrLaunchOp op, - SymbolTable* symbol_table, OpBuilder* builder) { + SymbolTable* symbol_table, OpBuilder* builder, + bool globally_unique_func_names) { llvm::SmallVector operand_types; operand_types.reserve(live_ins.size()); for (Value v : live_ins) operand_types.emplace_back(v.getType()); auto func_type = builder->getFunctionType(operand_types, op.getResultTypes()); - // TODO(lyandy): Define better name for outlined function. Potentially some - // name can be added during cluster formation. + std::string func_name; + if (globally_unique_func_names) { + // While processing XLA launch ops, signatures are created for each function + // to decide if a function has been compiled. Function signatures are + // decided by function name and input types. By giving each function a + // unique name, we make sure the same signature is not incorrectly given to + // functions of different graphs with same name and input type. + func_name = + absl::StrCat("_func_", size_t(OperationEquivalence::computeHash(op))); + } else { + func_name = "_func"; + } + func::FuncOp outlined_func = - func::FuncOp::create(op.getLoc(), "_func", func_type); + func::FuncOp::create(op.getLoc(), func_name, func_type); // This function is not externally visible and marking it private would allow // symbol-dce pass to remove it when it is not referenced anymore. @@ -108,13 +133,14 @@ func::FuncOp BuildFunction(llvm::ArrayRef live_ins, ClusterOrLaunchOp op, // `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is // removed afterwards.` void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, - OpBuilder* builder) { + OpBuilder* builder, bool globally_unique_func_names) { llvm::SetVector live_ins; getUsedValuesDefinedAbove(cluster_op.getBody(), cluster_op.getBody(), live_ins); func::FuncOp outlined_func = - BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder); + BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder, + globally_unique_func_names); cluster_op->setAttr( builder->getStringAttr(kFuncAttr), mlir::SymbolRefAttr::get(builder->getContext(), outlined_func.getName())); @@ -135,12 +161,13 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, // `tf_device.launch_func` to invoke that function. `tf_device.launch` is // removed afterwards.` void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table, - OpBuilder* builder) { + OpBuilder* builder, bool globally_unique_func_names) { llvm::SetVector live_ins; getUsedValuesDefinedAbove(launch_op.getBody(), launch_op.getBody(), live_ins); func::FuncOp outlined_func = - BuildFunction(live_ins.getArrayRef(), launch_op, symbol_table, builder); + BuildFunction(live_ins.getArrayRef(), launch_op, symbol_table, builder, + globally_unique_func_names); launch_op->setAttr( builder->getStringAttr(kFuncAttr), mlir::SymbolRefAttr::get(builder->getContext(), outlined_func.getName())); @@ -159,7 +186,8 @@ void ClusterOutliningPass::runOnOperation() { SymbolTable symbol_table(module); OpBuilder builder(module.getContext()); module.walk([&](tf_device::ClusterOp cluster) { - OutlineCluster(cluster, &symbol_table, &builder); + OutlineCluster(cluster, &symbol_table, &builder, + globally_unique_func_names_.getValue()); }); } @@ -168,18 +196,21 @@ void LaunchOutliningPass::runOnOperation() { SymbolTable symbol_table(module); OpBuilder builder(module.getContext()); module.walk([&](tf_device::LaunchOp launch) { - OutlineLaunch(launch, &symbol_table, &builder); + OutlineLaunch(launch, &symbol_table, &builder, + globally_unique_func_names_.getValue()); }); } } // namespace -std::unique_ptr> CreateClusterOutliningPass() { - return std::make_unique(); +std::unique_ptr> CreateClusterOutliningPass( + bool globally_unique_func_names) { + return std::make_unique(globally_unique_func_names); } -std::unique_ptr> CreateLaunchOutliningPass() { - return std::make_unique(); +std::unique_ptr> CreateLaunchOutliningPass( + bool globally_unique_func_names) { + return std::make_unique(globally_unique_func_names); } } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 4eb8f987a4d..481f2d868e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -16,6 +16,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h" #include +#include +#include +#include +#include #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project @@ -24,13 +28,27 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/tsl/util/device_name_utils.h" namespace mlir { namespace TF { +static bool IsOk(const tensorflow::Status& s) { + if (s.ok()) return true; + VLOG(2) << s.message(); + return false; +} + +#define RETURN_FAILURE_IF_ERROR(expr) \ + if (!IsOk(expr)) { \ + return mlir::failure(); \ + } + // Implements a TF specific policy on when constant folding is allowed. // Policy: // @@ -71,15 +89,129 @@ static bool ShouldBeFolded(Operation* inst) { #ifdef TF_DISABLE_CONSTANT_FOLDING constexpr int64_t kResultsSizeThreshold = 0; #else - constexpr int64_t kResultsSizeThreshold = (1 << 23); // 1 MB + constexpr int64_t kResultsSizeThreshold = (1 << 23); // 1 MB #endif - constexpr int64_t kOperandsSizeThreshold = (1 << 30); // 1 GB + constexpr int64_t kOperandsSizeThreshold = (1 << 30); // 128 MB return (operands_size <= kOperandsSizeThreshold) && (has_unknown_shape || (results_size <= kResultsSizeThreshold) || (results_size <= kSizeFactor * operands_size)); } +static const tensorflow::tfrt_stub::FallbackState& GetDefaultFallbackState() { + static const auto* const fallback_state = []() { + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + auto fallback_state = + tensorflow::tfrt_stub::FallbackState::CreateWithCpuDevice( + session_options, fdef_lib) + .value(); + return fallback_state.release(); + }(); + + return *fallback_state; +} + +static std::function)>* GetDefaultRunner() { + static auto* const default_runner = + new std::function)>( + [](const std::function& f) { f(); }); + return default_runner; +} + +static mlir::LogicalResult EvaluateOperation( + mlir::Operation* inst, llvm::ArrayRef operands, + llvm::SmallVectorImpl* results) { + // If any operand is nullptr returns true for a failure. + // TODO(b/120678030): remove this constraint if we find operators can be + // evaluated with some unknown operands. + if (std::any_of(operands.begin(), operands.end(), + [](mlir::Attribute operand) { return !operand; })) { + VLOG(1) << "Can't evaluate since not all operands are constant."; + return mlir::failure(); + } + + // Builds TF operation and sets all the attributes. + std::string node_name = "unnamed"; + if (auto attr = inst->getAttrOfType("name")) { + node_name = std::string(attr.getValue()); + } + auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef( + inst, node_name.c_str(), /*ignore_unregistered_attrs=*/true); + RETURN_FAILURE_IF_ERROR(node_def_or.status()); + const auto& node_def = node_def_or.value(); + + const auto& fallback_state = GetDefaultFallbackState(); + + // Explicitly set device to Host CPU instead of the device present in device + // attribute of the MLIR op. The assigned device might be remote, not + // available during compilation or compilation only device for on demand + // execution which may create a recursion if used for constant folding. + auto host_cpu = tensorflow::DeviceNameUtils::FullName( + /*job=*/"localhost", /*replica=*/0, /*task=*/0, /*type=*/"CPU", /*id=*/0); + + auto statusor_runner = tensorflow::tfrt_stub::OpKernelRunner::Create( + node_def->op(), node_def->name(), host_cpu, operands.size(), + [&](tensorflow::AttrValueMap* attr_value_map) { + *attr_value_map = node_def->attr(); + return tensorflow::OkStatus(); + }, + fallback_state.device_manager(), + fallback_state.process_function_library_runtime()); + RETURN_FAILURE_IF_ERROR(statusor_runner.status()); + const auto& runner = *statusor_runner; + + VLOG(1) << "Start to evaluate node: " << node_def->DebugString(); + + std::vector inputs; + + // Adds inputs to the TF operation. + for (const auto operand : operands) { + tensorflow::Tensor tensor; + RETURN_FAILURE_IF_ERROR(tensorflow::ConvertToTensor(operand, &tensor)); + inputs.push_back(std::move(tensor)); + } + + std::vector input_values; + for (auto& tensor : inputs) { + input_values.emplace_back(); + input_values.back().tensor = &tensor; + } + + tensorflow::OpKernelContext::Params params; + params.inputs = input_values; + params.device = runner.device(); + params.op_kernel = runner.op_kernel(); + // Still use original device's resource_manager. + params.resource_manager = runner.resource_manager(); + params.input_alloc_attrs = runner.input_alloc_attrs(); + params.output_attr_array = runner.output_alloc_attrs().data(); + // Following two parameters are used to support executing tf.data via + // fallback. + params.function_library = runner.function_library_runtime(); + params.runner = GetDefaultRunner(); + + // Executes the TF operation. + tensorflow::OpKernelContext op_kernel_context(¶ms); + runner.Run(&op_kernel_context); + RETURN_FAILURE_IF_ERROR(op_kernel_context.status()); + + // Converts the outputs to MLIR attributes. + mlir::Builder builder(inst->getContext()); + + for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { + DCHECK(op_kernel_context.mutable_output(i)); + auto attr_or = tensorflow::ConvertTensor( + *op_kernel_context.mutable_output(i), &builder); + RETURN_FAILURE_IF_ERROR(attr_or.status()); + results->push_back(attr_or.value()); + } + + VLOG(1) << "Evaluate node " << node_name << " successfully!"; + + return mlir::success(); +} + LogicalResult ConstantFoldFallbackHook( Operation* inst, ArrayRef operands, SmallVectorImpl& results) { // NOLINT @@ -136,13 +268,6 @@ LogicalResult ConstantFoldFallbackHook( // size/size increase due to folding. if (!ShouldBeFolded(inst)) return failure(); - // TODO(jpienaar): Currently this persists the entire program execution. This - // should instead be per module/set from the Graph being executed in TF (if - // any) so that the value of variables in the context could be read. - // Note: Sharing the context is fine as ops are side-effect free. - static TFE_Context* ctx = GetContextForConstantFold(); - if (!ctx) return failure(); - // Returns directly if any of the operands is not an elements attributes. if (std::any_of(operands.begin(), operands.end(), [](Attribute attr) { return !attr || !attr.isa(); @@ -160,8 +285,7 @@ LogicalResult ConstantFoldFallbackHook( static auto* mu = new tensorflow::mutex(); tensorflow::mutex_lock l(*mu); SmallVector constants; - LogicalResult status = - tensorflow::EvaluateOperation(inst, inputs, ctx, &constants); + LogicalResult status = EvaluateOperation(inst, inputs, &constants); results.assign(constants.begin(), constants.end()); return status; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 4c379b4e5b5..51438ac4901 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -48,6 +48,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" #include "tensorflow/core/util/matmul_bcast.h" @@ -82,6 +83,27 @@ arith::ConstantOp createI64ConstantOp(llvm::ArrayRef values, return rewriter->create(loc, values_type, constant_attr); } +// Function to create a tf.SumOp to sum the element in 'value' reduced along the +// 'redux_axes'. +TF::SumOp createSumOp(Value value, Location loc, + llvm::ArrayRef redux_axes, + PatternRewriter* rewriter) { + Value redux_op = createI32ConstantOp(redux_axes, loc, rewriter); + + auto value_type = value.getType().cast(); + auto shape = value_type.getShape(); + llvm::SmallVector sum_shape; + for (int i = 0; i < shape.size(); ++i) { + if (std::find(redux_axes.begin(), redux_axes.end(), i) == + redux_axes.end()) { + sum_shape.push_back(shape[i]); + } + } + return rewriter->create( + loc, RankedTensorType::get(sum_shape, value_type.getElementType()), value, + redux_op); +} + TF::TransposeOp createTransposeOp(Value value, Location loc, llvm::ArrayRef permutation, PatternRewriter* rewriter) { @@ -344,6 +366,61 @@ std::tuple FlattenEllipsis( return std::make_tuple(new_lhs, new_rhs, new_output); } +// vectors/maps to map the dimensions of lhs with output in unary einsum op +std::optional GetEinsumDimensionNumbersUnary( + llvm::StringRef equation, RankedTensorType lhs_ty) { + llvm::StringRef lhs; + llvm::StringRef out; + std::tie(lhs, out) = equation.split("->"); + if (lhs.empty() || out.empty()) return std::nullopt; + + // Try to flatten the "..." if possible. + int lhs_named_label, rhs_named_label; + + // following rhs and rhs_ty variables are non-functional here only created to + // comply with the existing API + llvm::StringRef rhs; + RankedTensorType rhs_ty; + + auto available_labels = + GetAvailableLabels(lhs, rhs, &lhs_named_label, &rhs_named_label); + if (!available_labels.has_value()) return std::nullopt; + + auto flattended_labels = + FlattenEllipsis(lhs, lhs_named_label, rhs, rhs_named_label, out, lhs_ty, + rhs_ty, available_labels.value()); + + lhs = std::get<0>(flattended_labels); + out = std::get<2>(flattended_labels); + + auto lhs_map_or = EquationToMap(lhs); + if (!lhs_map_or.has_value()) return std::nullopt; + auto lhs_map = lhs_map_or.value(); + + auto out_map_or = EquationToMap(out); + if (!out_map_or.has_value()) return std::nullopt; + auto out_map = out_map_or.value(); + + EinsumDimensionNumbers dnums; + for (int64_t i = 0; i < lhs.size(); ++i) { + auto out_index = out_map.find(lhs[i]); + if (out_index == out_map.end()) { + dnums.lhs.emplace_back(i); + } else { + dnums.lhs_out.emplace_back(i, out_index->second); + } + } + + for (int64_t i = 0; i < out.size(); ++i) { + auto lhs_index = lhs_map.find(out[i]); + if (lhs_index == lhs_map.end()) { + // out only isn't supported + return std::nullopt; + } + } + return dnums; +} + std::optional GetEinsumDimensionNumbers( llvm::StringRef equation, RankedTensorType lhs_ty, RankedTensorType rhs_ty) { @@ -419,6 +496,62 @@ std::optional GetEinsumDimensionNumbers( return dnums; } +// Function to replace a unary einsum op, that can undergo simple transpose, to +// an explicit transpose op. +LogicalResult rewriteToReduceSumAndTranspose(TF::EinsumOp op, + EinsumDimensionNumbers dnums, + PatternRewriter& rewriter) { + auto inputs = op.getInputs(); + Value lhs = inputs.front(); + + // Having indices in dnums.lhs list indicates that the ranks of the input and + // output to the unary einsum are not equal making it non-candidate for simple + // transpose. + bool needs_reduce_sum = false; + if (!dnums.lhs.empty()) { + needs_reduce_sum = true; + llvm::SmallVector reduce_idcs(dnums.lhs.size()); + for (int64_t i = 0; i < dnums.lhs.size(); ++i) { + reduce_idcs[i] = dnums.lhs[i]; + } + + lhs = createSumOp(lhs, lhs.getLoc(), reduce_idcs, &rewriter); + } + + llvm::SmallVector lhs_transpose; + lhs_transpose.reserve(dnums.lhs_out.size()); + + llvm::SmallDenseMap out_lhs_map(dnums.lhs_out.size()); + for (int64_t i = 0; i < dnums.lhs_out.size(); ++i) { + out_lhs_map[std::get<1>(dnums.lhs_out[i])] = std::get<0>(dnums.lhs_out[i]); + } + + bool needs_transpose = false; + for (int64_t i = 0; i < dnums.lhs_out.size(); ++i) { + if (std::get<0>(dnums.lhs_out[i]) > + lhs.getType().cast().getRank() - 1) { + continue; + } + + if (std::get<0>(dnums.lhs_out[i]) != std::get<1>(dnums.lhs_out[i])) { + needs_transpose = true; + } + lhs_transpose.push_back(out_lhs_map[i]); + } + + if (!needs_reduce_sum && !needs_transpose) { + return rewriter.notifyMatchFailure( + op, "unary einsum equation does not require transpose"); + } else if (needs_reduce_sum && !needs_transpose) { + rewriter.replaceOp(op, lhs); + return success(); + } + + lhs = createTransposeOp(lhs, lhs.getLoc(), lhs_transpose, &rewriter); + rewriter.replaceOp(op, lhs); + return success(); +} + std::vector inverseTransposeVector( llvm::ArrayRef input, llvm::ArrayRef permutation) { std::vector output(input.size()); @@ -682,6 +815,27 @@ LogicalResult rewriteToBatchMatmul(TF::EinsumOp op, return success(); } +LogicalResult matchAndRewriteUnaryEinsumOp(TF::EinsumOp op, + PatternRewriter& rewriter) { + if (op->getNumOperands() != 1) { + return rewriter.notifyMatchFailure( + op, "Function only supports unary einsum op"); + } + RankedTensorType lhs = + op.getOperand(0).getType().dyn_cast_or_null(); + if (!lhs) { + return failure(); + } + // unary einsum op is only supported to the case where the operation can be + // replaced using reduce_sum and/or transpose + if (const auto dnums_or = + GetEinsumDimensionNumbersUnary(op.getEquation(), lhs)) { + return rewriteToReduceSumAndTranspose(op, dnums_or.value(), rewriter); + } + + return rewriter.notifyMatchFailure(op, "unsupported einsum lowering"); +} + #define GEN_PASS_DEF_TRANSFORMEINSUMPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" @@ -703,6 +857,10 @@ void TransformEinsumPass::runOnOperation() { LogicalResult ConvertTFEinsumOp::matchAndRewrite( TF::EinsumOp op, PatternRewriter& rewriter) const { + if (op->getNumOperands() == 1) { + return matchAndRewriteUnaryEinsumOp(op, rewriter); + } + RankedTensorType lhs = op.getOperand(0).getType().dyn_cast_or_null(); RankedTensorType rhs = @@ -711,10 +869,10 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( return failure(); } - // TODO(b/162328998) Better support Einsum with dynamic input. Currently, one - // dynamic dimension is always supported. If there are two or more dynamic - // dimensions, it is supported if they only exist in a single component - // among: L0,...,Ln R0,...,Rn or C0,...,Cn. + // TODO(b/162328998) Better support Einsum with dynamic input. Currently, + // one dynamic dimension is always supported. If there are two or more + // dynamic dimensions, it is supported if they only exist in a single + // component among: L0,...,Ln R0,...,Rn or C0,...,Cn. if (const auto dnums_or = GetEinsumDimensionNumbers(op.getEquation(), lhs, rhs)) return rewriteToBatchMatmul(op, dnums_or.value(), rewriter); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc new file mode 100644 index 00000000000..e5671bf5961 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc @@ -0,0 +1,922 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This pass implements automated pipelining for TPU embeddings defined using +// the TF2 Embedding API. This is designed for applications that have an +// embedding lookup on the SparseCore, followed by one or more dense layers on +// TensorCores, optionally followed by a backward pass (training update) with +// more ops on the SparseCore. Ops are broken up into: +// 1. SC forward pass +// 2. TC forward/backward pass +// 3. SC backward pass +// 4. non-TPU loop counter updates +// These 4 functions are then staggered so as to enable parallel execution. + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" + +#define GEN_PASS_DEF_EMBEDDINGPIPELININGPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining"; +static constexpr char kEmbeddingForward[] = "forward"; +static constexpr char kEmbeddingBackward[] = "backward"; +static constexpr char kDevice[] = "device"; +static constexpr llvm::StringRef kTpuCompilationStatus = + "_tpu_compilation_status"; + +namespace mlir { +namespace TFDevice { +namespace { + +struct EmbeddingPipeliningPass + : public ::impl::EmbeddingPipeliningPassBase { + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; + +template +std::vector GetValueTypes(const InputContainer& input) { + // Convert a list of mlir::Value's into a list of mlir::Type's + std::vector types; + types.reserve(input.size()); + for (auto val : input) types.push_back(val.getType()); + return types; +} + +bool IsResourceType(Type val_type) { + if (auto tensor_type = val_type.dyn_cast()) { + if (tensor_type.getElementType().isa()) { + return true; + } + } + return false; +} + +bool IsTPUOp(mlir::Operation* op) { + return op->hasAttr(TF::kReplicationInfoAttr); +} + +StringAttr GetReplicationAttr(mlir::Operation* op) { + return op->getAttrOfType(TF::kReplicationInfoAttr); +} + +StringAttr GetReplicationAttr(TF::TPUCompilationResultOp op) { + // Special case for getting the replication region for + // TPUCompilationResultsOp. + return op->getAttrOfType(kTpuCompilationStatus); +} + +int64_t GetNumOps(func::FuncOp func) { + int64_t num_ops = 0; + for (auto it = func.begin(); it != func.end(); ++it) ++num_ops; + return num_ops; +} + +void GatherOpsForExtraction(mlir::SetVector* operations, + const mlir::SetVector& ops_to_avoid, + bool predecessors, bool successors) { + // Walk the input and output dependencies of the Ops in `operations` to form + // the closer of Ops needed to evaluate 'operations'. Input dependencies are + // walked if 'predecessors' is true and output dependencies are walked if + // 'successors' is true. In either case, if a discoverd Op is in the + // 'ops_to_avoid' set, then the dependency walking is terminated. + llvm::SetVector ops_to_process(*operations); + llvm::SetVector new_ops; + + while (!ops_to_process.empty()) { + for (Operation* op : ops_to_process) { + if (predecessors) { + for (Value operand : op->getOperands()) { + // Stop at the block boundary. + if (operand.isa()) continue; + + Operation* predecessor = operand.getDefiningOp(); + if (!operations->contains(predecessor) && + !ops_to_avoid.contains(predecessor)) { + new_ops.insert(operand.getDefiningOp()); + operations->insert(operand.getDefiningOp()); + } + } + } + if (successors) { + for (mlir::Operation* successor : op->getUsers()) { + // Don't include the return op + if (llvm::isa(successor)) continue; + + if (!operations->contains(successor) && + !ops_to_avoid.contains(successor)) { + new_ops.insert(successor); + operations->insert(successor); + } + } + } + } + ops_to_process.swap(new_ops); + new_ops.clear(); + } +} + +TF::StatefulPartitionedCallOp MakeFuncCaller( + mlir::OpBuilder& builder, const Location& loc, func::FuncOp func, + const llvm::SetVector& operands) { + // Constructs a tf.StatefulPartitionedCall to the function provided in 'func' + // using the operands in 'operands'. Assumes the insertion point on builder is + // already set. + auto symbol = + mlir::SymbolRefAttr::get(builder.getContext(), func.getSymName()); + auto result_types = func.getResultTypes(); + auto caller = builder.create( + loc, result_types, operands.getArrayRef(), symbol, + /*config=*/builder.getStringAttr(""), + /*config_proto=*/builder.getStringAttr(""), + /*executor_type=*/builder.getStringAttr("")); + caller.setFAttr(symbol); + return caller; +} + +func::FuncOp CreateFnWithSignature(ModuleOp module, + const llvm::SetVector& inputs, + const llvm::SetVector& outputs, + const std::string& name) { + // Creates an empty func.FuncOp with a signature compatible with 'inputs' + // (operands) and 'outputs' (results). + OpBuilder builder(module); + + std::vector input_types = GetValueTypes(inputs); + std::vector output_types = GetValueTypes(outputs); + builder.setInsertionPointToEnd(&module.getBodyRegion().back()); + func::FuncOp func_op = builder.create( + module.getLoc(), name, + builder.getFunctionType(input_types, output_types)); + func_op.setPrivate(); + + return func_op; +} + +TF::StatefulPartitionedCallOp EncapsulateOpsInFunc( + OpBuilder& builder, const llvm::SetVector& ops, + const llvm::SetVector& inputs, const llvm::SetVector& outputs, + func::FuncOp parent_func, ModuleOp module, const std::string& name) { + // Moves all of the Operations in 'ops' into a newly created func.FuncOp + // function named 'name' and replaces the original ops with a call to the + // newly created function using a tf.StatefulPartitionedCall. Here, + // 'parent_func' is the function that holds the original set of ops. + // Note, 'inputs' and 'outputs' are the predetermined set of values that + // should become the operands and return values, respectively. + auto insertion_point = builder.saveInsertionPoint(); + func::FuncOp new_func = CreateFnWithSignature(module, inputs, outputs, + absl::StrCat("_func_", name)); + + // This preserves the order of the ops that was in the original parent + // funtion. This is critical for preserving correctness in the presence of + // resource variables and stateful functions. + std::vector topological_order; + for (Operation& op : parent_func.getOps()) + if (ops.contains(&op)) topological_order.push_back(&op); + + // Create the partitioned call + builder.restoreInsertionPoint(insertion_point); + auto caller = MakeFuncCaller(builder, module.getLoc(), new_func, inputs); + + Block* block = new_func.addEntryBlock(); + + for (Operation* op : topological_order) op->moveBefore(block, block->end()); + + // Replace the 'inputs' values with the new function's arguments. + for (auto p : llvm::zip(inputs, new_func.getArguments())) + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + new_func.getBody()); + + builder.setInsertionPointToEnd(block); + builder.create(parent_func.getLoc(), outputs.getArrayRef()); + + // Replace the original 'outputs' values with the result of the call to the + // new function. + for (auto p : llvm::zip(outputs, caller->getResults())) + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + parent_func.getBody()); + + return caller; +} + +void UpdateAndInsertTPUOps(TF::StatefulPartitionedCallOp caller, + TF::TPUReplicateMetadataOp metadata_op, + TF::TPUCompilationResultOp compilation_op, + StringAttr old_group) { + // Adds the TPUReplicateMetatdataOp and TPUCompilationResultOp ops to the + // function called by the provided 'caller'. + mlir::CallInterfaceCallable callable = caller.getCallableForCallee(); + mlir::SymbolRefAttr sym = callable.dyn_cast(); + auto func = llvm::dyn_cast( + mlir::SymbolTable::lookupNearestSymbolFrom(caller, sym)); + OpBuilder builder(func.getBody()); + + StringAttr new_group = builder.getStringAttr( + absl::StrCat(old_group.getValue().str(), caller.getF().str())); + + builder.insert(metadata_op.clone()); + for (Operation& op : func.getOps()) { + if (!IsTPUOp(&op)) continue; + op.setAttr(TF::kReplicationInfoAttr, new_group); + } + TF::TPUCompilationResultOp new_result = compilation_op.clone(); + new_result->setAttr(kTpuCompilationStatus, new_group); + builder.insert(new_result); +} + +template +LogicalResult FindAndExcludeOp(func::FuncOp func, + const StringAttr& replication_attr, + llvm::SetVector& merged_set, + OpType& found_op) { + // Find the TPUReplicationMetadata or TPUCompilationResult ops which will be + // cloned/inserted into each region. We add them to the merged_set so that + // they're ignored when extracting the four main functions. + found_op = nullptr; + for (OpType op : func.getOps()) { + if (found_op != nullptr) { + func.emitOpError() << "number of " << found_op.getOperationName() + << " in loop body is not 1"; + return LogicalResult::failure(); + } + if (GetReplicationAttr(op) != replication_attr) { + op.emitOpError() << "is not part of the replication region " + << replication_attr << " vs " << GetReplicationAttr(op); + return LogicalResult::failure(); + } + found_op = op; + merged_set.insert(found_op); + } + return LogicalResult::success(); +} + +LogicalResult FindOwningWhileOp(func::FuncOp body_func, ModuleOp module, + TF::WhileOp* while_op) { + // Given a while loop body function 'body_func', find the tf.While Op that + // uses it. + auto uses_optional = body_func.getSymbolUses(module); + if (!uses_optional.has_value()) { + body_func.emitOpError() << "no use of while loop body"; + return LogicalResult::failure(); + } + *while_op = nullptr; + for (auto& use : uses_optional.value()) { + if (llvm::isa(use.getUser())) { + if (*while_op != nullptr) { + use.getUser()->emitOpError() << "multiple users of function."; + return LogicalResult::failure(); + } else { + *while_op = llvm::cast(use.getUser()); + } + } else { + use.getUser()->emitOpError() << "non while use of function."; + return LogicalResult::failure(); + } + } + // TODO(bfontain): If the while op is not present we could just split things + // or we wait until the compiler supports multiple regions? + if (while_op == nullptr) { + body_func.emitOpError() << "unable to find while body user."; + return LogicalResult::failure(); + } + return LogicalResult::success(); +} + +LogicalResult FindForwardPassOps(OpBuilder& builder, + llvm::SetVector& forward_pass_ops, + llvm::SetVector& backward_pass_ops, + llvm::SetVector& merged_set, + func::FuncOp loop_body_func, + const int num_replicas) { + // Find all the ops that are to be included in the 'sc_forward' function which + // will be executed on the SparseCore. Note, 'forward_pass_ops' is initially + // seeded with ops from the input MLIR graph that have the + // _embedding_pipelining="forward" attribute which is set by the TF2 Embedding + // API. + // + // When outputs of the forward pass function are used outside of it, we'll + // need to insert a TPUReplicatedOutput Op and include that in the + // forward_pass_ops. And if that usage is also on the TPU (either TensorCore + // or SparseCore) we'll need to insert a matching TPUReplicatedInput. We do + // this before the Ops are removed from the original function/graph so that + // function operands and return values are handled automatically. + + // First, walk the op dependencies. + GatherOpsForExtraction(&forward_pass_ops, merged_set, /*predecessors=*/true, + /*successors=*/false); + + // Locate which variable inputs are part of the forwards pass. These will + // also be used in the backwards pass. We need to create a 'private' copy + // of the TpuReplicatedInput for for the fowards pass if there are users + // outside the pass. Note that in the case of the backwards pass existing + // this will be the case. + // This means that when we have put all out sections together some resource + // inputs will have multiple TPUReplicateInput nodes, so we will need a final + // pass to merge these together into the earliest copy. + llvm::SetVector forward_variable_inputs; + + // Validate that the only resource inputs that are read by ops in + // forward_pass_ops are dataset and variable ops. + int64_t resource_count = 0; + for (auto argument : loop_body_func.getArguments()) { + // Check that all resource arguments are either fed to iterator get next + // or a TPUReplicatedInput with is_packed. + + if (IsResourceType(argument.getType())) { + resource_count++; + bool is_variable = false; + bool is_non_variable = false; + bool use_in_forward = false; + bool use_in_not_forward = false; + for (auto user : argument.getUsers()) { + if (llvm::isa(user)) continue; + if (!forward_pass_ops.contains(user)) { + use_in_not_forward = true; + } else { + use_in_forward = true; + } + if (TF::TPUReplicatedInputOp input = + llvm::dyn_cast(user)) { + if (!input.getIsPacked()) { + input.emitOpError() << "unexpected variable input, not packed"; + return LogicalResult::failure(); + } + + if (is_variable) { + input.emitOpError() << "unexpected multiple TPUReplicatedInputOp " + << "for single argument"; + return LogicalResult::failure(); + } + is_variable = true; + } else { + is_non_variable = true; + } + } + if (use_in_forward && use_in_not_forward) { + loop_body_func.emitOpError() + << "resource input " << argument.getArgNumber() + << " is used both in the forwards and " + << "not forward passes dataset"; + return LogicalResult::failure(); + } + if (is_non_variable && is_variable) { + loop_body_func.emitOpError() + << "resource input " << argument.getArgNumber() + << " is used both as a varible and not " + << " a variable"; + return LogicalResult::failure(); + } + if (is_variable && use_in_forward) + forward_variable_inputs.insert(argument.getArgNumber()); + } + } + + VLOG(3) << "Found " << forward_variable_inputs.size() + << " variables used in forward pass of " << resource_count + << " total resource inputs"; + + // Clone the TPUReplicatedInputs. + int64_t cloned_inputs = 0; + for (int64_t index : forward_variable_inputs) { + Value argument = loop_body_func.getArgument(index); + // Uses of this argument should only be the return and the + // TPUReplicateInputOp. This is checked by the loop above. + Operation* input_ptr = nullptr; + for (Operation* user : argument.getUsers()) { + if (llvm::isa(user)) { + input_ptr = user; + break; + } + } + TF::TPUReplicatedInputOp input = + llvm::cast(input_ptr); + + // Validate that all users of the TPUReplicatedInput are ReadVariable + // or AssignVariable ops and check if any are outside the forwards pass. + bool duplicate_needed = false; + for (Operation* next_user : input.getOutput().getUsers()) { + if (!llvm::isa(next_user) && + !llvm::isa(next_user)) { + next_user->emitOpError() + << "unexpected user of output of TPUReplicatedInputOp"; + return LogicalResult::failure(); + } + if (!forward_pass_ops.contains(next_user)) duplicate_needed = true; + } + if (!duplicate_needed) continue; + + cloned_inputs++; + builder.setInsertionPointAfter(input); + forward_pass_ops.remove(input); + + TF::TPUReplicatedInputOp private_input = input.clone(); + builder.insert(private_input); + forward_pass_ops.insert(private_input); + for (OpOperand& next_use : input.getOutput().getUses()) { + if (!forward_pass_ops.contains(next_use.getOwner())) continue; + next_use.getOwner()->setOperand(next_use.getOperandNumber(), + private_input.getOutput()); + } + } + + VLOG(2) << "Cloned " << cloned_inputs << " TPUReplicatedInputOps"; + + // Add TPUReplicatedInput/TPUReplicatedOutput pairs along each edge. + llvm::SetVector new_forward_ops; + for (Operation* op : forward_pass_ops) { + // TODO(bfontain): Should validate that all the TPU ops are in the same + // replication region. + if (!IsTPUOp(op)) continue; + for (Value result : op->getResults()) { + std::vector> out_of_region_use; + for (OpOperand& use : result.getUses()) { + auto use_owner = use.getOwner(); + // TODO(bfontain): Error check here, if the use.getOwner() is not a TPU + // then this op must be a TPUReplicatedOutputOp. + if (IsTPUOp(use_owner) && !forward_pass_ops.contains(use_owner)) + out_of_region_use.push_back( + std::make_pair(use_owner, use.getOperandNumber())); + } + if (out_of_region_use.empty()) continue; + builder.setInsertionPointAfter(op); + std::vector types(num_replicas, result.getType()); + TF::TPUReplicatedOutputOp replicated_output = + builder.create(op->getLoc(), + TypeRange(types), result); + new_forward_ops.insert(replicated_output); + // TODO(bfontain): Check for other attributes. + replicated_output->setAttr(kDevice, builder.getStringAttr("")); + TF::TPUReplicatedInputOp input = builder.create( + op->getLoc(), result.getType(), replicated_output.getResults()); + input->setAttr(kDevice, builder.getStringAttr("")); + mlir::Value new_value = input.getOutput(); + + if (mlir::isa( + result.getDefiningOp())) { + TF::TPUAnnotateTensorsWithDynamicShapeOp annotate_op = + builder.create( + op->getLoc(), result.getType(), new_value, + result.getDefiningOp()->getAttrs()); + for (auto [operation, index] : out_of_region_use) { + if (!backward_pass_ops.contains(operation)) { + operation->emitOpError() + << "expect all dynamic inputs consumed by backwards pass."; + return LogicalResult::failure(); + } + } + + backward_pass_ops.insert(annotate_op); + new_value = annotate_op->getResult(0); + } + for (auto [operation, index] : out_of_region_use) + operation->setOperand(index, new_value); + } + } + + VLOG(2) << "inserted " << new_forward_ops.size() << " TPU Input/Output ops"; + forward_pass_ops.insert(new_forward_ops.begin(), new_forward_ops.end()); + return LogicalResult::success(); +} + +LogicalResult FindBackwardPassOps( + OpBuilder& builder, llvm::SetVector& backward_pass_ops, + llvm::SetVector& merged_set, const int num_replicas) { + // Find all the ops that are to be included in the 'sc_backward' function + // which will be executed on the SparseCore. Note, 'backward_pass_ops' is + // initially seeded with ops from the input MLIR graph that have the + // _embedding_pipelining="backward" attribute which is set by the TF2 + // Embedding API. + // + // Since we're inserting a replication boundary around the backward pass + // function, we'll also need to make sure TPUReplicatedInputOp and + // TPUReplicatedOutputOp ops are inserted as necessary. + + // First, walk the Ops dependencies. + GatherOpsForExtraction(&backward_pass_ops, merged_set, /*predecessors=*/false, + /*successors=*/true); + + VLOG(3) << "found " << backward_pass_ops.size() << " backwards pass ops"; + + // If any inputs are to the backward_pass_ops region are direct + // TPUReplicatedInput ops, then include (if this is the only use) or + // clone the op. This will be the case for all Read/Assign variable ops. + + llvm::SetVector to_clone; + llvm::SetVector to_insert; + + for (Operation* op : backward_pass_ops) { + for (OpOperand& input_value : op->getOpOperands()) { + Operation* predecessor_op = input_value.get().getDefiningOp(); + if (TF::TPUReplicatedInputOp input = + llvm::dyn_cast(predecessor_op)) { + if (to_clone.contains(input) || to_insert.contains(input)) continue; + // Check if all uses in backwards pass. + bool all_in_backwards = true; + for (Operation* user : input->getUsers()) + if (!backward_pass_ops.contains(user)) all_in_backwards = false; + if (all_in_backwards) + to_insert.insert(input); + else + to_clone.insert(input); + } + } + } + backward_pass_ops.insert(to_insert.begin(), to_insert.end()); + for (TF::TPUReplicatedInputOp input : to_clone) { + builder.setInsertionPointAfter(input); + TF::TPUReplicatedInputOp private_input = input.clone(); + builder.insert(private_input); + backward_pass_ops.insert(private_input); + for (OpOperand& next_use : input.getOutput().getUses()) { + if (!backward_pass_ops.contains(next_use.getOwner())) continue; + next_use.getOwner()->setOperand(next_use.getOperandNumber(), + private_input.getOutput()); + } + } + + VLOG(2) << " cloned " << to_clone.size() << " and inserted " + << to_insert.size() << " TPUReplicatedInput ops"; + + // For all other inputs that go from TPU op to TPU op, insert the + // TPUOutput/Input pair. + + // Add TPUReplicatedInput/TPUReplicatedOutput pairs along each edge. + // TODO(bfontain): Should be merged with the above loop. + llvm::SetVector values_to_add_nodes; + + for (Operation* op : backward_pass_ops) { + // TODO(bfontain): Should validate that all the TPU ops are in the same + // replication region. + // If the op is already a replicated input, no need to to anything. + if (!IsTPUOp(op) || llvm::isa(op)) continue; + for (OpOperand& input_value : op->getOpOperands()) + // TODO(bfontain): Error check here, this line should never be false, + // since we skip the TF::TPUReplicatedInputOp case. + if (IsTPUOp(input_value.get().getDefiningOp()) && + !backward_pass_ops.contains(input_value.get().getDefiningOp())) + values_to_add_nodes.insert(input_value.get()); + } + + for (Value value : values_to_add_nodes) { + builder.setInsertionPointAfter(value.getDefiningOp()); + std::vector types(num_replicas, value.getType()); + Location loc = value.getDefiningOp()->getLoc(); + TF::TPUReplicatedOutputOp output = + builder.create(loc, TypeRange(types), value); + // TODO(bfontain): Check for other attributes. + output->setAttr(kDevice, builder.getStringAttr("")); + TF::TPUReplicatedInputOp input = builder.create( + loc, value.getType(), output.getResults()); + input->setAttr(kDevice, builder.getStringAttr("")); + for (OpOperand& use : value.getUses()) + if (backward_pass_ops.contains(use.getOwner())) + use.getOwner()->setOperand(use.getOperandNumber(), input.getOutput()); + backward_pass_ops.insert(input); + } + + VLOG(2) << " inserted " << values_to_add_nodes.size() + << " TPUReplicatedInput/Output pairs"; + return LogicalResult::success(); +} + +LogicalResult FindCoreTPUOps( + llvm::SetVector& core_tpu_ops, + const llvm::SetVector& forward_pass_ops, + const llvm::SetVector& backward_pass_ops, + const llvm::SetVector& merged_set, + func::FuncOp loop_body_func) { + // Find all of the Ops that are part of the forward/backward pass but aren't + // targeting the SparseCore. Note that we need to include some non-TPU ops + // that flow out of the forward pass function. Otherwise, they would get + // absorbed into the non_tpu function which breaks the pipelining + // decomposition strategy. + // + // Find all the outputs of the forward pass that aren't fed into the backward + // pass. + for (Operation* op : forward_pass_ops) { + for (Value res : op->getResults()) { + for (auto user : res.getUsers()) { + if (!forward_pass_ops.contains(user) && + !backward_pass_ops.contains(user)) { + core_tpu_ops.insert(user); + } + } + } + } + + // Gather all TPU ops marked for compilation in this while loop body that also + // are not in one of the two other sets. + for (Operation& op : loop_body_func.getOps()) { + // Find all TPU ops that don't belong to the forward or backward pass. + if (merged_set.contains(&op) || llvm::isa(op) || + !IsTPUOp(&op) || op.hasAttr(kEmbeddingPipelining)) + continue; + // TODO(bfontain): only collect those ops in a fixed TPUReplica. + core_tpu_ops.insert(&op); + } + + GatherOpsForExtraction(&core_tpu_ops, merged_set, /*predecessors=*/true, + /*successors=*/true); + + // TODO(patn): Verify that all the ops here fall between the forward pass + // and backward pass ops (i.e., not before the forward pass or after the + // backward pass). + return LogicalResult::success(); +} + +LogicalResult FindNonTPUOps(llvm::SetVector& non_tpu_ops, + const llvm::SetVector& merged_set, + func::FuncOp loop_body_func) { + // Find all of the left over Ops after the sc_forward, sc_backward and + // core_tpu ops have been identified. What's left are just the ops necessary + // for updating loop counters etc. + llvm::SetVector non_tpu_args; + for (Operation& op : loop_body_func.getOps()) { + if (merged_set.contains(&op) || llvm::isa(op) || + op.hasAttr(kEmbeddingPipelining)) + continue; + // Note, there should be no TPU ops left at this point. If this trips, + // there's likely a bug in this pass. + if (IsTPUOp(&op)) { + loop_body_func.emitOpError() + << "Unexpcted TPU op found while identifying non-TPU ops."; + return LogicalResult::failure(); + } + non_tpu_ops.insert(&op); + } + + // Validate that remainder_ops takes and returns a subset of the loop carried + // args. This will basically be our set increment fn. + for (Operation* op : non_tpu_ops) + for (Value input : op->getOperands()) + if (BlockArgument arg = llvm::dyn_cast(input)) + // TODO(bfontain): Check that this is actually an argument to the loop + // body. + non_tpu_args.insert(arg.getArgNumber()); + + // All funcs have a return op so this should be safe. + func::ReturnOp return_op = *loop_body_func.getOps().begin(); + + for (OpOperand& operand : return_op->getOpOperands()) { + if (non_tpu_args.contains(operand.getOperandNumber())) { + if (BlockArgument argument = + llvm::dyn_cast(operand.get())) { + if (argument.getArgNumber() != operand.getOperandNumber()) { + return_op.emitOpError() + << "non TPU ops do not divide state into two pieces."; + return LogicalResult::failure(); + } + } else if (!non_tpu_ops.contains(operand.get().getDefiningOp())) { + return_op.emitOpError() + << "non TPU ops do not divide state into two pieces."; + return LogicalResult::failure(); + } + } + } + return LogicalResult::success(); +} + +LogicalResult ExtractOpsAsFunc( + OpBuilder& builder, ModuleOp module, llvm::SetVector& ops, + StringAttr replication_attr, TF::TPUReplicateMetadataOp metadata_op, + TF::TPUCompilationResultOp compilation_op, func::FuncOp parent_func, + const std::string& func_name, Operation** caller) { + // Move the given set of 'ops' into it's own function and replace them with a + // call to that function ('caller'). if 'metadata_op' and 'compilation_op' are + // non-null, also insert those (i.e., target the resulting function to the + // TPU). Here, 'parent_func' is the func.FuncOp that owns the ops in 'ops'. + // + // Returns in 'caller' a tf.StatefulPartitionedCallOp that calls the function + // that was extracted.. + + // Find the input edges to form the set of operands to the new function call. + llvm::SetVector inputs; + for (Operation* op : ops) { + for (Value operand : op->getOperands()) { + Operation* defining_op = operand.getDefiningOp(); + if (!ops.contains(defining_op)) inputs.insert(operand); + } + } + // Find the output edges to form the set of resutls of the new function call. + llvm::SetVector results; + for (Operation* op : ops) { + for (auto result : op->getResults()) { + for (const OpOperand& operand : result.getUsers()) { + if (!ops.contains(operand.getOwner())) { + results.insert(result); + break; + } + } + } + } + llvm::SetVector outputs; + for (auto output : results) outputs.insert(output); + auto tf_caller = EncapsulateOpsInFunc(builder, ops, inputs, outputs, + parent_func, module, func_name); + if (!ops.empty() && metadata_op != nullptr && compilation_op != nullptr) + UpdateAndInsertTPUOps(tf_caller, metadata_op, compilation_op, + replication_attr); + *caller = tf_caller; + return LogicalResult::success(); +} + +void EmbeddingPipeliningPass::runOnOperation() { + ModuleOp module = getOperation(); + + llvm::SetVector forward_pass_ops; + llvm::SetVector backward_pass_ops; + + // Find all ops that we know compose the embedding forward and backward pass. + // These ops are only tagged if one enables the + // `pipeline_execution_with_tensor_core` flag in the mid-level API. + WalkResult walk_result = module.walk([&](Operation* op) -> WalkResult { + if (op->hasAttr(kEmbeddingPipelining)) { + const std::string region = + op->getAttrOfType(kEmbeddingPipelining).getValue().str(); + if (region == kEmbeddingForward) { + forward_pass_ops.insert(op); + } else if (region == kEmbeddingBackward) { + backward_pass_ops.insert(op); + } else { + return op->emitOpError() + << "embedding op has unknown " << kEmbeddingPipelining + << " attribute value " << region << "."; + } + op->removeAttr(kEmbeddingPipelining); + } + return WalkResult::advance(); + }); + if (walk_result.wasInterrupted()) return signalPassFailure(); + + // If there are no forward pass ops, there is no SC, so we end early. + if (forward_pass_ops.empty()) { + if (backward_pass_ops.empty()) { + return; + } else { + (*backward_pass_ops.begin())->emitOpError() + << "embedding backwards pass op with no forwards pass ops."; + return signalPassFailure(); + } + } + + // Ensure that all ops are in the same region, and have the same replication + // info. + // TODO(bfontain): Allow for multiple regions/loops in one module. + // TODO(patn): move this pass after cluster formation to remove the complexity + // with replication info and metadata, cluster checking and generalizing to + // multiple TPU clusters. + Region* region = (*forward_pass_ops.begin())->getParentRegion(); + StringAttr replication_attr = GetReplicationAttr(*forward_pass_ops.begin()); + llvm::SmallVector checkset(forward_pass_ops.getArrayRef()); + checkset.append(backward_pass_ops.begin(), backward_pass_ops.end()); + for (Operation* op : checkset) { + if (op->getParentRegion() != region) { + op->emitOpError() << "embedding ops in two different regions"; + return signalPassFailure(); + } + if (GetReplicationAttr(op) != replication_attr) { + op->emitOpError() << "embedding ops with different replication info " + << replication_attr << " vs " << GetReplicationAttr(op); + return signalPassFailure(); + } + } + + // TODO(bfontain): Check that the region here is the region + // of the loop body func. + // Find the FuncOp for the surrounding while loop body. + func::FuncOp loop_body_func = + (*forward_pass_ops.begin())->getParentOfType(); + + // merged_set will keep track of which ops are to be avoided when gather ops + // for inclusion into the four extracted functions. + llvm::SetVector merged_set; + + // Find the TPUReplicationMetadata and TPUCompilationResult ops and delete + // them. These will be cloned/inserted into each region. + TF::TPUReplicateMetadataOp metadata_op; + auto result = FindAndExcludeOp(loop_body_func, replication_attr, merged_set, + metadata_op); + if (failed(result)) return signalPassFailure(); + const int num_replicas = metadata_op.getNumReplicas(); + + TF::TPUCompilationResultOp compilation_op; + result = FindAndExcludeOp( + loop_body_func, replication_attr, merged_set, compilation_op); + if (failed(result)) return signalPassFailure(); + + TF::WhileOp while_op = nullptr; + result = FindOwningWhileOp(loop_body_func, module, &while_op); + if (failed(result)) return signalPassFailure(); + + OpBuilder builder(module); + + result = FindForwardPassOps(builder, forward_pass_ops, backward_pass_ops, + merged_set, loop_body_func, num_replicas); + if (failed(result)) return signalPassFailure(); + merged_set.insert(forward_pass_ops.begin(), forward_pass_ops.end()); + + result = + FindBackwardPassOps(builder, backward_pass_ops, merged_set, num_replicas); + if (failed(result)) return signalPassFailure(); + merged_set.insert(backward_pass_ops.begin(), backward_pass_ops.end()); + + llvm::SetVector core_tpu_ops; + result = FindCoreTPUOps(core_tpu_ops, forward_pass_ops, backward_pass_ops, + merged_set, loop_body_func); + if (failed(result)) return signalPassFailure(); + merged_set.insert(core_tpu_ops.begin(), core_tpu_ops.end()); + + llvm::SetVector non_tpu_ops; + result = FindNonTPUOps(non_tpu_ops, merged_set, loop_body_func); + if (failed(result)) return signalPassFailure(); + merged_set.insert(non_tpu_ops.begin(), non_tpu_ops.end()); + + VLOG(2) << "Forwards pass " << forward_pass_ops.size() + << " ops, backwards pass " << backward_pass_ops.size() + << " ops, core " << core_tpu_ops.size() + << " ops. Total = " << merged_set.size() << " of " + << GetNumOps(loop_body_func) << ".\n"; + + builder.setInsertionPointAfter(*non_tpu_ops.begin()); + Operation* non_tpu_caller = nullptr; + result = + ExtractOpsAsFunc(builder, module, non_tpu_ops, replication_attr, nullptr, + nullptr, loop_body_func, "non_tpu", &non_tpu_caller); + if (failed(result)) return signalPassFailure(); + + builder.setInsertionPointAfter(non_tpu_caller); + Operation* forward_caller = nullptr; + result = ExtractOpsAsFunc(builder, module, forward_pass_ops, replication_attr, + metadata_op, compilation_op, loop_body_func, + "sc_forward", &forward_caller); + if (failed(result)) return signalPassFailure(); + + // Create tpu_core function + builder.setInsertionPointAfter(forward_caller); + Operation* core_tpu_caller = nullptr; + result = ExtractOpsAsFunc(builder, module, core_tpu_ops, replication_attr, + metadata_op, compilation_op, loop_body_func, + "core_tpu", &core_tpu_caller); + if (failed(result)) return signalPassFailure(); + + builder.setInsertionPointAfter(core_tpu_caller); + Operation* backwards_pass_caller = nullptr; + result = ExtractOpsAsFunc( + builder, module, backward_pass_ops, replication_attr, metadata_op, + compilation_op, loop_body_func, "sc_backward", &backwards_pass_caller); + if (failed(result)) return signalPassFailure(); + + metadata_op->erase(); + compilation_op->erase(); +} + +} // namespace + +std::unique_ptr> CreateEmbeddingPipeliningPass() { + return std::make_unique(); +} + +} // namespace TFDevice +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc index 1d7b2c10ba6..9c3e82e88e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc @@ -169,13 +169,14 @@ Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc, ValueRange inputs, Value compilation_key, Value device_ordinal, int default_device_ordinal, + StringAttr device_type_attr, llvm::StringRef communication_key) { if (device_ordinal) return ApplyXlaHostTransferAttr( builder.create( loc, inputs, /*dynamic_key=*/compilation_key, device_ordinal, - builder.getStringAttr(communication_key)), + builder.getStringAttr(communication_key), device_type_attr), builder); return ApplyXlaHostTransferAttr( @@ -183,7 +184,8 @@ Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc, loc, inputs, /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key), - /*device_ordinal=*/builder.getI64IntegerAttr(default_device_ordinal)), + /*device_ordinal=*/builder.getI64IntegerAttr(default_device_ordinal), + device_type_attr), builder); } @@ -192,19 +194,21 @@ Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc, Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc, TypeRange output_types, Value compilation_key, Value device_ordinal, int default_device_ordinal, + StringAttr device_type_attr, llvm::StringRef communication_key) { if (device_ordinal) return ApplyXlaHostTransferAttr( builder.create( loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal, - builder.getStringAttr(communication_key)), + builder.getStringAttr(communication_key), device_type_attr), builder); return ApplyXlaHostTransferAttr( builder.create( loc, output_types, /*dynamic_key=*/compilation_key, builder.getStringAttr(communication_key), - /*device_ordinal=*/builder.getI64IntegerAttr(default_device_ordinal)), + /*device_ordinal=*/builder.getI64IntegerAttr(default_device_ordinal), + device_type_attr), builder); } @@ -332,14 +336,6 @@ bool HasDynamicExternalValues(Operation* op) { .wasInterrupted(); } -// Checks if `type` is allowed for XLA. String and resources are not XLA types. -// There are other TF types that are not XLA types which will be removed by -// successive passes in TF/XLA bridge phase 2. -bool TypeValidForXLA(const Type& type) { - const Type elem = getElementTypeOrSelf(type); - return !elem.isa() && !elem.isa(); -} - // Returns operands of `cluster_ops` that need to be // communicated from device->host. This is for the case when all operands have a // static shape. @@ -354,7 +350,7 @@ llvm::SmallSetVector GetStaticExternalOperands( walked_op)) return WalkResult::advance(); for (Value v : walked_op->getOperands()) { - if (!TypeValidForXLA(v.getType())) continue; + if (!tensorflow::TypeValidForXLA(v.getType())) continue; if (auto* defining_op = v.getDefiningOp()) { if (!op->isAncestor(defining_op) && device_cluster->isAncestor(defining_op) && @@ -385,7 +381,7 @@ llvm::SmallSetVector GetAllExternalOperands( for (Operation* op : cluster_ops) { op->walk([&](Operation* walked_op) { for (Value v : walked_op->getOperands()) { - if (!TypeValidForXLA(v.getType())) continue; + if (!tensorflow::TypeValidForXLA(v.getType())) continue; Operation* defining_op = v.getDefiningOp(); if (!defining_op || !cluster_ops.count(defining_op)) { external_values.insert(v); @@ -431,8 +427,8 @@ void GetExternalOutputs(const llvm::SmallSetVector& cluster_ops, HasDynamicOutputs(user)) { if (!user_set.insert(user).second) continue; for (Value v : user->getOperands()) { - if (TypeValidForXLA(v.getType()) && v.getDefiningOp() == op && - !isa(user)) + if (tensorflow::TypeValidForXLA(v.getType()) && + v.getDefiningOp() == op && !isa(user)) external_outputs.insert(v); if (v.getDefiningOp() == op && isa(user)) tmp_host_outputs.push_back(v); @@ -489,7 +485,7 @@ bool ShouldCloseCluster(llvm::ArrayRef outputs) { return true; } } - if (!TypeValidForXLA(v.getType())) + if (!tensorflow::TypeValidForXLA(v.getType())) for (const Operation* user : v.getUsers()) if (!isa(user)) has_nonxla_output = true; } @@ -570,8 +566,8 @@ void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, const llvm::SmallSetVector& external_operands, const llvm::SmallSetVector& external_outputs, Operation* insertion_point, Value compilation_key, - Value device_ordinal, int default_device_ordignal, - int& communication_key_index) { + Value device_ordinal, int default_device_ordinal, + StringAttr device_type_attr, int& communication_key_index) { OpBuilder builder(insertion_point); Operation& op = *clustered_ops.back(); std::string args_communication_key = @@ -612,7 +608,7 @@ void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, Operation* recv_at_host = CreateRecvAtHostOp( builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal, - default_device_ordignal, args_communication_key); + default_device_ordinal, device_type_attr, args_communication_key); Block* original_op_block = op.getBlock(); Operation* after_op = recv_at_host; for (Operation* cluster_op : clustered_ops) { @@ -624,7 +620,8 @@ void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, if (!external_outputs.empty()) { CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(), compilation_key, device_ordinal, - default_device_ordignal, retvals_communication_key); + default_device_ordinal, device_type_attr, + retvals_communication_key); } if (external_operands.empty()) { @@ -656,7 +653,7 @@ void MoveOpsToHost(const llvm::SmallSetVector& clustered_ops, // its value. LogicalResult MoveOpsToHost( tf_device::ClusterOp device_cluster, Block* src, Operation* insertion_point, - Value compilation_key, Value device_ordinal, int default_device_ordignal, + Value compilation_key, Value device_ordinal, int default_device_ordinal, int& communication_key_index, llvm::SmallVector* return_value_from_host = nullptr) { // Contains all of the outside compiled operations that should be moved to the @@ -664,6 +661,8 @@ LogicalResult MoveOpsToHost( // single op except in the case where some of the input/output shapes are // non-static. llvm::SmallSetVector clustered_ops; + auto device_type_attr = + device_cluster->getAttrOfType(TF::kCompileDeviceTypeAttr); for (Operation& op : llvm::make_early_inc_range(*src)) { if (HasOutsideCompilationAncestorExclusive(&op) || @@ -687,7 +686,8 @@ LogicalResult MoveOpsToHost( } MoveOpsToHost(clustered_ops, external_operands, external_outputs, insertion_point, compilation_key, device_ordinal, - default_device_ordignal, communication_key_index); + default_device_ordinal, device_type_attr, + communication_key_index); clustered_ops.clear(); } @@ -710,7 +710,8 @@ LogicalResult MoveOpsToHost( MoveOpsToHost(clustered_ops, external_operands, external_outputs, insertion_point, compilation_key, device_ordinal, - default_device_ordignal, communication_key_index); + default_device_ordinal, device_type_attr, + communication_key_index); clustered_ops.clear(); } } @@ -740,24 +741,22 @@ void GetReturnValueFromDevice( // `communication_key_index` when creating communication ops. LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, Value compilation_key, Value device_ordinal, - int default_device_ordignal, + int default_device_ordinal, int& communication_key_index) { auto result = device_cluster.GetBody().walk([&](Operation* op) { if (auto if_op = llvm::dyn_cast(op)) { if (!HasOutsideCompilationNested(op)) return WalkResult::advance(); OpBuilder builder(if_op); auto host_if = CloneEmptyIfWithPredicate(if_op, builder); - if (failed(MoveOpsToHost(device_cluster, &if_op.getThenBranch().front(), - host_if.getThenBranch().front().getTerminator(), - compilation_key, device_ordinal, - default_device_ordignal, - communication_key_index))) + if (failed(MoveOpsToHost( + device_cluster, &if_op.getThenBranch().front(), + host_if.getThenBranch().front().getTerminator(), compilation_key, + device_ordinal, default_device_ordinal, communication_key_index))) return WalkResult::interrupt(); - if (failed(MoveOpsToHost(device_cluster, &if_op.getElseBranch().front(), - host_if.getElseBranch().front().getTerminator(), - compilation_key, device_ordinal, - default_device_ordignal, - communication_key_index))) + if (failed(MoveOpsToHost( + device_cluster, &if_op.getElseBranch().front(), + host_if.getElseBranch().front().getTerminator(), compilation_key, + device_ordinal, default_device_ordinal, communication_key_index))) return WalkResult::interrupt(); // Mark op as stateful due to side-effecting communication ops. if_op->setAttr("is_stateless", builder.getBoolAttr(false)); @@ -782,21 +781,21 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp device_cluster, builder.setInsertionPointToEnd(&cond.front()); auto recv_condition_at_host = CreateRecvAtHostOp( builder, while_op.getLoc(), TypeRange{condition.getType()}, - compilation_key, device_ordinal, default_device_ordignal, + compilation_key, device_ordinal, default_device_ordinal, + device_cluster->getAttrOfType(TF::kCompileDeviceTypeAttr), condition_send_recv_key); builder.create(while_op.getLoc(), recv_condition_at_host->getResults()); if (failed(MoveOpsToHost(device_cluster, &while_op.getCond().front(), recv_condition_at_host, compilation_key, - device_ordinal, default_device_ordignal, + device_ordinal, default_device_ordinal, communication_key_index))) return WalkResult::interrupt(); - if (failed(MoveOpsToHost(device_cluster, &while_op.getBody().front(), - host_while.getBody().front().getTerminator(), - compilation_key, device_ordinal, - default_device_ordignal, - communication_key_index))) + if (failed(MoveOpsToHost( + device_cluster, &while_op.getBody().front(), + host_while.getBody().front().getTerminator(), compilation_key, + device_ordinal, default_device_ordinal, communication_key_index))) return WalkResult::interrupt(); // Mark op as stateful due to side-effecting communication ops. while_op->setAttr("is_stateless", builder.getBoolAttr(false)); @@ -1167,7 +1166,7 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( // have a valid XLA type. LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) { for (OpResult result : cluster.getResults()) { - if (!TypeValidForXLA(result.getType())) { + if (!tensorflow::TypeValidForXLA(result.getType())) { cluster.emitError() << "The ExtractHeadTailOutsideCompilation pass produced a Device " "cluster with a result with a non-XLA type: " diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc new file mode 100644 index 00000000000..9284fd2bc0b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc @@ -0,0 +1,199 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" + +#define DEBUG_TYPE "tf-extract-tpu-copy-with-dynamic-shape-op" + +namespace mlir { +namespace TFTPU { + +namespace { + +#define GEN_PASS_DEF_EXTRACTTPUCOPYWITHDYNAMICSHAPEOPPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +class ExtractTPUCopyWithDynamicShapeOpPass + : public impl::ExtractTPUCopyWithDynamicShapeOpPassBase< + ExtractTPUCopyWithDynamicShapeOpPass> { + void runOnOperation() override; +}; + +// Finds op that created a given value. If the value is a BlockArgument, this +// returns the owner of the Block. +Operation* GetOpOfValue(Value value) { + if (auto block_arg = value.dyn_cast()) + return block_arg.getOwner()->getParentOp(); + + return value.getDefiningOp(); +} + +// Check if the TPUCopyWithDynamicShapeOp is valid. +// 1. The op should be wrapped inside a launch op. +// 2. The wrapped launch op should be placed on CPU. +bool IsOpValid(Operation* op) { + auto launch_op = llvm::dyn_cast(op->getParentOp()); + if (!launch_op) return false; + std::string device_str = launch_op.getDeviceAttr().getValue().str(); + return device_str == tensorflow::GetDeviceAliasForHostOfLogicalCore(0) || + device_str == "/job:localhost/replica:0/task:0/device:CPU:0"; +} + +// Get the new launch op results. This is the results if the copy op is removed +// from the old launch op. +llvm::SmallVector CreateNewLaunchOpResults( + tf_device::LaunchOp* old_launch_op, + Operation* tpu_copy_with_dynamic_shape_op) { + llvm::SmallSetVector new_launch_op_results; + + new_launch_op_results.insert( + old_launch_op->GetBody().getTerminator()->getOperands().begin(), + old_launch_op->GetBody().getTerminator()->getOperands().end()); + + for (Value operand : tpu_copy_with_dynamic_shape_op->getOperands()) { + if (GetOpOfValue(operand)->getParentRegion() == + tpu_copy_with_dynamic_shape_op->getParentRegion()) { + new_launch_op_results.insert(operand); + } + } + + for (Value result : tpu_copy_with_dynamic_shape_op->getResults()) { + new_launch_op_results.remove(result); + } + + return new_launch_op_results.takeVector(); +} + +// Create a new host launch op which contains all the old launch op body +// except the dynamic shape copy op. +tf_device::LaunchOp CreateNewHostLaunchOpWithNewResult( + tf_device::LaunchOp* old_launch_op, + llvm::SmallVector& new_launch_op_results) { + OpBuilder builder(*old_launch_op); + + builder.setInsertionPointAfter(*old_launch_op); + + llvm::SmallVector new_launch_op_results_types; + for (Value result : new_launch_op_results) + new_launch_op_results_types.push_back(result.getType()); + + auto new_launch_op = builder.create( + old_launch_op->getLoc(), old_launch_op->getDeviceAttr(), + /*result_types=*/new_launch_op_results_types); + + new_launch_op.getBody().takeBody(old_launch_op->getBody()); + new_launch_op.GetBody().getTerminator()->setOperands(new_launch_op_results); + + return new_launch_op; +} + +// Create the new device launch op which wraps the copy op. +tf_device::LaunchOp CreateNewDeviceLaunchOp( + Operation* tpu_copy_with_dynamic_shape_op, bool replicated) { + OpBuilder builder(tpu_copy_with_dynamic_shape_op); + + builder.setInsertionPointAfter(tpu_copy_with_dynamic_shape_op); + + std::string device_str; + if (replicated) { + device_str = tensorflow::GetDeviceAliasForLogicalCore(0); + } else { + device_str = "/job:localhost/replica:0/task:0/device:TPU:0"; + } + + auto new_device_launch_op = builder.create( + tpu_copy_with_dynamic_shape_op->getLoc(), + builder.getStringAttr(device_str), + /*result_types=*/tpu_copy_with_dynamic_shape_op->getResultTypes()); + + new_device_launch_op.getBody().push_back(new Block); + builder.setInsertionPointToEnd(&new_device_launch_op.GetBody()); + auto* return_op = builder + .create( + tpu_copy_with_dynamic_shape_op->getLoc(), + tpu_copy_with_dynamic_shape_op->getResults()) + .getOperation(); + tpu_copy_with_dynamic_shape_op->moveBefore(return_op); + return new_device_launch_op; +} + +// Update all the usage of tf_device.return op with launch op result. +void UpdateReturnOpResultWithLaunchOpResult(tf_device::LaunchOp* launch_op) { + auto operand_not_in_launch = [&](OpOperand& operand) { + return !launch_op->getOperation()->isProperAncestor(operand.getOwner()); + }; + + for (auto result : + llvm::zip(launch_op->getResults(), + launch_op->GetBody().getTerminator()->getOperands())) + std::get<1>(result).replaceUsesWithIf(std::get<0>(result), + operand_not_in_launch); +} + +void ExtractTPUCopyWithDynamicShapeOpPass::runOnOperation() { + llvm::SmallVector tpu_copy_with_dynamic_shape_ops; + getOperation().walk([&](Operation* op) { + if (isa(op)) { + if (!IsOpValid(op)) return signalPassFailure(); + tpu_copy_with_dynamic_shape_ops.push_back(op); + } + }); + + for (Operation* op : tpu_copy_with_dynamic_shape_ops) { + OpBuilder builder(op); + + auto old_launch_op = llvm::dyn_cast(op->getParentOp()); + + bool replicated = old_launch_op.getDeviceAttr().getValue().str() == + tensorflow::GetDeviceAliasForHostOfLogicalCore(0); + + for (auto result : + llvm::zip(old_launch_op->getResults(), + old_launch_op.GetBody().getTerminator()->getOperands())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); + + llvm::SmallVector new_launch_op_results = + CreateNewLaunchOpResults(&old_launch_op, op); + + op->moveAfter(old_launch_op); + + auto new_host_launch_op = CreateNewHostLaunchOpWithNewResult( + &old_launch_op, new_launch_op_results); + UpdateReturnOpResultWithLaunchOpResult(&new_host_launch_op); + + old_launch_op->erase(); + + auto new_device_launch_op = CreateNewDeviceLaunchOp(op, replicated); + UpdateReturnOpResultWithLaunchOpResult(&new_device_launch_op); + } +} + +} // namespace + +std::unique_ptr> +CreateExtractTPUCopyWithDynamicShapeOpPass() { + return std::make_unique(); +} +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index b5af8f60bd4..dff2223b115 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h" +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -33,7 +35,8 @@ using Graph = ::tensorflow::Graph; } // namespace Status MlirGraphOptimizationPass::Run( - const ConfigProto& config_proto, ModuleOp module, const Graph& graph, + const std::string& function_name, const ConfigProto& config_proto, + ModuleOp module, const Graph& graph, const tensorflow::FunctionLibraryDefinition& function_library) { if (GetPassState(/*device_set=*/nullptr, config_proto, graph, function_library) == diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h index 4da3e14721b..4390e59ca80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ +#include + #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" namespace mlir { @@ -39,6 +41,7 @@ class MlirGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass { } ::tensorflow::Status Run( + const std::string& function_name, const ::tensorflow::ConfigProto& config_proto, ModuleOp module, const ::tensorflow::Graph& graph, const tensorflow::FunctionLibraryDefinition& function_library) override; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc b/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc index 805fb19742a..2edd6d76f03 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc @@ -164,7 +164,7 @@ void wrapOpsInFunction(std::vector& ops, int function_id, auto call = builder.create( ops[0]->getLoc(), func.getFunctionType().getResults(), func.getSymName(), inputs); - for (auto& v : llvm::enumerate(outputs)) { + for (const auto& v : llvm::enumerate(outputs)) { v.value().replaceUsesWithIf(call.getResult(v.index()), [=](OpOperand& o) { // Outside of what we're moving, results of our operations need to // be replaced by results from the function call. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc index 44e178ac76c..3b974c39570 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc @@ -114,8 +114,8 @@ LogicalResult InitializeVariablesInSessionInitializer( const tensorflow::DeviceMgr* mgr = nullptr; auto status = session->LocalDeviceManager(&mgr); if (!status.ok()) { - module->emitError("failed to fetch device manager: " + - status.error_message()); + module->emitError( + absl::StrCat("failed to fetch device manager: ", status.message())); return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc index aa2605cbb33..bc0534fdb0b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -82,7 +82,7 @@ LogicalResult LiftVariablesFromSession( /*target_tensor_names=*/{}, &resource_tensors); if (!status.ok()) { return module.emitOpError() - << "failed to run the provided session: " << status.error_message(); + << "failed to run the provided session: " << status.message(); } const DeviceMgr* device_manager; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc index 25c9e4c0749..54fef16b043 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.cc @@ -56,8 +56,8 @@ LogicalResult MarkInitializedVariablesInFunction(func::FuncOp function, const tensorflow::DeviceMgr* mgr = nullptr; auto status = session->LocalDeviceManager(&mgr); if (!status.ok()) - return function->emitError("failed to fetch device manager: " + - status.error_message()); + return function->emitError( + absl::StrCat("failed to fetch device manager: ", status.message())); // Fetch all varHandleOp in the function. llvm::SmallVector var_ops; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc index af856295f24..709e4532c12 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc @@ -367,7 +367,8 @@ llvm::SmallVector GetReturnIndicesToKeep( } return false; }; - for (auto& index_and_value : llvm::enumerate(current_if_op.getResults())) { + for (const auto& index_and_value : + llvm::enumerate(current_if_op.getResults())) { if (!llvm::all_of(index_and_value.value().getUsers(), is_op_inside_IfRegions)) { return_indices_to_keep.push_back(index_and_value.index()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc index 8594e5ad65a..32ff28ea968 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc @@ -64,7 +64,7 @@ void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::TF::CreateTFShapeInferencePass()); llvm::StringRef tf2xla_fallback_device_type = "XLA_CPU_JIT"; - pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( + pm.addPass(mlir::mhlo::createLegalizeTFPass( /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, tf2xla_fallback_device_type, /*prefer_tf2xla=*/false)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 034fbe1d840..be01d276902 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -43,7 +43,7 @@ def CanFuseMulAndConv2D : Constraint>; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; + CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; def DefinedByConv2D : Constraint($0.getDefiningOp())">>; // Checks if the value has only one user. def HasOneUse : Constraint>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index b17084201a4..4092b90411e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -377,10 +377,12 @@ std::unique_ptr> CreateClusterConstantSinkingPass( llvm::function_ref filter = {}); // Creates a pass that outlines regions of tf_device.cluster operations. -std::unique_ptr> CreateClusterOutliningPass(); +std::unique_ptr> CreateClusterOutliningPass( + bool globally_unique_func_names = true); // Creates a pass that outlines regions of tf_device.launch operations. -std::unique_ptr> CreateLaunchOutliningPass(); +std::unique_ptr> CreateLaunchOutliningPass( + bool globally_unique_func_names = true); // Creates a pass that converts tf_device::LaunchFuncOp into // TF::PartitionedCallOp. @@ -432,6 +434,10 @@ std::unique_ptr> CreateReplicateToIslandPass( std::unique_ptr> CreateReplicaIDToDeviceOrdinalPass(); +// Creates a pass that adds pipelining to a graph that contains device +// accelerated embeddings. +std::unique_ptr> CreateEmbeddingPipeliningPass(); + // Creates a pass that creates `tf_executor.island` from a single // `tf_device.parallel_execute` island. std::unique_ptr> CreateParallelExecuteToIslandsPass( @@ -529,9 +535,16 @@ CreateTPUReorderReplicateAndPartitionedInputsPass(); std::unique_ptr> CreateTPUResourceReadsWritesPartitioningPass(); +// Creates a pass that looks for usage of the result of +// TPUCopyWithDynamicShapeOp and annotate these values to be dynamic shape. This +// ensures that the generated tpu program has the correct inputs annotation. +std::unique_ptr> +CreateTPUAnnotateDynamicShapeInputsPass(); + // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime // ops. -std::unique_ptr> CreateTPURewritePass(); +std::unique_ptr> CreateTPURewritePass( + llvm::StringRef module_name = llvm::StringRef()); // Creates a pass that identifies XLASharding ops in launch op for TPU // computation. @@ -549,6 +562,12 @@ CreateTPUParallelExecuteSinkResourceWritePass(); std::unique_ptr> CreateTPUMergeVariablesWithExecutePass(); +// Create a pass that extract TPUCopyWithDynamicShapeOp from the host launch op +// and wrap them in device launch op. This allows this op executed on TPU while +// still compiled on host. +std::unique_ptr> +CreateExtractTPUCopyWithDynamicShapeOpPass(); + // Creates a pass that wraps ReadVariableOp/AssignVariable op that consumes a // packed tensor to have same device placement as underlying TPU device. std::unique_ptr> @@ -578,9 +597,13 @@ CreateTPUUpdateEmbeddingEnqueueOpInputsPass(); // Creates a pass that propagates TPU devices to users. std::unique_ptr> CreateTPUDevicePropagationPass(); +// Create a pass that colocates each `Split` with its predecessor. +std::unique_ptr> CreateTPUColocateSplitsPass(); + // Populates the supplied passmanager with the passes required to run the // bridge. -void CreateTPUBridgePipeline(OpPassManager& pm); +void CreateTPUBridgePipeline(OpPassManager& pm, + llvm::StringRef module_name = llvm::StringRef()); // Populates the supplied passmanager with the passes required to run the // bridge in V1 mode. @@ -681,6 +704,7 @@ enum MoveTransposeDirection { kBegin, kEnd }; #define GEN_PASS_DECL_TPUHOSTCOMPUTATIONEXPANSIONPASS #define GEN_PASS_DECL_TPUIDENTITYPRUNINGPASS #define GEN_PASS_DECL_TPUMERGEVARIABLESWITHEXECUTEPASS +#define GEN_PASS_DECL_EXTRACTTPUCOPYWITHDYNAMICSHAPEOPPASS #define GEN_PASS_DECL_TPUPARALLELEXECUTESINKRESOURCEWRITEPASS #define GEN_PASS_DECL_TPUREORDERREPLICATEANDPARTITIONEDINPUTSPASS #define GEN_PASS_DECL_TPURESOURCEREADFORWRITEPASS diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 457dca838af..81affc412bc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -67,6 +67,12 @@ class RewriteXlaHostComputeMlir LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op, PatternRewriter& rewriter) const override { + if (op.getManualSharding()) { + op.emitOpError() << "manual_sharding not supported with fallback of " + "phase 2 legalize TF/XLA bridge. manual_sharding is " + "used by map_outside_compilation"; + return failure(); + } llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { @@ -99,7 +105,8 @@ class RewriteXlaHostComputeMlir auto recv_at_host = rewriter.create( func.getLoc(), op.getOperandTypes(), /*dynamic_key=*/dynamic_key, op.getSendKeyAttr(), - /*device_ordinal=*/rewriter.getI64IntegerAttr(0)); + /*device_ordinal=*/rewriter.getI64IntegerAttr(0), + rewriter.getStringAttr("TPU")); for (auto result : llvm::zip(cloned_func.getArguments(), recv_at_host->getResults())) { std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); @@ -110,7 +117,8 @@ class RewriteXlaHostComputeMlir func.getLoc(), cloned_func.getBody().front().getTerminator()->getOperands(), /*dynamic_key=*/dynamic_key, op.getRecvKeyAttr(), - /*device_ordinal=*/rewriter.getI64IntegerAttr(0)); + /*device_ordinal=*/rewriter.getI64IntegerAttr(0), + rewriter.getStringAttr("TPU")); } constexpr int64_t kDefaultCostEstimate = 1000000; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 1831ecc68db..49b913963af 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -17,6 +17,7 @@ limitations under the License. // result(s) regardless of replication, out of their respective replicate. #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" @@ -45,8 +46,46 @@ struct ReplicateInvariantOpHoistingPass void runOnOperation() override; }; +// Check if op directly uses a key in `virtual_devices`. +bool DirectUseOfVirtualDevice(const DictionaryAttr& virtual_devices, + Operation* op) { + StringAttr op_device = op->getAttrOfType(kDeviceAttr); + if (!op_device) return false; + if (virtual_devices.get(op_device.getValue())) return true; + return false; +} + +// Check if op or its ancestor uses a key in `virtual_devices`. +bool AncestorUsesVirtualDevice( + const std::optional& virtual_devices, Operation* op) { + if (!virtual_devices.has_value()) return false; + if (!op) return false; + if (llvm::isa(op)) return false; + if (DirectUseOfVirtualDevice(*virtual_devices, op)) return true; + return AncestorUsesVirtualDevice(virtual_devices, op->getParentOp()); +} + +// Check if op or its descendant uses a key in `virtual_devices`. +bool DescendantUsesVirtualDevice( + const std::optional& virtual_devices, + Operation* operation) { + if (!virtual_devices.has_value()) return false; + + auto result = operation->walk([&](Operation* op) { + if (DirectUseOfVirtualDevice(*virtual_devices, op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return result.wasInterrupted(); +} + +// Make invariant the `ShapeOp`s or a `ReadVariableOp` that's the `ShapeOp`'s +// predecessor. void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, Block* replicate_block, TF::ShapeOp shape_op) { + // Ignore ShapeOps that have virtual devices. + if (AncestorUsesVirtualDevice(replicate_op.getDevices(), shape_op)) return; + Value input = shape_op.getInput(); // If ShapeOp operand is replicate tensor block argument, replace with the // associated first replica operand. @@ -85,22 +124,6 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, } } -// Check if op uses a device from a list of virtual devices. -bool UsesVirtualDevice(const std::optional& virtual_devices, - Operation* operation) { - if (!virtual_devices.has_value()) return false; - - auto result = operation->walk([&](Operation* op) { - StringAttr op_device = op->getAttrOfType(kDeviceAttr); - if (!op_device) return WalkResult::advance(); - - if (virtual_devices.value().get(op_device.getValue())) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - return result.wasInterrupted(); -} - // Checks if op and inner op operands are all replicate invariant. bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { auto ancestor_of_replicate = [&](Region* region) { @@ -110,6 +133,9 @@ bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { for (Value operand : op->getOperands()) if (!ancestor_of_replicate(operand.getParentRegion())) return false; + // _TPUDeviceOrdinalPlaceholder implicitly depends on the replica. + if (llvm::isa(op)) return false; + bool has_replicate_operands = false; visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) { if (!ancestor_of_replicate(operand->get().getParentRegion())) @@ -127,6 +153,10 @@ void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { const int num_replicas = replicate_op.getN(); Block* replicate_block = &replicate_op.GetBody(); + // A `ShapeOp` that directly depends on a `tf_device.replicate` param and does + // not have a virtual device is assumed to return the same shape across all + // replicas. Thus it is invariant across replicas. + // TODO(b/277936694): Remove this assumption and special case. replicate_op.walk([&](TF::ShapeOp shape_op) { MakeShapeOpInvariant(replicate_op, num_replicas, replicate_block, shape_op); }); @@ -138,7 +168,7 @@ void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { if (llvm::isa(inner_op)) continue; // Skip hoisting if the inner op device attribute is a virtual device // defined by tf_device.replicate. - if (UsesVirtualDevice(virtual_device_list, &inner_op)) continue; + if (DescendantUsesVirtualDevice(virtual_device_list, &inner_op)) continue; if (IsOpReplicateInvariant(replicate_region, &inner_op)) inner_op.moveBefore(replicate_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index aa4941ec5b6..a8bfb700209 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -464,7 +464,8 @@ void RegionResourceHoister::ReplaceOpWithNewOp() { // Clone this old operation but with new result types. Operation* new_op = Operation::create( op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(), - op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions()); + op_->getAttrs(), op_->getPropertiesStorage(), op_->getSuccessors(), + op_->getNumRegions()); builder.insert(new_op); // Move regions to the new op. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc index 08e7b308b10..99693a91b2f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -97,7 +97,8 @@ void EliminateUnusedResults( OpBuilder builder(op); Operation *new_op = Operation::create( op->getLoc(), op->getName(), new_result_types, op->getOperands(), - op->getAttrs(), op->getSuccessors(), op->getNumRegions()); + op->getAttrs(), op->getPropertiesStorage(), op->getSuccessors(), + op->getNumRegions()); builder.insert(new_op); // Move region bodies to the new operation. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 3b176cb0ba3..1edc7f4bb73 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -19,9 +19,13 @@ limitations under the License. #include #include #include +#include #include #include #include +#include +#include +#include #include "absl/container/flat_hash_set.h" #include "llvm/ADT/ArrayRef.h" @@ -45,6 +49,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/FunctionInterfaces.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project @@ -69,6 +74,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" @@ -427,6 +433,35 @@ Type GetType(Attribute shape_attr, Attribute type_attr) { return UnrankedTensorType::get(type.getValue()); } +// Returns a new arg type based on the shape and element type. If there are +// dynamic bounds attribute to the arg, update the bounds based on the shape +// as well. +Type GetNewArgType(Type old_arg_type, ArrayRef shape, + Type element_type, mlir::MLIRContext* context) { + Type new_arg_type = tensorflow::GetTypeFromTFTensorShape(shape, element_type); + + if (auto input_ty = old_arg_type.dyn_cast()) { + ArrayRef bounds = hlo::encodingToBounds(input_ty.getEncoding()); + // The input type has bounded dynamic dimension. + if (!bounds.empty()) { + SmallVector new_bounds(bounds.begin(), bounds.end()); + SmallVector new_shape(shape.begin(), shape.end()); + // If dimension of the input type is dynamic. Update the + // bounds of the dim with the new type if needed. + for (int i = 0; i < input_ty.getShape().size(); i++) { + if (hlo::isDynamicDimSize(input_ty.getShape()[i])) { + new_bounds[i] = new_shape[i]; + new_shape[i] = ShapedType::kDynamic; + } + } + new_arg_type = tensorflow::GetTypeFromTFTensorShape( + new_shape, element_type, + mhlo::TypeExtensionsAttr::get(context, new_bounds)); + } + } + return new_arg_type; +} + } // namespace // Returns whether type can be further refined. @@ -883,6 +918,10 @@ class ShapeInference { // yields. bool InferShapeForCaseRegion(CaseRegionOp op); + // Infers the shape CaseRegion outputs based on the embedded StableHLO module. + // Returns true if a return type was changed. + bool InferShapeForXlaCallModule(XlaCallModuleOp op); + // Infers the shape of _XlaHostComputeMlir based on the host computation // module. Returns true if a return type was changed. bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op); @@ -955,6 +994,14 @@ class ShapeInference { // TODO(b/154065712): Remove propagate_caller_callee_constants once using // SCCP pass instead. bool propagate_caller_callee_constants_; + + // XlaCallModule loader, which is used to deserialize the StableHLO module in + // each `XlaCallModule` op. Uses its own MLIRContext since the loader needs to + // load additional dialects, which is not allowed for the main context since + // shape inference may be called from a pass. + MLIRContext xla_call_module_context_; + DenseMap> + xla_call_module_loaders_; }; ShapeInference::ShapeInference(int64_t graph_version, ModuleOp module, @@ -1141,6 +1188,74 @@ bool ShapeInference::InferShapeForCaseRegion(CaseRegionOp op) { return changed; } +bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { + tensorflow::XlaCallModuleLoader* loader; + { + const auto [it, inserted] = xla_call_module_loaders_.insert({op, nullptr}); + + // Lazily parse XlaCallModule's embedded HLO module and cache the loader to + // avoid repeatedly parsing the module. + if (inserted) { + std::vector dim_args_spec; + for (auto attr : op.getDimArgsSpec().getAsRange()) { + dim_args_spec.push_back(attr.getValue().str()); + } + + // Always use the first platform. The assumption is that shape inference + // results should be the same regardless of which platform is chosen. + int platform_index = op.getPlatforms().size() > 1 ? 0 : -1; + + auto l = tensorflow::XlaCallModuleLoader::Create( + &xla_call_module_context_, op.getVersion(), op.getModule().str(), + std::move(dim_args_spec), platform_index); + if (!l.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: " + << l.status().ToString() << "\n"); + return false; + } + it->second = *std::move(l); + } + + loader = it->second.get(); + } + + // Cannot pass `op.getArgs().getTypes()` to `loader->RefineDynamicShapes` + // because `op` and `loader` are using different MLIR contexts. See comments + // on `xla_call_module_context_` for details. + std::vector input_shapes; + input_shapes.reserve(op.getArgs().size()); + for (mlir::Type type : op.getArgs().getTypes()) { + input_shapes.push_back(xla::TypeToShape(type)); + } + + tsl::Status status = loader->RefineDynamicShapes(input_shapes); + if (!status.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Failed during XlaCallModule shape refinement: " + << status.ToString()); + return false; + } + + bool changed = false; + for (auto [result, type] : + llvm::zip(op.getResults(), loader->output_types())) { + auto ranked = type.dyn_cast(); + if (ranked == nullptr) { + LLVM_DEBUG(llvm::dbgs() + << "Unsupported XlaCallModule result type: " << type); + continue; + } + + // Build a new type object from `type` and `elem_type`. `type` is owned by + // `xla_call_module_context_` and should not be mixed with op's context. + auto new_type = RankedTensorType::get( + ranked.getShape(), getElementTypeOrSelf(result.getType())); + + changed = RefineResultType(op, result, new_type) || changed; + } + + return changed; +} + bool ShapeInference::InferShapeForXlaHostComputeMlir( _XlaHostComputeMlirOp host_compute_op) { // Extract the module and function. @@ -1741,14 +1856,14 @@ bool ShapeInference::InferShapeForXlaGatherOp(XlaGatherOp op) { auto output_shape = xla::ShapeInference::InferGatherShape( input_shape, start_indices_shape, gather_dim_numbers, slice_sizes); if (!output_shape.ok()) { - op->emitError(output_shape.status().error_message()); + op->emitError() << output_shape.status().message(); return false; } auto refined_type = xla::ConvertShapeToType( *output_shape, mlir::Builder(op)); if (!refined_type.ok()) { - op->emitError(refined_type.status().error_message()); + op->emitError() << refined_type.status().message(); return false; } @@ -2030,7 +2145,8 @@ bool ShapeInference::RefineWithInferTypeOpInterface( SmallVector inferred; LogicalResult res = infer_ti.inferReturnTypes( op->getContext(), op->getLoc(), op->getOperands(), - op->getAttrDictionary(), op->getRegions(), inferred); + op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions(), + inferred); if (failed(res)) { op->emitOpError("failed to refine type as inference failed"); return false; @@ -2319,6 +2435,10 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, while_region, while_region.getBody().front().getTerminator()->getOperandTypes()); + if (auto xla_call_module = dyn_cast(op)) { + return InferShapeForXlaCallModule(xla_call_module); + } + if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) { return InferShapeForXlaHostComputeMlir(host_compute_op); } @@ -2952,8 +3072,9 @@ FailureOr InferShapeForFunction(func::FuncOp func, element_type = unranked_input_ty.getElementType(); } - auto new_arg_type = - tensorflow::GetTypeFromTFTensorShape(shape, element_type); + auto new_arg_type = GetNewArgType(func_type.getInput(i), shape, + element_type, func.getContext()); + if (new_arg_type != func_type.getInput(i)) { // If the new type is more detailed, trigger shape inference. func.getArgument(i).setType(new_arg_type); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td index f46fb14d2ca..e307266c93f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td @@ -170,7 +170,7 @@ def HostLaunchToOutsideCompiledPass : Pass<"tf-device-host-launch-to-outside-com "tf_device.launch"() { "tf.B"() tf_device.return - } {device = "TPU_REPLICATED_HOST"} : () -> () + } {device = "TPU_REPLICATED_HOST_0"} : () -> () "tf.C"() tf_device.return }) {num_cores_per_replica = 1, topology = "", device_assignment = []} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index e84cf959800..46a23a48a6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" +#include + #include "llvm/Support/CommandLine.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -26,9 +28,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/tsl/platform/statusor.h" @@ -72,7 +74,7 @@ void GraphOptPass::runOnOperation() { auto graph = std::make_unique(flib_def); Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def); if (!status.ok()) { - mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.error_message(); + mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.message(); return signalPassFailure(); } @@ -92,7 +94,7 @@ void GraphOptPass::runOnOperation() { Status status = pass->Run(options); if (!status.ok()) { mlir::emitError(mlir::UnknownLoc::get(&ctx)) - << pass->name() << ": " << status.error_message(); + << pass->name() << ": " << status.message(); return signalPassFailure(); } } @@ -104,7 +106,7 @@ void GraphOptPass::runOnOperation() { ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx); if (!module_or_status.ok()) { mlir::emitError(mlir::UnknownLoc::get(&ctx)) - << module_or_status.status().error_message(); + << module_or_status.status().message(); return signalPassFailure(); } auto module_out = std::move(module_or_status).value(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index 94d5ee37e05..839d9d601d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -393,6 +393,15 @@ def ReplicaIDToDeviceOrdinalPass : Pass<"tf-replica-id-to-device-ordinal", "mlir }]; } +def EmbeddingPipeliningPass : Pass<"tf-embedding-pipelining", "mlir::ModuleOp"> { + let summary = "Rewrite graph for embedding pipelining"; + let constructor = "TFDevice::CreateEmbeddingPipeliningPass()"; + let description = [{ + For architectures that support accelerated embedding lookups, this pass will + rewrite the graph to use pipelining for better device utilization. + }]; +} + def ConvertReadonlyReferenceVariablesToResourceVariablesPass : Pass<"tf-readonly-references-to-resources", "mlir::func::FuncOp"> { let summary = "Convert readonly reference variables to resource variables."; @@ -1115,6 +1124,13 @@ def ClusterOutliningPass : Pass<"tf-device-cluster-outlining", "ModuleOp"> { }]; let constructor = "TFDevice::CreateClusterOutliningPass()"; + + let options = [ + Option<"globally_unique_func_names_", "globally-unique-func-names", "bool", + /*default=*/"true", + "If true, the pass adds extra identifiers to make function names " + "globally unique within a process, not just within a module."> + ]; } def ConvertTfControlFlowToScfPass : Pass<"convert-tf-control-flow-to-scf", "ModuleOp"> { @@ -1168,6 +1184,13 @@ def LaunchOutliningPass : Pass<"tf-device-launch-outlining", "ModuleOp"> { }]; let constructor = "TFDevice::CreateLaunchOutliningPass()"; + + let options = [ + Option<"globally_unique_func_names_", "globally-unique-func-names", "bool", + /*default=*/"true", + "If true, the pass adds extra identifiers to make function names " + "globally unique within a process, not just within a module."> + ]; } def ConvertLaunchFuncToTFCallPass : Pass<"tf-device-convert-launch-func-to-tf-call", "ModuleOp"> { @@ -1575,7 +1598,7 @@ def TPURewritePass : Pass<"tf-tpu-rewrite", "ModuleOp"> { ```mlir func @tf_tpu_rewrite(%arg0: tensor, %arg1: tensor) { - %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} { + %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST_0 = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} { %1:2 = "tf_device.launch"() ( { %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = ""} : () -> (tensor, tensor<3x!tf_type.string>) tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> @@ -1781,6 +1804,31 @@ def TPUMergeVariablesWithExecutePass : Pass<"tf-tpu-merge-variables-with-execute let constructor = "TFTPU::CreateTPUMergeVariablesWithExecutePass()"; } +def ExtractTPUCopyWithDynamicShapeOpPass : Pass<"tf-extract-tpu-copy-with-dynamic-shape-op", "mlir::func::FuncOp"> { + let summary = "Extract the TPUCopyWithDynamicShapeOp out of the host launch and place it on device launch"; + + let description = [{ + This pass looks for TPUCopyWithDynamicShapeOp which wraps in a + `tf_device.launch` with host device attribute. It extracts the ops and wrap + them in `tf_device.launch` with tpu device attribute so that ops can be + run on TPU instead of CPU while still being compiled on host. + }]; + + let constructor = "TFTPU::CreateExtractTPUCopyWithDynamicShapeOpPass()"; +} + +def TPUAnnotateDynamicShapeInputsPass : Pass<"tf-tpu-annotate-dynamic-shape-inputs", "ModuleOp"> { + let summary = "Annotate the inputs returned by TPUCopyWithDynamicShapeOp with dynamic shape"; + + let description = [{ + This pass looks for the usage of the result of TPUCopyWithDynamicShapeOp + and sets the shape of these inputs to be dynamic shaped. This will ensure + that the generated HLO program is correctly reflecting the dynamic shape. + }]; + + let constructor = "TFTPU::CreateTPUAnnotateDynamicShapeInputsPass()"; +} + def ReplicateInvariantOpHoistingPass : Pass<"tf-replicate-invariant-op-hoisting", "mlir::func::FuncOp"> { let summary = "Hoists replicate invariant operations out of replicate"; @@ -1790,6 +1838,12 @@ def ReplicateInvariantOpHoistingPass : Pass<"tf-replicate-invariant-op-hoisting" if possible. This currently updates or replaces `tf.Shape` ops of replicated arguments, either tensors or resources. + The primary benefit of the pass is to hoist `num_replicas` `_TPUCompile`s + into a single `_TPUCompile`. + + This pass assumes that when a `tf.Shape` directly inputs from `replicate` + params, then it is the same shape across replicas. + For example, the following ```mlir @@ -1878,7 +1932,7 @@ def OutsideCompiledToHostLaunchPass : Pass<"tf-outside-compiled-to-host-launch", "tf_device.launch"() { "tf.B"() {_xla_outside_compilation = "cluster1"} tf_device.return - } {device = "TPU_REPLICATED_HOST"} : () -> () + } {device = "TPU_REPLICATED_HOST_0"} : () -> () "tf.C"() tf_device.return }) {num_cores_per_replica = 1, topology = "", device_assignment = []} @@ -2357,6 +2411,31 @@ def TPUDevicePropagationPass : Pass<"tf-tpu-device-propagation", "mlir::func::Fu let constructor = "TFTPU::CreateTPUDevicePropagationPass()"; } +def TPUColocateSplitsPass : Pass<"tf-tpu-colocate-splits", "mlir::func::FuncOp"> { + let summary = "Colocates each Split op with its predecessor"; + let constructor = "TFTPU::CreateTPUColocateSplitsPass()"; + let description = [{ + It is beneficial for performance to assign a `Split` op to the same device + as its predecessor. This is because the weight of cut edges is always + minimized when the `Split` is with its predecessor. This colocation + constraint will be used by the placer graph optimization to assign a device + to the op. + + This pass should run in the export pipeline after tf-replicate-to-island so + each replica has its own distinct (predecessor, Split) pair. + + The colocation class (`_class`) of the `Split` is set to the same class as + its predecessor: + + ```mlir + %outputs1:2, %control1 = tf_executor.island wraps "tf.IteratorGetNext"(%arg) + {_class = ["loc:@dataset_iterator_1"]} + %outputs2:2, %control2 = tf_executor.island wraps "tf.Split"(%outputs0, %outputs1#1) + {_class = ["loc:@dataset_iterator_1", num_split = 2 : i32} + ``` + }]; +} + def TPUIdentityPruningPass : Pass<"tf-tpu-identity-pruning", "ModuleOp"> { let summary = "Removes Identity/IdentityN ops from the TPU computation"; let constructor = "TFTPU::CreateTPUIdentityPruningPass()"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc index a1552ffc4ac..141807309c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc @@ -363,8 +363,8 @@ LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) { const tensorflow::DeviceMgr* mgr = nullptr; auto status = session->LocalDeviceManager(&mgr); if (!status.ok()) { - module->emitError("failed to fetch device manager: " + - status.error_message()); + module->emitError( + absl::StrCat("failed to fetch device manager: ", status.message())); return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc new file mode 100644 index 00000000000..76f64d18935 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc @@ -0,0 +1,159 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +#define DEBUG_TYPE "tf-tpu-annotate-dynamic-shape-inputs" + +namespace mlir { +namespace TFTPU { + +namespace { + +#define GEN_PASS_DEF_TPUANNOTATEDYNAMICSHAPEINPUTSPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +class TPUAnnotateDynamicShapeInputsPass + : public impl::TPUAnnotateDynamicShapeInputsPassBase< + TPUAnnotateDynamicShapeInputsPass> { + void runOnOperation() override; +}; + +// Finds op that created a given value. If the value is a BlockArgument, this +// returns the owner of the Block. +Operation* GetOpOfValue(Value value) { + if (auto block_arg = value.dyn_cast()) + return block_arg.getOwner()->getParentOp(); + + return value.getDefiningOp(); +} + +void TPUAnnotateDynamicShapeInputsPass::runOnOperation() { + getOperation().walk([&](tf_device::ClusterFuncOp cluster_func_op) { + Builder builder(cluster_func_op->getContext()); + // Skip non-tpu device cluster_func. + auto cluster_id = + cluster_func_op->getAttrOfType(TF::kReplicationInfoAttr); + if (!cluster_id) return WalkResult::advance(); + + llvm::SmallVector dynamic_shape_arg_index; + + // Traverse the operands of the cluster func op and find which operand + // is returned by TPUAnnotateTensorsWithDynamicShapeOp. + for (const auto& cluster_func_operand : + llvm::enumerate(cluster_func_op.getOperands())) { + auto device_launch_op = llvm::dyn_cast( + GetOpOfValue(cluster_func_operand.value())); + if (!device_launch_op) continue; + for (auto result : llvm::zip( + device_launch_op.getResults(), + device_launch_op.GetBody().getTerminator()->getOperands())) { + if (std::get<0>(result) == cluster_func_operand.value() && + llvm::isa( + std::get<1>(result).getDefiningOp())) { + dynamic_shape_arg_index.push_back(cluster_func_operand.index()); + } + } + } + + cluster_func_op->setAttr(TF::kDynamicArgIndexAttr, + builder.getI32ArrayAttr(dynamic_shape_arg_index)); + + FlatSymbolRefAttr func_attr = cluster_func_op.getFuncAttr(); + func::FuncOp func = + cluster_func_op->getParentOfType().lookupSymbol( + func_attr.getValue()); + + // Update the marked argument with dynamic shapes. + for (int index : dynamic_shape_arg_index) { + BlockArgument arg = func.getArgument(index); + auto inputType = arg.getType().dyn_cast(); + // Only rank 1 tensor is supported for now. + if (!inputType || inputType.getRank() != 1) continue; + auto shape = llvm::to_vector<4>(inputType.getShape()); + llvm::SmallVector bounds(shape.begin(), shape.end()); + // Mark the dim as dynamic dim. + shape[0] = ShapedType::kDynamic; + auto extensions = + mhlo::TypeExtensionsAttr::get(func->getContext(), bounds); + auto resultType = + RankedTensorType::get(shape, inputType.getElementType(), extensions); + arg.setType(resultType); + } + llvm::SmallVector arg_types; + for (auto arg : func.getArguments()) arg_types.push_back(arg.getType()); + func.setType( + FunctionType::get(func.getContext(), arg_types, + func.front().getTerminator()->getOperandTypes())); + return WalkResult::advance(); + }); + + // Remove the annotated op after since it is just a placeholder. + DenseSet launch_ops; + getOperation().walk([&](Operation* op) { + if (llvm::isa(op)) { + for (auto result : llvm::zip(op->getOperands(), op->getResults())) { + std::get<1>(result).replaceAllUsesWith(std::get<0>(result)); + } + launch_ops.insert(op->getParentOfType()); + op->erase(); + } + return WalkResult::advance(); + }); + + for (auto launch_op : launch_ops) { + Block& block = launch_op.GetBody(); + if (&block.front() == &block.back()) { + // The tf_device.launch is empty (except for the return). + // Remove the whole tf_device.launch, since later passes will make it send + // the arguments back and forth between the devices. + Operation* return_op = &block.back(); + assert(llvm::isa(return_op)); + for (auto [inner, outer] : + llvm::zip(return_op->getOperands(), launch_op->getResults())) { + outer.replaceAllUsesWith(inner); + } + launch_op->erase(); + } + } +} + +} // namespace + +std::unique_ptr> +CreateTPUAnnotateDynamicShapeInputsPass() { + return std::make_unique(); +} +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index c9dd2824ce1..bb2f8f26b65 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -142,6 +144,7 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, std::string& device) { bool has_replicated_compiled_op = false; bool has_non_replicated_compiled_op = false; + bool has_local_device_name_collisions = false; // Use ordered set here to make error message below deterministic. std::set device_types; std::unordered_map devices; @@ -197,17 +200,20 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, // information such as task, replica, job etc. An example fullname is // "/job:foo_bar/replica:1/task:2/device:GPU:3" if (devices.count(device_local_name)) { + std::string device1 = devices[device_local_name]; + std::string device2 = device_attr.str(); + // Is either of the two devices just a substring of the other? If + // not, we treat them as different devices, and we have a collision. + if (device1.find(device2) == std::string::npos && + device2.find(device1) == std::string::npos) { + has_local_device_name_collisions = true; + LOG(WARNING) << "found two devices with same local name " + << device_local_name + << " but conflicting fullname: " << device1 << " and " + << device2; + } + // Always keep the longer name. if (devices[device_local_name].size() < device_attr.str().size()) { - // If for same local name, the smaller device fullname is not - // a substring of larger device fullname, then there is definitely - // some issue with device names. - if (device_attr.str().find(devices[device_local_name]) == - std::string::npos) { - LOG(WARNING) << "found two devices with same local name but " - "conflicting fullname: " - << device_attr.str() << " and " - << devices[device_local_name]; - } devices[device_local_name] = device_attr.str(); } } else { @@ -233,9 +239,10 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters, for (const auto& device_names : devices) { LOG(WARNING) << device_names.first << ", " << device_names.second; } - } - if (devices.size() == 1 && - absl::StrContains(devices.begin()->second, "TPU:")) { + } else if (has_local_device_name_collisions) { + LOG(WARNING) << "Not assigning device because of conflicting fullnames."; + } else if (devices.size() == 1 && + absl::StrContains(devices.begin()->second, "TPU:")) { device = devices.begin()->second; } } @@ -437,12 +444,191 @@ tf_device::ClusterOp CreateClusterOp( return cluster; } +// Returns an op of the given type that uses the result, along with +// a list of identity ops along the way. +template +std::tuple> GetSingleUserOfType( + OpResult result) { + llvm::SmallVector identity_ops; + + do { + Operation* user = result.hasOneUse() ? *result.getUsers().begin() : nullptr; + if (auto t = llvm::dyn_cast_or_null(user)) { + return std::make_tuple(t, identity_ops); + } else if (auto identity = llvm::dyn_cast_or_null(user)) { + identity_ops.emplace_back(identity); + result = identity->getResult(0); + } else { + result = OpResult(); // reset to stop iterating + } + } while (result); + + return std::make_tuple(T(), identity_ops); +} + +using PartitionedClusterOutputMap = + absl::flat_hash_map>; + +// Returns the partitioned output ops from the cluster if there are any, +// along with any single user identity ops between them. Not all outputs +// of a cluster must be partitioned, so the output is a map from cluster +// output ids to ops. +std::tuple> +GetPartitionedOutputsAndIdentityOps(tf_device::ClusterOp cluster) { + PartitionedClusterOutputMap partitioned_outputs; + llvm::SmallVector erase_list; + + for (auto [cluster_result_id, cluster_result] : + llvm::enumerate(cluster.getResults())) { + auto [replicated_output, _] = + GetSingleUserOfType(cluster_result); + if (replicated_output) { + for (OpResult per_replica_result : replicated_output->getResults()) { + auto [partitioned_output, id_ops] = + GetSingleUserOfType( + per_replica_result); + if (partitioned_output) { + erase_list.insert(erase_list.end(), id_ops.begin(), id_ops.end()); + partitioned_outputs[cluster_result_id].emplace_back( + partitioned_output); + } + } + } + } + + return std::forward_as_tuple(partitioned_outputs, erase_list); +} + +// Inlines the partitioned output ops into the cluster, and updates +// their users to point to the replicate op instead. +Operation* BuildPartitionedOutputs( + OpBuilder& builder, tf_device::ClusterOp cluster, + tf_device::ReplicateOp replicate_op, + PartitionedClusterOutputMap& partitioned_outputs, + llvm::SmallVector& erase_list, + llvm::SmallVector& result_types, int num_replicas) { + Operation* result_op; + llvm::SmallVector results; + uint64_t num_results = cluster.getNumResults(); + for (uint64_t result_id = 0; result_id < num_results; ++result_id) { + auto search = partitioned_outputs.find(result_id); + if (search == partitioned_outputs.end()) { + // If the output is not partitioned, directly pass it through. + results.emplace_back(cluster.getResult(result_id)); + + continue; + } + + // Otherwise, "inline" the partitioned output ops by: + // - Building a new op within the cluster. + // - Replacing all the uses of the original ops with the cluster's outputs. + llvm::SmallVector& ops = search->second; + for (auto [replica_id, partitioned_output] : llvm::enumerate(ops)) { + for (auto [core_id, result] : + llvm::enumerate(partitioned_output->getResults())) { + // outputs from replicate op are interleaved: + // [(replica:0,core:0), (replica:1,core:0), ..., + // (replica:0,core:1), (replica:1,core:1), ...] + uint64_t output_id = + core_id * num_replicas + replica_id + results.size(); + result.replaceAllUsesWith(replicate_op.getResult(output_id)); + } + } + + // Assume all the replicas have the same structure. + TF::TPUPartitionedOutputV2Op first_op = *(ops.begin()); + ArrayAttr dims = first_op.getPartitionDimsAttr(); + StringAttr sharding = first_op.get_XlaShardingAttr(); + Operation::result_type_range output_types = first_op.getResultTypes(); + result_op = builder.create( + replicate_op.getLoc(), output_types, cluster.getResult(result_id), dims, + sharding); + + results.insert(results.end(), result_op->getResults().begin(), + result_op->getResults().end()); + } + + // Once we've accumulated all the cluster's results, build a return op. + builder.create(result_op->getLoc(), results); + + // Then erase all the identity and partitioned output ops. + for (auto [_, ops] : partitioned_outputs) { + for (TF::TPUPartitionedOutputV2Op op : ops) { + op->erase(); + } + } + + for (TF::IdentityOp to_erase : erase_list) { + to_erase->erase(); + } + + return result_op; +} + +// Return the cluster's per-replica result type, converting any full-shaped +// tensor types into sharded-shaped ones if they're partitioned. +llvm::SmallVector GetClusterResultTypes( + tf_device::ClusterOp cluster, + const PartitionedClusterOutputMap& partitioned_outputs) { + llvm::SmallVector result_types; + Operation::result_type_range cluster_result_types = cluster.getResultTypes(); + if (partitioned_outputs.empty()) { + // Directly pass through the cluster's outputs if none are partitioned. + result_types.insert(result_types.end(), cluster_result_types.begin(), + cluster_result_types.end()); + } else { + // For each output of the cluster... + for (auto [output_id, result_type] : + llvm::enumerate(cluster_result_types)) { + auto search = partitioned_outputs.find(output_id); + if (search == std::end(partitioned_outputs)) { + // If it's not partitioned, directly pass it through. + result_types.emplace_back(result_type); + } else { + // Otherwise, pass through the result shard types. + Operation::result_type_range partitioned_result_types = + (*search->second.begin())->getResultTypes(); + result_types.insert(result_types.end(), + partitioned_result_types.begin(), + partitioned_result_types.end()); + } + } + } + return result_types; +} + // Creates a `tf_device.replicate` to represent replication for the cluster, if -// necessary. +// necessary. Erases Identity ops between partitioned and replicated output ops. LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, int num_cores_per_replica) { + OpBuilder builder(cluster); + auto [partitioned_outputs, erase_list] = + GetPartitionedOutputsAndIdentityOps(cluster); + + for (auto [_, ops] : partitioned_outputs) { + if (!(ops.empty() || ops.size() == num_replicas)) { + return (ops.begin())->emitOpError() + << "expected zero or " << num_replicas + << " 'TPUPartitionedOutput' op(s), instead got " + << partitioned_outputs.size(); + } + } + // No need to replicate. - if (num_replicas == 1) return success(); + if (num_replicas == 1) { + // Collapse all the Identity ops between the TRO and TPO ops. + if (!partitioned_outputs.empty()) { + for (TF::IdentityOp to_erase : erase_list) { + Value in = to_erase->getOperand(0); + OpResult out = to_erase->getResult(0); + out.replaceAllUsesWith(in); + to_erase->erase(); + } + } + + return success(); + } if (num_replicas < 1) return cluster.emitError() << "requires '" << kNumReplicasAttr @@ -494,7 +680,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, llvm::SmallVector packed_inputs; llvm::SmallVector replicated_ops; llvm::SmallVector packed_ops; - for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) { + for (const auto& pos_and_input : llvm::enumerate(replicated_input_ops)) { auto input = pos_and_input.value(); bool is_packed = input.getIsPacked(); const int num_operands = input->getNumOperands(); @@ -528,24 +714,28 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, } // Create replicate op. - OpBuilder builder(cluster); + auto result_types = GetClusterResultTypes(cluster, partitioned_outputs); auto replicate_op = builder.create( cluster.getLoc(), num_replicas, llvm::SmallDenseMap>(), - replicated_inputs, packed_inputs, cluster.getResultTypes()); + replicated_inputs, packed_inputs, result_types); if (!mirrored_variable_indices.empty()) replicate_op->setAttr(kMirroredVariableIndicesAttr, builder.getI64ArrayAttr(mirrored_variable_indices)); // Replace replicated cluster results with replicate op results. - for (auto result_and_idx : llvm::enumerate(cluster.getResults())) { - Value result = result_and_idx.value(); - int idx = result_and_idx.index(); - auto replicate_outputs = llvm::make_range( - std::next(replicate_op.result_begin(), idx * num_replicas), - std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); + uint64_t offset = 0; + for (auto [idx, result] : llvm::enumerate(cluster.getResults())) { + if (partitioned_outputs.contains(idx)) { + // Partitioned output propagation happens in BuildPartitionedOutputs. + offset += num_replicas * num_cores_per_replica; + continue; + } + auto replicate_outputs = llvm::make_range( + std::next(replicate_op.result_begin(), offset), + std::next(replicate_op.result_begin(), offset + num_replicas)); for (auto& use : llvm::make_early_inc_range(result.getUses())) { Operation* def = use.getOwner(); if (!llvm::isa(def)) { @@ -562,6 +752,8 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, def->replaceAllUsesWith(replicate_outputs); } + + offset += num_replicas; } // Collect all `tf.TPUPartitionedInputV2` ops to be moved inside the @@ -587,11 +779,20 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, // Create terminator for replicate op and move `tf_device.cluster` and // `tf.TPUPartitionedInputV2`(s) into replicate body. builder.setInsertionPointToEnd(&replicate_op.GetBody()); - auto return_op = builder.create(replicate_op.getLoc(), - cluster.getResults()); - for (auto pi : partitioned_inputs) pi->moveBefore(return_op); - cluster.getOperation()->moveBefore(return_op); + Operation* result_op; + if (!partitioned_outputs.empty()) { + result_op = BuildPartitionedOutputs(builder, cluster, replicate_op, + partitioned_outputs, erase_list, + result_types, num_replicas); + } else { + result_op = builder.create(replicate_op.getLoc(), + cluster.getResults()); + } + + for (auto pi : partitioned_inputs) pi->moveBefore(result_op); + + cluster.getOperation()->moveBefore(result_op); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_splits.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_splits.cc new file mode 100644 index 00000000000..1e7f9c5a4d2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_splits.cc @@ -0,0 +1,85 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +#define GEN_PASS_DEF_TPUCOLOCATESPLITSPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +constexpr char kDeviceAttr[] = "device"; +// Attribute of colocation classes. +constexpr char kClassAttr[] = "_class"; + +bool HasDevice(Operation* op) { + auto attr = op->getAttrOfType(kDeviceAttr); + if (!attr) return false; + return !attr.getValue().empty(); +} + +// Returns the predecessors of `op` when `op`'s predecessors are wrapped by +// islands. +llvm::SmallVector IslandPredecessors(Operation* op) { + llvm::SmallVector predecessors; + for (Value operand : op->getOperands()) { + if (Operation* pred = operand.getDefiningOp()) { + int result_number = llvm::cast(operand).getResultNumber(); + if (auto pred_island = llvm::dyn_cast(pred)) { + Value yield_operand = pred_island.GetYield().getOperand(result_number); + predecessors.push_back(yield_operand.getDefiningOp()); + } + } + } + return predecessors; +} + +struct TPUColocateSplits + : public impl::TPUColocateSplitsPassBase { + void runOnOperation() override; +}; + +void TPUColocateSplits::runOnOperation() { + getOperation().walk([&](Operation* op) { + if (auto split = llvm::dyn_cast(op)) { + if (HasDevice(split) || split->getAttrOfType(kClassAttr)) + return WalkResult::advance(); + for (Operation* pred : IslandPredecessors(split)) { + if (auto colocation_classes = + pred->getAttrOfType(kClassAttr)) { + split->setAttr(kClassAttr, colocation_classes); + return WalkResult::advance(); + } + } + } + return WalkResult::advance(); + }); +} + +} // namespace + +std::unique_ptr> CreateTPUColocateSplitsPass() { + return std::make_unique(); +} + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index c8ad200e328..04b488a3804 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -172,7 +172,7 @@ bool HandleReplicatedInputs( MutableArrayRef inputs = replicate.GetOperandsForBlockArgument(replicate_arg); - for (auto entry : llvm::enumerate(inputs)) { + for (const auto& entry : llvm::enumerate(inputs)) { auto input_op = entry.value().get().getDefiningOp(); if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis)) return false; @@ -181,7 +181,7 @@ bool HandleReplicatedInputs( auto get_layout = BuildGetLayout(execute_arg_index, compilation_key, compile_launch, &builder); builder.setInsertionPoint(replicate); - for (auto entry : llvm::enumerate(inputs)) { + for (const auto& entry : llvm::enumerate(inputs)) { auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch, get_layout, entry.value().get(), &builder); @@ -222,7 +222,7 @@ void HandleCompileAndExecutes( llvm::cast(execute_launch.GetBody().front()); const auto& input_mapping = std::get<1>(execute_and_input_mapping); - for (auto& input_and_idx : llvm::enumerate(execute.getArgs())) { + for (const auto& input_and_idx : llvm::enumerate(execute.getArgs())) { Value input = input_and_idx.value(); const int64_t execute_arg_index = input_and_idx.index(); if (auto block_arg = input.dyn_cast()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index b8eb4c22598..f20db8a9976 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -21,6 +22,7 @@ limitations under the License. #include "absl/strings/match.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -29,6 +31,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -82,7 +85,12 @@ constexpr char kBadArrayAttrLengthMsg[] = namespace { struct TPURewritePass : public impl::TPURewritePassBase { + explicit TPURewritePass(llvm::StringRef _module_name) + : module_name(_module_name) {} + void runOnOperation() override; + + llvm::StringRef module_name; }; // Creates a missing attribute error message. @@ -90,7 +98,8 @@ std::string CreateMissingAttributeMsg(llvm::StringRef attribute) { return llvm::formatv("requires attribute '{0}'", attribute).str(); } -LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func, +LogicalResult EncapsulateFuncAndSerialize(const std::string& module_name, + func::FuncOp entry_func, std::string* serialized_func_module) { ModuleOp module = entry_func->getParentOfType(); SymbolTable entry_module_table(module); @@ -98,7 +107,8 @@ LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func, // Create a new module to hold func and all referenced functions. OwningOpRef module_for_func = - ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext())); + ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()), + absl::StrCat("module_", module_name)); auto parent_module = entry_func->getParentOfType(); auto versions_attr = parent_module->getAttr(kVersionsAttr); if (!versions_attr) @@ -207,6 +217,15 @@ LogicalResult SetMetadataProtoArgs( // Set args metadata in proto. mlir::StringAttr replication_attr_name = mlir::StringAttr::get( op.getContext(), "mhlo.is_same_data_across_replicas"); + + auto dynamic_arg_idx = op->getAttrOfType(TF::kDynamicArgIndexAttr); + llvm::SmallSet dynamic_arg_idx_set; + if (dynamic_arg_idx) { + for (auto idx : dynamic_arg_idx.getValue()) { + dynamic_arg_idx_set.insert(idx.dyn_cast().getInt()); + } + } + for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) { Type operand_type = operand_type_and_idx.value(); int index = operand_type_and_idx.index(); @@ -217,7 +236,7 @@ LogicalResult SetMetadataProtoArgs( if (!status.ok()) return op.emitOpError( llvm::formatv("failed to determine operand type at index {0}: {1}", - index, status.error_message())); + index, status.message())); arg->set_dtype(dtype); // TODO(lyandy): Support other arg kinds. @@ -247,6 +266,10 @@ LogicalResult SetMetadataProtoArgs( mlir::UnitAttr attr = op.getFuncOp().getArgAttrOfType( index, replication_attr_name); arg->set_is_same_data_across_replicas(attr != nullptr); + + // Currently only support first dimension to be bounded dynamic. + arg->mutable_is_bounded_dynamic_dim()->Add( + dynamic_arg_idx_set.contains(index)); } return success(); @@ -336,12 +359,14 @@ tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc, // Create a `tf._TPUCompileMlir` that contains a MLIR module that is // functionally equivalent to the function referenced by cluster_func. Operation* BuildCompileOp( - tf_device::ClusterFuncOp cluster_func, int num_replicas, - int num_cores_per_replica, llvm::StringRef compilation_device, + llvm::StringRef module_name, tf_device::ClusterFuncOp cluster_func, + int num_replicas, int num_cores_per_replica, + llvm::StringRef compilation_device, std::optional&& xla_device_assignment, OpBuilder* builder, bool tpu_compile_metadata_debug) { // Set metadata from attributes. tensorflow::tpu::TPUCompileMetadataProto metadata; + if (!module_name.empty()) metadata.set_module_name(module_name.str()); if (failed(SetMetadataProtoFromClusterFuncOp( cluster_func, num_replicas, num_cores_per_replica, std::move(xla_device_assignment), &metadata))) @@ -373,7 +398,10 @@ Operation* BuildCompileOp( func_attr.getValue()); std::string txt_module; - if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr; + if (failed(EncapsulateFuncAndSerialize( + module_name.empty() ? "unknown_graph" : module_name.str(), func, + &txt_module))) + return nullptr; auto compilation_status_type = RankedTensorType::get({}, builder->getType()); @@ -419,24 +447,24 @@ void AssignDevicesToReplicate( for (int core = 0; core < num_cores_per_replica; ++core) { llvm::SmallVector devices_by_core; devices_by_core.reserve(num_replicas); - for (int replica = 0; replica < num_replicas; ++replica) + llvm::SmallVector hosts_by_core; + hosts_by_core.reserve(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) { devices_by_core.push_back(tpu_devices[replica][core].device); + hosts_by_core.push_back(tpu_devices[replica][core].host); + } device_attrs.push_back( builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core), builder->getStrArrayAttr(devices_by_core))); + + // For data parallelism, also add replicated host devices, as these are + // necessary for outside compilation. + device_attrs.push_back(builder->getNamedAttr( + tensorflow::GetDeviceAliasForHostOfLogicalCore(core), + builder->getStrArrayAttr(hosts_by_core))); } - // For data parallelism, also add replicated host devices, as these are - // necessary for outside compilation. - llvm::SmallVector hosts; - hosts.reserve(num_replicas); - for (int replica = 0; replica < num_replicas; ++replica) - hosts.push_back(tpu_devices[replica][0].host); - - device_attrs.push_back(builder->getNamedAttr( - tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts))); - replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs)); } @@ -715,7 +743,7 @@ int GetNumResultsPreCluster(tf_device::ParallelExecuteOp parallel_execute) { } LogicalResult Rewrite( - tf_device::ClusterFuncOp cluster_func, + llvm::StringRef module_name, tf_device::ClusterFuncOp cluster_func, llvm::ArrayRef devices, ArrayRef compilation_result, OpBuilder* builder, bool tpu_compile_metadata_debug) { @@ -782,7 +810,7 @@ LogicalResult Rewrite( if (!status_or_device_coodinates.ok()) return cluster_func.emitError() << "error in fetching tpu device coordinates: " - << status_or_device_coodinates.status().error_message(); + << status_or_device_coodinates.status().message(); // Determine compilation and execution devices. auto status_or_tpu_device_assignment = @@ -792,7 +820,7 @@ LogicalResult Rewrite( if (!status_or_tpu_device_assignment.ok()) return cluster_func.emitError() << "error in fetching TPU compilation/execution devices: " - << status_or_tpu_device_assignment.status().error_message(); + << status_or_tpu_device_assignment.status().message(); // Create compile op. auto& tpu_device_assignment = status_or_tpu_device_assignment.value(); @@ -800,11 +828,11 @@ LogicalResult Rewrite( // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of // the parallel_execute. builder->setInsertionPoint(old_parallel_execute); - Operation* compile_op = - BuildCompileOp(cluster_func, num_replicas, num_cores_per_replica, - tpu_device_assignment.compilation_device, - std::move(tpu_device_assignment.xla_device_assignment), - builder, tpu_compile_metadata_debug); + Operation* compile_op = BuildCompileOp( + module_name, cluster_func, num_replicas, num_cores_per_replica, + tpu_device_assignment.compilation_device, + std::move(tpu_device_assignment.xla_device_assignment), builder, + tpu_compile_metadata_debug); if (!compile_op) return failure(); // This replaces _TPUCompileMlir placeholder ops that are required @@ -940,7 +968,7 @@ void TPURewritePass::runOnOperation() { auto cluster_id = op->getAttrOfType(TF::kReplicationInfoAttr); if (!cluster_id) return WalkResult::advance(); - if (failed(Rewrite(op, devices.device_names(), + if (failed(Rewrite(module_name, op, devices.device_names(), compilation_results[cluster_id], &builder, tpu_compile_metadata_debug_))) return WalkResult::interrupt(); @@ -970,8 +998,9 @@ void TPURewritePass::runOnOperation() { } // namespace -std::unique_ptr> CreateTPURewritePass() { - return std::make_unique(); +std::unique_ptr> CreateTPURewritePass( + llvm::StringRef module_name) { + return std::make_unique(module_name); } } // namespace TFTPU diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index e3f9d15e5c6..bb8dd429174 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -16,6 +16,8 @@ limitations under the License. #include #include #include +#include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -35,6 +37,7 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -46,6 +49,9 @@ namespace mlir { namespace TFTPU { namespace { +using OpShardingVariant = std::variant; +using OpShardingVector = llvm::SmallVector; + constexpr char kReplicateSharding[] = ""; constexpr char kShardingAttr[] = "mhlo.sharding"; constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning"; @@ -65,20 +71,79 @@ std::string CreateMissingAttributeMsg(llvm::StringRef attribute) { return llvm::formatv("requires attribute '{0}'", attribute).str(); } -// Returns XLA sharding from TPUPartitionedInput op connected to a -// `tf_device.cluster_func` operand value. If value is a resource type then +// Returns nullptr if the op does not have a sharding attribute. +template +mlir::Operation* NullUnlessSharded(PartitionedOp op) { + return op.get_XlaSharding() ? op : nullptr; +} + +// Returns a TPUPartitionedInput op connected to a `tf_device.cluster_func` +// operand value if it has an XLA sharding. If value is a resource type then // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into // a `tf_device.cluster_func`. -std::optional GetXlaShardingFromOperand(Value value) { +mlir::Operation* GetXlaShardingFromOperand(Value value) { Value value_to_visit = value; if (auto read_var = value_to_visit.getDefiningOp()) value_to_visit = read_var.getResource(); if (auto partitioned_input = - value_to_visit.getDefiningOp()) - return partitioned_input.get_XlaSharding(); + value_to_visit.getDefiningOp()) { + return NullUnlessSharded(partitioned_input); + } - return std::nullopt; + return nullptr; +} + +// Returns the op sharding attribute from a partitioned operator. +std::optional GetXlaShardingFromOperator(mlir::Operation* op) { + if (auto partitioned_output = + llvm::dyn_cast(op)) { + return partitioned_output.get_XlaSharding(); + } else if (auto partitioned_input = + llvm::dyn_cast(op)) { + return partitioned_input.get_XlaSharding(); + } else { + return std::nullopt; + } +} + +// Returns the sharding string from a op-sharding variant if it is available. +std::optional GetShardingStringFromVariant( + const OpShardingVariant& sharding_or_op) { + return std::visit( + [](auto&& sharding_or_op) -> std::optional { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return sharding_or_op; + } else { + return GetXlaShardingFromOperator(sharding_or_op); + } + }, + sharding_or_op); +} + +// Returns the sharding from a op-sharding variant if it is available and valid. +std::optional GetShardingFromVariant( + const OpShardingVariant& sharding_or_op) { + xla::OpSharding sharding; + const auto sharding_string = GetShardingStringFromVariant(sharding_or_op); + if (sharding_string && sharding.ParseFromString(sharding_string->str())) { + return sharding; + } else { + return std::nullopt; + } +} + +// Converts an op-sharding vector into a string attr using the builder. +mlir::ArrayAttr GetStrArrayAttr(Builder* builder, + const OpShardingVector& vect) { + llvm::SmallVector strings; + for (const auto& sharding_or_op : vect) { + if (const auto sharding = GetShardingStringFromVariant(sharding_or_op)) { + strings.emplace_back(builder->getStringAttr(*sharding)); + } + } + return builder->getArrayAttr(strings); } // Given a `tf_device.cluster_func` operand value return true iff it a device @@ -96,19 +161,37 @@ bool IsMaximalVariable(Value value) { // on CPU) // If the sharding is incorrect, return failure. If it's good, or if we can't // verify it, return success. -LogicalResult VerifySharding(Type type, StringRef sharding_string) { - xla::OpSharding sharding; - if (!sharding.ParseFromString(sharding_string.str())) { +LogicalResult VerifySharding(mlir::Type type, + const OpShardingVariant& sharding_or_op) { + auto* partitioned_op = + std::holds_alternative(sharding_or_op) + ? std::get(sharding_or_op) + : nullptr; + const auto sharding = GetShardingFromVariant(sharding_or_op); + if (!sharding || sharding->type() != xla::OpSharding::OTHER) { // Some test cases use \01\02\03 as sharding, to test propagation. Treat - // a non-proto sharding as valid, and don't verify further. - return success(); - } - if (sharding.type() != xla::OpSharding::OTHER) { - // We currently only verify shardings that actually break a tensor apart. + // a non-proto sharding as valid, and don't verify further. We also only + // verify shardings that actually break a tensor apart. return success(); } if (RankedTensorType ranked_type = type.dyn_cast()) { - if (ranked_type.getRank() < sharding.tile_assignment_dimensions_size()) { + const int64_t tensor_rank = ranked_type.getRank(); + int tile_assignment_rank = sharding->tile_assignment_dimensions_size(); + + // When a tensor is partial or subgroup tiled, its tile assignment will + // have one or more dimension(s) than its rank; so, we subtract them to + // determine which rank the sharding is compatible with. + tile_assignment_rank -= (int)sharding->replicate_on_last_tile_dim(); + tile_assignment_rank -= sharding->last_tile_dims_size(); + + if (tensor_rank < tile_assignment_rank) { + if (partitioned_op) { + partitioned_op->emitError() + << "tensor of type " << ranked_type << " (rank=" << tensor_rank + << ") sharded in " << (tile_assignment_rank - tensor_rank) + << " extra dimension(s) by: " << sharding->DebugString(); + } + return failure(); } } @@ -116,21 +199,20 @@ LogicalResult VerifySharding(Type type, StringRef sharding_string) { } // Verify sharding for all arguments and return values. -LogicalResult VerifyShardings( - mlir::func::FuncOp func, - const llvm::SmallVectorImpl& sharding_for_args, - const llvm::SmallVectorImpl& sharding_for_rets) { +LogicalResult VerifyShardings(mlir::func::FuncOp func, + const OpShardingVector& sharding_for_args, + const OpShardingVector& sharding_for_rets) { Block& function_block = func.front(); for (auto sharding_and_arg : llvm::zip(sharding_for_args, function_block.getArguments())) { - StringRef sharding = std::get<0>(sharding_and_arg); + const auto& sharding = std::get<0>(sharding_and_arg); BlockArgument arg = std::get<1>(sharding_and_arg); if (failed(VerifySharding(arg.getType(), sharding))) return failure(); } Operation* terminator = function_block.getTerminator(); for (auto sharding_and_retval : llvm::zip(sharding_for_rets, terminator->getOpOperands())) { - StringRef sharding = std::get<0>(sharding_and_retval); + const auto& sharding = std::get<0>(sharding_and_retval); OpOperand& retval = std::get<1>(sharding_and_retval); if (failed(VerifySharding(retval.get().getType(), sharding))) return failure(); @@ -215,8 +297,7 @@ std::optional GetXlaShardingFromArg( void IdentifyXlaShardingForComputationInputs( const llvm::SmallVector& logical_device_vec, bool use_spmd, bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, - func::FuncOp func, Builder* builder, - llvm::SmallVectorImpl& sharding_for_args) { + func::FuncOp func, Builder* builder, OpShardingVector& sharding_for_args) { // Look up function definition from module. Block& function_block = func.front(); @@ -245,7 +326,7 @@ void IdentifyXlaShardingForComputationInputs( BlockArgument arg = std::get<1>(operand_and_arg); if (auto operand_sharding = GetXlaShardingFromOperand(operand)) { - sharding_for_args.push_back(operand_sharding.value()); + sharding_for_args.push_back(operand_sharding); continue; } @@ -271,24 +352,24 @@ void IdentifyXlaShardingForComputationInputs( } } -// Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via -// AssignVariableOp/resource write) op connected to a `tf_device.cluster_func` -// result value. -std::optional GetXlaShardingFromResult(Value value) { - if (!value.hasOneUse()) return std::nullopt; +// Returns a TPUPartitionedOutput or TPUPartitionedInput op with XLA sharding +// connected to a `tf_device.cluster_func` result value (via AssignVariableOp/ +// resource write). +mlir::Operation* GetXlaShardingFromResult(Value value) { + if (!value.hasOneUse()) return nullptr; Operation* user = *value.getUsers().begin(); if (auto partitioned_output = llvm::dyn_cast(user)) - return partitioned_output.get_XlaSharding(); + return NullUnlessSharded(partitioned_output); if (auto assign_var = llvm::dyn_cast(user)) if (auto partitioned_input = assign_var.getResource() .getDefiningOp()) - return partitioned_input.get_XlaSharding(); + return NullUnlessSharded(partitioned_input); - return std::nullopt; + return nullptr; } // Looks up arg->retval aliases for every argument, and builds a reverse map. @@ -307,12 +388,12 @@ void ExtractAliases(func::FuncOp func, llvm::SmallVectorImpl& aliases) { // Returns XLA sharding from argument connected via tf.aliasing_output. std::optional GetXlaShardingFromAlias( Value value, llvm::SmallVectorImpl& aliases, - const llvm::SmallVectorImpl& sharding_for_args) { + const OpShardingVector& sharding_for_args) { int retval_index = value.cast().getResultNumber(); if (retval_index >= 0 && retval_index < aliases.size()) { int arg_index = aliases[retval_index]; if (arg_index >= 0 && arg_index < sharding_for_args.size()) { - return sharding_for_args[arg_index]; + return GetShardingStringFromVariant(sharding_for_args[arg_index]); } } return std::nullopt; @@ -394,8 +475,8 @@ void IdentifyXlaShardingForComputationOutputs( const llvm::SmallVector& logical_device_vec, bool use_spmd, bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, func::FuncOp func, Builder* builder, - const llvm::SmallVectorImpl& sharding_for_args, - llvm::SmallVectorImpl& sharding_for_rets) { + const OpShardingVector& sharding_for_args, + OpShardingVector& sharding_for_rets) { Block& function_block = func.front(); Operation* terminator = function_block.getTerminator(); sharding_for_rets.reserve(terminator->getNumOperands()); @@ -418,7 +499,7 @@ void IdentifyXlaShardingForComputationOutputs( OpOperand& retval = std::get<1>(result_and_retval); if (auto result_sharding = GetXlaShardingFromResult(result)) { - sharding_for_rets.push_back(result_sharding.value()); + sharding_for_rets.push_back(result_sharding); continue; } @@ -477,21 +558,21 @@ LogicalResult IdentifyXlaShardingForTPUComputation( xla::sharding_builder::AssignDevice(idx).SerializeAsString(); } - llvm::SmallVector sharding_for_args; + OpShardingVector sharding_for_args; IdentifyXlaShardingForComputationInputs(logical_device_vec, use_spmd, /*infer_from_computation=*/true, cluster_func, func, builder, sharding_for_args); - llvm::SmallVector sharding_for_rets; + OpShardingVector sharding_for_rets; IdentifyXlaShardingForComputationOutputs( logical_device_vec, use_spmd, /*infer_from_computation=*/true, cluster_func, func, builder, sharding_for_args, sharding_for_rets); - auto has_maximal_sharding = [](llvm::StringRef sharding_string) -> bool { - xla::OpSharding sharding; - sharding.ParseFromString(sharding_string.str()); - return sharding.type() == xla::OpSharding::MAXIMAL; + auto has_maximal_sharding = + [](const OpShardingVariant& sharding_or_op) -> bool { + const auto sharding = GetShardingFromVariant(sharding_or_op); + return sharding && sharding->type() == xla::OpSharding::MAXIMAL; }; // XLA SPMD only supports cases where all inputs/outputs exist on every @@ -523,26 +604,30 @@ LogicalResult IdentifyXlaShardingForTPUComputation( Block& function_block = func.front(); for (auto sharding_and_arg : llvm::zip(sharding_for_args, function_block.getArguments())) { - StringRef sharding = std::get<0>(sharding_and_arg); BlockArgument arg = std::get<1>(sharding_and_arg); - func.setArgAttr(arg.getArgNumber(), kShardingAttr, - builder->getStringAttr(sharding)); + const auto& sharding_or_op = std::get<0>(sharding_and_arg); + if (auto sharding = GetShardingStringFromVariant(sharding_or_op)) { + func.setArgAttr(arg.getArgNumber(), kShardingAttr, + builder->getStringAttr(*sharding)); + } } Operation* terminator = function_block.getTerminator(); for (auto sharding_and_retval : llvm::zip(sharding_for_rets, terminator->getOpOperands())) { - StringRef sharding = std::get<0>(sharding_and_retval); OpOperand& retval = std::get<1>(sharding_and_retval); - func.setResultAttr(retval.getOperandNumber(), kShardingAttr, - builder->getStringAttr(sharding)); + const auto& sharding_or_op = std::get<0>(sharding_and_retval); + if (auto sharding = GetShardingStringFromVariant(sharding_or_op)) { + func.setResultAttr(retval.getOperandNumber(), kShardingAttr, + builder->getStringAttr(*sharding)); + } } // Update input/output sharding attributes on tf_device.cluster_func op. cluster_func->setAttr(tensorflow::kInputShardingAttr, - builder->getStrArrayAttr(sharding_for_args)); + GetStrArrayAttr(builder, sharding_for_args)); cluster_func->setAttr(tensorflow::kOutputShardingAttr, - builder->getStrArrayAttr(sharding_for_rets)); + GetStrArrayAttr(builder, sharding_for_rets)); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc index c6909581e17..dd5465bf5d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc @@ -12,11 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include +#include "absl/strings/match.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" namespace mlir { namespace TFTPU { @@ -26,34 +38,151 @@ namespace { #define GEN_PASS_DEF_TPUVALIDATEINPUTSPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +typedef std::unordered_map MetadataMap; + struct TPUValidateInputsPass : public impl::TPUValidateInputsPassBase { void runOnOperation() override; }; +bool IsTpuRegularOp(Operation* op) { + static auto* ops = [] { + llvm::SmallDenseSet* ops_set = + new llvm::SmallDenseSet{ + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + }; + return ops_set; + }(); + auto abstractOp = op->getRegisteredInfo(); + if (!abstractOp) return true; + return ops->count(abstractOp->getTypeID()) == 0; +} -bool ValidateReplicatedInput(TF::TPUReplicatedInputOp rep, int num_replicas) { - int arity = rep.getInputs().size(); - if (rep.getIsPacked() && arity != 1) { - rep.emitOpError( - "TF/XLA TPU bridge input check: packed with number of inputs not 1.") - << " num_replicas=" << num_replicas << " no. of inputs=" << arity; +bool IsIntersectionXlaNonXlaOps(Operation* op) { + static auto* ops = [] { + llvm::SmallDenseSet* ops_set = + new llvm::SmallDenseSet{ + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + }; + return ops_set; + }(); + auto abstractOp = op->getRegisteredInfo(); + if (!abstractOp) return true; + return ops->count(abstractOp->getTypeID()) == 0; +} + +bool IsPartitionedOp(Operation* op) { + static auto* ops = [] { + llvm::SmallDenseSet* ops_set = + new llvm::SmallDenseSet{ + TypeID::get(), + TypeID::get(), + TypeID::get(), + }; + return ops_set; + }(); + auto abstractOp = op->getRegisteredInfo(); + if (!abstractOp) return false; + return ops->count(abstractOp->getTypeID()) != 0; +} + +// Gets the successors of an op wrapped in a tf_executor.island. +llvm::SmallVector GetSuccessors(Operation* op) { + llvm::SmallVector successors; + for (auto result : op->getParentOp()->getOpResults()) { + for (auto& use : result.getUses()) { + auto succ = use.getOwner(); + successors.push_back(succ); + } + } + return successors; +} +// Gets the predecessors of an op wrapped in tf_executor.island. +llvm::SmallVector GetPredecessors(Operation* op) { + llvm::SmallVector predecessors; + for (auto operand : op->getOperands()) { + if (Operation* pred = operand.getDefiningOp()) { + pred->walk([&](mlir::Operation* opinexecutor) { + predecessors.push_back(opinexecutor); + }); + } + } + return predecessors; +} + +bool CheckTpuReplicateAttr(Operation* op, StringAttr attr, + std::function errormsg) { + if (!op->hasAttr(TF::kTpuReplicateAttr)) { + op->emitOpError("TF2XLA TPU bridge input check: " + errormsg() + + "missing _tpu_replicate attr"); return false; - } else if (!rep.getIsPacked() && arity != num_replicas) { - rep.emitOpError( - "TF/XLA TPU bridge input check: number of inputs inconsistent.") - << " num_replicas=" << num_replicas << " no. of inputs=" << arity; + } + auto opattr = op->getAttr(TF::kTpuReplicateAttr); + if (opattr != attr) { + op->emitOpError("TF2XLA TPU bridge input check: " + errormsg() + + "invalid _tpu_replicate attr.") + << " Expected attr: " << attr << ", Actual attr: " << opattr; return false; } return true; } -bool ValidateReplicatedOutput(TF::TPUReplicatedOutputOp rep, int num_replicas) { + +bool ValidateReplicatedInput(TF::TPUReplicatedInputOp rep, int num_replicas, + StringAttr attr) { + int arity = rep.getInputs().size(); + if (rep.getIsPacked() && arity != 1) { + rep.emitOpError( + "TF2XLA TPU bridge input check: packed with number of inputs not 1.") + << " num_replicas=" << num_replicas << " no. of inputs=" << arity; + return false; + } else if (!rep.getIsPacked() && arity != num_replicas) { + rep.emitOpError( + "TF2XLA TPU bridge input check: number of inputs inconsistent.") + << " num_replicas=" << num_replicas << " no. of inputs=" << arity; + return false; + } + for (auto& succ : GetSuccessors(rep)) { + if (!IsTpuRegularOp(succ)) continue; + auto errormsg = [&]() -> std::string { + return rep->getName().getStringRef().str() + " op has successor op " + + succ->getName().getStringRef().str() + " with error: "; + }; + if (!CheckTpuReplicateAttr(succ, attr, errormsg)) return false; + } + return true; +} +bool ValidateReplicatedOutput(TF::TPUReplicatedOutputOp rep, int num_replicas, + StringAttr attr) { int arity = rep.getOutputs().size(); if (arity != num_replicas) { rep.emitOpError( - "TF/XLA TPU bridge input check: number of outputs inconsistent.") + "TF2XLA TPU bridge input check: number of outputs inconsistent.") << " num_replicas=" << num_replicas << " no. of outputs=" << arity; return false; } + for (auto& pred : GetPredecessors(rep)) { + if (!IsTpuRegularOp(pred)) continue; + auto errormsg = [&]() -> std::string { + return rep->getName().getStringRef().str() + " op has predecessor op " + + pred->getName().getStringRef().str() + " with error: "; + }; + if (!CheckTpuReplicateAttr(pred, attr, errormsg)) return false; + } return true; } bool ValidatePartitionedInput(TF::TPUPartitionedInputOp rep, @@ -61,7 +190,7 @@ bool ValidatePartitionedInput(TF::TPUPartitionedInputOp rep, int arity = rep.getInputs().size(); if (arity != num_cores_per_replica) { rep.emitOpError( - "TF/XLA TPU bridge input check: number of inputs inconsistent.") + "TF2XLA TPU bridge input check: number of inputs inconsistent.") << " num_cores_per_replica=" << num_cores_per_replica << " no. of inputs=" << arity; return false; @@ -73,13 +202,13 @@ bool ValidatePartitionedInputV2(TF::TPUPartitionedInputV2Op rep, int arity = rep.getInputs().size(); if (rep.getIsPacked() && arity != 1) { rep.emitOpError( - "TF/XLA TPU bridge input check: packed with number of inputs not 1.") + "TF2XLA TPU bridge input check: packed with number of inputs not 1.") << " num_cores_per_replicas=" << num_cores_per_replica << " no. of inputs=" << arity; return false; } else if (!rep.getIsPacked() && arity != num_cores_per_replica) { rep.emitOpError( - "TF/XLA TPU bridge input check: number of inputs inconsistent.") + "TF2XLA TPU bridge input check: number of inputs inconsistent.") << " num_cores_per_replica=" << num_cores_per_replica << " no. of inputs=" << arity; return false; @@ -91,49 +220,229 @@ bool ValidatePartitionedOutput(T rep, int num_cores_per_replica) { int arity = rep.getOutput().size(); if (arity != num_cores_per_replica) { rep.emitOpError( - "TF/XLA TPU bridge input check: number of outputs inconsistent.") + "TF2XLA TPU bridge input check: number of outputs inconsistent.") << " num_cores_per_replica=" << num_cores_per_replica << " no. of outputs=" << arity; return false; } return true; } + +bool CheckReplicatedIOOp(Operation* op, TF::TPUReplicateMetadataOp metadata, + Operation* parent) { + int num_replicas = metadata.getNumReplicas(); + int num_cores_per_replica = metadata.getNumCoresPerReplica(); + StringAttr tpu_replicate_attr = + metadata->getAttrOfType(TF::kTpuReplicateAttr); + if (auto repinput = dyn_cast(op)) { + if (!ValidateReplicatedInput(repinput, num_replicas, tpu_replicate_attr)) + return false; + } + if (auto repoutput = dyn_cast(op)) { + if (!ValidateReplicatedOutput(repoutput, num_replicas, tpu_replicate_attr)) + return false; + } + if (auto partinput = dyn_cast(op)) { + if (!ValidatePartitionedInput(partinput, num_cores_per_replica)) + return false; + } + if (auto partinput = dyn_cast(op)) { + if (!ValidatePartitionedInputV2(partinput, num_cores_per_replica)) + return false; + } + if (auto partoutput = dyn_cast(op)) { + if (!ValidatePartitionedOutput(partoutput, num_cores_per_replica)) + return false; + } + if (auto partoutput = dyn_cast(op)) { + if (!ValidatePartitionedOutput(partoutput, num_cores_per_replica)) + return false; + } + return true; +} +// Checking op which is successor to a cluster op. +bool CheckClusterSuccessors(Operation* op, std::string cluster, + Operation* parent, MetadataMap& metadata_map) { + std::string cluster_succ = ""; + if (op->hasAttr(TF::kTpuReplicateAttr)) { + cluster_succ = op->getAttrOfType(TF::kTpuReplicateAttr).str(); + } + if (cluster_succ.empty()) { + // TODO (b/269195256#comment16): Change to error after resolving issue + // with test. Will fix it after the upstream code is fixed. + op->emitWarning("TF2XLA TPU bridge input check: cluster op = ") + << parent->getName() << " with cluster = " << cluster + << " has successor as non cluster op " << op->getName(); + return true; + } + if (cluster != cluster_succ) { + op->emitOpError( + "TF2XLA TPU bridge input check: mismatch clusters tpu_replicate " + "attr. Parent op ") + << parent->getName() << " with cluster = " << cluster + << " has successor cluster op " << op->getName() + << " with cluster = " << cluster_succ; + return false; + } + return true; +} + +// Checking op which is a predecessor to a non-cluster op. +bool CheckNonClusterSuccessors(Operation* op, Operation* parent, + MetadataMap& metadata_map) { + if (!IsTpuRegularOp(op)) { + if (isa(op)) { + op->emitOpError("TF2XLA TPU bridge input check: non-cluster op = ") + << parent->getName() + << " has invalid successor op = " << op->getName(); + return false; + } else { + return true; + } + } + return true; +} +// Checking op which is a successor to a non-cluster op. +bool CheckNonClusterPredecessors(Operation* op, Operation* parent, + MetadataMap& metadata_map) { + if (!IsTpuRegularOp(op)) { + if (isa(op)) { + op->emitOpError("TF2XLA TPU bridge input check: non-cluster op = ") + << parent->getName() + << " has invalid predecessor op = " << op->getName(); + return false; + } else { + return true; + } + } + return true; +} + +bool CheckOpsClusterIO(Operation* op, MetadataMap& metadata_map) { + bool is_cluster_op = false; + std::string cluster = ""; + if (op->hasAttr(TF::kTpuReplicateAttr)) { + cluster = op->getAttrOfType(TF::kTpuReplicateAttr).str(); + if (cluster.empty()) { + op->emitOpError("TF2XLA TPU bridge input check: empty _tpu_replicate") + << " attr for op = " << op->getName(); + return false; + } + is_cluster_op = true; + } + bool has_cluster_metadata = + (metadata_map.find(cluster) != metadata_map.end()); + + for (auto pred : GetPredecessors(op)) { + if (is_cluster_op && !IsTpuRegularOp(pred) && has_cluster_metadata) { + if (!CheckReplicatedIOOp(pred, metadata_map[cluster], op)) return false; + } + if (!is_cluster_op) { + if (!CheckNonClusterPredecessors(pred, op, metadata_map)) return false; + } + } + + for (auto succ : GetSuccessors(op)) { + if (is_cluster_op && !IsTpuRegularOp(succ) && has_cluster_metadata) { + if (!CheckReplicatedIOOp(succ, metadata_map[cluster], op)) return false; + } + if (is_cluster_op && IsTpuRegularOp(succ)) { + if (!CheckClusterSuccessors(succ, cluster, op, metadata_map)) + return false; + } + if (!is_cluster_op) { + if (!CheckNonClusterSuccessors(succ, op, metadata_map)) return false; + } + } + return true; +} + +bool TypeMustBeNonXLA(const Type& type) { + const Type elem = getElementTypeOrSelf(type); + return !elem.isa() && !tensorflow::TypeValidForXLA(type); +} + +// Check if the op cannot be XLA compiled. If the op does not satisfy this +// criteria, then it is possible for the op to be XLA and non-XLA. But this +// function specifically checks if the op must be non-xla. +bool IsMustNotBeXlaOp(Operation* op) { + for (auto& input : op->getOpOperands()) { + if (TypeMustBeNonXLA(input.get().getType())) return true; + } + for (auto output_types : op->getResultTypes()) { + if (TypeMustBeNonXLA(output_types)) return true; + } + return false; +} + +// Check if the op must be compiled with XLA. If the op does not satisfy this +// critiria for "must be xla" then it is still possible for this op to be xla +// and non-xla as well. But below function specifically checks for the op to be +// only XLA op. +bool IsMustBeXlaOp(Operation* op, MetadataMap metadata_map) { + // All PartitionedCall are inlined-out before XLA. + // So MustBeXLA should return false + if (IsPartitionedOp(op)) return false; + if (!op->hasAttr(TF::kTpuReplicateAttr)) return false; + auto cluster = op->getAttrOfType(TF::kTpuReplicateAttr).str(); + if (metadata_map.find(cluster) == metadata_map.end()) return false; + auto metadata = metadata_map[cluster]; + if (!metadata.getAllowSoftPlacement() && + !op->hasAttr(TF::kXlaOutsideCompilationAttr)) + return true; + std::string device = ""; + if (op->hasAttr(TF::kDeviceAttr)) + device = op->getAttrOfType(TF::kDeviceAttr).str(); + else + return false; + if (absl::StrContains(device, TF::kTpuDevice)) return true; + return false; +} +bool ValidateIntersectionXlaNonXlaOps(Operation* op, MetadataMap metadata_map) { + if (isa(op) || + isa(op) || isa(op) || + isa(op) || + isa(op) || + isa(op) || + isa(op)) + return true; + if (IsMustBeXlaOp(op, metadata_map) && IsMustNotBeXlaOp(op)) { + // TODO(b/269195256#comment19) change the warning for Identity op to error + // when issue with input graph is resolved. Possible issue with python layer + // inserting Identity op incorrectly. + if (isa(op)) { + op->emitWarning("TF/XLA TPU bridge input check: found invalid op. ") + << op->getName() << " can't be both xla and non-xla"; + return true; + } + op->emitOpError("TF/XLA TPU bridge input check: found invalid op. ") + << "Can't be both xla and non-xla"; + return false; + } + return true; +} + void TPUValidateInputsPass::runOnOperation() { ModuleOp module = getOperation(); bool success = true; int num_metadata = 0; TF::TPUReplicateMetadataOp metadata; + MetadataMap metadata_map; module.walk([&](TF::TPUReplicateMetadataOp meta) { ++num_metadata; metadata = meta; + metadata_map[meta->getAttrOfType(TF::kTpuReplicateAttr).str()] = + meta; + }); + + getOperation().walk([&](mlir::Operation* op) { + if (IsTpuRegularOp(op)) { + success &= CheckOpsClusterIO(op, metadata_map); + } + if (IsIntersectionXlaNonXlaOps(op)) { + success &= ValidateIntersectionXlaNonXlaOps(op, metadata_map); + } }); - // TODO(b/269195256): support multi-TPUReplicateMetadata case. - // Currently handling case with one metadata op / cluster. Further CLs will - // address cases with multi-TPUReplicatedMetadata. - if (num_metadata == 1) { - int num_replicas = metadata.getNumReplicas(); - int num_cores_per_replica = metadata.getNumCoresPerReplica(); - module.walk([&](mlir::Operation* op) { - if (auto repinput = dyn_cast(op)) { - success &= ValidateReplicatedInput(repinput, num_replicas); - } - if (auto repoutput = dyn_cast(op)) { - success &= ValidateReplicatedOutput(repoutput, num_replicas); - } - if (auto partinput = dyn_cast(op)) { - success &= ValidatePartitionedInput(partinput, num_cores_per_replica); - } - if (auto partinput = dyn_cast(op)) { - success &= ValidatePartitionedInputV2(partinput, num_cores_per_replica); - } - if (auto partoutput = dyn_cast(op)) { - success &= ValidatePartitionedOutput(partoutput, num_cores_per_replica); - } - if (auto partoutput = dyn_cast(op)) { - success &= ValidatePartitionedOutput(partoutput, num_cores_per_replica); - } - }); - } if (!success) { signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc index df3aba4eeb0..f876231ab00 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc @@ -57,6 +57,9 @@ void EncapsulatePartitionedCall(Operation *call_op) { builder.setInsertionPointToEnd(&cluster.GetBody()); builder.create(call_op->getLoc(), call_op->getResults()); + // Propagate necessary attributes to the cluster so that when it's outlined, + // the function will have correct attributes. + TF::CopyDeviceAndUnderscoredAttributes(call_op, cluster); } void XlaClusterFormationPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc index 550e5804430..c59d6e532d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc @@ -60,14 +60,11 @@ void MoveResourceArgsToEnd(func::FuncOp callee) { callee.getResultTypes())); } -template ::value>::type * = nullptr> -void RewriteCall(OpT call_op, SymbolTable &symtab) { +void RewriteCall(tf_device::ClusterFuncOp cluster_func_op, SymbolTable &symtab, + OpBuilder &builder) { llvm::SmallVector non_resource_args, resource_args; bool has_resources = false, in_order = true; - for (const Value &arg : call_op.getArgs()) { + for (const Value &arg : cluster_func_op.getOperands()) { if (!getElementTypeOrSelf(arg.getType()).template isa()) { non_resource_args.push_back(arg); if (has_resources) in_order = false; @@ -80,33 +77,26 @@ void RewriteCall(OpT call_op, SymbolTable &symtab) { if (!in_order) { // Functions do not get reused in practice, so skip the check for if the // callee has been updated. - StringAttr callee_sym = - cast(call_op.getFAttr()).getRootReference(); + StringAttr callee_sym = cluster_func_op.getFuncAttr().getAttr(); MoveResourceArgsToEnd(symtab.lookup(callee_sym)); } - OpBuilder builder(call_op->getContext()); - builder.setInsertionPoint(call_op); + builder.setInsertionPoint(cluster_func_op); auto xla_launch_op = builder.create( - call_op.getLoc(), call_op.getResultTypes(), + cluster_func_op.getLoc(), cluster_func_op.getResultTypes(), /*constants=*/ValueRange({}), ValueRange(non_resource_args), - ValueRange(resource_args), call_op.getFAttr()); + ValueRange(resource_args), cluster_func_op.getFuncAttr()); - CopyDeviceAndUnderscoredAttributes(call_op, xla_launch_op); - call_op.replaceAllUsesWith(xla_launch_op.getResults()); - call_op.erase(); + CopyDeviceAndUnderscoredAttributes(cluster_func_op, xla_launch_op); + cluster_func_op.replaceAllUsesWith(xla_launch_op.getResults()); + cluster_func_op.erase(); } void XlaRewritePass::runOnOperation() { ModuleOp module = getOperation(); SymbolTable symtab(module); - module.walk([&](tf_device::ClusterOp cluster_op) { - cluster_op.getBody().walk([&](mlir::Operation *op) { - if (auto call_op = llvm::dyn_cast(op)) { - RewriteCall(call_op, symtab); - } else if (auto call_op = llvm::dyn_cast(op)) { - RewriteCall(call_op, symtab); - } - }); + OpBuilder builder(&getContext()); + module.walk([&](tf_device::ClusterFuncOp cluster_func_op) { + RewriteCall(cluster_func_op, symtab, builder); }); // Verify that there are no nested XLA launch ops. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 4aff79c8585..61490f6a749 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -63,6 +63,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/regularization/util.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -135,7 +136,9 @@ class Exporter { private: explicit Exporter(Graph* graph, const Dialect* tf_dialect) - : graph_(graph), tf_dialect_(tf_dialect) {} + : graph_(graph), tf_dialect_(tf_dialect) { + graph_->ToGraphDef(&graphdef_); + } Status AddArgumentNode(BlockArgument arg, unsigned index, llvm::StringRef name); @@ -158,6 +161,7 @@ class Exporter { Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index); Graph* graph_; + GraphDef graphdef_; LegalizedOpOrValLocNameMapper op_to_name_; absl::flat_hash_map nodes_; llvm::DenseMap args_; @@ -358,7 +362,8 @@ Status Exporter::AddEdge(Operation* inst) { Status Exporter::AddInstructionNode(Operation* inst) { std::unique_ptr node_def; - auto name = op_to_name_.GetUniqueName(inst); + int graph_hash_value = graph_regularization::ComputeHash(graphdef_); + auto name = op_to_name_.GetUniqueName(inst, graph_hash_value); // Convert registered TF ops to NodeDef. Only registered ops are handled to // ensure that PopulateDerivedAttrs adds the correct attributes. TF_ASSIGN_OR_RETURN(node_def, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index ce9d086f22a..b8ba989b33b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -99,6 +100,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -124,7 +126,6 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" @@ -1089,7 +1090,7 @@ StatusOr ImporterBase::InferOutputType(const Node& node, int idx, return errors::InvalidArgument( "Node '", node.name(), " has an invalid ", kOutputShapesAttrName, " attribute (shape #", idx, " error:'", - s.error_message(), "')"); + s.message(), "')"); c->set_output(idx, h); } } @@ -1680,7 +1681,7 @@ Status ImporterBase::ConvertFunctionArgAndRets( } llvm::SmallVector inst_to_return; - for (auto ret_and_idx : llvm::enumerate(ret_nodes)) { + for (const auto& ret_and_idx : llvm::enumerate(ret_nodes)) { const auto& ret = ret_and_idx.value(); auto* inst = node_values_[ret.node->id()]; if (ret.node->IsRetval()) { @@ -1772,6 +1773,7 @@ mlir::Location ImporterBase::GetLocation(const Node& node) { mlir::FileLineColLoc::get(file_name, frame.line_number, 1); locations.push_back(file_line_loc); } + stack_trace->WipeCache(); } else { DVLOG(1) << "No stack trace for " << node.name(); const auto location_it = debug_info.find(debug_info_key); @@ -2353,7 +2355,8 @@ class GraphDefImporter : public ImporterBase { mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - std::unordered_map& tf_name_to_mlir_name); + std::unordered_map& tf_name_to_mlir_name, + bool disable_crash_analysis = false); private: explicit GraphDefImporter( @@ -2395,7 +2398,8 @@ StatusOr> GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - std::unordered_map& tf_name_to_mlir_name) { + std::unordered_map& tf_name_to_mlir_name, + bool disable_crash_analysis) { LoadImporterDialects(*context); mlir::OwningOpRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); @@ -2405,22 +2409,28 @@ StatusOr> GraphDefImporter::Convert( // via conversion to the graph def first. Convert graph to graph_def here // first and avoid extra copies later. auto graph_def = std::make_unique(); - graph.ToGraphDef(graph_def.get()); + graph.ToGraphDef(graph_def.get(), /*include_flib_def=*/false); - static std::atomic counter(0); - uint32 current_file_prefix = counter++; - const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash( - absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"), - *graph_def); - auto reachable_flib = flib_def.ReachableDefinitions(*graph_def); - const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash( - absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"), - reachable_flib.ToProto()); + auto scope_exit = [&]() { + std::function cleanup = []() {}; + if (!disable_crash_analysis) { + static std::atomic counter(0); + uint32 current_file_prefix = counter++; + const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash( + absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"), + *graph_def); + auto reachable_flib = flib_def.ReachableDefinitions(*graph_def); + const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash( + absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"), + reachable_flib.ToProto()); + cleanup = [=]() { + crash_analysis::RemoveReportData(graph_crash_handle); + crash_analysis::RemoveReportData(flib_crash_handle); + }; + } - auto scope_exit = llvm::make_scope_exit([&]() { - crash_analysis::RemoveReportData(graph_crash_handle); - crash_analysis::RemoveReportData(flib_crash_handle); - }); + return llvm::make_scope_exit(std::move(cleanup)); + }(); VLOG(2) << "Importing: " << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph, @@ -2471,6 +2481,11 @@ StatusOr> GraphDefImporter::Convert( attrs.push_back(b.getNamedAttr( "tf.entry_function", b.getDictionaryAttr({inputs, outputs, control_outputs}))); + if (!specs.xla_compile_device_type.empty()) { + attrs.push_back( + b.getNamedAttr("_xla_compile_device_type", + b.getStringAttr(specs.xla_compile_device_type))); + } } else { // Collects the argument and return nodes by looking up the node names // specified by the user. @@ -2539,7 +2554,7 @@ StatusOr GraphDefImporter::InferMainFunctionType( // Feeds have been remapped to single output nodes (Placeholder), so an exact // name match is sufficient. absl::flat_hash_map inputs; - for (auto input_and_idx : llvm::enumerate(specs.inputs)) { + for (const auto& input_and_idx : llvm::enumerate(specs.inputs)) { TensorId tensor = ParseTensorName(input_and_idx.value().first); auto remapped_it = remapped_feeds_.find(tensor); if (remapped_it != remapped_feeds_.end()) { @@ -2700,7 +2715,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( mlir::Builder builder(context); llvm::SmallVector arg_types; arg_types.reserve(arg_nodes->size()); - for (auto arg_node_and_idx : llvm::enumerate(*arg_nodes)) { + for (const auto& arg_node_and_idx : llvm::enumerate(*arg_nodes)) { auto& arg_node = arg_node_and_idx.value(); if (arg_node.node == nullptr) return errors::InvalidArgument("Graph missing _Arg at index ", @@ -2713,7 +2728,7 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( llvm::SmallVector ret_types; ret_types.reserve(ret_nodes->size()); - for (auto ret_node_and_idx : llvm::enumerate(*ret_nodes)) { + for (const auto& ret_node_and_idx : llvm::enumerate(*ret_nodes)) { auto& ret_node = ret_node_and_idx.value(); if (ret_node.node == nullptr) return errors::InvalidArgument("Graph missing _Retval at index ", @@ -2733,7 +2748,7 @@ Status GraphDefImporter::GetControlRetsFromGraph( if (control_outputs.empty()) return OkStatus(); llvm::SmallDenseMap controls_to_idx; - for (auto control_and_idx : llvm::enumerate(control_outputs)) + for (const auto& control_and_idx : llvm::enumerate(control_outputs)) controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()}); if (controls_to_idx.size() != control_outputs.size()) @@ -3411,12 +3426,12 @@ Status CreateSavedModelIR( function.concrete_functions(0), "' (", input_index_paths.size(), " vs ", bound_input_base, ")"); } - for (auto index_path : llvm::enumerate(input_index_paths)) { + for (const auto& index_path : llvm::enumerate(input_index_paths)) { func.setArgAttr(index_path.index(), kTfSavedModelIndexPathAttr, index_path.value()); } - for (auto& bound_input : + for (const auto& bound_input : llvm::enumerate(concrete_function.bound_inputs())) { int arg_index = bound_input_base + bound_input.index(); auto symbol_ref = mlir::SymbolRefAttr::get( @@ -3438,7 +3453,7 @@ Status CreateSavedModelIR( function.concrete_functions(0), "' (", output_index_paths.size(), " vs ", func.getNumResults(), ")"); } - for (auto index_path : llvm::enumerate(output_index_paths)) { + for (const auto& index_path : llvm::enumerate(output_index_paths)) { func.setResultAttr(index_path.index(), kTfSavedModelIndexPathAttr, index_path.value()); } @@ -3560,8 +3575,8 @@ SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef)); } - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph)); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( + options, std::move(preprocessed_graphdef), &graph)); NameUniquifier function_name_uniquifier(graph.flib_def()); SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs, @@ -3615,7 +3630,7 @@ class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput { const MLIRImportOptions& import_options, const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) { DCHECK(meta_graph_def); - GraphDef graph_def = meta_graph_def->graph_def(); + GraphDef graph_def(meta_graph_def->graph_def()); auto graph = std::make_unique(OpRegistry::Global()); if (import_options.upgrade_legacy) { @@ -3626,8 +3641,8 @@ class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput { GraphConstructorOptions graph_ctor_options; graph_ctor_options.allow_internal_ops = true; graph_ctor_options.add_default_attributes = true; - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph.get())); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( + graph_ctor_options, std::move(graph_def), graph.get())); if (import_options.upgrade_legacy) { // TODO(jpienaar): Remove need to const_cast. @@ -3941,7 +3956,8 @@ SavedModelSignatureDefImporterLite::ConvertGraph( // Convert sub-graph to MLIR module. return GraphDefImporter::Convert(module_->getContext(), *subgraph, input_.debug_info(), subgraph->flib_def(), - specs, tf_name_to_mlir_name); + specs, tf_name_to_mlir_name, + /*disable_crash_analysis=*/true); } Status SavedModelSignatureDefImporterLite::ConvertSignature( @@ -3983,11 +3999,11 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature( builder.getStrArrayAttr({sig_def_key})); // Transfer input and output parameter names to index_path attributes. - for (auto input_and_idx : llvm::enumerate(inputs)) { + for (const auto& input_and_idx : llvm::enumerate(inputs)) { func_op.setArgAttr(input_and_idx.index(), kTfSavedModelIndexPathAttr, builder.getStrArrayAttr({input_and_idx.value().first})); } - for (auto output_and_idx : llvm::enumerate(outputs)) { + for (const auto& output_and_idx : llvm::enumerate(outputs)) { func_op.setResultAttr( output_and_idx.index(), kTfSavedModelIndexPathAttr, builder.getStrArrayAttr({output_and_idx.value().first})); @@ -4170,7 +4186,9 @@ class SavedModelSignatureDefImporter { mlir::OpBuilder builder(module->getContext()); (*module)->setAttr("tf_saved_model.under_construction", builder.getUnitAttr()); - TF_RETURN_IF_ERROR(LiftVariables(bundle, *module, options.lift_variables)); + TF_RETURN_IF_ERROR( + LiftVariables(bundle, *module, options.lift_variables, + options.include_variables_in_initializers)); (*module)->removeAttr("tf_saved_model.under_construction"); return module; @@ -4178,14 +4196,21 @@ class SavedModelSignatureDefImporter { private: // Lifts the variables in `module`. + // If `include_variables_in_initializers` is set to false, then it removes all + // variables from the initializer functions (registered in the + // `tf_saved_model::SessionInitializerOp`) by running the + // `RemoveVariablesInSessionInitializerPass`, regardless of whether + // `lift_variable_ops_to_args` is true or not. static Status LiftVariables(const SavedModelBundle& bundle, mlir::ModuleOp module, - bool lift_varhandle_ops_to_args); + bool lift_varhandle_ops_to_args, + bool include_variables_in_initializers); }; Status SavedModelSignatureDefImporter::LiftVariables( const SavedModelBundle& bundle, mlir::ModuleOp module, - bool lift_varhandle_ops_to_args) { + const bool lift_varhandle_ops_to_args, + const bool include_variables_in_initializers) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); mlir::PassManager pm(module.getContext()); @@ -4194,8 +4219,10 @@ Status SavedModelSignatureDefImporter::LiftVariables( mlir::tf_executor::CreateTFExecutorGraphPruningPass()); pm.addNestedPass( mlir::CreateExecutorDialectToFunctionalConversionPass()); - pm.addPass( - mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass()); + if (!include_variables_in_initializers) { + pm.addPass( + mlir::tf_saved_model::CreateRemoveVariablesInSessionInitializerPass()); + } pm.addNestedPass( mlir::TF:: CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index ac10baa94c3..182a53078ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h index e1b45dda1c5..44262d0bd08 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h @@ -32,9 +32,20 @@ struct MLIRImportOptions { // Apply default attributes from the op definition to the loaded op. bool add_default_attributes = true; - // If set, promote tf.VarHandleOp to resource arguments for all functions. + // If set, promote tf.VarHandleOp to resource arguments for all functions. bool lift_variables = true; + // Keeps the variables in initializers before lifting variables (when + // `lift_variables == true`) or newly adding variable initialization patterns + // in the initializer functions. One might want to set this to `true` because + // the `RemoveVariablesInSessionInitializerPass` pass, which runs otherwise, + // may unexpectedly also remove the initialization patterns for non-variable + // resources (like hash tables) if they involve variables. Such a case is + // illustrated in the test file + // "../tests/tf_saved_model_remove_vars_in_session_initializer.mlir". + // This defaults to `false` to avoid breaking existing uses. + bool include_variables_in_initializers = false; + // Load the model without restoring associated variables from disk. Enables // loading raw programs without checkpoints. bool allow_uninitialized_variables = false; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index 12db69d867b..3c703722b82 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -62,6 +62,7 @@ std::string GraphImportConfig::str() const { ss << "\nenable_shape_inference: " << enable_shape_inference; ss << "\nunconditionally_use_set_output_shapes: " << unconditionally_use_set_output_shapes; + ss << "\nxla_compile_device_type: " << xla_compile_device_type; return ss.str(); } @@ -245,7 +246,7 @@ static StatusOr> ParseDTypesHelper( bool inside_subtype = false; int cur_pos = 0; std::vector dtypes; - for (auto& it : llvm::enumerate(data_types_str)) { + for (const auto& it : llvm::enumerate(data_types_str)) { char c = it.value(); int i = it.index(); // Skip parsing the subtypes of a type diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index 5eb7b25a126..79d364bf6b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -90,6 +90,9 @@ struct GraphImportConfig { // so make it opt-in to consider it unconditionally also when importing the // graph. bool unconditionally_use_set_output_shapes = false; + // If set, use the value as the device type and mark the function graph for + // XLA compilation. + string xla_compile_device_type; }; struct GraphExportConfig { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index 65a6dbaa1c5..84ae5a522e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -25,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/graph_constructor.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc index a1dced4bf5e..c5e059e3a67 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc @@ -176,7 +176,7 @@ void SplitIsland(mlir::tf_executor::IslandOp island_op, // `island_op.getControl().dropAllUses();` of a control dep that's only used // in a graph's fetch, immediately leads to a segfault. Turns out we need to // drop its uses manually so that we don't leave dangling controls. - for (auto& fetch : llvm::enumerate(graph_op.GetFetch().getFetches())) { + for (const auto& fetch : llvm::enumerate(graph_op.GetFetch().getFetches())) { if (fetch.value() == island_op.getControl()) { graph_op.GetFetch().getFetchesMutable().erase(fetch.index(), 1); break; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 66f511d1a93..233d35d8c01 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -34,18 +34,19 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/utils/transitive_fanin.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h" namespace tensorflow { static StatusOr> GraphdefToMlirImport( llvm::StringRef input, absl::string_view debug_info_file, + absl::string_view xla_compile_device_type, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, @@ -73,6 +74,7 @@ static StatusOr> GraphdefToMlirImport( specs.enable_shape_inference = enable_shape_inference; specs.unconditionally_use_set_output_shapes = unconditionally_use_set_output_shapes; + specs.xla_compile_device_type = xla_compile_device_type; TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes, input_shapes, &specs.inputs)); TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs)); @@ -108,6 +110,7 @@ static StatusOr> GraphdefToMlirImport( StatusOr> GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, + absl::string_view xla_compile_device_type, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, @@ -117,10 +120,11 @@ StatusOr> GraphdefToMlirTranslateFunction( bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( - input, debug_info_file, input_arrays, input_dtypes, input_shapes, - output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, - enable_shape_inference, unconditionally_use_set_output_shapes, context); + input, debug_info_file, xla_compile_device_type, input_arrays, + input_dtypes, input_shapes, output_arrays, control_output_arrays, + prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, + upgrade_legacy, enable_shape_inference, + unconditionally_use_set_output_shapes, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); } @@ -129,12 +133,12 @@ StatusOr> GraphdefToMlirTranslateFunction( StatusOr> GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view input_arrays, absl::string_view input_dtypes, - absl::string_view input_shapes, absl::string_view output_arrays, - absl::string_view control_output_arrays, bool prune_unused_nodes, - bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - bool enable_shape_inference, bool unconditionally_use_set_output_shapes, - mlir::MLIRContext* context) { + absl::string_view xla_compile_device_type, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, absl::string_view control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, + bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { std::vector input_array_vector; std::vector input_dtype_vector; std::vector>> input_shapes_vector; @@ -147,11 +151,11 @@ StatusOr> GraphdefToMlirTranslateFunction( TF_RETURN_IF_ERROR( ParseNodeNames(control_output_arrays, control_output_array_vector)); return GraphdefToMlirTranslateFunction( - input, debug_info_file, input_array_vector, input_dtype_vector, - input_shapes_vector, output_array_vector, control_output_array_vector, - prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, - upgrade_legacy, enable_shape_inference, - unconditionally_use_set_output_shapes, context); + input, debug_info_file, xla_compile_device_type, input_array_vector, + input_dtype_vector, input_shapes_vector, output_array_vector, + control_output_array_vector, prune_unused_nodes, + convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, + enable_shape_inference, unconditionally_use_set_output_shapes, context); } StatusOr> SavedModelObjectGraphToMlirImport( @@ -249,6 +253,7 @@ SavedModelSignatureDefsToMlirImportLite( StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, + absl::string_view xla_compile_device_type, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, @@ -258,10 +263,11 @@ GraphdefToSplattedMlirTranslateFunction( bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( - input, debug_info_file, input_arrays, input_dtypes, input_shapes, - output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, - enable_shape_inference, unconditionally_use_set_output_shapes, context); + input, debug_info_file, xla_compile_device_type, input_arrays, + input_dtypes, input_shapes, output_arrays, control_output_arrays, + prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, + upgrade_legacy, enable_shape_inference, + unconditionally_use_set_output_shapes, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); return module_or.status(); @@ -274,7 +280,7 @@ GraphdefToSplattedMlirTranslateFunction( auto attr_id = mlir::StringAttr::get(context, "value"); if (auto attr = inst.getAttrOfType(attr_id)) { mlir::Attribute rand_val; - mlir::Type element_type = attr.getType().getElementType(); + mlir::Type element_type = attr.getShapedType().getElementType(); if (element_type.isa()) { rand_val = mlir::IntegerAttr::get(element_type, std::rand()); } else if (element_type.isF16() || element_type.isF32() || @@ -288,8 +294,8 @@ GraphdefToSplattedMlirTranslateFunction( << "an unsupported attribute type " << element_type; continue; } - auto new_attr = - mlir::DenseElementsAttr::get(attr.getType(), rand_val); + auto new_attr = mlir::DenseElementsAttr::get( + llvm::cast(attr.getType()), rand_val); inst.setAttr(attr_id, new_attr); } } @@ -301,12 +307,12 @@ GraphdefToSplattedMlirTranslateFunction( StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view input_arrays, absl::string_view input_dtypes, - absl::string_view input_shapes, absl::string_view output_arrays, - absl::string_view control_output_arrays, bool prune_unused_nodes, - bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - bool enable_shape_inference, bool unconditionally_use_set_output_shapes, - mlir::MLIRContext* context) { + absl::string_view xla_compile_device_type, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, absl::string_view control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, + bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) { std::vector input_array_vector; std::vector input_dtype_vector; std::vector>> input_shapes_vector; @@ -319,11 +325,11 @@ GraphdefToSplattedMlirTranslateFunction( TF_RETURN_IF_ERROR( ParseNodeNames(control_output_arrays, control_output_array_vector)); return GraphdefToSplattedMlirTranslateFunction( - input, debug_info_file, input_array_vector, input_dtype_vector, - input_shapes_vector, output_array_vector, control_output_array_vector, - prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, - upgrade_legacy, enable_shape_inference, - unconditionally_use_set_output_shapes, context); + input, debug_info_file, xla_compile_device_type, input_array_vector, + input_dtype_vector, input_shapes_vector, output_array_vector, + control_output_array_vector, prune_unused_nodes, + convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, + enable_shape_inference, unconditionally_use_set_output_shapes, context); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index 33435cea739..677c09dd027 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -41,6 +41,7 @@ using tsl::StatusOr; // Creates MLIR entities into the given MLIR `context`. StatusOr> GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, + absl::string_view xla_compile_device_type, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>>& input_shapes, @@ -59,10 +60,11 @@ ABSL_DEPRECATED( // Creates MLIR entities into the given MLIR `context`. StatusOr> GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view input_arrays, absl::string_view input_dtypes, - absl::string_view input_shapes, absl::string_view output_arrays, - absl::string_view control_output_arrays, bool prune_unused_nodes, - bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, + absl::string_view xla_compile_device_type, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, absl::string_view control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, // TODO(jpienaar): Remove these. bool enable_shape_inference, bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context); @@ -72,6 +74,7 @@ StatusOr> GraphdefToMlirTranslateFunction( StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, + absl::string_view xla_compile_device_type, const std::vector& input_arrays, const std::vector& input_dtypes, const std::vector>& input_shapes, @@ -91,10 +94,11 @@ ABSL_DEPRECATED( StatusOr> GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view input_arrays, absl::string_view input_dtypes, - absl::string_view input_shapes, absl::string_view output_arrays, - absl::string_view control_output_arrays, bool prune_unused_nodes, - bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, + absl::string_view xla_compile_device_type, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, absl::string_view control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, // TODO(jpienaar): Remove these. bool enable_shape_inference, bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc index fdcfb18cd58..d739b3997c5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc @@ -86,6 +86,12 @@ opt debug_info_file( llvm::cl::desc("Path to the debug info file of the input graph def"), llvm::cl::init("")); +// NOLINTNEXTLINE +opt xla_compile_device_type( + "tf-xla-compile-device-type", + llvm::cl::desc("Sets the compilation device type of the input graph def"), + llvm::cl::init("")); + // TODO(b/134792656): If pruning is moved into TF dialect as a pass // we should remove this. // NOLINTNEXTLINE diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h index aaf0b5c4c74..af50bdc185f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h @@ -36,6 +36,7 @@ extern llvm::cl::opt inference_type; extern llvm::cl::opt min_values; extern llvm::cl::opt max_values; extern llvm::cl::opt debug_info_file; +extern llvm::cl::opt xla_compile_device_type; extern llvm::cl::opt prune_unused_nodes; extern llvm::cl::opt convert_legacy_fed_inputs; extern llvm::cl::opt graph_as_function; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index f1c39aba7ad..6ce04664a7b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -45,10 +45,11 @@ inline absl::string_view StringRefToView(llvm::StringRef ref) { static OwningOpRef GraphdefToMlirTranslateFunction( llvm::StringRef input, MLIRContext* context) { auto module_or = tensorflow::GraphdefToMlirTranslateFunction( - input, debug_info_file, input_arrays, input_dtypes, input_shapes, - output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, - enable_shape_inference, unconditionally_use_set_output_shapes, context); + input, debug_info_file, xla_compile_device_type, input_arrays, + input_dtypes, input_shapes, output_arrays, control_output_arrays, + prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, + upgrade_legacy, enable_shape_inference, + unconditionally_use_set_output_shapes, context); if (!module_or.status().ok()) return nullptr; return std::move(module_or).value(); } @@ -59,10 +60,11 @@ static TranslateToMLIRRegistration GraphdefToMlirTranslate( static OwningOpRef GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, MLIRContext* context) { auto module_or = tensorflow::GraphdefToSplattedMlirTranslateFunction( - input, debug_info_file, input_arrays, input_dtypes, input_shapes, - output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, - enable_shape_inference, unconditionally_use_set_output_shapes, context); + input, debug_info_file, xla_compile_device_type, input_arrays, + input_dtypes, input_shapes, output_arrays, control_output_arrays, + prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, + upgrade_legacy, enable_shape_inference, + unconditionally_use_set_output_shapes, context); if (!module_or.status().ok()) return nullptr; return std::move(module_or).value(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index 0c6a4733dc9..95066de457a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -43,12 +43,20 @@ inline constexpr llvm::StringRef kReplicationInfoAttr = "_replication_info"; inline constexpr llvm::StringRef kTpuReplicateAttr = "_tpu_replicate"; // Device types. inline constexpr llvm::StringRef kTpuDevice = "TPU"; +// _xla_outside_compilation +inline constexpr llvm::StringRef kXlaOutsideCompilationAttr = + "_xla_outside_compilation"; +// device attr +inline constexpr llvm::StringRef kDeviceAttr = "device"; // Function attribute to signal that a function should be skipped from TPU // island outlining. The attribute is set in // `TpuV1BridgeExecutorIslandCoarsening` and removed in the subsequent // `TPUBridgeExecutorIslandOutlining` pass. inline constexpr llvm::StringRef kSkipIslandOutlining = "_skip_island_outlining"; +// Function attribute to signal which argument contains bounded dynamic +// dimension. +inline constexpr llvm::StringRef kDynamicArgIndexAttr = "_dynamic_arg_index"; // This string attribute encodes parallel execution groups and their associated // branches. It has the following format: diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 42716d6e9ec..fce8c6f8dcf 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -219,7 +219,7 @@ StatusOr ConvertTensorProto(const TensorProto& input_tensor, llvm::SmallVector original_dimensions; for (auto dim : input_tensor_shape) original_dimensions.push_back(dim.size); return ElementsAttr(mlir::SplatElementsAttr::get( - single_attr.getType().clone(original_dimensions), + single_attr.getShapedType().clone(original_dimensions), single_attr.getValues()[0])); } @@ -404,7 +404,7 @@ void ConvertFloat8ElementsAttr(const mlir::DenseElementsAttr attr, } Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { - auto type = attr.getType(); + auto type = attr.getShapedType(); auto shape = type.getShape(); DataType output_dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 115a1cbbfd2..373e88f7413 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -119,7 +119,7 @@ class ConvertTensorTest : public ::testing::Test { TF_ASSERT_OK(value_or.status()); auto attr = value_or.value(); - EXPECT_EQ(attr.getType().getElementType(), expected_ty); + EXPECT_EQ(attr.getShapedType().getElementType(), expected_ty); Tensor out; TF_ASSERT_OK(ConvertToTensor(attr, &out)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.cc b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.cc new file mode 100644 index 00000000000..f49950cdf1e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.cc @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" + +#include +#include +#include + +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" + +namespace tensorflow { +DataDumperLoggerConfig::DataDumperLoggerConfig( + std::function get_filename, + const std::string &pass_prefix, bool print_module_scope, + bool print_after_only_on_change) + : ::tensorflow::BridgeLoggerConfig(print_module_scope, + print_after_only_on_change), + get_filename_(get_filename), + pass_prefix_(pass_prefix) {} + +void DataDumperLoggerConfig::printBeforeIfEnabled( + mlir::Pass *pass, mlir::Operation *op, PrintCallbackFn print_callback) { + std::string pass_name = pass->getName().str(); + std::string filename = get_filename_(pass_prefix_ + "before_" + pass_name); + + DumpMlir(filename, print_callback); +} + +void DataDumperLoggerConfig::printAfterIfEnabled( + mlir::Pass *pass, mlir::Operation *op, PrintCallbackFn print_callback) { + std::string pass_name = pass->getName().str(); + std::string filename = get_filename_(pass_prefix_ + "after_" + pass_name); + + DumpMlir(filename, print_callback); +} + +void DataDumperLoggerConfig::DumpMlir( + const std::string &filename, + BridgeLoggerConfig::PrintCallbackFn print_callback) { + std::unique_ptr os; + std::string filepath; + if (tensorflow::CreateFileForDumping(filename, &os, &filepath).ok()) { + print_callback(*os); + LOG(INFO) << "Dumped MLIR module to " << filepath; + } +} +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h new file mode 100644 index 00000000000..c962d68c02f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DATA_DUMPER_LOGGER_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DATA_DUMPER_LOGGER_CONFIG_H_ + +#include +#include + +#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" + +namespace tensorflow { + +class DataDumperLoggerConfig : public ::tensorflow::BridgeLoggerConfig { + public: + explicit DataDumperLoggerConfig( + std::function get_filename, + const std::string &pass_prefix = "", bool print_module_scope = false, + bool print_after_only_on_change = true); + + void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *op, + PrintCallbackFn print_callback) override; + + void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *op, + PrintCallbackFn print_callback) override; + + private: + static void DumpMlir(const std::string &filename, + BridgeLoggerConfig::PrintCallbackFn print_callback); + + // The function to dump the target MLIR string to file. + // The parameter that will be sent to the dump_func_ is: + // The pass name (std::string) + std::function get_filename_; + + // The pass prefix. + std::string pass_prefix_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DATA_DUMPER_LOGGER_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index dbc8f07c4a6..51db1be0820 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -64,7 +64,7 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, mlir::TF::RuntimeDevices* devices) { DeviceNameUtils::ParsedName device; - for (auto& kv : llvm::enumerate(array_attr)) { + for (const auto& kv : llvm::enumerate(array_attr)) { const int idx = kv.index(); auto string_attr = kv.value().dyn_cast(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index c45ef133240..efcbca84872 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -148,7 +148,7 @@ Status CreateFileForDumping(llvm::StringRef name, dir = GetDumpDirFromEnvVar(); if (dir.empty()) { - return Status(error::Code::INVALID_ARGUMENT, + return Status(absl::StatusCode::kInvalidArgument, "(TF_DUMP_GRAPH_PREFIX not specified)"); } @@ -164,7 +164,7 @@ Status CreateFileForDumping(llvm::StringRef name, if (!status.ok()) { LOG(WARNING) << "Failed to create '" << dir << "' directory for dumping: " << status; - return Status(error::Code::UNAVAILABLE, "(unavailable)"); + return Status(absl::StatusCode::kUnavailable, "(unavailable)"); } *filepath = io::JoinPath(dir, MakeUniqueFilename(std::string(name))); @@ -173,7 +173,7 @@ Status CreateFileForDumping(llvm::StringRef name, status = env->NewWritableFile(*filepath, &file); if (!status.ok()) { LOG(WARNING) << "Failed to create file '" << filepath << "': " << status; - return Status(error::Code::UNAVAILABLE, "(unavailable)"); + return Status(absl::StatusCode::kUnavailable, "(unavailable)"); } file = std::make_unique(std::move(file)); *os = std::make_unique(std::move(file)); @@ -202,7 +202,7 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, std::unique_ptr os; std::string filepath; Status result = CreateFileForDumping(name, &os, &filepath, dirname); - if (!result.ok()) return result.error_message(); + if (!result.ok()) return std::string(result.message()); if (pass_manager) PrintPassPipeline(*pass_manager, op, *os); op->print(*os, mlir::OpPrintingFlags().useLocalScope()); @@ -236,7 +236,7 @@ std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, std::unique_ptr os; std::string filepath; Status result = CreateFileForDumping(name, &os, &filepath, dirname); - if (!result.ok()) return result.error_message(); + if (!result.ok()) return std::string(result.message()); (*os) << content; LOG(INFO) << "Outputted requested string to '" << filepath << "'"; @@ -276,8 +276,8 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { auto* env = tensorflow::Env::Default(); auto status = env->RecursivelyCreateDir(path); if (!status.ok()) { - LOG(WARNING) << "cannot create directory '" + path + - "': " + status.error_message(); + LOG(WARNING) << "cannot create directory '" << path + << "': " << status.message(); return; } @@ -307,7 +307,7 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { if (!status.ok()) { error = absl::StrCat("Failed to create file '", path, - "': ", status.error_message()); + "': ", status.message()); return nullptr; } return std::make_unique( @@ -318,7 +318,11 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { void applyTensorflowAndCLOptions(mlir::PassManager& pm, llvm::StringRef dir_path) { - mlir::applyPassManagerCLOptions(pm); + mlir::registerPassManagerCLOptions(); + if (!mlir::succeeded(mlir::applyPassManagerCLOptions(pm))) { + LOG(ERROR) << "cannot apply MLIR pass manager CL options"; + return; + } SetCrashReproducer(pm, dir_path); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 5287a7d2d25..6069b8ca2ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -96,6 +96,9 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path = ""); void applyTensorflowAndCLOptions(mlir::PassManager& pm, llvm::StringRef dir_path = ""); +// Prints the pass pipeline of `pass_manager` to `os`. +void PrintPassPipeline(const mlir::PassManager& pass_manager, + mlir::Operation* op, llvm::raw_ostream& os); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index a68a66a7136..908bf40f834 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project @@ -88,6 +89,7 @@ TEST(DumpMlirModuleTest, Valid) { } TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { + mlir::registerPassManagerCLOptions(); mlir::MLIRContext context; mlir::OwningOpRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); @@ -119,11 +121,13 @@ TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { mlir::registerTensorFlowPasses(); EXPECT_TRUE(mlir::MlirOptMain(output_stream->os(), std::move(input_file), - passPipeline, registry, - /*splitInputFile=*/false, - /*verifyDiagnostics=*/false, - /*verifyPasses=*/false, - /*allowUnregisteredDialects=*/false) + registry, + mlir::MlirOptMainConfig{} + .splitInputFile(false) + .verifyDiagnostics(false) + .verifyPasses(false) + .allowUnregisteredDialects(false) + .setPassPipelineParser(passPipeline)) .succeeded()); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc index 4ef6340f39e..3cf746cd226 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc @@ -62,10 +62,9 @@ TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) { StatusScopedDiagnosticHandler ssdh(&context); Status s = ssdh.Combine(function()); ASSERT_TRUE(tensorflow::errors::IsInternal(s)); - EXPECT_THAT(s.error_message(), HasSubstr("Passed in error")); - EXPECT_THAT(s.error_message(), HasSubstr("Diagnostic message reported")); - EXPECT_THAT(s.error_message(), - HasSubstr("Second diagnostic message reported")); + EXPECT_THAT(s.message(), HasSubstr("Passed in error")); + EXPECT_THAT(s.message(), HasSubstr("Diagnostic message reported")); + EXPECT_THAT(s.message(), HasSubstr("Second diagnostic message reported")); } } @@ -111,11 +110,11 @@ TEST(ErrorUtilTest, StatusScopedDiagnosticHandlerWithFilter) { emitError(callsite_loc3) << "Error 3"; Status s_filtered = ssdh_filter.ConsumeStatus(); // Check for the files that should not be filtered. - EXPECT_THAT(s_filtered.error_message(), HasSubstr("keras")); - EXPECT_THAT(s_filtered.error_message(), HasSubstr("test.py")); - EXPECT_THAT(s_filtered.error_message(), HasSubstr("show_file")); + EXPECT_THAT(s_filtered.message(), HasSubstr("keras")); + EXPECT_THAT(s_filtered.message(), HasSubstr("test.py")); + EXPECT_THAT(s_filtered.message(), HasSubstr("show_file")); // Verify the filtered files are not present. - EXPECT_THAT(s_filtered.error_message(), Not(HasSubstr("filtered_file"))); + EXPECT_THAT(s_filtered.message(), Not(HasSubstr("filtered_file"))); } TEST(ErrorUtilTest, StatusScopedDiagnosticHandlerWithoutFilter) { @@ -151,10 +150,10 @@ TEST(ErrorUtilTest, StatusScopedDiagnosticHandlerWithoutFilter) { emitError(callsite_loc2) << "Error 2"; Status s_no_filter = ssdh_no_filter.ConsumeStatus(); // All files should be present, especially the 'filtered' ones. - EXPECT_THAT(s_no_filter.error_message(), HasSubstr("keras")); - EXPECT_THAT(s_no_filter.error_message(), HasSubstr("my_op")); - EXPECT_THAT(s_no_filter.error_message(), HasSubstr("filtered_file_A")); - EXPECT_THAT(s_no_filter.error_message(), HasSubstr("filtered_file_B")); + EXPECT_THAT(s_no_filter.message(), HasSubstr("keras")); + EXPECT_THAT(s_no_filter.message(), HasSubstr("my_op")); + EXPECT_THAT(s_no_filter.message(), HasSubstr("filtered_file_A")); + EXPECT_THAT(s_no_filter.message(), HasSubstr("filtered_file_B")); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index c03ba4c2f8a..925c2dfc57b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -50,7 +50,7 @@ static bool IsOk(const TF_Status* s) { static bool IsOk(const Status& s) { if (s.ok()) return true; - VLOG(2) << s.error_message(); + VLOG(2) << s.message(); return false; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc index 5b89105156d..fdb1ebc39a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc @@ -71,7 +71,7 @@ absl::StatusOr> GetResourcesFromSession( auto status = session->Run({}, variable_names, {}, &resource_tensors); if (!status.ok()) - return absl::Status(absl::StatusCode::kInternal, status.error_message()); + return absl::Status(absl::StatusCode::kInternal, status.message()); return resource_tensors; } } // namespace tf_saved_model diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 16c0e316204..449e0532cf0 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -30,6 +30,8 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/service/computation_placer.h" @@ -48,22 +50,25 @@ constexpr int kTPUTopologyRank = 4; constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM"; constexpr char kDeviceTPU[] = "TPU"; constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE"; +constexpr char kTPUReplicatedHost[] = "TPU_REPLICATED_HOST"; constexpr char kBadIntArrayElementMsg[] = "bad '{0}' attribute at index {1}, not an int"; -using Device = DeviceNameUtils::ParsedName; -using Devices = llvm::ArrayRef; +using ParsedDevice = DeviceNameUtils::ParsedName; +using ParsedDevices = llvm::ArrayRef; namespace { -// Finds matching devices in `devices` based on pattern `spec`. -void FindMatchingDevices(Devices devices, const Device& spec, - llvm::SmallVectorImpl* matched_devices) { +// Find matching devices in `devices` based on pattern `spec`. +llvm::SmallVector FindMatchingDevices( + ParsedDevices devices, const ParsedDevice& spec) { + llvm::SmallVector matching_devices; for (const auto& device : devices) if (DeviceNameUtils::IsCompleteSpecification(spec, device)) - matched_devices->push_back(device); + matching_devices.push_back(device); + return matching_devices; } -// Creates error message for a conflicting attribute of a device. +// Create error message for a conflicting attribute of a device. template Status MismatchedTPUSystemAttributeErr(absl::string_view attribute, T a, T b) { return errors::InvalidArgument("found ", kDeviceTPUSystem, @@ -71,20 +76,20 @@ Status MismatchedTPUSystemAttributeErr(absl::string_view attribute, T a, T b) { a, "' and '", b, "'"); } -// Finds TPU_SYSTEM:0 devices in `devices`. If multiple TPU_SYSTEM devices are +// Find TPU_SYSTEM:0 devices in `devices`. If multiple TPU_SYSTEM devices are // found, the first one lexicographically is returned. If no TPU_SYSTEM device // is found or if there are multiple TPU_SYSTEM devices with different jobs or // replicas, a failure will be returned. -Status GetTPUSystemDevices(Devices devices, - llvm::SmallVectorImpl* matched_devices) { - Device spec; +StatusOr> GetTPUSystemDevices( + ParsedDevices devices) { + ParsedDevice spec; spec.type = kDeviceTPUSystem; spec.has_type = true; spec.id = 0; spec.has_id = true; - llvm::SmallVector system_devices; - FindMatchingDevices(devices, spec, &system_devices); + llvm::SmallVector system_devices = + FindMatchingDevices(devices, spec); if (system_devices.empty()) return errors::InvalidArgument("no ", kDeviceTPUSystem, " devices found"); @@ -103,33 +108,36 @@ Status GetTPUSystemDevices(Devices devices, // Sort by task to be deterministic. std::sort(system_devices.begin(), system_devices.end(), - [](const Device& a, const Device& b) { return a.task < b.task; }); + [](const ParsedDevice& a, const ParsedDevice& b) { + return a.task < b.task; + }); - matched_devices->swap(system_devices); - - return OkStatus(); + return system_devices; } -// Finds TPU devices associated to system device based on spec (e.g. from +// Find TPU devices associated to system device based on spec (e.g. from // GetTPUSystemDevices). If the number of TPU devices per host do not match for // every host, a failure will be returned. -Status GetTPUDevices( - Devices devices, llvm::ArrayRef system_devices, - llvm::SmallVectorImpl>* tpu_devices) { - tpu_devices->reserve(system_devices.size()); +StatusOr, 8>> +GetTPUDevices(ParsedDevices devices, + llvm::ArrayRef system_devices) { + llvm::SmallVector, 8> tpu_devices; + tpu_devices.reserve(system_devices.size()); - auto lookup = [&devices](Device device_spec) { + auto lookup = [&devices](ParsedDevice device_spec) { device_spec.has_type = true; device_spec.type = kDeviceTPU; // Enumerate all the available TPUs. device_spec.has_id = false; - llvm::SmallVector host_tpu_devices; - FindMatchingDevices(devices, device_spec, &host_tpu_devices); + llvm::SmallVector host_tpu_devices = + FindMatchingDevices(devices, device_spec); // Sort devices by id. std::sort(host_tpu_devices.begin(), host_tpu_devices.end(), - [](const Device& i, const Device& j) { return i.id < j.id; }); + [](const ParsedDevice& i, const ParsedDevice& j) { + return i.id < j.id; + }); return host_tpu_devices; }; @@ -138,7 +146,7 @@ Status GetTPUDevices( const auto& device = system_devices[0]; auto host_tpu_devices = lookup(device); num_tpus_per_host = host_tpu_devices.size(); - tpu_devices->push_back(std::move(host_tpu_devices)); + tpu_devices.push_back(std::move(host_tpu_devices)); } for (const auto& device_spec : llvm::make_range( @@ -151,14 +159,15 @@ Status GetTPUDevices( "expected the number of TPU devices per host to be ", num_tpus_per_host, ", got ", host_tpu_devices.size()); - tpu_devices->push_back(std::move(host_tpu_devices)); + tpu_devices.push_back(std::move(host_tpu_devices)); } - return OkStatus(); + return tpu_devices; } -// Finds the compilation device from system device. -std::string GetTPUCompilationDevice(Device system_device) { +// Find the compilation device from system device with `DEVICE_CPU` as its +// type. +std::string GetTPUCompilationDevice(ParsedDevice system_device) { // TODO(b/110910013) GetTPUSystemDevices parses the spec and returns the // TPU_SYSTEM device, which we replace with the CPU device. We do this // replacement because we want to place the `tf._TPUCompileMlir` explicitly on @@ -167,21 +176,22 @@ std::string GetTPUCompilationDevice(Device system_device) { return DeviceNameUtils::ParsedNameToString(system_device); } -// Finds the host CPU device for a given TPU device. -std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) { +// Find the host CPU device for a given TPU device with `DEVICE_CPU` as its +// type and `id` 0. +std::string GetCPUHostDeviceForTPUDevice(ParsedDevice tpu_device) { tpu_device.type = DEVICE_CPU; tpu_device.id = 0; return DeviceNameUtils::ParsedNameToString(tpu_device); } -// Determines execution devices when topology and device assignment are not +// Determine execution devices when topology and device assignment are not // defined. This is a special case where a single core computation is replicated // to every core in the mesh. TPU devices are simply added to // `execution_devices` of one replica. `num_replicas` must be 1 or the total // number of TPU devices available, and `num_cores_per_replica` must be 1. StatusOr GetFullMeshTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, - llvm::ArrayRef> tpu_devices) { + llvm::ArrayRef> tpu_devices) { const int num_tasks = tpu_devices.size(); const int num_tpus_per_task = tpu_devices[0].size(); const int num_tpu_devices = num_tasks * num_tpus_per_task; @@ -219,14 +229,14 @@ struct TaskAndDevice { int device = -1; }; -// Checks if device coordinate is outside of topology mesh shape bounds. +// Check if device coordinate is outside of topology mesh shape bounds. bool DeviceCoordinateOutOfBound(int x, int y, int z, int core, int bound_x, int bound_y, int bound_z, int bound_core) { return x < 0 || x >= bound_x || y < 0 || y >= bound_y || z < 0 || z >= bound_z || core < 0 || core >= bound_core; } -// Creates error message for an out of bound device coordinate. +// Create error message for an out of bound device coordinate. Status DeviceCoordinateErrorMsg(absl::string_view attribute, int x, int y, int z, int core, int bound_x, int bound_y, int bound_z, int bound_core) { @@ -236,7 +246,7 @@ Status DeviceCoordinateErrorMsg(absl::string_view attribute, int x, int y, bound_y, ", ", bound_z, ", ", bound_core, ")"); } -// Creates error message for a duplicate device coordinate. +// Create error message for a duplicate device coordinate. Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, int y, int z, int core) { return errors::InvalidArgument("'", attribute, @@ -244,7 +254,7 @@ Status DuplicateCoordinateErrorMsg(absl::string_view attribute, int x, int y, y, ", ", z, ", ", core, ")"); } -// Parses and validates topology (serialized string of TopologyProto), and maps +// Parse and validate topology (serialized string of TopologyProto), and maps // device coordinate (x, y, z, core) to task and device (of available TPUs). // Topology attribute device coordinates are ordered by task then device (major // to minor). @@ -326,7 +336,7 @@ StatusOr> ParseTopologyAttr( return topology; } -// Determines execution devices when topology and device assignment are defined. +// Determine execution devices when topology and device assignment are defined. // With a topology device coordinate to task and device mapping, device // assignment device coordinates can then be mapped to task and device for TPU // devices. The device assignment array is also validated. @@ -340,7 +350,7 @@ StatusOr> ParseTopologyAttr( StatusOr> GetGeneralTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, - llvm::ArrayRef> tpu_devices, + llvm::ArrayRef> tpu_devices, llvm::StringRef topology_attr, llvm::ArrayRef device_assignment_attr) { const int num_tasks = tpu_devices.size(); @@ -441,59 +451,149 @@ mlir::LogicalResult GetHostDeviceOCInGenericPipeline( return mlir::success(); } -mlir::LogicalResult GetHostDeviceOCInTPUPipeline( - mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, - std::string* host_device) { - auto replicate = cluster->getParentOfType(); - if (replicate) { - *host_device = tensorflow::kTPUReplicatedHost; - return mlir::success(); - } - - auto topology_attr = +mlir::LogicalResult GetTopology(mlir::tf_device::ClusterOp cluster, + std::string& topology) { + mlir::StringAttr topology_attr = cluster->getAttrOfType(tensorflow::kTopologyAttr); - if (!topology_attr) - return cluster.emitOpError("cluster op missing `topology` attribute"); - - auto num_cores_per_replica_attr = cluster->getAttrOfType( - tensorflow::kNumCoresPerReplicaAttr); - if (!num_cores_per_replica_attr) + if (topology_attr) { + topology = topology_attr.getValue(); + return mlir::success(); + } else { return cluster.emitOpError( - llvm::formatv("requires attribute '{0}'", - tensorflow::kNumCoresPerReplicaAttr) + llvm::formatv("requires attribute '{0}'", tensorflow::kTopologyAttr) .str()); + } +} - auto device_assignment_attr = cluster->getAttrOfType( - tensorflow::kDeviceAssignmentAttr); +mlir::LogicalResult GetDeviceAssignmentCoordinates( + mlir::tf_device::ClusterOp cluster, + llvm::SmallVector& device_coordinates) { + mlir::ArrayAttr device_assignment_attr = + cluster->getAttrOfType( + tensorflow::kDeviceAssignmentAttr); if (!device_assignment_attr) return cluster.emitOpError(llvm::formatv("requires attribute '{0}'", tensorflow::kDeviceAssignmentAttr) .str()); + if (StatusOr> fetched_device_coordinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + fetched_device_coordinates.ok()) { + device_coordinates = *fetched_device_coordinates; + return mlir::success(); + } else { + return cluster.emitError() << "error in fetching tpu device coordinates: " + << fetched_device_coordinates.status().message(); + } +} - auto status_or_device_coodinates = - tensorflow::GetDeviceCoordinates(device_assignment_attr); +int GetNumCoresPerReplica(mlir::tf_device::ClusterOp cluster) { + mlir::IntegerAttr num_cores_per_replica_attr = + cluster->getAttrOfType(kNumCoresPerReplicaAttr); + if (num_cores_per_replica_attr) { + return num_cores_per_replica_attr.getInt(); + } else { + return 1; + } +} - if (!status_or_device_coodinates.ok()) - return cluster.emitError() - << "error in fetching tpu device coordinates: " - << status_or_device_coodinates.status().error_message(); +// Get the TPUDevicesAndHosts for a cluster that is not replicated. +mlir::LogicalResult GetTPUDevicesAndHostsNotReplicated( + mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, + tensorflow::TPUDevicesAndHosts& devices_and_hosts) { + std::string topology; + if (failed(GetTopology(cluster, topology))) { + return mlir::failure(); + } + + llvm::SmallVector device_coordinates; + if (failed(GetDeviceAssignmentCoordinates(cluster, device_coordinates))) { + return mlir::failure(); + } // Determine compilation and execution devices. - auto status_or_tpu_device_assignment = - tensorflow::GetTPUCompilationAndExecutionDevices( - devices.device_names(), /*num_replicas=*/1, - num_cores_per_replica_attr.getInt(), topology_attr.getValue(), - std::move(status_or_device_coodinates).value()); - if (!status_or_tpu_device_assignment.ok()) + if (StatusOr tpu_device_assignment = + tensorflow::GetTPUCompilationAndExecutionDevices( + devices.device_names(), /*num_replicas=*/1, + GetNumCoresPerReplica(cluster), topology, device_coordinates); + tpu_device_assignment.ok()) { + devices_and_hosts = tpu_device_assignment->tpu_devices; + return mlir::success(); + } else { return cluster.emitError() << "error in fetching TPU compilation/execution devices: " - << status_or_tpu_device_assignment.status().error_message(); - auto& tpu_device_assignment = status_or_tpu_device_assignment.value(); + << tpu_device_assignment.status().message(); + } +} - *host_device = tpu_device_assignment.tpu_devices[0][0].host; +mlir::LogicalResult GetHostDeviceOCInTPUPipeline( + mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, + std::string& host_device) { + mlir::tf_device::ReplicateOp replicate = + cluster->getParentOfType(); + if (replicate) { + host_device = GetDeviceAliasForHostOfLogicalCore(0); + return mlir::success(); + } + + tensorflow::TPUDevicesAndHosts devices_and_hosts; + if (failed(GetTPUDevicesAndHostsNotReplicated(devices, cluster, + devices_and_hosts))) { + return mlir::failure(); + } else { + host_device = devices_and_hosts[0][0].host; + return mlir::success(); + } +} + +// Get the map from `core` to `TPU_REPLICATED_HOST_{core}` for a replicated +// TPU cluster. +// TPU_REPLICATED_HOST_{core} is the host that corresponds to the TPU core. +// Different TPU_REPLICATED_HOST_*s can map to the same physical host within the +// same replica. Also, TPU_REPLICATE_HOST_{core} in different replicas can map +// to the same physical host. For example, if there are 2 hosts, num_replicas=8, +// and num_cores_per_replica=2, then all cores in the first 4 replicas will map +// to the first host and all cores in the second 4 replicas will map to the +// second host. +llvm::SmallVector GetTPUToHostMapReplicated( + mlir::tf_device::ClusterOp cluster) { + int num_cores_per_replica = GetNumCoresPerReplica(cluster); + llvm::SmallVector core_to_host; + core_to_host.reserve(num_cores_per_replica); + for (int core = 0; core < num_cores_per_replica; ++core) { + core_to_host.push_back(GetDeviceAliasForHostOfLogicalCore(core)); + } + return core_to_host; +} + +// Get the map from `core` to host device for a non-replicated TPU cluster. +mlir::LogicalResult GetTPUToHostMapNotReplicated( + mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, + llvm::SmallVector& core_to_host) { + tensorflow::TPUDevicesAndHosts devices_and_hosts; + if (failed(GetTPUDevicesAndHostsNotReplicated(devices, cluster, + devices_and_hosts))) { + return mlir::failure(); + } + + // core_to_host is the list of hosts in replica 0, which is the only replica. + core_to_host.reserve(GetNumCoresPerReplica(cluster)); + for (const auto& device_and_host : devices_and_hosts[0]) { + core_to_host.push_back(device_and_host.host); + } return mlir::success(); } +// Get the map from `core` to host device for a TPU cluster. +mlir::LogicalResult GetTPUToHostMap( + mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, + llvm::SmallVector& core_to_host) { + if (cluster->getParentOfType()) { + core_to_host = GetTPUToHostMapReplicated(cluster); + return mlir::success(); + } + return GetTPUToHostMapNotReplicated(devices, cluster, core_to_host); +} + } // anonymous namespace StatusOr> GetDeviceCoordinates( @@ -518,16 +618,14 @@ StatusOr> GetDeviceCoordinates( } StatusOr GetTPUCompilationAndExecutionDevices( - Devices devices, int num_replicas, int num_cores_per_replica, + ParsedDevices devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, llvm::ArrayRef device_assignment_attr) { // Collect TPU_SYSTEM devices. - llvm::SmallVector system_devices; - TF_RETURN_IF_ERROR(GetTPUSystemDevices(devices, &system_devices)); + TF_ASSIGN_OR_RETURN(auto system_devices, GetTPUSystemDevices(devices)); // Collect TPU devices based on TPU_SYSTEM devices collected earlier. - llvm::SmallVector, 8> tpu_devices; - TF_RETURN_IF_ERROR(GetTPUDevices(devices, system_devices, &tpu_devices)); + TF_ASSIGN_OR_RETURN(auto tpu_devices, GetTPUDevices(devices, system_devices)); std::string compilation_device = GetTPUCompilationDevice(system_devices[0]); @@ -553,10 +651,14 @@ StatusOr GetTPUCompilationAndExecutionDevices( std::move(devices_and_ids.second)); } -std::string GetDeviceAliasForLogicalCore(int core_index) { +std::string GetDeviceAliasForLogicalCore(const int core_index) { return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str(); } +std::string GetDeviceAliasForHostOfLogicalCore(const int core_index) { + return llvm::formatv("{0}_{1}", kTPUReplicatedHost, core_index).str(); +} + bool HasModelParallelism(mlir::tf_device::ClusterOp cluster) { mlir::IntegerAttr num_cores_per_replica_attr = cluster->getAttrOfType( @@ -576,13 +678,15 @@ mlir::LogicalResult GetHostDeviceOutsideComputation( mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, std::string* host_device) { if (HasTPUDevice(devices) || - cluster->getParentOfType()) - return GetHostDeviceOCInTPUPipeline(devices, cluster, host_device); - return GetHostDeviceOCInGenericPipeline(devices, host_device); + cluster->getParentOfType()) { + return GetHostDeviceOCInTPUPipeline(devices, cluster, *host_device); + } else { + return GetHostDeviceOCInGenericPipeline(devices, host_device); + } } bool IsTPUDevice(llvm::StringRef device) { - Device parsed_device; + ParsedDevice parsed_device; if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device), &parsed_device)) return false; @@ -590,10 +694,41 @@ bool IsTPUDevice(llvm::StringRef device) { } bool IsTPUReplicatedCore(llvm::StringRef device) { - Device parsed_device; + ParsedDevice parsed_device; if (!DeviceNameUtils::ParseFullName(mlir::StringRefToView(device), &parsed_device)) return false; return parsed_device.has_type && parsed_device.type == kTPUReplicatedCore; } + +bool TypeValidForXLA(const mlir::Type& type) { + const mlir::Type elem = getElementTypeOrSelf(type); + return !elem.isa() && + !elem.isa(); +} + +mlir::LogicalResult GetDeviceToHostMap( + mlir::tf_device::ClusterOp cluster, + llvm::SmallVector& core_to_host) { + mlir::TF::RuntimeDevices devices; + if (failed(tensorflow::GetDevicesFromOp( + cluster->getParentOfType(), &devices))) { + return mlir::failure(); + } + + if (tensorflow::HasTPUDevice(devices) || + cluster->getParentOfType()) { + return GetTPUToHostMap(devices, cluster, core_to_host); + } + + std::string host_device; + if (failed(tensorflow::GetHostDeviceOCInGenericPipeline(devices, + &host_device))) { + return mlir::failure(); + } else { + core_to_host.push_back(host_device); + return mlir::success(); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index f4780d6abc0..77f853be582 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -35,7 +35,6 @@ limitations under the License. namespace tensorflow { using tsl::StatusOr; -inline constexpr absl::string_view kTPUReplicatedHost = "TPU_REPLICATED_HOST"; inline constexpr absl::string_view kNumCoresPerReplicaAttr = "num_cores_per_replica"; inline constexpr absl::string_view kTopologyAttr = "topology"; @@ -238,10 +237,14 @@ StatusOr GetTPUCompilationAndExecutionDevices( int num_cores_per_replica, llvm::StringRef topology_attr, llvm::ArrayRef device_assignment_attr); -// Virtual device is used for evice assignment for executing ops on a specified -// logical core. +// Virtual device name of the passed logical core. The logical core is the index +// of a core within a replica. std::string GetDeviceAliasForLogicalCore(int core_index); +// Virtual device name of the host that is associated with the passed logical +// core. The logical core is the index of a core within a replica. +std::string GetDeviceAliasForHostOfLogicalCore(int core_index); + // Returns true if cluster contains model parallelism based on // `num_cores_per_replica_attribute`. Otherwise returns false. bool HasModelParallelism(mlir::tf_device::ClusterOp cluster); @@ -251,7 +254,8 @@ bool HasTPUDevice(const mlir::TF::RuntimeDevices& devices); // Parses XLA compilation and execution devices from a tf_device.cluster and // returns the host device for the head and tail computations. For TPU device, -// if the computation is replicated, kTPUReplicatedHost is returned instead. +// if the computation is replicated, GetDeviceAliasForHostOfLogicalCore(0) is +// returned instead. mlir::LogicalResult GetHostDeviceOutsideComputation( mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, std::string* host_device); @@ -262,6 +266,20 @@ bool IsTPUDevice(llvm::StringRef device); // Checks if a device string is a TPU replicated core device. bool IsTPUReplicatedCore(llvm::StringRef device); +// Checks if `type` is allowed for XLA. String and resources are not XLA types. +// There are other TF types that are not XLA types which will be removed by +// successive passes in TF/XLA bridge phase 2. +bool TypeValidForXLA(const mlir::Type& type); + +// Returns the map from core to the host that is associated with the +// core. If `cluster` is not replicated then the core is a physical core index +// and the host is a physical host name. If `cluster` is replicated then the +// core with index `i` is a logical core (`TPU_REPLICATED_CORE_i`), and the host +// is the associated virtual device name (`TPU_REPLICATED_HOST_i`). +mlir::LogicalResult GetDeviceToHostMap( + mlir::tf_device::ClusterOp cluster, + llvm::SmallVector& core_to_host); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 8cb93df6922..2f33ccd88b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -19,12 +19,15 @@ limitations under the License. #include #include +#include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" @@ -33,6 +36,20 @@ limitations under the License. namespace tensorflow { namespace { +tsl::StatusOr> GetMlirModuleFromString( + llvm::StringRef string, mlir::MLIRContext* context) { + mlir::DialectRegistry mlir_registry; + RegisterAllTensorFlowDialects(mlir_registry); + context->appendDialectRegistry(mlir_registry); + mlir::OwningOpRef mlir_module; + auto status = + tensorflow::DeserializeMlirModule(string, context, &mlir_module); + if (!status.ok()) { + return status; + } + return mlir_module; +} + using Device = DeviceNameUtils::ParsedName; bool DeviceNamesToParsedNames(llvm::ArrayRef device_names, @@ -63,7 +80,7 @@ TEST_P(ParameterizedDeviceSetTest, BadDeviceSet) { devices, /*num_replicas=*/1, /*num_cores_per_replica=*/1, topology_attr, device_assignment_attr); ASSERT_FALSE(status_or.ok()); - EXPECT_EQ(status_or.status().error_message(), std::get<1>(GetParam())); + EXPECT_EQ(status_or.status().message(), std::get<1>(GetParam())); } INSTANTIATE_TEST_SUITE_P( @@ -110,7 +127,7 @@ TEST_P(ParameterizedMetadataTest, BadMetadata) { devices, std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam()), std::get<3>(GetParam())); ASSERT_FALSE(status_or.ok()); - EXPECT_EQ(status_or.status().error_message(), std::get<4>(GetParam())); + EXPECT_EQ(status_or.status().message(), std::get<4>(GetParam())); } std::string TopologyWithMeshShape(llvm::ArrayRef mesh_shape) { @@ -310,7 +327,7 @@ TEST(TPURewriteDeviceUtilTest, device_assignment_attr); ASSERT_FALSE(status_or.ok()); - EXPECT_EQ(status_or.status().error_message(), + EXPECT_EQ(status_or.status().message(), "no TPU device found for 'device_assignment' device coordinate (1, " "0, 0, 0)"); } @@ -622,7 +639,7 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { auto status_or_device_coodinates = GetDeviceCoordinates(device_assignment_attr); ASSERT_TRUE(!status_or_device_coodinates.ok()); - EXPECT_EQ(status_or_device_coodinates.status().error_message(), + EXPECT_EQ(status_or_device_coodinates.status().message(), "bad 'device_assignment' attribute at index 0, not an int"); } @@ -830,7 +847,7 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) { std::string host_device; EXPECT_TRUE(mlir::succeeded( GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device))); - EXPECT_EQ(host_device, kTPUReplicatedHost); + EXPECT_EQ(host_device, GetDeviceAliasForHostOfLogicalCore(0)); } TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) { @@ -917,5 +934,124 @@ TEST(TPURewriteDeviceUtilTest, TestIsTPUDevice) { EXPECT_FALSE(IsTPUDevice("INVALID_DEVICE")); } +TEST(TPURewriteDeviceUtilTest, TestDeviceToHostMapBadTopology) { + static const char* const module_str = + R"( +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}} { + func.func @main() -> () { + "tf_device.cluster"() ({ + tf_device.return + }) {device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], num_cores_per_replica = 2 : i64} : () -> () + func.return + } +})"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + GetMlirModuleFromString(module_str, &context)); + mlir::tf_device::ClusterOp cluster; + module->walk( + [&](mlir::tf_device::ClusterOp descendant) { cluster = descendant; }); + llvm::SmallVector core_to_host; + EXPECT_TRUE(mlir::failed(GetDeviceToHostMap(cluster, core_to_host))); +} + +TEST(TPURewriteDeviceUtilTest, TestDeviceToHostMapBadDeviceAssignment) { + static const char* const module_str = + R"( +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}} { + func.func @main() -> () { + "tf_device.cluster"() ({ + tf_device.return + }) {num_cores_per_replica = 2 : i64, topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01"} : () -> () + func.return + } +})"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + GetMlirModuleFromString(module_str, &context)); + mlir::tf_device::ClusterOp cluster; + module->walk( + [&](mlir::tf_device::ClusterOp descendant) { cluster = descendant; }); + llvm::SmallVector core_to_host; + EXPECT_TRUE(mlir::failed(GetDeviceToHostMap(cluster, core_to_host))); +} + +// Tests `GetDeviceToHostMap` on a non-replicated TPU cluster. +TEST(TPURewriteDeviceUtilTest, TestDeviceToHostMapNotReplicated) { + static const char* const module_str = + R"( +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}} { + func.func @main() -> () { + "tf_device.cluster"() ({ + tf_device.return + }) {device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], num_cores_per_replica = 2 : i64, topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01*\02\08\01"} : () -> () + func.return + } +})"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + GetMlirModuleFromString(module_str, &context)); + mlir::tf_device::ClusterOp cluster; + module->walk( + [&](mlir::tf_device::ClusterOp descendant) { cluster = descendant; }); + llvm::SmallVector core_to_host; + EXPECT_TRUE(mlir::succeeded(GetDeviceToHostMap(cluster, core_to_host))); + EXPECT_EQ(core_to_host.size(), 2); + EXPECT_EQ(core_to_host[0], "/job:localhost/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(core_to_host[1], "/job:localhost/replica:0/task:0/device:CPU:0"); +} + +// Tests `GetDeviceToHostMap` on a replicated TPU cluster. +TEST(TPURewriteDeviceUtilTest, TestDeviceToHostMapReplicated) { + static const char* const module_str = + R"( +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:2", "/job:localhost/replica:0/task:0/device:TPU:3", "/job:localhost/replica:0/task:0/device:TPU:4", "/job:localhost/replica:0/task:0/device:TPU:5", "/job:localhost/replica:0/task:0/device:TPU:6", "/job:localhost/replica:0/task:0/device:TPU:7", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"}} { + func.func @main() -> () { + tf_device.replicate() {n = 4 : i32} { + "tf_device.cluster"() ({ + tf_device.return + }) {device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], num_cores_per_replica = 2 : i64, topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01*\02\08\01"} : () -> () + tf_device.return + } + func.return + } +})"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + GetMlirModuleFromString(module_str, &context)); + mlir::tf_device::ClusterOp cluster; + module->walk( + [&](mlir::tf_device::ClusterOp descendant) { cluster = descendant; }); + llvm::SmallVector core_to_host; + EXPECT_TRUE(mlir::succeeded(GetDeviceToHostMap(cluster, core_to_host))); + EXPECT_EQ(core_to_host.size(), 2); + EXPECT_EQ(core_to_host[0], "TPU_REPLICATED_HOST_0"); + EXPECT_EQ(core_to_host[1], "TPU_REPLICATED_HOST_1"); +} + +// Tests `GetDeviceToHostMap` on a CPU cluster. +TEST(TPURewriteDeviceUtilTest, TestDeviceToHostMapCPU) { + static const char* const module_str = + R"( +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"}} { + func.func @main() -> () { + "tf_device.cluster"() ({ + tf_device.return + }) {} : () -> () + func.return + } +})"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + GetMlirModuleFromString(module_str, &context)); + mlir::tf_device::ClusterOp cluster; + module->walk( + [&](mlir::tf_device::ClusterOp descendant) { cluster = descendant; }); + llvm::SmallVector core_to_host; + EXPECT_TRUE(mlir::succeeded(GetDeviceToHostMap(cluster, core_to_host))); + EXPECT_EQ(core_to_host.size(), 1); + EXPECT_EQ(core_to_host[0], "/job:localhost/replica:0/task:0/device:CPU:0"); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 2a6d94828a3..e55ba55caf9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -40,6 +41,47 @@ namespace { constexpr char kNumSplitAttr[] = "num_split"; +// Gets the proper tensor dimension from XLA OpSharding. +// "replicate_on_last_tile_dim" and "last_tile_dims" should be deducted from the +// real Tensor dimensions when tiled. +// For example: +// f32[8,512](sharding={devices=[1,1,2]0,1 last_tile_dims={REPLICATED}) +// also means a replicated tensor over all devices. +// +// See xla_data.proto for detailed explanations on the fields. +int GetDimsFromXLAShardingTiled(const xla::OpSharding& xla_sharding) { + return xla_sharding.tile_assignment_dimensions_size() - + (xla_sharding.replicate_on_last_tile_dim() ? 1 : 0) - + xla_sharding.last_tile_dims_size(); +} + +// A sharding with OTHER type may be REPLICATED if: +// 'replicate_on_last_tile_dim' is true OR +// 'last_tile_dims' is not empty +// AND +// other than replicated last tile dims, all other dims are not sharded. +bool IsOtherReplicatedSharding(const xla::OpSharding& xla_sharding) { + int max_dim = GetDimsFromXLAShardingTiled(xla_sharding); + for (int i = 0; i < max_dim; ++i) { + if (xla_sharding.tile_assignment_dimensions(i) != 1) { + return false; + } + } + return xla_sharding.type() == xla::OpSharding::OTHER && + (xla_sharding.replicate_on_last_tile_dim() || + !xla_sharding.last_tile_dims().empty()); +} + +bool IsSplitSharding(const xla::OpSharding& sharding) { + return sharding.type() == xla::OpSharding::OTHER && + !IsOtherReplicatedSharding(sharding); +} + +bool IsReplicatedSharding(const xla::OpSharding& sharding) { + return sharding.type() == xla::OpSharding::REPLICATED || + IsOtherReplicatedSharding(sharding); +} + // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways // in 'split_dimension' dimension and returns the split values. mlir::LogicalResult CreateSplitOp(const int num_split, @@ -147,7 +189,7 @@ mlir::LogicalResult HandleTileShardedInputs( // Split nodes at ith depth from the original input node represent nodes // that split the input data at i-th dimension. const auto& dimension_splits = input_sharding.tile_assignment_dimensions(); - for (auto num_splits_and_index : llvm::enumerate(dimension_splits)) { + for (const auto& num_splits_and_index : llvm::enumerate(dimension_splits)) { const int num_splits = num_splits_and_index.value(); const int dimension_index = num_splits_and_index.index(); if (num_splits == 1) continue; @@ -256,7 +298,7 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( << input_index << "-th input"; if (input_sharding_type == xla::OpSharding::REPLICATED) { - for (auto& index_and_inputs : llvm::enumerate(*input_list)) { + for (const auto& index_and_inputs : llvm::enumerate(*input_list)) { index_and_inputs.value().emplace_back( partitioned_input.getOperand(index_and_inputs.index())); } @@ -276,7 +318,7 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( continue; } - if (input_sharding_type == xla::OpSharding::OTHER) { + if (IsSplitSharding(sharding)) { llvm::SmallVector tiled_inputs; auto result = HandleTileShardedInputs( cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs); @@ -290,7 +332,7 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( const int assigned_logical_device = sharding.tile_assignment_devices(i); (*input_list)[assigned_logical_device].emplace_back(tiled_inputs[i]); } - } else if (input_sharding_type == xla::OpSharding::REPLICATED) { + } else if (IsReplicatedSharding(sharding)) { for (auto& inputs : *input_list) inputs.emplace_back(input_value); } else { assert(input_sharding_type == xla::OpSharding::MAXIMAL); @@ -317,7 +359,7 @@ mlir::LogicalResult ParseAndValidateOutputSharding( if (output_sharding_attrs.size() != cluster_func.getNumResults()) return cluster_func.emitError("incorrect number of output sharding"); - for (auto output_sharding_and_index : + for (const auto& output_sharding_and_index : llvm::enumerate(output_sharding_attrs)) { const auto& output_sharding = output_sharding_and_index.value(); const int sharding_index = output_sharding_and_index.index(); @@ -472,7 +514,7 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( mlir::Type* tiled_logical_computation_type) { auto new_output_shape = llvm::to_vector<4>(cluster_func_output_type.getShape()); - for (auto dimension_and_output_splits : + for (const auto& dimension_and_output_splits : llvm::enumerate(output_sharding.tile_assignment_dimensions())) { const auto dimension_index = dimension_and_output_splits.index(); const auto output_splits = dimension_and_output_splits.value(); @@ -515,17 +557,17 @@ mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( output_types->reserve(cluster_func.getNumResults()); int core_index = 0; - for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) { + for (const auto& result_and_index : + llvm::enumerate(cluster_func.getResults())) { const auto output_index = result_and_index.index(); const auto& output_sharding = output_sharding_config[output_index]; - const auto output_sharding_type = output_sharding.type(); const auto cluster_func_output_type = result_and_index.value().getType().cast(); // If output shape of cluster func is statically known and output is tiled // sharded, then the corresponding output shape of cluster func must be // evenly divisible number of shardings. - if (output_sharding_type == xla::OpSharding::OTHER) { + if (IsSplitSharding(output_sharding)) { mlir::Type tiled_logical_computation_type; if (cluster_func_output_type.hasRank()) { auto result = ValidateAndGetTiledExecuteOutputShape( @@ -537,7 +579,7 @@ mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( } cluster_to_core_index->emplace_back(core_index++); output_types->emplace_back(tiled_logical_computation_type); - } else if (output_sharding_type == xla::OpSharding::REPLICATED || + } else if (IsReplicatedSharding(output_sharding) || IsAssignedToLogicalDevice(core_id, output_sharding)) { cluster_to_core_index->emplace_back(core_index++); output_types->emplace_back(cluster_func_output_type); @@ -557,7 +599,7 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx, mlir::tf_device::ParallelExecuteOp new_parallel_execute, mlir::OpBuilder* builder) { - for (auto& result_and_index : + for (const auto& result_and_index : llvm::enumerate(old_parallel_execute.getResults())) { const auto output_index = result_and_index.index(); const auto old_parallel_execute_output = result_and_index.value(); @@ -605,10 +647,11 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( if (output_sharding_type == xla::OpSharding::REPLICATED) { for (const auto& index_and_output : llvm::enumerate(partitioned_output.getOutput())) { + auto idx = (cluster_idx + index_and_output.index()) % + new_parallel_execute->getNumRegions(); const auto output_from_logical_device = new_parallel_execute.GetRegionOutputs( - cluster_idx + - index_and_output.index())[tpu_cluster_output_index]; + idx)[tpu_cluster_output_index]; index_and_output.value().replaceAllUsesWith( output_from_logical_device); } @@ -627,7 +670,7 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( continue; } - if (output_sharding_type == xla::OpSharding::OTHER) { + if (IsSplitSharding(output_sharding)) { if (failed(HandleTileShardedOutputs( tpu_cluster_output_index, output_sharding_config, cluster_to_core_index, location, old_parallel_execute_output, diff --git a/tensorflow/compiler/mlir/tf2xla/BUILD b/tensorflow/compiler/mlir/tf2xla/BUILD index 963bd8cfaa2..5605cbed225 100644 --- a/tensorflow/compiler/mlir/tf2xla/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/BUILD @@ -2,9 +2,13 @@ # TF2XLA Bridge and related components. load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") + +package_group( + name = "tensorflow_mlir_tf2xla", + packages = [ + "//tensorflow/compiler/mlir/tf2xla/...", + ], +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -26,427 +30,7 @@ cc_library( ], ) -cc_library( - name = "compile_mlir_util_no_tf_dialect_passes", - srcs = ["api/v0/compile_mlir_util.cc"], - hdrs = ["api/v0/compile_mlir_util.h"], - deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:bridge_logger", - "//tensorflow/compiler/mlir/tensorflow:convert_tensor", - "//tensorflow/compiler/mlir/tensorflow:convert_type", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:export_graphdef", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/compiler/mlir/tensorflow:shape_inference_pass", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tensorflow:translate_utils", - "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", - "//tensorflow/compiler/mlir/tf2xla:tf_xla_passes", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_targets", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf_with_tf2xla", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:layout_util", - "//tensorflow/compiler/tf2xla:xla_argument", - "//tensorflow/compiler/tf2xla:xla_helpers", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/mlir/framework/transforms:passes", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", - "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:layout_util", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/common_runtime:core_cpu_internal", - "//tensorflow/core/platform:error_payloads", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:logging", - "//tensorflow/core/tpu:tpu_defs", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:variant", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@stablehlo//:register", - ], -) - -tf_cc_test( - name = "compile_mlir_util_test", - srcs = ["api/v0/compile_mlir_util_test.cc"], - deps = [ - ":compile_mlir_util", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/compiler/tf2xla:xla_helpers", - "//tensorflow/core:framework", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Pass", - ], -) - alias( name = "compile_mlir_util", - actual = ":compile_mlir_util_no_tf_dialect_passes", -) - -gentbl_cc_library( - name = "legalize_tf_patterns_inc_gen", - compatible_with = get_compatible_with_cloud(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/legalize_tf_patterns.td", - deps = [ - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncTdFiles", - "@llvm-project//mlir:TensorOpsTdFiles", - ], -) - -gentbl_cc_library( - name = "xla_legalize_tf_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=LegalizeTf", - ], - "transforms/xla_legalize_tf_passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/xla_legalize_tf_passes.td", - deps = [ - "@llvm-project//mlir:PassBaseTdFiles", - ], -) - -gentbl_cc_library( - name = "tf_xla_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TfXla", - ], - "transforms/tf_xla_passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/tf_xla_passes.td", - deps = [ - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncTdFiles", - "@llvm-project//mlir:PassBaseTdFiles", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:TensorOpsTdFiles", - ], -) - -cc_library( - name = "tf_xla_passes", - srcs = [ - "transforms/xla_legalize_tf_passes.h.inc", - ], - hdrs = [ - "transforms/passes.h", - ], - deps = [ - ":tf_xla_passes_inc_gen", - ":xla_legalize_tf", - "//tensorflow/compiler/xla/mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - ], -) - -cc_library( - name = "legalize_utils", - srcs = ["transforms/utils.cc"], - hdrs = ["transforms/utils.h"], - deps = [ - "//tensorflow/compiler/xla/mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "legalize_tf", - srcs = [ - "transforms/generated_legalize_tf.inc", - "transforms/legalize_tf.cc", - ], - hdrs = [ - "transforms/passes.h", - ], - deps = [ - ":legalize_tf_patterns_inc_gen", - ":legalize_utils", - ":tf_xla_passes_inc_gen", - ":xla_legalize_tf_passes_inc_gen", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client:sharding_builder", - "//tensorflow/compiler/xla/client/lib:conv_grad_size_util", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:convert_op_folder", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:attribute_importer", - "//tensorflow/core:framework", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "//tensorflow/tsl/platform:bfloat16", - "//tensorflow/tsl/platform:status", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@stablehlo//:chlo_ops", - ], -) - -cc_library( - name = "xla_legalize_targets", - srcs = [ - "transforms/xla_legalize_targets.cc", - ], - hdrs = [ - "transforms/xla_legalize_targets.h", - ], - deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/xla/mlir_hlo", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@stablehlo//:chlo_ops", - ], -) - -tf_cc_test( - name = "xla_legalize_targets_test", - srcs = ["transforms/xla_legalize_targets_test.cc"], - deps = [ - ":xla_legalize_targets", - "//tensorflow/compiler/mlir/tensorflow", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@stablehlo//:chlo_ops", - ], -) - -tf_cc_test( - name = "verify_tfxla_legalization_test", - srcs = ["transforms/verify_tfxla_legalization_test.cc"], - deps = [ - ":legalize_tf", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/core/lib/monitoring:cell_reader", - "//tensorflow/core/platform:errors", - "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:Pass", - ], -) - -cc_library( - name = "xla_legalize_tf", - srcs = [ - "transforms/convert_mhlo_quant_to_int.cc", - "transforms/infeed_ops_xla_adjust_layout.cc", - "transforms/legalize_tf_collective.cc", - "transforms/legalize_tf_communication.cc", - "transforms/legalize_tf_types.cc", - "transforms/tf_xla_passes.h.inc", - "transforms/tfxla_device_specific_transforms.cc", - "transforms/verify_tfxla_legalization.cc", - "transforms/xla_legalize_tf.cc", - "transforms/xla_legalize_tf_passes.h.inc", - ], - hdrs = [ - "transforms/passes.h", - ], - deps = [ - ":legalize_tf", - ":legalize_utils", - ":xla_legalize_targets", - ":xla_legalize_tf_no_fallback", - ":xla_legalize_tf_passes_inc_gen", - ":xla_legalize_tf_with_tf2xla", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/mlir/tensorflow:mangling_util", - "//tensorflow/compiler/mlir/tensorflow:set_tpu_infeed_layout", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/tf2xla/kernels:rng_converter_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:side_effect_util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client:sharding_builder", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:chlo_legalize_to_hlo", - "//tensorflow/compiler/xla/mlir_hlo:convert_op_folder", - "//tensorflow/compiler/xla/stream_executor/tpu:c_api_conversions", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_api", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:attribute_importer", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/util/quantization:uniform_quant_ops_params", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - "@stablehlo//:chlo_ops", - ], -) - -cc_library( - name = "xla_legalize_tf_no_fallback", - srcs = [ - "transforms/xla_legalize_tf_no_fallback.cc", - "transforms/xla_legalize_tf_passes.h.inc", - ], - hdrs = [ - "transforms/passes.h", - ], - deps = [ - ":legalize_tf", - ":tf_xla_passes_inc_gen", - ":xla_legalize_tf_passes_inc_gen", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/xla/mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - "@stablehlo//:chlo_ops", - ], -) - -cc_library( - name = "xla_legalize_tf_with_tf2xla", - srcs = [ - "transforms/legalize_tf_with_tf2xla.cc", - ], - hdrs = [ - "transforms/passes.h", - ], - deps = [ - ":tf_xla_passes_inc_gen", - ":xla_legalize_tf_passes_inc_gen", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_tensor", - "//tensorflow/compiler/mlir/tensorflow:convert_type", - "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow:tpu_embedding_ops_registry", - "//tensorflow/compiler/mlir/tensorflow:translate_utils", - "//tensorflow/compiler/tf2xla:xla_compilation_device", - "//tensorflow/compiler/tf2xla:xla_context", - "//tensorflow/compiler/tf2xla:xla_expression", - "//tensorflow/compiler/tf2xla:xla_helpers", - "//tensorflow/compiler/tf2xla:xla_op_registry", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/stream_executor:timer", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:mlir_hlo_builder", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_options", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - ], + actual = "//tensorflow/compiler/mlir/tf2xla/api/v0:compile_mlir_util_no_tf_dialect_passes", ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v0/BUILD new file mode 100644 index 00000000000..18744b3032f --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/BUILD @@ -0,0 +1,138 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "compile_mlir_util_no_tf_dialect_passes", + srcs = ["compile_mlir_util.cc"], + hdrs = ["compile_mlir_util.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:bridge_logger", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:export_graphdef", + "//tensorflow/compiler/mlir/tensorflow:import_model", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/mlir/tensorflow:shape_inference_pass", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", + "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_targets", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/tf2xla:xla_argument", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", + "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:layout_util", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/platform:error_payloads", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:logging", + "//tensorflow/core/tpu:tpu_defs", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@stablehlo//:register", + ], +) + +tf_cc_test( + name = "compile_mlir_util_test", + srcs = ["compile_mlir_util_test.cc"], + deps = [ + ":compile_mlir_util_no_tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:framework", + "//tensorflow/core/lib/monitoring:cell_reader", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "compile_tf_graph", + srcs = ["compile_tf_graph.cc"], + hdrs = ["compile_tf_graph.h"], + deps = [ + ":compile_mlir_util_no_tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:import_model", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/mlir/tensorflow:set_tpu_infeed_layout", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/pjrt:compile_options_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tpu:tpu_compile", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_util", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:variant", + "@llvm-project//mlir:IR", + ], +) + +tf_cc_test( + name = "compile_tf_graph_test", + testonly = 1, + srcs = ["compile_tf_graph_test.cc"], + linkstatic = 1, + deps = [ + ":compile_tf_graph", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/lib/monitoring:test_utils", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc index c800a6fce7a..19c148214e1 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc @@ -16,11 +16,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h" #include +#include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/optional.h" -#include "absl/types/variant.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -52,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -78,10 +77,10 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" #include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/util/debug_data_dumper.h" namespace tensorflow { namespace { - constexpr absl::string_view kGroupSizeAttrName = "tf2xla.collective_info.group_size"; constexpr absl::string_view kGroupKeyAttrName = @@ -336,7 +335,7 @@ void AddLegalizationPasses(mlir::OpPassManager& pm, bool legalize_chlo, // in VerifyTFXLALegalization that full conversion happened. // TODO(b/188389290): Cleanup allow_partial_conversion as a legalization // parameter. - pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( + pm.addPass(mlir::mhlo::createLegalizeTFPass( /*allow_partial_conversion=*/true, legalize_chlo, /*tf2xla_fallback_device_type=*/device_type, enable_op_fallback)); @@ -356,35 +355,6 @@ void AddLegalizationPasses(mlir::OpPassManager& pm, bool legalize_chlo, pm.addPass(mlir::TF::CreateTFShapeInferencePass()); } -// The default LLVM MLIR Inliner always runs canonicalization, however there -// is a bug where dumping the pass pipeline and recreating it in offline -// tools doesn't run canonicalization. To ensure prod and offline tools -// inlining are equal, explicitly create the Inliner with canonicalization so -// that the canonicalizer is dumped as part of pipeline passes. -// See https://github.com/llvm/llvm-project/issues/60960. -ABSL_CONST_INIT absl::Mutex pass_registration_lock(absl::kConstInit); -std::unique_ptr CreateInlinerWithCanonicalization() { - // This is really wonky. Pass Registration isn't thread safe in LLVM, so we - // need a mutex to guard pass registration. Pass registration also needs - // to happen once per thread, so make this thread local. - // TODO(b/268509024): Delete this whole function once the upstream LLVM issue - // is resolved. - static thread_local bool pass_registered = false; - if (!pass_registered) { - absl::MutexLock lock(&pass_registration_lock); - mlir::registerCanonicalizerPass(); - pass_registered = true; - } - - auto inliner = mlir::createInlinerPass(/*opPipelines=*/{}, - /*defaultPipelineBuilder=*/{}); - if (inliner->initializeOptions("default-pipeline=canonicalize").failed()) { - return nullptr; - } - - return inliner; -} - } // namespace void CreateConvertMlirToXlaHloPipeline( @@ -401,7 +371,7 @@ void CreateConvertMlirToXlaHloPipeline( // Note that the region-based control-flow produced here still contains // function call ops which get inlined by the subsequent inliner pass. pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); - pm.addPass(CreateInlinerWithCanonicalization()); + pm.addPass(mlir::createInlinerPass()); pm.addNestedPass( mlir::TF::CreateDropWhileShapeInvariantPass()); // Create a replicated TensorList initialization ops for all of its uses. This @@ -457,8 +427,6 @@ void CreateConvertMlirToXlaHloPipeline( pm.addNestedPass(mlir::TF::CreateLowerQuantizedPass()); pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass()); - pm.addPass(mlir::mhlo::createLegalizeTFModulePass( - /*tf2xla_fallback_device_type=*/device_type)); for (auto& target_pass : custom_legalization_passes) { pm.addNestedPass(std::move(target_pass)); @@ -481,7 +449,7 @@ void CreateConvertMlirToXlaHloPipeline( } if (CanInlineFunctionsPostLegalization(device_type)) { - pm.addPass(CreateInlinerWithCanonicalization()); + pm.addPass(mlir::createInlinerPass()); } // In order to export to XLA, we must sink constants to control flow regions, @@ -543,20 +511,34 @@ Status RefineShapes(llvm::ArrayRef arg_shapes, Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type, bool enable_op_fallback, llvm::MutableArrayRef> - custom_legalization_passes) { + custom_legalization_passes, + llvm::StringRef module_name = llvm::StringRef()) { mlir::PassManager tf2xla(module_op.getContext()); applyTensorflowAndCLOptions(tf2xla); CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, enable_op_fallback, custom_legalization_passes); - if (VLOG_IS_ON(1)) - tensorflow::DumpMlirOpToFile("legalize_hlo_before", module_op, "", &tf2xla); - if (VLOG_IS_ON(2)) { + if (DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(), kDebugGroupMain) || + VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile( + DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain, + "legalize_hlo_before"), + module_op, "", &tf2xla); + } + + if (VLOG_IS_ON(2) || DEBUG_DATA_DUMPER()->ShouldDump( + module_name.str(), kDebugGroupBridgePhase2)) { // Print the whole module after each pass which requires disabling // multi-threading as well. module_op.getContext()->disableMultithreading(); - tf2xla.enableIRPrinting(std::make_unique( - /*print_module_scope=*/true)); + tf2xla.enableIRPrinting( + std::make_unique<::tensorflow::DataDumperLoggerConfig>( + [module_name](const std::string& pass_tag_name) { + return DEBUG_DATA_DUMPER()->GetDumpFilename( + module_name.str(), kDebugGroupBridgePhase2, pass_tag_name); + }, + "", + /*print_module_scope=*/true)); } // Make sure we catch any error reported by MLIR and forward it to the TF @@ -572,8 +554,14 @@ Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type, return error_handler.Combine(status); } - if (VLOG_IS_ON(1)) - tensorflow::DumpMlirOpToFile("legalize_hlo_after", module_op, "", &tf2xla); + if (DEBUG_DATA_DUMPER()->ShouldDump(module_name.str(), kDebugGroupMain) || + VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile( + DEBUG_DATA_DUMPER()->GetDumpFilename(module_name.str(), kDebugGroupMain, + "legalize_hlo_after"), + module_op, "", &tf2xla); + } + Status status = error_handler.ConsumeStatus(); tensorflow::OkOrSetErrorCounterPayload( tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_2, @@ -602,9 +590,10 @@ Status ConvertMLIRToXlaComputation( bool enable_op_fallback, bool return_tuple, const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, llvm::MutableArrayRef> - custom_legalization_passes) { + custom_legalization_passes, + llvm::StringRef module_name) { TF_RETURN_IF_ERROR(LegalizeToHlo(module_op, device_type, enable_op_fallback, - custom_legalization_passes)); + custom_legalization_passes, module_name)); mlir::MlirToHloConversionOptions options; options.layout_preference_fn = @@ -722,7 +711,8 @@ Status CompileMlirToXlaHlo( XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> - custom_legalization_passes) { + custom_legalization_passes, + llvm::StringRef module_name) { if (enable_op_fallback && GetMlirBridge2ndPhaseRolloutPolicy(module_op) == MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) { @@ -736,7 +726,7 @@ Status CompileMlirToXlaHlo( TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( module_op, device_type, compilation_result->computation.get(), use_tuple_args, enable_op_fallback, use_return_tuple, - shape_determination_fns, custom_legalization_passes)); + shape_determination_fns, custom_legalization_passes, module_name)); TF_RETURN_IF_ERROR(PopulateCollectiveInfo(module_op, compilation_result)); @@ -751,7 +741,8 @@ Status CompileSerializedMlirToXlaHlo( const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> - custom_legalization_passes) { + custom_legalization_passes, + llvm::StringRef module_name) { mlir::DialectRegistry mlir_registry; RegisterDialects(mlir_registry); mlir::MLIRContext mlir_context(mlir_registry); @@ -767,7 +758,7 @@ Status CompileSerializedMlirToXlaHlo( mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args, enable_op_fallback, /*use_return_tuple=*/true, /*use_resource_updates_for_aliases=*/false, shape_determination_fns, - compilation_result, custom_legalization_passes); + compilation_result, custom_legalization_passes, module_name); } // Rewrites the given module with specified args. For each of the constant args, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h index 19f1551382d..84cf70f0b3b 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" namespace tensorflow { @@ -72,7 +72,8 @@ Status ConvertMLIRToXlaComputation( const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns = {}, llvm::MutableArrayRef> - custom_legalization_passes = {}); + custom_legalization_passes = {}, + llvm::StringRef module_name = llvm::StringRef()); // Creates a MLIR pipeline that lowers MLIR module to MHLO dialect. The input // module should only contain operations in tf dialect. For example, if the @@ -144,7 +145,8 @@ Status CompileMlirToXlaHlo( XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> - custom_legalization_passes); + custom_legalization_passes, + llvm::StringRef module_name = llvm::StringRef()); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. @@ -154,7 +156,8 @@ Status CompileSerializedMlirToXlaHlo( const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> - custom_legalization_passes = {}); + custom_legalization_passes = {}, + llvm::StringRef module_name = llvm::StringRef()); // Compiles a TensorFlow Graph (already converted to MLIR, imported with // tf_executor dialect still present) into XLA HLO, generates all accompanying diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util_test.cc index 31dd9aeb551..d3158b5a917 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util_test.cc @@ -27,11 +27,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" namespace tensorflow { namespace { using ::mlir::OpPassManager; +using ::tensorflow::monitoring::testing::CellReader; using ::testing::HasSubstr; static constexpr char kMlirModuleStr[] = R"( @@ -64,6 +66,8 @@ TEST(LegalizeMlirTest, FailsLegalizesModule) { func.return %0 : tensor<1xi32> } })"; + CellReader count( + "/tensorflow/core/tf2xla/v0/mlir_failed_xla_legalize_tf_pass_count"); std::vector arg_shapes; XlaCompilationResult compilation_result; @@ -73,6 +77,7 @@ TEST(LegalizeMlirTest, FailsLegalizesModule) { /*shape_determination_fns=*/{}, &compilation_result); EXPECT_FALSE(status.ok()); + EXPECT_EQ(count.Delta("tf.DoesntExist", "Unknown"), 1); } TEST(CompileMlirUtil, CreatesPipeline) { @@ -89,9 +94,7 @@ TEST(CompileMlirUtil, CreatesPipeline) { TEST(CompileMlirUtil, HasLegalizationPass) { OpPassManager pass_manager; llvm::StringRef device_type = "XLA_CPU_JIT"; - absl::string_view kLegalizeTfPass = - "xla-legalize-tf{allow-partial-conversion=false device-type=XLA_CPU_JIT " - "legalize-chlo=true prefer-tf2xla=true use-tf2xla-fallback=true})"; + absl::string_view kLegalizeTfPass = "xla-legalize-tf"; CreateConvertMlirToXlaHloPipeline(pass_manager, device_type, /*enable_op_fallback=*/true, @@ -121,5 +124,24 @@ TEST(CompileMlirUtil, CanonicalizationIsExplicitDuringInlining) { EXPECT_THAT(pass_description, HasSubstr(kInlinePass)); } +TEST(LegalizeMlirTest, LegalizesModuleWithDynamicShape) { + constexpr char legalization[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor>) -> tensor> { + %0 = "tf.Identity"(%arg0) : (tensor>) -> tensor> + func.return %0 : tensor> + } + })"; + + std::vector arg_shapes = {{1}}; + XlaCompilationResult compilation_result; + Status status = CompileSerializedMlirToXlaHlo( + legalization, arg_shapes, /*device_type=*/"XLA_TPU_JIT", + /*use_tuple_args=*/true, /*enable_op_fallback=*/false, + /*shape_determination_fns=*/{}, &compilation_result); + + EXPECT_TRUE(status.ok()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.cc new file mode 100644 index 00000000000..9df38ac36cb --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.cc @@ -0,0 +1,257 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.h" + +#include +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/tpu_compile.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace v0 { + +using ::tensorflow::tpu::FunctionToHloArgs; +using ::tensorflow::tpu::GuaranteedConsts; +using ::tensorflow::tpu::MlirToHloArgs; +using ::tensorflow::tpu::ShardingAndIndex; + +auto* phase2_bridge_compilation_status = + tensorflow::monitoring::Counter<1>::New( + "/tensorflow/core/tf2xla/api/v0/" + "phase2_compilation_status", /*metric_name*/ + "Tracks the compilation status of the non-mlir bridge", + /* metric description */ "status" /* metric label */); + +auto* phase2_bridge_compilation_time = tsl::monitoring::Sampler<1>::New( + {"/tensorflow/core/tf2xla/api/v0/phase2_compilation_time", + "The wall-clock time spent on executing graphs in milliseconds.", + "configuration"}, + // Power of 1.5 with bucket count 45 (> 23 hours) + {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); + +// There were no MLIR ops so the old bridge was called successfully. +constexpr char kOldBridgeNoMlirSuccess[] = "kOldBridgeNoMlirSuccess"; +// There were no MLIR ops so the old bridge was called but it failed. +constexpr char kOldBridgeNoMlirFailure[] = "kOldBridgeNoMlirFailure"; + +namespace { + +// Time the execution of kernels (in CPU cycles). Meant to be used as RAII. +struct CompilationTimer { + uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); + + uint64 ElapsedCycles() { + return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles; + } + + int64_t ElapsedCyclesInMilliseconds() { + std::chrono::duration duration = + profile_utils::CpuUtils::ConvertClockCycleToTime(ElapsedCycles()); + + return std::chrono::duration_cast(duration) + .count(); + } +}; + +// Populates input_output_alias field in the HLO Module proto. +Status PopulateInputOutputAliasing( + mlir::func::FuncOp main_fn, + XlaCompiler::CompilationResult* compilation_result, bool use_tuple_args) { + constexpr char kAliasingAttr[] = "tf.aliasing_output"; + llvm::SmallDenseMap output_to_input_alias; + unsigned num_arguments = main_fn.getNumArguments(); + for (unsigned arg_index = 0; arg_index < num_arguments; ++arg_index) { + if (auto aliasing_output = main_fn.getArgAttrOfType( + arg_index, kAliasingAttr)) + output_to_input_alias[aliasing_output.getInt()] = arg_index; + } + + if (output_to_input_alias.empty()) return OkStatus(); + + xla::HloModuleProto* module_proto = + compilation_result->computation->mutable_proto(); + StatusOr program_shape_or_status = + compilation_result->computation->GetProgramShape(); + TF_RET_CHECK(program_shape_or_status.ok()); + + xla::ProgramShape& program_shape = program_shape_or_status.value(); + if (!program_shape.result().IsTuple()) + return errors::Internal("Expect result to have tuple shape"); + + xla::HloInputOutputAliasConfig config(program_shape.result()); + for (auto alias : output_to_input_alias) { + if (use_tuple_args) { + TF_RETURN_IF_ERROR(config.SetUpAlias( + xla::ShapeIndex({alias.first}), 0, xla::ShapeIndex({alias.second}), + xla::HloInputOutputAliasConfig::AliasKind::kMayAlias)); + } else { + TF_RETURN_IF_ERROR(config.SetUpAlias( + xla::ShapeIndex({alias.first}), alias.second, xla::ShapeIndex({}), + xla::HloInputOutputAliasConfig::AliasKind::kMayAlias)); + } + } + *module_proto->mutable_input_output_alias() = config.ToProto(); + return OkStatus(); +} + +// Transforms the given module to be suitable for export to TensorFlow GraphDef +// and then exports all functions to the given library. +Status PrepareAndExportToLibrary(mlir::ModuleOp module, + FunctionLibraryDefinition* flib_def) { + // Pass pipeline is defined here instead of leveraging the phase one export + // pipeline because only the functional to executor dialect conversion and + // breakup islands passes are common between the export pipeline and here. + // Reconsider this if there is more commonality in the future with more + // passes. + mlir::PassManager manager(module.getContext()); + applyTensorflowAndCLOptions(manager); + manager.addPass(mlir::TF::CreatePrepareTpuComputationForTfExportPass()); + manager.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + manager.addNestedPass( + mlir::CreateFunctionalToExecutorDialectConversionPass()); + manager.addPass(mlir::CreateBreakUpIslandsPass()); + + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + if (failed(manager.run(module))) return diag_handler.ConsumeStatus(); + + GraphExportConfig config; + config.export_entry_func_to_flib = true; + return tensorflow::ConvertMlirToGraph(module, config, /*graph=*/nullptr, + flib_def); +} + +} // namespace + +tsl::Status CompileTensorflowGraphToHlo( + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + const XlaShapeLayoutHelpers::ShapeDeterminationFns + shape_determination_funcs, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client, + XlaCompiler::CompilationResult* compilation_result) { + LOG_FIRST_N(INFO, 1) << "Compiling MLIR computation to XLA HLO using the " + "old (non-MLIR) tf2xla bridge"; + + *compilation_result = {}; + bool has_mlir = computation.index() == 0; + + std::string mlir_string = has_mlir ? "has_mlir" : "has_function_to_hlo"; + const std::string kBridgePhase2Config = + absl::StrCat("graph_old_bridge_", mlir_string); + CompilationTimer timer; + + if (!has_mlir) { + FunctionToHloArgs function_computation = std::get<1>(computation); + Status comp_status = CompileTFFunctionToHlo( + *function_computation.flib_def, function_computation.graph_def_version, + shape_determination_funcs, arg_shapes, + function_computation.guaranteed_constants, + *function_computation.function, metadata, client, arg_core_mapping, + per_core_arg_shapes, use_tuple_args, compilation_result); + if (comp_status.ok()) { + phase2_bridge_compilation_status->GetCell(kOldBridgeNoMlirSuccess) + ->IncrementBy(1); + } else { + phase2_bridge_compilation_status->GetCell(kOldBridgeNoMlirFailure) + ->IncrementBy(1); + } + + phase2_bridge_compilation_time->GetCell(kBridgePhase2Config) + ->Add(timer.ElapsedCyclesInMilliseconds()); + return comp_status; + } + + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module; + TF_RETURN_IF_ERROR(DeserializeMlirModule(std::get<0>(computation).mlir_module, + &context, &mlir_module)); + if (!mlir::SetTPUInfeedLayout(mlir_module)) + return errors::Internal("Failed to set layouts attribute"); + + if (VLOG_IS_ON(2)) { + tensorflow::DumpMlirOpToFile("legalize_with_old_bridge", mlir_module.get()); + } + constexpr char kEntryFuncName[] = "main"; + auto main_fn = mlir_module->lookupSymbol(kEntryFuncName); + if (!main_fn) { + return errors::Internal( + "TPU compile op requires module with a entry function main"); + } + + // Export functions to the library. + auto flib_def = std::make_unique( + OpRegistry::Global(), FunctionDefLibrary()); + TF_RETURN_IF_ERROR(PrepareAndExportToLibrary(*mlir_module, flib_def.get())); + + if (VLOG_IS_ON(2)) { + tensorflow::DumpMlirOpToFile("legalize_with_old_bridge_post_transform", + mlir_module.get()); + } + VersionDef versions; + if (mlir::failed(ExtractTfVersions(*mlir_module, &versions))) { + return errors::Internal( + "module attribute in _TPUCompileMlir op is missing tf versions."); + } + + NameAttrList func; + func.set_name(kEntryFuncName); + GuaranteedConsts consts; + + *compilation_result = {}; + + TF_RETURN_IF_ERROR(CompileTFFunctionToHlo( + *flib_def, versions.producer(), shape_determination_funcs, arg_shapes, + consts, func, metadata, client, arg_core_mapping, per_core_arg_shapes, + use_tuple_args, compilation_result)); + + phase2_bridge_compilation_time->GetCell(kBridgePhase2Config) + ->Add(timer.ElapsedCyclesInMilliseconds()); + + return PopulateInputOutputAliasing(main_fn, compilation_result, + use_tuple_args); +} + +}; // namespace v0 +}; // namespace tf2xla +}; // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.h b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.h new file mode 100644 index 00000000000..b249b228330 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V0_COMPILE_TF_GRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V0_COMPILE_TF_GRAPH_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/pjrt/compile_options.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" + +namespace tensorflow { +namespace tf2xla { +namespace v0 { + +// Compiles the given Tensorflow graph into xla::HLO. The result is in +// compilation_result. If the input computation is in MLIR, it will be +// converted to a Tensorflow graph. Otherwise, the graph compiler will be run. +tsl::Status CompileTensorflowGraphToHlo( + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client, + XlaCompiler::CompilationResult* compilation_result); + +} // namespace v0 +} // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V0_COMPILE_TF_GRAPH_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph_test.cc new file mode 100644 index 00000000000..678e1ab7243 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/lib/monitoring/test_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace v0 { +namespace { + +using ::tensorflow::monitoring::testing::CellReader; +using ::tensorflow::tpu::FunctionToHloArgs; +using ::tensorflow::tpu::MlirToHloArgs; +using ::tensorflow::tpu::ShardingAndIndex; +using ::tsl::monitoring::testing::Histogram; + +static constexpr char kCompilationTimeStreamzName[] = + "/tensorflow/core/tf2xla/api/v0/phase2_compilation_time"; + +static constexpr char kCompilationStatusStreamzName[] = + "/tensorflow/core/tf2xla/api/v0/phase2_compilation_status"; + +MlirToHloArgs CreateTestMlirToHloArgs() { + static constexpr char kMlirModuleStr[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> () { + func.return + } + })"; + + MlirToHloArgs mlir_to_hlo_args; + mlir_to_hlo_args.rollout_state = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + mlir_to_hlo_args.mlir_module = kMlirModuleStr; + + return mlir_to_hlo_args; +} + +class CompileTFGraphTest : public ::testing::Test { + public: + tsl::Status CompileWithComputation( + const std::variant + computation) { + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("Host").value(); + auto client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); + + std::vector arg_shapes; + bool use_tuple_args = true; + std::vector arg_core_mapping; + std::vector> per_core_arg_shapes; + XlaCompiler::CompilationResult result; + tpu::TPUCompileMetadataProto metadata_proto; + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs; + + return CompileTensorflowGraphToHlo( + computation, metadata_proto, use_tuple_args, shape_determination_funcs, + arg_shapes, &arg_core_mapping, &per_core_arg_shapes, client, &result); + } +}; + +TEST_F(CompileTFGraphTest, RecordsStreamzForMlirFallback) { + CellReader compilation_time(kCompilationTimeStreamzName); + + MlirToHloArgs mlir_to_hlo_args = CreateTestMlirToHloArgs(); + + TF_EXPECT_OK(CompileWithComputation(mlir_to_hlo_args)); + + Histogram histogram = compilation_time.Delta("graph_old_bridge_has_mlir"); + + EXPECT_EQ(histogram.num(), 1); +} + +TEST_F(CompileTFGraphTest, RecordsStreamzForFunctionToHlo) { + CellReader compilation_time(kCompilationTimeStreamzName); + CellReader compilation_status(kCompilationStatusStreamzName); + + FunctionDef empty_function = + tensorflow::FunctionDefHelper::Create("empty", {}, {}, {}, {}, {}); + + tensorflow::FunctionDefLibrary fdef; + *(fdef.add_function()) = empty_function; + tensorflow::FunctionLibraryDefinition flib_def( + tensorflow::OpRegistry::Global(), fdef); + + OpInputList guaranteed_constants; + NameAttrList function; + function.set_name("empty"); + + FunctionToHloArgs function_to_hlo_args = {&function, + &flib_def, + /*graph_def_version=*/0, + {&guaranteed_constants}}; + + TF_EXPECT_OK(CompileWithComputation(function_to_hlo_args)); + + Histogram histogram = + compilation_time.Delta("graph_old_bridge_has_function_to_hlo"); + + EXPECT_EQ(histogram.num(), 1); + EXPECT_EQ(compilation_status.Delta("kOldBridgeNoMlirSuccess"), 1); +} + +} // namespace +} // namespace v0 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD new file mode 100644 index 00000000000..a95e558f506 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -0,0 +1,82 @@ +load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + ":__subpackages__", + ":tf2xla_users", + ], +) + +# Please reach out to tf-bridge-team@ before using the TF2XLA bridge. +package_group(name = "tf2xla_users") + +cc_library( + name = "legalize_tf", + srcs = ["legalize_tf.cc"], + hdrs = ["legalize_tf.h"], + deps = [ + ":device_type_proto_cc", + "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/compiler/jit:shape_inference", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:export_graphdef", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/mlir/tensorflow:set_tpu_infeed_layout", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", + "//tensorflow/compiler/mlir/tf2xla/api/v0:compile_tf_graph", + "//tensorflow/compiler/tf2xla:layout_util", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", + "//tensorflow/compiler/xla/pjrt:compile_options_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform:statusor", + "//tensorflow/core/tpu:tpu_compile", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_util_hdrs", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:variant", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@stablehlo//:register", + ], +) + +tf_cc_test( + name = "legalize_tf_test", + srcs = ["legalize_tf_test.cc"], + deps = [ + ":device_type_proto_cc", + ":legalize_tf", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/tsl/lib/monitoring:test_utils", + "//tensorflow/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + +tf_proto_library( + name = "device_type_proto", + srcs = ["device_type.proto"], + cc_api_version = 2, +) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/device_type.proto b/tensorflow/compiler/mlir/tf2xla/api/v1/device_type.proto new file mode 100644 index 00000000000..6bca9312b4a --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/device_type.proto @@ -0,0 +1,11 @@ +syntax = "proto2"; + +package tensorflow.tf2xla.v1; + +// The requested device type to compile for. +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0; + XLA_TPU_JIT = 1; + XLA_CPU_JIT = 2; + XLA_GPU_JIT = 3; +} diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc new file mode 100644 index 00000000000..f5f6818d33e --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc @@ -0,0 +1,281 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_tf_graph.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_util.h" +#include "tensorflow/core/tpu/tpu_compile.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace v1 { + +using tpu::FunctionToHloArgs; +using tpu::MlirToHloArgs; +using tpu::ShardingAndIndex; + +auto* mlir_second_phase_count = tensorflow::monitoring::Counter<1>::New( + "/tensorflow/core/tf2xla/api/v1/phase2_compilation_status" /*metric_name*/, + "Counts the number of graphs that were analyzed prior deciding whether " + "the MLIR or the old bridge will be used" /* metric description */, + "status" /* metric label */); + +auto* phase2_bridge_compilation_time = tsl::monitoring::Sampler<1>::New( + {"/tensorflow/core/tf2xla/api/v1/phase2_compilation_time", + "The wall-clock time spent on executing graphs in milliseconds.", + "configuration"}, + // Power of 1.5 with bucket count 45 (> 23 hours) + {tsl::monitoring::Buckets::Exponential(1, 1.5, 45)}); + +// The label `status` is used to count the following events: +// MLIR bridge phase 2 was executed and the graph was processed successfully +// (fallback enabled). +constexpr char kMlirWithFallbackModeSuccess[] = "kMlirWithFallbackModeSuccess"; +// MLIR bridge phase 2 compilation was failure (fallback enabled). +constexpr char kMlirWithFallbackModeFailure[] = "kMlirWithFallbackModeFailure"; +// MLIR bridge phase 2 compilation was successful (manually enabled). +constexpr char kMlirModeSuccess[] = "kMlirModeSuccess"; +// MLIR bridge phase 2 compilation fails (manually enabled) +constexpr char kMlirModeFailure[] = "kMlirModeFailure"; +// Old bridge compilation was run successfully (was run because MLIR bridge +// could not process the graph). +constexpr char kOldBridgeMlirFilteredSuccess[] = + "kOldBridgeMlirFilteredSuccess"; +// Old bridge failed (was run b/c MLIR bridge could not process the graph). +constexpr char kOldBridgeMlirFilteredFailure[] = + "kOldBridgeMlirFilteredFailure"; +// Old bridge compilation was successfully run after MLIR bridge ran and failed. +constexpr char kOldBridgeWithFallbackModeSuccess[] = + "kOldBridgeWithFallbackModeSuccess"; +// Old Bridge failed in fallback (was run because MLIR bridge failed first). +constexpr char kOldBridgeWithFallbackModeFailure[] = + "kOldBridgeWithFallbackModeFailure"; + +// Time the execution of kernels (in CPU cycles). Meant to be used as RAII. +struct CompilationTimer { + uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); + + uint64 ElapsedCycles() { + return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles; + } + + int64_t ElapsedCyclesInMilliseconds() { + std::chrono::duration duration = + profile_utils::CpuUtils::ConvertClockCycleToTime(ElapsedCycles()); + + return std::chrono::duration_cast(duration) + .count(); + } +}; + +namespace { + +bool ShouldFallbackToGraphCompiler( + const std::variant& computation) { + if (computation.index() == 1) return true; + + return std::get<0>(computation).rollout_state == + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; +} + +Status CompileFromMlirToXlaHlo( + bool enable_op_fallback, + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, llvm::StringRef device_type, + const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns, + bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result, + std::vector>& custom_legalization_passes, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes) { + if (enable_op_fallback) { + LOG_FIRST_N(INFO, 1) + << "Compiling MLIR computation to XLA HLO using MLIR tf2xla bridge in " + "the op by op fallback mode. This is Phase 2 of the TF2XLA Bridge. " + "Old (non-MLIR) bridge may be used in case of unsupported feature " + "or compilation failure from the MLIR bridge (full fallback mode)."; + } else { + LOG_FIRST_N(INFO, 1) + << "Compiling MLIR computation to XLA HLO using MLIR tf2xla bridge " + "phase 2. Fallback to the old (non-MLIR) bridge is disabled. " + "Op-by-op fallback is also disabled."; + } + + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + mlir::MLIRContext context(registry); + mlir::OwningOpRef mlir_module; + TF_RETURN_IF_ERROR(DeserializeMlirModule(std::get<0>(computation).mlir_module, + &context, &mlir_module)); + if (!mlir::SetTPUInfeedLayout(mlir_module)) + return errors::Internal("Failed to set layouts attribute"); + + TF_RETURN_IF_ERROR(CompileSerializedMlirToXlaHlo( + SerializeMlirModule(mlir_module.get()), arg_shapes, device_type, + use_tuple_args, enable_op_fallback, shape_determination_fns, + compilation_result, custom_legalization_passes, metadata.module_name())); + + // Compute how arguments are shared across different cores. + return tpu::GetShardingInfo(metadata, arg_shapes, shape_determination_fns, + arg_core_mapping, per_core_arg_shapes); +} + +} // namespace + +tsl::StatusOr LegalizeMlirToHlo( + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + llvm::StringRef device_type, + std::vector>& custom_legalization_passes, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client) { + XlaCompilationResult compilation_result; + // If there are no MLIR args, compile the given function in the library. + if (ShouldFallbackToGraphCompiler(computation)) { + TF_RETURN_IF_ERROR(tf2xla::v0::CompileTensorflowGraphToHlo( + computation, metadata, use_tuple_args, shape_determination_fns, + arg_shapes, arg_core_mapping, per_core_arg_shapes, client, + &compilation_result)); + return compilation_result; + } + + // We could only end up here if the MLIR bridge was explicitly enabled or + // if it was in the default/unspecified state and graph analysis in the first + // phase has not identified unsupported features. + // Enabling op fallback also enables whole graph fallback if op by op + // fallback failed. + bool enable_op_fallback = + std::get<0>(computation).rollout_state != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + + Status mlir_bridge_status = tsl::OkStatus(); + { + CompilationTimer timer; + std::string enabled_string = enable_op_fallback ? "enabled" : "disabled"; + const std::string kMlirBridgeFallback = + absl::StrCat("mlir_bridge_op_fallback_", enabled_string); + + mlir_bridge_status = CompileFromMlirToXlaHlo( + enable_op_fallback, computation, metadata, device_type, + shape_determination_fns, use_tuple_args, &compilation_result, + custom_legalization_passes, arg_shapes, arg_core_mapping, + per_core_arg_shapes); + + phase2_bridge_compilation_time->GetCell(kMlirBridgeFallback) + ->Add(timer.ElapsedCyclesInMilliseconds()); + } + + if (mlir_bridge_status.ok()) { + if (enable_op_fallback) { + VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using MLIR " + "tf2xla bridge"; + mlir_second_phase_count->GetCell(kMlirWithFallbackModeSuccess) + ->IncrementBy(1); + } else { + mlir_second_phase_count->GetCell(kMlirModeSuccess)->IncrementBy(1); + } + return compilation_result; + } else if (!enable_op_fallback) { + // Don't fallback to the old bridge if op-by-op fallback isn't enabled. + mlir_second_phase_count->GetCell(kMlirModeFailure)->IncrementBy(1); + return mlir_bridge_status; + } + + bool filtered_graph = false; + if (mlir_bridge_status == CompileToHloGraphAnalysisFailedError()) { + VLOG(1) << "Filtered out MLIR computation to XLA HLO using MLIR tf2xla " + "bridge. Falling back to old (non-MLIR) bridge."; + filtered_graph = true; + } else { + mlir_second_phase_count->GetCell(kMlirWithFallbackModeFailure) + ->IncrementBy(1); + + VLOG(1) << "Failed to compile MLIR computation to XLA HLO using MLIR " + "tf2xla bridge. Falling back to old (non-MLIR) bridge. MLIR " + "bridge compilation status: " + << mlir_bridge_status; + } + + Status old_bridge_status = tf2xla::v0::CompileTensorflowGraphToHlo( + computation, metadata, use_tuple_args, shape_determination_fns, + arg_shapes, arg_core_mapping, per_core_arg_shapes, client, + &compilation_result); + + // Record filter/failure stats only if the old bridge succeeds. This removes + // noise from invalid inputs. + if (!old_bridge_status.ok()) { + // If the old bridge failed for this input as well. Mark the input as + // invalid. This might be incorrect in case of old bridge bugs but that + // should be rare. + if (filtered_graph) { + mlir_second_phase_count->GetCell(kOldBridgeMlirFilteredFailure) + ->IncrementBy(1); + } else { + mlir_second_phase_count->GetCell(kOldBridgeWithFallbackModeFailure) + ->IncrementBy(1); + } + return old_bridge_status; + } + + if (filtered_graph) { + mlir_second_phase_count->GetCell(kOldBridgeMlirFilteredSuccess) + ->IncrementBy(1); + } else { + mlir_second_phase_count->GetCell(kOldBridgeWithFallbackModeSuccess) + ->IncrementBy(1); + } + return compilation_result; +} + +}; // namespace v1 +}; // namespace tf2xla +}; // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.h b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.h new file mode 100644 index 00000000000..f9ddb0cad78 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.h @@ -0,0 +1,68 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_LEGALIZE_TF_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_LEGALIZE_TF_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/api/v1/device_type.pb.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/pjrt/compile_options.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace v1 { + +// Legalizes the given mlir::Module into XLA HLO. If successful, returns the +// compiled XLA HLO. V1 of the tf2xla uses MLIR whereas V0 does not use MLIR. +// +// Inputs: +// computation - The MLIR module op. It currently takes in +// tpu::FunctionToHloArgs but this is deprecated. arg_shapes - The shapes of +// the arguments in module_op. device_type - The device type to compile for. +// use_tuple_args - Pack the incoming arg shapes into a single tuple. +// custom_legalization_passes - Extra passes to lower from TF -> MHLO. +// arg_shapes - The shapes of the args. +// arg_core_mapping - Which args go on which cores. +// per_core_arg_shapes - For each core, the shapes for each argument. +// client - The Xla Compilation client. +tsl::StatusOr LegalizeMlirToHlo( + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + llvm::StringRef device_type, + std::vector>& custom_legalization_passes, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client); + +}; // namespace v1 +}; // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_LEGALIZE_TF_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf_test.cc new file mode 100644 index 00000000000..32e90358ab9 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.h" + +#include +#include + +#include +#include +#include "tensorflow/compiler/mlir/tf2xla/api/v1/device_type.pb.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/tsl/lib/monitoring/test_utils.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace v1 { + +using ::tensorflow::monitoring::testing::CellReader; +using tpu::FunctionToHloArgs; +using tpu::MlirToHloArgs; +using tpu::ShardingAndIndex; +using tpu::TPUCompileMetadataProto; +using ::tsl::monitoring::testing::Histogram; + +static constexpr char kCompilationTimeStreamzName[] = + "/tensorflow/core/tf2xla/api/v1/phase2_compilation_time"; +static constexpr char kCompilationStatusStreamzName[] = + "/tensorflow/core/tf2xla/api/v1/phase2_compilation_status"; + +static constexpr char kMlirModuleStr[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> () { + func.return + } +})"; + +tsl::StatusOr CompileMlirModule( + ConfigProto::Experimental::MlirBridgeRollout rollout_state) { + MlirToHloArgs mlir_to_hlo_args; + mlir_to_hlo_args.rollout_state = rollout_state; + mlir_to_hlo_args.mlir_module = kMlirModuleStr; + + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("Host").value(); + auto client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); + + std::vector arg_shapes; + TPUCompileMetadataProto metadata_proto; + bool use_tuple_args = true; + std::vector arg_core_mapping; + std::vector> per_core_arg_shapes; + std::vector> custom_legalization_passes; + + return LegalizeMlirToHlo(mlir_to_hlo_args, metadata_proto, use_tuple_args, + /*device_type=*/"XLA_TPU_JIT", + custom_legalization_passes, + /*shape_determination_fns=*/{}, arg_shapes, + &arg_core_mapping, &per_core_arg_shapes, client); +} + +TEST(LegalizeTFTest, RecordsStreamzForMlirBridge) { + CellReader compilation_time(kCompilationTimeStreamzName); + CellReader compilation_status(kCompilationStatusStreamzName); + + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED)); + + Histogram histogram = + compilation_time.Delta("mlir_bridge_op_fallback_disabled"); + EXPECT_EQ(histogram.num(), 1); + EXPECT_EQ(compilation_status.Delta("kMlirModeSuccess"), 1); +} + +TEST(LegalizeTFTest, RecordsStreamzForMlirOpFallback) { + CellReader compilation_time(kCompilationTimeStreamzName); + + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)); + + Histogram histogram = + compilation_time.Delta("mlir_bridge_op_fallback_enabled"); + EXPECT_EQ(histogram.num(), 1); +} + +TEST(LegalizeTFTest, RecordsStreamzForNoMlirFallback) { + FunctionDef my_func = + tensorflow::FunctionDefHelper::Create("empty", {}, {}, {}, {}, {}); + + tensorflow::FunctionDefLibrary fdef; + *(fdef.add_function()) = my_func; + tensorflow::FunctionLibraryDefinition flib_def( + tensorflow::OpRegistry::Global(), fdef); + + OpInputList guaranteed_constants; + NameAttrList function; + FunctionToHloArgs function_to_hlo_args{&function, + &flib_def, + /*graph_def_version=*/0, + {&guaranteed_constants}}; + + se::Platform* cpu_platform = + se::MultiPlatformManager::PlatformWithName("Host").value(); + auto client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform).value(); + + std::vector arg_shapes; + TPUCompileMetadataProto metadata_proto; + bool use_tuple_args = true; + std::vector arg_core_mapping; + std::vector> per_core_arg_shapes; + std::vector> custom_legalization_passes; + + // This doesn't actually compile correctly. + tsl::StatusOr compile_result = + LegalizeMlirToHlo(function_to_hlo_args, metadata_proto, use_tuple_args, + /*device_type=*/"XLA_CPU_JIT", + custom_legalization_passes, + /*shape_determination_fns=*/{}, arg_shapes, + &arg_core_mapping, &per_core_arg_shapes, client); + + EXPECT_FALSE(compile_result.ok()); +} + +} // namespace v1 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc index b4736462e26..6479253dd6e 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include + #include "tensorflow/compiler/jit/flags.h" namespace tensorflow { @@ -22,7 +24,7 @@ namespace tensorflow { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, + std::optional config_proto, bool is_tpu_graph, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats) { switch (GetMlirBridgeRolloutState(config_proto)) { diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h index 262ebc0fd2e..9f67442205d 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_MLIR_BRIDGE_ROLLOUT_POLICY_H_ #define TENSORFLOW_COMPILER_MLIR_TF2XLA_MLIR_BRIDGE_ROLLOUT_POLICY_H_ +#include + #include "mlir/IR/BuiltinOps.h" -#include "absl/types/optional.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -52,7 +53,7 @@ enum class MlirBridgeRolloutPolicy { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, const FunctionLibraryDefinition* function_library, - std::optional config_proto, + std::optional config_proto, bool is_tpu_graph, bool uses_uninitialized_resource_args, bool is_v1_compat, bool record_stats); diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir index f0cc4783d9d..be4026d77c6 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-communication.mlir @@ -858,3 +858,15 @@ func.func @multi_block_func() { %0 = "tf.XlaRecvFromHost"() {key = "recv_key", shape = #tf_type.shape<>} : () -> tensor func.return } + +// ----- + +// CHECK-LABEL: func @host_compute_manual_sharding +func.func @host_compute_manual_sharding(%arg0: tensor) { + // CHECK: "mhlo.send" + // CHECK-SAME: mhlo.sharding = "\08\04" + // CHECK: "mhlo.recv" + // CHECK-SAME: mhlo.sharding = "\08\04" + %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", host_mlir_module = "", manual_sharding = true} : (tensor) -> tensor + func.return +} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir index ab6d07c3e84..6382f05f708 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-include-tf2xla-fallback.mlir @@ -1,7 +1,7 @@ -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s -// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false use-tf2xla-hlo-importer=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s +// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT use-tf2xla-hlo-importer=false" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true use-tf2xla-hlo-importer=false" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE use-tf2xla-hlo-importer=false" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s // We run this test four times: // 1) Legalize without using TF2XLA fallback (ops cannot be legalized). diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir index 2f69349d13b..a2a20eb558c 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-prefer-tf2xla.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false use-tf2xla-fallback=true prefer-tf2xla=true" %s | FileCheck %s -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false prefer-tf2xla=true" %s | FileCheck --check-prefix NOFALLBACK %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false use-tf2xla-fallback=true prefer-tf2xla=true use-tf2xla-hlo-importer=false" %s | FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion device-type=XLA_CPU_JIT legalize-chlo=false prefer-tf2xla=true use-tf2xla-hlo-importer=false" %s | FileCheck --check-prefix NOFALLBACK %s module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { @@ -92,4 +92,51 @@ func.func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { func.return %output : tensor<3x2xf32> } +//===----------------------------------------------------------------------===// +// Fused op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: fused_conv2d +func.func @fused_conv2d(%input: tensor<1x300x300x40xi8>, + %filter: tensor<3x3x40x40xi8>, + %bias: tensor<40xf32>, + %act: tensor<0xi8>) -> tensor<1x300x300x40xi8> { + + // CHECK: %[[v0:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %[[v1:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-NEXT: %[[v2:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %[[v3:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-NEXT: %[[v4:.*]] = mhlo.convert %arg0 : (tensor<1x300x300x40xi8>) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v5:.*]] = mhlo.convert %arg1 : (tensor<3x3x40x40xi8>) -> tensor<3x3x40x40xf32> + // CHECK: %[[v6:.*]] = mhlo.convolution(%[[v4]], %[[v5]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME: batch_group_count = 1 + // CHECK-SAME: feature_group_count = 1 + // CHECK-NEXT: %[[v7:.*]] = mhlo.convert %[[v6]] : tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v8:.*]] = "mhlo.broadcast_in_dim"(%[[v2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v9:.*]] = mhlo.multiply %[[v7]], %[[v8]] : tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v10:.*]] = mhlo.convert %arg2 : tensor<40xf32> + // CHECK-NEXT: %[[v11:.*]] = "mhlo.broadcast_in_dim"(%[[v10]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<40xf32>) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v12:.*]] = mhlo.add %[[v9]], %[[v11]] : tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v13:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[v14:.*]] = "mhlo.broadcast_in_dim"(%[[v13]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v15:.*]] = mhlo.maximum %[[v12]], %[[v14]] : tensor<1x300x300x40xf32> + // CHECK-DAG: %[[v16:.*]] = mhlo.constant dense<-1.280000e+02> : tensor + // CHECK-DAG: %[[v17:.*]] = mhlo.constant dense<1.270000e+02> : tensor + // CHECK-DAG: %[[v18:.*]] = "mhlo.broadcast_in_dim"(%[[v16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK-DAG: %[[v19:.*]] = "mhlo.broadcast_in_dim"(%[[v17]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x300x300x40xf32> + // CHECK: %[[v20:.*]] = mhlo.clamp %[[v18]], %[[v15]], %[[v19]] : tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v21:.*]] = mhlo.round_nearest_even %[[v20]] : tensor<1x300x300x40xf32> + // CHECK-NEXT: %[[v22:.*]] = mhlo.convert %[[v21]] : (tensor<1x300x300x40xf32>) -> tensor<1x300x300x40xi8> + // CHECK-NEXT: return %[[v22]] : tensor<1x300x300x40xi8> + + %input_scale = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %side_input_scale = "tf.Const"() {value = dense<2.0> : tensor} : () -> tensor + %conv2d = "tf._FusedConv2D"(%input, %filter, %bias, %act, %input_scale, %side_input_scale) { + data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 9.99999974E-5 : f32, explicit_paddings = [], filter_format = "HWIO", fused_ops = ["BiasAdd", "Relu"], leakyrelu_alpha = 2.000000e-01 : f32, num_args = 2 : i64, operand_segment_sizes = array, padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true + } : (tensor<1x300x300x40xi8>, tensor<3x3x40x40xi8>, tensor<40xf32>, tensor<0xi8>, tensor, tensor) -> tensor<1x300x300x40xi8> + func.return %conv2d : tensor<1x300x300x40xi8> } + +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir new file mode 100644 index 00000000000..1d6bfb6bcd7 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla-hlo-importer.mlir @@ -0,0 +1,556 @@ +// RUN: tf-opt "-xla-legalize-tf=device-type=XLA_CPU_JIT allow-partial-conversion=true prefer-tf2xla=true use-tf2xla-fallback=true use-tf2xla-hlo-importer=true" %s -verify-diagnostics -mlir-disable-threading | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + // CHECK-LABEL: binary_op + func.func @binary_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.atan2 + %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + + // CHECK-LABEL: multiple_return_values + func.func @multiple_return_values(%arg0: tensor<3xi64>) -> tensor { + %0:3 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<3xi64>) -> (tensor, tensor, tensor) + // CHECK: return %1 : tensor + func.return %0#0 : tensor + } + + // CHECK-LABEL: constant_parameter + func.func @constant_parameter(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.Const"() {value = dense<1.42> : tensor<2xf32>} : () -> tensor<2xf32> + // CHECK: mhlo.atan2 %arg0, %4 + %1 = "tf.Atan2"(%arg0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + + // CHECK-LABEL: uses_translated_return_type + func.func @uses_translated_return_type(%arg0: tensor<3xf32>) -> tensor { + // CHECK: tensor.cast %{{[0-9]+}} : tensor> to tensor + %y, %idx = "tf.Unique"(%arg0) {device = ""} : (tensor<3xf32>) -> (tensor, tensor<3xi32>) + return %y : tensor + } + + // CHECK-LABEL: @abs + func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NOT: tf.Abs + %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + + // CHECK-LABEL: func @testBroadcastGradientArgs + func.func @testBroadcastGradientArgs(%s0: tensor<4xi32>, %s1: tensor<4xi32>) -> (tensor<1xi32>, tensor<0xi32>) { + // CHECK: tf.BroadcastGradientArgs + %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor<0xi32>) + func.return %r0, %r1 : tensor<1xi32>, tensor<0xi32> + } + + // CHECK-LABEL: @acos + func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NOT: tf.Acos + %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + + // CHECK-LABEL: strided_slice_uses_mlir + func.func @strided_slice_uses_mlir(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { + %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK-NOT: tf.StridedSlice + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> + func.return %output : tensor<3x2xf32> + } + + // CHECK-LABEL: func @random_uniform_uses_mlir + func.func @random_uniform_uses_mlir(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { + // CHECK-NOT: tf.RandomUniform + %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> + func.return %0 : tensor<12x?x64xf32> + } + + // CHECK-LABEL: unknown_op + func.func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: tf.CustomTestOp + %0 = "tf.CustomTestOp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + + func.return %0 : tensor<2xf32> + } + + // CHECK-LABEL: add_v2 + func.func @add_v2(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: mhlo.add %arg0, %arg0 : tensor<2xi32> + %0 = "tf.AddV2"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + func.return %0: tensor<2xi32> + } + + // CHECK-LABEL: not_allowlisted_op + func.func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tf.TensorListReserve + %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor>> + // CHECK: tf.TensorListGetItem + %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor<3xi32>) -> tensor + func.return %1 : tensor + } + + // CHECK-LABEL: unranked_operand + func.func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: tf.Atan2 + // expected-remark@+1 {{lowering requires bounded tensor operands}} + %0 = "tf.Atan2"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + + func.return %0 : tensor<*xf32> + } + + // CHECK-LABEL: dynamic_operand + func.func @dynamic_operand(%arg0: tensor) -> tensor { + // CHECK: tf.Atan2 + // expected-remark@+1 {{lowering requires bounded tensor operands}} + %0 = "tf.Atan2"(%arg0, %arg0) : (tensor, tensor) -> tensor + + func.return %0 : tensor + } + + // CHECK-LABEL: tuple_type + func.func @tuple_type(%arg0: tuple, tensor>) -> tensor { + // Verifies that the pass can handle operands of non-tensor type like tuple + // from non TensorFlow ops. + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + func.return %0 : tensor + } + + // CHECK-LABEL: unsupported_dtype + func.func @unsupported_dtype(%arg0: tensor<2x!tf_type.variant>) -> tensor<2x!tf_type.variant> { + // CHECK: tf.AddN + // expected-remark@+1 {{skipping legalization due to unsupported type 'tensor<2x!tf_type.variant>'}} + %0 = "tf.AddN"(%arg0, %arg0) : (tensor<2x!tf_type.variant>, tensor<2x!tf_type.variant>) -> tensor<2x!tf_type.variant> + + func.return %0 : tensor<2x!tf_type.variant> + } + + // CHECK-LABEL: multiple_dialect_ops + func.func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.negate + %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: mhlo.atan2 + %1 = "tf.Atan2"(%arg0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + + func.return %1 : tensor<2xf32> + } + + // CHECK-LABEL: binary_op_broadcast + func.func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> { + // CHECK: %[[BROADCAST0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32> + // CHECK: %[[RESHAPE0:.*]] = mhlo.reshape %[[BROADCAST0]] : (tensor<4x4x1xf32>) -> tensor<4x4xf32> + // CHECK: %[[UPDATED_ARG0:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + + // CHECK: %[[RESHAPE1:.*]] = mhlo.reshape %arg1 : (tensor<4x1x4xf32>) -> tensor<4x4xf32> + // CHECK: %[[UPDATED_ARG1:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32> + + // CHECK: %[[RESULT:.*]] = mhlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32> + // CHECK: return %[[RESULT]] : tensor<4x4x4xf32> + + %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32> + func.return %0: tensor<4x4x4xf32> + } + + // CHECK-LABEL: func @ternary_op + func.func @ternary_op(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: mhlo.select %arg0, %arg1, %arg2 + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + func.return %0: tensor<2xi32> + } + + // CHECK-LABEL: func @convert + func.func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> { + // CHECK: mhlo.convert %arg0 : (tensor<2xi32>) -> tensor<2xf32> + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + + // CHECK-LABEL: func @constant + func.func @constant(%arg0: tensor) -> tensor { + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[ONE]], %arg0 : tensor + // CHECK: return %[[RESULT]] + + %0 = "tf.Inv"(%arg0) : (tensor) -> tensor + func.return %0 : tensor + } + + // CHECK-LABEL: func @const_inputs + // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor, + func.func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> { + + // CHECK: "mhlo.pad"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64> + // CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64> + + %0 = mhlo.constant dense<[2, 1]> : tensor<2xi32> + %1 = mhlo.constant dense<[1, 2]> : tensor<2xi32> + %2 = mhlo.constant dense<[1, 0]> : tensor<2xi32> + %3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64> + func.return %3 : tensor<6x5xf64> + } + + func.func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> { + // expected-remark@+1 {{lowering requires operand #2 to be a constant}} + %0 = "tf.XlaPad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x2xf64>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64> + func.return %0 : tensor<6x5xf64> + } + + // CHECK-LABEL: dynamic_result_type + func.func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> { + // CHECK: %[[RESULT:.*]] = mhlo.atan2 %arg0, %arg0 : tensor<2xf32> + // CHECK: tensor.cast %[[RESULT]] : tensor<2xf32> to tensor<*xf32> + %0 = "tf.Atan2"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32> + + // return %[[RESULT]] + func.return %0 : tensor<*xf32> + } + + func.func @truncated_normal() -> tensor<2x2xf32> { + // CHECK-NOT: tf.TruncatedNormal + %0 = mhlo.constant dense<[2, 2]> : tensor<2xi32> + %1 = "tf.TruncatedNormal"(%0) {T = i32, device = "", dtype = f32, seed = 0 : i64, seed2 = 1950157571 : i64} : (tensor<2xi32>) -> tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> + } + + // CHECK-LABEL: dynamic_update_slice + // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32> + func.func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> { + + // CHECK: %[[SLICE0:.*]] = "mhlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM0:.*]] = mhlo.reshape %[[SLICE0]] : (tensor<1xi32>) -> tensor + + // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM1:.*]] = mhlo.reshape %[[SLICE1]] : (tensor<1xi32>) -> tensor + + // CHECK: mhlo.dynamic_update_slice %[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]] + + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> + func.return %0: tensor<3x4xi32> + } + + // CHECK-LABEL: @sparse_to_dense + // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor) + func.func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor) -> tensor<3x3xf32> { + + // CHECK: %[[DEFAULT:.*]] = "mhlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> + + // CHECK: %[[RESULT:.*]] = "mhlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ({ + // CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): + // CHECK: mhlo.return %[[ARG4]] : tensor + // CHECK: }) + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: scatter_dimension_numbers + // CHECK-SAME: inserted_window_dims = [0, 1] + // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] + // CHECK-SAME: index_vector_dim = 1 + // CHECK-SAME: unique_indices = false + // CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32> + + // return %[[RESULT]] : tensor<3x3xf32> + + %cst = mhlo.constant dense<3> : tensor<2xi32> + %0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor) -> tensor<3x3xf32> + func.return %0 : tensor<3x3xf32> + } + + // CHECK-LABEL: reverse_sequence + func.func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> { + // CHECK-NOT: tf.ReverseSequence + %0 = "tf.ReverseSequence"(%arg0, %arg1) {batch_dim = 2 : i64, seq_dim = 0 : i64}: (tensor<4x2x3x1x1xi32>, tensor<3xi32>) -> tensor<4x2x3x1x1xi32> + func.return %0 : tensor<4x2x3x1x1xi32> + } + + // CHECK-LABEL: mirror_pad + func.func @mirror_pad(%arg0: tensor<2x3xcomplex>) -> tensor<4x7xcomplex> { + %0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> + // CHECK-NOT: tf.MirrorPad + %1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex>, tensor<2x2xi32>) -> tensor<4x7xcomplex> + func.return %1 : tensor<4x7xcomplex> + } + + // CHECK-LABEL: bucketize + func.func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { + // CHECK-NOT: tf.Bucketize + %0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32> + func.return %0 : tensor<2x5xi32> + } + + // CHECK-LABEL: arg_min + func.func @arg_min(%arg0: tensor<6xf64>) -> tensor { + // CHECK-NOT: ArgMin + %0 = mhlo.constant dense<0> : tensor + %1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor) -> tensor + func.return %1 : tensor + } + + // CHECK-LABEL: non_max_suppression_v4 + func.func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor) -> tensor<2xi32> { + %max_size = mhlo.constant dense<2> : tensor + // CHECK-NOT: tf.NonMaxSuppressionV4 + %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + func.return %0#0 : tensor<2xi32> + } + + // CHECK-LABEL: bessel_i0e + func.func @bessel_i0e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) { + // CHECK-NOT: tf.BesselI0e + %0 = "tf.BesselI0e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>) + %1 = "tf.BesselI0e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>) + %2 = "tf.BesselI0e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>) + func.return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64> + } + + // CHECK-LABEL: bessel_i1e + func.func @bessel_i1e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) { + // CHECK-NOT: tf.BesselI1e + %0 = "tf.BesselI1e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>) + %1 = "tf.BesselI1e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>) + %2 = "tf.BesselI1e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>) + func.return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64> + } + + // CHECK-LABEL: diag + func.func @diag(%arg0: tensor<2xf32>) -> tensor<2x2xf32> { + // CHECK-NOT: tf.Diag + %0 = "tf.Diag"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> + } + + // CHECK-LABEL: random_uniform_int + func.func @random_uniform_int(%arg0: tensor, %arg1: tensor) -> tensor<1000xi32> { + %0 = "tf.Const"() {value = dense<1000> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NOT: tf.RandomUniformInt + %1 = "tf.RandomUniformInt"(%0, %arg0, %arg1) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<1xi32>, tensor, tensor) -> tensor<1000xi32> + func.return %1 : tensor<1000xi32> + } + + // CHECK-LABEL: multinomial + func.func @multinomial(%arg0: tensor<2x4xf32>, %seed: tensor, %seed2: tensor) -> tensor<2x10xi32> { + // CHECK-NOT: tf.Multinomial + %samples = "tf.Const"() { value = dense<10> : tensor } : () -> tensor + %1 = "tf.Multinomial"(%arg0, %samples) {seed = 0, seed2 = 0}: (tensor<2x4xf32>, tensor) -> tensor<2x10xi32> + func.return %1 : tensor<2x10xi32> + } + + // CHECK-LABEL: @set_dynamic_dimension_size + func.func @set_dynamic_dimension_size(%input: tensor<4xf32>, %size: tensor) -> tensor { + %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + // CHECK: mhlo.set_dimension_size + // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor + func.return %0 : tensor + } + + // CHECK-LABEL: unique + func.func @unique(%arg0: tensor<5xf32>) -> (tensor, tensor) { + // CHECK-NOT: tf.Unique + %0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor, tensor) + func.return %0, %1 : tensor , tensor + } + + // CHECK-LABEL: @erfinv + func.func @erfinv(%input: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NOT: tf.Erfinv + %0 = "tf.Erfinv"(%input) : (tensor<4xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> + } + + // CHECK-LABEL: @ndtri + func.func @ndtri(%input: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NOT: tf.Ndtri + %0 = "tf.Ndtri"(%input) : (tensor<4xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> + } + + // CHECK-LABEL: @fake_param + func.func @fake_param() -> tensor<4xf32> { + // CHECK-NOT: tf.FakeParam + %0 = "tf.FakeParam"() {shape = #tf_type.shape<4>} : () -> tensor<4xf32> + func.return %0 : tensor<4xf32> + } + + // CHECK-LABEL: @parameterized_truncated_normal + func.func @parameterized_truncated_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<10000000xf32> { + %0 = "tf.Const"() {value = dense<10000000> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NOT: tf.ParameterizedTruncatedNormal + %1 = "tf.ParameterizedTruncatedNormal"(%0, %arg0, %arg1, %arg2, %arg3) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<1xi32>, tensor, tensor, tensor, tensor) -> tensor<10000000xf32> + func.return %1 : tensor<10000000xf32> + } + + // Check XlaSpmdFullToShardShape's conversion from split sharding to manual + // sharding. + // The split sharding is: + // type: OTHER + // tile_assignment_dimensions: 2 + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // tile_assignment_devices: 1 + // Serialized string: + // "\08\03\1A\02\02\01\22\02\00\01" + // The manual sharding is: + // type: MANUAL + // Serialized string: + // "\08\04" + + // CHECK-LABEL: @xla_spmd_full_to_shard_shape + func.func @xla_spmd_full_to_shard_shape(%arg0: tensor<2x2xi64>) -> (tensor<1x2xi64>) { + // CHECK: %[[SHARDING:.*]] = mhlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,1]0,1}"} : (tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: %[[MANUAL:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[SHARDING]]) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<2x2xi64>) -> tensor<1x2xi64> + // CHECK: return %[[MANUAL]] + %0 = "tf.XlaSpmdFullToShardShape"(%arg0) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xi64>) -> tensor<1x2xi64> + func.return %0 : tensor<1x2xi64> + } + + // Check XlaSpmdShardToFullShape's conversion from manual sharding to split + // sharding. + // The manual sharding is: + // type: MANUAL + // Serialized string: + // "\08\04" + // The split sharding is: + // type: OTHER + // tile_assignment_dimensions: 2 + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // tile_assignment_devices: 1 + // Serialized string: + // "\08\03\1A\02\02\01\22\02\00\01" + + // CHECK-LABEL: @xla_spmd_shard_to_full_shape + func.func @xla_spmd_shard_to_full_shape(%arg0: tensor<1x2xi64>) -> (tensor<2x2xi64>) { + // CHECK: %[[SHARDING:.*]] = mhlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x2xi64>) -> tensor<1x2xi64> + // CHECK: %[[FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[SHARDING]]) {backend_config = "", mhlo.sharding = "{devices=[2,1]0,1}"} : (tensor<1x2xi64>) -> tensor<2x2xi64> + // CHECK: return %[[FULL]] + %0 = "tf.XlaSpmdShardToFullShape"(%arg0) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xi64>) -> tensor<2x2xi64> + func.return %0 : tensor<2x2xi64> + } + + // CHECK-LABEL: @xla_svd + func.func @xla_svd(%arg0: tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) { + // CHECK-NOT: XlaSvd + %s, %u, %v = "tf.XlaSvd"(%arg0) {max_iter = 1, epsilon = 1.0E-09 : f32, precision_config = ""} : (tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) + func.return %s, %u, %v : tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32> + } + + func.func @identity(%arg0: f32) -> f32 { + func.return %arg0 : f32 + } + + // This test verifies that legalization for ops with symbol reference attribute + // is not attempted even if they are in allow-list. XLA op kernels for these + // ops compile the function to HLO on-demand which won't work in our case as it + // may contain unsupported ops in the fallback nor we provide XlaCompiler to + // the kernel. Using a allowed op Atan2 to protect against future addition of a + // new op with a symbol ref. + + // CHECK-LABEL: @atan2_with_symbol_ref + func.func @atan2_with_symbol_ref(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: tf.Atan2 + // expected-remark@+1 {{ops with symbol references are not supported}} + %0 = "tf.Atan2"(%arg0, %arg0) {_body = @identity} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + + func.return %0 : tensor<2xf32> + } + + func.func private @branch0(tensor<2xf32>) -> tensor<2xf32> + func.func private @branch1(tensor<2xf32>) -> tensor<2xf32> + + func.func @case_with_symbol_ref(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: tf.Case + // expected-remark@+1 {{ops with symbol references are not supported}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + + // CHECK-LABEL: const + func.func @const() -> tensor<2xf32> { + // CHECK: mhlo.const + %cst = "tf.Const"() {value = dense<2.0> : tensor<2xf32>} : () -> tensor<2xf32> + func.return %cst : tensor<2xf32> + } + + // CHECK-LABEL: @bounds_propagation + func.func @bounds_propagation(%input: tensor<4xf32>, %size: tensor) -> tensor { + %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + // CHECK: %[[BOUNDED:.*]] = "mhlo.set_dimension_size" + // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor + + %axis = "tf.Const"() { value = dense<0> : tensor<1xi32> } : () -> tensor<1xi32> + // CHECK: %[[REVERSED:.*]] = "mhlo.reverse"(%[[BOUNDED]]) + // CHECK-SAME: {dimensions = dense<0> : tensor<1xi64>} + // CHECK-SAME: (tensor>) -> tensor> + %1 = "tf.ReverseV2"(%0, %axis) : (tensor, tensor<1xi32>) -> tensor + + // CHECK: %[[RESULT:.*]] = tensor.cast %[[REVERSED]] : tensor> to tensor + // CHECK: return %[[RESULT]] : tensor + func.return %1 : tensor + } + + // CHECK-LABEL: @bounds_propagation_skip_symbol_ref_ops + func.func @bounds_propagation_skip_symbol_ref_ops(%input: tensor<4xf32>, %size: tensor) -> tensor { + %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + // CHECK: %[[BOUNDED:.*]] = "mhlo.set_dimension_size" + // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor + + // CHECK: %[[ORIGINAL:.*]] = tensor.cast %[[BOUNDED]] : tensor> to tensor + + %axis = "tf.Const"() { value = dense<0> : tensor<1xi32> } : () -> tensor<1xi32> + // CHECK: tf.ReverseV2 + // CHECK-SAME: (tensor, tensor<1xi32>) -> tensor + // expected-remark@+1 {{lowering requires bounded tensor operands}} + %1 = "tf.ReverseV2"(%0, %axis) {_body = @identity} : (tensor, tensor<1xi32>) -> tensor + + func.return %1 : tensor + } + + // CHECK-LABEL: func @set_bound + func.func @set_bound(%arg0: tensor) -> tensor { + %bound = "tf.Const"() {value = dense<16> : tensor} : () -> tensor + + // CHECK: %[[RESULT:.*]] = mhlo.custom_call @SetBound(%arg0) {backend_config = "", mhlo.literal = dense<16> : tensor} + %bounded = "tf.XlaSetBound"(%arg0, %bound) : (tensor, tensor) -> tensor + + // CHECK: return %[[RESULT]] + func.return %bounded : tensor + } + + // CHECK-LABEL: func @greater + func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: mhlo.compare GT, %arg0, %arg1 + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0: tensor<2xi1> + } + + // CHECK-LABEL: batchmatmulv2 + func.func @batchmatmulv2(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { + // CHECK: mhlo.reduce + // CHECK: mhlo.dot_general + // CHECK: mhlo.transpose + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + func.return %0 : tensor<3x4x4xf32> + } + + // CHECK-LABEL: approx_topk + func.func @approx_topk(%arg0: tensor>> {tf._user_specified_name = "db"}) -> (tensor<10x10xbf16>) { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor>>) -> tensor<10x500xbf16> + // CHECK: mhlo.compare GT + %values, %indices = "tf.ApproxTopK"(%0) {aggregate_to_topk = true, device = "", is_max_k = true, k = 10 : i64, recall_target = 0.949999988 : f32, reduction_dimension = -1 : i64, reduction_input_size_override = -1 : i64} : (tensor<10x500xbf16>) -> (tensor<10x10xbf16>, tensor<10x10xi32>) + return %values : tensor<10x10xbf16> + } +} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir index cc2e7f24709..3e550e0366c 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-with-tf2xla.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt "-xla-legalize-tf=device-type=XLA_CPU_JIT allow-partial-conversion=true prefer-tf2xla=true use-tf2xla-fallback=true" %s -verify-diagnostics | FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=device-type=XLA_CPU_JIT allow-partial-conversion=true prefer-tf2xla=true use-tf2xla-fallback=true use-tf2xla-hlo-importer=false" %s -verify-diagnostics | FileCheck %s module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { @@ -337,6 +337,54 @@ func.func @parameterized_truncated_normal(%arg0: tensor, %arg1: tensor func.return %1 : tensor<10000000xf32> } +// Check XlaSpmdFullToShardShape's conversion from split sharding to manual +// sharding. +// The split sharding is: +// type: OTHER +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\02\01\22\02\00\01" +// The manual sharding is: +// type: MANUAL +// Serialized string: +// "\08\04" + +// CHECK-LABEL: @xla_spmd_full_to_shard_shape +func.func @xla_spmd_full_to_shard_shape(%arg0: tensor<2x2xi64>) -> (tensor<1x2xi64>) { + // CHECK: %[[SHARDING:.*]] = mhlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[2,1]0,1}"} : (tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: %[[MANUAL:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[SHARDING]]) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<2x2xi64>) -> tensor<1x2xi64> + // CHECK: return %[[MANUAL]] + %0 = "tf.XlaSpmdFullToShardShape"(%arg0) {dim = -1 : i64, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<2x2xi64>) -> tensor<1x2xi64> + func.return %0 : tensor<1x2xi64> +} + +// Check XlaSpmdShardToFullShape's conversion from manual sharding to split +// sharding. +// The manual sharding is: +// type: MANUAL +// Serialized string: +// "\08\04" +// The split sharding is: +// type: OTHER +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\02\01\22\02\00\01" + +// CHECK-LABEL: @xla_spmd_shard_to_full_shape +func.func @xla_spmd_shard_to_full_shape(%arg0: tensor<1x2xi64>) -> (tensor<2x2xi64>) { + // CHECK: %[[SHARDING:.*]] = mhlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<1x2xi64>) -> tensor<1x2xi64> + // CHECK: %[[FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[SHARDING]]) {backend_config = "", mhlo.sharding = "{devices=[2,1]0,1}"} : (tensor<1x2xi64>) -> tensor<2x2xi64> + // CHECK: return %[[FULL]] + %0 = "tf.XlaSpmdShardToFullShape"(%arg0) {dim = -1 : i64, full_shape = #tf_type.shape<2x2>, manual_sharding = "\08\03\1A\02\02\01\22\02\00\01", unspecified_dims = []} : (tensor<1x2xi64>) -> tensor<2x2xi64> + func.return %0 : tensor<2x2xi64> +} + // CHECK-LABEL: @xla_svd func.func @xla_svd(%arg0: tensor<1x1xf32>) -> (tensor<1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) { // CHECK-NOT: XlaSvd @@ -428,6 +476,21 @@ func.func @set_bound(%arg0: tensor) -> tensor { func.return %bounded : tensor } +// CHECK-LABEL: @XlaScatterOpNotSupported +func.func @XlaScatterOpNotSupported(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>) -> tensor<8xi32> { + // CHECK: tf.XlaScatter + %0 = "tf.XlaScatter"(%arg0, %arg1, %arg2) {dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", indices_are_sorted = false, update_computation = @no_reducer} : (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<8xi32> + func.return %0 : tensor<8xi32> +} + +// CHECK-LABEL: approx_topk +func.func @approx_topk(%arg0: tensor>> {tf._user_specified_name = "db"}) -> (tensor<10x10xbf16>) { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor>>) -> tensor<10x500xbf16> + // CHECK: mhlo.compare GT + %values, %indices = "tf.ApproxTopK"(%0) {aggregate_to_topk = true, device = "", is_max_k = true, k = 10 : i64, recall_target = 0.949999988 : f32, reduction_dimension = -1 : i64, reduction_input_size_override = -1 : i64} : (tensor<10x500xbf16>) -> (tensor<10x10xbf16>, tensor<10x10xi32>) + return %values : tensor<10x10xbf16> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 78f254d7a23..19fe43f0250 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -4059,6 +4059,34 @@ func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // ----- +// CHECK-LABEL: func @random_uniform_simple +func.func @random_uniform_simple(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> + // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> + %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> + // CHECK: return %[[F32]] + func.return %0 : tensor<12x?x64xf32> +} + +// ----- + +// CHECK-LABEL: func @random_uniform_with_seeds +func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64xf32> { + // CHECK: %0 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi32> + // CHECK-NEXT: %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %2 = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %3 = mhlo.convert %0 : (tensor<4xi32>) -> tensor<4xi64> + // CHECK-NEXT: %4 = "mhlo.rng"(%1, %2, %3) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> + %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> + %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> + // CHECK: return %4 : tensor<32x12x12x64xf32> + func.return %0 : tensor<32x12x12x64xf32> +} + +// ----- + // CHECK-LABEL: func @rng_std_normal func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor @@ -6264,6 +6292,208 @@ func.func @uniform_quantized_convolution(%input: tensor<1x2x2x3xf32>) -> () { func.return } +//===----------------------------------------------------------------------===// +// tf.UniformQuantizedAdd legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @uniform_quantized_add +func.func @uniform_quantized_add(%input: tensor<3x2xf32>) -> () { + %input_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %input_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + // tensor_proto that points to dense<127> of type !tf_type.qint32. + %bias = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %bias_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %bias_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + + %output_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %output_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + + // CHECK-DAG: %[[LHS:.*]] = mhlo.uniform_quantize %arg0 : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL}: {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> + // CHECK: chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) + // CHECK-SAME: -> tensor<3x2x!quant.uniform> + + %0 = "tf.UniformQuantize"(%input, %input_scales, %input_zps) { + quantization_axis = -1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64 + } : (tensor<3x2xf32>, tensor, tensor) -> tensor<3x2x!tf_type.qint32> + %1 = "tf.UniformQuantizedAdd"( + %0, %bias, + %input_scales, %input_zps, + %bias_scales, %bias_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -2147483648 : i64, + lhs_quantization_max_val = 2147483647 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -2147483648 : i64, + rhs_quantization_max_val = 2147483647 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, + tensor, tensor, + tensor, tensor, + tensor, tensor) -> tensor<3x2x!tf_type.qint32> + func.return +} + +// ----- + +// CHECK-LABEL: func @uniform_quantized_add_unknown_lhs_rank +func.func @uniform_quantized_add_unknown_lhs_rank(%input: tensor<*x!tf_type.qint32>) -> () { + %input_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %input_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + // tensor_proto that points to dense<127> of type !tf_type.qint32. + %bias = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %bias_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %bias_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + + %output_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %output_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + %1 = "tf.UniformQuantizedAdd"( + %input, %bias, + %input_scales, %input_zps, + %bias_scales, %bias_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -2147483648 : i64, + lhs_quantization_max_val = 2147483647 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -2147483648 : i64, + rhs_quantization_max_val = 2147483647 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<*x!tf_type.qint32>, tensor<2x!tf_type.qint32>, + tensor, tensor, + tensor, tensor, + tensor, tensor) -> tensor<*x!tf_type.qint32> + func.return +} + +// ----- + +// CHECK-LABEL: func @uniform_quantized_add_non_constant_lhs_scales +func.func @uniform_quantized_add_non_constant_lhs_scales( + %input: tensor<*x!tf_type.qint32>, %input_scales: tensor) -> () { + %input_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + // tensor_proto that points to dense<127> of type !tf_type.qint32. + %bias = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %bias_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %bias_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + + %output_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %output_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + %1 = "tf.UniformQuantizedAdd"( + %input, %bias, + %input_scales, %input_zps, + %bias_scales, %bias_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -2147483648 : i64, + lhs_quantization_max_val = 2147483647 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -2147483648 : i64, + rhs_quantization_max_val = 2147483647 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<*x!tf_type.qint32>, tensor<2x!tf_type.qint32>, + tensor, tensor, + tensor, tensor, + tensor, tensor) -> tensor<*x!tf_type.qint32> + func.return +} + +//===----------------------------------------------------------------------===// +// tf.UniformQuantizedClipByValue legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @uniform_quantized_clip_by_value +func.func @uniform_quantized_clip_by_value(%input: tensor<3x2xf32>) -> () { + %scales = "tf.Const"() { value = dense<2.0> : tensor<2xf32> } : () -> tensor<2xf32> + %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> + // tensor_proto that points to dense<127> of type !tf_type.qint32. + %min = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %max = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + + // CHECK-DAG: %[[OPERAND:.*]] = mhlo.uniform_quantize %arg0 : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + // CHECK-DAG: %[[MIN:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL}: {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> + // CHECK: %[[MAX:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL}: {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> + // CHECK: %[[MIN_CLIPPED:.*]] = chlo.broadcast_maximum %[[OPERAND]], %[[MIN]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) + // CHECK-SAME: -> tensor<3x2x!quant.uniform> + // CHECK: chlo.broadcast_minimum %[[MIN_CLIPPED]], %[[MAX]] {broadcast_dimensions = dense<1> : tensor<1xi64>} : + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x!quant.uniform>) + // CHECK-SAME: -> tensor<3x2x!quant.uniform> + + %0 = "tf.UniformQuantize"(%input, %scales, %zps) { + quantization_axis = 1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64 + } : (tensor<3x2xf32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + %1 = "tf.UniformQuantizedClipByValue"(%0, %min, %max, %scales, %zps) { + quantization_axis = 1 : i64, + quantization_min_val = -2147483648 : i64, + quantization_max_val = 2147483647 : i64 + } : (tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + func.return +} + +// ----- + +// CHECK-LABEL: func @uniform_quantized_clip_by_value_min_not_const +func.func @uniform_quantized_clip_by_value_min_not_const(%input: tensor<3x2x!tf_type.qint32>, %min: tensor<2x!tf_type.qint32>) -> () { + %scales = "tf.Const"() { value = dense<2.0> : tensor<2xf32> } : () -> tensor<2xf32> + %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> + // tensor_proto that points to dense<127> of type !tf_type.qint32. + %max = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %0 = "tf.UniformQuantizedClipByValue"(%input, %min, %max, %scales, %zps) { + quantization_axis = 1 : i64, + quantization_min_val = -2147483648 : i64, + quantization_max_val = 2147483647 : i64 + } : (tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + func.return +} + +// ----- + +// CHECK-LABEL: func @uniform_quantized_clip_by_value_max_not_const +func.func @uniform_quantized_clip_by_value_max_not_const(%input: tensor<3x2x!tf_type.qint32>, %max: tensor<2x!tf_type.qint32>) -> () { + %scales = "tf.Const"() { value = dense<2.0> : tensor<2xf32> } : () -> tensor<2xf32> + %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> + // tensor_proto that points to dense<127> of type !tf_type.qint32. + %min = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %0 = "tf.UniformQuantizedClipByValue"(%input, %min, %max, %scales, %zps) { + quantization_axis = 1 : i64, + quantization_min_val = -2147483648 : i64, + quantization_max_val = 2147483647 : i64 + } : (tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + func.return +} + +// ----- + +// CHECK-LABEL: func @uniform_quantized_clip_by_value_scales_not_const +func.func @uniform_quantized_clip_by_value_scales_not_const(%input: tensor<3x2x!tf_type.qint32>, %scales: tensor<2xf32>) -> () { + %zps = "tf.Const"() { value = dense<4> : tensor<2xi32> } : () -> tensor<2xi32> + // tensor_proto that points to dense<127> of type !tf_type.qint32. + %min = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %max = "tf.Const"() { value = #tf_type : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32> + %0 = "tf.UniformQuantizedClipByValue"(%input, %min, %max, %scales, %zps) { + quantization_axis = 1 : i64, + quantization_min_val = -2147483648 : i64, + quantization_max_val = 2147483647 : i64 + } : (tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2x!tf_type.qint32>, tensor<2xf32>, tensor<2xi32>) -> tensor<3x2x!tf_type.qint32> + func.return +} + //===----------------------------------------------------------------------===// // tf.Softplus legalization //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD new file mode 100644 index 00000000000..df4bf8fa204 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -0,0 +1,493 @@ +# Description: +# TF2XLA Bridge transforms + +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +gentbl_cc_library( + name = "legalize_tf_patterns_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "generated_legalize_tf.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "legalize_tf_patterns.td", + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + ], +) + +gentbl_cc_library( + name = "xla_legalize_tf_passes_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=LegalizeTf", + ], + "xla_legalize_tf_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_legalize_tf_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "tf_xla_passes_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=TfXla", + ], + "tf_xla_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "tf_xla_passes.td", + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:TensorOpsTdFiles", + ], +) + +cc_library( + name = "tf_xla_passes", + srcs = [ + "xla_legalize_tf_passes.h.inc", + ], + hdrs = [ + "passes.h", + ], + deps = [ + ":tf_xla_passes_inc_gen", + ":xla_legalize_tf", + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "legalize_utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], + deps = [ + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "test_utils", + testonly = True, + srcs = ["test_utils.cc"], + hdrs = ["test_utils.h"], + deps = [ + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/core/platform:errors", + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "legalize_tf", + srcs = [ + "generated_legalize_tf.inc", + "legalize_tf.cc", + ], + hdrs = [ + "passes.h", + ], + deps = [ + ":legalize_tf_patterns_inc_gen", + ":legalize_utils", + ":tf_xla_passes_inc_gen", + ":xla_legalize_tf_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:sharding_builder", + "//tensorflow/compiler/xla/client/lib:conv_grad_size_util", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/mlir_hlo:convert_op_folder", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:attribute_importer", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:conv_grad_shape_utils", + "//tensorflow/tsl/platform:bfloat16", + "//tensorflow/tsl/platform:status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:chlo_ops", + ], +) + +cc_library( + name = "xla_legalize_targets", + srcs = [ + "xla_legalize_targets.cc", + ], + hdrs = [ + "xla_legalize_targets.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:chlo_ops", + ], +) + +tf_cc_test( + name = "xla_legalize_targets_test", + srcs = ["xla_legalize_targets_test.cc"], + deps = [ + ":xla_legalize_targets", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:chlo_ops", + ], +) + +tf_cc_test( + name = "verify_tfxla_legalization_test", + srcs = ["verify_tfxla_legalization_test.cc"], + deps = [ + ":legalize_tf", + ":test_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "xla_legalize_tf", + srcs = [ + "convert_mhlo_quant_to_int.cc", + "infeed_ops_xla_adjust_layout.cc", + "legalize_tf_collective.cc", + "legalize_tf_communication.cc", + "legalize_tf_types.cc", + "tf_xla_passes.h.inc", + "tfxla_device_specific_transforms.cc", + "verify_tfxla_legalization.cc", + "xla_legalize_tf.cc", + "xla_legalize_tf_passes.h.inc", + ], + hdrs = [ + "passes.h", + ], + deps = [ + ":legalize_tf", + ":legalize_utils", + ":xla_legalize_targets", + ":xla_legalize_tf_no_fallback", + ":xla_legalize_tf_passes_inc_gen", + ":xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:set_tpu_infeed_layout", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/tf2xla/kernels:rng_converter_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:side_effect_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:sharding_builder", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/mlir_hlo:chlo_legalize_to_hlo", + "//tensorflow/compiler/xla/mlir_hlo:convert_op_folder", + "//tensorflow/compiler/xla/stream_executor/tpu:c_api_conversions", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_api", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:attribute_importer", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/util/quantization:uniform_quant_ops_params", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@stablehlo//:chlo_ops", + ], +) + +cc_library( + name = "xla_legalize_tf_no_fallback", + srcs = [ + "xla_legalize_tf_no_fallback.cc", + "xla_legalize_tf_passes.h.inc", + ], + hdrs = [ + "passes.h", + ], + deps = [ + ":legalize_tf", + ":tf_xla_passes_inc_gen", + ":xla_legalize_tf_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@stablehlo//:chlo_ops", + ], +) + +cc_library( + name = "tf2xla_rewriter", + srcs = [ + "tf2xla_rewriter.cc", + ], + hdrs = [ + "tf2xla_rewriter.h", + ], + visibility = ["//visibility:private"], + deps = [ + ":legalize_tf", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow:tpu_embedding_ops_registry", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/tf2xla:xla_compilation_device", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_expression", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_module_importer", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:mlir_hlo_builder", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:framework_types_hdr", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + ], +) + +tf_cc_test( + name = "tf2xla_rewriter_test", + srcs = [ + "tf2xla_rewriter_test.cc", + ], + deps = [ + ":test_utils", + ":tf2xla_rewriter", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "xla_legalize_tf_with_tf2xla", + srcs = [ + "legalize_tf_with_tf2xla.cc", + ], + hdrs = [ + "passes.h", + ], + deps = [ + ":tf2xla_rewriter", + ":tf_xla_passes_inc_gen", + ":xla_legalize_tf_passes_inc_gen", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tpu_embedding_ops_registry", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/tf2xla:xla_compilation_device", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_expression", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/stream_executor:timer", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:mlir_hlo_builder", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + ], +) + +tf_cc_test( + name = "xla_legalize_tf_test", + srcs = [ + "xla_legalize_tf_test.cc", + ], + deps = [ + ":tf_xla_passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/compiler/tf2xla:xla_compilation_device", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_expression", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core/framework:allocator", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 39742383ebf..06d6df007f2 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -137,7 +137,7 @@ static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, Builder *b) { IntegerAttr intAttr = attr.dyn_cast_or_null(); if (auto elementAttr = attr.dyn_cast_or_null()) { - SmallVector index(elementAttr.getType().getRank(), 0); + SmallVector index(elementAttr.getShapedType().getRank(), 0); intAttr = elementAttr.getValues()[index]; } @@ -259,7 +259,7 @@ static RankedTensorType GetStaticBroadcastType( shape_large.end()); // Update according to the broadcast dimensions. - for (auto &index_pair : llvm::enumerate(broadcast_dimensions)) { + for (const auto &index_pair : llvm::enumerate(broadcast_dimensions)) { auto old_value = out_shape[index_pair.value()]; auto new_value = shape_small[index_pair.index()]; out_shape[index_pair.value()] = std::max(old_value, new_value); @@ -554,7 +554,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( llvm::SmallVector values; values.reserve(shaped_type.getNumElements() / shape[1]); - for (auto &it : llvm::enumerate(int_attr.getValues())) { + for (const auto &it : llvm::enumerate(int_attr.getValues())) { if (static_cast(it.index() % shape[1]) == column) { values.push_back(it.value().getSExtValue()); } @@ -568,7 +568,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( // Returns interior padding to use in HLO Pad op based on the TensorFlow padding // in TensorFlow PadV2 op. static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { - auto length = tf_padding.getType().getShape()[0]; + auto length = tf_padding.getShapedType().getShape()[0]; auto element_type = IntegerType::get(tf_padding.getContext(), 64); return DenseIntElementsAttr::get( tensorflow::GetTypeFromTFTensorShape({length}, element_type), 0); @@ -3403,7 +3403,7 @@ class ConvertSplitVOp : public OpRewritePattern { std::optional dynamic_dim_index; split_sizes.reserve( split_sizes_attr.getType().cast().getNumElements()); - for (auto &dim : llvm::enumerate(split_sizes_attr)) { + for (const auto &dim : llvm::enumerate(split_sizes_attr)) { int64_t dim_val = dim.value().getSExtValue(); split_sizes.push_back(dim_val); if (dim_val == -1) { @@ -4072,7 +4072,8 @@ class GenericConvertReductionOp : public OpRewritePattern { // that this is a restricted form of shape manipulation that is just adding // unit dims. if (op.getKeepDims()) { - for (auto &dim_is_reduced : llvm::enumerate(reduced_dimensions_bitmap)) { + for (const auto &dim_is_reduced : + llvm::enumerate(reduced_dimensions_bitmap)) { if (dim_is_reduced.value()) { auto index_attr = GetI32ElementsAttr( {static_cast(dim_is_reduced.index())}, &rewriter); @@ -5318,7 +5319,7 @@ class ConvertInfeedDequeueTupleOp } llvm::SmallVector results; results.reserve(result_types.size()); - for (auto &idx_and_type : llvm::enumerate(result_types)) { + for (const auto &idx_and_type : llvm::enumerate(result_types)) { results.push_back(data_and_token.getResult(idx_and_type.index())); } rewriter.replaceOp(op, ValueRange(results)); @@ -6772,7 +6773,7 @@ class LowerControlFlowOp : public OpConversionPattern { if constexpr (std::is_same::value) { TypeConverter::SignatureConversion signature(num_results); Block &block = region.front(); - for (auto &[block_idx, original_ty] : + for (const auto &[block_idx, original_ty] : llvm::enumerate(block.getArgumentTypes())) { TensorType updated_ty = UpdateElementTypeTo(original_ty, element_types[block_idx]); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 2e51ccd7901..94bb9deb14a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -212,13 +212,15 @@ LogicalResult GetFunctionsToRewrite( return success(); } -// Assigns op sharding to full tensor on `kShardingTpuCore`. -void SetOpSharding(Operation* op) { - std::string sharding_serialized = - ::xla::sharding_builder::AssignDevice(kShardingTpuCore) - .SerializeAsString(); +// Assigns either MAXIMAL or MANUAL sharding. The MAXIMAL sharding sends/recvs +// one message from core `kShardingTpuCore` with the full tensor. MANUAL +// sharding sends/recvs one message for each core with the core's shard. +void SetOpSharding(Operation* op, bool manual_sharding) { + xla::OpSharding sharding = + manual_sharding ? ::xla::sharding_builder::Manual() + : ::xla::sharding_builder::AssignDevice(kShardingTpuCore); op->setAttr(kShardingAttr, - StringAttr::get(op->getContext(), sharding_serialized)); + StringAttr::get(op->getContext(), sharding.SerializeAsString())); } // Assigns frontend attributes holding information about data type and @@ -263,7 +265,7 @@ void SetFrontendAttributes(Operation* op, int32_t index, StringRef key, // Creates a `mhlo.send` op for sending value `operand`. Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, Value operand, StringRef key, size_t index, Value token, - StringRef host_handler_name) { + StringRef host_handler_name, bool manual_sharding) { // type 2 == DEVICE_TO_HOST auto channel_handle = ChannelHandleAttr::get(builder.getContext(), /*handle=*/channel_id++, @@ -275,7 +277,7 @@ Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, SetFrontendAttributes(send, index, key, operand.getType(), /*device_to_host=*/true, host_handler_name); - SetOpSharding(send); + SetOpSharding(send, manual_sharding); return send.getResult(); } @@ -283,7 +285,7 @@ Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, // Creates a `mhlo.recv` op for receiving a value. Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, Value result, StringRef key, size_t index, Value token, - StringRef host_handler_name) { + StringRef host_handler_name, bool manual_sharding) { // type 3 == HOST_TO_DEVICE auto channel_handle = ChannelHandleAttr::get(builder.getContext(), /*handle=*/channel_id++, @@ -297,7 +299,7 @@ Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, SetFrontendAttributes(recv, index, key, result_type, /*device_to_host=*/false, host_handler_name); - SetOpSharding(recv); + SetOpSharding(recv, manual_sharding); result.replaceAllUsesWith(recv.getResult(0)); @@ -328,12 +330,14 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, Value token) { builder.setInsertionPoint(host_compute); Location loc = host_compute.getLoc(); + bool manual_sharding = host_compute.getManualSharding(); SmallVector send_tokens; for (auto operand : llvm::enumerate(host_compute.getInputs())) { auto send_token = CreateSendOp( builder, channel_id, loc, operand.value(), host_compute.getSendKey(), - operand.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName); + operand.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName, + manual_sharding); send_tokens.push_back(send_token); } token = CreateSinkToken(builder, loc, send_tokens, token); @@ -342,7 +346,8 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, for (auto result : llvm::enumerate(host_compute.getOutputs())) { auto recv_token = CreateRecvOp( builder, channel_id, loc, result.value(), host_compute.getRecvKey(), - result.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName); + result.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName, + manual_sharding); recv_tokens.push_back(recv_token); } token = CreateSinkToken(builder, loc, recv_tokens, token); @@ -358,7 +363,8 @@ Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, token = CreateSendOp(builder, channel_id, send_to_host.getLoc(), send_to_host.getInput(), send_to_host.getKey(), /*index=*/0, token, - xla::kXlaHostTransferTfRendezvousHandlerName); + xla::kXlaHostTransferTfRendezvousHandlerName, + /*manual_sharding=*/false); send_to_host.erase(); return token; @@ -371,7 +377,8 @@ Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id, token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(), recv_from_host.getOutput(), recv_from_host.getKey(), /*index=*/0, token, - xla::kXlaHostTransferTfRendezvousHandlerName); + xla::kXlaHostTransferTfRendezvousHandlerName, + /*manual_sharding=*/false); recv_from_host.erase(); return token; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index c78e0e8a709..f28ea6958d3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -269,7 +269,7 @@ def : EqualityPat>; //===----------------------------------------------------------------------===// def OneElementAttrPred - : CPred<"$_self.cast().getType().getNumElements() == 1">; + : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; def OneElementAttr : ElementsAttrBase, diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index 8f74a84288d..ddd3b091e23 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -25,7 +25,6 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -34,21 +33,14 @@ limitations under the License. #include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" @@ -56,7 +48,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/mlir_hlo_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -72,8 +63,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/public/session_options.h" #include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/status.h" #include "tensorflow/tsl/platform/statusor.h" @@ -102,6 +91,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -288,6 +278,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), }; @@ -347,9 +339,8 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -368,7 +359,6 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -431,141 +421,13 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { } // LINT.ThenChange() -// List of ops that require falling back to XlaOpKernel legalizations and also -// require the ability to create functions. -bool IsOpAllowedTf2XlaFallbackAndCreateFunctions(Operation* op) { - static auto* ops = new llvm::SmallDenseSet{ - TypeID::get(), - }; - auto abstractOp = op->getRegisteredInfo(); - if (!abstractOp) return false; - return ops->count(abstractOp->getTypeID()); -} - bool HasTf2XlaFallback(Operation* op) { return IsOpAllowedTf2XlaFallback(op) || - IsOpAllowedTf2XlaFallbackAndCreateFunctions(op) || IsOpAllowedTf2XlaPreferred(op); } namespace { -template -using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok - -static std::unique_ptr CreateDeviceMgr( - const std::string& device_type) { - // Register compilation kernels for all registered XLA backends. - tensorflow::XlaOpRegistry::RegisterCompilationKernels(); - - auto device = std::make_unique( - tensorflow::SessionOptions(), tensorflow::DeviceType(device_type)); - return std::make_unique(std::move(device)); -} - -class Tf2XlaRewriter { - public: - static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter, - const std::string& device_type, - bool is_module_pass) { - Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type, is_module_pass); - return tf2xla_rewriter.LegalizeOp(); - } - - private: - Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter, - const std::string& device_type, bool is_module_pass) - : op_(op), - device_type_(device_type), - rewriter_(rewriter), - hlo_builder_(op->getName().getStringRef().str(), rewriter_, - op->getLoc(), /*build_functions=*/is_module_pass), - context_(nullptr) {} - - ~Tf2XlaRewriter() { - if (context_) context_->Unref(); - } - - // Prepares OpKernelContext params common to all the ops. - // Emits an error on failure. - LogicalResult PrepareParams(); - - // Tries to legalize the specified TensorFlow op, if supported. - // - // Emits an error and returns failure if an error is encountered during - // conversion. Note that success return value doesn't mean successful - // legalization. - LogicalResult LegalizeOp(); - - // Converts the given operand to expression of kind kConstant or kXlaOp. - // Emits a remark and returns expression of kind kInvalid on failure. - tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op); - - Operation* op_; - std::string device_type_; - - PatternRewriter& rewriter_; - ::xla::MlirHloBuilder hlo_builder_; - tensorflow::OpOrArgLocNameMapper name_mapper_; - - tensorflow::XlaContext* context_; // Ref-counted. - - std::unique_ptr device_mgr_; - tensorflow::Device* device_; // Owned by device_mgr_; - std::unique_ptr step_container_; - std::unique_ptr flib_def_; - std::unique_ptr pflr_; - tensorflow::OpKernelContext::Params params_; -}; - -LogicalResult Tf2XlaRewriter::PrepareParams() { - // XlaCompiler within the context is only used by the functional ops to - // compile functions. We are not handling those at the moment so XlaCompiler - // is not required. - context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_, - /*graph=*/nullptr); - context_->Ref(); - - device_mgr_ = CreateDeviceMgr(device_type_); - if (!device_mgr_) return failure(); - - // Type of params_.device is DeviceBase* so store it as Device* to access - // derived class method. - device_ = device_mgr_->ListDevices().front(); - params_.device = device_; - params_.resource_manager = device_->resource_manager(); - - // Resources are cleared at the time of device manager destruction so pass - // no-op cleanup function. - auto cleanup = [](const std::string& name) {}; - // Use step_id zero as we only have a single context concurrently and - // concurrently running each of the MLIR functions create a new device. - step_container_ = std::make_unique( - /*step_id=*/0, cleanup); - tsl::Status status = step_container_->Create( - device_->resource_manager(), - tensorflow::XlaContext::kXlaContextResourceName, context_); - if (!status.ok()) { - return emitRemark(op_->getLoc()) - << "failed to create XlaContext resource: " << status.ToString(); - } - params_.step_container = step_container_.get(); - - tsl::StatusOr version_or = tensorflow::GetTfGraphProducerVersion( - op_->getParentOfType()); - if (!version_or.ok()) { - return emitError(op_->getLoc()) << version_or.status().ToString(); - } - - flib_def_ = std::make_unique( - tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); - pflr_ = std::make_unique( - device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr, - version_or.value(), flib_def_.get(), tensorflow::OptimizerOptions()); - params_.function_library = pflr_->GetFLR(device_->name()); - return success(); -} - // Returns true if the given type is a ranked tensor type with static or bounded // dimensions. bool IsBounded(Type ty) { @@ -601,183 +463,16 @@ bool HasSymbolRefAttr(Operation* op) { return false; } -LogicalResult Tf2XlaRewriter::LegalizeOp() { - for (Type ty : op_->getOperandTypes()) { - auto ranked_ty = ty.dyn_cast(); - // Only bounded operands are supported in the XLA builders. - if (!IsBounded(ranked_ty)) { - return op_->emitRemark() - << "lowering requires bounded tensor operands " << ranked_ty; - } - } - - if (HasSymbolRefAttr(op_)) { - return op_->emitRemark() << "ops with symbol references are not supported"; - } - - auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( - op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true); - if (!nodedef_or.ok()) { - return op_->emitRemark() << "failed to convert op to NodeDef: " - << nodedef_or.status().ToString(); - } - - if (failed(PrepareParams())) return failure(); - - std::shared_ptr props; - tsl::Status status = tensorflow::NodeProperties::CreateFromNodeDef( - *nodedef_or.value(), - params_.function_library->GetFunctionLibraryDefinition(), &props); - if (!status.ok()) { - return op_->emitRemark() - << "failed to create NodeProperties: " << status.ToString(); - } - tensorflow::OpKernel* op_kernel_raw; - status = params_.function_library->CreateKernel(props, &op_kernel_raw); - if (!status.ok()) { - return op_->emitRemark() - << "failed to create tf2xla kernel: " << status.ToString(); - } - // Transfer ownership of the kernel to a local smart pointer. - auto op_kernel = absl::WrapUnique(op_kernel_raw); - - std::vector required_constants; - status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( - *op_kernel, &required_constants); - if (!status.ok()) { - return op_->emitRemark() - << "failed to compute required constants: " << status.ToString(); - } - llvm::SmallDenseSet required_consts; - required_consts.insert(required_constants.begin(), required_constants.end()); - - // TensorValue in inputs are backed by tensors which in turn depend on - // expressions. So, pre-allocate them to the required size. - InlinedVector expressions; - InlinedVector tensors; - InlinedVector inputs; - expressions.reserve(op_->getNumOperands()); - tensors.reserve(op_->getNumOperands()); - inputs.reserve(op_->getNumOperands()); - - // Prepare the list of Tensor inputs for the kernel. - for (auto it : llvm::enumerate(op_->getOperands())) { - Value operand = it.value(); - size_t idx = it.index(); - - tensorflow::XlaExpression expr = GetExprForOperand(operand, op_); - tensorflow::XlaExpression::Kind kind = expr.kind(); - if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure(); - if (required_consts.count(idx) && - kind != tensorflow::XlaExpression::Kind::kConstant) { - return op_->emitRemark() - << "lowering requires operand #" << idx << " to be a constant"; - } - expressions.push_back(expr); - - if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) { - return op_->emitRemark() - << "skipping legalization due to unsupported type " - << operand.getType(); - } - - auto shape_or = expr.GetShape(); - if (!shape_or.ok()) { - return op_->emitRemark() - << "failed to get shape for expression. " << expr.HumanString(); - } - - tensors.emplace_back( - device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(), - shape_or.value()); - tensorflow::Tensor& tensor = tensors.back(); - tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor); - inputs.emplace_back(&tensor); - } - - params_.inputs = inputs; - params_.op_kernel = op_kernel.get(); - llvm::SmallVector output_attr( - op_->getNumResults()); - params_.output_attr_array = output_attr.data(); - - hlo_builder_.setInsertionPoint(op_); - hlo_builder_.SetLocation(op_->getLoc()); - - // Execute the kernel. - tensorflow::OpKernelContext op_context(¶ms_, op_->getNumResults()); - device_->Compute(params_.op_kernel, &op_context); - - status = op_context.status(); - status.Update(hlo_builder_.GetCurrentStatus()); - if (!status.ok()) { - return op_->emitRemark() - << "compilation to HLO failed: " << status.ToString(); - } - - // Replace uses of old results using the corresponding value after the - // lowering. - llvm::SmallVector values; - values.reserve(op_->getNumResults()); - for (int i = 0, e = op_->getNumResults(); i < e; i++) { - tensorflow::Tensor* output = op_context.mutable_output(i); - const tensorflow::XlaExpression* expr = - tensorflow::XlaExpression::CastExpressionFromTensor(*output); - if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp && - expr->kind() != tensorflow::XlaExpression::Kind::kConstant) { - return op_->emitRemark( - "expects XlaExpression of kind kXlaOp or kConstant in compiled " - "output"); - } - mlir::Value value = hlo_builder_.GetValue(expr->AsXlaOp(&hlo_builder_)); - values.push_back(value); - } - rewriter_.replaceOp(op_, values); - return success(); -} - -tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand, - Operation* op) { - ElementsAttr const_attr; - auto defining_op = operand.getDefiningOp(); - if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) { - tensorflow::Tensor tensor; - auto status = tensorflow::ConvertToTensor(const_attr, &tensor); - if (!status.ok()) { - op->emitRemark() << "skipping legalization due to failed const conversion" - << status.ToString(); - return tensorflow::XlaExpression::Invalid(); - } - return tensorflow::XlaExpression::Constant(tensor); - } - - // Skip this op if XLA doesn't support this operand type. - auto xla_op_or = hlo_builder_.MakeXlaOp(operand); - if (!xla_op_or.ok()) { - op->emitRemark() << "skipping legalization due to " - << xla_op_or.status().ToString(); - return tensorflow::XlaExpression::Invalid(); - } - ::xla::XlaOp xla_op = xla_op_or.value(); - - tensorflow::DataType dtype; - auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype); - if (!status.ok()) { - op->emitRemark() << "skipping legalization due to " << status.ToString(); - return tensorflow::XlaExpression::Invalid(); - } - return tensorflow::XlaExpression::XlaOp(xla_op, dtype); -} - class Tf2XlaRewritePattern : public ConversionPattern { public: explicit Tf2XlaRewritePattern(MLIRContext* ctx, TypeConverter& converter, const std::string& device_type, - bool prefer_tf2xla, bool is_module_pass) + bool prefer_tf2xla, + bool use_tf2xla_hlo_importer) : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, ctx), device_type_(device_type), prefer_tf2xla_(prefer_tf2xla), - is_module_pass_(is_module_pass) {} + use_tf2xla_hlo_importer_(use_tf2xla_hlo_importer) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, @@ -790,25 +485,19 @@ class Tf2XlaRewritePattern : public ConversionPattern { if (old_val.getType() != new_val.getType()) return failure(); } - if (is_module_pass_) { - // Module passes should only ever legalize ops that have been specifically - // whitelisted for legalization within a module pass. They will never - // legalize any ops whitelisted for legalization within a func pass. - if (!IsOpAllowedTf2XlaFallbackAndCreateFunctions(op)) { - return failure(); - } - } else if (!(IsOpAllowedTf2XlaFallback(op) || - (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)))) { + if (!(IsOpAllowedTf2XlaFallback(op) || + (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)))) { return failure(); } + return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_, - is_module_pass_); + use_tf2xla_hlo_importer_); } private: std::string device_type_; bool prefer_tf2xla_; - bool is_module_pass_; + bool use_tf2xla_hlo_importer_; }; bool ShouldRefineTypeTo(Type original_ty, Type updated_ty) { @@ -897,12 +586,15 @@ Tf2XlaTypeConverter::Tf2XlaTypeConverter() { addSourceMaterialization(cast_value); } -void PopulateLegalizeTfWithTf2XlaPatterns( - llvm::StringRef device_type, RewritePatternSet& patterns, MLIRContext* ctx, - Tf2XlaTypeConverter& converter, bool prefer_tf2xla, bool is_module_pass) { +void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, + RewritePatternSet& patterns, + MLIRContext* ctx, + Tf2XlaTypeConverter& converter, + bool prefer_tf2xla, + bool use_tf2xla_hlo_importer) { patterns.add(ctx); patterns.add(ctx, converter, device_type.str(), - prefer_tf2xla, is_module_pass); + prefer_tf2xla, use_tf2xla_hlo_importer); } } // end namespace mhlo diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/passes.h b/tensorflow/compiler/mlir/tf2xla/transforms/passes.h index 4438756a419..e805b069f86 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/passes.h @@ -46,16 +46,14 @@ namespace mhlo { /// patterns from TF2XLA fallback for provided device type (see /// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not /// used. -std::unique_ptr> createLegalizeTFPass( +/// Note: This is a module pass because when legalizing with TF2XLA fallback, +/// functions are imported into the module. Importing functions into a +/// module is not thread safe. +std::unique_ptr> createLegalizeTFPass( bool allow_partial_conversion = false, bool legalize_chlo = true, std::optional tf2xla_fallback_device_type = std::nullopt, bool prefer_tf2xla = false); -/// Legalize whitelisted Ops using TF2XLA fallback for ops that must also be -/// able to create new functions. -std::unique_ptr> createLegalizeTFModulePass( - StringRef tf2xla_fallback_device_type = ""); - // Legalizes from MHLO quantized ops with MHLO quant types to MHLO primitive ops // like int ops. std::unique_ptr> createConvertMHLOQuantToIntPass(); @@ -84,7 +82,7 @@ void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, MLIRContext* ctx, Tf2XlaTypeConverter& converter, bool prefer_tf2xla = false, - bool is_module_pass = false); + bool use_tf2xla_hlo_importer = false); /// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern /// list. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc new file mode 100644 index 00000000000..a8d36fe1fce --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" + +#include +#include +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace mhlo { +namespace test { + +using ::mlir::DialectRegistry; +using ::mlir::MLIRContext; +using ::mlir::ModuleOp; +using ::mlir::OwningOpRef; +using ::tsl::StatusOr; + +StatusOr> GetMlirModuleFromString( + absl::string_view module_string, MLIRContext* context) { + DialectRegistry mlir_registry; + RegisterCommonToolingDialects(mlir_registry); + context->appendDialectRegistry(mlir_registry); + + OwningOpRef mlir_module; + auto status = + tensorflow::DeserializeMlirModule(module_string, context, &mlir_module); + if (!status.ok()) { + return status; + } + return mlir_module; +} + +} // namespace test +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h new file mode 100644 index 00000000000..15ea2bc7412 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace mhlo { +namespace test { + +// Given a raw string, return a ModuleOp that can be used with the given +// MLIRContext. +tsl::StatusOr> GetMlirModuleFromString( + absl::string_view module_string, MLIRContext* mlir_context); + +} // namespace test +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc new file mode 100644 index 00000000000..4117b5ce026 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -0,0 +1,554 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/mlir_hlo_builder.h" +#include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace mhlo { +namespace { + +using ::mlir::FunctionType; +using ::mlir::ModuleOp; +using ::mlir::OwningOpRef; +using ::mlir::func::FuncOp; +using ::tensorflow::Tensor; +using ::tsl::StatusOr; +using ::xla::XlaComputation; + +static std::unique_ptr CreateDeviceMgr( + const std::string& device_type) { + // Register compilation kernels for all registered XLA backends. + tensorflow::XlaOpRegistry::RegisterCompilationKernels(); + + auto device = std::make_unique( + tensorflow::SessionOptions(), tensorflow::DeviceType(device_type)); + return std::make_unique(std::move(device)); +} + +bool RootInstructionIsTuple(const xla::HloModule& hlo_module) { + xla::HloInstruction* root_instruction = + hlo_module.entry_computation()->root_instruction(); + + return root_instruction->opcode() == xla::HloOpcode::kTuple; +} + +}; // namespace + +LogicalResult Tf2XlaRewriter::RewriteOp(Operation* op, + PatternRewriter& rewriter, + const std::string& device_type, + bool use_tf2xla_hlo_importer) { + Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type, + use_tf2xla_hlo_importer); + return tf2xla_rewriter.LegalizeOp(); +} + +Tf2XlaRewriter::Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter, + const std::string& device_type, + bool use_tf2xla_hlo_importer) + : op_(op), + device_type_(device_type), + rewriter_(rewriter), + hlo_builder_(op->getName().getStringRef().str(), rewriter_, op->getLoc(), + /*build_functions=*/true), + context_(nullptr), + use_tf2xla_hlo_importer_(use_tf2xla_hlo_importer), + xla_builder_(op_->getName().getStringRef().str()) {} + +Tf2XlaRewriter::~Tf2XlaRewriter() { + if (context_) context_->Unref(); +} + +tsl::StatusOr Tf2XlaRewriter::ImportXlaComputation( + XlaComputation& computation) { + xla::DebugOptions debug_options; + TF_ASSIGN_OR_RETURN(auto hlo_module_config, + xla::HloModule::CreateModuleConfigFromProto( + computation.proto(), debug_options)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + xla::HloModule::CreateFromProto(computation.proto(), hlo_module_config)); + + if (!RootInstructionIsTuple(*hlo_module)) { + return tsl::errors::InvalidArgument("Imported XLA Root is not a tuple op"); + } + + ModuleOp mlir_module = op_->getParentOfType(); + mlir::OpBuilder builder(op_); + mlir::SymbolTable symbol_table(mlir_module); + + llvm::SmallVector arguments; + for (int i = 0; i < op_->getNumOperands(); i++) { + arguments.push_back(op_->getOperand(i)); + } + + // Ideally we could use the Function Importer but it increases compilation + // time when we have a model with thousands of tf2xla op fallbacks. At time + // of writing, this caused compilation time to be greater than 2x slower. + // So we have to directly import these instructions. + TF_ASSIGN_OR_RETURN( + mlir::Value root_value, + xla::HloFunctionImporter::ImportInstructions( + *hlo_module->entry_computation(), arguments, symbol_table, &builder)); + + mhlo::TupleOp root_tuple = + mlir::dyn_cast_or_null(root_value.getDefiningOp()); + if (!root_tuple) { + return tsl::errors::InvalidArgument( + "Imported XLA Root Value is not a tuple op"); + } + + return root_tuple; +} + +LogicalResult Tf2XlaRewriter::PrepareParams() { + // XlaCompiler within the context is only used by the functional ops to + // compile functions. We are not handling those at the moment so + // XlaCompiler is not required. + if (use_tf2xla_hlo_importer_) { + context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &xla_builder_, + /*graph=*/nullptr); + } else { + context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_, + /*graph=*/nullptr); + } + context_->Ref(); + + device_mgr_ = CreateDeviceMgr(device_type_); + if (!device_mgr_) return failure(); + + // Type of params_.device is DeviceBase* so store it as Device* to access + // derived class method. + device_ = device_mgr_->ListDevices().front(); + params_.device = device_; + params_.resource_manager = device_->resource_manager(); + + // Resources are cleared at the time of device manager destruction so pass + // no-op cleanup function. + auto cleanup = [](const std::string& name) {}; + // Use step_id zero as we only have a single context concurrently and + // concurrently running each of the MLIR functions create a new device. + step_container_ = std::make_unique( + /*step_id=*/0, cleanup); + tsl::Status status = step_container_->Create( + device_->resource_manager(), + tensorflow::XlaContext::kXlaContextResourceName, context_); + if (!status.ok()) { + return emitRemark(op_->getLoc()) + << "failed to create XlaContext resource: " << status.ToString(); + } + params_.step_container = step_container_.get(); + + tsl::StatusOr version_or = tensorflow::GetTfGraphProducerVersion( + op_->getParentOfType()); + if (!version_or.ok()) { + return emitError(op_->getLoc()) << version_or.status().ToString(); + } + + flib_def_ = std::make_unique( + tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); + pflr_ = std::make_unique( + device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr, + version_or.value(), flib_def_.get(), tensorflow::OptimizerOptions()); + params_.function_library = pflr_->GetFLR(device_->name()); + return success(); +} + +// Returns true if the given type is a ranked tensor type with static or +// bounded dimensions. +bool IsBounded(Type ty) { + auto ranked_ty = ty.dyn_cast(); + if (!ranked_ty) return false; + + if (ranked_ty.hasStaticShape()) return true; + + auto encoding = + ranked_ty.getEncoding().dyn_cast_or_null(); + if (!encoding) return false; + + for (int i = 0; i < ranked_ty.getRank(); ++i) { + if (ranked_ty.isDynamicDim(i) && + encoding.getBounds()[i] == ShapedType::kDynamic) { + return false; + } + } + return true; +} + +bool HasSymbolRefAttr(Operation* op) { + for (const auto& attr : op->getAttrs()) { + Attribute attr_value = attr.getValue(); + if (attr_value.isa()) { + return true; + } else if (auto array_attr = attr_value.dyn_cast()) { + if (!array_attr.empty() && array_attr.begin()->isa()) { + return true; + } + } + } + return false; +} + +LogicalResult Tf2XlaRewriter::PrepareKernelInputs( + const llvm::SmallDenseSet& required_consts, + std::vector& expressions, + std::vector& tensors, + std::vector& inputs) { + // Prepare the list of Tensor inputs for the kernel. + for (auto it : llvm::enumerate(op_->getOperands())) { + Value operand = it.value(); + size_t idx = it.index(); + + tensorflow::XlaExpression expr = GetExprForOperand(operand, op_, idx); + tensorflow::XlaExpression::Kind kind = expr.kind(); + if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure(); + if (required_consts.count(idx) && + kind != tensorflow::XlaExpression::Kind::kConstant) { + return op_->emitRemark() + << "lowering requires operand #" << idx << " to be a constant"; + } + expressions.push_back(expr); + + if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) { + return op_->emitRemark() + << "skipping legalization due to unsupported type " + << operand.getType(); + } + + auto shape_or = expr.GetShape(); + if (!shape_or.ok()) { + return op_->emitRemark() + << "failed to get shape for expression. " << expr.HumanString(); + } + + tensors.emplace_back( + device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(), + shape_or.value()); + + tensorflow::Tensor& tensor = tensors.back(); + tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor); + inputs.emplace_back(&tensor); + } + + return success(); +} + +LogicalResult Tf2XlaRewriter::LegalizeOp() { + for (Type ty : op_->getOperandTypes()) { + auto ranked_ty = ty.dyn_cast(); + // Only bounded operands are supported in the XLA builders. + if (!IsBounded(ranked_ty)) { + return op_->emitRemark() + << "lowering requires bounded tensor operands " << ranked_ty; + } + } + + if (HasSymbolRefAttr(op_)) { + return op_->emitRemark() << "ops with symbol references are not supported"; + } + + auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( + op_, name_mapper_.GetUniqueName(op_), + /*ignore_unregistered_attrs=*/true); + if (!nodedef_or.ok()) { + return op_->emitRemark() << "failed to convert op to NodeDef: " + << nodedef_or.status().ToString(); + } + + if (failed(PrepareParams())) return failure(); + + std::shared_ptr props; + tsl::Status status = tensorflow::NodeProperties::CreateFromNodeDef( + *nodedef_or.value(), + params_.function_library->GetFunctionLibraryDefinition(), &props); + if (!status.ok()) { + return op_->emitRemark() + << "failed to create NodeProperties: " << status.ToString(); + } + tensorflow::OpKernel* op_kernel_raw; + status = params_.function_library->CreateKernel(props, &op_kernel_raw); + if (!status.ok()) { + return op_->emitRemark() + << "failed to create tf2xla kernel: " << status.ToString(); + } + // Transfer ownership of the kernel to a local smart pointer. + auto op_kernel = absl::WrapUnique(op_kernel_raw); + + std::vector required_constants; + status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( + *op_kernel, &required_constants); + if (!status.ok()) { + return op_->emitRemark() + << "failed to compute required constants: " << status.ToString(); + } + + llvm::SmallDenseSet required_consts; + required_consts.insert(required_constants.begin(), required_constants.end()); + + // TensorValue in inputs are backed by tensors which in turn depend on + // expressions. So, pre-allocate them to the required size. Subtle note: + // Since these are assigned to params_, these have to live past the kernel + // compilation. + std::vector expressions; + std::vector tensors; + std::vector inputs; + expressions.reserve(op_->getNumOperands()); + tensors.reserve(op_->getNumOperands()); + inputs.reserve(op_->getNumOperands()); + + if (failed( + PrepareKernelInputs(required_consts, expressions, tensors, inputs))) + return failure(); + + params_.inputs = inputs; + params_.op_kernel = op_kernel.get(); + llvm::SmallVector output_attr( + op_->getNumResults()); + params_.output_attr_array = output_attr.data(); + + hlo_builder_.setInsertionPoint(op_); + hlo_builder_.SetLocation(op_->getLoc()); + + tensorflow::OpKernelContext op_context(¶ms_, op_->getNumResults()); + device_->Compute(params_.op_kernel, &op_context); + + status = op_context.status(); + status.Update(hlo_builder_.GetCurrentStatus()); + if (!status.ok()) { + return op_->emitRemark() + << "compilation to HLO failed: " << status.ToString(); + } + + if (failed(VerifyOpResults(op_context))) return failure(); + + mhlo::TupleOp tuple_result; + if (use_tf2xla_hlo_importer_) { + StatusOr tuple_result_or_status = + CompileWithHloImporter(op_context); + if (!tuple_result_or_status.ok()) { + return op_->emitRemark() << tuple_result_or_status.status().ToString(); + } + tuple_result = tuple_result_or_status.value(); + } + + llvm::SmallVector output_values; + if (failed(GetKernelOutputs(op_context, tuple_result, output_values))) { + return failure(); + } + + rewriter_.replaceOp(op_, output_values); + return success(); +} + +tsl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( + tensorflow::OpKernelContext& op_context) { + if (!use_tf2xla_hlo_importer_) { + return tsl::errors::InvalidArgument( + "Cannot compile with HloImporter because it isn't supported"); + } + + // XLA can only return a single value. Wrap all output op return values + // in a Tuple op that gets unpacked later. + std::vector output_values; + for (int i = 0, e = op_->getNumResults(); i < e; i++) { + tensorflow::Tensor* output = op_context.mutable_output(i); + const tensorflow::XlaExpression* expr = + tensorflow::XlaExpression::CastExpressionFromTensor(*output); + output_values.push_back(expr->AsXlaOp(&xla_builder_)); + } + + absl::Span return_values(output_values); + xla::XlaOp root_value = xla::Tuple(&xla_builder_, return_values); + + TF_ASSIGN_OR_RETURN(XlaComputation computation, + xla_builder_.Build(root_value, + /*remove_dynamic_dimensions=*/false)); + + return ImportXlaComputation(computation); +} + +mlir::LogicalResult Tf2XlaRewriter::VerifyOpResults( + tensorflow::OpKernelContext& op_context) { + for (int i = 0, e = op_->getNumResults(); i < e; i++) { + tensorflow::Tensor* output = op_context.mutable_output(i); + const tensorflow::XlaExpression* expr = + tensorflow::XlaExpression::CastExpressionFromTensor(*output); + + if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp && + expr->kind() != tensorflow::XlaExpression::Kind::kConstant) { + return op_->emitRemark(absl::StrCat( + "expects XlaExpression of kind kXlaOp or kConstant in compiled " + "output index ", + i)); + } + } + return success(); +} + +// XLA computations can only return a single value, but TF ops can return +// multiple values. We get around this by returning a tuple as an XLA op. We +// then unpack it here to return the multiple values instead. +mlir::LogicalResult Tf2XlaRewriter::UnpackTupleResults( + mhlo::TupleOp tuple_result, llvm::SmallVector& outputs) { + if (tuple_result->getNumOperands() != op_->getNumResults()) { + return op_->emitRemark() << "Translated TF2XLA tuple has different " + "number of results than original op"; + } + + for (int i = 0; i < tuple_result->getNumOperands(); i++) { + outputs.push_back(tuple_result->getOperand(i)); + } + + tuple_result.getOperation()->erase(); + return success(); +} + +mlir::LogicalResult Tf2XlaRewriter::GetKernelOutputs( + tensorflow::OpKernelContext& op_context, mhlo::TupleOp tuple_results, + llvm::SmallVector& outputs) { + outputs.reserve(op_->getNumResults()); + + if (use_tf2xla_hlo_importer_) { + return UnpackTupleResults(tuple_results, outputs); + } + + for (int i = 0, e = op_->getNumResults(); i < e; i++) { + tensorflow::Tensor* output = op_context.mutable_output(i); + const tensorflow::XlaExpression* expr = + tensorflow::XlaExpression::CastExpressionFromTensor(*output); + + mlir::Value value = hlo_builder_.GetValue(expr->AsXlaOp(&hlo_builder_)); + outputs.push_back(value); + } + + return success(); +} + +tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand( + Value operand, Operation* op, int64_t operand_index) { + ElementsAttr const_attr; + auto defining_op = operand.getDefiningOp(); + + ::xla::XlaOp xla_op; + if (use_tf2xla_hlo_importer_) { + xla_op = xla::Parameter(&xla_builder_, operand_index, + xla::TypeToShape(operand.getType()), + std::to_string(operand_index)); + } + + if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) { + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(const_attr, &tensor); + if (!status.ok()) { + op->emitRemark() << "skipping legalization due to failed const conversion" + << status.ToString(); + return tensorflow::XlaExpression::Invalid(); + } + + return tensorflow::XlaExpression::Constant(tensor); + } + + if (!use_tf2xla_hlo_importer_) { + auto xla_op_or = hlo_builder_.MakeXlaOp(operand); + if (!xla_op_or.ok()) { + op->emitRemark() << "skipping legalization due to " + << xla_op_or.status().ToString(); + return tensorflow::XlaExpression::Invalid(); + } + xla_op = xla_op_or.value(); + } + + tensorflow::DataType dtype; + auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype); + if (!status.ok()) { + op->emitRemark() << "skipping legalization due to " << status.ToString(); + return tensorflow::XlaExpression::Invalid(); + } + return tensorflow::XlaExpression::XlaOp(xla_op, dtype); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h new file mode 100644 index 00000000000..642674469d1 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -0,0 +1,130 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ + +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/mlir_hlo_builder.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace mlir { +namespace mhlo { + +class Tf2XlaRewriterTestPeer; + +class Tf2XlaRewriter { + public: + static mlir::LogicalResult RewriteOp(mlir::Operation* op, + mlir::PatternRewriter& rewriter, + const std::string& device_type, + bool use_tf2xla_hlo_importer); + + private: + friend class Tf2XlaRewriterTestPeer; + + Tf2XlaRewriter(mlir::Operation* op, mlir::PatternRewriter& rewriter, + const std::string& device_type, bool use_tf2xla_hlo_importer); + + ~Tf2XlaRewriter(); + + // Compiles the given Operation with XlaBuilder and imports the generated HLO + // via the HLO -> MHLO importer. + tsl::StatusOr CompileWithHloImporter( + tensorflow::OpKernelContext& op_context); + + // Import the given XlaComputation into the parent module. Returns the given + // generated function. + tsl::StatusOr ImportXlaComputation( + xla::XlaComputation& computation); + + // Prepares OpKernelContext params common to all the ops. + // Emits an error on failure. + mlir::LogicalResult PrepareParams(); + + // Given the required_consts, it will fill the 3 output vectors with + // their respective data. + // Expressions: Output XLA expressions as required by the compiled kernel. + // Tensors: Vector of tensors that back the TensorValue inputs + // Inputs: Vector of inputs that are backed by tensors. + mlir::LogicalResult PrepareKernelInputs( + const llvm::SmallDenseSet& required_consts, + std::vector& expressions, + std::vector& tensors, + std::vector& inputs); + + mlir::LogicalResult VerifyOpResults(tensorflow::OpKernelContext& op_context); + mlir::LogicalResult GetKernelOutputs(tensorflow::OpKernelContext& op_context, + mhlo::TupleOp tuple_results, + llvm::SmallVector& outputs); + + // Given a translated function with a single return value, unpack the tuple + // results. + mlir::LogicalResult UnpackTupleResults(mhlo::TupleOp tuple_result, + llvm::SmallVector& outputs); + + // Tries to legalize the specified TensorFlow op, if supported. + // + // Emits an error and returns failure if an error is encountered during + // conversion. Note that success return value doesn't mean successful + // legalization. + mlir::LogicalResult LegalizeOp(); + + // Converts the given operand to expression of kind kConstant or kXlaOp. + // Emits a remark and returns expression of kind kInvalid on failure. + tensorflow::XlaExpression GetExprForOperand(mlir::Value operand, + mlir::Operation* op, + int64_t operand_index); + + mlir::Operation* op_; + std::string device_type_; + + mlir::PatternRewriter& rewriter_; + ::xla::MlirHloBuilder hlo_builder_; + tensorflow::OpOrArgLocNameMapper name_mapper_; + + tensorflow::XlaContext* context_; // Ref-counted. + + std::unique_ptr device_mgr_; + tensorflow::Device* device_; // Owned by device_mgr_; + std::unique_ptr step_container_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + tensorflow::OpKernelContext::Params params_; + + bool use_tf2xla_hlo_importer_; + xla::XlaBuilder xla_builder_; +}; + +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc new file mode 100644 index 00000000000..4aeb42bd7bd --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -0,0 +1,324 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/memory/memory.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace mhlo { + +using ::mlir::LogicalResult; +using ::mlir::ModuleOp; +using ::mlir::OpBuilder; +using ::mlir::Operation; +using ::mlir::func::FuncOp; +using ::tsl::Status; +using ::tsl::StatusOr; +using ::xla::ReplicaGroup; +using ::xla::ShapeUtil; +using ::xla::XlaBuilder; +using ::xla::XlaComputation; +using ::xla::XlaOp; + +static constexpr char kMlirModuleStr[] = R"( +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1442 : i32}} { + func.func @main(%arg0: tensor<3xi64> {tf._user_specified_name = "resource", tf.aliasing_output = 3 : i64}) -> () attributes {tf.entry_function = {control_outputs = "stateful_normal/RngReadAndSkip,stateful_uniform/RngReadAndSkip,stateful_uniform_full_int/RngReadAndSkip", inputs = "stateful_normal_rngreadandskip_resource", outputs = "identity_RetVal,identity_1_RetVal,identity_2_RetVal"}} { + %0:3 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<3xi64>) -> (tensor, tensor, tensor) + return + } +})"; + +XlaComputation GetTestXlaComputation() { + XlaBuilder xla_builder("test"); + XlaOp add = xla::Add(xla::ConstantR0(&xla_builder, 1.0), + xla::ConstantR0(&xla_builder, 2.0)); + + std::vector tuple_values; + tuple_values.push_back(add); + + xla::Tuple(&xla_builder, tuple_values); + return xla_builder.Build().value(); +} + +class EmptyPatternRewriter : public mlir::PatternRewriter { + public: + explicit EmptyPatternRewriter(const OpBuilder& other_builder) + : mlir::PatternRewriter(other_builder) {} + ~EmptyPatternRewriter() override = default; +}; + +class Tf2XlaRewriterTestPeer { + public: + explicit Tf2XlaRewriterTestPeer() = delete; + explicit Tf2XlaRewriterTestPeer(mlir::Operation* op) + : op_builder_(op), + empty_rewriter_(op_builder_), + tf2xla_rewriter_(op, empty_rewriter_, + /*device_type=*/"XLA_CPU_JIT", + /*use_tf2xla_hlo_importer=*/true) {} + + tsl::StatusOr ImportXlaComputationIntoModule( + XlaComputation& computation) { + return tf2xla_rewriter_.ImportXlaComputation(computation); + } + + private: + OpBuilder op_builder_; + EmptyPatternRewriter empty_rewriter_; + Tf2XlaRewriter tf2xla_rewriter_; +}; + +// This should only have unit tests. End to end tests should be done with +// FILECHECK and MLIR tests. +class Tf2XlaRewriterTest : public ::testing::Test { + public: + void SetUp() override { + tensorflow::XlaOpRegistry::RegisterCompilationKernels(); + } + + Status CreateMlirModule(std::string module_string = kMlirModuleStr) { + TF_ASSIGN_OR_RETURN( + module_, test::GetMlirModuleFromString(module_string, &context_)); + + context_.loadAllAvailableDialects(); + return tsl::OkStatus(); + } + + Status LegalizeSingleOp(bool use_tf2xla_hlo_importer, Operation& op) { + SourceMgrDiagnosticHandler sourceMgrHandler(source_manager_, &context_); + + OpBuilder op_builder(&op); + EmptyPatternRewriter pattern_rewriter(op_builder); + + LogicalResult result = Tf2XlaRewriter::RewriteOp( + &op, pattern_rewriter, + /*device_type=*/"XLA_CPU_JIT", use_tf2xla_hlo_importer); + if (!result.succeeded()) { + return tsl::errors::Internal("Failed to rewrite op"); + } + + return tsl::OkStatus(); + } + + Status LegalizeModule(bool use_tf2xla_hlo_importer, + std::string module_string = kMlirModuleStr) { + TF_EXPECT_OK(CreateMlirModule(module_string)); + FuncOp main = module_->lookupSymbol("main"); + if (!main) { + return tsl::errors::InvalidArgument("Could not find a main function"); + } + + WalkResult walk_result = main.walk([&](Operation* op) { + if (op->getDialect()->getNamespace() != + TF::TensorFlowDialect::getDialectNamespace()) { + return WalkResult::advance(); + } + + if (!LegalizeSingleOp(use_tf2xla_hlo_importer, *op).ok()) { + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + + if (walk_result.wasInterrupted()) { + return tsl::errors::Internal("Could not legalize all ops"); + } + + return tsl::OkStatus(); + } + + mlir::func::FuncOp GetMainFunc() { + func::FuncOp main_func = module_->lookupSymbol("main"); + EXPECT_TRUE(main_func); + + return main_func; + } + + mlir::Operation& GetFirstOpFromMain() { + mlir::func::FuncOp main_func = GetMainFunc(); + return main_func.getBody().front().front(); + } + + StatusOr ImportXlaComputationIntoModule( + XlaComputation& computation) { + SourceMgrDiagnosticHandler sourceMgrHandler(source_manager_, &context_); + + mlir::Operation& first_op = GetFirstOpFromMain(); + + Tf2XlaRewriterTestPeer test_peer(&first_op); + return test_peer.ImportXlaComputationIntoModule(computation); + } + + protected: + MLIRContext context_; + OwningOpRef module_; + llvm::SourceMgr source_manager_; +}; + +TEST_F(Tf2XlaRewriterTest, LegalizesOp) { + TF_EXPECT_OK(LegalizeModule(/*use_tf2xla_hlo_importer=*/false)); +} + +TEST_F(Tf2XlaRewriterTest, LegalizesOpWithTf2xlaHloImporter) { + TF_EXPECT_OK(LegalizeModule(/*use_tf2xla_hlo_importer=*/true)); + + int num_tuple_ops = 0; + module_->walk([&num_tuple_ops](TupleOp tuple_op) { num_tuple_ops += 1; }); + + EXPECT_EQ(num_tuple_ops, 0); +} + +TEST_F(Tf2XlaRewriterTest, ImportsXlaComputationIntoModule) { + TF_ASSERT_OK(CreateMlirModule()); + + XlaComputation computation = GetTestXlaComputation(); + + TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + ImportXlaComputationIntoModule(computation)); + + ModuleOp parent_module = + root_tuple.getOperation()->getParentOfType(); + EXPECT_EQ(parent_module, *module_); +} + +TEST_F(Tf2XlaRewriterTest, FailsWithoutRootTuple) { + TF_ASSERT_OK(CreateMlirModule()); + + XlaBuilder xla_builder("test_fail"); + xla::Add(xla::ConstantR0(&xla_builder, 1.0), + xla::ConstantR0(&xla_builder, 2.0)); + XlaComputation bad_computation = xla_builder.Build().value(); + + EXPECT_FALSE(ImportXlaComputationIntoModule(bad_computation).ok()); +} + +TEST_F(Tf2XlaRewriterTest, ImportsSingleComputation) { + XlaBuilder builder("test_builder"); + XlaComputation to_apply; + { + auto sub_builder = builder.CreateSubBuilder("add"); + auto arg0 = Parameter(sub_builder.get(), 0, + ShapeUtil::MakeScalarShape(xla::F32), "x"); + auto arg1 = Parameter(sub_builder.get(), 1, + ShapeUtil::MakeScalarShape(xla::F32), "y"); + Add(arg0, arg1); + TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build()); + } + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(xla::F32, {4, 16}), "x"); + ReplicaGroup group; + group.add_replica_ids(0); + group.add_replica_ids(1); + XlaOp reduce_scatter = + ReduceScatter(x, to_apply, /*scatter_dimension=*/1, /*shard_count=*/2, + /*replica_groups=*/{group}); + + std::vector tuple_values; + tuple_values.push_back(reduce_scatter); + xla::Tuple(&builder, tuple_values); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + EXPECT_EQ(computation.proto().computations_size(), 2); + + TF_ASSERT_OK(CreateMlirModule()); + TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + ImportXlaComputationIntoModule(computation)); + EXPECT_TRUE(root_tuple); + + int num_func_ops = 0; + module_->walk([&num_func_ops](func::FuncOp func_op) { num_func_ops++; }); + + // Ensure that only a single computation was imported. + EXPECT_EQ(num_func_ops, 1); +} + +TEST_F(Tf2XlaRewriterTest, InsertsConstantParameters) { + static constexpr char kModuleWithConstParam[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1442 : i32}} { + func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.Const"() {value = dense<1.42> : tensor<2xf32>} : () -> tensor<2xf32> + %1 = "tf.Atan2"(%arg0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + })"; + + TF_ASSERT_OK( + LegalizeModule(/*use_tf2xla_hlo_importer=*/true, kModuleWithConstParam)); +} + +TEST_F(Tf2XlaRewriterTest, DISABLED_ImportsPrivateFunctions) { + XlaBuilder builder("test_builder"); + XlaComputation to_apply; + { + auto sub_builder = builder.CreateSubBuilder("add"); + auto arg0 = Parameter(sub_builder.get(), 0, + ShapeUtil::MakeScalarShape(xla::F32), "x"); + auto arg1 = Parameter(sub_builder.get(), 1, + ShapeUtil::MakeScalarShape(xla::F32), "y"); + Add(arg0, arg1); + TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build()); + } + auto a = Parameter(&builder, 0, ShapeUtil::MakeScalarShape(xla::F32), "a"); + auto b = Parameter(&builder, 1, ShapeUtil::MakeScalarShape(xla::F32), "b"); + XlaOp call_op = xla::Call(&builder, to_apply, {a, b}); + + std::vector tuple_values; + tuple_values.push_back(call_op); + xla::Tuple(&builder, tuple_values); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + EXPECT_EQ(computation.proto().computations_size(), 2); + + TF_ASSERT_OK(CreateMlirModule()); + TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + ImportXlaComputationIntoModule(computation)); + EXPECT_TRUE(root_tuple); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc index 0a3de2f45ac..264d64eda8a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc @@ -25,16 +25,17 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace { -using ::llvm::StringRef; -using ::mlir::DialectRegistry; using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::OwningOpRef; +using ::mlir::mhlo::test::GetMlirModuleFromString; using ::tensorflow::monitoring::testing::CellReader; // Using a string constant here instead of testdata to make this compatible @@ -50,21 +51,6 @@ static constexpr char kMlirModuleStr[] = R"( static constexpr char kFailedLegalizationStreamz[] = "/tensorflow/core/tf2xla/mlir_second_phase_failed_legalization_op_count"; -tsl::StatusOr> GetMlirModuleFromString( - StringRef string, MLIRContext* context) { - DialectRegistry mlir_registry; - RegisterAllTensorFlowDialects(mlir_registry); - context->appendDialectRegistry(mlir_registry); - - OwningOpRef mlir_module; - auto status = - tensorflow::DeserializeMlirModule(string, context, &mlir_module); - if (!status.ok()) { - return status; - } - return mlir_module; -} - TEST(VerifyTfxlaLegalizationTest, RecordsStreamzFailedVerification) { MLIRContext context; TF_ASSERT_OK_AND_ASSIGN(OwningOpRef module, diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index 3f993f270c5..09d5b91f05a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -52,12 +53,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/utils.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/attribute_importer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h" #include "tensorflow/core/util/quantization/uniform_quant_ops_params.h" @@ -68,6 +72,10 @@ namespace { #define GEN_PASS_DEF_LEGALIZETF #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.h.inc" +auto *mlir_failed_legalization_count = tensorflow::monitoring::Counter<2>::New( + "/tensorflow/core/tf2xla/v0/mlir_failed_xla_legalize_tf_pass_count", + "Counts the failure of legalization of ops", "op_name", "legality"); + class LegalizeTF : public impl::LegalizeTFBase { public: explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo, @@ -87,17 +95,6 @@ class LegalizeTF : public impl::LegalizeTFBase { #define GEN_PASS_DEF_LEGALIZETFMODULEPASS #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.h.inc" -class LegalizeTFModulePass - : public impl::LegalizeTFModulePassBase { - public: - explicit LegalizeTFModulePass(StringRef tf2xla_fallback_device_type) { - device_type_ = tf2xla_fallback_device_type.str(); - } - - /// Performs the lowering to XLA dialect. - void runOnOperation() override; -}; - FailureOr GetStorageType(Operation *op, Type original_output_element_type, PatternRewriter &rewriter) { @@ -169,13 +166,15 @@ FailureOr GetUniformQuantizedType( return GetSameShapeTensorType(original_type.cast(), elem_ty); } -template -FailureOr CreateConstantOpForQint8Rhs( - UniformQuantizedOp op, TensorType new_rhs_type, PatternRewriter &rewriter) { +template +FailureOr CreateConstantOp(UniformQuantizedOp op, + Value original_operand, + TensorType new_operand_type, + PatternRewriter &rewriter) { // Check whether the rhs operand has constant op. TF::TensorProtoAttr tensor_proto_attr; - if (!matchPattern(op.getRhs(), m_Constant(&tensor_proto_attr))) { - return rewriter.notifyMatchFailure(op, "rhs must be constant."); + if (!matchPattern(original_operand, m_Constant(&tensor_proto_attr))) { + return rewriter.notifyMatchFailure(op, "operand must be constant."); } llvm::StringRef mangled_tensor = tensor_proto_attr.getValue(); @@ -186,7 +185,7 @@ FailureOr CreateConstantOpForQint8Rhs( tensorflow::Status status = tensorflow::mangling_util::DemangleTensor(tensor_view, &tensor_proto); if (!status.ok()) { - return rewriter.notifyMatchFailure(op, status.error_message()); + return rewriter.notifyMatchFailure(op, status.message()); } tensorflow::Tensor t; @@ -194,11 +193,13 @@ FailureOr CreateConstantOpForQint8Rhs( return op.emitError("Failed to convert tensor proto to Tensor."); } - auto arr = t.flat(); + auto arr = t.flat(); auto dense_attr = mlir::DenseElementsAttr::get( - GetSameShapeTensorType(new_rhs_type, rewriter.getIntegerType(8)), + GetSameShapeTensorType( + new_operand_type, + rewriter.getIntegerType(8 * sizeof(TFQuantizedType))), llvm::ArrayRef(arr.data(), arr.size())); - return rewriter.create(op.getLoc(), new_rhs_type, + return rewriter.create(op.getLoc(), new_operand_type, dense_attr); } @@ -360,7 +361,8 @@ class ConvertUniformQuantizedDotHybridOp return failure(); } - auto rhs = CreateConstantOpForQint8Rhs(op, *rhs_type, rewriter); + auto rhs = CreateConstantOp(op, op.getRhs(), *rhs_type, + rewriter); if (failed(rhs)) { return failure(); } @@ -388,7 +390,8 @@ class ConvertUniformQuantizedConvolutionHybridOp return failure(); } - auto rhs = CreateConstantOpForQint8Rhs(op, *rhs_type, rewriter); + auto rhs = CreateConstantOp(op, op.getRhs(), *rhs_type, + rewriter); if (failed(rhs)) { return failure(); } @@ -498,7 +501,8 @@ class ConvertUniformQuantizedDotOp return failure(); } - auto rhs_or = CreateConstantOpForQint8Rhs(op, *rhs_type, rewriter); + auto rhs_or = CreateConstantOp(op, op.getRhs(), + *rhs_type, rewriter); if (failed(rhs_or)) { return failure(); } @@ -539,7 +543,8 @@ class ConvertUniformQuantizedConvolutionOp return failure(); } - auto rhs_or = CreateConstantOpForQint8Rhs(op, *rhs_type, rewriter); + auto rhs_or = CreateConstantOp(op, op.getRhs(), + *rhs_type, rewriter); if (failed(rhs_or)) { return failure(); } @@ -565,6 +570,110 @@ class ConvertUniformQuantizedConvolutionOp } }; +class ConvertUniformQuantizedAddOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::UniformQuantizedAddOp op, TF::UniformQuantizedAddOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.getLhs(); + + auto lhs_type = lhs.getType().cast(); + if (!lhs_type.hasRank()) { + return rewriter.notifyMatchFailure( + op, "Legalization supports cases where only lhs rank known."); + } + // rhs (bias) is always 1D that broadcasts to the last dim of lhs. + auto broadcast_dims = + GetI64ElementsAttr({lhs_type.getRank() - 1}, &rewriter); + + auto rhs_type = GetUniformQuantizedType( + op, adaptor.getRhs().getType(), op.getRhsScales(), + op.getRhsZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), op.getRhsQuantizationMinVal(), + op.getRhsQuantizationMaxVal(), op.getRhsQuantizationAxis(), rewriter); + if (failed(rhs_type)) { + return failure(); + } + + auto rhs_or = CreateConstantOp(op, op.getRhs(), + *rhs_type, rewriter); + if (failed(rhs_or)) { + return failure(); + } + + auto output_type = GetUniformQuantizedType( + op, op.getOutput().getType(), op.getOutputScales(), + op.getOutputZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), + op.getOutputQuantizationMinVal(), op.getOutputQuantizationMaxVal(), + op.getOutputQuantizationAxis(), rewriter); + if (failed(output_type)) { + return failure(); + } + + // lhs, rhs, output scales and zero_points are guaranteed (by the TF + // quantizer) to be identical, respectively. + rewriter.replaceOpWithNewOp(op, *output_type, lhs, + *rhs_or, broadcast_dims); + return success(); + } +}; + +class ConvertUniformQuantizedClipByValueOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::UniformQuantizedClipByValueOp op, + TF::UniformQuantizedClipByValueOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value operand = adaptor.getOperand(); + + const int64_t quantization_axis = op.getQuantizationAxis(); + llvm::SmallVector broadcast_dims_values = {}; + if (quantization_axis >= 0) { + broadcast_dims_values.push_back(quantization_axis); + } + auto broadcast_dims = GetI64ElementsAttr(broadcast_dims_values, &rewriter); + + auto min_max_type = GetUniformQuantizedType( + op, adaptor.getMin().getType(), op.getScales(), op.getZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), op.getQuantizationMinVal(), + op.getQuantizationMaxVal(), op.getQuantizationAxis(), rewriter); + if (failed(min_max_type)) { + return failure(); + } + auto min_or = CreateConstantOp(op, op.getMin(), + *min_max_type, rewriter); + if (failed(min_or)) { + return failure(); + } + auto max_or = CreateConstantOp(op, op.getMax(), + *min_max_type, rewriter); + if (failed(max_or)) { + return failure(); + } + + auto output_type = GetUniformQuantizedType( + op, op.getOutput().getType(), op.getScales(), op.getZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), op.getQuantizationMinVal(), + op.getQuantizationMaxVal(), op.getQuantizationAxis(), rewriter); + if (failed(output_type)) { + return failure(); + } + + Value res_min_clipped = rewriter.create( + op->getLoc(), *output_type, operand, *min_or, broadcast_dims); + rewriter.replaceOpWithNewOp( + op, *output_type, res_min_clipped, *max_or, broadcast_dims); + return success(); + } +}; + // Emits debug information which includes the number of ops of each type which // failed to legalize. void EmitLegalizationErrors(Operation *op, @@ -736,25 +845,64 @@ RewritePatternSet PatternsIncludeOps( return to; } +std::string OperationLegalityString(Operation *op, + const ConversionTarget &target) { + auto op_name = op->getName(); + auto action = target.getOpAction(op_name); + if (!action.has_value()) { + return "Unknown"; + } + switch (action.value_or(ConversionTarget::LegalizationAction::Legal)) { + case ConversionTarget::LegalizationAction::Legal: + return "Legal"; + case ConversionTarget::LegalizationAction::Dynamic: + return "Dynamic"; + case ConversionTarget::LegalizationAction::Illegal: + return "Illegal"; + default: + return "Invalid"; + } +} + +void IncrementFailedLegalizationCount(Operation *op, + const ConversionTarget &target) { + auto op_name = op->getName(); + auto name_string = op_name.getStringRef().str(); + auto op_legality = OperationLegalityString(op, target); + + mlir_failed_legalization_count->GetCell(name_string, op_legality) + ->IncrementBy(1); +} + mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns, bool legalize_chlo) { ConversionTarget target = GetDefaultLegalConversionTargets(*op->getContext(), legalize_chlo); - return applyPartialConversion(op, target, std::move(patterns)); + DenseSet unconverted_ops; + auto result = + applyPartialConversion(op, target, std::move(patterns), &unconverted_ops); + if (failed(result)) { + IncrementFailedLegalizationCount(op, target); + } + for (const auto &unconverted_op : unconverted_ops) { + IncrementFailedLegalizationCount(unconverted_op, target); + } + return result; } /// When `tf2xla_fallback_device_type` is not `None`, also uses legalization /// patterns from TF2XLA fallback for provided device type (see -/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not -/// used. +/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is +/// not used. LogicalResult legalizeTF(Operation *op, bool legalize_chlo, std::optional tf2xla_fallback_device_type, - bool prefer_tf2xla) { + bool prefer_tf2xla, bool use_tf2xla_hlo_importer) { MLIRContext *context = op->getContext(); RewritePatternSet legalize_lower_patterns(context); // Note that the `OperationConverter` orders patterns lexicographically by: - // 1) Ascending legalization depth (i.e., minimum number of patterns necessary + // 1) Ascending legalization depth (i.e., minimum number of patterns + // necessary // to arrive at conversion target). This requires relevant patterns to // specify the list of ops generated by it which most of patterns // implemented in C++ don't do so this comparison doesn't work in those @@ -791,9 +939,9 @@ LogicalResult legalizeTF(Operation *op, bool legalize_chlo, Tf2XlaTypeConverter converter; if (tf2xla_fallback_device_type) { // Add TF->HLO legalization patterns via TF2XLA fallback. - PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.value(), - patterns, context, converter, - prefer_tf2xla); + PopulateLegalizeTfWithTf2XlaPatterns( + tf2xla_fallback_device_type.value(), patterns, context, converter, + prefer_tf2xla, use_tf2xla_hlo_importer); } // Populate with CHLO->HLO lowerings to account for TF ops legalized to @@ -817,28 +965,8 @@ void LegalizeTF::runOnOperation() { tf2xla_fallback_device_type = device_type_; } if (failed(legalizeTF(getOperation(), legalize_chlo_, - tf2xla_fallback_device_type, prefer_tf2xla_))) { - signalPassFailure(); - } -} - -void LegalizeTFModulePass::runOnOperation() { - // This pass should only be run when a fallback device is present. - if (!device_type_.hasValue()) { - return; - } - VLOG(1) << "TF to XLA legalization patterns include TF2XLA fallback " - "patterns for Ops that need to create functions."; - Operation *op = getOperation(); - MLIRContext *context = op->getContext(); - RewritePatternSet patterns(context); - Tf2XlaTypeConverter converter; - PopulateLegalizeTfWithTf2XlaPatterns(device_type_, patterns, context, - converter, /*prefer_tf2xla=*/false, - /*is_module_pass=*/true); - - if (failed(ApplyPatterns(op, patterns, - /*legalize_chlo=*/false))) { + tf2xla_fallback_device_type, prefer_tf2xla_, + use_tf2xla_hlo_importer_))) { signalPassFailure(); } } @@ -847,14 +975,16 @@ void LegalizeTFModulePass::runOnOperation() { void PopulateLegalizeTfQuantizationPatterns(MLIRContext *context, RewritePatternSet *patterns) { - patterns->add(context); + patterns + ->add(context); } -std::unique_ptr> createLegalizeTFPass( +std::unique_ptr> createLegalizeTFPass( bool allow_partial_conversion, bool legalize_chlo, std::optional tf2xla_fallback_device_type, bool prefer_tf2xla) { return std::make_unique(allow_partial_conversion, legalize_chlo, @@ -862,10 +992,5 @@ std::unique_ptr> createLegalizeTFPass( prefer_tf2xla); } -std::unique_ptr> createLegalizeTFModulePass( - StringRef tf2xla_fallback_device_type) { - return std::make_unique(tf2xla_fallback_device_type); -} - } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td index 4d1b9388af2..cfec5714798 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td @@ -17,7 +17,7 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LegalizeTF : Pass<"xla-legalize-tf", "mlir::func::FuncOp"> { +def LegalizeTF : Pass<"xla-legalize-tf", "ModuleOp"> { let summary = "Legalize from TF dialect's or HLO dialect's control flow."; let description = [{ @@ -44,7 +44,12 @@ def LegalizeTF : Pass<"xla-legalize-tf", "mlir::func::FuncOp"> { Option<"prefer_tf2xla_", "prefer-tf2xla", "bool", /*default=*/"false", "Prioritize tf2xla fallback legalization over MLIR legalization " - "patterns"> + "patterns">, + Option<"use_tf2xla_hlo_importer_", "use-tf2xla-hlo-importer", + "bool", /*default=*/"false", + "Use the experimental HLO to MHLO importer for per-op fallback calls " + " from MLIR bridge to TF2XLA." + "Users should not set this flag and ideally this goes away."> ]; let constructor = "mlir::mhlo::createLegalizeTFPass()"; @@ -56,26 +61,6 @@ def LegalizeTF : Pass<"xla-legalize-tf", "mlir::func::FuncOp"> { "sparse_tensor::SparseTensorDialect"]; } -def LegalizeTFModulePass : Pass<"xla-fallback-legalize-tf-module-pass", "ModuleOp"> { - let summary = "Legalize whitelisted Ops using TF2XLA fallback for ops that " - "must also be able to create new functions."; - - let description = [{ - Legalizes whitelisted Ops from TF dialect to HLO dialect using TF2XLA - fallback for ops that must be allowed to create new functions. - }]; - let options = [ - Option<"device_type_", "device-type", "std::string", - /*default=*/"\"INVALID_DEVICE_TYPE\"", - "The device type used by TF2XLA fallback. Required.">, - ]; - - let constructor = "mlir::mhlo::createLegalizeTFModulePass()"; - let dependentDialects = ["arith::ArithDialect, chlo::ChloDialect", - "mhlo::MhloDialect", - "shape::ShapeDialect", "func::FuncDialect", "sparse_tensor::SparseTensorDialect"]; -} - def ConvertMHLOQuantToInt : Pass<"convert-mhlo-quant-to-int", "mlir::func::FuncOp"> { let summary = "Convert from MHLO quantized ops to MHLO primitive ops."; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc new file mode 100644 index 00000000000..7cc4d39676a --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_test.cc @@ -0,0 +1,115 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace { + +using ::mlir::MLIRContext; +using ::mlir::ModuleOp; +using ::mlir::OwningOpRef; +using ::mlir::PassManager; +using ::tensorflow::monitoring::testing::CellReader; + +StatusOr> GetMlirModuleFromString( + absl::string_view module_string, MLIRContext* context) { + mlir::DialectRegistry mlir_registry; + RegisterAllTensorFlowDialects(mlir_registry); + context->appendDialectRegistry(mlir_registry); + + OwningOpRef mlir_module; + auto status = + tensorflow::DeserializeMlirModule(module_string, context, &mlir_module); + if (!status.ok()) { + return status; + } + return mlir_module; +} + +bool BuildAndRunPipeline(absl::string_view module_string, + const std::function& passes) { + mlir::registerPassManagerCLOptions(); + MLIRContext context; + + OwningOpRef module = + GetMlirModuleFromString(module_string, &context).value(); + + PassManager pm(&context); + + if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) return false; + passes(&pm); + + return pm.run(module.get()).succeeded(); +} + +std::function legalizeTFPasses() { + return [](PassManager* pm) { + pm->addPass(mlir::mhlo::createLegalizeTFPass( + /* allow_partial_conversion=*/false, /* legalize_chlo=*/true, + llvm::StringRef("gpu/xpu"), /* prefer_tf2xla=*/false)); + }; +} + +TEST(XlaLegalizeTest, IllegalOp) { + constexpr char kMlirIllegalOpStr[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() -> tensor<1xi32> { + %0 = "tf.DoesntExist"() : () -> tensor<1xi32> + func.return %0 : tensor<1xi32> + } + })"; + CellReader legalize_failure_count( + "/tensorflow/core/tf2xla/v0/mlir_failed_xla_legalize_tf_pass_count"); + + auto status = BuildAndRunPipeline(kMlirIllegalOpStr, legalizeTFPasses()); + + EXPECT_TRUE(status); + EXPECT_EQ(legalize_failure_count.Read("tf.DoesntExist", "Unknown"), 1); +} + +TEST(XlaLegalizeTest, LegalOp) { + // We expect legalization to fail for legal op with dynamic shapes: + static constexpr char kMlirLegalOpStr[] = R"( + func.func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) { + %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>) func.return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32> + })"; + CellReader legalize_failure_count( + "/tensorflow/core/tf2xla/v0/mlir_failed_xla_legalize_tf_pass_count"); + + auto status = BuildAndRunPipeline(kMlirLegalOpStr, legalizeTFPasses()); + + EXPECT_TRUE(status); + EXPECT_EQ(legalize_failure_count.Read("tf.InfeedDequeueTuple", "Unknown"), 1); +} +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 93e904f3d90..f9fff19986e 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -286,6 +286,7 @@ tf_py_test( deps = [ "//tensorflow/compiler/mlir/tfr/resources:composite_ops", "//tensorflow/python/eager:def_function", + "//tensorflow/python/platform:client_testlib", ], ) @@ -317,6 +318,7 @@ tf_py_test( ], deps = [ "//tensorflow/compiler/mlir/tfr/resources:composite_ops", + "//tensorflow/python/platform:client_testlib", ], ) @@ -354,13 +356,20 @@ py_library( "//tensorflow/compiler/mlir/tfr:tfr_wrapper", "//tensorflow/python/autograph/converters:control_flow", "//tensorflow/python/autograph/converters:return_statements", - "//tensorflow/python/autograph/impl", - "//tensorflow/python/autograph/pyct", - "//tensorflow/python/autograph/pyct/static_analysis", + "//tensorflow/python/autograph/impl:api", + "//tensorflow/python/autograph/pyct:anno", + "//tensorflow/python/autograph/pyct:cfg", + "//tensorflow/python/autograph/pyct:qual_names", + "//tensorflow/python/autograph/pyct:transformer", + "//tensorflow/python/autograph/pyct:transpiler", + "//tensorflow/python/autograph/pyct/static_analysis:activity", + "//tensorflow/python/autograph/pyct/static_analysis:reaching_definitions", + "//tensorflow/python/autograph/pyct/static_analysis:reaching_fndefs", + "//tensorflow/python/autograph/pyct/static_analysis:type_inference", "//tensorflow/python/framework", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:op_def_registry", - "//tensorflow/python/platform", + "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:tf_inspect", "@gast_archive//:gast", ], @@ -380,6 +389,7 @@ tf_py_test( "//tensorflow/compiler/mlir/tfr/resources:test_ops", "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/platform:client_testlib", ], ) @@ -389,6 +399,8 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/python/autograph/pyct:transformer", + "//tensorflow/python/autograph/pyct:transpiler", ], ) @@ -403,6 +415,7 @@ tf_py_test( ":composite", ":op_reg_gen", "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + "//tensorflow/python/platform:client_testlib", ], ) @@ -412,6 +425,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/python/platform:client_testlib", ], ) diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc index 206e5ef13f8..50a8686ad4d 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -40,7 +40,8 @@ MlirOptimizationPassState GraphDecomposePass::GetPassState( } Status GraphDecomposePass::Run( - const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph, + const std::string& function_name, const ConfigProto& config_proto, + mlir::ModuleOp module, const Graph& graph, const FunctionLibraryDefinition& function_library) { if (GetPassState(/*device_set=*/nullptr, config_proto, graph, function_library) == MlirOptimizationPassState::Disabled) { diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h index e415f5cbea9..575fd2d178d 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ +#include + #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" @@ -40,8 +42,8 @@ class GraphDecomposePass : public MlirOptimizationPass { // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime. - Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, - const Graph& graph, + Status Run(const std::string& function_name, const ConfigProto& config_proto, + mlir::ModuleOp module, const Graph& graph, const FunctionLibraryDefinition& function_library) override; }; diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index 8c33af424b8..91a306c1fba 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -144,7 +145,8 @@ TFRDialect::TFRDialect(MLIRContext *context) Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); + return builder.create(loc, type, + value.cast()); if (func::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value.cast()); @@ -923,6 +925,16 @@ ArrayRef TFRFuncOp::getCallableResults() { return getFunctionType().getResults(); } +// CallableOpInterface +::mlir::ArrayAttr TFRFuncOp::getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); +} + +// CallableOpInterface +::mlir::ArrayAttr TFRFuncOp::getCallableResAttrs() { + return getResAttrs().value_or(nullptr); +} + //===----------------------------------------------------------------------===// // Dialect type definitions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 2e12356f03f..9a76d68efd9 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -79,8 +79,8 @@ namespace TFR { namespace { // Quantize the float value based on given scale and zero point attributes. -Attribute Quantize(float value, Attribute scale_attr, Attribute zp_attr, - OpBuilder builder) { +IntegerAttr Quantize(float value, Attribute scale_attr, Attribute zp_attr, + OpBuilder builder) { double scale = scale_attr.cast().getValueAsDouble(); int64_t zp = zp_attr.cast().getInt(); @@ -223,8 +223,8 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { attr_cst = builder.create(op->getLoc(), output_type, attribute); } else { - attr_cst = - builder.create(op->getLoc(), attribute); + attr_cst = builder.create( + op->getLoc(), cast(attribute)); } new_operands.push_back(attr_cst); } diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 1d748bd6ae9..068b7cabf22 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -24,8 +24,8 @@ package_group( "//tensorflow/core/runtime_fallback/...", "//tensorflow/core/tfrt/eager/...", "//tensorflow/core/tfrt/experimental/data/...", - "//tensorflow/core/tfrt/saved_model/...", "//tensorflow/core/tfrt/graph_executor/...", + "//tensorflow/core/tfrt/saved_model/...", "//tensorflow/core/tfrt/tfrt_session/...", ] + if_google([ "//learning/brain/experimental/mlir/tflite/tfmrt/...", @@ -112,7 +112,7 @@ cc_library( deps = [ "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", "//tensorflow/compiler/xla/mlir/backends/cpu/transforms:passes", "//tensorflow/compiler/xla/mlir/runtime/transforms:compiler", @@ -300,7 +300,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core/platform:env", "//tensorflow/core/platform:threadpool_interface", - "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat", + "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat_eager", "//tensorflow/core/runtime_fallback/runtime:kernel_utils", "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", "//tensorflow/core/tfrt/utils:fallback_tensor", @@ -394,12 +394,12 @@ cc_library( "transforms/merge_tf_if_ops.cc", "transforms/optimize.cc", "transforms/optimize_tf_control_flow_side_effect.cc", + "transforms/passes.cc", "transforms/remove_device_attribute.cc", "transforms/remove_tf_if_const_args.cc", "transforms/reorder_assert.cc", "transforms/sink_in_invariant_ops.cc", "transforms/tf_to_tfrt.cc", - "transforms/tpu_passes.h", "transforms/xla_rewrite_pass.cc", ], hdrs = [ @@ -411,44 +411,38 @@ cc_library( ":cost_analysis", ":fallback_converter", ":tensor_array_side_effect_analysis", - ":tf_jitrt_opdefs", - ":tf_jitrt_pipeline", + ":tfrt_jitrt_stub", ":tfrt_pipeline_options", + ":tpu_passes", + ":transform_utils", ":transforms/gpu_passes", ":transforms/set_shape_invariant_in_while_ops", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", - "//tensorflow/compiler/mlir/tensorflow:device_util", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", - "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_clustering", - "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", - "//tensorflow/core:framework", - "//tensorflow/core/platform:tstring", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", + "//tensorflow/core:framework", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:tstring", + "//tensorflow/tsl/platform:status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", - "@tf_runtime//backends/jitrt:jitrt_opdefs", "@tf_runtime//:stream_analysis", "@tf_runtime//:test_kernels_opdefs", - ":transform_utils", - "//tensorflow/tsl/platform:status", - ] + if_google([ - "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu", - ]), + ], alwayslink = 1, ) @@ -504,30 +498,28 @@ cc_library( ], deps = [ ":tf_to_tfrt", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:FuncDialect", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@tf_runtime//:bef", "@tf_runtime//:core_runtime", "@tf_runtime//:hostcontext", "@tf_runtime//:mlirtobef", "@tf_runtime//:tensor", - ] + if_google([ - "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu", - ]), + ], ) cc_library( @@ -539,6 +531,7 @@ cc_library( "translate/import_model.h", ], visibility = [ + # copybara:uncomment "//learning/brain/experimental/tfrt/mlrt/application/tensorflow/compiler/transforms:__pkg__", # copybara:uncomment "//learning/brain/experimental/tfrt/visualization:__pkg__", "//tensorflow/compiler/mlir/tfrt/tests/saved_model:__pkg__", "//tensorflow/core/tfrt/eager:__pkg__", @@ -549,33 +542,35 @@ cc_library( ":function", ":tf_to_tfrt", ":tfrt_compile_options", - "@com_google_absl//absl/strings", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", + ":tfrt_pipeline_options", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "@llvm-project//mlir:FuncDialect", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/core:framework", "//tensorflow/core/common_runtime:function_body", "//tensorflow/core/common_runtime:function_def_utils", - "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/platform:status", + "//tensorflow/core/tfrt/fallback:fallback_state", + "//tensorflow/tsl/platform:errors", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@tf_runtime//:bef", "@tf_runtime//:mlirtobef", - ] + if_google([ - "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu", - ]), + ], ) cc_library( name = "tfrt_compile_options", srcs = ["translate/tfrt_compile_options.cc"], hdrs = ["translate/tfrt_compile_options.h"], - deps = ["@com_google_absl//absl/strings"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/strings", + ], ) cc_library( @@ -583,6 +578,7 @@ cc_library( srcs = ["analysis/cost_analysis.cc"], hdrs = ["analysis/cost_analysis.h"], deps = [ + ":constants", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core/platform:status", "//tensorflow/core/tfrt/fallback:cost_recorder", @@ -639,12 +635,10 @@ cc_library( ":__subpackages__", ], deps = [ + "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_test_passes", - "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", - ] + if_google([ - "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu", - ]), + ], ) cc_library( @@ -659,8 +653,8 @@ cc_library( ], ) -tf_cc_binary( - name = "tf-tfrt-opt", +cc_library( + name = "tf_tfrt_opt_lib", testonly = True, srcs = ["tf-tfrt-opt.cc"], deps = [ @@ -669,6 +663,7 @@ tf_cc_binary( ":test_tensor_array_side_effect_analysis", ":tf_jitrt_opdefs", ":tf_to_tfrt", + ":tfrt_jitrt_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir:passes", "//tensorflow/compiler/mlir/lite:tensorflow_lite", @@ -694,6 +689,12 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "tf-tfrt-opt", + testonly = True, + deps = [":tf_tfrt_opt_lib"], +) + tf_cc_binary( name = "lhlo-tfrt-opt", srcs = ["lhlo-tfrt-opt.cc"], @@ -778,11 +779,11 @@ tf_cc_binary( ], visibility = [":friends"], deps = [ - "@llvm-project//mlir:TranslateLib", - "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", - "//tensorflow/compiler/mlir/tfrt:tf_jitrt_registration", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tfrt:tf_jitrt_registration", "//tensorflow/compiler/mlir/tfrt:tfrt_fallback_registration", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", + "@llvm-project//mlir:TranslateLib", "@tf_runtime//:init_tfrt_dialects", "@tf_runtime//:mlirtobef_translate", ] + if_google( @@ -839,3 +840,55 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "tpu_passes", + hdrs = ["transforms/tpu_passes.h"], + deps = [ + ":fallback_converter", + ":tfrt_compile_options", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "tfrt_jitrt_passes", + srcs = ["transforms/tfrt_jitrt_passes.cc"], + deps = [ + ":fallback_converter", + ":tf_jitrt_opdefs", + ":tf_jitrt_pipeline", + ":tfrt_jitrt_stub", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", + "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_clustering", + "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:TransformUtils", + "@tf_runtime//:basic_kernels_opdefs", + "@tf_runtime//backends/jitrt:jitrt_opdefs", + ], + alwayslink = 1, +) + +cc_library( + name = "tfrt_jitrt_stub", + srcs = ["transforms/tfrt_jitrt_stub.cc"], + hdrs = ["transforms/tfrt_jitrt_stub.h"], + deps = [ + ":corert_converter", + ":tfrt_pipeline_options", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "constants", + hdrs = ["constants.h"], +) diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc index 9426580bf13..c7d02332839 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfrt/constants.h" #include "tensorflow/core/tfrt/fallback/cost_recorder.h" namespace tensorflow { @@ -141,6 +142,7 @@ void CostAnalysis::AnalyzeArguments(mlir::func::FuncOp func_op) { // Use the max size among function inputs as the default size of dynamic // shaped tensors in the function. for (auto arg : func_op.getArguments()) { + if (!arg.getType().isa()) continue; auto type = arg.getType().cast(); if (type.hasRank()) { max_arg_size_ = std::max(max_arg_size_, GetRankedTensorSize(type)); @@ -160,15 +162,6 @@ void CostAnalysis::EvaluateCost(mlir::Operation* op) { return; } - // These ops are cheap regardless of their input sizes. - // - // TODO(chky): Find a more scalable way to figure out cheap ops. - if (llvm::isa(op)) { - cost_map_[op] = kDefaultCheapCost; - return; - } - // Try to use its cost function if it is registered. const auto& registry = GetCostFunctionRegistry(); absl::string_view op_name = op->getName().getStringRef(); @@ -180,6 +173,25 @@ void CostAnalysis::EvaluateCost(mlir::Operation* op) { return; } + // Try to use the recorded cost if any. + if (cost_recorder_ != nullptr) { + const auto op_key_attr = + op->getAttrOfType(kOpKeyAttrName); + if (op_key_attr) { + cost_map_[op] = cost_recorder_->GetCostNanosecond(op_key_attr.getInt()); + return; + } + } + + // These ops are cheap regardless of their input sizes. + // + // TODO(chky): Find a more scalable way to figure out cheap ops. + if (llvm::isa(op)) { + cost_map_[op] = kDefaultCheapCost; + return; + } + // For other ops, use the sum of input sizes as its cost. int64_t cost = kDefaultCheapCost; for (auto operand : op->getOperands()) { diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h index 8ed554de919..fa01b38dd64 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" #include "tensorflow/core/tfrt/fallback/op_cost_map.pb.h" namespace tensorflow { @@ -36,7 +37,10 @@ namespace tfrt_compiler { // class CostAnalysis { public: - explicit CostAnalysis(mlir::func::FuncOp func_op) { + explicit CostAnalysis( + mlir::func::FuncOp func_op, + const tfrt_stub::CostRecorder* cost_recorder = nullptr) { + cost_recorder_ = cost_recorder; AnalyzeArguments(func_op); AnalyzeBlock(&func_op.front()); } @@ -50,6 +54,7 @@ class CostAnalysis { int64_t max_arg_size_ = 1; llvm::DenseMap cost_map_; + const tfrt_stub::CostRecorder* cost_recorder_; }; struct CostContext { diff --git a/tensorflow/compiler/mlir/tfrt/constants.h b/tensorflow/compiler/mlir/tfrt/constants.h new file mode 100644 index 00000000000..dfbb9ba4898 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/constants.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_CONSTANTS_H_ + +namespace tensorflow { +namespace tfrt_compiler { + +// Use __ prefix to indicate this is internal attribute. +inline constexpr char kOpKeyAttrName[] = "__op_key"; + +} // namespace tfrt_compiler + +namespace mlrt_compiler { + +inline constexpr char kArgPassByValue[] = "mlrt.__pass_by_value"; + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_CONSTANTS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td index d61d3235e0f..daf76268bc2 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td @@ -60,6 +60,26 @@ def SetResourceOp : FallbackSync_Op<"set_resource", [CoreRT_TypedAttributeTrait] let assemblyFormat = "operands attr-dict"; } +def SetResourceDhtOp : FallbackSync_Op<"set_resource_dht", [CoreRT_TypedAttributeTrait]> { + let summary = "Set a DHT in resource array"; + + let description = [{ + Set a DHT in resource array. + + arg: the tensor to be set in the resource array. + index: the index in the resource array + }]; + + let arguments = (ins + TensorType:$arg, + I64Attr:$index + ); + + let results = (outs); + + let assemblyFormat = "operands attr-dict"; +} + def GetResourceOp : FallbackSync_Op<"get_resource", [CoreRT_TypedAttributeTrait]> { let summary = "get a tensor in resource array"; @@ -82,6 +102,28 @@ def GetResourceOp : FallbackSync_Op<"get_resource", let assemblyFormat = "attr-dict `:` type($results)"; } +def GetResourceDhtOp : FallbackSync_Op<"get_resource_dht", + [CoreRT_TypedAttributeTrait]> { + let summary = "get a DHT in resource array"; + + let description = [{ + Get a tensor in resource array. + + indices: the indices in the resource array. + results: the tensor values for the corresponding indices. + }]; + + let arguments = (ins + I64ArrayAttr:$indices + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = "attr-dict `:` type($results)"; +} + def CreateOp: FallbackSync_Op<"createop", []> { let summary = "The Fallback CreateOp"; diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc index a7cf4379ccb..6fe3091fed3 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc @@ -651,7 +651,7 @@ struct DebugListener : public SpecializationListener { std::string message; llvm::raw_string_ostream os(message); os << "Specialized operands:\n"; - for (auto& tuple : llvm::enumerate(llvm::zip(operands, attrs))) { + for (const auto& tuple : llvm::enumerate(llvm::zip(operands, attrs))) { mlir::Type type = std::get<0>(tuple.value()); mlir::Attribute attr = std::get<1>(tuple.value()); os << "%arg" << tuple.index() << ": " << type << " " << attr << "\n"; diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc index 946cc1c4bb6..327cfa45b9a 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc @@ -98,7 +98,7 @@ void CreateTfJitRtPipeline(OpPassManager& pm, pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); // Transform TF operation to HLO. - pm.addNestedPass(mlir::mhlo::createLegalizeTFPass()); + pm.addPass(mlir::mhlo::createLegalizeTFPass()); if (options.legalize_i1_tensors) { // Convert 'i1' tensors into 'i8' tensors. @@ -130,7 +130,7 @@ void CreateTfJitRtPipeline(OpPassManager& pm, // Transform HLO operations to Linalg and Standard. pm.addNestedPass(mlir::mhlo::createLegalizeControlFlowPass()); pm.addNestedPass(mlir::mhlo::createLegalizeSortPass()); - pm.addNestedPass(xla::cpu::createLegalizeCollectiveOpsPass()); + pm.addNestedPass(xla::cpu::createLegalizeLibraryOpsPass()); if (options.vectorize) { pm.addNestedPass(mlir::mhlo::createLegalizeMHLOToTHLOPass()); @@ -170,6 +170,7 @@ void CreateTfJitRtPipeline(OpPassManager& pm, mlir::gml_st::getDefaultCPUPipelineOptions(llvm::sys::getHostCPUName()); gml_st_opts.matmulTileSizes = options.matmul_tile_sizes; gml_st_opts.lowerToMmt4d = options.lower_to_mmt4d; + gml_st_opts.reductionEnableHeuristic = true; mlir::gml_st::addCPUTilingPipeline(pm, gml_st_opts); } else { pm.addNestedPass(CreateFusionPass()); diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc index 2418798f19d..c0bb497078e 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc @@ -29,7 +29,6 @@ namespace tensorflow { using ::tfrt::AsyncValue; -using ::tfrt::DType; using ::tfrt::RCReference; using ::tfrt::RemainingResults; diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD index c185cb0f9c3..b1ce160fa74 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD @@ -63,7 +63,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:gml_st", "//tensorflow/compiler/xla/mlir_hlo:gml_st_passes", diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc index 6fe59cf34e7..4ae62f262a5 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc @@ -852,7 +852,7 @@ void populateTfJitRtConstraintsPolicies(ClusteringPolicySet& policies, mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value) { return success(value.getNumElements() <= 16 && - value.getType().getElementType().isIntOrIndexOrFloat()); + value.getShapedType().getElementType().isIntOrIndexOrFloat()); } static bool IsI1Integer(Type type) { diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc index 0e3a24ee5c1..65456ca8c0f 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc @@ -39,9 +39,9 @@ using mlir::AffineMap; using mlir::MLIRContext; using mlir::Operation; using mlir::OpOperand; -using mlir::OpResult; using mlir::RewritePatternSet; +namespace affine = mlir::affine; namespace linalg = mlir::linalg; namespace tensor = mlir::tensor; @@ -140,7 +140,7 @@ struct FusionPass : public impl::FusionBase { linalg::populateConstantFoldLinalgOperations(patterns, ControlElementwiseOpsFusion); - mlir::AffineApplyOp::getCanonicalizationPatterns(patterns, context); + affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); linalg::GenericOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc index f81159c9699..64237f0bb08 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc @@ -19,8 +19,6 @@ limitations under the License. namespace tensorflow { -using ::mlir::Operation; - bool IsContiguousMemref(mlir::Value value) { auto memref_type = value.getType().dyn_cast(); if (!memref_type) return false; diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl b/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl index d6327bf24a5..fb183e02362 100644 --- a/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl +++ b/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/build_defs.bzl @@ -32,7 +32,10 @@ def _run_regression_test(name, compare_with_tensorflow, vectorize, data): "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tfrt_fallback", "//tensorflow/python:client_testlib", - "//tensorflow/python/platform", + "//tensorflow/python/platform:tf_logging", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:resource_loader", + "//tensorflow/python/platform:gfile", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc index 62ec862a393..0f16091d799 100644 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc +++ b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc @@ -33,13 +33,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/threadpool_interface.h" -#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h" #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tfrt/bef/bef_buffer.h" // from @tf_runtime #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime #include "tfrt/host_context/async_value.h" // from @tf_runtime +#include "tfrt/host_context/chain.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime #include "tfrt/host_context/function.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime @@ -52,12 +53,9 @@ using ::tfrt::AsyncValue; using ::tfrt::BEFFile; using ::tfrt::ExecutionContext; using ::tfrt::Function; -using ::tfrt::HostContext; using ::tfrt::MakeAvailableAsyncValueRef; using ::tfrt::RCReference; -using ::tfrt::RequestContext; using ::tfrt::RequestContextBuilder; -using ::tfrt::ResourceContext; using ::tensorflow::Env; using ::tensorflow::thread::ThreadPool; @@ -112,8 +110,7 @@ RuntimeFallbackExecutor::RuntimeFallbackExecutor(int64_t num_threads) // Initialize fallback kernels state with a custom intra-op thread pool. auto status = tensorflow::tfd::SetUpKernelFallbackCompatRequestContext( &builder, /*runner_table=*/nullptr, eager_context, intra_op_.get()); - CHECK(status.ok()) << "Failed to setup request context: " - << status.error_message(); + CHECK(status.ok()) << "Failed to setup request context: " << status.message(); auto req_ctx = std::move(builder).build(); if (auto err = req_ctx.takeError()) diff --git a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir index 1cafb216743..ceda5aecef7 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir @@ -244,6 +244,34 @@ func.func private @some_func(%arg: tensor) -> tensor { module attributes {tf_saved_model.semantics} { +// Test not hoisting callees in xla launch functions. + +// CHECK-LABEL: func private @xla_func +func.func private @xla_func(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK-NOT: tf._TfrtGetResource + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + %2 = "tf.AddV2"(%arg0, %1) {device = "/device:CPU:0"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %3 = "tf.Identity"(%2) {device = "/device:CPU:0"} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %3 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["main"]} { + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.XlaLaunch"(%arg0, %0) {device = "/device:GPU:0", function = @xla_func, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor>>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> + +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // Test not hoisting in TPU functions. // CHECK-LABEL: func @_tfrt_resource_init @@ -260,4 +288,4 @@ func.func private @func2(%arg: tensor) -> tensor { func.return %r : tensor } -} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops_mlrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops_mlrt.mlir new file mode 100644 index 00000000000..7b797b357a1 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops_mlrt.mlir @@ -0,0 +1,31 @@ +// RUN: tf-tfrt-opt -split-input-file -tfrt-lower-tf-savedmodel="hoist-invariant-ops=true fuse-get-resource-ops=false" %s | FileCheck %s --dump-input=fail --dump-input-filter=all + +module attributes {tf_saved_model.semantics} { + +// Test hoisting hash table op. + +// CHECK-LABEL: func @_tfrt_resource_init +// CHECK: [[handle:%.*]] = "tf.HashTableV2"() +// CHECK-SAME: shared_name = "x" +// CHECK: "tf._TfrtSetResource"([[handle]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[handle_id:.*]] : i64} +// CHECK: [[x:%.*]] = "tf.LookupTableSizeV2"([[handle]]) +// CHECK: "tf._TfrtSetResource"([[x]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0", index = [[size_id:.*]] : i64} : (tensor) -> () + +// CHECK: func @test_hoist_hash_table +func.func @hoist_hash_table(%arg: tensor {tf_saved_model.index_path = ["input"]}, %default: tensor {tf_saved_model.index_path = ["default"]}) -> (tensor {tf_saved_model.index_path = ["r"]}, tensor<*xi64> {tf_saved_model.index_path = ["r1"]}) + attributes {tf_saved_model.exported_names = ["test_hoist_hash_table"]} { + // CHECK-NOT: tf.HashTableV2 + // CHECK-NOT: tf.LookupTableSizeV2 + // CHECK-DAG: [[v0:%.*]] = "tf._TfrtGetResource"() {container = [""], device = "/job:localhost/replica:0/task:0/device:CPU:0", indices = [[[handle_id]]], shared_name = [{{.*}}]} + // CHECK-DAG: [[v1:%.*]] = "tf._TfrtGetResource"() {container = [""], device = "/job:localhost/replica:0/task:0/device:CPU:0", indices = [[[size_id]]], shared_name = [{{.*}}]} + // CHECK-DAG: [[r:%.*]] = "tf.LookupTableFindV2"([[v0]] + // CHECK-DAG: return [[v1]], [[r]] + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "x", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %1 = "tf.LookupTableSizeV2"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.LookupTableFindV2"(%0, %arg, %default) {device = "/CPU:0"} : (tensor, tensor, tensor) -> tensor<*xi64> + func.return %1, %2 : tensor, tensor<*xi64> +} + +} + +// ----- diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir index 7e8655ce4e0..b208fe390ac 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/const_tensor.mlir @@ -16,7 +16,7 @@ func.func @dense_tensor() -> tensor<4xui64> { %0 = "tf.Const"() {value = dense<[1, 2, 3, 4]> : tensor<4xui64>} : () -> tensor<4xui64> // CHECK: corert.const_dense_tensor dense<1.000000e+00> : tensor<1xbf16> %1 = "tf.Const"() {device = "/device:CPU:0", value = dense<[1.0]> : tensor<1xbf16>} : () -> tensor<4xbf16> - // CHECK-NOT: corert.executeop + // CHECK: corert.executeop({{.*}}) "tf.Const"() {dtype = ui64, value = dense<[1, 2, 3, 4]> : tensor<4xui64>} : 1 %2 = "tf.Const"() {device = "/device:GPU:0", value = dense<[1, 2, 3, 4]> : tensor<4xui64>} : () -> tensor<4xui64> func.return %0 : tensor<4xui64> } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir index 149fee8f244..4c5777c28e2 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir @@ -5,7 +5,7 @@ func.func @device_test( %arg0: tensor<3x1xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<1x3xf32> {tf_saved_model.index_path = [0]}) -> (tensor<3x3xf32> {tf_saved_model.index_path = []}) { - // CHECK: device("/device:GPU:0") + // CHECK: {{%.*}} = corert.get_op_handler %arg0 "/device:GPU:0" %2 = "tf.MatMul"(%arg0, %arg1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> func.return %2 : tensor<3x3xf32> } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir index 0e605ccc6af..8f59a1a42d7 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir @@ -62,9 +62,9 @@ func.func @no_native(%arg0: tensor<3x1xf32>, %arg1: tensor, %arg1: tensor>>) -> tensor<3x3xf32> { - // CHECK-NOT: corert.executeop - // CHECK: tfrt_fallback_async.executeop.seq({{.*}}) key({{.*}}) cost({{.*}}) device("/device:GPU:0") "tf.ReadVariableOp" - // CHECK: tfrt_fallback_async.executeop key({{.*}}) cost({{.*}}) device("/device:GPU:0") "tf.MatMul" + // CHECK: {{%.*}} = corert.get_op_handler %arg0 "/device:GPU:0" + // CHECK: {{.*}} = corert.executeop.seq({{.*}}) "tf.ReadVariableOp"({{.*}}) {dtype = f32} : 1 + // CHECK: {{.*}} = corert.executeop({{.*}}) "tf.MatMul"({{.*}}) {T = f32, transpose_a = false, transpose_b = false} : 1 %0 = "tf.ReadVariableOp"(%arg1) {device = "/device:GPU:0", dtype = f32} : (tensor>>) -> tensor<1x3xf32> %1 = "tf.MatMul"(%arg0, %0) {T = f32, device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> func.return %1 : tensor<3x3xf32> @@ -117,12 +117,3 @@ func.func @tensor_array() -> (tensor<1x1x512xf32>) { %result = "tf.TensorArrayGatherV3"(%handle, %indices, %flow_1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", element_shape = #tf_type.shape<1x512>} : (tensor<2x!tf_type.resource>>, tensor<1xi32>, tensor) -> tensor<1x1x512xf32> func.return %result : tensor<1x1x512xf32> } - -// CHECK-LABEL: func @gpu_device_cost -func.func @gpu_device_cost(%arg0: tensor<3x1xf32>, %arg1: tensor>>) -> tensor<3x3xf32> { - // CHECK: tfrt_fallback_async.executeop.seq({{.*}}) key({{.*}}) cost({{1}}) device({{.*}}) "tf.ReadVariableOp" - // CHECK: tfrt_fallback_async.executeop key({{.*}}) cost({{1}}) device({{.*}}) "tf.MatMul" - %0 = "tf.ReadVariableOp"(%arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0", dtype = f32} : (tensor>>) -> tensor<1x3xf32> - %1 = "tf.MatMul"(%arg0, %0) {T = f32, device = "/job:localhost/replica:0/task:0/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> - func.return %1 : tensor<3x3xf32> -} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc index 47c0277670b..0ec42b59fd6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc @@ -82,80 +82,6 @@ struct HoistInfo { hoisted_values; }; -void ReplaceHoistedValues( - llvm::ArrayRef> - hoisted_values, - mlir::OpBuilder &builder) { - struct HoistedValueInfo { - llvm::SmallVector hoisted_values; - llvm::SmallVector indices; - llvm::SmallVector shared_names; - llvm::SmallVector containers; - }; - // Rearrange the hoisted values by each function and each device. - llvm::DenseMap> - hoisted_values_by_block_device; - - // Find a block where to place tf._TfrtGetResource operation. We do not place - // get resource operations inside the `tf_device.cluster` operations, because - // these blocks are intended for later on-device compilation. Insert resource - // reads to the closest block outside of the `tf_device.cluster` operation. - auto hoist_into_block = [](mlir::Value value) -> mlir::Block * { - mlir::Operation *cluster_op = - value.getDefiningOp()->getParentOfType(); - return cluster_op ? cluster_op->getBlock() : value.getParentBlock(); - }; - - for (auto iter : llvm::enumerate(hoisted_values)) { - auto value = iter.value().first; - auto index = iter.index(); - auto &device_map = hoisted_values_by_block_device[hoist_into_block(value)]; - - assert(value.getDefiningOp() && "hoisted values must not be arguments."); - llvm::StringRef device = kCpuDeviceName; - if (auto device_attr = - value.getDefiningOp()->getAttrOfType("device")) { - if (!device_attr.getValue().empty()) device = device_attr.getValue(); - } - - auto &item = device_map[device]; - - item.hoisted_values.push_back(value); - item.indices.push_back(index); - item.shared_names.push_back(iter.value().second.name); - item.containers.push_back(iter.value().second.container); - } - - // Create tf._TfrtGetResource op for each function and device. - for (const auto &block_iter : hoisted_values_by_block_device) { - auto *block = block_iter.first; - const auto &device_map = block_iter.second; - - builder.setInsertionPointToStart(block); - for (const auto &device_iter : device_map) { - llvm::StringRef device = device_iter.getKey(); - mlir::ValueRange old_values = device_iter.getValue().hoisted_values; - const auto &indices = device_iter.getValue().indices; - const auto &shared_name_arr = device_iter.getValue().shared_names; - const auto &container_arr = device_iter.getValue().containers; - - auto get_resource_op = builder.create( - block->getParentOp()->getLoc(), old_values.getTypes(), - builder.getI64ArrayAttr(indices), - builder.getStrArrayAttr(shared_name_arr), - builder.getStrArrayAttr(container_arr)); - get_resource_op->setAttr("device", builder.getStringAttr(device)); - - auto new_values = get_resource_op.getResults(); - for (auto iter : llvm::zip(old_values, new_values)) { - auto old_value = std::get<0>(iter); - auto new_value = std::get<1>(iter); - old_value.replaceAllUsesWith(new_value); - } - } - } -} - bool OnlyHasReadOrNoEffect(mlir::Operation *op) { auto interface = llvm::dyn_cast(op); if (!interface) return false; @@ -275,136 +201,35 @@ void HoistInvariantOpsInFunction( } } +void FindCalleesRecursiveForOp(const mlir::SymbolTable &symbol_table, + mlir::Operation *op, + llvm::StringSet<> &callees) { + for (const auto &named_attr : op->getAttrs()) { + if (auto symbol_attr = + named_attr.getValue().dyn_cast()) { + auto symbol = symbol_attr.getValue(); + if (!callees.contains(symbol)) { + callees.insert(symbol); + + auto func = symbol_table.lookup(symbol); + if (!func) continue; + + func.walk([&](mlir::Operation *op) { + FindCalleesRecursiveForOp(symbol_table, op, callees); + }); + } + } + } +} + void FindCalleesRecursive(const mlir::SymbolTable &symbol_table, mlir::func::FuncOp func, llvm::StringSet<> &callees) { assert(func); func.walk([&](mlir::Operation *op) { - for (const auto &named_attr : op->getAttrs()) { - if (auto symbol_attr = - named_attr.getValue().dyn_cast()) { - auto symbol = symbol_attr.getValue(); - if (!callees.contains(symbol)) { - callees.insert(symbol); - - auto func = symbol_table.lookup(symbol); - if (!func) continue; - - FindCalleesRecursive(symbol_table, func, callees); - } - } - } + FindCalleesRecursiveForOp(symbol_table, op, callees); }); } -void HoistInvariantOps(mlir::ModuleOp module) { - mlir::SymbolTable symbol_table(module); - - // Find all resources used in non-init functions. - llvm::DenseMap> - resources; - - // Find all callees referenced in the initialization functions. - llvm::StringSet<> init_callees; - - module.walk([&](mlir::Operation *op) { - if (llvm::isa(op)) { - auto func = op->getParentOfType(); - if (IsSessionInitializer(func)) return; - resources[GetResourceHandle(op)].push_back(op); - } else if (auto func = llvm::dyn_cast(op)) { - if (!IsSessionInitializer(func)) return; - FindCalleesRecursive(symbol_table, func, init_callees); - } - }); - - llvm::DenseSet read_only_vars; - for (const auto &iter : resources) { - const auto &key = iter.first; - const auto &vars = iter.second; - if (std::all_of(vars.begin(), vars.end(), [](mlir::Operation *op) { - for (auto *user : op->getUsers()) { - if (!OnlyHasReadOrNoEffect(user)) return false; - } - return true; - })) { - read_only_vars.insert(key); - } - } - - mlir::TF::SideEffectAnalysis side_effect_analysis(module); - - mlir::OpBuilder builder(&module.getBodyRegion()); - // "_tfrt_resource_init" is the special function that executes all invariant - // ops (eg. read-only variables) used in the model. This function should be - // executed after user-specified initialization. - auto init_func_op = builder.create( - module.getLoc(), "_tfrt_resource_init", - mlir::FunctionType::get(module.getContext(), /*inputs=*/{}, - /*results=*/{})); - auto *block = init_func_op.addEntryBlock(); - builder.setInsertionPointToStart(block); - - HoistInfo module_hoist_info; - - for (auto func : module.getOps()) { - // Skips hoisting if this function is an init function or any callees, - // including recursive ones, of an init functions, because otherwise the - // hoisted values won't be initialized when this function is called. - if (IsSessionInitializer(func) || - init_callees.contains(func.getSymName()) || func == init_func_op) - continue; - - // Skips hoisting if this function runs on TPU. This is will happen when - // fallback to TPUPartitionedCallOp is enabled for SPMD. - // TODO(b/214039254): remove this once tfrt support native SPMD. - bool has_tpu_op = false; - func.walk([&has_tpu_op](mlir::Operation *op) { - if (op->hasAttr("_tpu_replicate")) has_tpu_op = true; - }); - if (has_tpu_op) continue; - - HoistInvariantOpsInFunction(func, read_only_vars, - side_effect_analysis.GetAnalysisForFunc(func), - builder, module_hoist_info); - } - - // Create tf._TfrtSetResource ops in the init function. - for (auto iter : llvm::enumerate(module_hoist_info.hoisted_values)) { - mlir::Value value = iter.value().first; - int64_t index = iter.index(); - - auto new_value = module_hoist_info.value_mapping.lookup(value); - auto *new_op = new_value.getDefiningOp(); - assert(new_op); - builder.setInsertionPointAfter(new_op); - auto set_resource_op = builder.create( - new_op->getLoc(), new_value, index); - - // Preserve the device attribute. - llvm::StringRef device = kCpuDeviceName; - if (auto device_attr = new_op->getAttrOfType("device")) { - if (!device_attr.getValue().empty()) device = device_attr.getValue(); - } - set_resource_op->setAttr("device", builder.getStringAttr(device)); - } - - builder.setInsertionPointToEnd(block); - // Finish building the init function by inserting an return op. - builder.create(init_func_op.getLoc()); - - // Now that we have the index for each value that will be replaced, we can - // create the tf._TfrtGetResource op in each function using these indices. - ReplaceHoistedValues(module_hoist_info.hoisted_values, builder); - - // Lastly, erase the hoisted ops in reverse topological order. - for (auto *op : - llvm::reverse(module_hoist_info.hoists_in_topological_order)) { - assert(op->use_empty()); - op->erase(); - } -} - // This pass rewrites tf_saved_model dialect's ops according to TFRT's // requirements: // @@ -416,11 +241,17 @@ void HoistInvariantOps(mlir::ModuleOp module) { class LowerTFSavedModelPass : public mlir::PassWrapper> { + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerTFSavedModelPass) - explicit LowerTFSavedModelPass(bool hoist_invariant_ops) { + explicit LowerTFSavedModelPass(bool hoist_invariant_ops, + bool fuse_get_resource_ops) { hoist_invariant_ops_ = hoist_invariant_ops; + fuse_get_resource_ops_ = fuse_get_resource_ops; } LowerTFSavedModelPass() = default; LowerTFSavedModelPass(const LowerTFSavedModelPass &) {} @@ -512,11 +343,230 @@ class LowerTFSavedModelPass } private: + void HoistInvariantOps(mlir::ModuleOp module); + void ReplaceHoistedValues( + llvm::ArrayRef> + hoisted_values, + mlir::OpBuilder &builder); + Option hoist_invariant_ops_{*this, "hoist-invariant-ops", llvm::cl::desc("hoist-invariant-ops"), llvm::cl::init(false)}; + Option fuse_get_resource_ops_{*this, "fuse-get-resource-ops", + llvm::cl::desc("fuse get resource ops"), + llvm::cl::init(true)}; }; +void LowerTFSavedModelPass::HoistInvariantOps(mlir::ModuleOp module) { + mlir::SymbolTable symbol_table(module); + + // Find all resources used in non-init functions. + llvm::DenseMap> + resources; + + // Find all callees referenced in the initialization functions. + llvm::StringSet<> init_callees; + + // Recursively find all callees referenced in the tf.XlaLaunch op. + // At and after the point of calling this pass, the MLIR xla function is no + // longer used. So there is no point to do hoisting for xla functions. + llvm::StringSet<> xla_launch_callees; + + module.walk([&](mlir::Operation *op) { + if (llvm::isa(op)) { + auto func = op->getParentOfType(); + if (IsSessionInitializer(func)) return; + resources[GetResourceHandle(op)].push_back(op); + } else if (auto func = llvm::dyn_cast(op)) { + if (!IsSessionInitializer(func)) return; + FindCalleesRecursive(symbol_table, func, init_callees); + } else if (op->getName().getStringRef().str() == "tf.XlaLaunch") { + // TODO(b/275095412): Clean up MLIR XLA functions after they are written + // back to function library, so that we don't need to do special handling + // for those functions here. + FindCalleesRecursiveForOp(symbol_table, op, xla_launch_callees); + } + }); + + llvm::DenseSet read_only_vars; + for (const auto &iter : resources) { + const auto &key = iter.first; + const auto &vars = iter.second; + if (std::all_of(vars.begin(), vars.end(), [](mlir::Operation *op) { + for (auto *user : op->getUsers()) { + if (!OnlyHasReadOrNoEffect(user)) return false; + } + return true; + })) { + read_only_vars.insert(key); + } + } + + mlir::TF::SideEffectAnalysis side_effect_analysis(module); + + mlir::OpBuilder builder(&module.getBodyRegion()); + // "_tfrt_resource_init" is the special function that executes all invariant + // ops (eg. read-only variables) used in the model. This function should be + // executed after user-specified initialization. + auto init_func_op = builder.create( + module.getLoc(), "_tfrt_resource_init", + mlir::FunctionType::get(module.getContext(), /*inputs=*/{}, + /*results=*/{})); + auto *block = init_func_op.addEntryBlock(); + builder.setInsertionPointToStart(block); + + HoistInfo module_hoist_info; + + for (auto func : module.getOps()) { + // Skips hoisting if this function is an init function or any callees, + // including recursive ones, of an init functions, because otherwise the + // hoisted values won't be initialized when this function is called. + if (IsSessionInitializer(func) || + init_callees.contains(func.getSymName()) || func == init_func_op || + xla_launch_callees.contains(func.getSymName())) + continue; + + // Skips hoisting if this function runs on TPU. This is will happen when + // fallback to TPUPartitionedCallOp is enabled for SPMD. + // TODO(b/214039254): remove this once tfrt support native SPMD. + bool has_tpu_op = false; + func.walk([&has_tpu_op](mlir::Operation *op) { + if (op->hasAttr("_tpu_replicate")) has_tpu_op = true; + }); + if (has_tpu_op) continue; + + HoistInvariantOpsInFunction(func, read_only_vars, + side_effect_analysis.GetAnalysisForFunc(func), + builder, module_hoist_info); + } + + // Create tf._TfrtSetResource ops in the init function. + for (auto iter : llvm::enumerate(module_hoist_info.hoisted_values)) { + mlir::Value value = iter.value().first; + int64_t index = iter.index(); + + auto new_value = module_hoist_info.value_mapping.lookup(value); + auto *new_op = new_value.getDefiningOp(); + assert(new_op); + builder.setInsertionPointAfter(new_op); + auto set_resource_op = builder.create( + new_op->getLoc(), new_value, index); + + // Preserve the device attribute. + llvm::StringRef device = kCpuDeviceName; + if (auto device_attr = new_op->getAttrOfType("device")) { + if (!device_attr.getValue().empty()) device = device_attr.getValue(); + } + set_resource_op->setAttr("device", builder.getStringAttr(device)); + } + + builder.setInsertionPointToEnd(block); + // Finish building the init function by inserting an return op. + builder.create(init_func_op.getLoc()); + + // Now that we have the index for each value that will be replaced, we can + // create the tf._TfrtGetResource op in each function using these indices. + ReplaceHoistedValues(module_hoist_info.hoisted_values, builder); + + // Lastly, erase the hoisted ops in reverse topological order. + for (auto *op : + llvm::reverse(module_hoist_info.hoists_in_topological_order)) { + assert(op->use_empty()); + op->erase(); + } +} + +void LowerTFSavedModelPass::ReplaceHoistedValues( + llvm::ArrayRef> + hoisted_values, + mlir::OpBuilder &builder) { + struct HoistedValueInfo { + llvm::SmallVector hoisted_values; + llvm::SmallVector indices; + llvm::SmallVector shared_names; + llvm::SmallVector containers; + }; + // Rearrange the hoisted values by each function and each device. + llvm::DenseMap> + hoisted_values_by_block_device; + + // Find a block where to place tf._TfrtGetResource operation. We do not place + // get resource operations inside the `tf_device.cluster` operations, because + // these blocks are intended for later on-device compilation. Insert resource + // reads to the closest block outside of the `tf_device.cluster` operation. + auto hoist_into_block = [](mlir::Value value) -> mlir::Block * { + mlir::Operation *cluster_op = + value.getDefiningOp()->getParentOfType(); + return cluster_op ? cluster_op->getBlock() : value.getParentBlock(); + }; + + for (auto iter : llvm::enumerate(hoisted_values)) { + auto value = iter.value().first; + auto index = iter.index(); + auto &device_map = hoisted_values_by_block_device[hoist_into_block(value)]; + + assert(value.getDefiningOp() && "hoisted values must not be arguments."); + llvm::StringRef device = kCpuDeviceName; + if (auto device_attr = + value.getDefiningOp()->getAttrOfType("device")) { + if (!device_attr.getValue().empty()) device = device_attr.getValue(); + } + + auto &item = device_map[device]; + + item.hoisted_values.push_back(value); + item.indices.push_back(index); + item.shared_names.push_back(iter.value().second.name); + item.containers.push_back(iter.value().second.container); + } + + // Create tf._TfrtGetResource op for each function and device. + for (const auto &block_iter : hoisted_values_by_block_device) { + auto *block = block_iter.first; + const auto &device_map = block_iter.second; + + builder.setInsertionPointToStart(block); + for (const auto &device_iter : device_map) { + llvm::StringRef device = device_iter.getKey(); + mlir::ValueRange old_values = device_iter.getValue().hoisted_values; + const auto &indices = device_iter.getValue().indices; + const auto &shared_name_arr = device_iter.getValue().shared_names; + const auto &container_arr = device_iter.getValue().containers; + + llvm::SmallVector new_values; + + if (fuse_get_resource_ops_) { + auto get_resource_op = builder.create( + block->getParentOp()->getLoc(), old_values.getTypes(), + builder.getI64ArrayAttr(indices), + builder.getStrArrayAttr(shared_name_arr), + builder.getStrArrayAttr(container_arr)); + get_resource_op->setAttr("device", builder.getStringAttr(device)); + new_values = get_resource_op.getResults(); + } else { + for (int i = 0; i < old_values.size(); ++i) { + auto get_resource_op = builder.create( + block->getParentOp()->getLoc(), + mlir::TypeRange(old_values[i].getType()), + builder.getI64ArrayAttr(indices[i]), + builder.getStrArrayAttr(shared_name_arr[i]), + builder.getStrArrayAttr(container_arr[i])); + get_resource_op->setAttr("device", builder.getStringAttr(device)); + new_values.append(get_resource_op->result_begin(), + get_resource_op->result_end()); + } + } + + for (auto iter : llvm::zip(old_values, new_values)) { + auto old_value = std::get<0>(iter); + auto new_value = std::get<1>(iter); + old_value.replaceAllUsesWith(new_value); + } + } + } +} + static llvm::SmallVector CompareTypes(mlir::TypeRange x, mlir::TypeRange y) { llvm::SmallVector results; @@ -672,8 +722,10 @@ void ConvertReferenceVariableToResourceVariablePass::runOnOperation() { } // namespace std::unique_ptr> -CreateLowerTFSavedModelPass(bool hoist_invariant_ops) { - return std::make_unique(hoist_invariant_ops); +CreateLowerTFSavedModelPass(bool hoist_invariant_ops, + bool fuse_get_resource_ops) { + return std::make_unique(hoist_invariant_ops, + fuse_get_resource_ops); } std::unique_ptr> diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc new file mode 100644 index 00000000000..2eda5bfd0e9 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -0,0 +1,238 @@ + +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace { + +// Assigns devices so that later passes can utilize device information. +// Device assignment might have not been done by the upstream pipeline, or get +// removed by previous passes. However, we assume most of the device assignment +// has been done by the upstream pipeline, so we simply assign the default +// device to unassigned ops. Specifically, we do assignment for ConstOp first to +// place it on the same device as its user operation, instead of placing it on +// the default device blindly. +// TODO(b/221297389): Figure out a more robust way to handle dropped device +// assignment. +void AddTfDeviceAssignmentPasses(mlir::OpPassManager &pm, + const TfrtPipelineOptions &options) { + pm.addPass(mlir::TF::CreateConstantOpDeviceAssignmentPass()); + pm.addNestedPass( + mlir::TF::CreateTFDeviceAssignmentByFuncAttrPass()); + pm.addNestedPass( + mlir::TF::CreateSimpleTFDeviceAssignmentPass(options.default_device)); +} + +} // namespace + +void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( + mlir::OpPassManager &pm, const TfrtPipelineOptions &options) { + // Due to b/191304670, functionalized while ops might not have the + // shape_invariant attribute set correctly, which leads to failure in shape + // inference. As a workaround, we conservatively (e.g., we place less + // restrictions on tf.while which will avoid failures but lead to potentially + // less exact shape inference) set the shape_invariant attribute in all + // tf.While ops before performing shape inference. + // + // Note that this pass might not work well with TF XLA bridge, but this is + // fine as TF XLA bridge is run before this pipeline. For CPU ops, less exact + // shape inference may lead to fewer optimizations but it should be fine as it + // is limited to while ops currently. + // + // TODO(b/191304670): Remove this pass once the shape_invariant attribute is + // set correctly in the upstream. + pm.addNestedPass( + tfrt_compiler::CreateSetShapeInvariantInWhileOps()); + + // We pass the MLIR module through the TF standard pipeline, which for + // instances does shape inference, canonicalization, inlining, etc. + pm.addNestedPass( + mlir::tf_executor::CreateTFExecutorGraphPruningPass()); + pm.addNestedPass( + mlir::tf_executor::CreateTFExecutorIslandCoarseningPass()); + + AddTfDeviceAssignmentPasses(pm, options); + + pm.addPass(tfrt_compiler::CreateTfrtXlaRewritePass()); + + // Here we perform TFRT specific optimization before standard TF optimization, + // as TFRT-specific optimization may create more opportunities. + pm.addNestedPass( + tfrt_compiler::CreateOptimizeTfForTfrtPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + // Guarantee all functions have one use, which enables more exact shape + // inference. + pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addNestedPass(mlir::TF::CreateTFOptimizePass()); + pm.addNestedPass(mlir::createCSEPass()); + + AddTfDeviceAssignmentPasses(pm, options); + + // After the standard pass, we now have MLIR in TF dialect, and now we convert + // reference variable to resource variables, which is besteffort. + pm.addPass(CreateConvertReferenceVariableToResourceVariablePass()); + + // Move the tf.Assert op to the end of the function, so that it does not + // impose unnecessary control dependencies on other ops. + pm.addPass(tfrt_compiler::CreateReorderTfAssertPass()); + + // Optimze the side-effects of control flow ops by examining the ops in its + // callees. + pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass()); + + // Remove tf.If ops' operands that are produced by tf.Const ops. + pm.addPass(tfrt_compiler::CreateRemoveTfIfConstArgsPass()); + + // Merge non-side-effecting tf.If ops if their operands are the same. + pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass()); + + // Deduplicate functions invoked by tf.BatchFunction with the same + // shared_name + pm.addPass( + tfrt_compiler::CreateDeduplicateFunctionsInovkedByBatchFunctionPass()); + + // RemoveUnusedWhileResultsPass operates on the region-based control flow, so + // the functional control flow is first converted to region-based control + // flow, which is converted back after the optimization passes are performed. + pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); + pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass( + mlir::TF::CreateRemoveUnusedWhileResultsPass()); + pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + + // Apply standard optimization after optimizing control flow ops. + pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass(mlir::createCSEPass()); + + // TODO(b/187876545): An extra shape inference pass is added because it does + // not work well with tf.Identity op that remove ref type. So we work around + // by performing shape inference again after reference variable to resource + // variable conversion. We should remove this after b/187876545 is fixed. + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + + pm.addNestedPass( + mlir::TFDevice::CreateLaunchToDeviceAttributePass()); + + // After all standard passes run layout optimization to assign optimal data + // format for all layout sensitive operations. + mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; + layout_optimization_options.force_data_format = + options.force_data_format.getValue(); + // TODO(b/191304261): Folding transpose in ops is buggy in the layout + // optimization pass. Disable it to avoid errors in b/191304261. This should + // not affect CPU performance as it does not change the number of ops, nor + // does it change the types of the ops. + layout_optimization_options.skip_fold_transpose_in_ops = true; + mlir::TF::CreateLayoutOptimizationPipeline(pm.nest(), + layout_optimization_options); + + // Run canonicalization pipeline to remove unused constants and bypassed + // transpose operations left in the IR after layout optimization. + pm.addNestedPass(mlir::createCanonicalizerPass()); + + // Decompose resource ops as resource variables will be converted to tensors + // directly. + if (options.decompose_resource_ops) + pm.addNestedPass( + mlir::TFDevice::CreateDecomposeResourceOpsPass()); + + AddTfDeviceAssignmentPasses(pm, options); + + pm.addNestedPass( + mlir::TF::CreateTensorDeviceCopyConversionPass()); + + AddTfrtJitRtPasses(options, pm); + + // Rewriter operation sequences to device specific fusions. + DeviceNameUtils::ParsedName parsed_name; + + // Ignore error. + bool success = + DeviceNameUtils::ParseFullName(options.default_device, &parsed_name); + assert(success && "default device is invalid"); + (void)success; + + if (parsed_name.has_type && parsed_name.type == DEVICE_GPU) + pm.addNestedPass(mlir::TF::CreateGpuOpFusionPass()); + + if (parsed_name.has_type && parsed_name.type == DEVICE_CPU) + pm.addNestedPass( + mlir::TF::CreateFusedKernelMatcherPass()); + + if (options.tpu_fuse_ops) { + pm.addNestedPass( + tfrt_compiler::CreateFuseTpuCompileAndExecutePass()); + // Remove ops for the input to _TPUCompileMlirOp, which are no longer needed + // after CreateFuseTpuCompileAndExecutePass + pm.addNestedPass(mlir::createCanonicalizerPass()); + } + + AddTfDeviceAssignmentPasses(pm, options); +} + +void CreateTFExecutorToTFInvariantOptimizationPipelineHelper( + mlir::OpPassManager &pm, const TfrtPipelineOptions &options) { + if (options.sink_in_invariant_ops) { + pm.addPass(CreateSinkInInvariantOpsPass()); + } + + pm.addPass(CreateLowerTFSavedModelPass( + options.hoist_invariant_ops, options.fuse_get_resource_ops_in_hoisting)); +} + +Status ValidateTfrtPipelineOptions(const TfrtPipelineOptions &options) { + if (options.target_tpurt && + (options.target_gpu || options.use_bridge_for_gpu)) { + return tensorflow::errors::Internal( + "Invalid pipeline options. Targeting both TPU and GPU is not " + "supported."); + } + return OkStatus(); +} + +Status CreateTFExecutorToTFPreInvariantOptimizationPipeline( + mlir::PassManager &pm, const TfrtPipelineOptions &options) { + TF_RETURN_IF_ERROR(ValidateTfrtPipelineOptions(options)); + if (VLOG_IS_ON(1)) { + // Print the whole module after each pass, which requires disabling + // multi-threading as well. + pm.getContext()->disableMultithreading(); + pm.enableIRPrinting(std::make_unique( + /*print_module_scope=*/true)); + } + CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper(pm, options); + return OkStatus(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/passes.h index 01f1010788c..2502623adfa 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.h @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" -#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/core/platform/status.h" namespace mlir { class PassManager; @@ -88,7 +88,8 @@ CreateSinkInInvariantOpsPass(); // Create a pass that rewrites tf_saved_model dialect's ops according to TFRT's // requirements. std::unique_ptr> -CreateLowerTFSavedModelPass(bool hoist_invariant_ops); +CreateLowerTFSavedModelPass(bool hoist_invariant_ops, + bool fuse_get_resource_ops); // Create a pass that converts ref variables to resource variables in a limited // number of cases. @@ -116,19 +117,28 @@ CreateCrossDeviceTransferPass(); std::unique_ptr> CreateTfToTfrtConversionPass(const TfrtPipelineOptions& options); -// Creates a pipeline of passes that lowers MLIR TF Executor dialect to TF -// dialect for CoreRT purposes. -tsl::Status CreateTFExecutorToTFPipeline(mlir::PassManager& pm, - const TfrtPipelineOptions& options); - // Creates a pipeline of passes that lowers MLIR TF dialect to TFRT dialects. void CreateTfToTfrtPipeline(mlir::OpPassManager& pm, const TfrtPipelineOptions& options); // Creates a pipeline of passes that lowers MLIR TF dialect from tf.function to // TFRT dialect. SavedModel related conversions are not included. -tsl::Status CreateTfExecutorToTfrtPipeline(mlir::PassManager& pm, - const TfrtPipelineOptions& options); +Status CreateTfExecutorToTfrtPipeline(mlir::PassManager& pm, + const TfrtPipelineOptions& options); + +// Creates a pipeline of passes that lowers MLIR TF Executor dialect to TF +// dialect for CoreRT purposes. +Status CreateTFExecutorToTFPipeline(mlir::PassManager& pm, + const TfrtPipelineOptions& options); + +// TODO(deqiangc): refactor below helpers once mlrt is OSSed. +void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( + mlir::OpPassManager& pm, const TfrtPipelineOptions& options); +void CreateTFExecutorToTFInvariantOptimizationPipelineHelper( + mlir::OpPassManager& pm, const TfrtPipelineOptions& options); + +Status CreateTFExecutorToTFPreInvariantOptimizationPipeline( + mlir::PassManager& pm, const TfrtPipelineOptions& options); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index 3205154d0e5..15973c75a9c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -26,53 +26,41 @@ limitations under the License. #include "mlir/IR/Dialect.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" -#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h" #include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" -#include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.h" #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h" #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" #include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" #include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/tstring.h" -#include "tfrt/jitrt/opdefs/jitrt_ops.h" // from @tf_runtime #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime -#include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime #include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime @@ -94,7 +82,9 @@ constexpr int64_t kDefaultCheapCost = 1; void getDependentConversionDialects(mlir::DialectRegistry ®istry) { registry.insert(); + tfrt::compiler::TFRTDialect>(); + + RegisterJitRtDialects(registry); } mlir::Value GetFunctionInputChain(mlir::Operation *op) { @@ -237,7 +227,13 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { // called by a TPUPartitionedCall op and will be compiled in // TPUPartitionedCall op via FunctionLibraryRuntime and not be processed // by BEFExecutor. - is_tpu_op) { + // + // We also avoid creating tfrt_fallback_async.createop for all GPU ops + // except for tf.XlaLaunch. This is correct as long as we only run XLA + // clusters on GPU and all other ops on CPU. + is_tpu_op || + (parsed_device_name->device_type == DEVICE_GPU && + op->getName().getStringRef().str() != "tf.XlaLaunch")) { return ConvertToCoreRTExecuteOp( op, operands, parsed_device_name->op_handler_name, op_attrs, op_func_attrs, op_name, rewriter); @@ -450,7 +446,7 @@ class FallbackConstOpConversion tensorflow::TensorProto tensor_proto; auto status = ConvertToTensorProto(op.getValue(), &tensor_proto); - if (!status.ok()) return op.emitError(status.error_message()); + if (!status.ok()) return op.emitError(tsl::NullTerminatedMessage(status)); rewriter.replaceOpWithNewOp( op, rewriter.getType(), @@ -812,7 +808,7 @@ class CoreRTConstStringTensorOpConversion llvm::StringRef(element.data(), element.size()))); // Create the shape attribute from the tensor shape. - ArrayRef shape = op.getValue().getType().getShape(); + ArrayRef shape = op.getValue().getShapedType().getShape(); llvm::SmallVector dims; dims.reserve(shape.size()); auto i64_type = rewriter.getIntegerType(64); @@ -1414,43 +1410,6 @@ mlir::func::FuncOp TFRTWhileOpConversion::GetWhileBodyFunction( return body_fn; } -// TODO(ezhulenev): tf_device.cluster operations after auto-fusion should -// have the correct device assigned based on the fused operations. We should -// use this device to convert operands and results from/to corert handles. -// For now it is safe to assume that it is "CPU" because we do not support -// any other devices and do not support distributed models. -constexpr char kJitRtDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0"; - -// Convert jitrt.call operations to the tf_jitrt.fallback.execute operation. -class JitRtCallToJitRtCompileAndExecuteConversion - : public OpConversionPattern { - public: - explicit JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext *context) - : OpConversionPattern(context) {} - - LogicalResult matchAndRewrite( - tfrt::jitrt::CallOp call, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Convert operands to fallback tensors. - llvm::SmallVector fallback_operands; - if (failed(tfrt_compiler::ConvertFallbackOperands( - call, kJitRtDevice, adaptor.getOperands(), &fallback_operands, - rewriter))) - return rewriter.notifyMatchFailure(call, "failed to convert operand"); - - // tf_jitrt.fallback.execute always produces fallback tensors. - llvm::SmallVector result_types( - call->getNumResults(), - rewriter.getType()); - - // Replace jitrt.call operation with a tf_jitrt.fallback.execute operation. - rewriter.replaceOpWithNewOp( - call, result_types, call.getCallee(), fallback_operands, kJitRtDevice); - - return success(); - } -}; - // Helper function for specifying legal dialects for conversion to CoreRT. void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target, mlir::TypeConverter *func_type_converter, @@ -1459,10 +1418,8 @@ void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target, target->addLegalDialect(); target->addLegalDialect(); target->addLegalDialect(); - target->addLegalDialect(); target->addIllegalDialect(); target->addIllegalDialect(); - target->addIllegalDialect(); target->addDynamicallyLegalOp([func_type_converter, chain_type]( func::FuncOp op) { @@ -1477,14 +1434,6 @@ void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target, }); } -// Helper function for inserting TFRT JitRt dialect conversions. -void PopulateJitRtConversionPatterns(MLIRContext *context, - RewritePatternSet *patterns, - CoreRTConverter *corert_converter) { - // Lower jitrt.call to the pair of compile and execute operations. - patterns->add(context); -} - // Helper function for inserting TF dialect to TFRT dialect op conversion // patterns. void PopulateTFToTFRTConversionPatterns( @@ -1613,7 +1562,9 @@ class TfToTfrtConversionPass } SetUpTFToTFRTConversionLegality(&target, func_type_converter, corert_converter.chain_type()); - PopulateJitRtConversionPatterns(&context, &patterns, &corert_converter); + + PopulateJitRtConversionPatterns(&target, &context, &patterns, + &corert_converter); PopulateTFToTFRTConversionPatterns( &context, &patterns, &corert_converter, &fallback_converter, @@ -1737,31 +1688,8 @@ class TfToTfrtConversionPass chain_value = create_op; } - // Pre-compile all JIT compiled kernels found in the module. - llvm::SmallVector compiled; - - // A set SymbolRef attributes referencing compiled kernels. - llvm::DenseSet kernels; - - // Compile all kernels in parallell. - module.walk([&](tf_jitrt::FallbackExecuteOp execute) { - // Do not compiled the same kernel multiple times. - if (kernels.contains(execute.getKernel())) return; - - auto compile = builder.create( - execute.getLoc(), chain_type, execute.getKernel(), - execute.getDevice()); - compiled.push_back(compile.getResult()); - kernels.insert(compile.getKernel()); - }); - - // Wait for the compilation completion before returning from init function. - if (!compiled.empty()) { - // Do not forget to wait for the fallback kernels initialization. - compiled.insert(compiled.begin(), chain_value); - chain_value = builder.create( - func_op.getLoc(), chain_type, compiled); - } + chain_value = + CreateJitRtFallbackCompileKernel(builder, module, chain_value); builder.create(func_op.getLoc(), chain_value); } @@ -1877,25 +1805,6 @@ class TfToTfrtConversionPass "currently experimental."), llvm::cl::init(false)}; }; - -// Assigns devices so that later passes can utilize device information. -// Device assignement might have not been done by the upstream pipeline, or get -// removed by previous passes. However, we assume most of the device assignment -// has been done by the upstream pipeline, so we simply assign the default -// device to unassigned ops. Specifically, we do assignment for ConstOp first to -// place it on the same device as its user operation, instead of placing it on -// the default device blindly. -// TODO(b/221297389): Figure out a more robust way to handle dropped device -// assignment. -void AddTfDeviceAssignmentPasses(mlir::OpPassManager &pm, - const TfrtPipelineOptions &options) { - pm.addPass(mlir::TF::CreateConstantOpDeviceAssignmentPass()); - pm.addNestedPass( - mlir::TF::CreateTFDeviceAssignmentByFuncAttrPass()); - pm.addNestedPass( - mlir::TF::CreateSimpleTFDeviceAssignmentPass(options.default_device)); -} - } // namespace std::unique_ptr> @@ -1904,425 +1813,6 @@ CreateTfToTfrtConversionPass(const TfrtPipelineOptions &options) { } // -------------------------------------------------------------------------- // -// Outline tf_device.cluster operation regions into functions in the nested -// modules and replaces all cluster operations with jitrt.call operations. -// -------------------------------------------------------------------------- // - -class OutlineJitRtClustersPass - : public PassWrapper> { - public: - llvm::StringRef getArgument() const final { - return "tf-outline-jitrt-cluster"; - } - llvm::StringRef getDescription() const final { - return "Outlines `tf_device.cluster` operations into functions and " - "replaces them with `jitrt.call` operations."; - } - - void runOnOperation() override; - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } - - private: - struct CompiledModule { - ModuleOp module; - func::FuncOp entrypoint; - llvm::SetVector operands; - }; - - // Creates a nested module with a single function that will be compiled into - // the kernel at runtime. - CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table); - - // Update compiled module entrypoint signature with inferred operands - // constraints. - LogicalResult SetEntrypointConstraints(CompiledModule &compiled); - - // Outlines cluster operation regions into compiled modules, and replaces - // cluster operation with a jitrt.call operation. - LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table); - - // Mapping from the outlined module string representation to the module itself - // and an entrypoint function. Used to deduplicate identical modules during - // the `tf_device.cluster` outlining. - llvm::StringMap> outlined_; -}; - -OutlineJitRtClustersPass::CompiledModule -OutlineJitRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster, - int64_t max_arg_size, - SymbolTable *symbol_table) { - MLIRContext *ctx = cluster->getContext(); - Location loc = cluster.getLoc(); - - // Create a module that will hold compiled function and async wrappers. - // TODO(ezhulenev): Give better names to module and function. - auto compiled_module = ModuleOp::create(loc, {"kernel"}); - compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx)); - compiled_module->setAttr( - "tfrt.max-arg-size", - IntegerAttr::get(IntegerType::get(ctx, 64), max_arg_size)); - - SymbolTable compiled_module_symbol_table(compiled_module); - - // Find out the cluster arguments and their types. - llvm::SetVector live_ins; - getUsedValuesDefinedAbove(cluster.getBody(), cluster.getBody(), live_ins); - - llvm::SmallVector operand_types; - operand_types.reserve(live_ins.size()); - for (Value v : live_ins) operand_types.emplace_back(v.getType()); - - // Create a function in the compiled module. - auto compiled_func_type = - FunctionType::get(ctx, operand_types, cluster->getResultTypes()); - auto compiled_func = func::FuncOp::create(loc, "compute", compiled_func_type); - compiled_module_symbol_table.insert(compiled_func); - - // Replace uses of live-in values within cluster region with block arguments. - Block *compiled_func_block = compiled_func.addEntryBlock(); - for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments())) - replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), - cluster.getBody()); - - // Move all operations in cluster into compiled_func's entry block. - auto &cluster_body = cluster.GetBody().getOperations(); - compiled_func_block->getOperations().splice( - compiled_func_block->end(), cluster_body, cluster_body.begin(), - cluster_body.end()); - - // Replace `tf_device.return` terminator with `func.return` in the function - // body. - auto device_return = - cast(compiled_func_block->getTerminator()); - OpBuilder builder(device_return.getOperation()); - builder.create(device_return.getLoc(), - device_return.getOperands()); - device_return.erase(); - - // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet, - // replace module printing with a more principled solution when available. - // Operations in the cluster can be in different order, however define the - // identical Tensorflow programs, with current approach we'll not be able - // to detect duplicates like this. - - // Remove location attribute attached to Tensorflow operations to be able to - // deduplicate compiled clusters with the same set of operations. - // - // TODO(ezhulenev): Figure out how to propagate locations for error reporting, - // right now JitRt will ignore them anyway. - compiled_module.walk([](Operation *op) { op->removeAttr("_class"); }); - - // Serialize prepared module to string. - std::string serialized; - llvm::raw_string_ostream os(serialized); - compiled_module.print(os); - - // Try to find if identical module was already outlined. - auto it = outlined_.find(serialized); - - // Return identical module that was already outlined earlier. - if (it != outlined_.end()) { - compiled_module.erase(); // erase identical module - return {it->second.first, it->second.second, live_ins}; - } - - // Insert compiled module into the symbol table and assign it a unique name. - symbol_table->insert(compiled_module); - - // Cache unique module. - outlined_.insert({std::move(serialized), {compiled_module, compiled_func}}); - - return {compiled_module, compiled_func, live_ins}; -} - -LogicalResult OutlineJitRtClustersPass::SetEntrypointConstraints( - CompiledModule &compiled) { - func::FuncOp func = compiled.entrypoint; - - // Functions outlined from jitrt device clusters must have a single block. - assert(func.getBody().getBlocks().size() == 1 && "expected single block"); - - mlir::TFDevice::ClusteringPolicySet policies; - populateTfJitRtConstraintsPolicies(policies); - - // Infer constraints on the values defined in the entrypoint function - // (including function entry block arguments). - mlir::TFDevice::ValuesConstraintSet constraints; - if (failed(mlir::TFDevice::PropagateValuesConstraints( - func.getBody(), policies, constraints, /*resolve=*/true))) - return failure(); - - // Annotate arguments with inferred constraints. - for (unsigned i = 0; i < func.getNumArguments(); ++i) { - if (auto constraint = constraints.GetConstraint(func.getArgument(i))) { - auto constraint_name = mlir::StringAttr::get( - &getContext(), llvm::formatv("{0}", *constraint).str()); - func.setArgAttr(i, "rt.constraint", constraint_name); - } - } - - return success(); -} - -LogicalResult OutlineJitRtClustersPass::OutlineClusterOp( - tf_device::ClusterOp cluster, int64_t max_arg_size, - SymbolTable *symbol_table) { - Location loc = cluster->getLoc(); - OpBuilder builder(cluster); - - CompiledModule compiled_module = - CreateCompiledModule(cluster, max_arg_size, symbol_table); - func::FuncOp compiled_func = compiled_module.entrypoint; - - // Add constraints to the entrypoint arguments. - if (failed(SetEntrypointConstraints(compiled_module))) return failure(); - - // Replace device cluster with a jitrt.call operation. - auto module_name = *compiled_module.module.getSymName(); - auto func_name = compiled_func.getSymName(); - auto func_flat_ref = - mlir::SymbolRefAttr::get(builder.getContext(), func_name); - auto func_ref = mlir::SymbolRefAttr::get(builder.getContext(), module_name, - {func_flat_ref}); - - auto cluster_func_op = builder.create( - loc, cluster.getResultTypes(), func_ref, - compiled_module.operands.getArrayRef()); - - cluster.replaceAllUsesWith(cluster_func_op); - cluster.erase(); - - return success(); -} - -void OutlineJitRtClustersPass::runOnOperation() { - ModuleOp module = getOperation(); - SymbolTable symbol_table(module); - - // Keep track of the maximum argument size for each function with tf_device - // cluster operations in the function body. We need to pass it to the compiled - // module to correctly compute its cost later. - llvm::DenseMap max_arg_size_map; - - auto get_max_arg_size = [&](mlir::func::FuncOp func) -> int64_t { - auto it = max_arg_size_map.find(func); - if (it != max_arg_size_map.end()) return it->second; - return max_arg_size_map[func] = tf_jitrt::GetMaxArgSize(func); - }; - - OpBuilder builder(module.getContext()); - auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult { - // Ensure that cluster was formed for TFRT JIT compilation. - auto policy = cluster->getAttr("policy").dyn_cast_or_null(); - if (!policy || policy.getValue() != "tfrt.auto-fusion") - return WalkResult::advance(); - - // Get the maximum argument size of the parent function. - mlir::func::FuncOp parent_func = - cluster->getParentOfType(); - int64_t max_arg_size = get_max_arg_size(parent_func); - - if (failed(OutlineClusterOp(cluster, max_arg_size, &symbol_table))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) { - module->emitError("Failed to outline tf_device.cluster operations"); - signalPassFailure(); - } -} - -static std::unique_ptr CreateOutlineJitRtClustersPass() { - return std::make_unique(); -} - -// -------------------------------------------------------------------------- // - -static void CreateTFExecutorToTFPipelineHelper( - mlir::OpPassManager &pm, const TfrtPipelineOptions &options) { - // Due to b/191304670, functionalized while ops might not have the - // shape_invariant attribute set correctly, which leads to failure in shape - // inference. As a workaround, we conservatively (e.g., we place less - // restrictions on tf.while which will avoid failures but lead to potentially - // less exact shape inference) set the shape_invariant attribute in all - // tf.While ops before performing shape inference. - // - // Note that this pass might not work well with TF XLA bridge, but this is - // fine as TF XLA bridge is run before this pipeline. For CPU ops, less exact - // shape inference may lead to fewer optimizations but it should be fine as it - // is limited to while ops currently. - // - // TODO(b/191304670): Remove this pass once the shape_invariant attribute is - // set correctly in the upstream. - pm.addNestedPass( - tfrt_compiler::CreateSetShapeInvariantInWhileOps()); - - // We pass the MLIR module through the TF standard pipeline, which for - // instances does shape inference, canonicalization, inlining, etc. - pm.addNestedPass( - mlir::tf_executor::CreateTFExecutorGraphPruningPass()); - pm.addNestedPass( - mlir::tf_executor::CreateTFExecutorIslandCoarseningPass()); - - AddTfDeviceAssignmentPasses(pm, options); - - pm.addPass(tfrt_compiler::CreateTfrtXlaRewritePass()); - - // Here we perform TFRT specific optimization before standard TF optimization, - // as TFRT-specific optimization may create more opportunities. - pm.addNestedPass( - tfrt_compiler::CreateOptimizeTfForTfrtPass()); - pm.addNestedPass(mlir::createCanonicalizerPass()); - // Guarantee all functions have one use, which enables more exact shape - // inference. - pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); - pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mlir::createSymbolDCEPass()); - pm.addNestedPass(mlir::TF::CreateTFOptimizePass()); - pm.addNestedPass(mlir::createCSEPass()); - - AddTfDeviceAssignmentPasses(pm, options); - - // After the standard pass, we now have MLIR in TF dialect, and now we convert - // reference variable to resource variables, which is besteffort. - pm.addPass(CreateConvertReferenceVariableToResourceVariablePass()); - - // Move the tf.Assert op to the end of the function, so that it does not - // impose unnecessary control dependencies on other ops. - pm.addPass(tfrt_compiler::CreateReorderTfAssertPass()); - - // Optimze the side-effects of control flow ops by examining the ops in its - // callees. - pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass()); - - // Remove tf.If ops' operands that are produced by tf.Const ops. - pm.addPass(tfrt_compiler::CreateRemoveTfIfConstArgsPass()); - - // Merge non-side-effecting tf.If ops if their operands are the same. - pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass()); - - // Deduplicate functions invoked by tf.BatchFunction with the same - // shared_name - pm.addPass( - tfrt_compiler::CreateDeduplicateFunctionsInovkedByBatchFunctionPass()); - - // RemoveUnusedWhileResultsPass operates on the region-based control flow, so - // the functional control flow is first converted to region-based control - // flow, which is converted back after the optimization passes are performed. - pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); - pm.addPass(mlir::createInlinerPass()); - pm.addNestedPass( - mlir::TF::CreateRemoveUnusedWhileResultsPass()); - pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); - - // Apply standard optimization after optimizing control flow ops. - pm.addPass(mlir::createInlinerPass()); - pm.addNestedPass(mlir::createCSEPass()); - - // TODO(b/187876545): An extra shape inference pass is added because it does - // not work well with tf.Identity op that remove ref type. So we work around - // by performing shape inference again after reference variable to resource - // variable conversion. We should remove this after b/187876545 is fixed. - pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - - pm.addNestedPass( - mlir::TFDevice::CreateLaunchToDeviceAttributePass()); - - // After all standard passes run layout optimization to assign optimal data - // format for all layout sensitive operations. - mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; - layout_optimization_options.force_data_format = - options.force_data_format.getValue(); - // TODO(b/191304261): Folding transpose in ops is buggy in the layout - // optimization pass. Disable it to avoid errors in b/191304261. This should - // not affect CPU performance as it does not change the number of ops, nor - // does it change the types of the ops. - layout_optimization_options.skip_fold_transpose_in_ops = true; - mlir::TF::CreateLayoutOptimizationPipeline(pm.nest(), - layout_optimization_options); - - // Run canonicalization pipeline to remove unused constants and bypassed - // transpose operations left in the IR after layout optimization. - pm.addNestedPass(mlir::createCanonicalizerPass()); - - // Decompose resource ops as resource variables will be converted to tensors - // directly. - if (options.decompose_resource_ops) - pm.addNestedPass( - mlir::TFDevice::CreateDecomposeResourceOpsPass()); - - AddTfDeviceAssignmentPasses(pm, options); - - pm.addNestedPass( - mlir::TF::CreateTensorDeviceCopyConversionPass()); - - // Outline auto-fusion clusters into tf_device.cluster_operations and then - // convert them to functions. We currently support only tfrt fallback tensors - // as operands, so we disable these passes if we can have native ops after - // lowering. - pm.addNestedPass(CreateTfJitRtClusteringPass( - options.auto_fusion_oplist, options.auto_fusion_min_cluster_size)); - - // Sink small constants into the outlined clusters to reduce the number of - // arguments for each of the execute operations. - auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster, - mlir::ElementsAttr value) -> bool { - // Ensure that cluster was formed for TFRT JIT compilation. - auto policy = cluster->getAttr("policy").dyn_cast_or_null(); - if (!policy || policy.getValue() != "tfrt.auto-fusion") return false; - - // Check that TF->JitRt compiler supports constant compilation. - return mlir::succeeded(IsCompilableConstant(value)); - }; - - pm.addNestedPass( - mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const)); - - // Outline formed JIT compiled device clusters into function. - pm.addPass(CreateOutlineJitRtClustersPass()); - - // Rewriter operation sequences to device specific fusions. - DeviceNameUtils::ParsedName parsed_name; - - // Ignore error. - bool success = - DeviceNameUtils::ParseFullName(options.default_device, &parsed_name); - assert(success && "default device is invalid"); - (void)success; - - if (parsed_name.has_type && parsed_name.type == DEVICE_GPU) - pm.addNestedPass(mlir::TF::CreateGpuOpFusionPass()); - - if (parsed_name.has_type && parsed_name.type == DEVICE_CPU) - pm.addNestedPass( - mlir::TF::CreateFusedKernelMatcherPass()); - - if (options.tpu_fuse_ops) { - pm.addNestedPass( - tfrt_compiler::CreateFuseTpuCompileAndExecutePass()); - // Remove ops for the input to _TPUCompileMlirOp, which are no longer needed - // after CreateFuseTpuCompileAndExecutePass - pm.addNestedPass(mlir::createCanonicalizerPass()); - } - - AddTfDeviceAssignmentPasses(pm, options); - - if (options.sink_in_invariant_ops) { - pm.addPass(CreateSinkInInvariantOpsPass()); - } - - pm.addPass(CreateLowerTFSavedModelPass(options.hoist_invariant_ops)); -} - void CreateTfToTfrtPipeline(mlir::OpPassManager &pm, const TfrtPipelineOptions &options) { pm.addPass(CreateTfToTfrtConversionPass(options)); @@ -2341,49 +1831,33 @@ void CreateTfToTfrtPipeline(mlir::OpPassManager &pm, static void CreateTfExecutorToTfrtPipelineHelper( mlir::OpPassManager &pm, const TfrtPipelineOptions &options) { - CreateTFExecutorToTFPipelineHelper(pm, options); + CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper(pm, options); + CreateTFExecutorToTFInvariantOptimizationPipelineHelper(pm, options); CreateTfToTfrtPipeline(pm, options); } -Status ValidateTfrtPipelineOptions(const TfrtPipelineOptions &options) { - if (options.target_tpurt && - (options.target_gpu || options.use_bridge_for_gpu)) { - return tensorflow::errors::Internal( - "Invalid pipeline options. Targeting both TPU and GPU is not " - "supported."); - } - return OkStatus(); -} - // If verbose logging is on, dump the output of each pass to a file directory, // set via env var TF_DUMP_GRAPH_PREFIX. e.g.: // export TF_DUMP_GRAPH_PREFIX=/tmp/mlir Status CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm, const TfrtPipelineOptions &options) { - TF_RETURN_IF_ERROR(CreateTFExecutorToTFPipeline(pm, options)); + TF_RETURN_IF_ERROR( + CreateTFExecutorToTFPreInvariantOptimizationPipeline(pm, options)); + CreateTFExecutorToTFInvariantOptimizationPipelineHelper(pm, options); CreateTfToTfrtPipeline(pm, options); return OkStatus(); } Status CreateTFExecutorToTFPipeline(mlir::PassManager &pm, const TfrtPipelineOptions &options) { - TF_RETURN_IF_ERROR(ValidateTfrtPipelineOptions(options)); - if (VLOG_IS_ON(1)) { - // Print the whole module after each pass, which requires disabling - // multi-threading as well. - pm.getContext()->disableMultithreading(); - pm.enableIRPrinting(std::make_unique( - /*print_module_scope=*/true)); - } - CreateTFExecutorToTFPipelineHelper(pm, options); + TF_RETURN_IF_ERROR( + CreateTFExecutorToTFPreInvariantOptimizationPipeline(pm, options)); + CreateTFExecutorToTFInvariantOptimizationPipelineHelper(pm, options); return OkStatus(); } static mlir::PassRegistration tf_to_tfrt_pass; -static mlir::PassRegistration - tf_outline_jitrt_cluster_pass(CreateOutlineJitRtClustersPass); - static mlir::PassPipelineRegistration tf_pipeline( "tf-executor-to-tfrt-pipeline", "Convert Tensorflow Executor dialect to TFRT dialect and " diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc new file mode 100644 index 00000000000..91a4c1d61fd --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_passes.cc @@ -0,0 +1,414 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" +#include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h" +#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h" +#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" +#include "tfrt/jitrt/opdefs/jitrt_ops.h" // from @tf_runtime +#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime + +namespace tensorflow { +namespace { + +class TfrtJitRtStubImpl : public TfrtJitRtStub { + void RegisterJitRtDialects(mlir::DialectRegistry ®istry) override; + + void PopulateJitRtConversionPatterns( + mlir::ConversionTarget *target, mlir::MLIRContext *context, + mlir::RewritePatternSet *patterns, + CoreRTConverter *corert_converter) override; + + mlir::Value CreateJitRtFallbackCompileKernel( + mlir::OpBuilder &builder, mlir::ModuleOp module, + mlir::Value chain_value) override; + + void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, + mlir::OpPassManager &pm) override; +}; + +void TfrtJitRtStubImpl::RegisterJitRtDialects(mlir::DialectRegistry ®istry) { + registry.insert(); +} + +// TODO(ezhulenev): tf_device.cluster operations after auto-fusion should +// have the correct device assigned based on the fused operations. We should +// use this device to convert operands and results from/to corert handles. +// For now it is safe to assume that it is "CPU" because we do not support +// any other devices and do not support distributed models. +constexpr char kJitRtDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + +// Convert jitrt.call operations to the tf_jitrt.fallback.execute operation. +class JitRtCallToJitRtCompileAndExecuteConversion + : public OpConversionPattern { + public: + explicit JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult matchAndRewrite( + tfrt::jitrt::CallOp call, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Convert operands to fallback tensors. + llvm::SmallVector fallback_operands; + if (failed(tfrt_compiler::ConvertFallbackOperands( + call, kJitRtDevice, adaptor.getOperands(), &fallback_operands, + rewriter))) + return rewriter.notifyMatchFailure(call, "failed to convert operand"); + + // tf_jitrt.fallback.execute always produces fallback tensors. + llvm::SmallVector result_types( + call->getNumResults(), + rewriter.getType()); + + // Replace jitrt.call operation with a tf_jitrt.fallback.execute operation. + rewriter.replaceOpWithNewOp( + call, result_types, call.getCallee(), fallback_operands, kJitRtDevice); + + return success(); + } +}; + +// Helper function for inserting TFRT JitRt dialect conversions. +void TfrtJitRtStubImpl::PopulateJitRtConversionPatterns( + mlir::ConversionTarget *target, MLIRContext *context, + RewritePatternSet *patterns, CoreRTConverter *corert_converter) { + target->addLegalDialect(); + target->addIllegalDialect(); + // Lower jitrt.call to the pair of compile and execute operations. + patterns->add(context); +} + +mlir::Value TfrtJitRtStubImpl::CreateJitRtFallbackCompileKernel( + mlir::OpBuilder &builder, mlir::ModuleOp module, mlir::Value chain_value) { + // Pre-compile all JIT compiled kernels found in the module. + llvm::SmallVector compiled; + + // A set SymbolRef attributes referencing compiled kernels. + llvm::DenseSet kernels; + + // Compile all kernels in parallell. + module.walk([&](tf_jitrt::FallbackExecuteOp execute) { + // Do not compiled the same kernel multiple times. + if (kernels.contains(execute.getKernel())) return; + + auto compile = builder.create( + execute.getLoc(), builder.getType(), + execute.getKernel(), execute.getDevice()); + compiled.push_back(compile.getResult()); + kernels.insert(compile.getKernel()); + }); + + // Wait for the compilation completion before returning from init function. + if (!compiled.empty()) { + // Do not forget to wait for the fallback kernels initialization. + compiled.insert(compiled.begin(), chain_value); + chain_value = builder.create( + module.getLoc(), builder.getType(), + compiled); + } + + return chain_value; +} + +// -------------------------------------------------------------------------- // +// Outline tf_device.cluster operation regions into functions in the nested +// modules and replaces all cluster operations with jitrt.call operations. +// -------------------------------------------------------------------------- // + +class OutlineJitRtClustersPass + : public PassWrapper> { + public: + llvm::StringRef getArgument() const final { + return "tf-outline-jitrt-cluster"; + } + llvm::StringRef getDescription() const final { + return "Outlines `tf_device.cluster` operations into functions and " + "replaces them with `jitrt.call` operations."; + } + + void runOnOperation() override; + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OutlineJitRtClustersPass) + + private: + struct CompiledModule { + ModuleOp module; + func::FuncOp entrypoint; + llvm::SetVector operands; + }; + + // Creates a nested module with a single function that will be compiled into + // the kernel at runtime. + CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster, + int64_t max_arg_size, + SymbolTable *symbol_table); + + // Update compiled module entrypoint signature with inferred operands + // constraints. + LogicalResult SetEntrypointConstraints(CompiledModule &compiled); + + // Outlines cluster operation regions into compiled modules, and replaces + // cluster operation with a jitrt.call operation. + LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster, + int64_t max_arg_size, + SymbolTable *symbol_table); + + // Mapping from the outlined module string representation to the module itself + // and an entrypoint function. Used to deduplicate identical modules during + // the `tf_device.cluster` outlining. + llvm::StringMap> outlined_; +}; + +OutlineJitRtClustersPass::CompiledModule +OutlineJitRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster, + int64_t max_arg_size, + SymbolTable *symbol_table) { + MLIRContext *ctx = cluster->getContext(); + Location loc = cluster.getLoc(); + + // Create a module that will hold compiled function and async wrappers. + // TODO(ezhulenev): Give better names to module and function. + auto compiled_module = ModuleOp::create(loc, {"kernel"}); + compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx)); + compiled_module->setAttr( + "tfrt.max-arg-size", + IntegerAttr::get(IntegerType::get(ctx, 64), max_arg_size)); + + SymbolTable compiled_module_symbol_table(compiled_module); + + // Find out the cluster arguments and their types. + llvm::SetVector live_ins; + getUsedValuesDefinedAbove(cluster.getBody(), cluster.getBody(), live_ins); + + llvm::SmallVector operand_types; + operand_types.reserve(live_ins.size()); + for (Value v : live_ins) operand_types.emplace_back(v.getType()); + + // Create a function in the compiled module. + auto compiled_func_type = + FunctionType::get(ctx, operand_types, cluster->getResultTypes()); + auto compiled_func = func::FuncOp::create(loc, "compute", compiled_func_type); + compiled_module_symbol_table.insert(compiled_func); + + // Replace uses of live-in values within cluster region with block arguments. + Block *compiled_func_block = compiled_func.addEntryBlock(); + for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments())) + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + cluster.getBody()); + + // Move all operations in cluster into compiled_func's entry block. + auto &cluster_body = cluster.GetBody().getOperations(); + compiled_func_block->getOperations().splice( + compiled_func_block->end(), cluster_body, cluster_body.begin(), + cluster_body.end()); + + // Replace `tf_device.return` terminator with `func.return` in the function + // body. + auto device_return = + cast(compiled_func_block->getTerminator()); + OpBuilder builder(device_return.getOperation()); + builder.create(device_return.getLoc(), + device_return.getOperands()); + device_return.erase(); + + // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet, + // replace module printing with a more principled solution when available. + // Operations in the cluster can be in different order, however define the + // identical Tensorflow programs, with current approach we'll not be able + // to detect duplicates like this. + + // Remove location attribute attached to Tensorflow operations to be able to + // deduplicate compiled clusters with the same set of operations. + // + // TODO(ezhulenev): Figure out how to propagate locations for error reporting, + // right now JitRt will ignore them anyway. + compiled_module.walk([](Operation *op) { op->removeAttr("_class"); }); + + // Serialize prepared module to string. + std::string serialized; + llvm::raw_string_ostream os(serialized); + compiled_module.print(os); + + // Try to find if identical module was already outlined. + auto it = outlined_.find(serialized); + + // Return identical module that was already outlined earlier. + if (it != outlined_.end()) { + compiled_module.erase(); // erase identical module + return {it->second.first, it->second.second, live_ins}; + } + + // Insert compiled module into the symbol table and assign it a unique name. + symbol_table->insert(compiled_module); + + // Cache unique module. + outlined_.insert({std::move(serialized), {compiled_module, compiled_func}}); + + return {compiled_module, compiled_func, live_ins}; +} + +LogicalResult OutlineJitRtClustersPass::SetEntrypointConstraints( + CompiledModule &compiled) { + func::FuncOp func = compiled.entrypoint; + + // Functions outlined from jitrt device clusters must have a single block. + assert(func.getBody().getBlocks().size() == 1 && "expected single block"); + + mlir::TFDevice::ClusteringPolicySet policies; + populateTfJitRtConstraintsPolicies(policies); + + // Infer constraints on the values defined in the entrypoint function + // (including function entry block arguments). + mlir::TFDevice::ValuesConstraintSet constraints; + if (failed(mlir::TFDevice::PropagateValuesConstraints( + func.getBody(), policies, constraints, /*resolve=*/true))) + return failure(); + + // Annotate arguments with inferred constraints. + for (unsigned i = 0; i < func.getNumArguments(); ++i) { + if (auto constraint = constraints.GetConstraint(func.getArgument(i))) { + auto constraint_name = mlir::StringAttr::get( + &getContext(), llvm::formatv("{0}", *constraint).str()); + func.setArgAttr(i, "rt.constraint", constraint_name); + } + } + + return success(); +} + +LogicalResult OutlineJitRtClustersPass::OutlineClusterOp( + tf_device::ClusterOp cluster, int64_t max_arg_size, + SymbolTable *symbol_table) { + Location loc = cluster->getLoc(); + OpBuilder builder(cluster); + + CompiledModule compiled_module = + CreateCompiledModule(cluster, max_arg_size, symbol_table); + func::FuncOp compiled_func = compiled_module.entrypoint; + + // Add constraints to the entrypoint arguments. + if (failed(SetEntrypointConstraints(compiled_module))) return failure(); + + // Replace device cluster with a jitrt.call operation. + auto module_name = *compiled_module.module.getSymName(); + auto func_name = compiled_func.getSymName(); + auto func_flat_ref = + mlir::SymbolRefAttr::get(builder.getContext(), func_name); + auto func_ref = mlir::SymbolRefAttr::get(builder.getContext(), module_name, + {func_flat_ref}); + + auto cluster_func_op = builder.create( + loc, cluster.getResultTypes(), func_ref, + compiled_module.operands.getArrayRef()); + + cluster.replaceAllUsesWith(cluster_func_op); + cluster.erase(); + + return success(); +} + +void OutlineJitRtClustersPass::runOnOperation() { + ModuleOp module = getOperation(); + SymbolTable symbol_table(module); + + // Keep track of the maximum argument size for each function with tf_device + // cluster operations in the function body. We need to pass it to the compiled + // module to correctly compute its cost later. + llvm::DenseMap max_arg_size_map; + + auto get_max_arg_size = [&](mlir::func::FuncOp func) -> int64_t { + auto it = max_arg_size_map.find(func); + if (it != max_arg_size_map.end()) return it->second; + return max_arg_size_map[func] = tf_jitrt::GetMaxArgSize(func); + }; + + OpBuilder builder(module.getContext()); + auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult { + // Ensure that cluster was formed for TFRT JIT compilation. + auto policy = cluster->getAttr("policy").dyn_cast_or_null(); + if (!policy || policy.getValue() != "tfrt.auto-fusion") + return WalkResult::advance(); + + // Get the maximum argument size of the parent function. + mlir::func::FuncOp parent_func = + cluster->getParentOfType(); + int64_t max_arg_size = get_max_arg_size(parent_func); + + if (failed(OutlineClusterOp(cluster, max_arg_size, &symbol_table))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + module->emitError("Failed to outline tf_device.cluster operations"); + signalPassFailure(); + } +} + +std::unique_ptr CreateOutlineJitRtClustersPass() { + return std::make_unique(); +} + +void TfrtJitRtStubImpl::AddTfrtJitRtPasses(const TfrtPipelineOptions &options, + mlir::OpPassManager &pm) { + // Outline auto-fusion clusters into tf_device.cluster_operations and then + // convert them to functions. We currently support only tfrt fallback tensors + // as operands, so we disable these passes if we can have native ops after + // lowering. + pm.addNestedPass(CreateTfJitRtClusteringPass( + options.auto_fusion_oplist, options.auto_fusion_min_cluster_size)); + + // Sink small constants into the outlined clusters to reduce the number of + // arguments for each of the execute operations. + auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster, + mlir::ElementsAttr value) -> bool { + // Ensure that cluster was formed for TFRT JIT compilation. + auto policy = cluster->getAttr("policy").dyn_cast_or_null(); + if (!policy || policy.getValue() != "tfrt.auto-fusion") return false; + + // Check that TF->JitRt compiler supports constant compilation. + return mlir::succeeded(IsCompilableConstant(value)); + }; + + pm.addNestedPass( + mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const)); + + // Outline formed JIT compiled device clusters into function. + pm.addPass(CreateOutlineJitRtClustersPass()); +} + +mlir::PassRegistration tf_outline_jitrt_cluster_pass( + CreateOutlineJitRtClustersPass); + +const bool kUnused = + (RegisterTfrtJitRtStub(std::make_unique()), true); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc new file mode 100644 index 00000000000..1bde6382c79 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.cc @@ -0,0 +1,76 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h" + +#include +#include +#include + +namespace tensorflow { +namespace { + +class TfrtJitRtStubRegistry { + public: + TfrtJitRtStubRegistry() : stub_(std::make_unique()) {} + + void Register(std::unique_ptr stub) { + stub_ = std::move(stub); + } + + TfrtJitRtStub &Get() { return *stub_; } + + private: + std::unique_ptr stub_; +}; + +TfrtJitRtStubRegistry &GetGlobalTfrtJitRtStubRegistry() { + static auto *const stub = new TfrtJitRtStubRegistry; + return *stub; +} + +} // namespace + +void RegisterTfrtJitRtStub(std::unique_ptr stub) { + GetGlobalTfrtJitRtStubRegistry().Register(std::move(stub)); +} + +void RegisterJitRtDialects(mlir::DialectRegistry ®istry) { + GetGlobalTfrtJitRtStubRegistry().Get().RegisterJitRtDialects(registry); +} + +// Helper function for inserting TFRT JitRt dialect conversions. +void PopulateJitRtConversionPatterns(mlir::ConversionTarget *target, + mlir::MLIRContext *context, + mlir::RewritePatternSet *patterns, + CoreRTConverter *corert_converter) { + GetGlobalTfrtJitRtStubRegistry().Get().PopulateJitRtConversionPatterns( + target, context, patterns, corert_converter); +} + +mlir::Value CreateJitRtFallbackCompileKernel(mlir::OpBuilder &builder, + mlir::ModuleOp module, + mlir::Value chain_value) { + return GetGlobalTfrtJitRtStubRegistry() + .Get() + .CreateJitRtFallbackCompileKernel(builder, module, chain_value); +} + +void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, + mlir::OpPassManager &pm) { + GetGlobalTfrtJitRtStubRegistry().Get().AddTfrtJitRtPasses(options, pm); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h new file mode 100644 index 00000000000..d9c00c4d376 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_jitrt_stub.h @@ -0,0 +1,71 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ + +#include + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" + +namespace tensorflow { + +class TfrtJitRtStub { + public: + virtual ~TfrtJitRtStub() = default; + + virtual void RegisterJitRtDialects(mlir::DialectRegistry ®istry) {} + + virtual void PopulateJitRtConversionPatterns( + mlir::ConversionTarget *target, mlir::MLIRContext *context, + mlir::RewritePatternSet *patterns, CoreRTConverter *corert_converter) {} + + virtual mlir::Value CreateJitRtFallbackCompileKernel( + mlir::OpBuilder &builder, mlir::ModuleOp module, + mlir::Value chain_value) { + return chain_value; + } + + virtual void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, + mlir::OpPassManager &pm) {} +}; + +void RegisterTfrtJitRtStub(std::unique_ptr stub); + +void RegisterJitRtDialects(mlir::DialectRegistry ®istry); + +// Helper function for inserting TFRT JitRt dialect conversions. +void PopulateJitRtConversionPatterns(mlir::ConversionTarget *target, + mlir::MLIRContext *context, + mlir::RewritePatternSet *patterns, + CoreRTConverter *corert_converter); + +mlir::Value CreateJitRtFallbackCompileKernel(mlir::OpBuilder &builder, + mlir::ModuleOp module, + mlir::Value chain_value); + +void AddTfrtJitRtPasses(const TfrtPipelineOptions &options, + mlir::OpPassManager &pm); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_JITRT_STUB_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h index cdd221750dc..24d245b1714 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h @@ -122,6 +122,11 @@ struct TfrtPipelineOptions "out to run during loading."), llvm::cl::init(false)}; + Option fuse_get_resource_ops_in_hoisting{ + *this, "fuse-get-resource-ops-in-hoisting", + llvm::cl::desc("If true, get_resource_op will be fused during hoisting"), + llvm::cl::init(true)}; + Option sink_in_invariant_ops{ *this, "sink-in-invariant-ops", llvm::cl::desc("If true, sink the selected invariant ops in to the " diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 2980c5486eb..1573306ba2d 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -16,21 +16,26 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" #include +#include #include #include #include -#include "absl/strings/match.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" #include "tensorflow/core/common_runtime/function_body.h" #include "tensorflow/core/common_runtime/function_def_utils.h" +#include "tensorflow/tsl/platform/errors.h" #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime namespace tensorflow { @@ -108,7 +113,7 @@ Status ConvertFunctionToBef( if (!expected_module.ok()) return tensorflow::errors::Internal( "Failed to convert function to mlir for function ", function_name.str(), - ". Error: ", expected_module.status().error_message()); + ". Error: ", expected_module.status().message()); auto module = std::move(expected_module).value(); @@ -165,8 +170,7 @@ Status ConvertTfMlirToRuntimeExecutable( } } else if (options.device_target == TfrtDeviceInfraTarget::kGpu && options.use_bridge_for_gpu) { - TF_RETURN_IF_ERROR( - mlir::TF::RunTFXLABridge(module, /*enable_logging=*/VLOG_IS_ON(1))); + TF_RETURN_IF_ERROR(mlir::TF::RunTFXLABridge(module)); // GPU XLA clusters are wrapped in functions, which could be transformed by // bridge. Hence, the MLIR functions for XLA clusters are exported and added @@ -187,44 +191,13 @@ Status ConvertTfMlirToRuntimeExecutable( // Lower MLIR TF Dialect to MLIR TFRT CoreRT dialect. mlir::PassManager pm(module.getContext()); - tensorflow::TfrtPipelineOptions pass_options; - if (!options.default_device.empty()) { - pass_options.default_device = options.default_device; - } - if (!options.force_data_format.empty()) { - pass_options.force_data_format = options.force_data_format; - } - - // TODO(b/187991150): Consider only decomposing read-only resource variable - // ops. - pass_options.decompose_resource_ops = options.decompose_resource_ops; - pass_options.enable_optimizer = options.enable_optimizer; - pass_options.target_tpurt = - (options.device_target == TfrtDeviceInfraTarget::kTpurt); - pass_options.target_gpu = - (options.device_target == TfrtDeviceInfraTarget::kGpu); - pass_options.use_bridge_for_gpu = options.use_bridge_for_gpu; - pass_options.tpu_fuse_ops = options.tpu_fuse_ops; - pass_options.use_tpu_host_allocator_for_inputs = - options.use_tpu_host_allocator_for_inputs; - pass_options.tpu_allow_unpadded_batch = options.tpu_allow_unpadded_batch; - pass_options.sink_in_invariant_ops = options.sink_in_invariant_ops; - pass_options.hoist_invariant_ops = options.hoist_invariant_ops; - pass_options.func_use_fallback_tensor = true; - pass_options.enable_while_parallel_iterations = - options.enable_while_parallel_iterations; - pass_options.auto_fusion_oplist = options.auto_fusion_oplist; - pass_options.auto_fusion_min_cluster_size = - options.auto_fusion_min_cluster_size; - pass_options.cost_threshold = options.cost_threshold; - pass_options.upper_cost_threshold = options.upper_cost_threshold; - pass_options.merge_inter_dependent_streams = - options.merge_inter_dependent_streams; + auto pipeline_options = GetTfrtPipelineOptions(options); TF_RETURN_IF_ERROR( - tensorflow::CreateTFExecutorToTFPipeline(pm, pass_options)); + tensorflow::CreateTFExecutorToTFPreInvariantOptimizationPipeline( + pm, *pipeline_options)); - auto status = emit_executable(pm, module, pass_options); + auto status = emit_executable(pm, module, *pipeline_options); if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("tfrt_dialect", module); @@ -241,11 +214,17 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, [bef_buffer](mlir::PassManager& pm, mlir::ModuleOp module, const tensorflow::TfrtPipelineOptions& options) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + tensorflow::CreateTFExecutorToTFInvariantOptimizationPipelineHelper( + pm, options); tensorflow::CreateTfToTfrtPipeline(pm, options); - if (mlir::failed(pm.run(module))) + if (mlir::failed(pm.run(module))) { + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("tf_to_corert_failure", module); + } return diag_handler.Combine(tensorflow::errors::Internal( "failed to lower TF Dialect to CoreRT dialect.")); + } *bef_buffer = tfrt::ConvertMLIRToBEF(module, /*disable_optional_sections=*/true); @@ -259,4 +238,45 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, fallback_state); } +std::unique_ptr GetTfrtPipelineOptions( + const TfrtCompileOptions& options) { + auto pipeline_options = std::make_unique(); + if (!options.default_device.empty()) { + pipeline_options->default_device = options.default_device; + } + if (!options.force_data_format.empty()) { + pipeline_options->force_data_format = options.force_data_format; + } + + // TODO(b/187991150): Consider only decomposing read-only resource variable + // ops. + pipeline_options->decompose_resource_ops = options.decompose_resource_ops; + pipeline_options->enable_optimizer = options.enable_optimizer; + pipeline_options->target_tpurt = + (options.device_target == TfrtDeviceInfraTarget::kTpurt); + pipeline_options->target_gpu = + (options.device_target == TfrtDeviceInfraTarget::kGpu); + pipeline_options->use_bridge_for_gpu = options.use_bridge_for_gpu; + pipeline_options->tpu_fuse_ops = options.tpu_fuse_ops; + pipeline_options->use_tpu_host_allocator_for_inputs = + options.use_tpu_host_allocator_for_inputs; + pipeline_options->tpu_allow_unpadded_batch = options.tpu_allow_unpadded_batch; + pipeline_options->sink_in_invariant_ops = options.sink_in_invariant_ops; + pipeline_options->hoist_invariant_ops = options.hoist_invariant_ops; + pipeline_options->fuse_get_resource_ops_in_hoisting = + options.fuse_get_resource_ops_in_hoisting; + pipeline_options->func_use_fallback_tensor = true; + pipeline_options->enable_while_parallel_iterations = + options.enable_while_parallel_iterations; + pipeline_options->auto_fusion_oplist = options.auto_fusion_oplist; + pipeline_options->auto_fusion_min_cluster_size = + options.auto_fusion_min_cluster_size; + pipeline_options->cost_threshold = options.cost_threshold; + pipeline_options->upper_cost_threshold = options.upper_cost_threshold; + pipeline_options->merge_inter_dependent_streams = + options.merge_inter_dependent_streams; + + return pipeline_options; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.h b/tensorflow/compiler/mlir/tfrt/translate/import_model.h index 9df6ae57137..2b2dc6cc987 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.h +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ +#include #include -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tfrt/function/function.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/status.h" @@ -61,6 +61,9 @@ Status ConvertTfMlirToRuntimeExecutable( emit_executable, tfrt_stub::FallbackState* fallback_state = nullptr); +std::unique_ptr GetTfrtPipelineOptions( + const TfrtCompileOptions& options); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index 1f6bcb54baf..e451cf737f3 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "tensorflow/core/protobuf/config.pb.h" + namespace tensorflow { enum class TfrtDeviceInfraTarget { @@ -47,6 +49,10 @@ struct TfrtCompileOptions { // If true, run grappler passes before compiling. bool enable_grappler = true; + // Graph rewrite options that will be applied on GraphDef before converting to + // MLIR. + GraphOptions graph_options; + // Force data format for all layout sensitive operations, eg. setting it to // "NHWC" will changes all data format in the graph to "NHWC" by inserting // or removing related tf.Transpose op. Currently the supported formats are @@ -97,6 +103,9 @@ struct TfrtCompileOptions { // supposed to be turned on by default. bool hoist_invariant_ops = false; + // If true, get_resource_op will be fused during hoisting. + bool fuse_get_resource_ops_in_hoisting = true; + // If true, the compiler will try to sink in the invariant ops (e.g. const // ops, var handle ops, etc.) to the nested function (e.g. batch function) to // facilitate invariant ops hoisting. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index c815a19f411..145cce5ac6b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -45,7 +45,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf_no_fallback", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_no_fallback", "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:bufferize", "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:gpu_passes", # fixdeps: keep "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", @@ -60,11 +60,13 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BufferizationTransforms", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToStandard", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:GPUToLLVMIRTranslation", "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", @@ -96,24 +98,24 @@ tf_cc_binary( ], deps = [ ":kernel_creator", - "@llvm-project//llvm:TargetParser", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", "@com_google_absl//absl/strings", - "@llvm-project//llvm:Analysis", "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep + "@llvm-project//llvm:Analysis", "@llvm-project//llvm:CodeGen", "@llvm-project//llvm:Core", "@llvm-project//llvm:MC", "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep "@llvm-project//llvm:X86Disassembler", # fixdeps: keep "@llvm-project//mlir:ExecutionEngineUtils", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:ToLLVMIRTranslation", ] + if_llvm_system_z_available([ "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep @@ -188,14 +190,14 @@ cc_library( "-DTENSORFLOW_USE_ROCM=1", ]), deps = [ - "@llvm-project//mlir:mlir_runner_utils", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", + "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/core:framework", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:mutex", - "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/tsl/platform:hash", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:mlir_runner_utils", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "//tensorflow/compiler/xla/stream_executor/cuda:stream_executor_cuda", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index f040ca2af3b..29ae6752cd2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -83,6 +83,7 @@ cc_library( ":tf_framework_ops_inc_gen", ":tf_status_inc_gen", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "@com_google_absl//absl/status", "@llvm-project//mlir:AllocationOpInterface", "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:ControlFlowInterfaces", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index ac791ace79c..48e288eb48d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -120,43 +120,43 @@ std::optional JITExecuteOp::buildClone(OpBuilder &builder, Value alloc) { .getResult(); } -::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code) { +absl::StatusCode ConvertAttrToEnumValue(ErrorCode error_code) { using ::tensorflow::error::Code; switch (error_code) { case ErrorCode::OK: - return Code::OK; + return absl::StatusCode::kOk; case ErrorCode::CANCELLED: - return Code::CANCELLED; + return absl::StatusCode::kCancelled; case ErrorCode::UNKNOWN: - return Code::UNKNOWN; + return absl::StatusCode::kUnknown; case ErrorCode::INVALID_ARGUMENT: - return Code::INVALID_ARGUMENT; + return absl::StatusCode::kInvalidArgument; case ErrorCode::DEADLINE_EXCEEDED: - return Code::DEADLINE_EXCEEDED; + return absl::StatusCode::kDeadlineExceeded; case ErrorCode::NOT_FOUND: - return Code::NOT_FOUND; + return absl::StatusCode::kNotFound; case ErrorCode::ALREADY_EXISTS: - return Code::ALREADY_EXISTS; + return absl::StatusCode::kAlreadyExists; case ErrorCode::PERMISSION_DENIED: - return Code::PERMISSION_DENIED; + return absl::StatusCode::kPermissionDenied; case ErrorCode::UNAUTHENTICATED: - return Code::UNAUTHENTICATED; + return absl::StatusCode::kUnauthenticated; case ErrorCode::RESOURCE_EXHAUSTED: - return Code::RESOURCE_EXHAUSTED; + return absl::StatusCode::kResourceExhausted; case ErrorCode::FAILED_PRECONDITION: - return Code::FAILED_PRECONDITION; + return absl::StatusCode::kFailedPrecondition; case ErrorCode::ABORTED: - return Code::ABORTED; + return absl::StatusCode::kAborted; case ErrorCode::OUT_OF_RANGE: - return Code::OUT_OF_RANGE; + return absl::StatusCode::kOutOfRange; case ErrorCode::UNIMPLEMENTED: - return Code::UNIMPLEMENTED; + return absl::StatusCode::kUnimplemented; case ErrorCode::INTERNAL: - return Code::INTERNAL; + return absl::StatusCode::kInternal; case ErrorCode::UNAVAILABLE: - return Code::UNAVAILABLE; + return absl::StatusCode::kUnavailable; case ErrorCode::DATA_LOSS: - return Code::DATA_LOSS; + return absl::StatusCode::kDataLoss; } } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h index 6f05c194093..c5f011f25cf 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ +#include "absl/status/status.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -48,7 +49,7 @@ class JITCallableType using Base::Base; }; -::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code); +absl::StatusCode ConvertAttrToEnumValue(ErrorCode error_code); } // namespace tf_framework } // namespace kernel_gen diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 2f946e28bd8..0a1977380da 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -46,6 +46,8 @@ limitations under the License. #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project @@ -409,6 +411,8 @@ StatusOr> SetupContextAndParseModule( mlir::DialectRegistry registry; mlir::RegisterAllTensorFlowDialects(registry); registry.insert(); + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerGPUDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); mlir::registerROCDLDialectTranslation(registry); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir index bd5ef12ba3f..e2a5601fc53 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -1,10 +1,10 @@ // RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s // CHECK: llvm.func @_mlir_ciface_tf_alloc -// CHECK-SAME: (!llvm.ptr, i64, i64, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK-SAME: (!llvm.ptr, i64, i64, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK-LABEL: llvm.func @alloc( -// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, +// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, // CHECK-SAME: [[SIZE_0:%.*]]: i64, // CHECK-SAME: [[SIZE_2:%.*]]: i64) -> [[DESC_TY:!.*]] { func.func @alloc(%ctx: !tf_framework.op_kernel_context, @@ -18,16 +18,16 @@ func.func @alloc(%ctx: !tf_framework.op_kernel_context, // CHECK: [[NUM_ELEMS:%.*]] = llvm.mul [[NUM_ELEM_0]], [[SIZE_2]] : i64 // Compute the size of an individual element. -// CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]]{{\[}}1] -// CHECK-SAME: (!llvm.ptr) -> !llvm.ptr +// CHECK-SAME: (!llvm.ptr) -> !llvm.ptr, f32 // CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]] -// CHECK-SAME: !llvm.ptr to i64 +// CHECK-SAME: !llvm.ptr to i64 // Compute output index (-1) and candidate indices (0, NULL). // CHECK: [[OUTPUT_INDEX:%.*]] = llvm.mlir.constant(-1 : i32) : i32 // CHECK-NEXT: [[NUM_CANDIDATES:%.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr // Allocate memory. // CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_ELEMS]], @@ -38,10 +38,8 @@ func.func @alloc(%ctx: !tf_framework.op_kernel_context, // CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]] // Set pointers and offset. -// CHECK: [[FLOAT_PTR:%.*]] = llvm.bitcast [[BYTES_PTR]] -// CHECK-SAME: !llvm.ptr to !llvm.ptr -// CHECK: [[DESC_1:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_0]][0] -// CHECK: [[DESC_2:%.*]] = llvm.insertvalue [[FLOAT_PTR]], [[DESC_1]][1] +// CHECK: [[DESC_1:%.*]] = llvm.insertvalue [[BYTES_PTR]], [[DESC_0]][0] +// CHECK: [[DESC_2:%.*]] = llvm.insertvalue [[BYTES_PTR]], [[DESC_1]][1] // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: [[DESC_3:%.*]] = llvm.insertvalue [[C0]], [[DESC_2]][2] : [[DESC_TY]] @@ -59,10 +57,10 @@ func.func @alloc(%ctx: !tf_framework.op_kernel_context, // ----- -// CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr, !llvm.ptr) +// CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr, !llvm.ptr) // CHECK-LABEL: llvm.func @dealloc( -// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, +// CHECK-SAME: [[TF_CTX:%[a-z0-9]*]]: !llvm.ptr func.func @dealloc(%ctx: !tf_framework.op_kernel_context, %memref : memref) { tf_framework.dealloc(%ctx, %memref) : memref @@ -71,29 +69,27 @@ func.func @dealloc(%ctx: !tf_framework.op_kernel_context, // Extract allocated ptr from the memref descriptor. // CHECK: %{{.*}} = llvm.mlir.undef : [[DESC_TY:!.*]] // CHECK: [[FLOAT_PTR:%.*]] = llvm.extractvalue %{{.*}}[0] : [[DESC_TY]] -// CHECK-NEXT: [[VOID_PTR:%.*]] = llvm.bitcast [[FLOAT_PTR]] -// CHECK-SAME: !llvm.ptr to !llvm.ptr // Deallocate. // CHECK: llvm.call @_mlir_ciface_tf_dealloc( -// CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr, !llvm.ptr) -> () +// CHECK-SAME: [[TF_CTX]], [[FLOAT_PTR]]) : (!llvm.ptr, !llvm.ptr) -> () // ----- -// CHECK-LABEL: llvm.func @_mlir_ciface_tf_report_error(!llvm.ptr, i32, !llvm.ptr) +// CHECK-LABEL: llvm.func @_mlir_ciface_tf_report_error(!llvm.ptr, i32, !llvm.ptr) // CHECK: llvm.mlir.global internal constant [[MSG_CONST:@error_message_[0-9]+]]("Everything is awesome\00") func.func @report_error(%ctx: !tf_framework.op_kernel_context) { tf_framework.report_error %ctx, "INVALID_ARGUMENT", "Everything is awesome" loc(unknown) func.return } -// CHECK: llvm.func @report_error([[CTX:%.*]]: !llvm.ptr) +// CHECK: llvm.func @report_error([[CTX:%.*]]: !llvm.ptr) // CHECK-NEXT: [[ADDR:%.*]] = llvm.mlir.addressof [[MSG_CONST]] // CHECK: [[MSG:%.*]] = llvm.getelementptr [[ADDR]] // CHECK: [[CODE:%.*]] = llvm.mlir.constant({{.*}}) : i32 // CHECK: llvm.call @{{.*}}_tf_report_error([[CTX]], [[CODE]], [[MSG]]) -// ---- +// ----- // CHECK-LABEL: llvm.func @unranked_null_memref() func.func @unranked_null_memref() { @@ -101,12 +97,12 @@ func.func @unranked_null_memref() { func.return } // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> // CHECK: [[DESC_1:%.*]] = llvm.insertvalue [[C0]], [[DESC_0]][0] // CHECK: [[PTR:%.*]] = llvm.alloca {{.*}} x i8 // CHECK: [[DESC_2:%.*]] = llvm.insertvalue [[PTR]], [[DESC_1]][1] -// ---- +// ----- // CHECK-LABEL: llvm.func @ranked_null_memref() func.func @ranked_null_memref() { @@ -119,9 +115,9 @@ func.func @ranked_null_memref() { // CHECK-NEXT: %[[C1_:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: llvm.mlir.null -// CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[DESC_0:.*]] = llvm.mlir.undef : -// CHECK-SAME: !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-SAME: !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[DESC_1:.*]] = llvm.insertvalue %[[NULL]], %[[DESC_0]][0] // CHECK-NEXT: %[[DESC_2:.*]] = llvm.insertvalue %[[NULL]], %[[DESC_1]][1] // CHECK-NEXT: %[[DESC_3:.*]] = llvm.insertvalue %[[C0]], %[[DESC_2]][2] @@ -130,7 +126,7 @@ func.func @ranked_null_memref() { // CHECK-NEXT: %[[DESC_6:.*]] = llvm.insertvalue %[[C1]], %[[DESC_5]][3, 1] // CHECK-NEXT: %[[DESC_7:.*]] = llvm.insertvalue %[[C1_]], %[[DESC_6]][4, 1] -// ---- +// ----- // CHECK-LABEL: llvm.func @is_valid_memref func.func @is_valid_memref(%buf: memref) -> i1 { @@ -146,19 +142,18 @@ func.func @is_valid_memref(%buf: memref) -> i1 { // CHECK-NEXT: %[[IS_EMPTY_:.*]] = llvm.or %[[IS_EMPTY]], %[[IS_ZERO]] : i1 // CHECK-NEXT: %[[PTR_F32:.*]] = llvm.extractvalue %[[MEMREF]][0] -// CHECK-NEXT: %[[VOID_PTR:.*]] = llvm.bitcast %[[PTR_F32]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: %[[NULL_PTR:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[NOT_NULL:.*]] = llvm.icmp "ne" %[[VOID_PTR]], %[[NULL_PTR]] +// CHECK-NEXT: %[[NULL_PTR:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[NOT_NULL:.*]] = llvm.icmp "ne" %[[PTR_F32]], %[[NULL_PTR]] // CHECK-NEXT: %[[PRED:.*]] = llvm.or %[[NOT_NULL]], %[[IS_EMPTY_]] : i1 // CHECK-NEXT: llvm.return %[[PRED]] // ----- -// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_compile(!llvm.ptr, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr, i64, i1, i1, i1) -> !llvm.ptr +// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_compile(!llvm.ptr, !llvm.ptr, i64, !llvm.ptr, i64, !llvm.ptr, i64, i1, i1, i1) -> !llvm.ptr // CHECK: llvm.mlir.global internal constant @[[CODE:jit_module_code_[0-9]+]]("placeholder\00") -// CHECK: @jit_compile_from_str(%[[CTX:.*]]: !llvm.ptr) +// CHECK: @jit_compile_from_str(%[[CTX:.*]]: !llvm.ptr) func.func @jit_compile_from_str(%ctx: !tf_framework.op_kernel_context) -> !tf_framework.jit_callable { // CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @[[CODE]] @@ -205,10 +200,10 @@ func.func @jit_compile_from_str(%ctx: !tf_framework.op_kernel_context) // ----- -// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_execute(!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) +// CHECK-LABEL: llvm.func @_mlir_ciface_tf_jit_execute(!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) // CHECK: @jit_execute -// CHECK-SAME: (%[[CTX:.*]]: !llvm.ptr, %[[CALLABLE:.*]]: !llvm.ptr, %[[RANK:.*]]: i64, %[[ARG_DESCR:.*]]: !llvm.ptr) +// CHECK-SAME: (%[[CTX:.*]]: !llvm.ptr, %[[CALLABLE:.*]]: !llvm.ptr, %[[RANK:.*]]: i64, %[[ARG_DESCR:.*]]: !llvm.ptr) func.func @jit_execute(%ctx: !tf_framework.op_kernel_context, %callable : !tf_framework.jit_callable, %arg : memref<*xf32>) -> memref<*xf32> { @@ -216,24 +211,21 @@ func.func @jit_execute(%ctx: !tf_framework.op_kernel_context, // CHECK: %[[T1:.*]] = llvm.insertvalue %[[RANK]], %[[T0]][0] // CHECK: %[[ARG:.*]] = llvm.insertvalue %[[ARG_DESCR]], %[[T1]][1] // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) - // CHECK: %[[RESULT_PTR:.*]] = llvm.alloca %[[C1]] x !llvm.struct<(i64, ptr)> - // CHECK: %[[RESULT_PTR_:.*]] = llvm.bitcast %[[RESULT_PTR]] - + // CHECK: %[[RESULT_PTR:.*]] = llvm.alloca %[[C1]] x !llvm.struct<(i64, ptr)> + // Copy argument(s) to stack-allocated buffer. // CHECK: %[[NUM_ARGS:.*]] = llvm.mlir.constant(1 : i64) - // CHECK: %[[ARGS_PTR:.*]] = llvm.alloca %[[NUM_ARGS]] x !llvm.struct<(i64, ptr)> + // CHECK: %[[ARGS_PTR:.*]] = llvm.alloca %[[NUM_ARGS]] x !llvm.struct<(i64, ptr)> // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) // CHECK: %[[ARGS0_PTR:.*]] = llvm.getelementptr %[[ARGS_PTR]][%[[C0]]] // CHECK: llvm.store %[[ARG]], %[[ARGS0_PTR]] - // CHECK: %[[ARGS_PTR_:.*]] = llvm.bitcast %[[ARGS_PTR]] - // CHECK: llvm.call @_mlir_ciface_tf_jit_execute(%[[CTX]], %[[CALLABLE]], %[[RESULT_PTR_]], %[[NUM_ARGS]], %[[ARGS_PTR_]]) + // CHECK: llvm.call @_mlir_ciface_tf_jit_execute(%[[CTX]], %[[CALLABLE]], %[[RESULT_PTR]], %[[NUM_ARGS]], %[[ARGS_PTR]]) // CHECK: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] // Copy unranked memref descriptor to stack-allocated memory. // ... - // CHECK: %[[RESULT_DESCR_SIZE:.*]] = llvm.add %16, %20 // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) - // CHECK: %[[STACK_RESULT_DESCR:.*]] = llvm.alloca %[[RESULT_DESCR_SIZE]] x i8 + // CHECK: %[[STACK_RESULT_DESCR:.*]] = llvm.alloca %[[RESULT_DESCR_SIZE:[0-9]*]] x i8 // CHECK: %[[RESULT_DESCR:.*]] = llvm.extractvalue %[[RESULT]][1] // CHECK: "llvm.intr.memcpy"(%[[STACK_RESULT_DESCR]], %[[RESULT_DESCR]], %[[RESULT_DESCR_SIZE]], %[[FALSE]]) // CHECK: llvm.call @free(%[[RESULT_DESCR]]) @@ -244,9 +236,8 @@ func.func @jit_execute(%ctx: !tf_framework.op_kernel_context, // Copy unranked memref descriptor to heap-allocated memory for return. // ... - // CHECK: %[[RESULT_DESCR_SIZE:.*]] = llvm.add %33, %37 // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) - // CHECK: %[[HEAP_RESULT_DESCR:.*]] = llvm.call @malloc(%[[RESULT_DESCR_SIZE]]) + // CHECK: %[[HEAP_RESULT_DESCR:.*]] = llvm.call @malloc(%[[RESULT_DESCR_SIZE:[0-9]*]]) // CHECK: %[[STACK_RESULT_DESCR:.*]] = llvm.extractvalue %[[RESULT]][1] // CHECK: "llvm.intr.memcpy"(%[[HEAP_RESULT_DESCR]], %[[STACK_RESULT_DESCR]], %[[RESULT_DESCR_SIZE]], %[[FALSE]]) // CHECK: %[[T0:.*]] = llvm.mlir.undef diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir index cf5b8f9620f..cf26ac7d6a1 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_kernel_gpu_launch_to_llvm.mlir @@ -10,20 +10,20 @@ gpu.module @kernel_module attributes {gpu.binary_blob = "BLOB!"} { } } -// CHECK: llvm.func @_mlir_ciface_tf_launch_kernel(!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, i64, !llvm.ptr>) +// CHECK: llvm.func @_mlir_ciface_tf_launch_kernel(!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, i64, !llvm.ptr) // CHECK-DAG: llvm.mlir.global internal constant @kernel_module_the_kernel_kernel_name("the_kernel\00") // CHECK-DAG: llvm.mlir.global internal constant @kernel_module_blob("BLOB!") // CHECK-LABEL: llvm.func @launch -// CHECK-SAME: (%[[CTX:.*]]: !llvm.ptr, %{{.*}}: !llvm.ptr, %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64 +// CHECK-SAME: (%[[CTX:.*]]: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64 func.func @launch(%ctx: !tf_framework.op_kernel_context, %memref: memref) { // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 - // CHECK: %[[BLOB:.*]] = llvm.mlir.addressof @kernel_module_blob : !llvm.ptr> - // CHECK: %[[BLOB_PTR:.*]] = llvm.getelementptr %[[BLOB]][0, 0] : (!llvm.ptr>) -> !llvm.ptr - // CHECK: %[[NAME:.*]] = llvm.mlir.addressof @kernel_module_the_kernel_kernel_name : !llvm.ptr> - // CHECK: %[[NAME_PTR:.*]] = llvm.getelementptr %[[NAME]][0, 0] : (!llvm.ptr>) -> !llvm.ptr + // CHECK: %[[BLOB:.*]] = llvm.mlir.addressof @kernel_module_blob : !llvm.ptr + // CHECK: %[[BLOB_PTR:.*]] = llvm.getelementptr %[[BLOB]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<5 x i8> + // CHECK: %[[NAME:.*]] = llvm.mlir.addressof @kernel_module_the_kernel_kernel_name : !llvm.ptr + // CHECK: %[[NAME_PTR:.*]] = llvm.getelementptr %[[NAME]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8> // CHECK: %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[ARGS:.*]] = llvm.alloca %22 x !llvm.ptr : (i32) -> !llvm.ptr> + // CHECK: %[[ARGS:.*]] = llvm.alloca %22 x !llvm.ptr : (i32) -> !llvm.ptr // CHECK: llvm.call @_mlir_ciface_tf_launch_kernel(%[[CTX]], %[[BLOB_PTR]], %[[NAME_PTR]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[ARGS]]) %c1 = arith.constant 1 : index gpu.launch_func @kernel_module::@the_kernel diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir index 79b1ca008b9..2d8ef4cd763 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir @@ -88,3 +88,40 @@ func.func @binary_sub(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*x // CHECK-SAME: } // CHECK: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG0]], %[[ARG1]]) // CHECK: return %[[RES]] + +// CHECK-JFLT-LABEL: @binary_sub +// CHECK-JFLT: %[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32> +// CHECK-JFLT: %[[LIMIT:.*]] = arith.constant 4294967296 +// CHECK-JFLT: %[[SHAPE1:.*]] = shape.shape_of %[[ARG0]] : tensor<*xf32> -> tensor +// CHECK-JFLT: %[[ELEMENTCOUNT1:.*]] = shape.num_elements %[[SHAPE1]] : tensor -> index +// CHECK-JFLT: %[[COMP1:.*]] = arith.cmpi sgt, %[[ELEMENTCOUNT1]], %[[LIMIT]] : index +// CHECK-JFLT: %[[SHAPE2:.*]] = shape.shape_of %[[ARG1]] : tensor<*xf32> -> tensor +// CHECK-JFLT: %[[ELEMENTCOUNT2:.*]] = shape.num_elements %[[SHAPE2]] : tensor -> index +// CHECK-JFLT: %[[COMP2:.*]] = arith.cmpi sgt, %[[ELEMENTCOUNT2]], %[[LIMIT]] : index +// CHECK-JFLT: %[[COMPRES:.*]] = arith.ori %[[COMP1]], %[[COMP2]] : i1 +// CHECK-JFLT: %[[IFRES:.*]] = scf.if %[[COMPRES]] -> (tensor<*xf32>) { +// CHECK-JFLT: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str +// CHECK-JFLT-SAME: " +// CHECK-JFLT-SAME: module { +// CHECK-JFLT-SAME: func @main(%[[ARG0_JIT:.*]]: tensor<*xf32>, %[[ARG1_JIT:.*]]: tensor<*xf32>) -> tensor<*xf32> +// CHECK-JFLT-SAME: attributes {tf_entry} +// CHECK-JFLT-SAME: { +// CHECK-JFLT-SAME: %[[RES_JIT:.*]] = \22tf.Sub\22(%[[ARG0_JIT]], %[[ARG1_JIT]]) +// CHECK-JFLT-SAME: return %[[RES_JIT]] +// CHECK-JFLT-SAME: } +// CHECK-JFLT-SAME: } +// CHECK-JFLT-SAME: " +// CHECK-JFLT-SAME: { +// CHECK-JFLT-SAME: cpuCodegen = false +// CHECK-JFLT-SAME: enableFtz = false +// CHECK-JFLT-SAME: maxSupportedRank = 32 : i64 +// CHECK-JFLT-SAME: tileSizes = [1, 2, 3] +// CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] +// CHECK-JFLT-SAME: } +// CHECK-JFLT: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG0]], %[[ARG1]]) +// CHECK-JFLT: scf.yield %[[RES]] : tensor<*xf32> +// CHECK-JFLT: } else { +// CHECK-JFLT: %[[RES2:.*]] = "tf.Sub"(%[[ARG0]], %[[ARG1]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-JFLT: scf.yield %[[RES2]] : tensor<*xf32> +// CHECK-JFLT: } +// CHECK-JFLT: return %[[IFRES]] diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index 5b98741b053..fe9d26723b9 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -130,8 +130,8 @@ std::string GetFileCachePath(const std::string cache_dir, llvm::orc::SymbolMap TFFrameworkSymbolMap(llvm::orc::MangleAndInterner mangle) { llvm::orc::SymbolMap symbol_map; auto bind = [&](llvm::StringRef name, auto symbol_ptr) { - symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol( - llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags()); + symbol_map[mangle(name)] = {llvm::orc::ExecutorAddr::fromPtr(symbol_ptr), + llvm::JITSymbolFlags()}; }; // Register TF framework symbols. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc index 4798d8508a0..8d5d583a2dc 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -29,6 +29,7 @@ #include "llvm/IR/Module.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Host.h" @@ -117,6 +118,9 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, // Compile. mlir::MLIRContext context; + llvm::SourceMgr source_mgr; + mlir::SourceMgrDiagnosticHandler source_mgr_handler(source_mgr, &context); + TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, GenerateKernelForTfCode(context, tf_code, architectures, tile_sizes, diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 79a43aeb240..70127f58c4a 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -103,43 +103,57 @@ cc_library( deps = [ ":embed_tf_framework", ":kernel_gen_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/mlir_hlo:gml_st", + "//tensorflow/compiler/xla/mlir_hlo:lhlo", + "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", + "//tensorflow/compiler/xla/mlir_hlo:type_conversion", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service/gpu:gpu_asm_opts_util", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/core:lib", + "//tensorflow/core/platform:errors", "@llvm-project//llvm:Support", "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithTransforms", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:ControlFlowToLLVM", - "@llvm-project//mlir:MathToLibm", - "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MathToLibm", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:NVVMToLLVMIRTranslation", # buildcleaner: keep - "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:ROCDLToLLVMIRTranslation", # buildcleaner: keep + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:ShapeTransforms", - "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", @@ -148,20 +162,6 @@ cc_library( "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:VectorTransforms", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", - "//tensorflow/compiler/xla/mlir_hlo:lhlo", - "//tensorflow/compiler/xla/mlir_hlo:gml_st", - "//tensorflow/compiler/xla/mlir_hlo:type_conversion", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", - "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", - "//tensorflow/compiler/xla/service/gpu:gpu_asm_opts_util", - "//tensorflow/compiler/xla/service/gpu:target_constants", - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/core:lib", - "//tensorflow/core/platform:errors", ] + if_cuda_is_configured([ "//tensorflow/tsl/platform:cuda_libdevice_path", "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index dcb59b2ae06..0d76cd4c93c 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -70,7 +70,7 @@ class GpuKernelToBlobPass return; } // Forward the error by attaching the message to the gpu module. - gpu_module.emitError(blob_or.status().error_message()); + gpu_module.emitError(blob_or.status().message()); return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index c91ed7c427c..9a5b0749888 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -62,14 +62,14 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { Location loc, Type size_ty, Type element_ty, std::optional attr, ConversionPatternRewriter *rewriter, std::function create_element) const { - Type element_ptr_ty = LLVM::LLVMPointerType::get(element_ty); + Type ptr_ty = LLVM::LLVMPointerType::get(element_ty.getContext()); // If the attribute is missing or empty, set the element count to 0 and // return NULL. if (!attr.has_value() || attr.value().empty()) { Value zero = rewriter->create( loc, size_ty, rewriter->getIntegerAttr(size_ty, 0)); - Value null_ptr = rewriter->create(loc, element_ptr_ty); + Value null_ptr = rewriter->create(loc, ptr_ty); return std::make_pair(zero, null_ptr); } @@ -78,12 +78,12 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { Value array_size = rewriter->create( loc, size_ty, rewriter->getIntegerAttr(size_ty, array_attr.size())); Value array_ptr = rewriter->create( - loc, element_ptr_ty, array_size, /*alignment=*/0); - for (auto &e : llvm::enumerate(array_attr)) { + loc, ptr_ty, element_ty, array_size, /*alignment=*/0); + for (const auto &e : llvm::enumerate(array_attr)) { Value index = rewriter->create( loc, size_ty, rewriter->getIntegerAttr(size_ty, e.index())); - Value element_ptr = - rewriter->create(loc, element_ptr_ty, array_ptr, index); + Value element_ptr = rewriter->create(loc, ptr_ty, element_ty, + array_ptr, index); Value element = create_element(e.value()); rewriter->create(loc, element, element_ptr); } @@ -169,17 +169,15 @@ class TFAllocOpConverter : public ConvertToLLVMCallOpPattern { Type GetFuncType() const override { Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32); - Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type); - Type llvm_void_ptr_type = getVoidPtrType(); + Type llvm_ptr_type = LLVM::LLVMPointerType::get(getDialect().getContext()); return LLVM::LLVMFunctionType::get( - llvm_void_ptr_type, - llvm::ArrayRef( - {/*void* op_kernel_ctx*/ llvm_void_ptr_type, - /*size_t num_elements*/ getIndexType(), - /*size_t element_size*/ getIndexType(), - /*int32_t output_index*/ llvm_i32_type, - /*int32_t num_candidates*/ llvm_i32_type, - /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type})); + llvm_ptr_type, + llvm::ArrayRef({/*void* op_kernel_ctx*/ llvm_ptr_type, + /*size_t num_elements*/ getIndexType(), + /*size_t element_size*/ getIndexType(), + /*int32_t output_index*/ llvm_i32_type, + /*int32_t num_candidates*/ llvm_i32_type, + /*int32_t* candidate_input_indices*/ llvm_ptr_type})); } private: @@ -193,10 +191,8 @@ class TFAllocOpConverter : public ConvertToLLVMCallOpPattern { rewriter, loc, typeConverter->convertType(memref_type)); // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr. - Value allocated_type_ptr = rewriter.create( - loc, getElementPtrType(memref_type), allocated_byte_ptr); - memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr); - memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr); + memref_desc.setAllocatedPtr(rewriter, loc, allocated_byte_ptr); + memref_desc.setAlignedPtr(rewriter, loc, allocated_byte_ptr); memref_desc.setConstantOffset(rewriter, loc, 0); if (memref_type.getRank() == 0) { @@ -230,9 +226,7 @@ class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern { if (!op.getMemref().getType().isa()) return failure(); MemRefDescriptor memref(adaptor.getMemref()); - Value allocated_bytes_ptr = rewriter.create( - op.getLoc(), getVoidPtrType(), - memref.allocatedPtr(rewriter, op.getLoc())); + Value allocated_bytes_ptr = memref.allocatedPtr(rewriter, op.getLoc()); // Insert function call. FlatSymbolRefAttr tf_func_ref = @@ -296,18 +290,16 @@ class JITCompileFromStrOpConverter StringRef GetFuncName() const override { return kCInterfaceJITCompile; } Type GetFuncType() const override { - auto i8_ptr_ty = - LLVM::LLVMPointerType::get(IntegerType::get(getContext(), 8)); + auto ptr_ty = LLVM::LLVMPointerType::get(getContext()); auto i64_ty = IntegerType::get(getContext(), 64); - Type i64_ptr_ty = LLVM::LLVMPointerType::get(i64_ty); auto i1_ty = IntegerType::get(getContext(), 1); return LLVM::LLVMFunctionType::get( getVoidPtrType(), {/*void* op_kernel_ctx*/ getVoidPtrType(), - /*char* code*/ i8_ptr_ty, + /*char* code*/ ptr_ty, /*int64_t num_tile_sizes*/ i64_ty, - /*int64_t* tile_sizes_ptr*/ i64_ptr_ty, + /*int64_t* tile_sizes_ptr*/ ptr_ty, /*int64_t num_unroll_factors*/ i64_ty, - /*int64_t* unroll_factors_ptr*/ i64_ptr_ty, + /*int64_t* unroll_factors_ptr*/ ptr_ty, /*int64_t max_supported_rank*/ i64_ty, /*bool enable_ftz*/ i1_ty, /*bool index_64bit*/ i1_ty, @@ -331,47 +323,42 @@ class JITExecuteOpConverter : public ConvertToLLVMCallOpPattern { auto loc = op.getLoc(); Type result_ty = getTypeConverter()->convertType(op->getResultTypes().front()); - Type result_ptr_ty = LLVM::LLVMPointerType::get(result_ty); + Type ptr_ty = LLVM::LLVMPointerType::get(getContext()); Type i64_ty = rewriter.getI64Type(); Value one = rewriter.create( loc, i64_ty, rewriter.getI64IntegerAttr(1)); auto result_ptr = - rewriter.create(loc, result_ptr_ty, one, std::nullopt); - Type void_ptr_ty = getVoidPtrType(); - auto result_void_ptr = - rewriter.create(loc, void_ptr_ty, result_ptr); + rewriter.create(loc, ptr_ty, result_ty, one); // Pass the buffer arguments as a stack-allocated array. - Type arg_ptr_ty = - LLVM::LLVMPointerType::get(adaptor.getInputs().front().getType()); + Type args_elem_ty = adaptor.getInputs().front().getType(); Value num_args = rewriter.create( loc, i64_ty, rewriter.getI64IntegerAttr( static_cast(adaptor.getInputs().size()))); - Value args_ptr = rewriter.create(loc, arg_ptr_ty, num_args, - /*alignment=*/0); + Value args_ptr = + rewriter.create(loc, ptr_ty, args_elem_ty, num_args, + /*alignment=*/0); for (const auto &it : llvm::enumerate(adaptor.getInputs())) { Value index = rewriter.create( loc, i64_ty, rewriter.getI64IntegerAttr(it.index())); - Value element_ptr = - rewriter.create(loc, arg_ptr_ty, args_ptr, index); + Value element_ptr = rewriter.create( + loc, ptr_ty, args_elem_ty, args_ptr, index); rewriter.create(loc, it.value(), element_ptr); } - auto args_void_ptr = - rewriter.create(loc, void_ptr_ty, args_ptr); // Materialize runtime call. FlatSymbolRefAttr tf_func_ref = GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter); rewriter.create( loc, std::nullopt, tf_func_ref, - ValueRange{adaptor.getCtx(), adaptor.getCallable(), result_void_ptr, - num_args, args_void_ptr}); + ValueRange{adaptor.getCtx(), adaptor.getCallable(), result_ptr, + num_args, args_ptr}); // Copy result (including the descriptor) to a stack-allocated buffer and // free the old descriptor. llvm::SmallVector final_result = { - rewriter.create(loc, result_ptr)}; + rewriter.create(loc, result_ty, result_ptr)}; if (failed(copyUnrankedDescriptors(rewriter, loc, op->getResultTypes(), final_result, /*toDynamic=*/false))) { @@ -387,13 +374,13 @@ class JITExecuteOpConverter : public ConvertToLLVMCallOpPattern { Type GetFuncType() const override { auto i64_ty = IntegerType::get(getContext(), 64); - auto void_ptr_ty = getVoidPtrType(); + auto ptr_ty = LLVM::LLVMPointerType::get(getContext()); return LLVM::LLVMFunctionType::get(getVoidType(), - {/*void* op_kernel_ctx*/ void_ptr_ty, - /*void* callable*/ void_ptr_ty, - /*void* result*/ void_ptr_ty, + {/*void* op_kernel_ctx*/ ptr_ty, + /*void* callable*/ ptr_ty, + /*void* result*/ ptr_ty, /*int64_t num_args*/ i64_ty, - /*void* args_ptr*/ void_ptr_ty}); + /*void* args_ptr*/ ptr_ty}); } }; @@ -426,10 +413,10 @@ class ReportErrorOpConverter StringRef GetFuncName() const override { return kCInterfaceReportError; } Type GetFuncType() const override { MLIRContext *ctx = &getTypeConverter()->getContext(); - auto i8_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + auto ptr_type = LLVM::LLVMPointerType::get(ctx); auto i32_type = IntegerType::get(ctx, 32); - return LLVM::LLVMFunctionType::get( - getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type}); + return LLVM::LLVMFunctionType::get(getVoidType(), + {getVoidPtrType(), i32_type, ptr_type}); } private: @@ -474,6 +461,7 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Location loc = null_memref_op->getLoc(); LLVMTypeConverter type_converter = *getTypeConverter(); + MLIRContext *ctx = null_memref_op.getContext(); mlir::Operation *op = null_memref_op.getOperation(); auto shaped_result_type = null_memref_op.getType().cast(); @@ -481,9 +469,8 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { shaped_result_type.getMemorySpace().dyn_cast_or_null(); unsigned address_space = static_cast(mem_space ? mem_space.getInt() : 0); - - Type elem_type = shaped_result_type.getElementType(); - Type llvm_elem_type = type_converter.convertType(elem_type); + LLVM::LLVMPointerType llvm_ptr_type = + LLVM::LLVMPointerType::get(ctx, address_space); Value zero = createIndexConstant(rewriter, loc, 0); if (auto result_type = null_memref_op.getType().dyn_cast()) { @@ -497,8 +484,7 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { // Prepare packed args [allocatedPtr, alignedPtr, offset, sizes, strides] // to create a memref descriptor. - Value null = rewriter.create( - loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space)); + Value null = rewriter.create(loc, llvm_ptr_type); SmallVector packed_values{null, null, zero}; packed_values.append(sizes); packed_values.append(strides); @@ -529,21 +515,18 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), desc, addressSpace, sizes); Value underlying_desc_ptr = rewriter.create( - loc, getVoidPtrType(), sizes.front(), std::nullopt); + loc, getVoidPtrType(), IntegerType::get(getContext(), 8), + sizes.front()); // Populate underlying ranked descriptor. - LLVM::LLVMPointerType elem_ptr_ptr_type = LLVM::LLVMPointerType::get( - LLVM::LLVMPointerType::get(llvm_elem_type, address_space)); - - Value null = rewriter.create( - loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space)); + Value null = rewriter.create(loc, llvm_ptr_type); UnrankedMemRefDescriptor::setAllocatedPtr( - rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, null); + rewriter, loc, underlying_desc_ptr, llvm_ptr_type, null); UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), - underlying_desc_ptr, - elem_ptr_ptr_type, null); + underlying_desc_ptr, llvm_ptr_type, + null); UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), - underlying_desc_ptr, elem_ptr_ptr_type, + underlying_desc_ptr, llvm_ptr_type, zero); desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr); @@ -576,8 +559,7 @@ class IsValidMemRefOpConverter rewriter.create(loc, is_empty_shape, is_zero_size); } - Value ptr = rewriter.create( - loc, getVoidPtrType(), desc.allocatedPtr(rewriter, loc)); + Value ptr = desc.allocatedPtr(rewriter, loc); Value null = rewriter.create(loc, getVoidPtrType()); Value is_not_nullptr = rewriter.create( loc, rewriter.getI1Type(), LLVM::ICmpPredicate::ne, ptr, null); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc index 8f7ce2f0a0c..136b278e8c9 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include "llvm/ADT/STLExtras.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project @@ -37,6 +39,7 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" @@ -77,10 +80,7 @@ class ConvertLaunchFuncOpToTfRuntimeCallPattern MLIRContext *context_ = &this->getTypeConverter()->getContext(); Type llvm_void_type_ = LLVM::LLVMVoidType::get(context_); - Type llvm_pointer_type_ = - LLVM::LLVMPointerType::get(IntegerType::get(context_, 8)); - Type llvm_pointer_pointer_type_ = - LLVM::LLVMPointerType::get(llvm_pointer_type_); + Type llvm_pointer_type_ = LLVM::LLVMPointerType::get(context_); Type llvm_int8_type_ = IntegerType::get(context_, 8); Type llvm_int32_type_ = IntegerType::get(context_, 32); Type llvm_int64_type_ = IntegerType::get(context_, 64); @@ -119,25 +119,24 @@ Value ConvertLaunchFuncOpToTfRuntimeCallPattern::generateParamsArray( auto one = builder.create(loc, llvm_int32_type_, builder.getI32IntegerAttr(1)); auto struct_ptr = builder.create( - loc, LLVM::LLVMPointerType::get(struct_type), one, /*alignment=*/0); + loc, llvm_pointer_type_, struct_type, one, /*alignment=*/0); auto array_size = builder.create( loc, llvm_int32_type_, builder.getI32IntegerAttr(num_arguments)); auto array_ptr = builder.create( - loc, llvm_pointer_pointer_type_, array_size, /*alignment=*/0); + loc, llvm_pointer_type_, llvm_pointer_type_, array_size, /*alignment=*/0); auto zero = builder.create(loc, llvm_int32_type_, builder.getI32IntegerAttr(0)); for (auto en : llvm::enumerate(arguments)) { auto index = builder.create( loc, llvm_int32_type_, builder.getI32IntegerAttr(en.index())); auto field_ptr = builder.create( - loc, LLVM::LLVMPointerType::get(argument_types[en.index()]), struct_ptr, + loc, llvm_pointer_type_, struct_type, struct_ptr, ArrayRef{zero, index.getResult()}); builder.create(loc, en.value(), field_ptr); - auto element_ptr = builder.create( - loc, llvm_pointer_pointer_type_, array_ptr, index.getResult()); - auto casted = - builder.create(loc, llvm_pointer_type_, field_ptr); - builder.create(loc, casted, element_ptr); + auto element_ptr = + builder.create(loc, llvm_pointer_type_, llvm_pointer_type_, + array_ptr, index.getResult()); + builder.create(loc, field_ptr, element_ptr); } return array_ptr; } @@ -179,7 +178,7 @@ LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite( name_buffer.append("_blob"); Value module_blob = LLVM::createGlobalString(loc, rewriter, name_buffer.str(), binary_attr.getValue(), - LLVM::Linkage::Internal, false); + LLVM::Linkage::Internal, true); // Make sure the trailing zero is included in the constant. auto kernel_name = launch_op.getKernelName().getValue(); @@ -193,7 +192,7 @@ LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite( .toStringRef(kernel_name_global_name_buffer); auto kernel_name_global = LLVM::createGlobalString( loc, rewriter, kernel_name_global_name, kernel_name_buffer, - LLVM::Linkage::Internal, false); + LLVM::Linkage::Internal, true); // The TensorFlow OpKernelContext is the first argument of the surrounding // LLVMFunc. @@ -208,19 +207,18 @@ LogicalResult ConvertLaunchFuncOpToTfRuntimeCallPattern::matchAndRewrite( if (!function) { PatternRewriter::InsertionGuard guard(rewriter); auto function_type = LLVM::LLVMFunctionType::get( - llvm_void_type_, - { - llvm_pointer_type_, /* void* context */ - llvm_pointer_type_, /* void* module_blob */ - llvm_pointer_type_, /* void* function_name */ - llvm_intptr_type_, /* intptr_t grid_x_dim */ - llvm_intptr_type_, /* intptr_t grid_y_dim */ - llvm_intptr_type_, /* intptr_t grid_z_dim */ - llvm_intptr_type_, /* intptr_t block_x_dim */ - llvm_intptr_type_, /* intptr_t block_y_dim */ - llvm_intptr_type_, /* intptr_t block_z_dim */ - llvm_pointer_pointer_type_, /* void **kernel_params */ - }); + llvm_void_type_, { + llvm_pointer_type_, /* void* context */ + llvm_pointer_type_, /* void* module_blob */ + llvm_pointer_type_, /* void* function_name */ + llvm_intptr_type_, /* intptr_t grid_x_dim */ + llvm_intptr_type_, /* intptr_t grid_y_dim */ + llvm_intptr_type_, /* intptr_t grid_z_dim */ + llvm_intptr_type_, /* intptr_t block_x_dim */ + llvm_intptr_type_, /* intptr_t block_y_dim */ + llvm_intptr_type_, /* intptr_t block_z_dim */ + llvm_pointer_type_, /* void **kernel_params */ + }); rewriter.setInsertionPointToStart( launch_op->getParentOfType().getBody()); function = rewriter.create( @@ -257,15 +255,13 @@ class TFKernelToLLVMPass // Populate type conversions. MLIRContext *ctx = m.getContext(); - // TODO(b/267828330): Migrate to opaque pointers. LowerToLLVMOptions options(&getContext()); - options.useOpaquePointers = false; LLVMTypeConverter type_converter(ctx, options); type_converter.addConversion([&](tf_framework::OpKernelContextType type) { - return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + return LLVM::LLVMPointerType::get(type.getContext()); }); type_converter.addConversion([&](tf_framework::JITCallableType type) { - return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + return LLVM::LLVMPointerType::get(type.getContext()); }); // Populate patterns. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc index a1e6e7b6c3c..1afa6372434 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc @@ -64,6 +64,10 @@ bool IsUnaryTFOperation(Operation *op) { return IsSingleResultTFOperation(op) && op->getNumOperands() == 1; } +bool IsBinaryTFOperation(Operation *op) { + return IsSingleResultTFOperation(op) && op->getNumOperands() == 2; +} + struct TFToJITInvocationsPattern : public RewritePattern { explicit TFToJITInvocationsPattern(MLIRContext *ctx) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} @@ -116,20 +120,30 @@ struct TFToI64JITInvocationForLargeTensorsPattern : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!IsUnaryTFOperation(op) || + if ((!IsUnaryTFOperation(op) && !IsBinaryTFOperation(op)) || !llvm::isa(op->getParentOp())) { return failure(); } // Create large argument condition. auto loc = op->getLoc(); - auto arg = op->getOperands().front(); - auto shape = rewriter.create(loc, arg); - auto num_elems = rewriter.create(loc, shape); + auto arg_1 = op->getOperands().front(); + auto shape_1 = rewriter.create(loc, arg_1); + auto num_elems_1 = rewriter.create(loc, shape_1); Value cst_i32_limit = rewriter.create(loc, i32Limit); Value large_tensor_predicate = rewriter.create( - loc, arith::CmpIPredicate::sgt, num_elems, cst_i32_limit); + loc, arith::CmpIPredicate::sgt, num_elems_1, cst_i32_limit); + if (IsBinaryTFOperation(op)) { + auto arg_2 = op->getOperands().back(); + auto shape_2 = rewriter.create(loc, arg_2); + auto num_elems_2 = rewriter.create(loc, shape_2); + large_tensor_predicate = rewriter.create( + loc, large_tensor_predicate, + // Compare op to check size of the second op + rewriter.create(loc, arith::CmpIPredicate::sgt, + num_elems_2, cst_i32_limit)); + } // Create dispatch code. auto jit_body_builder_fn = [&](OpBuilder &b, Location loc) { @@ -152,9 +166,10 @@ struct TFToI64JITInvocationForLargeTensorsPattern : public RewritePattern { } // Create JIT execute op. + assert(op->getOperands().size() == 1 || op->getOperands().size() == 2); auto jit_execute_op = b.create( loc, op->getResultTypes().front(), /*ctx=*/Value(), - jit_compile_op.getResult(), arg); + jit_compile_op.getResult(), op->getOperands()); b.create(loc, jit_execute_op.getResult()); }; auto aot_body_builder_fn = [&](OpBuilder &b, Location loc) { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc index c1dbe67f5b6..b1c909bb523 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project namespace mlir { namespace kernel_gen { @@ -49,16 +50,18 @@ Value CreateOrFindGlobalStringConstant(Location loc, StringRef global_name, Operation* global_constant = SymbolTable::lookupNearestSymbolFrom( module, b->getStringAttr(global_name)); if (global_constant) { - Value global_ptr = b->create( - loc, cast(global_constant)); + auto global_op = cast(global_constant); + StringRef symbol_name = global_op.getName(); + Type symbol_type = global_op.getType(); + Type ptr_type = LLVM::LLVMPointerType::get(b->getContext()); + Value global_ptr = b->create(loc, ptr_type, symbol_name); Value c0 = b->create(loc, b->getI64Type(), b->getIndexAttr(0)); - return b->create( - loc, LLVM::LLVMPointerType::get(b->getIntegerType(8)), global_ptr, - ValueRange{c0, c0}); + return b->create(loc, ptr_type, symbol_type, global_ptr, + ValueRange{c0, c0}); } return LLVM::createGlobalString(loc, *b, global_name, content, - LLVM::Linkage::Internal, false); + LLVM::Linkage::Internal, true); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index b456496bc5b..586204d9594 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -67,6 +67,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", ], ) @@ -178,11 +179,16 @@ cc_library( name = "tfl_passes", srcs = [ "tfl_passes.cc", + "transforms/convert_metadata.cc", "transforms/convert_tfl_uint8.cc", "transforms/legalize_tfl.cc", "transforms/lower_complex_types.cc", + "transforms/lower_global_tensors.cc", + "transforms/retain_call_once_funcs.cc", + "transforms/strip_metadata.cc", "transforms/strip_quant_types.cc", "transforms/tfl_legalize_patterns.inc", + "transforms/verify_fully_converted.cc", ], hdrs = [ "tfl_passes.h", @@ -202,8 +208,10 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", "@llvm-project//mlir:Transforms", diff --git a/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir new file mode 100644 index 00000000000..7fb03c7728c --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir @@ -0,0 +1,25 @@ +// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(func.func(tosa-tflite-convert-function-metadata))' %s | FileCheck %s + +module attributes {tfl.schema_version = 3 : i32} { + // CHECK: func.func @main( + // CHECK-SAME: %arg0: tensor {ml_program.identifier = "input0"}, + // CHECK-SAME: %arg1: tensor {ml_program.identifier = "input1"} + // CHECK-SAME: ) -> ( + // CHECK-SAME: tensor {ml_program.identifier = "output0"}, + // CHECK-SAME: tensor {ml_program.identifier = "output1"}) + func.func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) attributes { + tf.entry_function = {inputs = "input0,input1", outputs = "output0,output1"} + } { + return %arg0, %arg1 : tensor, tensor + } + + // CHECK: func.func @no_input( + // CHECK-SAME: ) -> ( + // CHECK-SAME: tensor<1xf32> {ml_program.identifier = "output0"}) + func.func @no_input() -> (tensor<1xf32>) attributes { + tf.entry_function = {outputs = "output0"} + } { + %0 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + return %0 : tensor<1xf32> + } +} diff --git a/tensorflow/compiler/mlir/tosa/tests/lower_global_tensors.mlir b/tensorflow/compiler/mlir/tosa/tests/lower_global_tensors.mlir new file mode 100644 index 00000000000..5b8bd2cc3c0 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tests/lower_global_tensors.mlir @@ -0,0 +1,145 @@ +// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(tflite-lower-global-tensors)' %s | FileCheck %s + +module { + // CHECK: ml_program.global private mutable @Variable(dense<1.000000e+00> : tensor<16x16xf32>) + // CHECK-LABEL: func.func @state + func.func @state(%arg0: tensor<16x16xf32>) -> () { + "tfl.call_once"() {session_init_function = "StateInit"} : () -> () + return + } + + func.func private @StateInit() { + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> + "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return + } +} + +// ----- + +module { + // CHECK: ml_program.global private mutable @Variable(dense<1.000000e+00> : tensor<16x16xf32>) + + // CHECK-LABEL: func.func @assign + func.func @assign(%arg0: tensor<16x16xf32>) -> () { + "tfl.call_once"() {session_init_function = "AssignInit"} : () -> () + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + + // CHECK: ml_program.global_store @Variable = %arg0 + "tfl.assign_variable"(%0, %arg0) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return + } + + func.func private @AssignInit() { + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> + "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return + } +} + +// ----- + +module { + // CHECK: ml_program.global private mutable @Variable(dense<1.000000e+00> : tensor<16x16xf32>) + + // CHECK-LABEL: func.func @read + func.func @read(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { + "tfl.call_once"() {session_init_function = "ReadInit"} : () -> () + + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + + // CHECK: %[[LOAD:.+]] = ml_program.global_load @Variable : tensor<16x16xf32> + %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> + return %1 : tensor<16x16xf32> + } + + func.func private @ReadInit() { + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> + "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return + } +} + +// ----- + +module { + // CHECK: ml_program.global private mutable @Variable(dense<2.000000e+00> : tensor<16x16xf32>) + + // CHECK-LABEL: func.func @readAssign + func.func @readAssign(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { + "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> () + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + + // CHECK: %[[LOAD:.+]] = ml_program.global_load @Variable : tensor<16x16xf32> + %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> + + // CHECK: %[[ADD:.+]] = tfl.add %[[LOAD]], %arg0 + %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32> + + // CHECK: ml_program.global_store @Variable = %[[ADD]] + "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return %2 : tensor<16x16xf32> + } + func.func private @ReadAssignInit() { + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.pseudo_const"() {value = dense<2.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> + "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return + } +} + +// ----- + +module { + // CHECK: ml_program.global private mutable @Variable(dense<42> : tensor<2x3xi8>) + // CHECK-LABEL: func.func @readAssignQuant + func.func @readAssignQuant(%arg0: tensor<2x3x!quant.uniform>) -> (tensor<2x3x!quant.uniform>) { + "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> () + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + + // CHECK: %[[ADDR:.+]] = ml_program.global_load @Variable : tensor<2x3xi8> + // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ADDR]] : tensor<2x3xi8> to tensor<2x3x!quant.uniform> + %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<2x3x!quant.uniform> + + // CHECK: %[[ADD:.+]] = tfl.add %[[CAST]], %arg0 {fused_activation_function = "NONE"} + %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<2x3x!quant.uniform> + + // CHECK: %[[CAST2:.+]] = builtin.unrealized_conversion_cast %[[ADD]] : tensor<2x3x!quant.uniform> to tensor<2x3xi8> + // CHECK: ml_program.global_store @Variable = %[[CAST2]] + "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<2x3x!quant.uniform>) -> () + return %2 : tensor<2x3x!quant.uniform> + } + func.func private @ReadAssignInit() { + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.pseudo_const"() {qtype = tensor<2x3x!quant.uniform>, value = dense<42> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform> + "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<2x3x!quant.uniform>) -> () + return + } +} + +// ----- + +module { + // CHECK-label: @nostate + func.func @nostate(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { + "tfl.call_once"() {session_init_function = "NoStateInit"} : () -> () + // CHECK: tfl.var_handle + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + + // CHECK: tfl.read_variable + %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> + + %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32> + + // CHECK: tfl.assign_variable + "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return %2 : tensor<16x16xf32> + } + func.func private @NoStateInit() { + return + } +} + diff --git a/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir new file mode 100644 index 00000000000..c513f2ec936 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir @@ -0,0 +1,16 @@ +// RUN: tf-opt --tfl-to-tosa-pipeline=target-compilation-backend %s | FileCheck %s + +// CHECK: tensor<1x8x8x3xf32> {ml_program.identifier = "a"} +// CHECK-SAME: tensor<1x8x8x3xf32> {ml_program.identifier = "b"} +// CHECK-SAME: tensor<1x8x8x3xf32> {ml_program.identifier = "c"} +// CHECK-SAME: tensor<1x8x8x3xf32> {ml_program.identifier = "d"} +// CHECK-SAME: -> (tensor<1x8x8x3xf32> {ml_program.identifier = "x"}, tensor<1x8x8x3xf32> {ml_program.identifier = "y"}) + +module attributes {tfl.schema_version = 3 : i32} { + func.func @main(%arg0: tensor<1x8x8x3xf32>, %arg1: tensor<1x8x8x3xf32>, %arg2: tensor<1x8x8x3xf32>, %arg3: tensor<1x8x8x3xf32>) -> (tensor<1x8x8x3xf32>, tensor<1x8x8x3xf32>) attributes {tf.entry_function = {inputs = "a,b,c,d", outputs = "x,y"}} { + %0 = tfl.add %arg1, %arg2 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32> + %1 = tfl.add %arg0, %0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32> + %2 = tfl.add %arg3, %0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32> + return %1, %2 : tensor<1x8x8x3xf32>, tensor<1x8x8x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir new file mode 100644 index 00000000000..5719fd35989 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir @@ -0,0 +1,21 @@ +// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(tflite-retain-call-once-funcs)' %s | FileCheck %s + +// CHECK-LABEL: module { +module { + // CHECK-LABEL: @main + func.func @main(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { + // CHECK: "tfl.call_once"() {session_init_function = "NoOp", session_init_function_symbol = @NoOp} : () -> () + "tfl.call_once"() {session_init_function = "NoOp"} : () -> () + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.read_variable"(%0) : (tensor<*x!tf_type.resource>) -> tensor<16x16xf32> + %2 = tfl.add %1, %arg0 {fused_activation_function = "NONE"} : tensor<16x16xf32> + "tfl.assign_variable"(%0, %2) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return %2 : tensor<16x16xf32> + } + func.func private @NoOp() { + %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> + %1 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<16x16xf32>} : () -> tensor<16x16xf32> + "tfl.assign_variable"(%0, %1) : (tensor<*x!tf_type.resource>, tensor<16x16xf32>) -> () + return + } +} diff --git a/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir new file mode 100644 index 00000000000..f2198823a6d --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir @@ -0,0 +1,14 @@ +// RUN: tf-opt --pass-pipeline='builtin.module(tosa-tflite-strip-module-metadata,func.func(tosa-tflite-strip-function-metadata))' %s | FileCheck %s + +// CHECK-LABEL: module { +// CHECK-NOT: tf.schema_version +module attributes {tfl.schema_version = 3 : i32} { + // CHECK: func.func @main + // CHECK-NOT: tf.entry_function + func.func @main(%arg0: tensor<1x8x8x3xf32>) -> tensor<1x8x8x3xf32> attributes {tf.entry_function = {inputs = "input", outputs = "output"}} { + // CHECK-NEXT: tfl.add + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32> + %1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32> + return %1 : tensor<1x8x8x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index c0e5b23e3b8..5cacdf03552 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -7,10 +7,10 @@ // ----- // CHECK-LABEL: test_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<[3, 0, 1, 2]> : tensor<4xi32>} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg1, %[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.conv2d"(%arg0, %[[VAR2]], %[[VAR0]]) {dilation = array, pad = array, stride = array} +// CHECK: %[[VAR3:.*]] = "tosa.conv2d"(%arg0, %[[VAR2]], %[[VAR0]]) <{dilation = array, pad = array, stride = array} func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x16xf32>) -> tensor<1x32x32x16xf32> { %3 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x32x32x8xf32>, tensor<2x2x8x16xf32>) -> tensor<1x32x32x16xf32> func.return %3 : tensor<1x32x32x16xf32> @@ -19,8 +19,8 @@ func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x16xf32> // ----- // CHECK-LABEL: test_depthwise_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.depthwise_conv2d"(%arg0, %arg1, %0) {dilation = array, pad = array, stride = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>} +// CHECK: %[[VAR1:.*]] = "tosa.depthwise_conv2d"(%arg0, %arg1, %0) <{dilation = array, pad = array, stride = array} func.func @test_depthwise_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x2xf32>) -> tensor<1x32x32x16xf32> { %5 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>, tensor<2x2x8x2xf32>) -> tensor<1x32x32x16xf32> %6 = "tf.Identity"(%5) : (tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> @@ -31,9 +31,9 @@ func.func @test_depthwise_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2 // CHECK-LABEL: @test_transpose_conv2d // CHECK-SAME: %[[ARG0:.*]]: tensor<1x32x32x8xf32>, %[[ARG1:.*]]: tensor<1x1x16x8xf32> -// CHECK: %[[CONST:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[RESHAPE:.*]] = "tosa.reshape"(%[[ARG1]]) {new_shape = array} -// CHECK: %[[TRANSPOSE:.*]] = "tosa.transpose_conv2d"(%[[ARG0]], %[[RESHAPE]], %[[CONST]]) {out_pad = array, out_shape = array, stride = array} +// CHECK: %[[CONST:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>} +// CHECK: %[[RESHAPE:.*]] = "tosa.reshape"(%[[ARG1]]) <{new_shape = array} +// CHECK: %[[TRANSPOSE:.*]] = "tosa.transpose_conv2d"(%[[ARG0]], %[[RESHAPE]], %[[CONST]]) <{out_pad = array, out_shape = array, stride = array} // CHECK: return %[[TRANSPOSE]] func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1x16x8xf32>) -> tensor<1x32x32x16xf32> { %3 = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -46,10 +46,10 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1 // CHECK-LABEL: test_conv3d // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4x128x128x8xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x2x4xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<4xf32>} -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>} +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) <{dilation = array, pad = array, stride = array} func.func @test_conv3d(%arg0: tensor<2x4x128x128x8xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x4x64x64x4xf32> { %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 2, 1]} : (tensor<2x4x128x128x8xf32>, tensor<2x3x3x2x4xf32>) -> tensor<2x4x64x64x4xf32> return %0 : tensor<2x4x64x64x4xf32> @@ -61,9 +61,9 @@ func.func @test_conv3d(%arg0: tensor<2x4x128x128x8xf32>, %arg1: tensor<2x3x3x2x4 // CHECK-SAME: %[[VAL_0:.*]]: tensor<3x32x16x16x5xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x5x10xf32> // CHECK-SAME: %[[VAL_2:.*]]: tensor<10xf32>) -> tensor<3x32x16x16x10xf32> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) <{dilation = array, pad = array, stride = array} func.func @test_conv3d_bias(%arg0: tensor<3x32x16x16x5xf32>, %arg1: tensor<2x3x3x5x10xf32>, %bias: tensor<10xf32>) -> tensor<3x32x16x16x10xf32> { %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<3x32x16x16x5xf32>, tensor<2x3x3x5x10xf32>) -> tensor<3x32x16x16x10xf32> %1 = "tf.BiasAdd"(%0, %bias) {data_format = "NHWC", device = ""} : (tensor<3x32x16x16x10xf32>, tensor<10xf32>) -> tensor<3x32x16x16x10xf32> @@ -91,7 +91,7 @@ func.func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> te // ----- // CHECK-LABEL: test_mul -// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} +// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32} func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Mul"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> @@ -137,7 +137,7 @@ func.func @test_rcp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_relu -// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) <{max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} func.func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Relu"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> @@ -146,7 +146,7 @@ func.func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_relu6 -// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} func.func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Relu6"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> @@ -155,9 +155,9 @@ func.func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_leaky_relu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1xf32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.mul"(%arg0, %[[VAR1]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1xf32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1xf32>} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.mul"(%arg0, %[[VAR1]]) <{shift = 0 : i32} // CHECK-DAG: %[[VAR3:.*]] = "tosa.greater_equal"(%arg0, %[[VAR0]]) // CHECK: %[[VAR6:.*]] = "tosa.select"(%[[VAR3]], %arg0, %[[VAR2]]) func.func @test_leaky_relu(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { @@ -168,7 +168,7 @@ func.func @test_leaky_relu(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { // ----- // CHECK-LABEL: test_concat -// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1) <{axis = 0 : i64} func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor %3 = "tf.ConcatV2"(%arg0, %arg1, %2) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor) -> tensor<26x21x3xf32> @@ -241,8 +241,8 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { // ----- // CHECK-LABEL: test_reduce_any -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_any"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_any"(%arg0) <{axis = 0 : i64} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Any"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1> @@ -252,8 +252,8 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- // CHECK-LABEL: test_reduce_all -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_all"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_all"(%arg0) <{axis = 0 : i64} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.All"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1> @@ -263,8 +263,8 @@ func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- // CHECK-LABEL: test_reduce_min -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_min"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_min"(%arg0) <{axis = 0 : i64} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Min"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -274,8 +274,8 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- // CHECK-LABEL: test_reduce_max -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) <{axis = 0 : i64} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Max"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -285,8 +285,8 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- // CHECK-LABEL: test_reduce_sum -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg0) <{axis = 0 : i64} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Sum"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -297,11 +297,11 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_sum_nonzero_axis // CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20x30x40x50xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> // CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<10x20x30x40x50xf32>, tensor<5xi32>) -> tensor<10x20x30x50x40xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array} : (tensor<10x20x30x50x40xf32>) -> tensor<300000x40xf32> -// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 1 : i64} : (tensor<300000x40xf32>) -> tensor<300000x1xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array} : (tensor<300000x1xf32>) -> tensor<10x20x30x50xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<10x20x30x50x40xf32>) -> tensor<300000x40xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) <{axis = 1 : i64}> : (tensor<300000x40xf32>) -> tensor<300000x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> : (tensor<300000x1xf32>) -> tensor<10x20x30x50xf32> // CHECK: return %[[VAL_5]] : tensor<10x20x30x50xf32> func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._user_specified_name = "inp_list"}) -> tensor<10x20x30x50xf32> { %cst = "tf.Const"() {device = "", value = dense<3> : tensor} : () -> tensor @@ -313,10 +313,10 @@ func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._u // ----- // CHECK-LABEL: test_reduce_mean -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.0769230798> : tensor<1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} -// CHECK: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.0769230798> : tensor<1x1xf32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%arg0) <{axis = 0 : i64} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) <{new_shape = array} +// CHECK: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) <{shift = 0 : i32} func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Mean"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -326,8 +326,8 @@ func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- // CHECK-LABEL: test_reduce_product -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_prod"(%arg0) <{axis = 0 : i64} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Prod"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -420,12 +420,12 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_sin // CHECK-SAME: -> tensor<10xf32> func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1xf32>} - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1xf32>} - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() {value = dense<0.159154937> : tensor<1xf32>} - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} + // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>} + // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>} + // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>} + // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>} // CHECK-DAG: %[[IN_SCALED:.+]] = "tosa.mul"(%arg0, %[[IN_SCALE]]) // CHECK-DAG: %[[FLOOR:.+]] = "tosa.floor"(%[[IN_SCALED]]) // CHECK-DAG: %[[SUB1:.+]] = "tosa.sub"(%[[IN_SCALED]], %[[FLOOR]]) @@ -447,13 +447,13 @@ func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_cos // CHECK-SAME: -> tensor<10xf32> func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1xf32>} - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1xf32>} - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() {value = dense<0.159154937> : tensor<1xf32>} - // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() {value = dense<1.57079637> : tensor<1xf32>} - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} + // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>} + // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>} + // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>} + // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1xf32>} + // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>} // CHECK-DAG: %[[IN_TRANSLATE:.+]] = "tosa.add"(%arg0, %[[HALF_PI]]) // CHECK-DAG: %[[IN_SCALED:.+]] = "tosa.mul"(%[[IN_TRANSLATE]], %[[IN_SCALE]]) // CHECK-DAG: %[[FLOOR:.+]] = "tosa.floor"(%[[IN_SCALED]]) @@ -473,6 +473,22 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { // ----- +// CHECK-LABEL: test_sign +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x33xf32> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1xf32>} +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<1x1xf32>} +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1xf32>} +// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_0]], %[[VAL_1]]) +// CHECK: %[[VAL_5:.*]] = "tosa.greater"(%[[VAL_1]], %[[VAL_0]]) +// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_5]], %[[VAL_2]], %[[VAL_1]]) +// CHECK: %[[VAL_7:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_3]], %[[VAL_6]]) +func.func @test_sign(%arg0: tensor<8x33xf32>) -> tensor<8x33xf32> { + %0 = "tf.Sign"(%arg0) : (tensor<8x33xf32>) -> tensor<8x33xf32> + func.return %0 : tensor<8x33xf32> +} + +// ----- + // CHECK-LABEL: test_sigmoid // CHECK: %[[VAR0:.*]] = "tosa.sigmoid"(%arg0) func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { @@ -483,7 +499,7 @@ func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_square -// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} +// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg0) <{shift = 0 : i32} func.func @test_square(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Square"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> @@ -539,7 +555,7 @@ func.func @test_less_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32> // ----- // CHECK-LABEL: test_argmax -// CHECK: %[[VAR0:.*]] = "tosa.argmax"(%arg0) {axis = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.argmax"(%arg0) <{axis = 0 : i64} func.func @test_argmax(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xi32> { %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor %3 = "tf.ArgMax"(%arg0, %2) : (tensor<13x21x3xf32>, tensor) -> tensor<21x3xi32> @@ -549,7 +565,7 @@ func.func @test_argmax(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xi32> { // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, stride = array} func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { %2 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> func.return %2 : tensor<1x32x32x8xf32> @@ -558,7 +574,7 @@ func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32 // ----- // CHECK-LABEL: test_max_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) <{kernel = array, pad = array, stride = array} func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { %2 = "tf.MaxPool"(%arg0) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> func.return %2 : tensor<1x32x32x8xf32> @@ -567,7 +583,7 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32 // ----- // CHECK-LABEL: test_reshape -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { %0 = "tf.Const"() {value = dense<[1, 819]> : tensor<2xi32>} : () -> tensor<2xi32> %3 = "tf.Reshape"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<1x819xf32> @@ -578,7 +594,7 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { // ----- // CHECK-LABEL: test_transpose -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>} // CHECK: %[[VAR1:.*]] = "tosa.transpose"(%arg0, %[[VAR0]]) func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { %2 = "tf.Const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> @@ -589,7 +605,7 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { // ----- // CHECK-LABEL: test_slice -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array} func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { %2 = "tf.Const"() {value = dense<[6, 8, 0]> : tensor<3xi64>} : () -> tensor<3xi64> %3 = "tf.Const"() {value = dense<[4, 11, 1]> : tensor<3xi64>} : () -> tensor<3xi64> @@ -600,10 +616,10 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { // ----- // CHECK-LABEL: test_strided_slice -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = array, start = array} -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} func.func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> { %2 = "tf.Const"() {value = dense<[4, 0, 1]> : tensor<3xi64>} : () -> tensor<3xi64> %3 = "tf.Const"() {value = dense<[13, 21, 3]> : tensor<3xi64>} : () -> tensor<3xi64> @@ -615,7 +631,7 @@ func.func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> { // ----- // CHECK-LABEL: test_select -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} : (tensor<1xi1>) -> tensor<1x1x1xi1> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) <{new_shape = array}> : (tensor<1xi1>) -> tensor<1x1x1xi1> // CHECK: %[[VAR2:.*]] = "tosa.select"(%[[VAR1]], %arg0, %arg1) func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<13x21x3xf32> { %2 = "tf.SelectV2"(%arg2, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -636,7 +652,7 @@ func.func @test_addn(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %ar // ----- // CHECK-LABEL: test_concatv2 -// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) <{axis = 0 : i64} func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<52x21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor %3 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %arg3, %2) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor) -> tensor<52x21x3xf32> @@ -646,8 +662,8 @@ func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, // ----- // CHECK-LABEL: test_stack -// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) <{axis = 0 : i64} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array} func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> { %2 = "tf.Pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> func.return %2 : tensor<4x13x21x3xf32> @@ -656,7 +672,7 @@ func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %a // ----- // CHECK-LABEL: test_unstack -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> { %2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> %3 = "tf.Identity"(%2) : (tensor<32x32x8xf32>) -> tensor<32x32x8xf32> @@ -666,8 +682,8 @@ func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> { // ----- // CHECK-LABEL: test_pad -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1> : tensor<3x2xi32>} -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1> : tensor<3x2xi32>} +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor} // CHECK: %[[VAR1:.*]] = "tosa.pad"(%arg0, %[[VAR0]], %[[PVAL]]) func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> { %2 = "tf.Const"() {value = dense<1> : tensor<3x2xi32>} : () -> tensor<3x2xi32> @@ -678,7 +694,7 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> { // ----- // CHECK-LABEL: test_expand_dims -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor %3 = "tf.ExpandDims"(%arg0, %2) : (tensor<13x21x3xf32>, tensor) -> tensor<1x13x21x3xf32> @@ -688,7 +704,7 @@ func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> // ----- // CHECK-LABEL: test_expand_dims_negative_index -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} func.func @test_expand_dims_negative_index(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x21x3xf32> { %2 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.ExpandDims"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<13x1x21x3xf32> @@ -698,7 +714,7 @@ func.func @test_expand_dims_negative_index(%arg0: tensor<13x21x3xf32>) -> tensor // ----- // CHECK-LABEL: test_shape -// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[13, 21, 3]> : tensor<3xi32>} +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[13, 21, 3]> : tensor<3xi32>} func.func @test_shape() -> tensor<3xi32> { %3 = "tf.Const"() {value = dense<[13, 21, 3]> : tensor<3xi32>} : () -> tensor<3xi32> func.return %3 : tensor<3xi32> @@ -707,7 +723,7 @@ func.func @test_shape() -> tensor<3xi32> { // ----- // CHECK-LABEL: test_rank -// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<3> : tensor} +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<3> : tensor} func.func @test_rank() -> tensor { %3 = "tf.Const"() {value = dense<3> : tensor} : () -> tensor func.return %3 : tensor @@ -716,8 +732,8 @@ func.func @test_rank() -> tensor { // ----- // CHECK-LABEL: test_elu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>} // CHECK-DAG: %[[VAR2:.*]] = "tosa.exp"(%arg0) // CHECK-DAG: %[[VAR4:.*]] = "tosa.sub"(%[[VAR2]], %[[VAR0]]) // CHECK-DAG: %[[VAR6:.*]] = "tosa.greater_equal"(%arg0, %[[VAR1]]) @@ -730,10 +746,12 @@ func.func @test_elu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_softmax -// CHECK-DAG: %[[VAR0:.*]] = "tosa.exp"(%arg0) -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) {axis = 2 : i64} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.reciprocal"(%[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.sub"(%arg0, %[[VAR0]]) +// CHECK-DAG: %[[VAR2:.*]] = "tosa.exp"(%[[VAR1]]) +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reduce_sum"(%[[VAR2]]) <{axis = 2 : i64} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reciprocal"(%[[VAR3]]) +// CHECK: %[[VAR5:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR4]]) <{shift = 0 : i32} func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Softmax"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> @@ -743,9 +761,9 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_log_softmax // CHECK-DAG: %[[VAR0:.*]] = "tosa.exp"(%arg0) -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) {axis = 2 : i64} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) <{axis = 2 : i64} // CHECK-DAG: %[[VAR2:.*]] = "tosa.reciprocal"(%[[VAR1]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR2]]) <{shift = 0 : i32} // CHECK: %[[VAR4:.*]] = "tosa.log"(%[[VAR3]]) func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.LogSoftmax"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -764,10 +782,10 @@ func.func @test_batch_matmul_3d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x3x4 // ----- // CHECK-LABEL: test_batch_matmul_4d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) <{new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} func.func @test_batch_matmul_4d(%arg0: tensor<5x13x21x3xf32>, %arg1: tensor<5x13x3x42xf32>) -> tensor<5x13x21x42xf32> { %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false, device = ""} : (tensor<5x13x21x3xf32>, tensor<5x13x3x42xf32>) -> tensor<5x13x21x42xf32> func.return %0 : tensor<5x13x21x42xf32> @@ -776,10 +794,10 @@ func.func @test_batch_matmul_4d(%arg0: tensor<5x13x21x3xf32>, %arg1: tensor<5x13 // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) <{new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> { %2 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32> func.return %2 : tensor<14x28xf32> @@ -788,7 +806,7 @@ func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> te // ----- // CHECK-LABEL: test_add_scalar -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>} // CHECK: %[[VAR2:.*]] = "tosa.add"(%arg0, %[[VAR0]]) func.func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor @@ -799,8 +817,8 @@ func.func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_add_1d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg1) {axis = 0 : i64} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) {axis = 1 : i64} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg1) <{axis = 0 : i64} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) <{axis = 1 : i64} // CHECK: %[[VAR2:.*]] = "tosa.add"(%arg0, %[[VAR1]]) func.func @test_add_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -812,9 +830,9 @@ func.func @test_add_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) - // ----- // CHECK-LABEL: test_split -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array} +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) <{size = array, start = array} func.func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) { %6 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %7:3 = "tf.Split"(%6, %arg0) : (tensor, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) @@ -845,13 +863,13 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_space_to_batch -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]> : tensor<3x2xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]> : tensor<3x2xi32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>} +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor} // CHECK-DAG: %[[VAR2:.*]] = "tosa.pad"(%arg0, %[[VAR0]], %[[PVAL]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} +// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) <{new_shape = array} func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> { %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32> @@ -862,12 +880,12 @@ func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32 // ----- // CHECK-LABEL: test_batch_to_space -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[3, 1, 2, 0]> : tensor<4xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[3, 1, 2, 0]> : tensor<4xi32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg0, %[[VAR0]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) <{new_shape = array} // CHECK: return %[[VAR5]] func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> { %2 = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> @@ -881,10 +899,10 @@ func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1 // ----- // CHECK-LABEL: test_space_to_depth -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> { %2 = "tf.SpaceToDepth"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> func.return %2 : tensor<1x16x16x32xf32> @@ -893,10 +911,10 @@ func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x3 // ----- // CHECK-LABEL: test_depth_to_space -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> { %2 = "tf.DepthToSpace"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> func.return %2 : tensor<1x64x64x2xf32> @@ -914,7 +932,7 @@ func.func @test_left_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1xi32>) -> te // ----- // CHECK-LABEL: test_right_shift -// CHECK: %[[VAR0:.*]] = "tosa.arithmetic_right_shift"(%arg0, %arg1) {round = false} +// CHECK: %[[VAR0:.*]] = "tosa.arithmetic_right_shift"(%arg0, %arg1) <{round = false} func.func @test_right_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1xi32>) -> tensor<4x4xi32> { %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x4xi32>, tensor<1x1xi32>) -> tensor<4x4xi32> func.return %0 : tensor<4x4xi32> @@ -924,13 +942,13 @@ func.func @test_right_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1xi32>) -> t // CHECK-LABEL: @test_one_hot // CHECK-SAME: %[[ARG0_0:.*]]: tensor<4x4xi32>, %[[ARG1_0:.*]]: tensor, %[[ARG2:.*]]: tensor -// CHECK: %[[RESHAPE_0:.*]] = "tosa.reshape"(%[[ARG1_0]]) {new_shape = array} -// CHECK: %[[TILE:.*]] = "tosa.tile"(%[[RESHAPE_0]]) {multiples = array} -// CHECK: %[[RESHAPE_1:.*]] = "tosa.reshape"(%[[ARG2]]) {new_shape = array} -// CHECK: %[[TILE_0:.*]] = "tosa.tile"(%[[RESHAPE_1]]) {multiples = array} -// CHECK: %[[RESHAPE_2:.*]] = "tosa.reshape"(%[[ARG0_0]]) {new_shape = array} +// CHECK: %[[RESHAPE_0:.*]] = "tosa.reshape"(%[[ARG1_0]]) <{new_shape = array} +// CHECK: %[[TILE:.*]] = "tosa.tile"(%[[RESHAPE_0]]) <{multiples = array} +// CHECK: %[[RESHAPE_1:.*]] = "tosa.reshape"(%[[ARG2]]) <{new_shape = array} +// CHECK: %[[TILE_0:.*]] = "tosa.tile"(%[[RESHAPE_1]]) <{multiples = array} +// CHECK: %[[RESHAPE_2:.*]] = "tosa.reshape"(%[[ARG0_0]]) <{new_shape = array} // CHECK: %[[SCATTER:.*]] = "tosa.scatter"(%[[TILE_0]], %[[RESHAPE_2]], %[[TILE]]) -// CHECK: %[[RESHAPE_3:.*]] = "tosa.reshape"(%[[SCATTER]]) {new_shape = array} +// CHECK: %[[RESHAPE_3:.*]] = "tosa.reshape"(%[[SCATTER]]) <{new_shape = array} // CHECK: return %[[RESHAPE_3]] func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<4x4x2xf32> { %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor @@ -941,18 +959,18 @@ func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tenso // ----- // CHECK-LABEL: test_fakequant_with_min_max_args -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<-2.00003052> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<1.99996948> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<6.10360876E-5> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() {value = dense<16383.75> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<-2.00003052> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<1.99996948> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>} // CHECK-DAG: %[[VAR6:.*]] = "tosa.minimum"(%arg0, %[[VAR1]]) // CHECK-DAG: %[[VAR8:.*]] = "tosa.maximum"(%[[VAR6]], %[[VAR0]]) // CHECK-DAG: %[[VAR10:.*]] = "tosa.sub"(%[[VAR8]], %[[VAR0]]) -// CHECK-DAG: %[[VAR12:.*]] = "tosa.mul"(%[[VAR10]], %[[VAR3]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR12:.*]] = "tosa.mul"(%[[VAR10]], %[[VAR3]]) <{shift = 0 : i32} // CHECK-DAG: %[[VAR14:.*]] = "tosa.add"(%[[VAR12]], %[[VAR4]]) // CHECK-DAG: %[[VAR15:.*]] = "tosa.floor"(%[[VAR14]]) -// CHECK-DAG: %[[VAR17:.*]] = "tosa.mul"(%[[VAR15]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR17:.*]] = "tosa.mul"(%[[VAR15]], %[[VAR2]]) <{shift = 0 : i32} // CHECK: %[[VAR19:.*]] = "tosa.add"(%[[VAR17]], %[[VAR0]]) func.func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 2.000000e+00 : f32, min = -2.000000e+00 : f32, narrow_range = false, num_bits = 16 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -961,10 +979,10 @@ func.func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tenso // ----- // CHECK-LABEL: test_gather -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<1x49xi32>} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<1x49xi32>} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR0]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) <{new_shape = array} // CHECK: return %[[VAR7]] func.func @test_gather(%arg0: tensor<13x21x3xf32>) -> tensor<7x7x21x3xf32> { %0 = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor @@ -976,10 +994,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>) -> tensor<7x7x21x3xf32> { // ----- // CHECK-LABEL: test_gather_nd -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<1x42xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<1x42xi32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.gather"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array} func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>) -> tensor<6x7x21x3xf32> { %0 = "tf.Const"() {device = "", value = dense<[[[0], [5], [3], [12], [2], [4], [3]], [[11], [1], [11], [10], [3], [12], [8]], [[5], [3], [1], [11], [3], [10], [0]], [[0], [8], [4], [7], [3], [12], [2]], [[7], [6], [11], [4], [2], [10], [11]], [[11], [1], [11], [1], [1], [11], [8]]]> : tensor<6x7x1xi32>} : () -> tensor<6x7x1xi32> %1 = "tf.GatherNd"(%arg0, %0) {device = ""} : (tensor<13x21x3xf32>, tensor<6x7x1xi32>) -> tensor<6x7x21x3xf32> @@ -992,16 +1010,16 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>) -> tensor<6x7x21x3xf32> { // CHECK-LABEL: test_fused_batch_norm func.func @test_fused_batch_norm(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e-03> : tensor<1xf32>} - // CHECK: %[[RES0:.+]] = "tosa.reshape"(%arg3) {new_shape = array} + // CHECK: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e-03> : tensor<1xf32>} + // CHECK: %[[RES0:.+]] = "tosa.reshape"(%arg3) <{new_shape = array} // CHECK: %[[SUB0:.+]] = "tosa.sub"(%arg0, %[[RES0]]) // CHECK: %[[ADD0:.+]] = "tosa.add"(%arg4, %[[ONE]]) // CHECK: %[[RSQR:.+]] = "tosa.rsqrt"(%[[ADD0]]) - // CHECK: %[[RES1:.+]] = "tosa.reshape"(%[[RSQR]]) {new_shape = array} - // CHECK: %[[MUL0:.+]] = "tosa.mul"(%[[SUB0]], %[[RES1]]) {shift = 0 : i32} - // CHECK: %[[RES1:.+]] = "tosa.reshape"(%arg1) {new_shape = array} - // CHECK: %[[MUL1:.+]] = "tosa.mul"(%[[MUL0]], %[[RES1]]) {shift = 0 : i32} - // CHECK: %[[RES2:.+]] = "tosa.reshape"(%arg2) {new_shape = array} + // CHECK: %[[RES1:.+]] = "tosa.reshape"(%[[RSQR]]) <{new_shape = array} + // CHECK: %[[MUL0:.+]] = "tosa.mul"(%[[SUB0]], %[[RES1]]) <{shift = 0 : i32} + // CHECK: %[[RES1:.+]] = "tosa.reshape"(%arg1) <{new_shape = array} + // CHECK: %[[MUL1:.+]] = "tosa.mul"(%[[MUL0]], %[[RES1]]) <{shift = 0 : i32} + // CHECK: %[[RES2:.+]] = "tosa.reshape"(%arg2) <{new_shape = array} // CHECK: %[[ADD1:.+]] = "tosa.add"(%[[MUL1]], %[[RES2]]) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) @@ -1022,14 +1040,14 @@ func.func @test_fused_batch_norm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: ten // CHECK-LABEL: mirrorpad_symmetric // CHECK-SAME: %[[VAL_0:.*]]: tensor<5x10xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<5x10xf32>) -// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<5x10xf32>) -// CHECK: %[[VAL_3:.*]] = "tosa.reverse"(%[[VAL_2]]) {axis = 0 : i64} : (tensor<2x10xf32>) -// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_3]]) {axis = 0 : i64} : (tensor<1x10xf32>, tensor<5x10xf32>, tensor<2x10xf32>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<8x10xf32>) -// CHECK: %[[VAL_6:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<8x10xf32>) -// CHECK: %[[VAL_7:.*]] = "tosa.reverse"(%[[VAL_6]]) {axis = 1 : i64} : (tensor<8x2xf32>) -// CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]], %[[VAL_7]]) {axis = 1 : i64} : (tensor<8x1xf32>, tensor<8x10xf32>, tensor<8x2xf32>) +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<5x10xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<5x10xf32>) +// CHECK: %[[VAL_3:.*]] = "tosa.reverse"(%[[VAL_2]]) <{axis = 0 : i64}> : (tensor<2x10xf32>) +// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_3]]) <{axis = 0 : i64}> : (tensor<1x10xf32>, tensor<5x10xf32>, tensor<2x10xf32>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) <{size = array, start = array}> : (tensor<8x10xf32>) +// CHECK: %[[VAL_6:.*]] = "tosa.slice"(%[[VAL_4]]) <{size = array, start = array}> : (tensor<8x10xf32>) +// CHECK: %[[VAL_7:.*]] = "tosa.reverse"(%[[VAL_6]]) <{axis = 1 : i64}> : (tensor<8x2xf32>) +// CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]], %[[VAL_7]]) <{axis = 1 : i64}> : (tensor<8x1xf32>, tensor<8x10xf32>, tensor<8x2xf32>) func.func @mirrorpad_symmetric(%arg0: tensor<5x10xf32>) -> tensor<8x13xf32> { %cst = "tf.Const"() {device = "", value = dense<[[1, 2], [1, 2]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %0 = "tf.MirrorPad"(%arg0, %cst) {device = "", mode = "SYMMETRIC"} : (tensor<5x10xf32>, tensor<2x2xi32>) -> tensor<8x13xf32> @@ -1041,12 +1059,12 @@ func.func @mirrorpad_symmetric(%arg0: tensor<5x10xf32>) -> tensor<8x13xf32> { // CHECK-LABEL: mirrorpad_reflect // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<13x21x3xf32>) -// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]]) {axis = 0 : i64} : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<14x21x3xf32>) -// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_3]], %[[VAL_2]]) {axis = 1 : i64} : (tensor<14x1x3xf32>, tensor<14x21x3xf32>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<14x22x3xf32>) -// CHECK: %[[VAL_6:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]]) {axis = 2 : i64} : (tensor<14x22x1xf32>, tensor<14x22x3xf32>) +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<13x21x3xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]]) <{axis = 0 : i64}> : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) <{size = array, start = array}> : (tensor<14x21x3xf32>) +// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_3]], %[[VAL_2]]) <{axis = 1 : i64}> : (tensor<14x1x3xf32>, tensor<14x21x3xf32>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) <{size = array, start = array}> : (tensor<14x22x3xf32>) +// CHECK: %[[VAL_6:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]]) <{axis = 2 : i64}> : (tensor<14x22x1xf32>, tensor<14x22x3xf32>) func.func @mirrorpad_reflect(%arg0: tensor<13x21x3xf32>) -> tensor<14x22x4xf32> { %cst = "tf.Const"() {device = "", value = dense<[[1, 0], [1, 0], [1, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %0 = "tf.MirrorPad"(%arg0, %cst) {device = "", mode = "REFLECT"} : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<14x22x4xf32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir index 95c8f252767..469728b361f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir @@ -26,7 +26,7 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_matmul // CHECK-DAG: %[[VAR0:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<28xf32>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<28xf32>}> // CHECK: %[[VAR2:.*]] = "tosa.transpose"(%arg1, %[[VAR0]]) // CHECK: %[[VAR3:.*]] = "tosa.fully_connected"(%arg0, %[[VAR2]], %[[VAR1]]) func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index b16b7ffa83b..ddfc7eefe81 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -11,8 +11,8 @@ // ----- // CHECK-LABEL: test_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.conv2d"(%arg0, %arg1, %[[VAR0]]) {dilation = array, pad = array, stride = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK: %[[VAR1:.*]] = "tosa.conv2d"(%arg0, %arg1, %[[VAR0]]) <{dilation = array, pad = array, stride = array}> func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>) -> tensor<*xf32> { %cst = arith.constant dense<0.000000e+00> : tensor<16xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<*xf32> @@ -33,7 +33,7 @@ func.func @test_conv2d_dynamic(%arg0: tensor, %arg1: tensor<16x1x // ----- // CHECK-LABEL: test_conv2d_bias -// CHECK: %[[VAR0:.*]] = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.conv2d"(%arg0, %arg1, %arg2) <{dilation = array, pad = array, stride = array}> // CHECK-SAME: tensor<1x32x32x16xf32> func.func @test_conv2d_bias(%arg0: tensor<1x32x32x8xf32>, %cst: tensor<16x2x2x8xf32>, %cst_0: tensor<16xf32>) -> tensor<*xf32> { %0 = "tfl.conv_2d"(%arg0, %cst, %cst_0) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<*xf32> @@ -43,8 +43,8 @@ func.func @test_conv2d_bias(%arg0: tensor<1x32x32x8xf32>, %cst: tensor<16x2x2x8x // ----- // CHECK-LABEL: test_transpose_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.transpose_conv2d"(%arg0, %arg1, %[[VAR0]]) {out_pad = array, out_shape = array, stride = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK: %[[VAR1:.*]] = "tosa.transpose_conv2d"(%arg0, %arg1, %[[VAR0]]) <{out_pad = array, out_shape = array, stride = array}> func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none @@ -55,9 +55,9 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16 // ----- // CHECK-LABEL: test_transpose_conv2d_relu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.transpose_conv2d"(%arg0, %arg1, %[[VAR0]]) {out_pad = array, out_shape = array, stride = array} -// CHECK: %[[VAR2:.*]] = "tosa.clamp"(%[[VAR1]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK: %[[VAR1:.*]] = "tosa.transpose_conv2d"(%arg0, %arg1, %[[VAR0]]) <{out_pad = array, out_shape = array, stride = array}> +// CHECK: %[[VAR2:.*]] = "tosa.clamp"(%[[VAR1]]) <{max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none @@ -68,9 +68,9 @@ func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tens // ----- // CHECK-LABEL: test_conv2d_qi8 -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<16x2x2x8xi8>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0> : tensor<16xi32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR0]], %[[VAR1]]) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<16x2x2x8xi8>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor<16xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR0]], %[[VAR1]]) <{dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array}> // CHECK: %[[VAR3:.*]] = "tosa.rescale"(%[[VAR2]]) func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x2x2x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, value = dense<42> : tensor<16x2x2x8xi8>} : () -> tensor<16x2x2x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1} >> @@ -82,9 +82,9 @@ func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<16xi48>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<16x1x1x8xi8>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR1]], %[[VAR0]]) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0> : tensor<16xi48>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<16x1x1x8xi8>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR1]], %[[VAR0]]) <{dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array}> // CHECK: %[[VAR3:.*]] = "tosa.rescale"(%[[VAR2]]) func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform> @@ -97,10 +97,9 @@ func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform // CHECK-LABEL: @test_depthwise_conv2d_bias_qi8 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x32x32x8x!quant.uniform> -// CHECK-DAG: %[[CONST:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<16xi32>} -// CHECK-DAG: %[[CONST_0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<1x2x2x16xi8>} -// CHECK-DAG: %[[RESHAPE:.*]] = "tosa.reshape"(%[[CONST_0]]) {new_shape = array} -// CHECK-DAG: %[[DEPTHWISE:.*]] = "tosa.depthwise_conv2d"(%[[ARG0]], %[[RESHAPE]], %[[CONST]]) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} +// CHECK-DAG: %[[CONST:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<16xi32>}> +// CHECK-DAG: %[[CONST_0:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<2x2x8x2xi8>}> +// CHECK-DAG: %[[DEPTHWISE:.*]] = "tosa.depthwise_conv2d"(%[[ARG0]], %[[CONST_0]], %[[CONST]]) <{dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array}> // CHECK: %[[RESCALE:.*]] = "tosa.rescale"(%[[DEPTHWISE]]) // CHECK-SAME: multiplier = array // CHECK-SAME: shift = array @@ -114,6 +113,75 @@ func.func @test_depthwise_conv2d_bias_qi8(%arg0: tensor<1x32x32x8x!quant.uniform // ----- +// CHECK-LABEL: @test_conv2d_grouped_convolution +// CHECK-DAG: %[[INPUT_SLICE_1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_1:.*]] = "tosa.slice"(%arg1) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_1:.*]] = "tosa.slice"(%arg2) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_1:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_1]], %[[FILTER_SLICE_1]], %[[BIAS_SLICE_1]]) <{dilation = array, pad = array, stride = array}> +// CHECK-DAG: %[[INPUT_SLICE_2:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_2:.*]] = "tosa.slice"(%arg1) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_2:.*]] = "tosa.slice"(%arg2) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_2:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_2]], %[[FILTER_SLICE_2]], %[[BIAS_SLICE_2]]) <{dilation = array, pad = array, stride = array}> +// CHECK-DAG: %[[CONCAT:.*]] = "tosa.concat"(%[[CONV_1]], %[[CONV_2]]) <{axis = 3 : i64}> +// CHECK: return %[[CONCAT]] +func.func @test_conv2d_grouped_convolution(%input: tensor<1x4x1x128xf32>, %weights: tensor<128x1x1x64xf32>, %bias: tensor<128xf32>) -> tensor<1x4x1x128xf32> { + %0 = "tfl.conv_2d"(%input, %weights, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x1x128xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>) -> (tensor<1x4x1x128xf32>) + return %0 : tensor<1x4x1x128xf32> +} + +// ----- + +// CHECK-LABEL: @test_conv2d_grouped_strided_convolution +// CHECK-DAG: %[[INPUT_SLICE_1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_1:.*]] = "tosa.slice"(%arg1) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_1:.*]] = "tosa.slice"(%arg2) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_1:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_1]], %[[FILTER_SLICE_1]], %[[BIAS_SLICE_1]]) <{dilation = array, pad = array, stride = array}> +// CHECK-DAG: %[[INPUT_SLICE_2:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_2:.*]] = "tosa.slice"(%arg1) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_2:.*]] = "tosa.slice"(%arg2) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_2:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_2]], %[[FILTER_SLICE_2]], %[[BIAS_SLICE_2]]) <{dilation = array, pad = array, stride = array}> +// CHECK-DAG: %[[INPUT_SLICE_3:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_3:.*]] = "tosa.slice"(%arg1) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_3:.*]] = "tosa.slice"(%arg2) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_3:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_3]], %[[FILTER_SLICE_3]], %[[BIAS_SLICE_3]]) <{dilation = array, pad = array, stride = array}> +// CHECK-DAG: %[[INPUT_SLICE_4:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_4:.*]] = "tosa.slice"(%arg1) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_4:.*]] = "tosa.slice"(%arg2) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_4:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_4]], %[[FILTER_SLICE_4]], %[[BIAS_SLICE_4]]) <{dilation = array, pad = array, stride = array}> +// CHECK-DAG: %[[CONCAT:.*]] = "tosa.concat"(%[[CONV_1]], %[[CONV_2]], %[[CONV_3]], %[[CONV_4]]) <{axis = 3 : i64}> +// CHECK: return %[[CONCAT]] +func.func @test_conv2d_grouped_strided_convolution(%input: tensor<1x3x1x64xf32>, %weights: tensor<512x3x1x16xf32>, %bias: tensor<512xf32>) -> tensor<1x2x1x512xf32> { + %0 = "tfl.conv_2d"(%input, %weights, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x3x1x64xf32>, tensor<512x3x1x16xf32>, tensor<512xf32>) -> (tensor<1x2x1x512xf32>) + return %0 : tensor<1x2x1x512xf32> +} + +// ----- + +// CHECK-LABEL: @test_conv2d_q_grouped_convolution +// CHECK-DAG: %[[BIAS:.*]] = "tosa.const"() <{value = dense<0> : tensor<16xi32>}> +// CHECK-DAG: %[[FILTER:.*]] = "tosa.const"() <{value = dense<42> : tensor<16x1x1x8xi8>}> +// CHECK-DAG: %[[INPUT_SLICE_1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_1:.*]] = "tosa.slice"(%[[FILTER]]) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_1:.*]] = "tosa.slice"(%[[BIAS]]) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_1:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_1]], %[[FILTER_SLICE_1]], %[[BIAS_SLICE_1]]) <{dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array}> +// CHECK-DAG: %[[RESCALE_1:.*]] = "tosa.rescale"(%[[CONV_1]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = true, scale32 = true, shift = array}> +// CHECK-DAG: %[[INPUT_SLICE_2:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[FILTER_SLICE_2:.*]] = "tosa.slice"(%[[FILTER]]) <{size = array, start = array}> +// CHECK-DAG: %[[BIAS_SLICE_2:.*]] = "tosa.slice"(%[[BIAS]]) <{size = array, start = array}> +// CHECK-DAG: %[[CONV_2:.*]] = "tosa.conv2d"(%[[INPUT_SLICE_2]], %[[FILTER_SLICE_2]], %[[BIAS_SLICE_2]]) <{dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array}> +// CHECK-DAG: %[[RESCALE_2:.*]] = "tosa.rescale"(%[[CONV_2]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = true, scale32 = true, shift = array}> +// CHECK-DAG: %[[CONCAT:.*]] = "tosa.concat"(%[[RESCALE_1]], %[[RESCALE_2]]) <{axis = 3 : i64}> +// CHECK: return %[[CONCAT]] + +func.func @test_conv2d_q_grouped_convolution(%input: tensor<1x4x1x16x!quant.uniform>) -> tensor<1x4x1x16x!quant.uniform> { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1} >> + %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<0> : tensor<16xi32>} : () -> tensor<16x!quant.uniform> + %2 = "tfl.conv_2d"(%input, %0, %1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x1x16x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, tensor<16x!quant.uniform>) -> tensor<1x4x1x16x!quant.uniform> + return %2 : tensor<1x4x1x16x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_depthwise_conv2d_bias_inferred func.func @test_depthwise_conv2d_bias_inferred(%arg0: tensor, %arg1 : tensor<1x1x1x16xf32>, %arg2 : tensor<16xf32>) -> tensor { // CHECK: tosa.depthwise_conv2d @@ -127,10 +195,10 @@ func.func @test_depthwise_conv2d_bias_inferred(%arg0: tensor, %ar // CHECK-LABEL: test_conv3d // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x7x7x2xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x2x4xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<4xf32>} -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>}> // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) <{dilation = array, pad = array, stride = array}> func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x2x7x7x4xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<2x2x7x7x2xf32>, tensor<2x3x3x2x4xf32>, none) -> tensor<2x2x7x7x4xf32> @@ -142,10 +210,10 @@ func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32 // CHECK-LABEL: test_conv3d_dynamic // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1x1x8x16xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>}> // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) <{dilation = array, pad = array, stride = array}> func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x1x1x8x16xf32>) -> tensor<*xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x1x1x8x16xf32>, none) -> tensor<*xf32> @@ -158,9 +226,9 @@ func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x // CHECK-SAME: %[[VAL_0:.*]]: tensor<10x3x64x64x12xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<16x2x2x12x8xf32> // CHECK-SAME: %[[VAL_2:.*]]: tensor<8xf32> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) <{dilation = array, pad = array, stride = array}> func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2x2x12x8xf32>, %cst: tensor<8xf32>) -> tensor<10x3x64x64x8xf32> { %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<10x3x64x64x12xf32>, tensor<16x2x2x12x8xf32>, tensor<8xf32>) -> tensor<10x3x64x64x8xf32> func.return %0 : tensor<10x3x64x64x8xf32> @@ -171,20 +239,18 @@ func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2 // CHECK-LABEL: test_conv3d_qi8( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x8x21x17x!quant.uniform> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x17x34xf32>) -> tensor<1x4x8x11x34x!quant.uniform> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<0.0156862643> : tensor<1x1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1.11982894> : tensor<1x1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() {value = dense<-4.000000e+00> : tensor<1x1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<34xf32>} -// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} -// CHECK: %[[VAL_8:.*]] = "tosa.cast"(%[[VAL_0]]) -// CHECK: %[[VAL_9:.*]] = "tosa.sub"(%[[VAL_8]], %[[VAL_2]]) -// CHECK: %[[VAL_10:.*]] = "tosa.mul"(%[[VAL_9]], %[[VAL_3]]) {shift = 0 : i32} -// CHECK: %[[VAL_11:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_7]]) -// CHECK: %[[VAL_12:.*]] = "tosa.conv3d"(%[[VAL_10]], %[[VAL_11]], %[[VAL_6]]) {dilation = array, pad = array, stride = array} -// CHECK: %[[VAL_13:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_4]]) {shift = 0 : i32} -// CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_13]], %[[VAL_5]]) -// CHECK: %[[VAL_15:.*]] = "tosa.cast"(%[[VAL_14]]) +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.0156862643> : tensor<1x1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.11982894> : tensor<1x1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<-4.000000e+00> : tensor<1x1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<34xf32>}> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>}> +// CHECK: %[[VAL_7:.*]] = "tosa.cast"(%[[VAL_0]]) +// CHECK: %[[VAL_8:.*]] = "tosa.mul"(%[[VAL_7]], %[[VAL_2]]) <{shift = 0 : i32}> +// CHECK: %[[VAL_9:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_6]]) +// CHECK: %[[VAL_10:.*]] = "tosa.conv3d"(%[[VAL_8]], %[[VAL_9]], %[[VAL_5]]) <{dilation = array, pad = array, stride = array}> +// CHECK: %[[VAL_11:.*]] = "tosa.mul"(%[[VAL_10]], %[[VAL_3]]) <{shift = 0 : i32}> +// CHECK: %[[VAL_12:.*]] = "tosa.add"(%[[VAL_11]], %[[VAL_4]]) +// CHECK: %[[VAL_13:.*]] = "tosa.cast"(%[[VAL_12]]) func.func @test_conv3d_qi8(%arg0: tensor<1x4x8x21x17x!quant.uniform>, %arg1: tensor<2x3x3x17x34xf32>) -> (tensor<1x4x8x11x34x!quant.uniform>) { %0 = "tfl.dequantize"(%arg0) : (tensor<1x4x8x21x17x!quant.uniform>) -> tensor<1x4x8x21x17xf32> %2 = "tfl.no_value"() {value} : () -> none @@ -232,7 +298,7 @@ func.func @test_sub_unranked(%arg0: tensor<1x21x3xf32>, %arg1: tensor<1x1x1xf32> // ----- // CHECK-LABEL: test_mul -// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} +// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> func.return %0 : tensor<13x21x3xf32> @@ -241,7 +307,7 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te // ----- // CHECK-LABEL: test_mul_unranked -// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} +// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> func.func @test_mul_unranked(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x1x1xf32>) -> tensor<*xf32> { %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -289,7 +355,7 @@ func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> ten // ----- // CHECK-LABEL: test_relu1 -// CHECK: %[[VAL0:.*]] = "tosa.clamp"(%arg0) {max_fp = 1.000000e+00 : f32, max_int = 1 : i64, min_fp = -1.000000e+00 : f32, min_int = -1 : i64} +// CHECK: %[[VAL0:.*]] = "tosa.clamp"(%arg0) <{max_fp = 1.000000e+00 : f32, max_int = 1 : i64, min_fp = -1.000000e+00 : f32, min_int = -1 : i64}> func.func @test_relu1(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.relu_n1_to_1"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %0 : tensor<13x21x3xf32> @@ -298,7 +364,7 @@ func.func @test_relu1(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_relu0To1 -// CHECK: %[[VAL0:.*]] = "tosa.clamp"(%arg0) {max_fp = 1.000000e+00 : f32, max_int = 1 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +// CHECK: %[[VAL0:.*]] = "tosa.clamp"(%arg0) <{max_fp = 1.000000e+00 : f32, max_int = 1 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> func.func @test_relu0To1(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.relu_0_to_1"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %0 : tensor<13x21x3xf32> @@ -307,7 +373,7 @@ func.func @test_relu0To1(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_relu6 -// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> func.func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %0 = "tfl.relu6"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -316,7 +382,7 @@ func.func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_relu6_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> // CHECK-SAME: -> tensor func.func @test_relu6_dynamic(%arg0: tensor) -> tensor { %0 = "tfl.relu6"(%arg0) : (tensor) -> tensor @@ -326,8 +392,8 @@ func.func @test_relu6_dynamic(%arg0: tensor) -> tensor { // ----- // CHECK-LABEL: test_leaky_relu -// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.707330704> : tensor<1x1x1xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.mul"(%arg0, %[[VAR0]]) {shift = 0 : i32} +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.707330704> : tensor<1x1x1xf32>}> +// CHECK: %[[VAR1:.*]] = "tosa.mul"(%arg0, %[[VAR0]]) <{shift = 0 : i32}> // CHECK: %[[VAR2:.*]] = "tosa.maximum"(%[[VAR1]], %arg0) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: return %[[VAR2]] : tensor<13x21x3xf32> func.func @test_leaky_relu(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -338,9 +404,9 @@ func.func @test_leaky_relu(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_prelu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.mul"(%arg0, %[[VAR1]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.mul"(%arg0, %[[VAR1]]) <{shift = 0 : i32}> // CHECK-DAG: %[[VAR3:.*]] = "tosa.greater_equal"(%arg0, %[[VAR0]]) // CHECK: %[[VAR4:.*]] = "tosa.select"(%[[VAR3]], %arg0, %[[VAR2]]) func.func @test_prelu(%arg0: tensor<4x2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x2x3xf32> { @@ -352,20 +418,20 @@ func.func @test_prelu(%arg0: tensor<4x2x3xf32>, %arg1: tensor<2x3xf32>) -> tenso // CHECK-LABEL: test_prelu_qu8 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x4x17x!quant.uniform> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x1x1xi32>} -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<8x4x17xi8>} -// CHECK: %[[VAL_3:.*]] = "tosa.rescale"(%[[VAL_0]]) {double_round = false, input_zp = 128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_4:.*]] = "tosa.rescale"(%[[VAL_3]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_5:.*]] = "tosa.rescale"(%[[VAL_4]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<8x4x17xi8>}> +// CHECK: %[[VAL_3:.*]] = "tosa.rescale"(%[[VAL_0]]) <{double_round = false, input_zp = 128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK: %[[VAL_4:.*]] = "tosa.rescale"(%[[VAL_3]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK: %[[VAL_5:.*]] = "tosa.rescale"(%[[VAL_4]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK: %[[VAL_6:.*]] = "tosa.greater_equal"(%[[VAL_5]], %[[VAL_1]]) : (tensor<1x8x4x17xi32>, tensor<1x1x1x1xi32>) -// CHECK: %[[VAL_7:.*]] = "tosa.rescale"(%[[VAL_2]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_7]]) {new_shape = array} : (tensor<8x4x17xi32>) -// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) {shift = 0 : i32} : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>) -// CHECK: %[[VAL_10:.*]] = "tosa.rescale"(%[[VAL_9]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_11:.*]] = "tosa.rescale"(%[[VAL_4]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[VAL_7:.*]] = "tosa.rescale"(%[[VAL_2]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_7]]) <{new_shape = array}> : (tensor<8x4x17xi32>) +// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) <{shift = 0 : i32}> : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>) +// CHECK: %[[VAL_10:.*]] = "tosa.rescale"(%[[VAL_9]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK: %[[VAL_11:.*]] = "tosa.rescale"(%[[VAL_4]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK: %[[VAL_12:.*]] = "tosa.select"(%[[VAL_6]], %[[VAL_11]], %[[VAL_10]]) -// CHECK: %[[VAL_13:.*]] = "tosa.rescale"(%[[VAL_12]]) {double_round = true, input_zp = 5 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_14:.*]] = "tosa.rescale"(%[[VAL_13]]) {double_round = false, input_zp = 5 : i32, multiplier = array, output_zp = 133 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[VAL_13:.*]] = "tosa.rescale"(%[[VAL_12]]) <{double_round = true, input_zp = 5 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK: %[[VAL_14:.*]] = "tosa.rescale"(%[[VAL_13]]) <{double_round = false, input_zp = 5 : i32, multiplier = array, output_zp = 133 : i32, per_channel = false, scale32 = true, shift = array}> func.func @test_prelu_qu8(%arg0: tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> { %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x8x4x17x!quant.uniform>} : (tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<8x4x17x!quant.uniform:f32, 0.023982547223567963>>, value = dense<"0x191D0557FF212FA1137FDE2B247CE8BA2A8B2213F6B109FA12232EC613FEEE03EF2D265BE5E4F6CB0E09F7F0A95606DA1709EDE632D0F92A2002E98E61F9213997D3FCEBFA0D2DFC4DD00D0700C60C0705F3CFCB01D30C3617C7144C294DAE27061A62E70665021AF50827F40EC9E0172D42B9FB01FB076A09553006F7F710211A031EC9F11BCF130FCC1906D5FED8E5F64E06EAEAFEFD2515F20BB6E3401023C89DFCF8DEC0390B37D8CA2001E1F7BC270ADDE92DFC6D230CE1FEEE1DE8F90ABF9E3ECAEEBC311DF6FDE41F0E31ED0AC309B3121533E7EC2D1B0F1E04D44513E627F4ED5E491D10E53EEA45FF23E31D11D1DE2E0A3B1015AF06102329DEED5C1C180402000B0D071BF0D4FBC0DE0C3BF012E018D80716351D1922F8D508CF2708BA0CEAFE14E4972732FDFD283ED9342A1506F4F137200A12F436D6C9EC071FBCBDEBF4F8051426B8201EC410F9C3C7EFF7CD04D7AC34E2F9D73A5A05CFFA0FF7FD21D6BBEA03F16AF8330C1105285605C9FFE72BE04726DA06F2DCDCDC14C1310CF4E32F06BE0941420B10C9293DD10EFE28D4D20716E6E6EE0A101FFE3AAF1716120EF62FECEBC0F0D72A0903F9E74425EDF82E290E0413BB69F3F45AF30A22D4D024411B4D243BE13FB9CBE0F5FA16A1D7532007AEF62837C42406E3ED3CCE0408CA1C0CFA18B40C0BF7261E06D3E504B8E714BCF6F010DB12373739E200E609E9DAEF1922A2C338FEF2C519F0E5101E2AE917DCA3FA27D245DD10F0EBCE"> : tensor<8x4x17xi8>} : () -> tensor<8x4x17x!quant.uniform:f32, 0.023982547223567963>> @@ -378,15 +444,15 @@ func.func @test_prelu_qu8(%arg0: tensor<1x8x4x17x!quant.uniform> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x1x1xi32>} -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<8x4x17xi8>} -// CHECK: %[[VAL_3:.*]] = "tosa.rescale"(%[[VAL_0]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<8x4x17xi8>}> +// CHECK: %[[VAL_3:.*]] = "tosa.rescale"(%[[VAL_0]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK: %[[VAL_4:.*]] = "tosa.greater_equal"(%[[VAL_3]], %[[VAL_1]]) : (tensor<1x8x4x17xi32>, tensor<1x1x1x1xi32>) -// CHECK: %[[VAL_5:.*]] = "tosa.rescale"(%[[VAL_2]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = array} : (tensor<8x4x17xi32>) -// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_6]]) {shift = 0 : i32} : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>) -// CHECK: %[[VAL_8:.*]] = "tosa.rescale"(%[[VAL_7]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 1 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_9:.*]] = "tosa.rescale"(%[[VAL_0]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 1 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[VAL_5:.*]] = "tosa.rescale"(%[[VAL_2]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK: %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_5]]) <{new_shape = array}> : (tensor<8x4x17xi32>) +// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_6]]) <{shift = 0 : i32}> : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>) +// CHECK: %[[VAL_8:.*]] = "tosa.rescale"(%[[VAL_7]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 1 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK: %[[VAL_9:.*]] = "tosa.rescale"(%[[VAL_0]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 1 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK: %[[VAL_10:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_9]], %[[VAL_8]]) func.func @test_prelu_qi8(%arg0: tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>>, value = dense<"0xDAFDEBC120CBE1E028231F05CF04F52484B2F0AC0041E618200308F820FE308FFCF2E1E02A06D00606FB1044C928D8D811E3FCCE350E25C4DE2B0D00E20AC1E215940D0D12C809290D480FE9E2DB26E31E50F5F4FDD31EFF21C210E717E187144F27C848E820C5D503E31729218D96D2D6D3D9C43BF13014EFCB043631AE4403FE2D4CDF1F16E2D13BA20AE92CEAB7323405F728CF3DF4E9BBFAFEFEE120ECA7FA120609030FF0FCF0E5D40939172EE7E256BADEC5ECFFB32C35F4E936E2F8092FE2E3EFE22B0C02F5EE1D36DE03CBE02FF346081C30ED882AECCAF4E4E3361604EABF133CB6371DDAFCDA4F2D32034A270BF0120A0048131331E50D11CAEB1DEE0ADFC0F12531E8351DD7BDEB2821FF3ECC34F8D42EE4D6FF2AE5FEEDFC3DF7463CED10192CE4B728151827A92E000EE31CF3C5DF193DAC2836181BD916D339E914192B14F0163C58C500BDC6BAEFFB03EC33DA24E7FF0E292CE30504B3070AB5FDE6D7E7CB4CB0D818F90919EAEF5DFDF2DB6C4132DF8EF2E40AF7EA04F1D496F22F2971420FF01D012E2954D5081C0AF2C5E5DED2CCD8C6157416201AFF3A2B29FBDD9EF06340B021F45C322A202DDD86111EBDF44BE9110E29F3FE7FDEDDFB5FDEDBD933E2ED0DD4E21C4BC6FD28E31934C821CE10F61C12740A100F1BE205CC01434BD7E3FB14F01CE0E406710022E464E0F0D8FB3D01C733C9C94017FAC50BE812D202E2B10C04E70AF326CEFD0DE20ABD153D3D14171C34061DE5FC5A"> : tensor<8x4x17xi8>} : () -> tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>> @@ -434,7 +500,7 @@ func.func @test_reduce_sum_axis_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> ten // CHECK-LABEL: test_reduce_all_axis_1_keep_true // CHECK-SAME: %[[VAL_0:.+]]: tensor<1x4x8x19xi1> -// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) {axis = 1 : i64} : (tensor<1x4x8x19xi1>) -> tensor<1x1x8x19xi1> +// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x4x8x19xi1>) -> tensor<1x1x8x19xi1> func.func @test_reduce_all_axis_1_keep_true(%arg0: tensor<1x4x8x19xi1>) -> tensor<1x1x8x19xi1> { %cst = arith.constant dense<1> : tensor<1xi32> %0 = "tfl.reduce_all"(%arg0, %cst) {keep_dims = true} : (tensor<1x4x8x19xi1>, tensor<1xi32>) -> tensor<1x1x8x19xi1> @@ -445,8 +511,8 @@ func.func @test_reduce_all_axis_1_keep_true(%arg0: tensor<1x4x8x19xi1>) -> tenso // CHECK-LABEL: test_reduce_all_axis_1_keep_false // CHECK-SAME: %[[VAL_0:.+]]: tensor<1x4x8x19xi1> -// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) {axis = 1 : i64} : (tensor<1x4x8x19xi1>) -> tensor<1x1x8x19xi1> -// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor<1x1x8x19xi1>) -> tensor<1x8x19xi1> +// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) <{axis = 1 : i64}> : (tensor<1x4x8x19xi1>) -> tensor<1x1x8x19xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<1x1x8x19xi1>) -> tensor<1x8x19xi1> func.func @test_reduce_all_axis_1_keep_false(%arg0: tensor<1x4x8x19xi1>) -> tensor<1x8x19xi1> { %cst = arith.constant dense<1> : tensor<1xi32> %0 = "tfl.reduce_all"(%arg0, %cst) {keep_dims = false} : (tensor<1x4x8x19xi1>, tensor<1xi32>) -> tensor<1x8x19xi1> @@ -457,7 +523,7 @@ func.func @test_reduce_all_axis_1_keep_false(%arg0: tensor<1x4x8x19xi1>) -> tens // CHECK-LABEL: test_reduce_all_axis_2_keep_true // CHECK-SAME: %[[VAL_0:.+]]: tensor<1x4x8x19xi1> -// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) {axis = 2 : i64} : (tensor<1x4x8x19xi1>) -> tensor<1x4x1x19xi1> +// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) <{axis = 2 : i64}> : (tensor<1x4x8x19xi1>) -> tensor<1x4x1x19xi1> func.func @test_reduce_all_axis_2_keep_true(%arg0: tensor<1x4x8x19xi1>) -> tensor<1x4x1x19xi1> { %cst = arith.constant dense<2> : tensor<1xi32> %0 = "tfl.reduce_all"(%arg0, %cst) {keep_dims = true} : (tensor<1x4x8x19xi1>, tensor<1xi32>) -> tensor<1x4x1x19xi1> @@ -468,8 +534,8 @@ func.func @test_reduce_all_axis_2_keep_true(%arg0: tensor<1x4x8x19xi1>) -> tenso // CHECK-LABEL: test_reduce_all_axis_2_keep_false // CHECK-SAME: %[[VAL_0:.+]]: tensor<1x4x8x19xi1> -// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) {axis = 2 : i64} : (tensor<1x4x8x19xi1>) -> tensor<1x4x1x19xi1> -// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor<1x4x1x19xi1>) -> tensor<1x4x19xi1> +// CHECK: %[[VAL_1:.*]] = "tosa.reduce_all"(%[[VAL_0]]) <{axis = 2 : i64}> : (tensor<1x4x8x19xi1>) -> tensor<1x4x1x19xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<1x4x1x19xi1>) -> tensor<1x4x19xi1> func.func @test_reduce_all_axis_2_keep_false(%arg0: tensor<1x4x8x19xi1>) -> tensor<1x4x19xi1> { %cst = arith.constant dense<2> : tensor<1xi32> %0 = "tfl.reduce_all"(%arg0, %cst) {keep_dims = false} : (tensor<1x4x8x19xi1>, tensor<1xi32>) -> tensor<1x4x19xi1> @@ -479,8 +545,8 @@ func.func @test_reduce_all_axis_2_keep_false(%arg0: tensor<1x4x8x19xi1>) -> tens // ----- // CHECK-LABEL: test_reduce_any -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_any"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_any"(%arg0) <{axis = 0 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_any"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1> @@ -490,8 +556,8 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- // CHECK-LABEL: test_reduce_min -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_min"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_min"(%arg0) <{axis = 0 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -501,8 +567,8 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- // CHECK-LABEL: test_reduce_max -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) <{axis = 0 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -512,21 +578,23 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- // CHECK-LABEL: test_reduce_sum -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg0) <{axis = 0 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.sum"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> func.return %0 : tensor<21x3xf32> } +// ----- + // CHECK-LABEL: test_reduce_sum_nonzero_axis // CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20x30x40x50xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> // CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<10x20x30x40x50xf32>, tensor<5xi32>) -> tensor<10x20x30x50x40xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array} : (tensor<10x20x30x50x40xf32>) -> tensor<300000x40xf32> -// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 1 : i64} : (tensor<300000x40xf32>) -> tensor<300000x1xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array} : (tensor<300000x1xf32>) -> tensor<10x20x30x50xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<10x20x30x50x40xf32>) -> tensor<300000x40xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) <{axis = 1 : i64}> : (tensor<300000x40xf32>) -> tensor<300000x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> : (tensor<300000x1xf32>) -> tensor<10x20x30x50xf32> // CHECK: return %[[VAL_5]] : tensor<10x20x30x50xf32> func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._user_specified_name = "inp_list"}) -> tensor<10x20x30x50xf32> { %cst = arith.constant dense<3> : tensor @@ -536,16 +604,14 @@ func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._u // ----- -// ----- - // CHECK-LABEL: test_reduce_sum_5D func.func @test_reduce_sum_5D(%arg0: tensor<4x5x6x7x8xf32>) -> tensor<6x8xf32> { %cst = arith.constant dense<[0, 1, 3]> : tensor<3xi32> - // CHECK-DAG: %[[PERM:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3]> : tensor<5xi32>} + // CHECK-DAG: %[[PERM:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3]> : tensor<5xi32>}> // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg0, %[[PERM]]) - // CHECK-DAG: %[[RESHAPE0:.+]] = "tosa.reshape"(%[[TRANSPOSE:.+]]) {new_shape = array} - // CHECK-DAG: %[[REDUCE:.+]] = "tosa.reduce_sum"(%[[RESHAPE0]]) {axis = 1 : i64} - // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = array} + // CHECK-DAG: %[[RESHAPE0:.+]] = "tosa.reshape"(%[[TRANSPOSE:.+]]) <{new_shape = array}> + // CHECK-DAG: %[[REDUCE:.+]] = "tosa.reduce_sum"(%[[RESHAPE0]]) <{axis = 1 : i64}> + // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[REDUCE]]) <{new_shape = array}> %0 = "tfl.sum"(%arg0, %cst) {keep_dims = false} : (tensor<4x5x6x7x8xf32>, tensor<3xi32>) -> tensor<6x8xf32> func.return %0 : tensor<6x8xf32> } @@ -553,10 +619,10 @@ func.func @test_reduce_sum_5D(%arg0: tensor<4x5x6x7x8xf32>) -> tensor<6x8xf32> { // ----- // CHECK-LABEL: test_reduce_mean -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.0769230798> : tensor<1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} -// CHECK: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.0769230798> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%arg0) <{axis = 0 : i64}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) <{new_shape = array}> +// CHECK: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) <{shift = 0 : i32}> func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.mean"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -576,8 +642,8 @@ func.func @test_reduce_mean_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> tensor< // ----- // CHECK-LABEL: test_reduce_product -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_prod"(%arg0) <{axis = 0 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -679,10 +745,38 @@ func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_rsqrt -// CHECK: %[[VAR0:.*]] = "tosa.rsqrt"(%arg0) -func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.rsqrt"(%[[VAL_0]]) : (tensor<13x21x3xf32>) +func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_rsqrt_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_2:.*]] = "tosa.table"(%[[VAL_0]], %[[VAL_1]]) +func.func @test_rsqrt_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.rsqrt"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_sign +// CHECK-SAME: %[[VAL_0:.*]]: tensor<21x45xi32> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1xi32>}> +// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_0]], %[[VAL_1]]) +// CHECK: %[[VAL_5:.*]] = "tosa.greater"(%[[VAL_1]], %[[VAL_0]]) +// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_5]], %[[VAL_2]], %[[VAL_1]]) +// CHECK: %[[VAL_7:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_3]], %[[VAL_6]]) +func.func @test_sign(%arg0: tensor<21x45xi32>) -> tensor<21x45xi32> { + %0 = "tfl.sign"(%arg0) : (tensor<21x45xi32>) -> tensor<21x45xi32> + func.return %0 : tensor<21x45xi32> } // ----- @@ -690,12 +784,12 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_sin // CHECK-SAME: -> tensor<10xf32> func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1xf32>} - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1xf32>} - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() {value = dense<0.159154937> : tensor<1xf32>} - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} + // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> + // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> + // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> + // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> + // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> + // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> // CHECK-DAG: %[[IN_SCALED:.+]] = "tosa.mul"(%arg0, %[[IN_SCALE]]) // CHECK-DAG: %[[FLOOR:.+]] = "tosa.floor"(%[[IN_SCALED]]) // CHECK-DAG: %[[SUB1:.+]] = "tosa.sub"(%[[IN_SCALED]], %[[FLOOR]]) @@ -717,13 +811,13 @@ func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_cos // CHECK-SAME: -> tensor<10xf32> func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1xf32>} - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1xf32>} - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() {value = dense<0.159154937> : tensor<1xf32>} - // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() {value = dense<1.57079637> : tensor<1xf32>} - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} + // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> + // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> + // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> + // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> + // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> + // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1xf32>}> + // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> // CHECK-DAG: %[[IN_TRANSLATE:.+]] = "tosa.add"(%arg0, %[[HALF_PI]]) // CHECK-DAG: %[[IN_SCALED:.+]] = "tosa.mul"(%[[IN_TRANSLATE]], %[[IN_SCALE]]) // CHECK-DAG: %[[FLOOR:.+]] = "tosa.floor"(%[[IN_SCALED]]) @@ -745,27 +839,27 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_atan2 // CHECK-SAME: -> tensor<13x21x3xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() {value = dense<1.57079637> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() {value = dense<3.14159274> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} : () -> tensor<513xi16> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.14159274> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> : () -> tensor<513xi16> // CHECK: %[[VAL_10:.*]] = "tosa.abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_11:.*]] = "tosa.abs"(%arg1) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_12:.*]] = "tosa.minimum"(%[[VAL_10]], %[[VAL_11]]) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_13:.*]] = "tosa.maximum"(%[[VAL_10]], %[[VAL_11]]) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_14:.*]] = "tosa.reciprocal"(%[[VAL_13]]) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.mul"(%[[VAL_14]], %[[VAL_12]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_2]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.mul"(%[[VAL_14]], %[[VAL_12]]) <{shift = 0 : i32}> : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_2]]) <{shift = 0 : i32}> : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_16]], %[[VAL_3]]) : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_4]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_4]]) <{shift = 0 : i32}> : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_19:.*]] = "tosa.cast"(%[[VAL_18]]) : (tensor<13x21x3xf32>) -> tensor<13x21x3xi16> // CHECK: %[[VAL_20:.*]] = "tosa.table"(%[[VAL_19]], %[[VAL_9]]) : (tensor<13x21x3xi16>, tensor<513xi16>) -> tensor<13x21x3xi32> // CHECK: %[[VAL_21:.*]] = "tosa.cast"(%[[VAL_20]]) : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_22:.*]] = "tosa.mul"(%[[VAL_21]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.mul"(%[[VAL_21]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_23:.*]] = "tosa.sub"(%[[VAL_6]], %[[VAL_22]]) : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_24:.*]] = "tosa.greater"(%[[VAL_10]], %[[VAL_11]]) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_25:.*]] = "tosa.select"(%[[VAL_24]], %[[VAL_23]], %[[VAL_22]]) : (tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -794,7 +888,7 @@ func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_square -// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} +// CHECK: %[[VAR0:.*]] = "tosa.mul"(%arg0, %arg0) <{shift = 0 : i32}> func.func @test_square(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %0 = "tfl.square"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -868,7 +962,7 @@ func.func @test_less_equal_dynamic(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x? // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -877,7 +971,7 @@ func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_avg_pool2d_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> func.func @test_avg_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -886,7 +980,7 @@ func.func @test_avg_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- // CHECK-LABEL: test_max_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -895,7 +989,7 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_max_pool2d_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> func.func @test_max_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32> { %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -904,7 +998,7 @@ func.func @test_max_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- // CHECK-LABEL: test_reshape -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 819]> : tensor<2xi32> %0 = "tfl.reshape"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<*xf32> @@ -914,7 +1008,7 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_reshape_unknown -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK-SAME: -> tensor<9x91xf32> func.func @test_reshape_unknown(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[9, -1]> : tensor<2xi32> @@ -925,7 +1019,7 @@ func.func @test_reshape_unknown(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_reshape_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK-SAME: -> tensor<3x?xf32> func.func @test_reshape_dynamic(%arg0: tensor<13x21x?xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[3, -1]> : tensor<2xi32> @@ -936,7 +1030,7 @@ func.func @test_reshape_dynamic(%arg0: tensor<13x21x?xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_transpose -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> // CHECK: %[[VAR1:.*]] = "tosa.transpose"(%arg0, %[[VAR0]]) func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[2, 0, 1]> : tensor<3xi32> @@ -947,7 +1041,7 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_transpose -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> // CHECK: %[[VAR1:.*]] = "tosa.transpose"(%arg0, %[[VAR0]]) func.func @test_transpose(%arg0: tensor<13x?x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[2, 0, 1]> : tensor<3xi32> @@ -958,7 +1052,7 @@ func.func @test_transpose(%arg0: tensor<13x?x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_slice -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[6, 8, 0]> : tensor<3xi32> %cst_0 = arith.constant dense<[4, 11, 1]> : tensor<3xi32> @@ -969,10 +1063,10 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_strided_slice_simple -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = array, start = array} -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -984,10 +1078,10 @@ func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32 // ----- // CHECK-LABEL: test_strided_slice_simple_negative -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = array, start = array} -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_strided_slice_simple_negative(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, -3, 3]> : tensor<3xi32> @@ -999,8 +1093,8 @@ func.func @test_strided_slice_simple_negative(%arg0: tensor<13x21x3xf32>) -> ten // ----- // CHECK-LABEL: test_strided_slice_strideless -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -1012,10 +1106,10 @@ func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<* // ----- // CHECK-LABEL: test_strided_slice_shrink -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = array, start = array} -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) <{size = array, start = array}> +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_strided_slice_shrink(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -1027,8 +1121,8 @@ func.func @test_strided_slice_shrink(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32 // ----- // CHECK-LABEL: test_strided_slice_shrink_ignore_stride -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_strided_slice_shrink_ignore_stride(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -1041,8 +1135,8 @@ func.func @test_strided_slice_shrink_ignore_stride(%arg0: tensor<13x21x3xf32>) - // CHECK-LABEL: test_strided_slice_unstrided // CHECK-SAME: -> tensor<9x21x2xf32> -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) {axis = 2 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) <{axis = 2 : i64}> // CHECK: return %[[VAR1]] func.func @test_strided_slice_unstrided(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> @@ -1056,8 +1150,8 @@ func.func @test_strided_slice_unstrided(%arg0: tensor<13x21x3xf32>) -> tensor<*x // CHECK-LABEL: test_strided_slice_unstrided_shorter // CHECK: -> tensor<9x21x3xf32> -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) {axis = 1 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) <{axis = 1 : i64}> // CHECK: return %[[VAR1]] func.func @test_strided_slice_unstrided_shorter(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0]> : tensor<2xi32> @@ -1071,8 +1165,8 @@ func.func @test_strided_slice_unstrided_shorter(%arg0: tensor<13x21x3xf32>) -> t // CHECK-LABEL: test_strided_slice_dynamic_masked // CHECK-SAME: -> tensor<10x?x?xf32> -// CHECK: %[[VAR0:.*]] = "tosa.reverse"(%arg0) {axis = 1 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) {axis = 2 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.reverse"(%arg0) <{axis = 1 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) <{axis = 2 : i64}> // CHECK: return %[[VAR1]] func.func @test_strided_slice_dynamic_masked(%arg0: tensor<10x?x?xf32>, %arg1: tensor<3xi32>) -> tensor<*xf32> { %cst_0 = arith.constant dense<[13, -1, 3]> : tensor<3xi32> @@ -1093,8 +1187,8 @@ func.func @test_strided_slice_dynamic_begin(%arg0: tensor<10x?x?xf32>) -> tensor %cst = arith.constant dense<[0, 2, 0]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, -1, 3]> : tensor<3xi32> %cst_1 = arith.constant dense<[1, -1, -1]> : tensor<3xi32> - // CHECK: %[[VAR0:.*]] = "tosa.reverse"(%arg0) {axis = 1 : i64} - // CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) {axis = 2 : i64} + // CHECK: %[[VAR0:.*]] = "tosa.reverse"(%arg0) <{axis = 1 : i64}> + // CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) <{axis = 2 : i64}> // CHECK: return %[[VAR1]] %0 = "tfl.strided_slice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 2 : i32, ellipsis_mask = 0 : i32, end_mask = 7 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -1108,10 +1202,10 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<* %end = arith.constant dense<[7, -1, 6]> : tensor<3xi32> %stride = arith.constant dense<[1, 2, -1]> : tensor<3xi32> - // CHECK: %[[SLICE1:.+]] = "tosa.slice"(%arg0) {size = array, start = array} - // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) {new_shape = array} - // CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) {size = array, start = array} - // CHECK: %[[RESHAPE2:.+]] = "tosa.reshape"(%[[SLICE2]]) {new_shape = array} + // CHECK: %[[SLICE1:.+]] = "tosa.slice"(%arg0) <{size = array, start = array}> + // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) <{new_shape = array}> + // CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) <{size = array, start = array}> + // CHECK: %[[RESHAPE2:.+]] = "tosa.reshape"(%[[SLICE2]]) <{new_shape = array}> %0 = "tfl.strided_slice"(%arg0, %begin, %end, %stride) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 2 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 4 : i32} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> // CHECK: return %[[RESHAPE2]] func.return %0 : tensor<*xf32> @@ -1120,7 +1214,7 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<* // ----- // CHECK-LABEL: test_select -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} : (tensor<1xi1>) -> tensor<1x1x1xi1> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) <{new_shape = array}> : (tensor<1xi1>) -> tensor<1x1x1xi1> // CHECK: %[[VAR2:.*]] = "tosa.select"(%[[VAR1]], %arg0, %arg1) func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<13x21x3xf32> { %0 = "tfl.select_v2"(%arg2, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -1151,7 +1245,7 @@ func.func @test_addn(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %ar // ----- // CHECK-LABEL: test_concatv2 -// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} +// CHECK: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) <{axis = 0 : i64}> func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<52x21x3xf32> { %0 = "tfl.concatenation"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<52x21x3xf32> func.return %0 : tensor<52x21x3xf32> @@ -1160,8 +1254,8 @@ func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, // ----- // CHECK-LABEL: test_stack -// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) <{axis = 0 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> { %0 = "tfl.pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i32, values_count = 4 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> func.return %0 : tensor<4x13x21x3xf32> @@ -1170,9 +1264,9 @@ func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %a // ----- // CHECK-LABEL: test_stack_end -// CHECK-DAG: %[[PERM:.*]] = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} -// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[PERM:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 0]> : tensor<4xi32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1) <{axis = 0 : i64}> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) <{new_shape = array}> // CHECK: %[[TRANSPOSE:.*]] = "tosa.transpose"(%[[VAR1]], %[[PERM]]) func.func @test_stack_end(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3x2xf32> { %0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3x2xf32> @@ -1182,7 +1276,7 @@ func.func @test_stack_end(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32> // ----- // CHECK-LABEL: test_unstack -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -1191,8 +1285,8 @@ func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_pad -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{\[\[}}1, 1], {{\[}}2, 2]]> : tensor<2x2xi32>} -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1], {{\[}}2, 2]]> : tensor<2x2xi32>}> +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> // CHECK: %[[VAR1:.*]] = "tosa.pad"(%arg0, %[[VAR0]], %[[PVAL]]) func.func @test_pad(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> @@ -1206,10 +1300,10 @@ func.func @test_pad(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_pad_v2 // CHECK-SAME: -> tensor<1x257x9x28xf32> func.func @test_pad_v2(%arg0: tensor<1x256x8x25xf32>) -> (tensor<*xf32>) { - // CHECK-DAG: %[[PADDING:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 0], [0, 1], [1, 2]]> : tensor<4x2xi32>} + // CHECK-DAG: %[[PADDING:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 0], [0, 1], [1, 2]]> : tensor<4x2xi32>}> %0 = "tfl.pseudo_const"() {value = dense<[[0, 0], [1, 0], [0, 1], [1, 2]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> - // CHECK-DAG: %[[VAL:.+]] = "tosa.const"() {value = dense<-3.40282347E+38> : tensor} + // CHECK-DAG: %[[VAL:.+]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> %1 = "tfl.pseudo_const"() {value = dense<-3.40282347E+38> : tensor} : () -> tensor // CHECK-DAG: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[PADDING]], %[[VAL]]) : (tensor<1x256x8x25xf32>, tensor<4x2xi32>, tensor) -> tensor<1x257x9x28xf32> @@ -1222,7 +1316,7 @@ func.func @test_pad_v2(%arg0: tensor<1x256x8x25xf32>) -> (tensor<*xf32>) { // ----- // CHECK-LABEL: test_expand_dims -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 13, 21, 3]> : tensor<4xi32> %0 = "tfl.reshape"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<*xf32> @@ -1232,7 +1326,7 @@ func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_shape -// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[13, 21, 3]> : tensor<3xi32>} +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[13, 21, 3]> : tensor<3xi32>}> func.func @test_shape() -> tensor<3xi32> { %cst = arith.constant dense<[13, 21, 3]> : tensor<3xi32> func.return %cst : tensor<3xi32> @@ -1241,7 +1335,7 @@ func.func @test_shape() -> tensor<3xi32> { // ----- // CHECK-LABEL: test_rank -// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<3> : tensor} +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<3> : tensor}> func.func @test_rank() -> tensor { %cst = arith.constant dense<3> : tensor func.return %cst : tensor @@ -1250,8 +1344,8 @@ func.func @test_rank() -> tensor { // ----- // CHECK-LABEL: test_elu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.exp"(%arg0) // CHECK-DAG: %[[VAR4:.*]] = "tosa.sub"(%[[VAR2]], %[[VAR0]]) // CHECK-DAG: %[[VAR6:.*]] = "tosa.greater_equal"(%arg0, %[[VAR1]]) @@ -1264,10 +1358,12 @@ func.func @test_elu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_softmax -// CHECK-DAG: %[[VAR0:.*]] = "tosa.exp"(%arg0) -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) {axis = 2 : i64} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.reciprocal"(%[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) +// CHECK-DAG: %[[VAR1:.*]] = "tosa.sub"(%arg0, %[[VAR0]]) +// CHECK-DAG: %[[VAR2:.*]] = "tosa.exp"(%[[VAR1]]) +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reduce_sum"(%[[VAR2]]) <{axis = 2 : i64}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reciprocal"(%[[VAR3]]) +// CHECK: %[[VAR5:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR4]]) <{shift = 0 : i32}> func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %0 : tensor<13x21x3xf32> @@ -1277,13 +1373,13 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_l2normalization func.func @test_l2normalization(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { - // CHECK-DAG: %[[MIN:.+]] = "tosa.const"() {value = dense<1.08420217E-19> : tensor<1x1xf32>} - // CHECK-DAG: %[[SQR:.+]] = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} - // CHECK-DAG: %[[SUM:.+]] = "tosa.reduce_sum"(%[[SQR]]) {axis = 1 : i64} + // CHECK-DAG: %[[MIN:.+]] = "tosa.const"() <{value = dense<1.08420217E-19> : tensor<1x1xf32>}> + // CHECK-DAG: %[[SQR:.+]] = "tosa.mul"(%arg0, %arg0) <{shift = 0 : i32}> + // CHECK-DAG: %[[SUM:.+]] = "tosa.reduce_sum"(%[[SQR]]) <{axis = 1 : i64}> // CHECK-DAG: %[[MAX:.+]] = "tosa.maximum"(%[[SUM]], %[[MIN]]) // CHECK-DAG: %[[RSQRT:.+]] = "tosa.rsqrt"(%[[MAX]]) // CHECK-DAG: %[[MUL:.+]] = "tosa.mul"(%[[RSQRT]], %arg0) - // CHECK: %[[CLAMP:.+]] = "tosa.clamp"(%[[MUL]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} + // CHECK: %[[CLAMP:.+]] = "tosa.clamp"(%[[MUL]]) <{max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> %0 = "tfl.l2_normalization"(%arg0) {fused_activation_function = "RELU"} : (tensor<16x16xf32>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> } @@ -1292,9 +1388,9 @@ func.func @test_l2normalization(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) // CHECK-LABEL: test_log_softmax // CHECK-DAG: %[[VAR0:.*]] = "tosa.exp"(%arg0) -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) {axis = 2 : i64} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) <{axis = 2 : i64}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.reciprocal"(%[[VAR1]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR2]]) <{shift = 0 : i32}> // CHECK: %[[VAR4:.*]] = "tosa.log"(%[[VAR3]]) func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.log_softmax"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -1304,8 +1400,8 @@ func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<28xf32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<28xf32>}> // CHECK: %[[VAR2:.*]] = "tosa.transpose"(%arg1, %[[VAR0]]) // CHECK: %[[VAR3:.*]] = "tosa.fully_connected"(%arg0, %[[VAR2]], %[[VAR1]]) func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { @@ -1377,10 +1473,10 @@ func.func @test_batch_matmul(%arg0: tensor<1x16x128xf32>, %arg1: tensor<1x128x32 // CHECK-LABEL: @test_batch_matmul_4d func.func @test_batch_matmul_4d(%arg0: tensor<4x5x16x128xf32>, %arg1: tensor<4x5x128x32xf32>) -> (tensor<4x5x16x32xf32> ) { - // CHECK: %[[R0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} - // CHECK: %[[R1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} + // CHECK: %[[R0:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> + // CHECK: %[[R1:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> // CHECK: %[[MM:.*]] = "tosa.matmul"(%[[R0]], %[[R1]]) - // CHECK: "tosa.reshape"(%[[MM]]) {new_shape = array} + // CHECK: "tosa.reshape"(%[[MM]]) <{new_shape = array}> %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<4x5x16x128xf32>, tensor<4x5x128x32xf32>) -> tensor<4x5x16x32xf32> func.return %0 : tensor<4x5x16x32xf32> } @@ -1389,7 +1485,7 @@ func.func @test_batch_matmul_4d(%arg0: tensor<4x5x16x128xf32>, %arg1: tensor<4x5 // CHECK-LABEL: @test_batch_matmul_transpose func.func @test_batch_matmul_transpose(%arg0: tensor<1x16x128xf32>, %arg1: tensor<1x128x32xf32>) -> (tensor<1x32x16xf32> ) { - // CHECK-DAG: %[[PERM:.+]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} + // CHECK-DAG: %[[PERM:.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> // CHECK-DAG: %[[TP0:.+]] = "tosa.transpose"(%arg0, %[[PERM]]) // CHECK-DAG: %[[TP1:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) // CHECK: "tosa.matmul"(%[[TP1]], %[[TP0]]) @@ -1400,7 +1496,7 @@ func.func @test_batch_matmul_transpose(%arg0: tensor<1x16x128xf32>, %arg1: tenso // ----- // CHECK-LABEL: test_add_scalar -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> // CHECK: %[[VAR2:.*]] = "tosa.add"(%arg0, %[[VAR0]]) func.func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<1.000000e+00> : tensor @@ -1411,8 +1507,8 @@ func.func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_add_1d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg1) {axis = 0 : i64} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) {axis = 1 : i64} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg1) <{axis = 0 : i64}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%[[VAR0]]) <{axis = 1 : i64}> // CHECK: %[[VAR2:.*]] = "tosa.add"(%arg0, %[[VAR1]]) func.func @test_add_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[0, 1]> : tensor<2xi32> @@ -1429,7 +1525,7 @@ func.func @test_fused_activation_relun_clamp( %arg1: tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { %cst = arith.constant dense<1.000000e+00> : tensor - // CHECK: "tosa.clamp"(%{{.+}}) {max_fp = 0.000000e+00 : f32, max_int = -67 : i64, min_fp = 0.000000e+00 : f32, min_int = -127 : i64} + // CHECK: "tosa.clamp"(%{{.+}}) <{max_fp = 0.000000e+00 : f32, max_int = -67 : i64, min_fp = 0.000000e+00 : f32, min_int = -127 : i64}> %0 = "tfl.add"(%arg0, %arg0) {fused_activation_function = "RELU6"} : (tensor<10x!quant.uniform>, tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> func.return %0 : tensor<10x!quant.uniform> } @@ -1442,7 +1538,7 @@ func.func @test_fused_activation_relun_noclamp( %arg1: tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { %cst = arith.constant dense<1.000000e+00> : tensor - // CHECK: "tosa.clamp"(%{{.+}}) {max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} + // CHECK: "tosa.clamp"(%{{.+}}) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> %0 = "tfl.add"(%arg0, %arg0) {fused_activation_function = "RELU6"} : (tensor<10x!quant.uniform>, tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> func.return %0 : tensor<10x!quant.uniform> } @@ -1454,7 +1550,7 @@ func.func @test_fused_activation_relun1to1_noclamp( %arg0: tensor<10x!quant.uniform>, %arg1: tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { %cst = arith.constant dense<1.000000e+00> : tensor - // CHECK: "tosa.clamp"(%{{.}}) {max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} + // CHECK: "tosa.clamp"(%{{.}}) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> %0 = "tfl.add"(%arg0, %arg0) {fused_activation_function = "RELU_N1_TO_1"} : (tensor<10x!quant.uniform>, tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> func.return %0 : tensor<10x!quant.uniform> } @@ -1466,7 +1562,7 @@ func.func @test_fused_activation_relun1to1_clamp( %arg0: tensor<10x!quant.uniform>, %arg1: tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { %cst = arith.constant dense<1.000000e+00> : tensor - // CHECK: "tosa.clamp"(%{{.}}) {max_fp = 0.000000e+00 : f32, max_int = 90 : i64, min_fp = 0.000000e+00 : f32, min_int = -110 : i64} + // CHECK: "tosa.clamp"(%{{.}}) <{max_fp = 0.000000e+00 : f32, max_int = 90 : i64, min_fp = 0.000000e+00 : f32, min_int = -110 : i64}> %0 = "tfl.add"(%arg0, %arg0) {fused_activation_function = "RELU_N1_TO_1"} : (tensor<10x!quant.uniform>, tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> func.return %0 : tensor<10x!quant.uniform> } @@ -1474,9 +1570,9 @@ func.func @test_fused_activation_relun1to1_clamp( // ----- // CHECK-LABEL: test_split -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> func.func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) { %cst_0 = arith.constant dense<1> : tensor %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) @@ -1488,13 +1584,13 @@ func.func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor // CHECK-LABEL: test_split_dynamic func.func @test_split_dynamic(%arg0: tensor<13x?x3xf32>) -> (tensor<13x?x3xf32>, tensor<13x?x3xf32>, tensor<13x?x3xf32>) { %cst_0 = arith.constant dense<1> : tensor - // CHECK-DAG: %[[VAR0:.+]] = "tosa.reshape"(%arg0) {new_shape = array} - // CHECK-DAG: %[[VAR1:.+]] = "tosa.slice"(%[[VAR0]]) {size = array, start = array} - // CHECK-DAG: %[[VAR2:.+]] = "tosa.slice"(%[[VAR0]]) {size = array, start = array} - // CHECK-DAG: %[[VAR3:.+]] = "tosa.slice"(%[[VAR0]]) {size = array, start = array} - // CHECK-DAG: %[[VAR4:.+]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} - // CHECK-DAG: %[[VAR5:.+]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} - // CHECK-DAG: %[[VAR6:.+]] = "tosa.reshape"(%[[VAR3]]) {new_shape = array} + // CHECK-DAG: %[[VAR0:.+]] = "tosa.reshape"(%arg0) <{new_shape = array}> + // CHECK-DAG: %[[VAR1:.+]] = "tosa.slice"(%[[VAR0]]) <{size = array, start = array}> + // CHECK-DAG: %[[VAR2:.+]] = "tosa.slice"(%[[VAR0]]) <{size = array, start = array}> + // CHECK-DAG: %[[VAR3:.+]] = "tosa.slice"(%[[VAR0]]) <{size = array, start = array}> + // CHECK-DAG: %[[VAR4:.+]] = "tosa.reshape"(%[[VAR1]]) <{new_shape = array}> + // CHECK-DAG: %[[VAR5:.+]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> + // CHECK-DAG: %[[VAR6:.+]] = "tosa.reshape"(%[[VAR3]]) <{new_shape = array}> // CHECK: return %[[VAR4]], %[[VAR5]], %[[VAR6]] %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<13x?x3xf32>) -> (tensor<13x?x3xf32>, tensor<13x?x3xf32>, tensor<13x?x3xf32>) func.return %0#0, %0#1, %0#2 : tensor<13x?x3xf32>, tensor<13x?x3xf32>, tensor<13x?x3xf32> @@ -1503,9 +1599,9 @@ func.func @test_split_dynamic(%arg0: tensor<13x?x3xf32>) -> (tensor<13x?x3xf32>, // ----- // CHECK-LABEL: test_split_neg -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> func.func @test_split_neg(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) { %cst_0 = arith.constant dense<-2> : tensor %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) @@ -1515,9 +1611,9 @@ func.func @test_split_neg(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, te // ----- // CHECK-LABEL: test_split_axis_0 -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> func.func @test_split_axis_0(%arg0: tensor<21x13x3xf32>) -> (tensor<7x13x3xf32>, tensor<7x13x3xf32>, tensor<7x13x3xf32>) { %cst_0 = arith.constant dense<0> : tensor %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<21x13x3xf32>) -> (tensor<7x13x3xf32>, tensor<7x13x3xf32>, tensor<7x13x3xf32>) @@ -1527,8 +1623,8 @@ func.func @test_split_axis_0(%arg0: tensor<21x13x3xf32>) -> (tensor<7x13x3xf32>, // ----- // CHECK-LABEL: test_split_v_neg_axis -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} -// CHECK: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> +// CHECK: %[[VAR1:.*]] = "tosa.slice"(%arg0) <{size = array, start = array}> func.func @test_split_v_neg_axis(%arg0: tensor<2x3x3x8xf32>) -> (tensor<2x3x3x3xf32>, tensor<2x3x3x5xf32>) { %split_size = arith.constant dense<[3, 5]> : tensor<2xi32> %axis = arith.constant dense<-1> : tensor @@ -1549,13 +1645,13 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_space_to_batch -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]> : tensor<3x2xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]> : tensor<3x2xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.pad"(%arg0, %[[VAR0]], %[[PVAL]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} +// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) <{new_shape = array}> func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> { %cst = arith.constant dense<2> : tensor<1xi32> %cst_0 = arith.constant dense<[[0, 1]]> : tensor<1x2xi32> @@ -1566,13 +1662,13 @@ func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32 // ----- // CHECK-LABEL: test_space_to_batch_dyn -// CHECK-DAG: %[[C0:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} -// CHECK-DAG: %[[C1:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 2], [0, 0], [0, 0]]> : tensor<4x2xi32>} -// CHECK-DAG: %[[C2:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} +// CHECK-DAG: %[[C0:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> +// CHECK-DAG: %[[C1:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 2], [0, 0], [0, 0]]> : tensor<4x2xi32>}> +// CHECK-DAG: %[[C2:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}> // CHECK-DAG: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[C1]], %[[C0]]) : (tensor, tensor<4x2xi32>, tensor) -> tensor -// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%[[PAD]]) {new_shape = array} +// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%[[PAD]]) <{new_shape = array}> // CHECK-DAG: %[[T:.+]] = "tosa.transpose"(%[[R0]], %[[C2]]) -// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) {new_shape = array} +// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) <{new_shape = array}> // CHECK: return %[[R1]] : tensor func.func @test_space_to_batch_dyn(%arg0 : tensor) -> (tensor) { %0 = "tfl.pseudo_const"() {value = dense<[3, 1]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -1584,12 +1680,12 @@ func.func @test_space_to_batch_dyn(%arg0 : tensor) -> (tensor : tensor<4xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[3, 1, 2, 0]> : tensor<4xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg0, %[[VAR0]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) <{new_shape = array}> // CHECK: return %[[VAR5:.*]] func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> { %cst = arith.constant dense<2> : tensor<2xi32> @@ -1603,11 +1699,11 @@ func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1 // ----- // CHECK-LABEL: @test_batch_to_space_dyn -// CHECK-DAG: %[[C0:.+]] = "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[C0:.+]] = "tosa.const"() <{value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>}> +// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK-DAG: %[[T:.+]] = "tosa.transpose"(%[[R0]], %[[C0]]) -// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) {new_shape = array} -// CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[R1]]) {size = array, start = array} +// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) <{new_shape = array}> +// CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[R1]]) <{size = array, start = array}> // CHECK: return %[[SLICE]] func.func @test_batch_to_space_dyn(%arg0 : tensor) -> (tensor) { %0 = "tfl.pseudo_const"() {value = dense<[3, 1]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -1619,10 +1715,10 @@ func.func @test_batch_to_space_dyn(%arg0 : tensor) -> (tensor : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> { %0 = "tfl.space_to_depth"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> func.return %0 : tensor<1x16x16x32xf32> @@ -1631,10 +1727,10 @@ func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x3 // ----- // CHECK-LABEL: test_depth_to_space -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> { %0 = "tfl.depth_to_space"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> func.return %0 : tensor<1x64x64x2xf32> @@ -1643,12 +1739,12 @@ func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2 // ----- // CHECK-LABEL: @test_bucketize -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() {value = dense<{{\[\[\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]]]> : tensor<1x1x4xf32>} -// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]]]> : tensor<1x1x4xf32>}> +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK: %[[VAL_2:.*]] = "tosa.greater_equal"(%[[VAL_1]], %[[VAL_0]]) // CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_2]]) : (tensor<2x5x4xi1>) -> tensor<2x5x4xi32> -// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 2 : i64} -// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array} +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) <{axis = 2 : i64}> +// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> func.func @test_bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { %0 = "tfl.bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32> func.return %0 : tensor<2x5xi32> @@ -1657,14 +1753,14 @@ func.func @test_bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { // ----- // CHECK-LABEL: @test_bucketize_cast_boundaries -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() {value = dense<{{\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]> : tensor<4xf32>} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]> : tensor<4xf32>}> // CHECK: %[[VAL_1:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<4xf32>) -> tensor<4xi32> -// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%arg0) {new_shape = array} -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} +// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> // CHECK: %[[VAL_4:.*]] = "tosa.greater_equal"(%[[VAL_2]], %[[VAL_3]]) // CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<2x5x4xi1>) -> tensor<2x5x4xi32> -// CHECK: %[[VAL_6:.*]] = "tosa.reduce_sum"(%[[VAL_5]]) {axis = 2 : i64} -// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array} +// CHECK: %[[VAL_6:.*]] = "tosa.reduce_sum"(%[[VAL_5]]) <{axis = 2 : i64}> +// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) <{new_shape = array}> func.func @test_bucketize_cast_boundaries(%arg0: tensor<2x5xi32>) -> tensor<2x5xi32> { %0 = "tfl.bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xi32>) -> tensor<2x5xi32> func.return %0 : tensor<2x5xi32> @@ -1674,13 +1770,13 @@ func.func @test_bucketize_cast_boundaries(%arg0: tensor<2x5xi32>) -> tensor<2x5x // CHECK-LABEL: @test_one_hot // CHECK-SAME: %[[ARG0:.*]]: tensor<4x4xi32>, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor -// CHECK-DAG: %[[RESHAPE:.*]] = "tosa.reshape"(%[[ARG1]]) {new_shape = array} -// CHECK-DAG: %[[TILE:.*]] = "tosa.tile"(%[[RESHAPE]]) {multiples = array} -// CHECK-DAG: %[[RESHAPE_0:.*]] = "tosa.reshape"(%[[ARG2]]) {new_shape = array} -// CHECK-DAG: %[[TILE_0:.*]] = "tosa.tile"(%[[RESHAPE_0]]) {multiples = array} -// CHECK-DAG: %[[RESHAPE_1:.*]] = "tosa.reshape"(%[[ARG0]]) {new_shape = array} +// CHECK-DAG: %[[RESHAPE:.*]] = "tosa.reshape"(%[[ARG1]]) <{new_shape = array}> +// CHECK-DAG: %[[TILE:.*]] = "tosa.tile"(%[[RESHAPE]]) <{multiples = array}> +// CHECK-DAG: %[[RESHAPE_0:.*]] = "tosa.reshape"(%[[ARG2]]) <{new_shape = array}> +// CHECK-DAG: %[[TILE_0:.*]] = "tosa.tile"(%[[RESHAPE_0]]) <{multiples = array}> +// CHECK-DAG: %[[RESHAPE_1:.*]] = "tosa.reshape"(%[[ARG0]]) <{new_shape = array}> // CHECK-DAG: %[[SCATTER:.*]] = "tosa.scatter"(%[[TILE_0]], %[[RESHAPE_1]], %[[TILE]]) -// CHECK-DAG: %[[RESHAPE_2:.*]] = "tosa.reshape"(%[[SCATTER]]) {new_shape = array} +// CHECK-DAG: %[[RESHAPE_2:.*]] = "tosa.reshape"(%[[SCATTER]]) <{new_shape = array}> // CHECK: return %[[RESHAPE_2]] func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<4x4x2xf32> { %0 = arith.constant dense<2> : tensor @@ -1691,15 +1787,12 @@ func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tenso // ----- // CHECK-LABEL: test_fakequant_with_min_max_args -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<16383.75> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<6.10360876E-5> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.mul"(%arg0, %[[VAR0]]) {shift = 0 : i32} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.add"(%[[VAR4]], %[[VAR1]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.cast"(%[[VAR5]]) -// CHECK-DAG: %[[VAR8:.*]] = "tosa.cast"(%[[VAR7]]) -// CHECK-DAG: %[[VAR10:.*]] = "tosa.sub"(%[[VAR8]], %[[VAR1]]) -// CHECK-DAG: %[[VAR12:.*]] = "tosa.mul"(%[[VAR10]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.mul"(%arg0, %[[VAR0]]) <{shift = 0 : i32}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.cast"(%[[VAR2]]) +// CHECK-DAG: %[[VAR4:.*]] = "tosa.cast"(%[[VAR3]]) +// CHECK-DAG: %[[VAR5:.*]] = "tosa.mul"(%[[VAR4]], %[[VAR1]]) <{shift = 0 : i32}> func.func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.quantize"(%arg0) {qtype = tensor<13x21x3x!quant.uniform>} : (tensor<13x21x3xf32>) -> tensor<*x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<*x!quant.uniform>) -> tensor<13x21x3xf32> @@ -1720,7 +1813,7 @@ func.func @test_dequantize_float(%arg0: tensor<10xf16>) -> tensor<*xf32> { // CHECK-LABEL: @test_dequantize_quant_uniform func.func @test_dequantize_quant_uniform(%arg0: tensor<4x!quant.uniform>) -> tensor<*xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() {value = dense<-1.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[VAL1:.+]] = "tosa.cast"(%arg0) // CHECK-DAG: %[[VAL2:.+]] = "tosa.sub"(%[[VAL1]], %[[VAL0]]) %0 = "tfl.dequantize"(%arg0) : (tensor<4x!quant.uniform>) -> tensor<*xf32> @@ -1730,11 +1823,11 @@ func.func @test_dequantize_quant_uniform(%arg0: tensor<4x!quant.uniform>) -> tensor<*xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() {value = dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]]> : tensor<1x4xf32>} - // CHECK-DAG: %[[VAL1:.+]] = "tosa.const"() {value = dense<{{\[}}[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]]> : tensor<1x4xf32>} + // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() <{value = dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]]> : tensor<1x4xf32>}> + // CHECK-DAG: %[[VAL1:.+]] = "tosa.const"() <{value = dense<{{\[}}[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]]> : tensor<1x4xf32>}> // CHECK-DAG: %[[VAL2:.+]] = "tosa.cast"(%arg0) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> // CHECK-DAG: %[[VAL3:.+]] = "tosa.sub"(%[[VAL2]], %[[VAL1]]) : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK: %[[VAL4:.+]] = "tosa.mul"(%[[VAL3]], %[[VAL0]]) {shift = 0 : i32} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> + // CHECK: %[[VAL4:.+]] = "tosa.mul"(%[[VAL3]], %[[VAL0]]) <{shift = 0 : i32}> : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> %0 = "tfl.dequantize"(%arg0) : (tensor<1x4x!quant.uniform>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } @@ -1751,11 +1844,11 @@ func.func @test_quantfork.stats(%arg0: tensor<2x1xf32>) -> (tensor<2x1xf32>) { // ----- // CHECK-LABEL: test_add_qi8 -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.rescale"(%arg0) {double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.rescale"(%[[VAL_0]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.rescale"(%arg1) {double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.rescale"(%arg0) <{double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.rescale"(%[[VAL_0]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.rescale"(%arg1) <{double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK-DAG: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_2]]) -// CHECK: %[[VAL_4:.*]] = "tosa.rescale"(%[[VAL_3]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[VAL_4:.*]] = "tosa.rescale"(%[[VAL_3]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array}> func.func @test_add_qi8(%arg0: tensor<13x21x1x!quant.uniform>, %arg1: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x1x!quant.uniform>, tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -1764,11 +1857,11 @@ func.func @test_add_qi8(%arg0: tensor<13x21x1x!quant.uniform, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.rescale"(%[[VAL_0]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.rescale"(%arg1) {double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.rescale"(%arg0) <{double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.rescale"(%[[VAL_0]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.rescale"(%arg1) <{double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK-DAG: %[[VAL_3:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_2]]) -// CHECK: %[[VAL_4:.*]] = "tosa.rescale"(%[[VAL_3]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[VAL_4:.*]] = "tosa.rescale"(%[[VAL_3]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> func.func @test_sub_qi8(%arg0: tensor<1x21x3x!quant.uniform>, %arg1: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = tfl.sub(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x21x3x!quant.uniform>, tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -1779,7 +1872,7 @@ func.func @test_sub_qi8(%arg0: tensor<1x21x3x!quant.uniform // CHECK: %[[VAR3:.*]] = "tosa.rescale"(%[[VAR2]]) func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform>, %arg1: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3x!quant.uniform>, tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -1789,7 +1882,7 @@ func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform, pad = array, quantization_info = #tosa.unary_quant, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, quantization_info = #tosa.unary_quant, stride = array}> // CHECK-SAME: -> tensor<1x32x32x8x!quant.uniform> func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -1799,7 +1892,7 @@ func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> // CHECK-SAME: -> tensor<1x32x32x8xi16> func.func @test_avg_pool2d_i16(%arg0: tensor<1x32x32x8xi16>) -> tensor<*xi16> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xi16>) -> tensor<*xi16> @@ -1809,7 +1902,7 @@ func.func @test_avg_pool2d_i16(%arg0: tensor<1x32x32x8xi16>) -> tensor<*xi16> { // ----- // CHECK-LABEL: test_max_pool2d_qi8 -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) <{kernel = array, pad = array, stride = array}> func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> func.return %0 : tensor<*x!quant.uniform> @@ -1818,24 +1911,24 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<4> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() {value = dense<536870912> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() {value = dense<1515870810> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() {value = dense<-1010580540> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() {value = dense<1> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() {value = dense<12> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() {value = dense<7> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR9:.*]] = "tosa.const"() {value = dense<9> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR10:.*]] = "tosa.const"() {value = dense<17> : tensor<1x1x1xi32>} -// CHECK-DAG: %[[VAR11:.*]] = "tosa.const"() {value = dense<"0x5{{.*}}"> : tensor<513xi16>} -// CHECK-DAG: %[[VAR12:.*]] = "tosa.const"() {value = dense<"0xE{{.*}}"> : tensor<513xi16>} -// CHECK-DAG: %[[VAR13:.*]] = "tosa.const"() {value = dense<"0x4{{.*}}"> : tensor<513xi16>} -// CHECK-DAG: %[[VAR14:.*]] = "tosa.const"() {value = dense<"0x0{{.*}}"> : tensor<513xi16>} -// CHECK-DAG: %[[VAR15:.*]] = "tosa.rescale"(%arg0) {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAR16:.*]] = "tosa.reduce_max"(%[[VAR15]]) {axis = 2 : i64} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<35> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<536870912> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<1515870810> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() <{value = dense<-1010580540> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() <{value = dense<12> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() <{value = dense<7> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR9:.*]] = "tosa.const"() <{value = dense<9> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR10:.*]] = "tosa.const"() <{value = dense<17> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR11:.*]] = "tosa.const"() <{value = dense<"0x5{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR12:.*]] = "tosa.const"() <{value = dense<"0xE{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR13:.*]] = "tosa.const"() <{value = dense<"0x4{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR14:.*]] = "tosa.const"() <{value = dense<"0x0{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR15:.*]] = "tosa.rescale"(%arg0) <{double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK-DAG: %[[VAR16:.*]] = "tosa.reduce_max"(%[[VAR15]]) <{axis = 2 : i64}> // CHECK-DAG: %[[VAR17:.*]] = "tosa.sub"(%[[VAR15]], %[[VAR16]]) -// CHECK-DAG: %[[VAR18:.*]] = "tosa.rescale"(%[[VAR17]]) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR18:.*]] = "tosa.rescale"(%[[VAR17]]) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK-DAG: %[[VAR19:.*]] = "tosa.table"(%[[VAR18]], %[[VAR14]]) // CHECK-DAG: %[[VAR20:.*]] = "tosa.table"(%[[VAR18]], %[[VAR13]]) // CHECK-DAG: %[[VAR21:.*]] = "tosa.table"(%[[VAR18]], %[[VAR12]]) @@ -1843,36 +1936,36 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform // CHECK-DAG: %[[VAR27:.*]] = "tosa.add"(%[[VAR23]], %[[VAR24]]) // CHECK-DAG: %[[VAR28:.*]] = "tosa.add"(%[[VAR27]], %[[VAR25]]) // CHECK-DAG: %[[VAR29:.*]] = "tosa.add"(%[[VAR28]], %[[VAR26]]) -// CHECK-DAG: %[[VAR30:.*]] = "tosa.arithmetic_right_shift"(%[[VAR29]], %[[VAR7]]) {round = true} -// CHECK-DAG: %[[VAR31:.*]] = "tosa.reduce_sum"(%[[VAR30]]) {axis = 2 : i64} +// CHECK-DAG: %[[VAR30:.*]] = "tosa.arithmetic_right_shift"(%[[VAR29]], %[[VAR7]]) <{round = true}> +// CHECK-DAG: %[[VAR31:.*]] = "tosa.reduce_sum"(%[[VAR30]]) <{axis = 2 : i64}> // CHECK-DAG: %[[VAR32:.*]] = "tosa.clz"(%[[VAR31]]) // CHECK-DAG: %[[VAR33:.*]] = "tosa.sub"(%[[VAR32]], %[[VAR6]]) // CHECK-DAG: %[[VAR34:.*]] = "tosa.logical_left_shift"(%[[VAR31]], %[[VAR33]]) -// CHECK-DAG: %[[VAR35:.*]] = "tosa.mul"(%[[VAR34]], %[[VAR5]]) {shift = 31 : i32} +// CHECK-DAG: %[[VAR35:.*]] = "tosa.mul"(%[[VAR34]], %[[VAR5]]) <{shift = 31 : i32}> // CHECK-DAG: %[[VAR36:.*]] = "tosa.add"(%[[VAR35]], %[[VAR4]]) -// CHECK-DAG: %[[VAR37:.*]] = "tosa.mul"(%[[VAR36]], %[[VAR34]]) {shift = 31 : i32} +// CHECK-DAG: %[[VAR37:.*]] = "tosa.mul"(%[[VAR36]], %[[VAR34]]) <{shift = 31 : i32}> // CHECK-DAG: %[[VAR38:.*]] = "tosa.sub"(%[[VAR3]], %[[VAR37]]) -// CHECK-DAG: %[[VAR39:.*]] = "tosa.mul"(%[[VAR36]], %[[VAR38]]) {shift = 31 : i32} -// CHECK-DAG: %[[VAR40:.*]] = "tosa.mul"(%[[VAR39]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR39:.*]] = "tosa.mul"(%[[VAR36]], %[[VAR38]]) <{shift = 31 : i32}> +// CHECK-DAG: %[[VAR40:.*]] = "tosa.mul"(%[[VAR39]], %[[VAR2]]) <{shift = 0 : i32}> // CHECK-DAG: %[[VAR41:.*]] = "tosa.add"(%[[VAR36]], %[[VAR40]]) -// CHECK-DAG: %[[VAR42:.*]] = "tosa.mul"(%[[VAR41]], %[[VAR34]]) {shift = 31 : i32} +// CHECK-DAG: %[[VAR42:.*]] = "tosa.mul"(%[[VAR41]], %[[VAR34]]) <{shift = 31 : i32}> // CHECK-DAG: %[[VAR43:.*]] = "tosa.sub"(%[[VAR3]], %[[VAR42]]) -// CHECK-DAG: %[[VAR44:.*]] = "tosa.mul"(%[[VAR41]], %[[VAR43]]) {shift = 31 : i32} -// CHECK-DAG: %[[VAR45:.*]] = "tosa.mul"(%[[VAR44]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR44:.*]] = "tosa.mul"(%[[VAR41]], %[[VAR43]]) <{shift = 31 : i32}> +// CHECK-DAG: %[[VAR45:.*]] = "tosa.mul"(%[[VAR44]], %[[VAR2]]) <{shift = 0 : i32}> // CHECK-DAG: %[[VAR46:.*]] = "tosa.add"(%[[VAR41]], %[[VAR45]]) -// CHECK-DAG: %[[VAR47:.*]] = "tosa.mul"(%[[VAR46]], %[[VAR34]]) {shift = 31 : i32} +// CHECK-DAG: %[[VAR47:.*]] = "tosa.mul"(%[[VAR46]], %[[VAR34]]) <{shift = 31 : i32}> // CHECK-DAG: %[[VAR48:.*]] = "tosa.sub"(%[[VAR3]], %[[VAR47]]) -// CHECK-DAG: %[[VAR49:.*]] = "tosa.mul"(%[[VAR46]], %[[VAR48]]) {shift = 31 : i32} -// CHECK-DAG: %[[VAR50:.*]] = "tosa.mul"(%[[VAR49]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR49:.*]] = "tosa.mul"(%[[VAR46]], %[[VAR48]]) <{shift = 31 : i32}> +// CHECK-DAG: %[[VAR50:.*]] = "tosa.mul"(%[[VAR49]], %[[VAR2]]) <{shift = 0 : i32}> // CHECK-DAG: %[[VAR51:.*]] = "tosa.add"(%[[VAR46]], %[[VAR50]]) -// CHECK-DAG: %[[VAR52:.*]] = "tosa.mul"(%[[VAR29]], %[[VAR51]]) {shift = 30 : i32} +// CHECK-DAG: %[[VAR52:.*]] = "tosa.mul"(%[[VAR29]], %[[VAR51]]) <{shift = 30 : i32}> // CHECK-DAG: %[[VAR53:.*]] = "tosa.sub"(%[[VAR1]], %[[VAR32]]) -// CHECK-DAG: %[[VAR54:.*]] = "tosa.arithmetic_right_shift"(%[[VAR52]], %[[VAR53]]) {round = true} -// CHECK: %[[VAR55:.*]] = "tosa.rescale"(%[[VAR54]]) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR54:.*]] = "tosa.arithmetic_right_shift"(%[[VAR52]], %[[VAR53]]) <{round = true}> +// CHECK: %[[VAR55:.*]] = "tosa.rescale"(%[[VAR54]]) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array}> func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -1882,37 +1975,37 @@ func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform : tensor<1x1xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<7> : tensor<1x1xi32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<32768> : tensor<1x1xi32>} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() {value = dense<14> : tensor<1x1xi32>} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() {value = dense<1073741824> : tensor<1x1xi32>} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() {value = dense<1> : tensor<1x1xi32>} -// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() {value = dense<32767> : tensor<1x1xi32>} -// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() {value = dense<"0xF{{.*}}> -// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() {value = dense<"0x0{{.*}}> : tensor<513xi16>} -// CHECK-DAG: %[[VAR9:.*]] = "tosa.rescale"(%arg0) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAR10:.*]] = "tosa.reduce_max"(%[[VAR9]]) {axis = 1 : i64} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<31> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<7> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<32768> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<14> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<1073741824> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() <{value = dense<32767> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() <{value = dense<"0xF{{.*}}> +// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() <{value = dense<"0x0{{.*}}> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR9:.*]] = "tosa.rescale"(%arg0) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +// CHECK-DAG: %[[VAR10:.*]] = "tosa.reduce_max"(%[[VAR9]]) <{axis = 1 : i64}> // CHECK-DAG: %[[VAR11:.*]] = "tosa.sub"(%[[VAR9]], %[[VAR10]]) -// CHECK-DAG: %[[VAR12:.*]] = "tosa.rescale"(%[[VAR11]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR12:.*]] = "tosa.rescale"(%[[VAR11]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> // CHECK-DAG: %[[VAR13:.*]] = "tosa.add"(%[[VAR12]], %[[VAR6]]) // CHECK-DAG: %[[VAR14:.*]] = "tosa.cast"(%[[VAR13]]) // CHECK-DAG: %[[VAR15:.*]] = "tosa.table"(%[[VAR14]], %[[VAR8]]) -// CHECK-DAG: %[[VAR16:.*]] = "tosa.arithmetic_right_shift"(%[[VAR15]], %[[VAR1]]) {round = true} -// CHECK-DAG: %[[VAR17:.*]] = "tosa.reduce_sum"(%[[VAR16]]) {axis = 1 : i64} +// CHECK-DAG: %[[VAR16:.*]] = "tosa.arithmetic_right_shift"(%[[VAR15]], %[[VAR1]]) <{round = true}> +// CHECK-DAG: %[[VAR17:.*]] = "tosa.reduce_sum"(%[[VAR16]]) <{axis = 1 : i64}> // CHECK-DAG: %[[VAR18:.*]] = "tosa.clz"(%[[VAR17]]) // CHECK-DAG: %[[VAR19:.*]] = "tosa.sub"(%[[VAR18]], %[[VAR5]]) // CHECK-DAG: %[[VAR20:.*]] = "tosa.logical_left_shift"(%[[VAR17]], %[[VAR19]]) // CHECK-DAG: %[[VAR21:.*]] = "tosa.sub"(%[[VAR20]], %[[VAR4]]) -// CHECK-DAG: %[[VAR22:.*]] = "tosa.arithmetic_right_shift"(%[[VAR21]], %[[VAR3]]) {round = true} +// CHECK-DAG: %[[VAR22:.*]] = "tosa.arithmetic_right_shift"(%[[VAR21]], %[[VAR3]]) <{round = true}> // CHECK-DAG: %[[VAR23:.*]] = "tosa.sub"(%[[VAR22]], %[[VAR2]]) // CHECK-DAG: %[[VAR24:.*]] = "tosa.cast"(%[[VAR23]]) // CHECK-DAG: %[[VAR25:.*]] = "tosa.table"(%[[VAR24]], %[[VAR7]]) -// CHECK-DAG: %[[VAR26:.*]] = "tosa.arithmetic_right_shift"(%[[VAR25]], %[[VAR1]]) {round = true} -// CHECK-DAG: %[[VAR27:.*]] = "tosa.mul"(%[[VAR26]], %[[VAR16]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR26:.*]] = "tosa.arithmetic_right_shift"(%[[VAR25]], %[[VAR1]]) <{round = true}> +// CHECK-DAG: %[[VAR27:.*]] = "tosa.mul"(%[[VAR26]], %[[VAR16]]) <{shift = 0 : i32}> // CHECK-DAG: %[[VAR28:.*]] = "tosa.sub"(%[[VAR0]], %[[VAR18]]) -// CHECK-DAG: %[[VAR29:.*]] = "tosa.arithmetic_right_shift"(%[[VAR27]], %[[VAR28]]) {round = true} -// CHECK: %[[VAR30:.*]] = "tosa.rescale"(%[[VAR29]]) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR29:.*]] = "tosa.arithmetic_right_shift"(%[[VAR27]], %[[VAR28]]) <{round = true}> +// CHECK: %[[VAR30:.*]] = "tosa.rescale"(%[[VAR29]]) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> func.func @test_softmax_qi16(%arg0: tensor<14x19x!quant.uniform>) -> tensor<14x19x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<14x19x!quant.uniform> func.return %0 : tensor<14x19x!quant.uniform> @@ -1921,7 +2014,7 @@ func.func @test_softmax_qi16(%arg0: tensor<14x19x!quant.uniform : tensor<256xi8>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<256xi8>}> // CHECK: %[[VAR1:.*]] = "tosa.table"(%arg0, %[[VAR0]]) func.func @test_sigmoid_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.logistic"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -1931,7 +2024,7 @@ func.func @test_sigmoid_qi8(%arg0: tensor<13x21x3x!quant.uniform : tensor<256xi8>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<256xi8>}> // CHECK: %[[VAR1:.*]] = "tosa.table"(%arg0, %[[VAR0]]) func.func @test_tanh_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.tanh"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -1942,7 +2035,7 @@ func.func @test_tanh_qi8(%arg0: tensor<13x21x3x!quant.uniform func.func @test_relu_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.relu"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -1952,7 +2045,7 @@ func.func @test_relu_qi8(%arg0: tensor<13x21x3x!quant.uniform func.func @test_relu0To1_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.relu_n1_to_1"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -1962,7 +2055,7 @@ func.func @test_relu0To1_qi8(%arg0: tensor<13x21x3x!quant.uniform func.func @test_relu6_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.relu6"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -1974,7 +2067,7 @@ func.func @test_relu6_qi8(%arg0: tensor<13x21x3x!quant.uniform // CHECK: %[[VAL_5:.*]] = "tosa.rescale"(%[[VAL_4]]) // CHECK: %[[VAL_6:.*]] = "tosa.rescale"(%[[VAL_5]]) func.func @test_relu6_qu8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { @@ -1987,10 +2080,10 @@ func.func @test_relu6_qu8(%arg0: tensor<13x21x3x!quant.uniform> func.func @test_leaky_relu_qi8(%arg0: tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.leaky_relu"(%arg0) {alpha = 0.948724806 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -2000,10 +2093,10 @@ func.func @test_leaky_relu_qi8(%arg0: tensor<14x19x!quant.uniform> func.func @test_leaky_relu_qi16(%arg0: tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.leaky_relu"(%arg0) {alpha = 1.048724806 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -2013,8 +2106,8 @@ func.func @test_leaky_relu_qi16(%arg0: tensor<14x19x!quant.uniform, mode = "BILINEAR", offset = array, scale = array} -// CHECK: %[[VAR2:.*]] = "tosa.rescale"(%[[VAR1]]) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "BILINEAR", offset = array, scale = array}> +// CHECK: %[[VAR2:.*]] = "tosa.rescale"(%[[VAR1]]) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> func.func @test_resize_bilinear_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2024,7 +2117,7 @@ func.func @test_resize_bilinear_qi8(%arg0: tensor<1x80x80x2x!quant.uniform, mode = "BILINEAR", offset = array, scale = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "BILINEAR", offset = array, scale = array}> func.func @test_resize_bilinear_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2034,7 +2127,7 @@ func.func @test_resize_bilinear_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform< // ----- // CHECK-LABEL: test_resize_bilinear_align_qi8 -// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "BILINEAR", offset = array, scale = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "BILINEAR", offset = array, scale = array}> func.func @test_resize_bilinear_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = true, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2044,7 +2137,7 @@ func.func @test_resize_bilinear_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform // ----- // CHECK-LABEL: test_resize_bilinear_align_half_qi8 -// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "BILINEAR", offset = array, scale = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "BILINEAR", offset = array, scale = array}> func.func @test_resize_bilinear_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = true, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2054,7 +2147,7 @@ func.func @test_resize_bilinear_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.un // ----- // CHECK-LABEL: test_resize_nearest_qi8 -// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} +// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array}> func.func @test_resize_nearest_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2065,7 +2158,7 @@ func.func @test_resize_nearest_qi8(%arg0: tensor<1x80x80x2x!quant.uniform, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} +// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array}> func.func @test_resize_nearest_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2075,7 +2168,7 @@ func.func @test_resize_nearest_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} +// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array}> func.func @test_resize_nearest_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = true, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2085,7 +2178,7 @@ func.func @test_resize_nearest_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform< // ----- // CHECK-LABEL: test_resize_nearest_align_half_qi8 -// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} +// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array}> func.func @test_resize_nearest_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = true, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2094,12 +2187,74 @@ func.func @test_resize_nearest_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uni // ----- +// CHECK-LABEL: test_resize_bilinear_f32_scalar_input +// CHECK: %[[VAL_1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "BILINEAR", offset = array, scale = array}> +func.func @test_resize_bilinear_f32_scalar_input(%arg0: tensor<3x1x1x7xf32>) -> tensor<3x2x2x7xf32> { + %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<3x1x1x7xf32>, tensor<2xi32>) -> tensor<3x2x2x7xf32> + func.return %1 : tensor<3x2x2x7xf32> +} + +// ----- + +// CHECK-LABEL: test_resize_bilinear_half_qi8_scalar_input +// CHECK: %[[VAL_1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "BILINEAR", offset = array, scale = array}> +// CHECK: %[[VAL_2:.*]] = "tosa.rescale"(%[[VAL_1]]) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +func.func @test_resize_bilinear_half_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { + %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<3x1x1x7x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x2x7x!quant.uniform> + func.return %1 : tensor<3x2x2x7x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_resize_bilinear_align_qi8_scalar_input +// CHECK: %[[VAL_1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "BILINEAR", offset = array, scale = array}> +// CHECK: %[[VAL_2:.*]] = "tosa.rescale"(%[[VAL_1]]) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array}> +func.func @test_resize_bilinear_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { + %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = true, half_pixel_centers = false} : (tensor<3x1x1x7x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x2x7x!quant.uniform> + func.return %1 : tensor<3x2x2x7x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_resize_nearest_f32_scalar_input +// CHECK: %[[VAL_1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array}> +func.func @test_resize_nearest_f32_scalar_input(%arg0: tensor<3x1x1x7xf32>) -> tensor<3x2x2x7xf32> { + %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<3x1x1x7xf32>, tensor<2xi32>) -> tensor<3x2x2x7xf32> + func.return %1 : tensor<3x2x2x7xf32> +} + +// ----- + +// CHECK-LABEL: test_resize_nearest_half_qi8_scalar_input +// CHECK: %[[VAL_1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array}> +func.func @test_resize_nearest_half_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { + %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<3x1x1x7x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x2x7x!quant.uniform> + func.return %1 : tensor<3x2x2x7x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_resize_nearest_align_qi8_scalar_input +// CHECK: %[[VAL_1:.*]] = "tosa.resize"(%arg0) <{border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array}> +func.func @test_resize_nearest_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { + %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = true, half_pixel_centers = false} : (tensor<3x1x1x7x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x2x7x!quant.uniform> + func.return %1 : tensor<3x2x2x7x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_fullyconnected_qi8 -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0> : tensor<28xi32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor<28xi32>}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg1, %[[VAR0]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.fully_connected"(%arg0, %[[VAR2]], %[[VAR1]]) {quantization_info = #tosa.conv_quant} -// CHECK: %[[VAR4:.*]] = "tosa.rescale"(%[[VAR3]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 3 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.fully_connected"(%arg0, %[[VAR2]], %[[VAR1]]) <{quantization_info = #tosa.conv_quant}> +// CHECK: %[[VAR4:.*]] = "tosa.rescale"(%[[VAR3]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 3 : i32, per_channel = false, scale32 = true, shift = array}> func.func @test_fullyconnected_qi8(%arg0: tensor<14x19x!quant.uniform>, %arg1: tensor<19x28x!quant.uniform>) -> tensor<14x28x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.transpose"(%arg1, %0) : (tensor<19x28x!quant.uniform>, tensor<2xi32>) -> tensor<28x19x!quant.uniform> @@ -2110,10 +2265,10 @@ func.func @test_fullyconnected_qi8(%arg0: tensor<14x19x!quant.uniform} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) <{new_shape = array}> // CHECK: return %[[VAR7]] func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi32>) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<7x7xi32>) -> tensor<*xf32> @@ -2122,10 +2277,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi32>) -> te // ----- // CHECK-LABEL: test_gather_dyn -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) <{new_shape = array}> // CHECK: return %[[VAR7]] func.func @test_gather_dyn(%arg0: tensor, %arg1 : tensor<7x7xi32>) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor, tensor<7x7xi32>) -> tensor<*xf32> @@ -2135,10 +2290,10 @@ func.func @test_gather_dyn(%arg0: tensor, %arg1 : tensor<7x7xi32>) - // ----- // CHECK-LABEL: test_gather_channel_dyn -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) <{new_shape = array}> // CHECK: return %[[VAR7]] func.func @test_gather_channel_dyn(%arg0: tensor<13x21x?xf32>, %arg1: tensor<7x7xi32>) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x?xf32>, tensor<7x7xi32>) -> tensor<*xf32> @@ -2147,10 +2302,10 @@ func.func @test_gather_channel_dyn(%arg0: tensor<13x21x?xf32>, %arg1: tensor<7x7 // ----- // CHECK-LABEL: test_gather_indices_dyn -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) <{new_shape = array}> // CHECK: return %[[VAR7]] func.func @test_gather_indices_dyn(%arg0: tensor<13x21x3xf32>, %arg1: tensor) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor) -> tensor<*xf32> @@ -2160,9 +2315,9 @@ func.func @test_gather_indices_dyn(%arg0: tensor<13x21x3xf32>, %arg1: tensor} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.gather"(%[[VAR1]], %[[VAR0]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> // CHECK: return %[[VAR3]] func.func @test_gather_batch(%arg0: tensor<1x4x4x4xi32>) -> tensor<1x3x4x4xi32> { %0 = "tfl.pseudo_const"() {value = dense<[[0, 3, 1]]> : tensor<1x3xi32>} : () -> tensor<1x3xi32> @@ -2172,9 +2327,9 @@ func.func @test_gather_batch(%arg0: tensor<1x4x4x4xi32>) -> tensor<1x3x4x4xi32> // ----- // CHECK-LABEL: test_gather_batch_dyn -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.gather"(%[[VAR1]], %arg1) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) <{new_shape = array}> // CHECK: return %[[VAR3]] func.func @test_gather_batch_dyn(%arg0: tensor, %arg1: tensor) -> tensor { %1 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 1 : i32} : (tensor, tensor) -> tensor @@ -2184,13 +2339,13 @@ func.func @test_gather_batch_dyn(%arg0: tensor, %arg1: tensor} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%arg1) {new_shape = array} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.mul"(%[[VAR3]], %[[VAR1]]) {shift = 0 : i32} -// CHECK-DAG: %[[VAR6:.*]] = "tosa.reduce_sum"(%[[VAR5]]) {axis = 1 : i64} -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.mul"(%[[VAR3]], %[[VAR1]]) <{shift = 0 : i32}> +// CHECK-DAG: %[[VAR6:.*]] = "tosa.reduce_sum"(%[[VAR5]]) <{axis = 1 : i64}> +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) <{new_shape = array}> // CHECK-DAG: %[[VAR8:.*]] = "tosa.gather"(%[[VAR2]], %[[VAR7]]) -// CHECK: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = array} +// CHECK: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) <{new_shape = array}> func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6x7x2xi32>) -> tensor<6x7x3xf32> { %1 = "tfl.gather_nd"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<6x7x2xi32>) -> tensor<6x7x3xf32> func.return %1 : tensor<6x7x3xf32> @@ -2199,10 +2354,10 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6x7x2xi32>) // ----- // CHECK-LABEL: test_gather_cast // CHECK-DAG: %[[VAR1:.*]] = "tosa.cast"(%arg1) -// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%arg0) {new_shape = array} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR1]]) <{new_shape = array}> // CHECK-DAG: %[[VAR4:.*]] = "tosa.gather"(%[[VAR2]], %[[VAR3]]) -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) <{new_shape = array}> // CHECK: return %[[VAR5]] func.func @test_gather_cast(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi64>) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<7x7xi64>) -> tensor<*xf32> @@ -2211,15 +2366,15 @@ func.func @test_gather_cast(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi64>) // ----- -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{\[\[}}48, 1]]> : tensor<1x2xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<-1> : tensor<1x48x1xi64>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}48, 1]]> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1x48x1xi64>}> // CHECK-DAG: %[[VAR2:.*]] = "tosa.cast"(%arg0) -// CHECK-DAG: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) {shift = 0 : i32} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reduce_sum"(%[[VAR4]]) {axis = 1 : i64} -// CHECK-DAG: %[[VAR6:.*]] = "tosa.reshape"(%arg1) {new_shape = array} -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR5]]) {new_shape = array} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) <{shift = 0 : i32}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reduce_sum"(%[[VAR4]]) <{axis = 1 : i64}> +// CHECK-DAG: %[[VAR6:.*]] = "tosa.reshape"(%arg1) <{new_shape = array}> +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR5]]) <{new_shape = array}> // CHECK-DAG: %[[VAR8:.*]] = "tosa.scatter"(%[[VAR1]], %[[VAR7]], %[[VAR6]]) -// CHECK-DAG: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = array} +// CHECK-DAG: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) <{new_shape = array}> // CHECK: return %[[VAR9]] func.func @sparse_to_dense(%arg0 : tensor, %arg1 : tensor) -> (tensor<1x48xi64>) { %0 = arith.constant dense<[1, 48]> : tensor<2xi64> @@ -2232,7 +2387,7 @@ func.func @sparse_to_dense(%arg0 : tensor, %arg1 : tensor) -> (t // CHECK-LABEL: @test_arg_max func.func @test_arg_max(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%arg0) {axis = 1 : i64} + // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%arg0) <{axis = 1 : i64}> %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %1 = "tfl.arg_max"(%arg0, %0) : (tensor<13x21x3xf32>, tensor) -> tensor<*xf32> func.return %1 : tensor<*xf32> @@ -2242,7 +2397,7 @@ func.func @test_arg_max(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: @test_arg_max_negative_dim func.func @test_arg_max_negative_dim(%arg0: tensor<13x21x3xf32>) -> tensor<13x21xf32> { - // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%arg0) {axis = 2 : i64} + // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%arg0) <{axis = 2 : i64}> %0 = "tfl.pseudo_const"() {value = dense<-1> : tensor} : () -> tensor %1 = "tfl.arg_max"(%arg0, %0) : (tensor<13x21x3xf32>, tensor) -> tensor<13x21xf32> func.return %1 : tensor<13x21xf32> @@ -2253,7 +2408,7 @@ func.func @test_arg_max_negative_dim(%arg0: tensor<13x21x3xf32>) -> tensor<13x21 // CHECK-LABEL: @test_arg_min_f32 func.func @test_arg_min_f32(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK: %[[NEG:.+]] = "tosa.negate"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> - // CHECK: "tosa.argmax"(%[[NEG]]) {axis = 1 : i64} + // CHECK: "tosa.argmax"(%[[NEG]]) <{axis = 1 : i64}> %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %1 = "tfl.arg_min"(%arg0, %0) : (tensor<13x21x3xf32>, tensor) -> tensor<*xf32> func.return %1 : tensor<*xf32> @@ -2263,9 +2418,9 @@ func.func @test_arg_min_f32(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: @test_arg_min_i32 func.func @test_arg_min_i32(%arg0: tensor<13x21x3xi32>) -> tensor<*xi32> { - // CHECK: %[[ONE:.+]] = "tosa.const"() {value = dense<-1> : tensor<1x1x1xi32>} + // CHECK: %[[ONE:.+]] = "tosa.const"() <{value = dense<-1> : tensor<1x1x1xi32>}> // CHECK: %[[SUB:.+]] = "tosa.sub"(%[[ONE]], %arg0) - // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%[[SUB]]) {axis = 1 : i64} + // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%[[SUB]]) <{axis = 1 : i64}> %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %1 = "tfl.arg_min"(%arg0, %0) : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> func.return %1 : tensor<*xi32> @@ -2275,9 +2430,9 @@ func.func @test_arg_min_i32(%arg0: tensor<13x21x3xi32>) -> tensor<*xi32> { // CHECK-LABEL: @test_arg_min_ui8 func.func @test_arg_min_ui8(%arg0: tensor<13x21x3xui8>) -> tensor<*xui8> { - // CHECK: %[[MAX:.+]] = "tosa.const"() {value = dense<255> : tensor<1x1x1xui8>} + // CHECK: %[[MAX:.+]] = "tosa.const"() <{value = dense<255> : tensor<1x1x1xui8>}> // CHECK: %[[SUB:.+]] = "tosa.sub"(%[[MAX]], %arg0) - // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%[[SUB]]) {axis = 1 : i64} + // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%[[SUB]]) <{axis = 1 : i64}> %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %1 = "tfl.arg_min"(%arg0, %0) : (tensor<13x21x3xui8>, tensor) -> tensor<*xui8> func.return %1 : tensor<*xui8> @@ -2286,18 +2441,18 @@ func.func @test_arg_min_ui8(%arg0: tensor<13x21x3xui8>) -> tensor<*xui8> { // ----- // CHECK-LABEL: test_fakequant -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<-2.00003052> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<1.99996948> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<6.10360876E-5> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() {value = dense<16383.75> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<-2.00003052> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<1.99996948> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>}> // CHECK-DAG: %[[VAR6:.*]] = "tosa.minimum"(%arg0, %[[VAR1]]) // CHECK-DAG: %[[VAR8:.*]] = "tosa.maximum"(%[[VAR6]], %[[VAR0]]) // CHECK-DAG: %[[VAR10:.*]] = "tosa.sub"(%[[VAR8]], %[[VAR0]]) -// CHECK-DAG: %[[VAR12:.*]] = "tosa.mul"(%[[VAR10]], %[[VAR3]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR12:.*]] = "tosa.mul"(%[[VAR10]], %[[VAR3]]) <{shift = 0 : i32}> // CHECK-DAG: %[[VAR14:.*]] = "tosa.add"(%[[VAR12]], %[[VAR4]]) // CHECK-DAG: %[[VAR15:.*]] = "tosa.floor"(%[[VAR14]]) -// CHECK-DAG: %[[VAR17:.*]] = "tosa.mul"(%[[VAR15]], %[[VAR2]]) {shift = 0 : i32} +// CHECK-DAG: %[[VAR17:.*]] = "tosa.mul"(%[[VAR15]], %[[VAR2]]) <{shift = 0 : i32}> // CHECK: %[[VAR19:.*]] = "tosa.add"(%[[VAR17]], %[[VAR0]]) func.func @test_fakequant(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %2 = "tfl.fake_quant"(%arg0) {max = 2.000000e+00 : f32, min = -2.000000e+00 : f32, narrow_range = false, num_bits = 16 : i32} : (tensor<13x21x3xf32>) -> tensor<*xf32> @@ -2307,14 +2462,12 @@ func.func @test_fakequant(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: @test_fullyconnected_hybrid -func.func @test_fullyconnected_hybrid(%arg0: tensor<14x19xf32>) -> tensor<*xf32> { +func.func @test_fullyconnected_hybrid(%arg0: tensor<14x19xf32>, %arg1: tensor<28x19x!quant.uniform>, %arg2: tensor<28xf32>) -> tensor<*xf32> { // This verifies that the constant is decomposed into a dequantization via a // cast, subtract, and multiplication. // CHECK: "tosa.sub" // CHECK: "tosa.fully_connected" - %0 = "tfl.pseudo_qconst"() {qtype = tensor<36x36x!quant.uniform>, value = dense<42> : tensor<28x19xi8>} : () -> tensor<28x19x!quant.uniform> - %1 = "tfl.pseudo_const"() {value = dense<0.0> : tensor<28xf32>} : () -> tensor<28xf32> - %2 = "tfl.fully_connected"(%arg0, %0, %1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<14x19xf32>, tensor<28x19x!quant.uniform>, tensor<28xf32>) -> tensor<*xf32> + %2 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<14x19xf32>, tensor<28x19x!quant.uniform>, tensor<28xf32>) -> tensor<*xf32> func.return %2 : tensor<*xf32> } @@ -2355,19 +2508,19 @@ func.func @test_squeeze_neg(%arg0: tensor<2x1x3x1xf32>) -> tensor<2x1x3xf32> { // CHECK-LABEL: test_gelu // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x8x19xf32> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() {value = dense<3.000000e+00> : tensor<1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<4.471500e-02> : tensor<1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<0.797884583> : tensor<1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<1x1x1x1xf32>} +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<4.471500e-02> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.797884583> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1x1xf32>}> // CHECK: %[[VAL_6:.*]] = "tosa.pow"(%[[VAL_0]], %[[VAL_1]]) -// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_6]], %[[VAL_2]]) {shift = 0 : i32} +// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_6]], %[[VAL_2]]) <{shift = 0 : i32}> // CHECK: %[[VAL_8:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_7]]) -// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_8]], %[[VAL_3]]) {shift = 0 : i32} +// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_8]], %[[VAL_3]]) <{shift = 0 : i32}> // CHECK: %[[VAL_10:.*]] = "tosa.tanh"(%[[VAL_9]]) // CHECK: %[[VAL_11:.*]] = "tosa.add"(%[[VAL_10]], %[[VAL_4]]) -// CHECK: %[[VAL_12:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_5]]) {shift = 0 : i32} -// CHECK: %[[VAL_13:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_11]]) {shift = 0 : i32} +// CHECK: %[[VAL_12:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_5]]) <{shift = 0 : i32}> +// CHECK: %[[VAL_13:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_11]]) <{shift = 0 : i32}> func.func @test_gelu(%arg0: tensor<1x4x8x19xf32>) -> tensor<1x4x8x19xf32> { %0 = "tfl.gelu"(%arg0) {approximate = true} : (tensor<1x4x8x19xf32>) -> tensor<1x4x8x19xf32> func.return %0 : tensor<1x4x8x19xf32> @@ -2377,8 +2530,8 @@ func.func @test_gelu(%arg0: tensor<1x4x8x19xf32>) -> tensor<1x4x8x19xf32> { // CHECK-LABEL: test_gelu_qi8 // CHECK-SAME: %[[VAR0:.*]]: tensor<1x4x4x4x!quant.uniform> -// CHECK: %[[VAR1:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<256xi8>} -// CHECK: %[[VAR2:.*]] = "tosa.table"(%[[VAR0]], %[[VAR1]]) : (tensor<1x4x4x4x!quant.uniform>, tensor<256xi8>) +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<256xi8>}> +// CHECK: %[[VAR2:.*]] = "tosa.table"(%[[VAR0]], %[[VAR1]]) : (tensor<1x4x4x4x!quant.uniform>, tensor<256x!quant.uniform>) func.func @test_gelu_qi8(%arg0: tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> { %0 = "tfl.gelu"(%arg0) {approximate = true} : (tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> func.return %0 : tensor<1x4x4x4x!quant.uniform> @@ -2388,14 +2541,14 @@ func.func @test_gelu_qi8(%arg0: tensor<1x4x4x4x!quant.uniform> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<4x9x!quant.uniform>) -// CHECK: %[[VAL_2:.*]] = "tosa.reverse"(%[[VAL_1]]) {axis = 0 : i64} : (tensor<2x9x!quant.uniform>) -// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<4x9x!quant.uniform>) -// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_2]], %[[VAL_0]], %[[VAL_3]]) {axis = 0 : i64} : (tensor<2x9x!quant.uniform>, tensor<4x9x!quant.uniform>, tensor<1x9x!quant.uniform>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<7x9x!quant.uniform>) -// CHECK: %[[VAL_6:.*]] = "tosa.reverse"(%[[VAL_5]]) {axis = 1 : i64} : (tensor<7x2x!quant.uniform>) -// CHECK: %[[VAL_7:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<7x9x!quant.uniform>) -// CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_6]], %[[VAL_4]], %[[VAL_7]]) {axis = 1 : i64} : (tensor<7x2x!quant.uniform>, tensor<7x9x!quant.uniform>, tensor<7x1x!quant.uniform>) +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<4x9x!quant.uniform>) +// CHECK: %[[VAL_2:.*]] = "tosa.reverse"(%[[VAL_1]]) <{axis = 0 : i64}> : (tensor<2x9x!quant.uniform>) +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<4x9x!quant.uniform>) +// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_2]], %[[VAL_0]], %[[VAL_3]]) <{axis = 0 : i64}> : (tensor<2x9x!quant.uniform>, tensor<4x9x!quant.uniform>, tensor<1x9x!quant.uniform>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) <{size = array, start = array}> : (tensor<7x9x!quant.uniform>) +// CHECK: %[[VAL_6:.*]] = "tosa.reverse"(%[[VAL_5]]) <{axis = 1 : i64}> : (tensor<7x2x!quant.uniform>) +// CHECK: %[[VAL_7:.*]] = "tosa.slice"(%[[VAL_4]]) <{size = array, start = array}> : (tensor<7x9x!quant.uniform>) +// CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_6]], %[[VAL_4]], %[[VAL_7]]) <{axis = 1 : i64}> : (tensor<7x2x!quant.uniform>, tensor<7x9x!quant.uniform>, tensor<7x1x!quant.uniform>) func.func @mirrorpad_reflect(%arg0: tensor<4x9x!quant.uniform>) -> tensor<7x12x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<[[2, 1], [2, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tfl.mirror_pad"(%arg0, %0) {mode = #tfl} : (tensor<4x9x!quant.uniform>, tensor<2x2xi32>) -> tensor<7x12x!quant.uniform> @@ -2406,12 +2559,12 @@ func.func @mirrorpad_reflect(%arg0: tensor<4x9x!quant.uniform -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<15x23x2xf32>) -// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]]) {axis = 0 : i64} : (tensor<1x23x2xf32>, tensor<15x23x2xf32>) -// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<16x23x2xf32>) -// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_3]], %[[VAL_2]]) {axis = 1 : i64} : (tensor<16x1x2xf32>, tensor<16x23x2xf32>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<16x24x2xf32>) -// CHECK: %[[VAL_6:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]]) {axis = 2 : i64} : (tensor<16x24x1xf32>, tensor<16x24x2xf32>) +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<15x23x2xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]]) <{axis = 0 : i64}> : (tensor<1x23x2xf32>, tensor<15x23x2xf32>) +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) <{size = array, start = array}> : (tensor<16x23x2xf32>) +// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_3]], %[[VAL_2]]) <{axis = 1 : i64}> : (tensor<16x1x2xf32>, tensor<16x23x2xf32>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) <{size = array, start = array}> : (tensor<16x24x2xf32>) +// CHECK: %[[VAL_6:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]]) <{axis = 2 : i64}> : (tensor<16x24x1xf32>, tensor<16x24x2xf32>) func.func @mirrorpad_symmetric(%arg0: tensor<15x23x2xf32>) -> tensor<16x24x3xf32> { %0 = "tfl.pseudo_const"() {value = dense<[[1, 0], [1, 0], [1, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %1 = "tfl.mirror_pad"(%arg0, %0) {mode = #tfl} : (tensor<15x23x2xf32>, tensor<3x2xi32>) -> tensor<16x24x3xf32> @@ -2422,8 +2575,8 @@ func.func @mirrorpad_symmetric(%arg0: tensor<15x23x2xf32>) -> tensor<16x24x3xf32 // CHECK-LABEL: @test_reverse_works func.func @test_reverse_works(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - // CHECK: %[[VAL0:.+]] = "tosa.reverse"(%arg0) {axis = 1 : i64} - // CHECK: %[[VAL1:.+]] = "tosa.reverse"(%[[VAL0]]) {axis = 2 : i64} + // CHECK: %[[VAL0:.+]] = "tosa.reverse"(%arg0) <{axis = 1 : i64}> + // CHECK: %[[VAL1:.+]] = "tosa.reverse"(%[[VAL0]]) <{axis = 2 : i64}> %0 = "tfl.pseudo_const"() {value = dense<[1, -2]> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.reverse_v2"(%arg0, %0): (tensor<1x2x3x4xf32>, tensor<2xi32>) -> tensor<1x2x3x4xf32> func.return %1 : tensor<1x2x3x4xf32> @@ -2443,7 +2596,7 @@ func.func @test_reverse_fail(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> // CHECK-LABEL: test_tfl_custom // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x64x64x32xf32> -// CHECK: %[[VAL_0:.*]] = "tosa.custom"(%[[ARG_0]]) {config = "TFL", identifier = "MaxPoolingWithArgmax2D", implementation_attrs = "{{.*}}"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) +// CHECK: %[[VAL_0:.*]] = "tosa.custom"(%[[ARG_0]]) <{config = "TFL", identifier = "MaxPoolingWithArgmax2D", implementation_attrs = "{{.*}}"}> : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) func.func @test_tfl_custom(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { // custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) %0, %1 = "tfl.custom"(%arg0) {custom_option = #tfl, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) @@ -2453,15 +2606,15 @@ func.func @test_tfl_custom(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32x // ----- // CHECK-LABEL: test_tfl_while_loop // CHECK: %[[VAL_0:.*]]: tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["placeholder_0"]}) -> (tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["output_0"]}) { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_2:.*]] = "tosa.while_loop"(%[[VAL_0]]) ({ // CHECK: ^bb0(%[[VAL_3:.*]]: tensor<1x4x4x4xf32>): -// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 1 : i64} : (tensor<1x4x4x4xf32>) -> tensor<1x1x4x4xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.reduce_sum"(%[[VAL_4]]) {axis = 2 : i64} : (tensor<1x1x4x4xf32>) -> tensor<1x1x1x4xf32> -// CHECK: %[[VAL_6:.*]] = "tosa.reduce_sum"(%[[VAL_5]]) {axis = 3 : i64} : (tensor<1x1x1x4xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) <{axis = 1 : i64}> : (tensor<1x4x4x4xf32>) -> tensor<1x1x4x4xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.reduce_sum"(%[[VAL_4]]) <{axis = 2 : i64}> : (tensor<1x1x4x4xf32>) -> tensor<1x1x1x4xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.reduce_sum"(%[[VAL_5]]) <{axis = 3 : i64}> : (tensor<1x1x1x4xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) <{new_shape = array}> : (tensor<1x1x1x1xf32>) -> tensor<1xf32> // CHECK: %[[VAL_8:.*]] = "tosa.greater"(%[[VAL_1]], %[[VAL_7]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> -// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_8]]) {new_shape = array} : (tensor<1xi1>) -> tensor +// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_8]]) <{new_shape = array}> : (tensor<1xi1>) -> tensor // CHECK: "tosa.yield"(%[[VAL_9]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[VAL_10:.*]]: tensor<1x4x4x4xf32>): @@ -2500,10 +2653,82 @@ func.func private @result_body(%arg0: tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32 // ----- +// CHECK-LABEL: test_rfft2d +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x16xf32> +// CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = "tosa.rfft2d"(%[[VAL_0]]) : (tensor<1x8x16xf32>) -> (tensor<1x8x9xf32>, tensor<1x8x9xf32>) +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<1x8x9xf32>) -> tensor<1x8x9x1xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<1x8x9xf32>) -> tensor<1x8x9x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.concat"(%[[VAL_3]], %[[VAL_4]]) <{axis = 3 : i64}> : (tensor<1x8x9x1xf32>, tensor<1x8x9x1xf32>) -> tensor<1x8x9x2xf32> +// CHECK: return %[[VAL_5]] : tensor<1x8x9x2xf32> +func.func @test_rfft2d(%arg0: tensor<1x8x16xf32>) -> tensor<1x8x9xcomplex> { + %0 = "tfl.pseudo_const"() {value = dense<[8, 16]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.rfft2d"(%arg0, %0) : (tensor<1x8x16xf32>, tensor<2xi32>) -> tensor<1x8x9xcomplex> + return %1 : tensor<1x8x9xcomplex> +} + +// ----- + +// CHECK-LABEL: test_rfft2d_crop_input +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<13x21x3xf32>) -> tensor<13x2x2xf32> +// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = "tosa.rfft2d"(%[[VAL_1]]) : (tensor<13x2x2xf32>) -> (tensor<13x2x2xf32>, tensor<13x2x2xf32>) +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<13x2x2xf32>) -> tensor<13x2x2x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_3]]) <{new_shape = array}> : (tensor<13x2x2xf32>) -> tensor<13x2x2x1xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.concat"(%[[VAL_4]], %[[VAL_5]]) <{axis = 3 : i64}> : (tensor<13x2x2x1xf32>, tensor<13x2x2x1xf32>) -> tensor<13x2x2x2xf32> +// CHECK: return %[[VAL_6]] : tensor<13x2x2x2xf32> +func.func @test_rfft2d_crop_input(%arg0: tensor<13x21x3xf32>) -> tensor<13x2x2xcomplex> { + %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tfl.rfft2d"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<13x2x2xcomplex> + return %1 : tensor<13x2x2xcomplex> +} + +// ----- + +// CHECK-LABEL: test_rfft2d_pad_input +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 11], [0, 5]]> : tensor<3x2xi32>}> : () -> tensor<3x2xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.pad"(%[[VAL_0]], %[[VAL_2]], %[[VAL_1]]) : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor) -> tensor<13x32x8xf32> +// CHECK: %[[VAL_4:.*]], %[[VAL_5:.*]] = "tosa.rfft2d"(%[[VAL_3]]) : (tensor<13x32x8xf32>) -> (tensor<13x32x5xf32>, tensor<13x32x5xf32>) +// CHECK: %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> : (tensor<13x32x5xf32>) -> tensor<13x32x5x1xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_5]]) <{new_shape = array}> : (tensor<13x32x5xf32>) -> tensor<13x32x5x1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_6]], %[[VAL_7]]) <{axis = 3 : i64}> : (tensor<13x32x5x1xf32>, tensor<13x32x5x1xf32>) -> tensor<13x32x5x2xf32> +// CHECK: return %[[VAL_8]] : tensor<13x32x5x2xf32> +func.func @test_rfft2d_pad_input(%arg0: tensor<13x21x3xf32>) -> (tensor<13x32x5xcomplex>) { + %0 = "tfl.pseudo_const"() {value = dense<[[0, 0], [0, 11], [0, 5]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %1 = "tfl.pad"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x32x8xf32> + %2 = "tfl.pseudo_const"() {value = dense<[32, 8]> : tensor<2xi32>} : () -> tensor<2xi32> + %3 = "tfl.rfft2d"(%1, %2) : (tensor<13x32x8xf32>, tensor<2xi32>) -> tensor<13x32x5xcomplex> + return %3 : tensor<13x32x5xcomplex> +} + +// ----- + +// CHECK-LABEL: test_rfft2d_crop_height_pad_width +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 0], [0, 13]]> : tensor<3x2xi32>}> : () -> tensor<3x2xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.pad"(%[[VAL_0]], %[[VAL_2]], %[[VAL_1]]) : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor) -> tensor<13x21x16xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.slice"(%[[VAL_3]]) <{size = array, start = array}> : (tensor<13x21x16xf32>) -> tensor<13x2x16xf32> +// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = "tosa.rfft2d"(%[[VAL_4]]) : (tensor<13x2x16xf32>) -> (tensor<13x2x9xf32>, tensor<13x2x9xf32>) +// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_5]]) <{new_shape = array}> : (tensor<13x2x9xf32>) -> tensor<13x2x9x1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_6]]) <{new_shape = array}> : (tensor<13x2x9xf32>) -> tensor<13x2x9x1xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.concat"(%[[VAL_7]], %[[VAL_8]]) <{axis = 3 : i64}> : (tensor<13x2x9x1xf32>, tensor<13x2x9x1xf32>) -> tensor<13x2x9x2xf32> +// CHECK: return %[[VAL_9]] : tensor<13x2x9x2xf32> +func.func @test_rfft2d_crop_height_pad_width(%arg0: tensor<13x21x3xf32>) -> (tensor<13x2x9xcomplex>) { + %0 = "tfl.pseudo_const"() {value = dense<[[0, 0], [0, 0], [0, 13]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %1 = "tfl.pad"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x16xf32> + %2 = "tfl.pseudo_const"() {value = dense<[2, 16]> : tensor<2xi32>} : () -> tensor<2xi32> + %3 = "tfl.rfft2d"(%1, %2) : (tensor<13x21x16xf32>, tensor<2xi32>) -> tensor<13x2x9xcomplex> + return %3 : tensor<13x2x9xcomplex> +} + +// ----- + // CHECK-LABEL: test_real // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x9x2xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<1x8x9x2xf32>) -> tensor<1x8x9x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor<1x8x9x1xf32>) -> tensor<1x8x9xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<1x8x9x2xf32>) -> tensor<1x8x9x1xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<1x8x9x1xf32>) -> tensor<1x8x9xf32> // CHECK: return %[[VAL_2]] : tensor<1x8x9xf32> func.func @test_real(%arg0: tensor<1x8x9xcomplex>) -> (tensor<1x8x9xf32>) { %0 = "tfl.real"(%arg0) {} : (tensor<1x8x9xcomplex>) -> tensor<1x8x9xf32> @@ -2525,8 +2750,8 @@ func.func @test_real_non_complex(%arg0: tensor<1x8x9xf32>) -> (tensor<1x8x9xf32> // CHECK-LABEL: test_imag // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x9x2xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<1x8x9x2xf32>) -> tensor<1x8x9x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor<1x8x9x1xf32>) -> tensor<1x8x9xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<1x8x9x2xf32>) -> tensor<1x8x9x1xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<1x8x9x1xf32>) -> tensor<1x8x9xf32> // CHECK: return %[[VAL_2]] : tensor<1x8x9xf32> func.func @test_imag(%arg0: tensor<1x8x9xcomplex>) -> (tensor<1x8x9xf32>) { %0 = "tfl.imag"(%arg0) {} : (tensor<1x8x9xcomplex>) -> tensor<1x8x9xf32> @@ -2537,7 +2762,7 @@ func.func @test_imag(%arg0: tensor<1x8x9xcomplex>) -> (tensor<1x8x9xf32>) { // CHECK-LABEL: test_imag_non_complex // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x9xf32> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x8x9xf32>} : () -> tensor<1x8x9xf32> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x8x9xf32>}> : () -> tensor<1x8x9xf32> // CHECK: return %[[VAL_1]] : tensor<1x8x9xf32> func.func @test_imag_non_complex(%arg0: tensor<1x8x9xf32>) -> (tensor<1x8x9xf32>) { %0 = "tfl.imag"(%arg0) {} : (tensor<1x8x9xf32>) -> tensor<1x8x9xf32> diff --git a/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir new file mode 100644 index 00000000000..3783c379908 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir @@ -0,0 +1,19 @@ +// RUN: tf-opt %s --tosa-tflite-verify-fully-converted --split-input-file -verify-diagnostics + +// CHECK-LABEL: func.func @main +func.func @main(%arg0: tensor<2xf32>) -> (tensor<2xf32>) { + // CHECK: "tosa.add" + %0 = "tosa.add"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// expected-error@below {{The following illegal operations still remain}} +func.func @main(%arg0: tensor<1x8x8x3xf32>) -> tensor<1x8x8x3xf32> attributes {tf.entry_function = {inputs = "input", outputs = "output"}} { + // expected-error@+1 {{'tfl.add' op : illegal op still exists}} + %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32> + // expected-error@+1 {{'tfl.sub' op : illegal op still exists}} + %1 = tfl.sub %0, %arg0 {fused_activation_function = "NONE"} : tensor<1x8x8x3xf32> + return %1 : tensor<1x8x8x3xf32> +} diff --git a/tensorflow/compiler/mlir/tosa/tf_passes.cc b/tensorflow/compiler/mlir/tosa/tf_passes.cc index caedab20ccf..f1e7191e2ca 100644 --- a/tensorflow/compiler/mlir/tosa/tf_passes.cc +++ b/tensorflow/compiler/mlir/tosa/tf_passes.cc @@ -36,8 +36,8 @@ void createTFtoTOSALegalizationPipeline( pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopFusionPass()); - pm.addPass(mlir::createAffineScalarReplacementPass()); + pm.addPass(mlir::affine::createLoopFusionPass()); + pm.addPass(mlir::affine::createAffineScalarReplacementPass()); //---------------------------------------------------------------------------- // Perform main conversion. diff --git a/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc b/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc index 98cd3514561..2b31e3246fd 100644 --- a/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/tosa/tf_tfl_passes.cc @@ -40,8 +40,8 @@ void createTFTFLtoTOSALegalizationPipeline( pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopFusionPass()); - pm.addPass(mlir::createAffineScalarReplacementPass()); + pm.addPass(mlir::affine::createLoopFusionPass()); + pm.addPass(mlir::affine::createAffineScalarReplacementPass()); //---------------------------------------------------------------------------- // Perform main conversion. diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.cc b/tensorflow/compiler/mlir/tosa/tfl_passes.cc index 9f352f9c4a3..ff3c38e381e 100644 --- a/tensorflow/compiler/mlir/tosa/tfl_passes.cc +++ b/tensorflow/compiler/mlir/tosa/tfl_passes.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Affine/Passes.h" // from @llvm-project #include "mlir/Dialect/Tosa/Transforms/Passes.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -29,8 +30,17 @@ void createTFLtoTOSALegalizationPipeline( //---------------------------------------------------------------------------- // Prepare TFL module for conversion //---------------------------------------------------------------------------- + if (opts.target_compilation_backend) { + pm.addPass(createRetainCallOnceFuncsPass()); + } // Inline all functions into main and then delete the functions themselves. pm.addPass(mlir::createInlinerPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createSymbolDCEPass()); + if (opts.target_compilation_backend) { + pm.nest().addPass(createConvertFunctionMetadataPass()); + pm.addPass(createLowerGlobalTensorsPass()); + } // Add pass to decompose TFLite mixed quantization to non-quantized variants. pm.addPass(TFL::CreateDecomposeHybridQuantizationPass()); @@ -39,8 +49,8 @@ void createTFLtoTOSALegalizationPipeline( pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopFusionPass()); - pm.addPass(mlir::createAffineScalarReplacementPass()); + pm.addPass(mlir::affine::createLoopFusionPass()); + pm.addPass(mlir::affine::createAffineScalarReplacementPass()); //---------------------------------------------------------------------------- // Perform main conversion. @@ -62,6 +72,15 @@ void createTFLtoTOSALegalizationPipeline( pm.addPass(mlir::createInlinerPass()); // Clean up with DCE. pm.addPass(mlir::createSymbolDCEPass()); + + if (opts.target_compilation_backend) { + pm.nest().addPass(mlir::tosa::createStripQuantTypesPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.nest().addPass(createStripFunctionMetadataPass()); + pm.addPass(createStripModuleMetadataPass()); + pm.addPass(createVerifyFullyConvertedPass()); + } } void registerTFLtoTOSALegalizationPipeline() { diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.h b/tensorflow/compiler/mlir/tosa/tfl_passes.h index 1d73a655ce0..228b9ec2691 100644 --- a/tensorflow/compiler/mlir/tosa/tfl_passes.h +++ b/tensorflow/compiler/mlir/tosa/tfl_passes.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/Support/CommandLine.h" #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project @@ -29,7 +30,16 @@ struct TOSATFLLegalizationPipelineOptions : public PassPipelineOptions { ArrayRef disabled_patterns; ArrayRef enabled_patterns; - bool dequantize_tfl_softmax = false; + + PassOptions::Option target_compilation_backend{ + *this, "target-compilation-backend", + llvm::cl::desc("Whether targetting compilation backend"), + llvm::cl::init(false)}; + + PassOptions::Option dequantize_tfl_softmax{ + *this, "dequantize-tfl-softmax", + llvm::cl::desc("Dequantize the TFLite softmax"), llvm::cl::init(false)}; + TOSATFLLegalizationPipelineOptions() { disabled_patterns = std::nullopt; enabled_patterns = std::nullopt; diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_metadata.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_metadata.cc new file mode 100644 index 00000000000..f81aee69f55 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_metadata.cc @@ -0,0 +1,114 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/StringExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::tosa { + +#define GEN_PASS_DEF_CONVERTFUNCTIONMETADATA +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" + +namespace { + +// Extract the input and output names +static void splitFunctionIONames(StringAttr namesAttr, + llvm::SmallVectorImpl &names) { + SmallVector namesRef; + llvm::SplitString(namesAttr.getValue(), namesRef, ","); + for (auto nameRef : namesRef) { + names.push_back(nameRef.str()); + } +} + +class ConvertFunctionMetadataPass + : public impl::ConvertFunctionMetadataBase { + public: + void runOnOperation() override { + auto funcOp = getOperation(); + + // Setup entry functions for compilation and preserve the + // associated metadata. Note that TFLite uses `tf.entry_function`. + auto entryFunctionAttr = + funcOp->getAttrOfType("tf.entry_function"); + if (entryFunctionAttr) { + setupEntryPointAttrs(funcOp, entryFunctionAttr); + } + } + + private: + // TF/TFL pack their I/O names in a dictionary, convert into arg attributes. + void setupEntryPointAttrs(func::FuncOp funcOp, + DictionaryAttr entryFunctionAttr) { + funcOp.setPublic(); + + if (funcOp.getNumArguments() > 0) { + auto inputsAttr = + dyn_cast_or_null(entryFunctionAttr.get("inputs")); + if (!inputsAttr) { + funcOp.emitError() << "functions with tf.entry_function must have " + "input names to be handled by backend"; + return signalPassFailure(); + } + SmallVector inputNames; + splitFunctionIONames(inputsAttr, inputNames); + if (inputNames.size() != funcOp.getNumArguments()) { + funcOp.emitError() + << "tf.entry_function attribute malformed: inputs don't " + "match the function signature"; + return signalPassFailure(); + } + for (auto [i, name] : llvm::enumerate(inputNames)) { + funcOp.setArgAttr(i, "ml_program.identifier", + StringAttr::get(&getContext(), name)); + } + } + if (funcOp.getNumResults() > 0) { + auto outputsAttr = + dyn_cast_or_null(entryFunctionAttr.get("outputs")); + if (!outputsAttr) { + funcOp.emitError() << "functions with tf.entry_function must have " + "output names to be handled by backend"; + return signalPassFailure(); + } + SmallVector outputNames; + splitFunctionIONames(outputsAttr, outputNames); + if (outputNames.size() != funcOp.getNumResults()) { + funcOp.emitError() + << "tf.entry_function attribute malformed: outputs don't " + "match the function signature"; + return signalPassFailure(); + } + for (auto [i, name] : llvm::enumerate(outputNames)) { + funcOp.setResultAttr(i, "ml_program.identifier", + StringAttr::get(&getContext(), name)); + } + } + } +}; +} // anonymous namespace + +std::unique_ptr> +createConvertFunctionMetadataPass() { + return std::make_unique(); +} + +} // namespace mlir::tosa diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 4e8cd06c2bf..54539429695 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" @@ -1396,9 +1397,11 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, // softmax = exp(logits - max(logits)) / reduce_sum(exp(logits - // max(logits)), -1) // - // We'll use first version for direct fp lowering, and second version for - // quantized lowering since second one we can restrict input to exp() be - // negative, and thus LUT can always be within [0.0, 1.0]. + // Second equation is used for both quantized and fp lowering. + // For quantized case, we can restrict input to exp() be negative, + // and thus LUT can always be within [0.0, 1.0]. + // For fp case, the normalization in the equation is required to prevent + // float overflow in softmax's intermediate calculations. RankedTensorType output_type = result_value.getType().dyn_cast(); RankedTensorType input_type = @@ -1777,28 +1780,42 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, rsum_shape_v[input_rank - 1] = 1; ArrayRef rsum_shape(rsum_shape_v); - // Floating-point loewring is more direct: + // Floating-point lowering is more direct: // - // op1 = exp(logits) - // op2 = reduce_sum(op1, -1) - // op3 = reciprocal(op2) - // op4 = mul(op1, op3) - auto op1_exp_in = CreateOpAndInfer(rewriter, op->getLoc(), - output_type, logits_value); + // op1 = reducemax(logits) + // op2 = sub(logits, op1) + // op3 = exp(op2) + // op4 = reduce_sum(op3, -1) + // op5 = reciprocal(op4) + // op6 = mul(op3, op5) RankedTensorType rsum_type = tensorflow::GetTypeFromTFTensorShape( rsum_shape, output_type.getElementType()); + RankedTensorType logits_type = tensorflow::GetTypeFromTFTensorShape( + logits_shape, output_type.getElementType()); - // Keep dims so we don't need to reshape later - auto op2_reducesum_op1 = CreateOpAndInfer( - rewriter, op->getLoc(), rsum_type, op1_exp_in.getResult(), + // Step 1. get x - max(x) + auto max_logits = CreateOpAndInfer( + rewriter, op->getLoc(), rsum_type, logits_value, rewriter.getI64IntegerAttr(input_rank - 1)); - auto op3_reciprocal_op2 = CreateOpAndInfer( - rewriter, op->getLoc(), op2_reducesum_op1.getType(), - op2_reducesum_op1.getResult()); + auto normalized_logits = + CreateOpAndInfer(rewriter, op->getLoc(), logits_type, + logits_value, max_logits.getResult()); + + // Step 2. get exp(x - max(x)) + auto exp_norm_logits = CreateOpAndInfer( + rewriter, op->getLoc(), output_type, normalized_logits); + + // Step 3. reuse softmax numerator to obtain denominator + // Keep dims so we don't need to reshape later + auto reducesum = CreateOpAndInfer( + rewriter, op->getLoc(), rsum_type, exp_norm_logits.getResult(), + rewriter.getI64IntegerAttr(input_rank - 1)); + auto denominator = CreateOpAndInfer( + rewriter, op->getLoc(), reducesum.getType(), reducesum.getResult()); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - op1_exp_in.getResult(), - op3_reciprocal_op2.getResult(), 0) + exp_norm_logits.getResult(), + denominator.getResult(), 0) .getResult(); } } @@ -2222,7 +2239,7 @@ std::optional> convertSplitVOp( // the only legal negative stride. static Value reverseNegativeStride(PatternRewriter& rewriter, Operation* op, Value input, ArrayRef strides) { - for (auto it : llvm::enumerate(strides)) { + for (const auto& it : llvm::enumerate(strides)) { auto axis = it.index(); auto stride = it.value(); if (stride != -1) continue; @@ -2321,7 +2338,7 @@ std::optional convertStridedSliceOp( } // Set begin mask values if possible. - for (auto& val : llvm::enumerate(begin)) + for (const auto& val : llvm::enumerate(begin)) begin_mask |= (val.value() == 0) << val.index(); // If all begin/end masks are set and striding is one we can just return @@ -3096,7 +3113,7 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, int& border) { // Dimension is length 1, we are just sampling from one value. if (input == 1) { - n = 1; + n = output; d = 1; offset = 0; border = output - 1; @@ -4463,5 +4480,45 @@ std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, .getResult(); } +// Lowers Sign operator to a sequence of TOSA ops. +std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, + Value input, RankedTensorType output_type) { + auto output_elem_type = output_type.getElementType(); + if (output_elem_type.isa()) { + (void)rewriter.notifyMatchFailure(op, "tfl quantization not yet supported"); + return std::nullopt; + } + + // TOSA greater and select can both broadcast, so simply create a tensor with + // one element. + Value pos_one, neg_one, zero; + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + if (output_elem_type.isa()) { + pos_one = getTosaConstTensorSingleF32(rewriter, op, 1.0f); + neg_one = getTosaConstTensorSingleF32(rewriter, op, -1.0f); + zero = getTosaConstTensorSingleF32(rewriter, op, 0.0f); + } else { + pos_one = getTosaConstTensorScalarInt(builder, output_elem_type, 1); + neg_one = getTosaConstTensorScalarInt(builder, output_elem_type, -1); + zero = getTosaConstTensorScalarInt(builder, output_elem_type, 0); + } + + ShapedType const_type = output_type.clone(rewriter.getIntegerType(1)); + + auto gt_zero_op = + CreateOpAndInfer(builder, const_type, input, zero); + + auto lt_zero_op = + CreateOpAndInfer(builder, const_type, zero, input); + + auto select_neg_op = CreateOpAndInfer( + builder, output_type, lt_zero_op, neg_one, zero); + + // Select positive one based on the condition tensor. + return CreateOpAndInfer(builder, output_type, gt_zero_op, + pos_one, select_neg_op) + .getResult(); +} + }; // namespace tosa }; // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 33eaaab4202..3dc87952753 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -302,6 +302,10 @@ std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, Value input, ShapedType output_type); +// Lowers Sign operator to a sequence of TOSA ops. +std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, + Value input, RankedTensorType output_type); + }; // namespace tosa }; // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index d38958f7fff..5418eab622c 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -79,6 +79,7 @@ DECL_CONVERT_OP(Sub); DECL_CONVERT_OP(Mul); DECL_CONVERT_OP(Square); DECL_CONVERT_OP(SquaredDifference); +DECL_CONVERT_OP(Sign); DECL_CONVERT_OP(Round); DECL_CONVERT_OP(FloorDiv); DECL_CONVERT_OP(FloorMod); @@ -250,6 +251,21 @@ LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFSignOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tf_sign_op = cast(op); + + RankedTensorType output_type = + tf_sign_op.getResult().getType().cast(); + + std::optional result = + convertSignOp(rewriter, op, tf_sign_op.getX(), output_type); + if (!result) return failure(); + + rewriter.replaceOp(op, {result.value()}); + return success(); +} + LogicalResult ConvertTFSinOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_sin_op = cast(op); @@ -748,8 +764,8 @@ LogicalResult ConvertTFRankOp::matchAndRewrite( RankedTensorType rank_type = tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getIntegerType(32)); auto rank_attr = DenseI32ArrayAttr::get(rewriter.getContext(), {rank}); - auto rank_const = CreateOpAndInfer(rewriter, op->getLoc(), - rank_type, rank_attr); + auto rank_const = CreateOpAndInfer( + rewriter, op->getLoc(), rank_type, cast(rank_attr)); rewriter.replaceOp(op, {rank_const.getResult()}); @@ -780,8 +796,8 @@ LogicalResult ConvertTFShapeOp::matchAndRewrite( {static_cast(shape_arr.size())}, rewriter.getIntegerType(32)); auto shape_attr = DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(shape_arr)); - auto shape_const = CreateOpAndInfer(rewriter, op->getLoc(), - shape_type, shape_attr); + auto shape_const = CreateOpAndInfer( + rewriter, op->getLoc(), shape_type, cast(shape_attr)); rewriter.replaceOp(op, {shape_const.getResult()}); @@ -849,11 +865,12 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( return failure(); RankedTensorType fill_type = tensorflow::GetTypeFromTFTensorShape( - ArrayRef(dims_vals), value_elem.getType().getElementType()); + ArrayRef(dims_vals), + value_elem.getShapedType().getElementType()); DenseArrayAttr fill_attr; // Convert to a compatible zero type - if (value_elem.getType().getElementType().isa()) { + if (value_elem.getShapedType().getElementType().isa()) { SmallVector fill_arr( total_size, value_elem.getValues()[0].getValue().convertToFloat()); @@ -866,8 +883,8 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( fill_attr = DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } - auto fill_const_op = CreateOpAndInfer(rewriter, op->getLoc(), - fill_type, fill_attr); + auto fill_const_op = CreateOpAndInfer( + rewriter, op->getLoc(), fill_type, fill_attr.cast()); rewriter.replaceOp(op, {fill_const_op.getResult()}); return success(); @@ -2428,6 +2445,7 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index b0d249495f9..0162ddd4a8a 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -15,6 +15,7 @@ limitations under the License. // Legalize TensorFlow Lite to TOSA +#include #include #include #include @@ -104,6 +105,7 @@ DECL_CONVERT_OP(Sub); DECL_CONVERT_OP(Mul); DECL_CONVERT_OP(Square); DECL_CONVERT_OP(SquaredDifference); +DECL_CONVERT_OP(Sign); DECL_CONVERT_OP(Round); DECL_CONVERT_OP(Div); DECL_CONVERT_OP(Maximum); @@ -123,6 +125,7 @@ DECL_CONVERT_OP(Fill); DECL_CONVERT_OP(Elu); DECL_CONVERT_OP(Softmax); DECL_CONVERT_OP(LogSoftmax); +DECL_CONVERT_OP(Rsqrt); DECL_CONVERT_OP(Sqrt); DECL_CONVERT_OP(L2Normalization); DECL_CONVERT_OP(ReduceAll); @@ -187,6 +190,7 @@ DECL_CONVERT_OP(FakeQuant); DECL_CONVERT_OP(While); DECL_CONVERT_OP(Real); DECL_CONVERT_OP(Imag); +DECL_CONVERT_OP(RFFT2d); #undef DECL_CONVERT_OP @@ -816,6 +820,21 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, return success(); } +LogicalResult ConvertTFLSignOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_sign_op = cast(op); + + RankedTensorType output_type = + tfl_sign_op.getResult().getType().cast(); + + std::optional result = + convertSignOp(rewriter, op, tfl_sign_op.getX(), output_type); + if (!result) return failure(); + + rewriter.replaceOp(op, {result.value()}); + return success(); +} + LogicalResult ConvertTFLAddOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { return matchAndRewriteAddSub(op, op->getOperands(), @@ -1251,6 +1270,126 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( return success(); } +// Returns a new type based on an input type and slicing information. +// If the input is quantized per axis, slices the scale and zp arrays. +// In any other case, returns the original type. +RankedTensorType getTypeForSlice(RankedTensorType type, int64_t slice_size, + int64_t offset) { + if (auto per_channel_qtype = + dyn_cast(type.getElementType())) { + SmallVector output_scale_arr( + per_channel_qtype.getScales().begin() + offset, + per_channel_qtype.getScales().begin() + offset + slice_size); + SmallVector output_zp_arr( + per_channel_qtype.getZeroPoints().begin() + offset, + per_channel_qtype.getZeroPoints().begin() + offset + slice_size); + auto output_per_channel_qtype = quant::UniformQuantizedPerAxisType::get( + per_channel_qtype.getFlags(), per_channel_qtype.getStorageType(), + per_channel_qtype.getExpressedType(), output_scale_arr, output_zp_arr, + per_channel_qtype.getQuantizedDimension(), + per_channel_qtype.getStorageTypeMin(), + per_channel_qtype.getStorageTypeMax()); + return RankedTensorType::get(type.getShape(), output_per_channel_qtype); + } + return type; +} + +Value lowerGroupedConvolution(TFL::Conv2DOp op, PatternRewriter& rewriter) { + auto input_type = dyn_cast(op.getInput().getType()); + auto filter_type = dyn_cast(op.getFilter().getType()); + auto bias_type = dyn_cast(op.getBias().getType()); + auto output_type = dyn_cast(op.getResult().getType()); + + // The inputs are NHWC, so the slicing/concatenation is done over dim 3. + int64_t in_channels_dim = 3; + int64_t input_channels = input_type.getDimSize(in_channels_dim); + int64_t filter_channels = filter_type.getDimSize(in_channels_dim); + int64_t num_groups = input_channels / filter_channels; + + SmallVector convolutions; + convolutions.reserve(num_groups); + auto rank = input_type.getRank(); + + // Input size vector + SmallVector input_size_vals(input_type.getShape().begin(), + input_type.getShape().end()); + input_size_vals.back() = filter_channels; + DenseI64ArrayAttr input_size = rewriter.getDenseI64ArrayAttr(input_size_vals); + auto input_slice_ty = + RankedTensorType::get(input_size_vals, input_type.getElementType()); + + // Filter size vector + SmallVector filter_size_vals(filter_type.getShape().begin(), + filter_type.getShape().end()); + filter_size_vals.front() = filter_type.getDimSize(0) / num_groups; + DenseI64ArrayAttr filter_size = + rewriter.getDenseI64ArrayAttr(filter_size_vals); + auto filter_slice_ty = + RankedTensorType::get(filter_size_vals, filter_type.getElementType()); + + // Bias size vector + int64_t bias_size_val = bias_type.getDimSize(0) / num_groups; + DenseI64ArrayAttr bias_size = rewriter.getDenseI64ArrayAttr(bias_size_val); + auto bias_slice_ty = + RankedTensorType::get(bias_size_val, bias_type.getElementType()); + + auto per_conv_out_ty = RankedTensorType::get( + {output_type.getDimSize(0), output_type.getDimSize(1), + output_type.getDimSize(2), output_type.getDimSize(3) / num_groups}, + output_type.getElementType()); + + // Create a separate convolution for each group + for (int i = 0; i < num_groups; ++i) { + auto verified_input_slice_ty = + getTypeForSlice(input_slice_ty, filter_channels, i * filter_channels); + auto verified_filter_slice_ty = + getTypeForSlice(filter_slice_ty, filter_channels, i * filter_channels); + auto verified_bias_slice_ty = + getTypeForSlice(bias_slice_ty, filter_channels, i * filter_channels); + auto verified_per_conv_out_ty = + getTypeForSlice(per_conv_out_ty, filter_channels, i * filter_channels); + + // Slice the input + SmallVector input_start_vals(rank, 0); + input_start_vals.back() = i * filter_channels; + DenseI64ArrayAttr input_start = + rewriter.getDenseI64ArrayAttr(input_start_vals); + + auto slice_input = rewriter.createOrFold( + op->getLoc(), verified_input_slice_ty, op.getInput(), input_start, + input_size); + + // Slice the filter + SmallVector filter_start_vals(rank, 0); + filter_start_vals.front() = i * filter_channels; + DenseI64ArrayAttr filter_start = + rewriter.getDenseI64ArrayAttr(filter_start_vals); + + auto slice_filter = rewriter.createOrFold( + op->getLoc(), verified_filter_slice_ty, op.getFilter(), filter_start, + filter_size); + + // Slice the bias + DenseI64ArrayAttr bias_start = + rewriter.getDenseI64ArrayAttr(i * filter_channels); + auto slice_bias = rewriter.createOrFold( + op->getLoc(), verified_bias_slice_ty, op.getBias(), bias_start, + bias_size); + + // Create a convolution for each set of slices + auto conv = rewriter.create( + op->getLoc(), verified_per_conv_out_ty, slice_input, slice_filter, + slice_bias, op.getDilationHFactor(), op.getDilationWFactor(), + op.getFusedActivationFunction(), op.getPadding(), op.getStrideH(), + op.getStrideW()); + + convolutions.push_back(conv.getResult()); + } + + return rewriter.createOrFold(op->getLoc(), output_type, + convolutions, in_channels_dim); +} + LogicalResult ConvertTFLConv2DOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_conv2d_op = cast(op); @@ -1281,6 +1420,14 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( "be all quantized or all floating-point"); } + int64_t input_channels = input_type.getDimSize(3); + int64_t filter_channels = filter_type.getDimSize(3); + if (input_channels != filter_channels && + input_channels % filter_channels == 0) { + rewriter.replaceOp(op, lowerGroupedConvolution(tfl_conv2d_op, rewriter)); + return success(); + } + DenseI64ArrayAttr pad; DenseI64ArrayAttr stride; DenseI64ArrayAttr dilation; @@ -2041,8 +2188,8 @@ LogicalResult ConvertTFLRankOp::matchAndRewrite( RankedTensorType rank_type = RankedTensorType::get({1}, rewriter.getIntegerType(32)); auto rank_attr = DenseI32ArrayAttr::get(rewriter.getContext(), {rank}); - auto rank_const = CreateOpAndInfer(rewriter, op->getLoc(), - rank_type, rank_attr); + auto rank_const = CreateOpAndInfer( + rewriter, op->getLoc(), rank_type, rank_attr.cast()); rewriter.replaceOp(op, {rank_const.getResult()}); @@ -2074,8 +2221,8 @@ LogicalResult ConvertTFLShapeOp::matchAndRewrite( {static_cast(shape_arr.size())}, rewriter.getIntegerType(32)); auto shape_attr = DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(shape_arr)); - auto shape_const = CreateOpAndInfer(rewriter, op->getLoc(), - shape_type, shape_attr); + auto shape_const = CreateOpAndInfer( + rewriter, op->getLoc(), shape_type, shape_attr.cast()); rewriter.replaceOp(op, {shape_const.getResult()}); @@ -2142,12 +2289,13 @@ LogicalResult ConvertTFLFillOp::matchAndRewrite( if (!matchPattern(tfl_fill_op.getInput(), m_Constant(&value_elem))) return failure(); - RankedTensorType fill_type = RankedTensorType::get( - ArrayRef(dims_vals), value_elem.getType().getElementType()); + RankedTensorType fill_type = + RankedTensorType::get(ArrayRef(dims_vals), + value_elem.getShapedType().getElementType()); DenseArrayAttr fill_attr; // Convert to a compatible zero type. - if (value_elem.getType().getElementType().isa()) { + if (value_elem.getShapedType().getElementType().isa()) { SmallVector fill_arr( total_size, value_elem.getValues()[0].convertToFloat()); fill_attr = @@ -2158,8 +2306,8 @@ LogicalResult ConvertTFLFillOp::matchAndRewrite( fill_attr = DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } - auto fill_const_op = CreateOpAndInfer(rewriter, op->getLoc(), - fill_type, fill_attr); + auto fill_const_op = CreateOpAndInfer( + rewriter, op->getLoc(), fill_type, fill_attr.cast()); rewriter.replaceOp(op, {fill_const_op.getResult()}); return success(); @@ -2350,6 +2498,54 @@ LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLRsqrtOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_rsqrt_op = cast(op); + + RankedTensorType output_type = + tfl_rsqrt_op.getResult().getType().dyn_cast(); + RankedTensorType input_type = + tfl_rsqrt_op.getX().getType().dyn_cast(); + + mlir::quant::UniformQuantizedType input_qtype = + input_type.getElementType() + .dyn_cast_or_null(); + mlir::quant::UniformQuantizedType output_qtype = + output_type.getElementType() + .dyn_cast_or_null(); + + // Quantization case + if (input_qtype && output_qtype) { + auto rsqrt_func = [](double x) -> double { + // Negative numbers are undefined for rsqrt + // 0 should return the max value of the storage data type for rsqrt + if (x <= 0.0) return DBL_MAX; + return 1.0 / std::sqrt(x); + }; + + // 16-bit is pending review for TFL + // https://github.com/tensorflow/tensorflow/pull/58406 + if (input_qtype.getStorageTypeIntegralWidth() != 8) { + return rewriter.notifyMatchFailure(op, + "input qtype storage width is not 8"); + } + + // Implement with 8-bit table lookup. + Value table_const = getTosaConst8bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), rsqrt_func); + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_rsqrt_op.getX(), table_const); + return success(); + } + + CreateReplaceOpAndInfer(rewriter, op, tfl_rsqrt_op.getType(), + tfl_rsqrt_op.getX()); + + return success(); +} + LogicalResult ConvertTFLSqrtOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_rsqrt_op = cast(op); @@ -3365,21 +3561,20 @@ static LogicalResult LegalizeQuantizedPrelu(Operation* op, // Perform an element-wise multiplication on rescaled alpha and input for // PReLU. - Value alpha = tfl_prelu_op.getAlpha(); - ShapedType alpha_type = alpha.getType().cast(); - UniformQuantizedType alpha_qtype = - alpha_type.getElementType().cast(); + Value alpha = tfl_prelu_op.getAlpha(); + ShapedType alpha_type = alpha.getType().cast(); + UniformQuantizedType alpha_qtype = + alpha_type.getElementType().cast(); - Value op_rescale_alpha = removeZeroPointAndCastToInt32( - rewriter, op, alpha, alpha_qtype.getZeroPoint()); + Value op_rescale_alpha = removeZeroPointAndCastToInt32( + rewriter, op, alpha, alpha_qtype.getZeroPoint()); - Value op_mul = - CreateOpAndInfer(rewriter, op->getLoc(), rescale_type, - op_rescale_in, op_rescale_alpha, 0); + Value op_mul = CreateOpAndInfer( + rewriter, op->getLoc(), rescale_type, op_rescale_in, op_rescale_alpha, 0); - op_rescale_slope_in = buildRescale( - rewriter, op, output_type, op_mul, scale_alpha, - /* input_zp = */ 0, output_qtype.getZeroPoint(), true, true); + op_rescale_slope_in = + buildRescale(rewriter, op, output_type, op_mul, scale_alpha, + /* input_zp = */ 0, output_qtype.getZeroPoint(), true, true); Value op_rescale_identity_in = buildRescale( rewriter, op, output_type, input, scale_identity, @@ -3745,7 +3940,7 @@ LogicalResult ConvertTFLConstOp::matchAndRewrite( if (!output_type) return failure(); ElementsAttr elements = tfl_const_op.getValue(); - Type element_type = elements.getType().getElementType(); + Type element_type = elements.getShapedType().getElementType(); if (output_type.getElementType().isa()) { output_type = RankedTensorType::get(output_type.getShape(), element_type); } @@ -4190,6 +4385,74 @@ LogicalResult ConvertTFLImagOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLRFFT2dOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto rfft2d_op = cast(op); + auto loc = op->getLoc(); + Value input = rfft2d_op.getInput(); + + auto input_type = dyn_cast(input.getType()); + auto output_type = dyn_cast(rfft2d_op.getResult().getType()); + + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure(op, "ranked input/output required"); + } + + if (!input_type.getElementType().isF32()) { + return rewriter.notifyMatchFailure(op, "input type must be fp32"); + } + + Value fft_length_value = rfft2d_op.getFftLength(); + llvm::SmallVector fft_length; + if (failed(getVectorFromValue32(fft_length_value, fft_length))) { + return rewriter.notifyMatchFailure(op, "fft_length is not a constant"); + } + + auto fp32_ty = UnrankedTensorType::get(rewriter.getF32Type()); + + // Padding is automatically inserted during the lowering when + // fft_length > input shape. However, to take care of the + // case fft_length < input shape we need to crop the input. + const int64_t rank = input_type.getRank(); + auto input_shape = input_type.getShape(); + if (fft_length[0] < input_shape[rank - 2] || + fft_length[1] < input_shape[rank - 1]) { + llvm::SmallVector slice_begin(rank, 0); + llvm::SmallVector slice_size; + for (auto dim : input_type.getShape().drop_back(2)) { + slice_size.push_back(dim); + } + slice_size.push_back(fft_length[0]); + slice_size.push_back(fft_length[1]); + input = CreateOpAndInfer( + rewriter, loc, fp32_ty, input, + rewriter.getDenseI64ArrayAttr(slice_begin), + rewriter.getDenseI64ArrayAttr(slice_size)); + } + + auto rfft2d = + CreateOpAndInfer(rewriter, loc, fp32_ty, fp32_ty, input); + + auto output_shape = output_type.getShape(); + llvm::SmallVector new_shape{output_shape}; + new_shape.push_back(1); + auto reshape_1 = CreateOpAndInfer( + rewriter, loc, fp32_ty, rfft2d.getResult(0), + rewriter.getDenseI64ArrayAttr(new_shape)); + auto reshape_2 = CreateOpAndInfer( + rewriter, loc, fp32_ty, rfft2d.getResult(1), + rewriter.getDenseI64ArrayAttr(new_shape)); + + llvm::SmallVector values = {reshape_1, reshape_2}; + auto concat = CreateOpAndInfer(rewriter, loc, fp32_ty, values, + rewriter.getI64IntegerAttr(3)); + + CreateReplaceOpAndInfer( + rewriter, op, output_type, concat.getResult()); + + return success(); +} + LogicalResult LegalizeTFL::initialize(MLIRContext* context) { RewritePatternSet patterns(context); mlir::tosa::populateLegalizeTFLPatterns(context, patterns); @@ -4241,6 +4504,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLMul); DEF_PATTERN_INSERT(TFLSquare); DEF_PATTERN_INSERT(TFLSquaredDifference); + DEF_PATTERN_INSERT(TFLSign); DEF_PATTERN_INSERT(TFLRound); DEF_PATTERN_INSERT(TFLDiv); DEF_PATTERN_INSERT(TFLMaximum); @@ -4325,6 +4589,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLWhile); DEF_PATTERN_INSERT(TFLReal); DEF_PATTERN_INSERT(TFLImag); + DEF_PATTERN_INSERT(TFLRFFT2d); } // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass. diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index 762ba97ed62..ff8616687a2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -285,6 +285,13 @@ Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, for (int32_t i = -128; i < 128; i++) { double dequantized = input_scale * (i - input_zp); double transformed = func(dequantized); + + double max = (output_scale > 1.0) ? DBL_MAX : (DBL_MAX * output_scale); + if (transformed >= max) { + table.push_back(INT8_MAX); + continue; + } + int32_t rescaled = std::llround(transformed / output_scale); int32_t quantized = static_cast(rescaled + output_zp); table.push_back( @@ -434,6 +441,21 @@ Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op, return const_op.getResult(); } +// Create an expected bitwidth integer constant operator based on the type +// parameter. +Value getTosaConstTensorScalarInt(ImplicitLocOpBuilder& builder, Type type, + int64_t val) { + auto bit_width = type.getIntOrFloatBitWidth(); + auto const_type = tensorflow::GetTypeFromTFTensorShape( + {}, builder.getIntegerType(bit_width)); + auto const_attr = + SplatElementsAttr::get(const_type, builder.getIntegerAttr(type, val)); + + auto const_op = + builder.create(builder.getLoc(), const_type, const_attr); + return const_op.getResult(); +} + // Create a vector from a 32-bit value tensor. Returns the size of // the new vector or -1 on error. LogicalResult getVectorFromValue32(Value val, SmallVectorImpl& vec) { @@ -695,7 +717,10 @@ LogicalResult ApplyPatternsWithShapeResolution( // This should be investigate for whether it is still necessary due to quant // type stripping changing. func.walk([&](tosa::ConstOp op) { - auto ety = op.getValue().getType().getElementType(); + if (op.getType().getElementType().isa()) { + return; + } + auto ety = op.getValue().getShapedType().getElementType(); auto new_ty = op.getType().cast().clone(ety); op.getResult().setType(new_ty); }); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index c67791200f5..b2e76197fb5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H_ +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project @@ -103,6 +105,11 @@ Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op, Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op, int32_t val); +// Create an expected bitwidth integer constant operator based on the type +// parameter. +Value getTosaConstTensorScalarInt(ImplicitLocOpBuilder& builder, Type type, + int64_t val); + // Create a vector from a 32-bit value tensor. Returns vector size on success // or -1 on error. LogicalResult getVectorFromValue32(Value val, SmallVectorImpl& vec); @@ -151,9 +158,9 @@ LogicalResult ApplyPatternsWithShapeResolution( // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the TFLite to TOSA lowering. template -TosaOp CreateOpAndInfer(PatternRewriter& rewriter, Location loc, Type result_ty, +TosaOp CreateOpAndInfer(ImplicitLocOpBuilder& builder, Type result_ty, Args&&... args) { - auto op = rewriter.create(loc, result_ty, args...); + auto op = builder.create(result_ty, args...); InferShapedTypeOpInterface shapeInterface = dyn_cast(op.getOperation()); @@ -161,8 +168,9 @@ TosaOp CreateOpAndInfer(PatternRewriter& rewriter, Location loc, Type result_ty, SmallVector returnedShapes; if (shapeInterface - .inferReturnTypeComponents(op.getContext(), op.getLoc(), + .inferReturnTypeComponents(op.getContext(), builder.getLoc(), op->getOperands(), op->getAttrDictionary(), + op->getPropertiesStorage(), op->getRegions(), returnedShapes) .failed()) return op; @@ -196,6 +204,13 @@ TosaOp CreateOpAndInfer(PatternRewriter& rewriter, Location loc, Type result_ty, return op; } +template +TosaOp CreateOpAndInfer(PatternRewriter& rewriter, Location loc, Type result_ty, + Args&&... args) { + ImplicitLocOpBuilder builder(loc, rewriter); + return CreateOpAndInfer(builder, result_ty, args...); +} + template void CreateReplaceOpAndInfer(PatternRewriter& rewriter, Operation* op, Type result_ty, Args&&... args) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/lower_global_tensors.cc b/tensorflow/compiler/mlir/tosa/transforms/lower_global_tensors.cc new file mode 100644 index 00000000000..de30f7c2fb0 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/transforms/lower_global_tensors.cc @@ -0,0 +1,206 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" + +#define PASS_NAME "tosa-lower-global-tensors" +#define DEBUG_TYPE PASS_NAME + +namespace mlir::tosa { + +#define GEN_PASS_DEF_LOWERGLOBALTENSORS +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" + +namespace { + +class LowerGlobalTensorsPass + : public impl::LowerGlobalTensorsBase { + public: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + // Converts TFLite state operations to the MLProgram equivalent. + void runOnOperation() override { + auto* context = &getContext(); + auto moduleOp = getOperation(); + mlir::OpBuilder builder(moduleOp.getBodyRegion()); + + DenseMap symNameToFunction; + for (auto func : moduleOp.getOps()) { + symNameToFunction[func.getSymName()] = func; + } + + DenseMap sharedNameToConstant; + DenseMap sharedNameToLoc; + + SmallVector handleOps; + SmallVector assignOps; + SmallVector readOps; + for (auto it : symNameToFunction) { + auto func = std::get<1>(it); + // Look through the initialization functions and find the assigned values + // for each handle, save out the constant value. + for (auto init : func.getOps()) { + auto findInitFunc = + symNameToFunction.find(init.getSessionInitFunction()); + if (findInitFunc == symNameToFunction.end()) { + init.emitError("unable to find initialization function: " + + init.getSessionInitFunction()); + continue; + } + func::FuncOp initFunc = std::get<1>(*findInitFunc); + for (auto assign : initFunc.getOps()) { + auto handle = dyn_cast( + assign.getResourceId().getDefiningOp()); + if (!handle) continue; + + DenseElementsAttr constant; + if (!matchPattern(assign.getValue(), m_Constant(&constant))) { + // Quantized types we can not use the m_Constant matcher. + if (auto constOp = dyn_cast( + assign.getValue().getDefiningOp())) { + constant = constOp.getValue().cast(); + } + } + if (!constant) continue; + + auto name = handle.getSharedName(); + sharedNameToConstant[name] = constant; + sharedNameToLoc[name] = handle.getLoc(); + } + } + + // We also want to grab the list of operations to replace. + for (auto& op : func.getOps()) { + if (auto handle = dyn_cast(op)) + handleOps.push_back(handle); + if (auto assign = dyn_cast(op)) + assignOps.push_back(assign); + if (auto read = dyn_cast(op)) + readOps.push_back(read); + } + } + + // TF::CallOnceOps are no longer needed as we have already extracted their + // state. + SmallVector callOnceOps; + for (auto func : moduleOp.getOps()) { + for (auto init : func.getOps()) { + callOnceOps.push_back(init); + } + } + for (auto op : callOnceOps) op.erase(); + + // Create the ml_program::GlobalOps to store our new global variables. + DenseMap symbolRefMap; + for (auto it : sharedNameToConstant) { + auto name = std::get<0>(it); + auto attribute = std::get<1>(it); + auto locIt = sharedNameToLoc.find(name); + LocationAttr loc = mlir::UnknownLoc(); + if (locIt != sharedNameToLoc.end()) { + loc = std::get<1>(*locIt); + } + + // TODO(suderman): Determine the global type based on all store + // operations. + auto global = builder.create( + loc, name, attribute.getType(), /*is_mutable=*/true, attribute, + nullptr); + global.setPrivate(); + + symbolRefMap[name] = global; + } + + // Replace the assign ops with a global store operation. + for (auto assign : assignOps) { + auto handle = dyn_cast( + assign.getResourceId().getDefiningOp()); + if (!handle) continue; + + Value value = assign.getValue(); + auto globalOpIt = symbolRefMap.find(handle.getSharedName()); + if (globalOpIt == symbolRefMap.end()) { + assign->emitError( + "unable to find corresponding GlobalOp for op's VarHandle"); + continue; + } + auto globalOp = std::get<1>(*globalOpIt); + + builder.setInsertionPoint(assign); + if (globalOp.getType() != value.getType()) { + value = builder + .create( + assign.getLoc(), globalOp.getType(), value) + .getResult(0); + } + + auto globalSymbolRef = SymbolRefAttr::get(context, globalOp.getSymName()); + builder.create(assign.getLoc(), + globalSymbolRef, value); + assign.erase(); + } + + for (auto read : readOps) { + auto handle = dyn_cast( + read.getResourceId().getDefiningOp()); + if (!handle) continue; + + auto globalOpIt = symbolRefMap.find(handle.getSharedName()); + if (globalOpIt == symbolRefMap.end()) continue; + auto globalOp = std::get<1>(*globalOpIt); + + builder.setInsertionPoint(read); + + auto globalSymbolRef = SymbolRefAttr::get(context, globalOp.getSymName()); + Value load = builder.create( + read.getLoc(), globalOp.getType(), globalSymbolRef); + + if (read.getType() != load.getType()) { + load = builder + .create(read.getLoc(), + read.getType(), load) + .getResult(0); + } + read.getResult().replaceAllUsesWith(load); + read.erase(); + } + + for (auto handle : handleOps) { + if (handle.getResult().use_empty()) { + handle.erase(); + } + } + } +}; + +} // namespace + +std::unique_ptr> createLowerGlobalTensorsPass() { + return std::make_unique(); +} + +} // namespace mlir::tosa diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h index 8721c83a50f..99f9465c8a6 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.h +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { @@ -55,13 +56,21 @@ std::unique_ptr> createLegalizeTFLPass( ArrayRef disabled_patterns = std::nullopt, ArrayRef enabled_patterns = std::nullopt); -std::unique_ptr> createLegalizeTFTFLPass(); +std::unique_ptr> createLowerGlobalTensorsPass(); +std::unique_ptr> createRetainCallOnceFuncsPass(); +std::unique_ptr> createStripModuleMetadataPass(); std::unique_ptr> createConvertTFLUint8Pass(); -std::unique_ptr> createStripQuantTypesPass(); -std::unique_ptr> createLowerComplexTypesPass(); +std::unique_ptr> +createConvertFunctionMetadataPass(); std::unique_ptr> createDequantizeTFLSoftmaxPass(); +std::unique_ptr> createLegalizeTFTFLPass(); +std::unique_ptr> createLowerComplexTypesPass(); +std::unique_ptr> createStripFunctionMetadataPass(); +std::unique_ptr> createStripQuantTypesPass(); +std::unique_ptr> createVerifyFullyConvertedPass(); #define GEN_PASS_REGISTRATION +#define GEN_PASS_CLASSES #define GEN_PASS_DECL_TOSALEGALIZETFPASS #define GEN_PASS_DECL_TOSALEGALIZETFLPASS #define GEN_PASS_DECL_TOSALEGALIZETFTFLPASS @@ -70,6 +79,12 @@ std::unique_ptr> createDequantizeTFLSoftmaxPass(); #define GEN_PASS_DECL_TOSASTRIPQUANTTYPESPASS #define GEN_PASS_DECL_TOSALOWERCOMPLEXTYPESPASS #define GEN_PASS_DECL_TOSADEQUANTIZETFLSOFTMAXPASS +#define GEN_PASS_DECL_LOWERGLOBALTENSORS +#define GEN_PASS_DECL_RETAINCALLONCEFUNCS +#define GEN_PASS_DECL_STRIPFUNCTIONMETADATA +#define GEN_PASS_DECL_STRIPMODULEMETADATA +#define GEN_PASS_DECL_VERIFYFULLYCONVERTED +#define GEN_PASS_DECL_CONVERTFUNCTIONMETADATA #include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.td b/tensorflow/compiler/mlir/tosa/transforms/passes.td index f2c3fe1d463..e623760a4e9 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.td +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.td @@ -88,3 +88,40 @@ def TosaDequantizeTFLSoftmaxPass : Pass<"tosa-dequantize-tfl-softmax", "mlir::fu let constructor = "createDequantizeTFLSoftmaxPass()"; let dependentDialects = ["mlir::TFL::TFLDialect", "quantfork::QuantizationForkDialect"]; } + +def LowerGlobalTensors : + Pass<"tflite-lower-global-tensors", "mlir::ModuleOp"> { + let summary = "Lowers TFLite global tensors to MLProgram dialect variables."; + let constructor = "createLowerGlobalTensorsPass()"; +} + +def RetainCallOnceFuncs : + Pass<"tflite-retain-call-once-funcs", "mlir::ModuleOp"> { + let summary = "Guarantees that functions used by tfl.call_once are retained."; + let constructor = "createRetainCallOnceFuncsPass()"; +} + +def StripFunctionMetadata : + Pass<"tosa-tflite-strip-function-metadata", "mlir::func::FuncOp"> { + let summary = "Strip all unneeded TF/TFLite specific metadata."; + let constructor = "createStripFunctionMetadataPass()"; +} + +def StripModuleMetadata : + Pass<"tosa-tflite-strip-module-metadata", "mlir::ModuleOp"> { + let summary = "Strip all unneeded TF/TFLite specific metadata."; + let constructor = "createStripModuleMetadataPass()"; +} + +def VerifyFullyConverted : + Pass<"tosa-tflite-verify-fully-converted", "mlir::func::FuncOp"> { + let summary = "Verifies that all TFLite frontend ops were converted and none remain."; + let constructor = "createVerifyFullyConvertedPass()"; +} + +def ConvertFunctionMetadata : + Pass<"tosa-tflite-convert-function-metadata", "mlir::func::FuncOp"> { + let summary = "Converts TFLite input attributes to MLProgram arg attributes on functions."; + let constructor = "createConvertFunctionMetadataPass()"; +} + diff --git a/tensorflow/compiler/mlir/tosa/transforms/retain_call_once_funcs.cc b/tensorflow/compiler/mlir/tosa/transforms/retain_call_once_funcs.cc new file mode 100644 index 00000000000..de76f4585ad --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/transforms/retain_call_once_funcs.cc @@ -0,0 +1,68 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" + +#define PASS_NAME "retain-call-once-funcs" +#define DEBUG_TYPE PASS_NAME + +namespace mlir::tosa { + +#define GEN_PASS_DEF_RETAINCALLONCEFUNCS +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" + +namespace { + +class RetainCallOnceFuncsPass + : public impl::RetainCallOnceFuncsBase { + public: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + llvm::DenseMap funcMap; + for (auto func : moduleOp.getOps()) { + funcMap[func.getSymName()] = func; + } + + for (auto func : moduleOp.getOps()) { + for (auto callOnce : func.getOps()) { + auto callFunc = funcMap[callOnce.getSessionInitFunction()]; + callOnce->setAttr("session_init_function_symbol", + SymbolRefAttr::get(callFunc)); + } + } + } +}; + +} // anonymous namespace + +std::unique_ptr> createRetainCallOnceFuncsPass() { + return std::make_unique(); +} + +} // namespace mlir::tosa diff --git a/tensorflow/compiler/mlir/tosa/transforms/strip_metadata.cc b/tensorflow/compiler/mlir/tosa/transforms/strip_metadata.cc new file mode 100644 index 00000000000..7960285fdb1 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/transforms/strip_metadata.cc @@ -0,0 +1,103 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/IR/FunctionInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" + +#define PASS_NAME "tosa-strip-metadata" +#define DEBUG_TYPE PASS_NAME + +namespace mlir::tosa { + +#define GEN_PASS_DEF_STRIPM +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" + +namespace { + +static bool isTFLAttr(NamedAttribute &namedAttr) { + // TFLite uses both tf and tfl in attribute annotations. + auto name = namedAttr.getName().strref(); + // Don't trim attributes from tf_saved_model---they carry ABI information. + if (name.startswith("tf_saved_model.")) return false; + + if (name.startswith("tf.") || name.startswith("tf_") || + name.startswith("tfl.") || name.startswith("tfl_")) { + return true; + } + StringRef attrNamespace = namedAttr.getValue().getDialect().getNamespace(); + return attrNamespace == "tf" || attrNamespace == "tfl"; +} + +class StripModuleMetadataPass + : public StripModuleMetadataBase { + public: + void runOnOperation() override { + auto moduleOp = getOperation(); + auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( + moduleOp->getAttrs(), + [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); + for (auto namedAttr : stripAttrs) { + moduleOp->removeAttr(namedAttr.getName()); + } + } +}; + +class StripFunctionMetadataPass + : public StripFunctionMetadataBase { + public: + void runOnOperation() override { + auto funcOp = getOperation(); + auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( + funcOp->getAttrs(), + [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); + for (auto namedAttr : stripAttrs) { + funcOp->removeAttr(namedAttr.getName()); + } + + for (int i = 0, e = funcOp.getNumArguments(); i < e; ++i) { + auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( + mlir::function_interface_impl::getArgAttrs(funcOp, i), + [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); + for (auto namedAttr : stripAttrs) { + funcOp.removeArgAttr(i, namedAttr.getName()); + } + } + + for (int i = 0, e = funcOp.getNumResults(); i < e; ++i) { + auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( + mlir::function_interface_impl::getResultAttrs(funcOp, i), + [](NamedAttribute namedAttr) { return isTFLAttr(namedAttr); })); + for (auto namedAttr : stripAttrs) { + funcOp.removeResultAttr(i, namedAttr.getName()); + } + } + } +}; + +} // anonymous namespace + +std::unique_ptr> createStripModuleMetadataPass() { + return std::make_unique(); +} + +std::unique_ptr> createStripFunctionMetadataPass() { + return std::make_unique(); +} + +} // namespace mlir::tosa diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td index cec441d25ef..1745907f8a8 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td +++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td @@ -31,7 +31,6 @@ def ConvertTFLCeilOp : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>; def ConvertTFLFloorOp : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>; def ConvertTFLExpOp : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>; def ConvertTFLLogOp : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>; -def ConvertTFLRsqrtOp : Pat<(TFL_RsqrtOp $arg), (Tosa_RsqrtOp $arg)>; def ConvertTFLLogicalNotOp : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>; def ConvertTFLCastOp: Pat<(TFL_CastOp $in), (Tosa_CastOp $in)>; diff --git a/tensorflow/compiler/mlir/tosa/transforms/verify_fully_converted.cc b/tensorflow/compiler/mlir/tosa/transforms/verify_fully_converted.cc new file mode 100644 index 00000000000..478c77ba61a --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/transforms/verify_fully_converted.cc @@ -0,0 +1,90 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir::tosa { + +#define GEN_PASS_DEF_VERIFYFULLYCONVERTED +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" + +namespace { + +static void emitLegalizationErrors(Location loc, + const DenseSet &illegalOps) { + // Print op errors for each of the illegal ops that still remain. + llvm::MapVector opNameCounts; + for (Operation *illegalOp : illegalOps) { + StringRef opName = illegalOp->getName().getStringRef(); + opNameCounts[opName]++; + illegalOp->emitOpError() << ": illegal op still exists"; + } + + std::vector errorMessages; + errorMessages.reserve(opNameCounts.size()); + for (const auto &opInfo : opNameCounts) { + errorMessages.push_back( + llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second)); + } + emitError(loc) << "The following illegal operations still remain: \n" + << llvm::join(errorMessages, "\n") << "\n"; +} + +LogicalResult verifyAllOperationsAreLegal(Operation *op, + const ConversionTarget &target) { + DenseSet illegalOps; + op->walk([&](Operation *op) { + if (!target.isLegal(op)) { + illegalOps.insert(op); + } + }); + if (illegalOps.empty()) return success(); + emitLegalizationErrors(op->getLoc(), illegalOps); + return failure(); +} + +class VerifyFullyConvertedPass + : public impl::VerifyFullyConvertedBase { + public: + // Validates that no TFLite frontends ops are in the function. + void runOnOperation() override { + // We don't just use applyPartialConversion with no patterns because this + // pass shouldn't alter the IR at all (including via folding or + // canonicalizations that dialect conversion does automatically). + ConversionTarget target(getContext()); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + target.addIllegalDialect(); + target.addIllegalOp(); + if (failed(verifyAllOperationsAreLegal(getOperation(), target))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +std::unique_ptr> createVerifyFullyConvertedPass() { + return std::make_unique(); +} + +} // namespace mlir::tosa diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f1419a29796..1fb31b9db74 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -50,11 +50,12 @@ py_library( "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", - "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/platform:flags", + "//tensorflow/python/platform:tf_logging", "//third_party/py/numpy", ], ) @@ -83,6 +84,7 @@ py_test( ], deps = [ ":xla_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -101,6 +103,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -120,6 +123,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -139,6 +143,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -159,6 +164,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -180,6 +186,7 @@ tf_xla_py_test( "//tensorflow/python:list_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -199,6 +206,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -224,6 +232,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -269,6 +278,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -311,6 +321,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -326,6 +337,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:cond", "//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_switch_case", "//tensorflow/python:framework", @@ -333,6 +345,7 @@ tf_xla_py_test( "//tensorflow/python:tensor_array_ops", "//tensorflow/python:training", "//tensorflow/python/eager:function", + "//tensorflow/python/platform:client_testlib", ], ) @@ -354,6 +367,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -371,6 +385,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -397,6 +412,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -458,6 +474,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -564,6 +581,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -581,6 +599,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", + "//tensorflow/python/platform:client_testlib", ], ) @@ -621,6 +640,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", + "//tensorflow/python/platform:test", "@absl_py//absl/testing:parameterized", ], ) @@ -655,6 +675,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -671,6 +692,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", + "//tensorflow/python:cond", "//tensorflow/python:framework", "//tensorflow/python:layers", "//tensorflow/python:math_ops", @@ -697,6 +719,7 @@ tf_xla_py_test( "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -747,6 +770,7 @@ tf_xla_py_test( srcs = ["ftrl_test.py"], enable_mlir_bridge = False, python_version = "PY3", + shard_count = 16, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -757,6 +781,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -814,7 +839,6 @@ tf_xla_py_test( tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. - "requires-gpu-nvidia", ], deps = [ ":xla_test", @@ -822,6 +846,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -842,6 +867,7 @@ tf_xla_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -898,6 +924,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -937,6 +964,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1012,6 +1040,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1074,6 +1103,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -1101,6 +1131,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -1180,6 +1211,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python/platform:test", ], ) @@ -1198,6 +1230,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1221,6 +1254,7 @@ tf_xla_py_test( # "//tensorflow/python:platform_test", # "//tensorflow/python/compat:v2_compat", # "//tensorflow/python/eager:function", +# "//tensorflow/python/platform:client_testlib", # ], # ) # copybara:uncomment_end @@ -1241,6 +1275,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1261,6 +1296,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1300,6 +1336,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1318,6 +1355,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:sparse_ops", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1338,6 +1376,7 @@ tf_xla_py_test( "//tensorflow/python:data_flow_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1361,6 +1400,8 @@ tf_xla_py_test( "//tensorflow/python:standard_ops", "//tensorflow/python:stateful_random_ops", "//tensorflow/python/kernel_tests/random:util", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:flags", ], ) @@ -1382,6 +1423,7 @@ tf_xla_py_test( "//tensorflow/python:standard_ops", "//tensorflow/python:stateless_random_ops", "//tensorflow/python/kernel_tests/random:util", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1411,6 +1453,7 @@ tf_xla_py_test( "//tensorflow/python:tensor_array_grad", "//tensorflow/python:tensor_array_ops", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1432,6 +1475,7 @@ tf_xla_py_test( "//tensorflow/python:list_ops", "//tensorflow/python:platform_test", "//tensorflow/python/eager:function", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1441,7 +1485,7 @@ tf_xla_py_test( srcs = ["ternary_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", - shard_count = 4, + shard_count = 16, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -1462,7 +1506,7 @@ tf_xla_py_test( srcs = ["unary_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", - shard_count = 4, + shard_count = 32, tags = [ "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1499,6 +1543,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -1543,6 +1588,7 @@ tf_xla_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:training", "//tensorflow/python:while_loop", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1566,6 +1612,7 @@ tf_xla_py_test( "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1588,6 +1635,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1606,6 +1654,8 @@ tf_xla_py_test( "//tensorflow/python:data_flow_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:flags", ], ) @@ -1623,6 +1673,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1641,6 +1692,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1661,6 +1713,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -1698,6 +1751,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -1737,6 +1791,7 @@ cuda_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", + "//tensorflow/python:cond", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", "//tensorflow/python:gradients", @@ -1820,8 +1875,8 @@ tf_cuda_cc_test( shard_count = 20, # This test is randomized, so only run it if explicitly requested. tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "manual", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", ] + tf_cuda_tests_tags(), deps = [":randomized_tests_library"], @@ -1834,8 +1889,8 @@ tf_cuda_cc_test( shard_count = 20, # This test is randomized, so only run it if explicitly requested. tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "manual", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", ] + tf_cuda_tests_tags(), deps = [":randomized_tests_library"], @@ -1856,8 +1911,8 @@ tf_cuda_cc_test( "config-cuda-only", "no_cuda_asan", # TODO(b/201651800) "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "requires-gpu-nvidia", "no_rocm", # ROCmSoftwarePlatform #958 + "requires-gpu-nvidia", ] + tf_cuda_tests_tags(), deps = [":randomized_tests_library"], ) @@ -1877,8 +1932,8 @@ tf_cuda_cc_test( "config-cuda-only", "no_cuda_asan", # TODO(b/201651800) "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "requires-gpu-nvidia", "no_rocm", # ROCmSoftwarePlatform #958 + "requires-gpu-nvidia", ] + tf_cuda_tests_tags(), deps = [":randomized_tests_library"], ) @@ -1917,7 +1972,7 @@ py_library( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", - "//tensorflow/python:variables", + "//tensorflow/python:variable_v1", "@six_archive//:six", ], ) @@ -1939,7 +1994,6 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:platform", "//tensorflow/python:variables", ], ) @@ -2062,6 +2116,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "//tensorflow/python/platform:client_testlib", ], ) @@ -2084,6 +2139,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -2126,6 +2182,7 @@ tf_xla_py_test( "//tensorflow/python:linalg_ops", "//tensorflow/python:platform_test", "//tensorflow/python:standard_ops", + "//tensorflow/python/platform:client_testlib", ], ) @@ -2142,6 +2199,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:framework", + "//tensorflow/python:gradient_checker_v2", "//tensorflow/python:linalg_ops", "//tensorflow/python:platform_test", "//tensorflow/python:standard_ops", @@ -2161,7 +2219,9 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:gradient_checker_v2", "//tensorflow/python:math_ops", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -2180,6 +2240,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:math_ops", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -2202,6 +2263,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:math_ops", + "//tensorflow/python/platform:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -2221,6 +2283,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python/platform:client_testlib", ], ) @@ -2246,6 +2309,7 @@ tf_xla_py_test( "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/tpu", ], ) @@ -2253,7 +2317,10 @@ tf_xla_py_test( name = "where_op_tpu_test", size = "small", srcs = ["where_op_test.py"], - args = ["--tpu_use_tfrt=true"], + args = [ + "--tpu_use_tfrt=true", + # TODO(b/274633087): Set tf_use_pjrt=true after fixing bug. + ], disabled_backends = [ "cpu", "cpu_ondemand", @@ -2274,6 +2341,7 @@ tf_xla_py_test( "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/tpu", ], ) @@ -2293,6 +2361,7 @@ tf_xla_py_test( "//tensorflow/python:platform_test", "//tensorflow/python/eager:function", "//tensorflow/python/ops/risc:risc_ops", + "//tensorflow/python/platform:client_testlib", ], ) @@ -2325,6 +2394,7 @@ cuda_py_test( ":xla_test", "//tensorflow/python:constant_op", "//tensorflow/python:framework", + "//tensorflow/python/platform:client_testlib", ], ) @@ -2370,6 +2440,7 @@ tf_xla_py_test( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:variables", + "//tensorflow/python/platform:client_testlib", "//tensorflow/python/tpu:tpu_lib", "@absl_py//absl/testing:parameterized", ], @@ -2404,7 +2475,6 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", ], use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ diff --git a/tensorflow/compiler/tests/cond_test.py b/tensorflow/compiler/tests/cond_test.py index 9119095b6a3..db767f0d554 100644 --- a/tensorflow/compiler/tests/cond_test.py +++ b/tensorflow/compiler/tests/cond_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_switch_case from tensorflow.python.ops import math_ops @@ -44,7 +45,7 @@ class CondTest(xla_test.XLATestCase): @def_function.function def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) - output = control_flow_ops.cond( + output = cond.cond( constant_op.constant(True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) @@ -64,7 +65,7 @@ class CondTest(xla_test.XLATestCase): @def_function.function def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) - output = control_flow_ops.cond( + output = cond.cond( constant_op.constant(False), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) @@ -84,7 +85,7 @@ class CondTest(xla_test.XLATestCase): def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) - output = control_flow_ops.cond( + output = cond.cond( constant_op.constant(True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) @@ -112,7 +113,7 @@ class CondTest(xla_test.XLATestCase): def if_false(): return 5. - output = control_flow_ops.cond( + output = cond.cond( constant_op.constant(True), if_true, if_false) self.assertAllEqual(1., @@ -142,7 +143,7 @@ class CondTest(xla_test.XLATestCase): def if_false(): return 5. - return control_flow_ops.cond( + return cond.cond( constant_op.constant(True), if_true, if_false) output = xla.compile(f) @@ -169,7 +170,7 @@ class CondTest(xla_test.XLATestCase): def if_false(): return array_ops.fill([p], 5.) - output = control_flow_ops.cond( + output = cond.cond( constant_op.constant(True), if_true, if_false) with self.assertRaisesRegex(errors.InvalidArgumentError, @@ -202,7 +203,7 @@ class CondTest(xla_test.XLATestCase): def if_false(): return array_ops.fill([p], 5.) - return control_flow_ops.cond(condition, if_true, if_false) + return cond.cond(condition, if_true, if_false) output = xla.compile(f) @@ -304,7 +305,7 @@ class CondTest(xla_test.XLATestCase): xla_context.Enter() for pred in True, False: - cond_out = control_flow_ops.cond( + cond_out = cond.cond( array_ops.placeholder_with_default(pred, []), lambda: constant_op.constant(2.), lambda: constant_op.constant(1.)) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 0b8b5c2d866..6ef5a7c9a9f 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -21,7 +21,6 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices @@ -30,7 +29,7 @@ from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops_stack # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import cond from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_random_ops @@ -345,12 +344,12 @@ class EagerFunctionTest(xla_test.XLATestCase): v = resource_variable_ops.ResourceVariable(1.0) w = resource_variable_ops.ResourceVariable(0.0) - @function.defun_with_attributes(attributes={'_noinline': True}) + @def_function.function(experimental_attributes={'_noinline': True}) def g(x): w.assign(w.read_value() + x) return v.read_value() + x * w.read_value() - @function.defun_with_attributes(attributes={'_noinline': True}) + @def_function.function(experimental_attributes={'_noinline': True}) def f(): return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0) @@ -362,11 +361,11 @@ class EagerFunctionTest(xla_test.XLATestCase): with self.test_scope(): v = resource_variable_ops.ResourceVariable(10.0) - @function.defun_with_attributes(attributes={'_noinline': True}) + @def_function.function(experimental_attributes={'_noinline': True}) def g(): return v.read_value() - @function.defun_with_attributes(attributes={'_noinline': True}) + @def_function.function(experimental_attributes={'_noinline': True}) def f(): return g() + g() + g() + g() + g() @@ -376,11 +375,11 @@ class EagerFunctionTest(xla_test.XLATestCase): with self.test_scope(): v = resource_variable_ops.ResourceVariable(0.0) - @function.defun_with_attributes(attributes={'_noinline': True}) + @def_function.function(experimental_attributes={'_noinline': True}) def g(x): v.assign(x) - @function.defun_with_attributes(attributes={'_noinline': True}) + @def_function.function(experimental_attributes={'_noinline': True}) def f(): g(1.0) g(2.0) @@ -637,7 +636,7 @@ class EagerFunctionTest(xla_test.XLATestCase): def f(pred, value): fn1 = lambda: math_ops.add(value, 1.0) fn2 = lambda: math_ops.subtract(value, 1.0) - return control_flow_ops.cond(pred, fn1, fn2) + return cond.cond(pred, fn1, fn2) plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) diff --git a/tensorflow/compiler/tests/giant_const_op_test.py b/tensorflow/compiler/tests/giant_const_op_test.py index c0f4b47be01..014b9d5f1eb 100644 --- a/tensorflow/compiler/tests/giant_const_op_test.py +++ b/tensorflow/compiler/tests/giant_const_op_test.py @@ -56,16 +56,6 @@ def get_tpu_strategy(): # tensors. class GiantConstOp(test.TestCase): - def setUp(self): - super(GiantConstOp, self).setUp() - # Make sure TF_XLA_FLAGS is not already set to avoid dropping the existing - # value silently. - assert "TF_XLA_FLAGS" not in os.environ - - # Disable tfxla constant folding that always creates full Tensors and will - # fail for giant tensors. - os.environ["TF_XLA_FLAGS"] = "--tf_xla_disable_constant_folding=true" - # Verifies that graphs containing giant const tensors that won't fit in memory # are compiled correctly to HLO. def testGiantConst(self): @@ -106,4 +96,12 @@ class GiantConstOp(test.TestCase): self.assertAllEqual(output, expected) if __name__ == "__main__": + # Make sure TF_XLA_FLAGS is not already set to avoid dropping the existing + # value silently. + assert "TF_XLA_FLAGS" not in os.environ + + # Disable tfxla constant folding that always creates full Tensors and will + # fail for giant tensors. + os.environ["TF_XLA_FLAGS"] = "--tf_xla_disable_constant_folding=true" + test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index b0f252658a8..7f24610f3e8 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -358,7 +359,7 @@ class XlaCompilationTest(test.TestCase): c = array_ops.placeholder(dtypes.bool) with jit_scope(): z = x + 1.0 - w = control_flow_ops.cond(c, lambda: z, lambda: y) + w = cond.cond(c, lambda: z, lambda: y) t = math_ops.add(z, w) # If JIT compilation chooses to cluster z and t, then execution will diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py index 748d5f0a850..8dd2a786155 100644 --- a/tensorflow/compiler/tests/lstm.py +++ b/tensorflow/compiler/tests/lstm.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variables +from tensorflow.python.ops import variable_v1 def Clip(x): @@ -115,7 +115,7 @@ def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq): def RandomVar(shape, name=None): """Returns a variable of the given shape initialized to random values.""" - return variables.VariableV1( + return variable_v1.VariableV1( random_ops.random_uniform(shape), dtype=dtypes.float32, name=name) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 012fe158e1c..61c187cf7c4 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -368,9 +368,11 @@ class StatelessRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._testParameterizedTruncatedNormal(-1., 1., -2., 2.) def testParameterizedTruncatedNormalRightTail(self): + self.skipTest('b/276957102') self._testParameterizedTruncatedNormal(0., 1., 4., 20., variance_rtol=2e-2) def testParameterizedTruncatedNormalLeftTail(self): + self.skipTest('b/276957102') self._testParameterizedTruncatedNormal( 0., 1., -20., -4., variance_rtol=5e-2) diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 9a6b3dd0a73..d21a8fb4cc5 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -860,7 +860,7 @@ class TensorArrayTest(xla_test.XLATestCase): # c = lambda i, acc: i < 5 # def b(i, acc): - # x1 = control_flow_ops.cond( + # x1 = cond.cond( # math_ops.equal(i, 0), lambda: x, # lambda: math_ops.multiply(acc.read(i - 1), 2.0)) # return i + 1, acc.write(i, x1) diff --git a/tensorflow/compiler/tests/where_op_test.py b/tensorflow/compiler/tests/where_op_test.py index 186877c03fe..e150b52567c 100644 --- a/tensorflow/compiler/tests/where_op_test.py +++ b/tensorflow/compiler/tests/where_op_test.py @@ -16,15 +16,23 @@ # pylint: disable=g-direct-tensorflow-import from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +from tensorflow.python.tpu import tpu # pylint: enable=g-direct-tensorflow-import class WhereOpTest(xla_test.XLATestCase): + def __init__(self, method_name="runTest"): + super(WhereOpTest, self).__init__(method_name) + if config.list_logical_devices("TPU"): + with self.session() as sess: + sess.run(tpu.initialize_system()) + def testWhere(self): """Test first form of where (return indices).""" diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index e0923f32bac..01f30718217 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -13,12 +13,15 @@ # limitations under the License. # ============================================================================== """Tests for XLA call module op wrapper.""" - +from typing import Tuple import unittest + import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.ops import gen_xla_ops from tensorflow.compiler.tf2xla.python import xla + from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -27,6 +30,14 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest +def serialize(module_str: str) -> Tuple[str, int]: + # TODO(b/274838200): error importing xla_extension in OSS + # target_version = '0.9.0' # TODO(gleasonk): use APIs to get this + # return xla_extension.mlir.serialize_portable_artifact( + # module_str, target_version), 4 + return module_str, 3 + + class XlaCallModuleOpTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, @@ -64,7 +75,7 @@ class XlaCallModuleOpTest(xla_test.XLATestCase): def f(x): # sin(cos(x)) - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = stablehlo.cosine %arg0 : tensor<3xf32> @@ -72,8 +83,8 @@ module @jit_f.0 { return %1 : tensor<3xf32> } } -""" - return xla.call_module([x], version=2, +""") + return xla.call_module([x], version=version, module=module, Tout=[x.dtype], Sout=[x.shape]) self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) @@ -84,7 +95,7 @@ module @jit_f.0 { def f(x): # return x >= 1 - module = """ + module, version = serialize(""" module @jit_f_jax.0 { func.func public @main(%arg0: tensor) -> tensor { %0 = stablehlo.constant dense<1> : tensor @@ -92,8 +103,8 @@ module @jit_f_jax.0 { return %1 : tensor } } -""" - return xla.call_module([x], version=2, +""") + return xla.call_module([x], version=version, module=module, Tout=[res.dtype], Sout=[res.shape]) @@ -106,7 +117,7 @@ module @jit_f_jax.0 { def f(x, y): # (sin(x), cos(y)) - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor<3xf32>, %arg1: tensor<4xf64>) -> (tensor<3xf32>, tensor<4xf64>) { %0 = stablehlo.sine %arg0 : tensor<3xf32> @@ -114,8 +125,8 @@ module @jit_f.0 { return %0, %1 : tensor<3xf32>, tensor<4xf64> } } -""" - return xla.call_module([x, y], version=2, +""") + return xla.call_module([x, y], version=version, module=module, Tout=[x.dtype, y.dtype], Sout=[x.shape, y.shape]) @@ -128,16 +139,15 @@ module @jit_f.0 { def f(x): # x: f32[2, b] # Module takes another argument which is the value of b # (sin(x), x.shape[1]) - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { %0 = stablehlo.sine %arg1 : tensor<2x?xf32> return %0, %arg0 : tensor<2x?xf32>, tensor } } -""" - return xla.call_module([x], - version=2, +""") + return xla.call_module([x], version=version, module=module, Tout=[x.dtype, np.int32], Sout=[(None, 3), ()], @@ -151,17 +161,16 @@ module @jit_f.0 { def f(x): # x: f32[2, b] # Module takes another argument which is the value of b # (sin(x), x.shape[1]) - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { %0 = stablehlo.sine %arg1 : tensor<2x?xf32> return %0, %arg0 : tensor<2x?xf32>, tensor } } -""" +""") return xla.call_module([x], - version=2, - module=module, + module=module, version=version, Tout=[x.dtype, np.int64], Sout=[(None, 3), ()], dim_args_spec=['0.1']) @@ -174,7 +183,7 @@ module @jit_f.0 { def f(x): # x: f32[2, b] # (sin(x), x.shape[1]) - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 1 : i64} : (tensor<2x?xf32>) -> tensor @@ -186,9 +195,9 @@ module @jit_f.0 { return %0, %arg0 : tensor<2x?xf32>, tensor } } -""" - return xla.call_module([x], version=2, - module=module, +""") + return xla.call_module([x], + module=module, version=version, Tout=[x.dtype, np.int32], Sout=[(None, 3), ()]) @@ -201,7 +210,7 @@ module @jit_f.0 { # Module takes two prefix arguments with the values of b and c # return (sin(x + y), x.shape[1]) - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x?x?xf32>, %arg3: tensor<2x?x?xf32>) -> (tensor<2x?x?xf32>, tensor) { %0 = stablehlo.add %arg2, %arg3 : tensor<2x?x?xf32> @@ -209,13 +218,12 @@ module @jit_f.0 { return %1, %arg0 : tensor<2x?x?xf32>, tensor } } -""" +""") dim_args_spec = ['0.1', '0.2'] def f(x, y): return xla.call_module([x, y], - version=2, - module=module, + module=module, version=version, Tout=[x.dtype, np.int32], Sout=[(None, 3), ()], dim_args_spec=dim_args_spec) @@ -274,7 +282,7 @@ module @jit_f.0 { x = np.float32(0.) # returns x + 2. on CPU, x + 3. on GPU and x + 4. on TPU - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg_platform_idx: tensor, %arg0: tensor) -> tensor { %to_add = "stablehlo.case"(%arg_platform_idx) ({ @@ -291,12 +299,11 @@ module @jit_f.0 { return %0 : tensor } } -""" +""") platforms = ['CPU', 'CUDA', 'TPU'] def f(x): - return xla.call_module([x], - version=3, + return xla.call_module([x], version=version, module=module, Tout=[np.float32], Sout=[()], @@ -310,7 +317,7 @@ module @jit_f.0 { y = np.arange(3., dtype=np.float32) # returns x + x on CPU and x - x on TPU - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg_platform_idx: tensor, %arg_dim0: tensor, %arg0: tensor, %arg1: tensor) -> tensor { %res = "stablehlo.case"(%arg_platform_idx) ({ @@ -323,10 +330,9 @@ module @jit_f.0 { return %res : tensor } } -""" +""") def f(x, y): - return xla.call_module([x, y], - version=3, + return xla.call_module([x, y], version=version, module=module, Tout=[np.float32], Sout=[(None,)], @@ -341,18 +347,17 @@ module @jit_f.0 { """Error reporting for the platforms attribute.""" x = np.float32(0.) - module = """ + module_str = """ module @jit_f.0 { func.func public @main(%arg_platform_idx: tensor, %arg0: tensor) -> tensor { return %arg0 : tensor } } """ + module, version = serialize(module_str) platforms = [] - version = 3 def f(x): - return xla.call_module([x], - version=version, + return xla.call_module([x], version=version, module=module, Tout=[np.float32], Sout=[()], @@ -376,17 +381,6 @@ module @jit_f.0 { 'and 0 dimension arguments.'): self._assertOpOutputMatchesExpected(f, (x,), (x,)) - # Same if the version is 2 - platforms = ['CPU', 'CUDA', 'TPU'] - version = 2 - with self.assertRaisesRegex( - errors.InvalidArgumentError, - 'Incorrect number of arguments passed to XlaCallModule: 1. ' - 'The module takes 2 arguments of which 0 platform index arguments ' - 'and 0 dimension arguments.'): - self._assertOpOutputMatchesExpected(f, (x,), (x,)) - - version = 3 platforms = ['RANDOM_PLATFORM_1', 'RANDOM_PLATFORM_2'] with self.assertRaisesRegex( errors.NotFoundError, @@ -403,7 +397,7 @@ module @jit_f.0 { self._assertOpOutputMatchesExpected(f, (x,), (x,)) # The module cannot have i64 %arg_platform_idx - module = module.replace('i32', 'i64') + module, version = serialize(module_str.replace('i32', 'i64')) platforms = ['CPU', 'CUDA', 'TPU'] with self.assertRaisesRegex( errors.InvalidArgumentError, @@ -413,13 +407,13 @@ module @jit_f.0 { self._assertOpOutputMatchesExpected(f, (x,), (x,)) # A module without the platform index argument - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg0: tensor) -> tensor { return %arg0 : tensor } } -""" +""") with self.assertRaisesRegex( errors.InvalidArgumentError, 'The module should have 1 platform index arguments and 0 dimension ' @@ -432,7 +426,7 @@ module @jit_f.0 { def f(x): # x: f32[b, 5] # return np.arange(x.shape[0], dtype=np.int32) - module = """ + module, version = serialize(""" module @jit_fun.1 { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> @@ -440,8 +434,8 @@ module @jit_fun.1 { return %1 : tensor } } -""" - return xla.call_module([x,], version=2, +""") + return xla.call_module([x,], version=version, module=module, Tout=[res.dtype], Sout=[(None,)], @@ -453,17 +447,16 @@ module @jit_fun.1 { """We can construct the tf.Graph on all platforms.""" x = np.float32(0.) - module = """ + module, version = serialize(""" module @jit_f.0 { func.func public @main(%arg_platform_idx: tensor, %arg0: tensor) -> tensor { return %arg0 : tensor } } -""" +""") platforms = ['TPU'] # the module is compileable only on TPU def f(x): - return xla.call_module([x], - version=3, + return xla.call_module([x], version=version, module=module, Tout=[np.float32], Sout=[()], @@ -476,7 +469,7 @@ module @jit_f.0 { res = x.reshape((-1,)) def f(x): # x: f32[b, 3] - module = """ + module, version = serialize(""" module @jit_fun_flat_jax { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<3> : tensor @@ -486,8 +479,8 @@ module @jit_fun_flat_jax { return %3 : tensor } } -""" - return xla.call_module([x], +""") + return xla.call_module([x], version=version, module=module, Tout=[res.dtype], Sout=[(None,)], @@ -500,7 +493,7 @@ module @jit_fun_flat_jax { res = np.ones((3, 2), dtype=np.float32) def f(x): # x: f32[b, 4] - module = """ + module, version = serialize(""" module @jit_fun_flat_jax { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<0> : tensor @@ -512,8 +505,8 @@ module @jit_fun_flat_jax { return %5 : tensor } } -""" - return xla.call_module([x], +""") + return xla.call_module([x], version=version, module=module, Tout=[res.dtype], Sout=[(None, 2)], @@ -526,7 +519,7 @@ module @jit_fun_flat_jax { res = x[-1, :] def f(x): # x: f32[b, 4] - module = """ + module, version = serialize(""" module @jit_fun_flat_jax { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { %0 = stablehlo.constant dense<-1> : tensor @@ -543,8 +536,8 @@ module @jit_fun_flat_jax { return %12 : tensor<4xf32> } } -""" - return xla.call_module([x], +""") + return xla.call_module([x], version=version, module=module, Tout=[x.dtype], Sout=[(4,)], @@ -558,7 +551,7 @@ module @jit_fun_flat_jax { res = x # The update should be a nop def f(x, idx): # x: f32[b, 4] idx: i32 - module = """ + module, version = serialize(""" module @jit_fun_flat_jax { func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = stablehlo.constant dense<0> : tensor @@ -570,8 +563,8 @@ module @jit_fun_flat_jax { return %5 : tensor } } -""" - return xla.call_module([x, idx], +""") + return xla.call_module([x, idx], version=version, module=module, Tout=[res.dtype], Sout=[(None, 4)], @@ -586,7 +579,7 @@ module @jit_fun_flat_jax { def f(x, y): # x: f32[b, 4] y: f32[2, b, 4] # return (np.broadcast_to(x, y.shape), x + y) - module = """ + module, version = serialize(""" module @jit_fun.0 { func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { %0 = stablehlo.constant dense<2> : tensor<1xi32> @@ -598,8 +591,8 @@ module @jit_fun.0 { return %5, %6 : tensor<2x?x4xf32>, tensor<2x?x4xf32> } } -""" - return xla.call_module([x, y], version=2, +""") + return xla.call_module([x, y], version=version, module=module, Tout=[res[0].dtype, res[1].dtype], Sout=[(2, None, 4), (2, None, 4)], @@ -613,7 +606,7 @@ module @jit_fun.0 { res = np.sum(x) * x.shape[0] def f(x): # x: i32[b] - module = """ + module, version = serialize(""" module @jit_fun{ func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<0> : tensor @@ -626,8 +619,8 @@ module @jit_fun{ return %2 : tensor } } -""" - return xla.call_module([x], version=1, +""") + return xla.call_module([x], version=version, module=module, Tout=[res.dtype], Sout=[res.shape], @@ -640,7 +633,7 @@ module @jit_fun{ res = np.arange(3, dtype=np.float32).reshape(3, 1) * 5 def f(x): # x: f32[b, 5] - module = """ + module, version = serialize(""" module @jit_fun_flat_jax { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.constant dense<0.000000e+00> : tensor @@ -656,8 +649,8 @@ module @jit_fun_flat_jax { return %5 : tensor } } -""" - return xla.call_module([x,], +""") + return xla.call_module([x,], version=version, module=module, Tout=[res.dtype], Sout=[(None, 1)], @@ -671,7 +664,7 @@ module @jit_fun_flat_jax { res = np.arange(x.shape[0], dtype=np.int32) def f(x): # x: f32[b] - module = """ + module, version = serialize(""" module @jit_fun_3 { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = call @f(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -683,8 +676,8 @@ module @jit_fun_3 { return %1 : tensor } } -""" - return xla.call_module([x,], version=2, +""") + return xla.call_module([x,], version=version, module=module, Tout=[res.dtype], Sout=[()], @@ -697,15 +690,14 @@ module @jit_fun_3 { res = x def f(x): # x: f32[b] - module = """ + module, version = serialize(""" module @jit_fun_3 { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { return %arg1 : tensor } } -""" - return xla.call_module([x], - version=2, +""") + return xla.call_module([x], version=version, module=module, Tout=[res.dtype], Sout=[()], @@ -723,7 +715,7 @@ module @jit_fun_3 { res1 = np.int64(5) def f(x): # x: f32[b] - module = """ + module, version = serialize(""" module @jit_fun_flat_jax { func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { %0 = stablehlo.constant dense<0> : tensor @@ -744,8 +736,8 @@ module @jit_fun_flat_jax { return %1#0, %1#1 : tensor, tensor } } -""" - return xla.call_module([x,], version=2, +""") + return xla.call_module([x,], version=version, module=module, Tout=[res0.dtype, res1.dtype], Sout=[(None,), res1.shape], @@ -753,6 +745,34 @@ module @jit_fun_flat_jax { self._assertOpOutputMatchesExpected(f, (x,), (res0, res1)) + def test_op_backward_compatibility(self): + """Test for ensuring XlaCallModuleOp backward compatiblity.""" + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + def f(x): + # sin(cos(x)) + module, version = serialize(""" +module @jit_f.0 { + func.func public @main(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.cosine %arg0 : tensor<3xf32> + %1 = stablehlo.sine %0 : tensor<3xf32> + return %1 : tensor<3xf32> + } +} +""") + # Create the raw XlaCallModule op directly instead of calling + # `xla.call_module`, which handles default values for unpresent + # attributes. + return gen_xla_ops.xla_call_module( + [x], + version=version, + module=module, + Tout=[x.dtype], + Sout=[x.shape], + ) + + self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) + if __name__ == '__main__': # This test is using Tensorflow sessions which are not compatible with eager diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index c4fb8e3f44c..f8ce172dc34 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -289,15 +289,44 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): padding_value=7, padding_low=[2, 1], padding_high=[1, 2], - padding_interior=[1, 0]) + padding_interior=[1, 0], + ) self._assertOpOutputMatchesExpected( pad_fn, args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),), expected=np.array( - [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7], - [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], - dtype=dtype)) + [ + [7, 7, 7, 7, 7], + [7, 7, 7, 7, 7], + [7, 0, 1, 7, 7], + [7, 7, 7, 7, 7], + [7, 2, 3, 7, 7], + [7, 7, 7, 7, 7], + ], + dtype=dtype, + ), + ) + + def testSetDynamicDimensionSize(self): + dynamic_size = 7 + + # XLA doesn't support this for bfloat16. + for dtype in set(self.numeric_types).intersection( + set([np.int32, np.float32, np.float64, np.complex64])): + + def xla_set_dynamic_dimension_size_fn(x): + # Tell XLA to cut the array to size=dynamic_size. + return gen_xla_ops.xla_set_dynamic_dimension_size( + x, dim_index=0, size=dynamic_size + ) + + a = np.arange(10, dtype=np.int32).astype(dtype) + expected = a[:dynamic_size] + + self._assertOpOutputMatchesExpected( + xla_set_dynamic_dimension_size_fn, args=(a,), expected=expected + ) def testPadNegative(self): for dtype in self.numeric_types: @@ -574,6 +603,41 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(values_1, values_2), expected=(values_1, values_2)) + @test_util.disable_mlir_bridge('Not supported yet') + def testScatter(self): + test_array = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_data_pb2.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + add_numbers = function.Defun(np.int32, np.int32)(lambda x, y: x + y) + + def test_fn( + scatter_input, + scatter_indices, + scatter_updates, + ): + return gen_xla_ops.xla_scatter( + scatter_input, + scatter_indices, + scatter_updates, + add_numbers, + dnums.SerializeToString(), + indices_are_sorted=False, + ) + + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) + self._assertOpOutputMatchesExpected( + test_fn, + args=(test_array, scatter_indices, updates), + expected=expected, + ) + def testSelectAndScatter(self): for dtype in set(self.numeric_types).intersection( set([dtypes.bfloat16.as_numpy_dtype, np.float32])): diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index d0ce575deef..bb48f9e806b 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -102,9 +102,9 @@ tf_cuda_cc_test( ":utils", "//tensorflow/compiler/xla/stream_executor/gpu:gpu_init", "//tensorflow/core:lib", - "//tensorflow/core/platform:stream_executor", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:stream_executor", ] + if_tensorrt([ ":tensorrt_lib", ]) + select({ @@ -211,11 +211,11 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":trt_conversion", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:tensor_testutil", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "//tensorflow/core:protos_all_cc", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/framework:tensor_testutil", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -230,8 +230,8 @@ tf_cuda_cc_test( ], deps = [ ":testutils", - "//tensorflow/core:test_main", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test_main", "//tensorflow/core/platform:protobuf", ] + if_tensorrt([ ":tensorrt_lib", @@ -247,6 +247,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":common_utils", ":trt_allocator", ":trt_conversion", ":trt_engine_utils", @@ -254,19 +255,18 @@ cc_library( ":trt_plugins", ":trt_resources", ":utils", - ":common_utils", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core/platform:stream_executor", "//tensorflow/core:stream_executor_headers_lib", "//tensorflow/core/common_runtime:core_cpu_lib_no_ops", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/platform:stream_executor", "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ] + if_tensorrt([ ":tensorrt_lib", "@local_config_cuda//cuda:cuda_headers", @@ -285,14 +285,14 @@ cc_library( ":trt_logging", ":trt_plugins", ":trt_resources", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(), alwayslink = 1, ) @@ -342,16 +342,11 @@ tf_cuda_cc_test( "nomac", ], deps = [ + ":testutils", + ":trt_conversion", ":trt_op_kernels", ":trt_op_libs", ":trt_resources", - ":trt_conversion", - ":testutils", - "@com_google_googletest//:gtest", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "//third_party/eigen3", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:scope", @@ -362,10 +357,15 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/kernels:ops_testutil", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:array", "//tensorflow/core/framework:fake_input", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:ops_testutil", + "//third_party/eigen3", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", ] + if_tensorrt([ "@local_config_cuda//cuda:cuda_headers", ]), @@ -401,17 +401,17 @@ tf_cuda_library( ], deps = [ ":common_utils", - ":trt_logging", - ":utils", ":trt_allocator", + ":trt_logging", ":trt_parameters", - "@com_google_absl//absl/strings", + ":utils", "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_headers_lib", "//tensorflow/core/platform:status", "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core:stream_executor_headers_lib", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -424,8 +424,8 @@ tf_cuda_library( ":common_utils", ":logger_registry", ":utils", - "@com_google_absl//absl/strings", "//tensorflow/core:lib_proto_parsing", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -445,7 +445,6 @@ tf_custom_op_py_library( ":trt_ops", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", "//tensorflow/python:resources", ], ) @@ -459,9 +458,9 @@ tf_cuda_library( copts = tf_copts(), deps = [ ":utils", - "@com_google_absl//absl/strings", - "//tensorflow/core:lib", "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -485,10 +484,10 @@ tf_cuda_library( ":utils", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", - "//tensorflow/core/grappler:op_types", - "//tensorflow/core:graph", "//tensorflow/core:gpu_runtime", + "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core/grappler:op_types", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -557,8 +556,8 @@ tf_cuda_library( ], copts = tf_copts(), deps = [ - "@com_google_absl//absl/strings", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -571,8 +570,8 @@ tf_cuda_library( copts = tf_copts(), deps = [ ":utils", - "//tensorflow/core:lib", "//tensorflow/core:framework", + "//tensorflow/core:lib", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -602,10 +601,10 @@ tf_cuda_library( ], visibility = ["//tensorflow:__subpackages__"], deps = [ - ":utils", ":op_converter", - "@com_google_absl//absl/strings", + ":utils", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -616,8 +615,8 @@ tf_cuda_library( ], copts = tf_copts(), deps = [ - ":utils", ":op_converter", + ":utils", "//tensorflow/core:lib", "//tensorflow/core/platform:env", "//tensorflow/core/platform:logging", @@ -704,27 +703,17 @@ tf_cuda_library( ":algorithm_selector", ":common_utils", ":logger_registry", - ":segment", - ":trt_allocator", - ":trt_parameters", - ":trt_plugins", - ":trt_logging", - ":trt_resources", - ":utils", - ":trt_weights", ":op_converter", ":op_converter_registry", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", + ":segment", + ":trt_allocator", + ":trt_logging", + ":trt_parameters", + ":trt_plugins", + ":trt_resources", + ":trt_weights", + ":utils", "//tensorflow/cc:array_ops", - "//tensorflow/core/common_runtime:core_cpu", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:op_types", - "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/utils:functions", "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//tensorflow/core:gpu_runtime", @@ -732,13 +721,23 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu", "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler/optimizers:meta_optimizer", - "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/profiler/lib:annotated_traceme", + "//tensorflow/core/profiler/lib:traceme", "//tensorflow/tools/graph_transforms:transform_utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps() + select({ ":use_efficient_nms_plugin": [":efficient_nms_plugin"], "//conditions:default": [], @@ -756,17 +755,13 @@ tf_cuda_cc_test( "nomac", ], deps = [ + ":testutils", + ":trt_conversion", ":trt_op_kernels", ":trt_op_libs", - ":trt_conversion", - ":testutils", - "@com_google_googletest//:gtest", - "@com_google_absl//absl/strings", "//tensorflow/cc:cc_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:direct_session", @@ -775,6 +770,10 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -791,31 +790,31 @@ tf_cuda_cc_test( "nomac", ], deps = [ - ":trt_logging", - ":trt_conversion", - ":trt_plugins", - ":trt_engine_utils", - ":utils", ":testutils", - "@com_google_googletest//:gtest", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", + ":trt_conversion", + ":trt_engine_utils", + ":trt_logging", + ":trt_plugins", + ":utils", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:ops", "//tensorflow/cc:scope", - "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:resource_variable_ops", - "//tensorflow/core:test", "//tensorflow/core/platform:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", ] + if_tensorrt([ ":tensorrt_lib", "@local_config_cuda//cuda:cuda_headers", @@ -834,38 +833,38 @@ tf_cuda_cc_test( "nomac", ], deps = [ - ":trt_logging", + ":testutils", ":trt_conversion", ":trt_convert_api", - ":trt_plugins", ":trt_engine_utils", + ":trt_logging", ":trt_op_kernels", + ":trt_plugins", ":trt_resources", ":utils", - ":testutils", - "//tensorflow/compiler/jit:shape_inference", - "@com_google_googletest//:gtest", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:ops", "//tensorflow/cc:scope", + "//tensorflow/compiler/jit:shape_inference", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/platform:status_matchers", - "//tensorflow/core/kernels:ops_testutil", - "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:nn", + "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:pooling_ops", + "//tensorflow/core/platform:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", ] + if_tensorrt([ ":tensorrt_lib", "@local_config_cuda//cuda:cuda_headers", @@ -963,12 +962,12 @@ cc_library( ], copts = tf_copts(), deps = [ - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ] + if_tensorrt([":tensorrt_lib"]), ) diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.cc b/tensorflow/compiler/tf2tensorrt/common/utils.cc index 69ecc84dca7..92166c2e79e 100644 --- a/tensorflow/compiler/tf2tensorrt/common/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/common/utils.cc @@ -213,6 +213,11 @@ std::ostream& operator<<(std::ostream& os, const nvinfer1::DataType& v) { case nvinfer1::DataType::kHALF: os << "kHalf"; break; +#if IS_TRT_VERSION_GE(8, 6, 0, 0) + case nvinfer1::DataType::kFP8: + os << "kFP8"; + break; +#endif case nvinfer1::DataType::kINT8: os << "kINT8"; break; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 79a60d2b1de..e809152c1e7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -914,7 +914,7 @@ Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params, } else { // Graph is not modified. LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg - << " reason: " << status.error_message() + << " reason: " << status.message() << " (keeping original segment)."; } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 914576f7552..1c3a1903477 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1010,11 +1010,11 @@ Status TrtNodeValidator::IsTensorRTCandidate(const Node* node) { &tensor_or_weights); if (!status.ok()) { VLOG(2) << "Failed to convert input `" << src_def.name() << "` to a " - << "TRT_TensorOrWeights: " << status.error_message(); + << "TRT_TensorOrWeights: " << status.message(); return errors::Internal( "Failed to convert at least one input to a TRT_TensorOrWeights: ", - status.error_message()); + status.message()); } inputs.push_back(tensor_or_weights); } @@ -1131,11 +1131,10 @@ Status Converter::ConvertNode(const NodeDef& node_def) { << output.DebugString(); Status status = AddTensorOrWeights(output_name, output); if (!status.ok()) { - return errors::Create( - static_cast(status.code()), - StrCat("Failed to add output for node: ", node_def.name(), ": ", - status.error_message()), - errors::GetPayloads(status)); + return errors::Create(static_cast(status.code()), + StrCat("Failed to add output for node: ", + node_def.name(), ": ", status.message()), + errors::GetPayloads(status)); } } return OkStatus(); @@ -1151,7 +1150,7 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype, status = MaybeUpdateBatchSize(batch_size); if (!status.ok()) { return errors::CreateWithUpdatedMessage( - status, batch_size_error(name, status.error_message())); + status, batch_size_error(name, status.message())); } } ITensorProxyPtr tensor = network()->addInput(name.c_str(), dtype, dims); @@ -1162,8 +1161,8 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype, status = AddTensorOrWeights(name, TRT_TensorOrWeights(tensor)); if (!status.ok()) { return errors::CreateWithUpdatedMessage( - status, StrCat("Failed to add input tensor ", name, ": ", - status.error_message())); + status, + StrCat("Failed to add input tensor ", name, ": ", status.message())); } return OkStatus(); } @@ -1173,8 +1172,8 @@ Status Converter::AddInputResource(const string& name, Status status = AddTensorOrWeights(name, TRT_TensorOrWeights(resource)); if (!status.ok()) { return errors::CreateWithUpdatedMessage( - status, StrCat("Failed to add input resource ", name, ": ", - status.error_message())); + status, + StrCat("Failed to add input resource ", name, ": ", status.message())); } return OkStatus(); } @@ -1376,7 +1375,7 @@ Status Converter::BuildCudaEngine( auto cache = registry->LookUp("default_cache", builder_config.get()); if (!cache.ok()) { LOG(WARNING) << "failed to create a timing cache: " - << cache.status().error_message(); + << cache.status().message(); } else { timing_cache = std::move(*cache); builder_config->setTimingCache(*timing_cache, /*ignoreMismatch*/ false); @@ -2537,7 +2536,7 @@ Status Converter::SqueezeTensor(ITensorProxyPtr input, // Reshape tensor. TF_RETURN_IF_ERROR(PrepareTensorForShape( params->converter, TRT_TensorOrWeights(input), DimsAdapter(*input_dims), - /*validation_only=*/false, output, params->node_def)); + /*validation_only=*/false, output, params->node_def, op_instance)); return OkStatus(); } @@ -5945,7 +5944,7 @@ Status ConvertGraphDefToEngine( if (!status.ok()) { const string error_message = StrCat("Validation failed for ", node_name, " and input slot ", - slot_number, ": ", status.error_message()); + slot_number, ": ", status.message()); LOG_WARNING_WITH_PREFIX << error_message; return errors::CreateWithUpdatedMessage(status, error_message); } @@ -6238,7 +6237,7 @@ std::string unexpected_type_error_msg(nvinfer1::DataType type_being_checked, DebugString(type_being_checked) + "."; } -string batch_size_error(const string& name, const string& comment) { +string batch_size_error(absl::string_view name, absl::string_view comment) { return StrCat("Batch size doesn't match for tensor '", name, "' : ", comment); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index be675a1a9c6..e9afd320be9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -574,7 +574,7 @@ std::string input_shapes_error_msg(const nvinfer1::Dims& shape1, const nvinfer1::Dims& shape2, const NodeDef& node, bool then_vs_else = false); -std::string batch_size_error(const string& name, const string& comment); +std::string batch_size_error(absl::string_view name, absl::string_view comment); inline bool find_name(const string& name, const std::vector names) { return std::find(names.begin(), names.end(), name) != names.end(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index d1c75833219..91b5b3540eb 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1457,7 +1457,7 @@ class OpConverterTest : public ::testing::Test { void RunConversion(const Node* node, absl::StatusCode expected_code = absl::StatusCode::kOk, - const std::string& expected_msg_substr = "") { + absl::string_view expected_msg_substr = "") { EXPECT_THAT(converter_->ConvertNode(node->def()), StatusIs(expected_code, HasSubstr(expected_msg_substr))); if (expected_code == absl::StatusCode::kOk) { @@ -1470,7 +1470,7 @@ class OpConverterTest : public ::testing::Test { void RunValidationAndConversion( const NodeDef& node_def, absl::StatusCode expected_code = absl::StatusCode::kOk, - const std::string& expected_msg_substr = "", + absl::string_view expected_msg_substr = "", bool should_run_conversion = true) { // Add the node to the graph. // TODO(laigd): we should accept a function that adds the node using @@ -1505,7 +1505,7 @@ class OpConverterTest : public ::testing::Test { const std::vector>& exp_out_dims) { RunValidationAndConversion(node_def, static_cast(status.code()), - status.error_message(), true); + status.message(), true); if (status.ok()) { // TODO(tfeher): Enable this check in explicit_batch_mode. @@ -9881,6 +9881,48 @@ TEST_P(OpConverter_Select, ConvertSelectV2) { RunTest("SelectV2"); } TEST_P(OpConverter_Select, Convert_Select) { RunTest("Select"); } +TEST_F(OpConverterTest, DuplicateSqueeze) { + // Define a custom converter which performs multiple squeezes. + auto op_converter = [](const OpConverterParams* params) -> Status { + if (params->validation_only) return OkStatus(); + auto input = params->inputs.at(0).tensor(); + ITensorProxyPtr output; + // Squeeze the first dimension. + std::vector new_dims = {0, 1, 2, 3}; + TF_EXPECT_OK(params->converter->SqueezeTensor( + /*input=*/input, /*input_dims=*/&new_dims, /*params=*/params, + /*output=*/&output, /*op_instance=*/0)); + // Squeeze the second dimension. + new_dims = {0, 2, 3}; + TF_EXPECT_OK(params->converter->SqueezeTensor( + /*input=*/output, /*input_dims=*/&new_dims, /*params=*/params, + /*output=*/&output, /*op_instance=*/1)); + params->outputs->push_back(TRT_TensorOrWeights(output)); + return OkStatus(); + }; + // Use a simple unary op for the custom converter and add an input. + NodeDef node_def = CreateUnaryOp(DataType::DT_FLOAT); + AddTestTensor("input", {1, 1, 2, 3}); + // Override the converter for Abs to use the custom converter for this test + // only, and run conversion. + GetOpConverterRegistry()->Register("Abs", kDefaultConverterPriority + 1, + op_converter); + RunValidationAndConversion(node_def); + // Set up the inputs and outputs. + DataVec input_data; + DataVec output_data; + InputOutputData abs_input{ + "input", ConstructTensor(/*data_size=*/6, /*value=*/0, + /*tf_type=*/DataType::DT_FLOAT)}; + InputOutputData abs_output{ + "my_unary", ConstructTensor(/*data_size=*/6, /*value=*/0, + /*tf_type=*/DataType::DT_FLOAT)}; + input_data.push_back(abs_input); + output_data.push_back(abs_output); + // Build and run the cuda engine. + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); +} + #endif } // namespace convert diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h index 5445df8b51c..e3aadc279d9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h @@ -657,8 +657,8 @@ class TRTNetworkBuilder { nvinfer1::INetworkDefinition* Network() { return network_; } private: - nvinfer1::INetworkDefinition* const network_; - TrtWeightStore* const weight_store_; + nvinfer1::INetworkDefinition* network_; + TrtWeightStore* weight_store_; }; class ShuffleBuilder { diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index ef61ea3fce6..f2cc8be2fd0 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -69,6 +69,10 @@ string DebugString(const nvinfer1::DataType trt_dtype) { #if IS_TRT_VERSION_GE(8, 5, 0, 0) case nvinfer1::DataType::kUINT8: return "kUINT8"; +#endif +#if IS_TRT_VERSION_GE(8, 6, 0, 0) + case nvinfer1::DataType::kFP8: + return "kFP8"; #endif default: return "Invalid TRT data type"; @@ -204,6 +208,11 @@ Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { case nvinfer1::DataType::kUINT8: *tf_type = DT_UINT8; break; +#endif +#if IS_TRT_VERSION_GE(8, 6, 0, 0) + case nvinfer1::DataType::kFP8: + *tf_type = DT_FLOAT8_E4M3FN; + break; #endif default: return errors::InvalidArgument("Invalid TRT data type"); diff --git a/tensorflow/compiler/tf2tensorrt/convert/weights.cc b/tensorflow/compiler/tf2tensorrt/convert/weights.cc index c608291a0ae..da2157096b5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/weights.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/weights.cc @@ -68,6 +68,9 @@ size_t TRT_ShapedWeights::size_bytes() const { break; #if IS_TRT_VERSION_GE(8, 5, 0, 0) case nvinfer1::DataType::kUINT8: +#endif +#if IS_TRT_VERSION_GE(8, 6, 0, 0) + case nvinfer1::DataType::kFP8: #endif case nvinfer1::DataType::kINT8: case nvinfer1::DataType::kBOOL: diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 1e1d7eab557..abf83f27027 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -244,6 +244,9 @@ class TRTEngineOp : public AsyncOpKernel { // Maximum number of cached engines. int max_cached_engines_; + // Flag to detect whether native segment nodes have been deleted from graph + bool native_segment_absent_; + int64 workspace_size_; mutex engine_mutex_; FunctionLibraryRuntime::Handle native_execution_func_handle_; @@ -357,7 +360,7 @@ StatusOr TRTEngineOp::ConstructFunctionHandle( FunctionLibraryRuntime::InstantiateOptions inst_ops; inst_ops.state_handle = ""; inst_ops.target = device_name; - if (allow_soft_placement) { + if (!native_segment_absent_ && allow_soft_placement) { const FunctionDef* fdef = lib->GetFunctionLibraryDefinition()->Find(func_.name()); if (!fdef) { @@ -421,9 +424,6 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("calibration_data", &calibration_data)); OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_)); - OP_REQUIRES(context, !func_.name().empty(), - errors::InvalidArgument( - "The TF function for the TRT segment could not be empty")); OP_REQUIRES_OK(context, TrtPrecisionModeFromName(precision_string, &precision_mode_)); OP_REQUIRES_OK(context, @@ -468,11 +468,17 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) use_explicit_precision_ = false; } + // When a TF-TRT converted model without native segments is loaded, + // func_ can be empty. + native_segment_absent_ = (func_.name() == ""); native_execution_func_handle_ = kInvalidHandle; - if (!static_engine_) { - OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(), - context->device()->name())); + if (!native_segment_absent_) { + if (!static_engine_) { + OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(), + context->device()->name())); + } } + // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for // backward compatibility reasons. Remove it once all known users switch to // 2.0. @@ -721,7 +727,12 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, VLOG(2) << "Passed calibration data"; } } - ExecuteNativeSegment(ctx, async_helper); + if (!native_segment_absent_) { + ExecuteNativeSegment(ctx, async_helper); + } else { + LOG(ERROR) << "Calibration requires native segment, but is not found in " + "the graph."; + } } Status TRTEngineOp::VerifyInputShapes( @@ -843,11 +854,11 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, Status verify_input_shape_status = VerifyInputShapes(input_concrete_shapes_filtered); // TODO(bixia): Fix the segmentation. - if (!verify_input_shape_status.ok()) { + if (!verify_input_shape_status.ok() && !native_segment_absent_) { LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Running native segment for" << name() << " due to failure in verifying input shapes: " - << verify_input_shape_status.error_message(); + << verify_input_shape_status.message(); ExecuteNativeSegment(ctx, async_helper); return; } @@ -868,8 +879,14 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, // Just collect the input shape info and return. The shapes are used to // generate optimization profiles during engine creation. cache_res->profiles_.AddShape(input_concrete_shapes); - VLOG(1) << "Native segment is used during collecting shapes for profiles"; - ExecuteNativeSegment(ctx, async_helper); + VLOG(1) + << "Native segment is used during collecting shapes for profiles."; + if (!native_segment_absent_) { + ExecuteNativeSegment(ctx, async_helper); + } else { + LOG(ERROR) << "Native segment is required for profile generation, " + "but is not found in the graph."; + } return; } else if (cache_res->profiles_.GetNumProfiles() == 0 && !static_engine_) { // Add current shape if we did not collect any shapes so far. @@ -926,9 +943,14 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, EngineContext* engine_context = status.value().first; int trt_context_idx = status.value().second; auto may_execute_native_segment = [&] { - if (!AllowEngineNativeSegmentExecution()) { + if (!native_segment_absent_ && !AllowEngineNativeSegmentExecution()) { ctx->CtxFailure( - errors::Aborted("User disallowed engine native segment execution")); + errors::Aborted("User disallowed engine native segment execution.")); + return false; + } else if (native_segment_absent_) { + ctx->CtxFailure( + errors::Aborted("Native segment execution is enabled but " + " native segment is not found in the graph.")); return false; } return true; @@ -954,14 +976,20 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, if (!may_execute_native_segment()) { return; } - // Release any outputs that are allocated, ExecuteNativeSegment will - // re-allocate them and fail if they are currently allocated. + // When Native Segment execution is enabled, release any outputs that + // are allocated. ExecuteNativeSegment will re-allocate them and + // fail if they are currently allocated. // The Tensor pointer in the returned TensorValue must be explicitly // deleted. for (int i = 0; i < ctx->num_outputs(); i++) { delete ctx->release_output(i).tensor; } - ExecuteNativeSegment(ctx, async_helper); + if (!native_segment_absent_) { + ExecuteNativeSegment(ctx, async_helper); + } else { + LOG(ERROR) << "Native segment execution is enabled, " + "but native segment is not found in the graph."; + } } Status TRTEngineOp::ExecuteTrtEngine( diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index a9994bc2db3..3e71229888b 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -765,7 +765,7 @@ string GenerateNonConversionReport( // Log the error in case of issue, however do not stop execution. LOG(ERROR) << "Problem encountered while generating the TF-TRT " << "Non-Conversion Report in CSV Format:\n" - << status.error_message(); + << status.message(); } show_detailed_conversion_report = true; } else if (std::stoi(detailed_report_var) >= 1) { @@ -949,7 +949,7 @@ Status SegmentGraph(const Graph* tf_graph, } else { const Status status = candidate_fn(node->tf_node()); if (!status.ok()) { - exclude_node(status.error_message()); + exclude_node(status.message()); } else if (tftrt_op_denylist.contains(node->tf_node()->type_string())) { // WARNING verbosity since the user explicitly requests this behavior. LOG_WARNING_WITH_PREFIX diff --git a/tensorflow/compiler/tf2tensorrt/trt_convert_api.cc b/tensorflow/compiler/tf2tensorrt/trt_convert_api.cc index c675e1157e0..171798b216a 100644 --- a/tensorflow/compiler/tf2tensorrt/trt_convert_api.cc +++ b/tensorflow/compiler/tf2tensorrt/trt_convert_api.cc @@ -158,8 +158,8 @@ Status RunTfTrt(const MetaGraphDef& meta_graph_def, const RewriterConfig& rewriter_config, GraphDef* segmented_graph_def) { ConfigProto config_proto; - config_proto.mutable_graph_options()->mutable_rewrite_options()->CopyFrom( - rewriter_config); + *config_proto.mutable_graph_options()->mutable_rewrite_options() = + rewriter_config; VLOG(4) << "Setting up Grappler parameters\n" << config_proto.DebugString(); std::unique_ptr cluster; @@ -202,7 +202,7 @@ Status RunSession(Session* session, const std::vector& input_names, std::vector> input_pairs; std::vector prefixed_output_names; auto prefixed_name = [](std::string prefix, std::string name) { - return prefix.size() > 0 ? absl::StrJoin({prefix, name}, "/") : name; + return !prefix.empty() ? absl::StrJoin({prefix, name}, "/") : name; }; for (int i = 0; i < input_names.size(); i++) { input_pairs.push_back( @@ -315,7 +315,7 @@ Status ReadSerializedEngine( // Saves the TRT engines as attributes of the TRTEngineOp nodes. Status ConvertToStaticEngine(const GraphDef graph_def, GraphDef* static_graph_def, Session* session) { - static_graph_def->CopyFrom(graph_def); + *static_graph_def = graph_def; VLOG(1) << "Saving TRT engines as static engine"; std::string op{"TRTEngineOp"}; for (auto& node : *(static_graph_def->mutable_node())) { @@ -397,7 +397,7 @@ StatusOr ConvertAndBuild( const TfTrtConversionParams& conv_params) { TF_RETURN_IF_ERROR(ValidateConversionParams(conv_params, inputs.size())); MetaGraphDef meta_graph; - meta_graph.mutable_graph_def()->CopyFrom(frozen_graph_def); + *meta_graph.mutable_graph_def() = frozen_graph_def; RewriterConfig rewriter_config; TF_RETURN_IF_ERROR( @@ -409,12 +409,12 @@ StatusOr ConvertAndBuild( GraphDef output; - if (inputs.size() > 0 && conv_params.convert_to_static_engine) { + if (!inputs.empty() && conv_params.convert_to_static_engine) { // The TRTOptimization pass has inserted placeholder TRTEngineOps. Here we // trigger conversion by inferring the graph. std::unique_ptr session( tensorflow::NewSession(GetSessionConfg())); - if (!session.get()) { + if (!session) { return errors::Internal("Failed to create build session"); } @@ -424,7 +424,7 @@ StatusOr ConvertAndBuild( TF_RETURN_IF_ERROR( ConvertToStaticEngine(segmented_graph_def, &output, session.get())); } else { - output.CopyFrom(segmented_graph_def); + output = segmented_graph_def; } VLOG(1) << "TF-TRT conversion finished"; return output; @@ -456,9 +456,9 @@ Status FreezeGraph(SavedModelBundle& bundle, MetaGraphDef* frozen_meta_graph) { TF_RETURN_IF_ERROR( FreezeSavedModel(bundle, &frozen_graph_def, &inputs, &outputs)); - frozen_meta_graph->CopyFrom(bundle.meta_graph_def); + *frozen_meta_graph = bundle.meta_graph_def; GraphDef* gdef = frozen_meta_graph->mutable_graph_def(); - gdef->CopyFrom(frozen_graph_def); + *gdef = frozen_graph_def; VLOG(2) << "Graph frozen"; return OkStatus(); @@ -491,7 +491,7 @@ StatusOr ConvertAndBuild( // Replace the graph_def with the inlined graph. Note that bundle->session // still has the original graph. - bundle->meta_graph_def.mutable_graph_def()->CopyFrom(inlined_graph_def); + *bundle->meta_graph_def.mutable_graph_def() = inlined_graph_def; // Freeze variables. MetaGraphDef frozen_meta_graph; diff --git a/tensorflow/compiler/tf2tensorrt/trt_convert_api_test.cc b/tensorflow/compiler/tf2tensorrt/trt_convert_api_test.cc index 706a8b515e1..5d969614448 100644 --- a/tensorflow/compiler/tf2tensorrt/trt_convert_api_test.cc +++ b/tensorflow/compiler/tf2tensorrt/trt_convert_api_test.cc @@ -107,16 +107,14 @@ class TrtConverterTest {}); FunctionDef fdef; if (use_variable_) { - gdef.add_node()->CopyFrom( + *gdef.add_node() = NDef("my_var", "VarHandleOp", {}, - {{"dtype", DT_FLOAT}, {"shape", value_shape_proto}})); + {{"dtype", DT_FLOAT}, {"shape", value_shape_proto}}); - gdef.add_node()->CopyFrom(NDef("my_var/init", "AssignVariableOp", - {"my_var", "my_const"}, - {{"dtype", DT_FLOAT}})); - gdef.add_node()->CopyFrom(NDef("my_var/Read/ReadVariableOp", - "ReadVariableOp", {"my_var"}, - {{"dtype", DT_FLOAT}})); + *gdef.add_node() = NDef("my_var/init", "AssignVariableOp", + {"my_var", "my_const"}, {{"dtype", DT_FLOAT}}); + *gdef.add_node() = NDef("my_var/Read/ReadVariableOp", "ReadVariableOp", + {"my_var"}, {{"dtype", DT_FLOAT}}); // Define function f(x, v) = x * v + x, where v is a variable. fdef = FunctionDefHelper::Define( "f", // Name @@ -146,7 +144,7 @@ class TrtConverterTest {{"my_add"}, "AddV2", {"x", "my_mul"}, {{"T", DT_FLOAT}}}, {{"q"}, "Identity", {"my_add"}, {{"T", DT_FLOAT}}}}); } - gdef.mutable_library()->add_function()->CopyFrom(fdef); + *gdef.mutable_library()->add_function() = fdef; return gdef; } @@ -166,13 +164,12 @@ class TrtConverterTest SignatureDef signature_def; (*signature_def.mutable_inputs())["input"].set_name("input:0"); (*signature_def.mutable_inputs())["input"].set_dtype(DT_FLOAT); - (*signature_def.mutable_inputs())["input"].mutable_tensor_shape()->CopyFrom( - shape_proto); + *(*signature_def.mutable_inputs())["input"].mutable_tensor_shape() = + shape_proto; (*signature_def.mutable_outputs())["output"].set_name("output:0"); (*signature_def.mutable_outputs())["output"].set_dtype(DT_FLOAT); - (*signature_def.mutable_outputs())["output"] - .mutable_tensor_shape() - ->CopyFrom(shape_proto); + *(*signature_def.mutable_outputs())["output"].mutable_tensor_shape() = + shape_proto; (*out.mutable_signature_def())["serving_default"] = signature_def; VLOG(2) << signature_def.DebugString(); diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc index 110a32b1f2f..798ebd8bd0c 100755 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc @@ -108,6 +108,10 @@ Status SetupBindings(nvinfer1::ICudaEngine* cuda_engine, const Tensor& tensor, case nvinfer1::DataType::kUINT8: buffers[binding_index] = const_cast(tensor.flat().data()); break; +#endif +#if IS_TRT_VERSION_GE(8, 6, 0, 0) + case nvinfer1::DataType::kFP8: + return errors::Internal("FP8 inputs are not supported yet!"); #endif default: return errors::Internal("Unknown TRT data type: ", diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 8827bd480b1..0bf252386bc 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -35,6 +35,7 @@ package_group( "//tensorflow/compiler/tf2xla/...", "//tensorflow/core/tpu/...", "//tensorflow/python/compiler/...", + "//tensorflow/python/util/...", ], ) @@ -42,6 +43,7 @@ package_group( name = "friends", includes = [":internal"], packages = [ + "//platforms/performance/automl/...", "//tensorflow/...", "//tensorflow_federated/cc/core/impl/executors/...", "//tensorflow_models/...", @@ -210,6 +212,7 @@ filegroup( "//tensorflow/compiler/xla:cpu_runtime_hdrs", "//tensorflow/compiler/xla/service:custom_call_status_hdrs", "//tensorflow/compiler/xla/service/cpu:runtime_hdrs", + "//tensorflow/compiler/xla/service/cpu:xla_runtime_runner_hdrs", "//tensorflow/core/kernels:xla_cpu_runtime_hdrs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", "//tensorflow/tsl/framework:xla_cpu_runtime_hdrs", @@ -226,6 +229,7 @@ filegroup( "//tensorflow/compiler/xla:cpu_runtime_srcs", "//tensorflow/compiler/xla/service:custom_call_status_srcs", "//tensorflow/compiler/xla/service/cpu:runtime_srcs", + "//tensorflow/compiler/xla/service/cpu:xla_runtime_runner_srcs", "//tensorflow/core/kernels:xla_cpu_runtime_srcs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", "//tensorflow/tsl/platform:xla_cpu_runtime_srcs", @@ -282,7 +286,9 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "//third_party/eigen3", @@ -371,6 +377,7 @@ cc_library( # binary produced by tfcompile. "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla/service/cpu:buffer_desc", "//tensorflow/core/platform:types", ], ) @@ -403,10 +410,10 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/core:lib", - "//tensorflow/core/platform:errors", - "//tensorflow/core:protos_all_cc", "//tensorflow/compiler/xla/stream_executor:platform", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", ] + if_libtpu( if_false = [ "//tensorflow/compiler/xla/service:cpu_plugin", @@ -417,14 +424,32 @@ cc_library( ), ) +tf_cc_test( + name = "graph_compiler_test", + srcs = ["graph_compiler_test.cc"], + deps = [ + ":graph_compiler_util", + ":tf2xla_proto_cc", + ":xla_compilation_device", + ":xla_compiler", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/core/platform:refcount", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "xla_compiler", srcs = [ "const_analysis.cc", "graph_compiler.cc", "xla_compiler.cc", - "xla_op_kernel.cc", "xla_cpu_backend.cc", + "xla_op_kernel.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]) + if_rocm_is_configured([ @@ -442,8 +467,8 @@ cc_library( visibility = [":friends"], deps = [ ":common", - ":layout_util", ":host_compute_metadata_proto_cc", + ":layout_util", ":rearrange_function_argument", ":sharding_util", ":side_effect_util", @@ -455,22 +480,13 @@ cc_library( ":xla_helpers", ":xla_op_registry", ":xla_resource", - "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_compile_util", + "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", + "//tensorflow/compiler/mlir/tf2xla/api/v0:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/mlir/utils:array_container_utils", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:layout_util", - "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util_no_tf_dialect_passes", - "//tensorflow/compiler/xla/client:value_inference", - "//tensorflow/compiler/xla/service:computation_placer_hdr", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", @@ -479,11 +495,12 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:value_inference", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/core/util:overflow", - "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/compiler/xla/service:computation_placer_hdr", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:layout_util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -491,6 +508,14 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/core/util:overflow", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ] + if_libtpu([ ":xla_tpu_backend_registration", ]), @@ -1336,8 +1361,9 @@ cc_library( deps = [ ":xla_compiler", "//tensorflow/compiler/jit:xla_compile_util", - "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util_no_tf_dialect_passes", + "//tensorflow/compiler/mlir/tf2xla/api/v0:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/mlir/utils:array_container_utils", + "//tensorflow/core:framework", "@llvm-project//mlir:IR", ], ) @@ -1380,3 +1406,41 @@ tf_cuda_cc_test( "@com_google_absl//absl/memory", ], ) + +filegroup( + name = "tf2xla_opset_hdrs", + srcs = [ + "tf2xla_opset.h", + ], + visibility = ["//tensorflow/python/util:__pkg__"], +) + +cc_library( + name = "tf2xla_opset", + srcs = [ + "tf2xla_opset.cc", + ], + hdrs = ["tf2xla_opset.h"], + visibility = ["//tensorflow/python:__pkg__"], + deps = [ + ":tf2xla_util", + ":xla_op_registry", + "//tensorflow/compiler/jit:xla_device", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "tf2xla_opset_test", + srcs = [ + "tf2xla_opset_test.cc", + ], + deps = [ + ":tf2xla_opset", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index fef312ab635..ee91e574a02 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -1528,10 +1528,10 @@ Status FunctionalizeCond::FunctionalizeInternal() { // nesting. (CondId, AncestorId) is not enough, e.g. // pred1 = array_ops.placeholder(dtypes.bool, name='pred1') // pred2 = array_ops.placeholder(dtypes.bool, name='pred2') - // cond1 = control_flow_ops.cond(pred1, ...) - // cond2 = control_flow_ops.cond(pred2, ...) - // cond3 = control_flow_ops.cond(pred1, use cond1 and cond2) - // cond4 = control_flow_ops.cond(pred2, use cond1 and cond2) + // cond1 = cond.cond(pred1, ...) + // cond2 = cond.cond(pred2, ...) + // cond3 = cond.cond(pred1, use cond1 and cond2) + // cond4 = cond.cond(pred2, use cond1 and cond2) // cond3 and cond4 have the same (CondId, AncestorId), but they should not // be merged into one "If" node (because they have different predicates). std::deque> merge_clusters; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 91a9fc63716..b1fd82aeed4 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -65,7 +65,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, // Graph: // x = array_ops.placeholder(dtypes.int32) // y = array_ops.placeholder(dtypes.int32) -// z = control_flow_ops.cond( +// z = cond.cond( // math_ops.less(y, x), lambda: math_ops.multiply(y, 17), // lambda: math_ops.add(x, 23)) // diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 1b07033b5c8..f72a47ace77 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -47,12 +47,20 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" namespace tensorflow { +auto* graph_compiler_failed_compilation_op_count = + tensorflow::monitoring::Counter<1>::New( + /*metric_name=*/ + "/tensorflow/core/tf2xla/graph_compilation_failed_op_count", + /*metric_description=*/"Records an op that failed to compile", + /*metric_label=*/"op_name"); + namespace { Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, @@ -177,6 +185,9 @@ Status GraphCompiler::Compile() { device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); Status s = op_context.status(); if (!s.ok()) { + graph_compiler_failed_compilation_op_count + ->GetCell(params.op_kernel->def().op()) + ->IncrementBy(1); return AttachDef(s, n->def()); } } diff --git a/tensorflow/compiler/tf2xla/graph_compiler_test.cc b/tensorflow/compiler/tf2xla/graph_compiler_test.cc new file mode 100644 index 00000000000..6ec8b8f8793 --- /dev/null +++ b/tensorflow/compiler/tf2xla/graph_compiler_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/graph_compiler.h" + +#include + +#include +#include +#include "tensorflow/compiler/tf2xla/graph_compiler_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { + +using ::tensorflow::monitoring::testing::CellReader; + +constexpr char kOpCompilationFailureStreamz[] = + "/tensorflow/core/tf2xla/graph_compilation_failed_op_count"; + +class DummyOp : public XlaOpKernel { + public: + explicit DummyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_DEFAULT), DummyOp); +REGISTER_KERNEL_BUILDER(Name("NoOp").Device("XLA_TPU_JIT"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("NoOp").Device("XLA_CPU_JIT"), DummyOp); + +class MockAlwaysFailsOp : public XlaOpKernel { + public: + explicit MockAlwaysFailsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->CtxFailure(__FILE__, __LINE__, errors::InvalidArgument("MockBroken")); + } +}; + +REGISTER_OP("MockAlwaysFails") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +A test only Op that always fails to compile. +)doc"); + +REGISTER_KERNEL_BUILDER(Name("MockAlwaysFails").Device(DEVICE_DEFAULT), + MockAlwaysFailsOp); +REGISTER_KERNEL_BUILDER(Name("MockAlwaysFails").Device("XLA_CPU_JIT"), + MockAlwaysFailsOp); +REGISTER_KERNEL_BUILDER(Name("MockAlwaysFails").Device("XLA_TPU_JIT"), + MockAlwaysFailsOp); +REGISTER_XLA_OP(Name("MockAlwaysFails").CompilationOnly(), MockAlwaysFailsOp); + +class GraphCompilerTest : public ::testing::Test { + public: + void SetUp() override { + device_ = new tensorflow::XlaCompilationDevice( + tensorflow::SessionOptions(), tensorflow::DeviceType("XLA_TPU_JIT")); + device_mgr_ = std::make_unique(absl::WrapUnique(device_)); + } + + Status RunGraphCompiler(Graph& graph) { + ProcessFunctionLibraryRuntime runtime( + device_mgr_.get(), Env::Default(), nullptr, TF_GRAPH_DEF_VERSION, + &graph.flib_def(), OptimizerOptions()); + + xla::XlaBuilder builder("test_builder"); + XlaCompiler::Options options; + options.device_type = "XLA_TPU_JIT"; + + XlaCompiler xla_compiler(options); + + // Resource cleanup is messy, see the LINT.ThenChange for comments. + // LINT.IfChange + XlaContext* xla_context = new XlaContext(&xla_compiler, &builder, &graph); + core::ScopedUnref context_unref(xla_context); + xla_context->Ref(); + + auto step_container = + std::make_unique(0, [this](const string& name) { + Status status = this->device_->resource_manager()->Cleanup(name); + }); + auto container_status = step_container->Create( + device_->resource_manager(), XlaContext::kXlaContextResourceName, + xla_context); + + GraphCompiler graph_compiler( + device_, &graph, runtime.GetFLR(device_->name()), step_container.get()); + + return graph_compiler.Compile(); + // LINT.ThenChange(//tensorflow/compiler/tf2xla/xla_compiler.cc:ExecuteGraph) + } + + protected: + XlaCompilationDevice* device_; // Owned by device_mgr_ + std::unique_ptr device_mgr_; +}; + +TEST_F(GraphCompilerTest, CompilesGraph) { + Graph graph(OpRegistry::Global()); + + EXPECT_TRUE(RunGraphCompiler(graph).ok()); +} + +TEST_F(GraphCompilerTest, RecordsStreamzFailedCompilationNode) { + Graph graph(OpRegistry::Global()); + Node* mock_fail; + ASSERT_TRUE(NodeBuilder("mock_fail", "MockAlwaysFails") + .Finalize(&graph, &mock_fail) + .ok()); + graph.AddControlEdge(graph.source_node(), mock_fail); + graph.AddControlEdge(mock_fail, graph.sink_node()); + + CellReader op_reader(kOpCompilationFailureStreamz); + + EXPECT_FALSE(RunGraphCompiler(graph).ok()); + + EXPECT_EQ(op_reader.Delta("MockAlwaysFails"), 1); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index cb4ed43287b..ac616e542a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -172,28 +172,15 @@ tf_kernel_library( ":case_op", ":conv_op_helpers", ":if_op", + ":rng_converter_utils", ":tensor_list_utils", ":while_op", ":xla_call_module_op", - ":rng_converter_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@stablehlo//:chlo_ops", "//tensorflow/compiler/jit:xla_activity_listener", "//tensorflow/compiler/jit:xla_activity_proto_cc", - "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:mlir_xla_op_kernel", "//tensorflow/compiler/tf2xla:xla_compilation_device", @@ -238,6 +225,8 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:svd", "//tensorflow/compiler/xla/client/lib:tridiagonal", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -247,6 +236,17 @@ tf_kernel_library( "//tensorflow/core/kernels:stateless_random_ops_v2_header", "//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/util:overflow", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@stablehlo//:chlo_ops", ] + if_cuda_or_rocm( if_false = [], if_true = [":light_outside_compilation"], @@ -377,13 +377,38 @@ cc_library( ], ) +cc_library( + name = "xla_call_module_loader", + srcs = ["xla_call_module_loader.cc"], + hdrs = ["xla_call_module_loader.h"], + deps = [ + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/mlir_hlo", + "//tensorflow/compiler/xla/pjrt:mlir_to_hlo", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:regexp", + "//tensorflow/tsl/platform:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_serialization", + "@stablehlo//:vhlo_ops", + ], +) + tf_kernel_library( name = "xla_call_module_op", srcs = ["xla_call_module_op.cc"], deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + ":xla_call_module_loader", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -397,27 +422,14 @@ tf_kernel_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/pjrt:mlir_to_hlo", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", "//tensorflow/core/tpu:tpu_defs", - "//tensorflow/tsl/platform:regexp", "@com_google_absl//absl/strings", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - "@stablehlo//:chlo_ops", - "@stablehlo//:stablehlo_ops", - "@stablehlo//:stablehlo_passes", - "@stablehlo//:stablehlo_serialization", - "@stablehlo//:vhlo_ops", + "@llvm-project//llvm:Support", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index 1707fd10a74..e2b3e3ffcf5 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -35,13 +35,13 @@ class DataFormatDimMapOp : public XlaOpKernel { OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); - OP_REQUIRES(context, src_format.size() == 4 or src_format.size() == 5, + OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5, errors::InvalidArgument( absl::StrCat("Source format must of length 4 or 5, " "received src_format = ", src_format))); OP_REQUIRES( - context, dst_format.size() == 4 or dst_format.size() == 5, + context, dst_format.size() == 4 || dst_format.size() == 5, errors::InvalidArgument(absl::StrCat( "Destination format must of length 4 or 5, received dst_format = ", dst_format))); diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index f169d86e8b1..aaf6a8f89eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -521,8 +521,8 @@ void GenericTfCallback(void* stream_handle, void** buffers, const char* opaque, int opaque_len, XlaCustomCallStatus* status) { Status s = CallTfKernel(stream_handle, buffers, opaque, opaque_len); if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), - s.error_message().size()); + auto msg = s.message(); + XlaCustomCallStatusSetFailure(status, msg.data(), msg.size()); } } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 2b3ae968309..ad0366eba03 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" @@ -28,12 +29,17 @@ class TopKOp : public XlaOpKernel { public: explicit TopKOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); + DataType index_type; + OP_REQUIRES_OK(context, context->GetAttr("index_type", &index_type)); + OP_REQUIRES_OK(context, DataTypeToPrimitiveType(index_type, &index_type_)); } void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape(0); - int last_dim = input_shape.dims() - 1; - int last_dim_size = input_shape.dim_size(last_dim); + const StatusOr input_shape_or = context->InputXlaShape(0); + OP_REQUIRES_OK(context, input_shape_or.status()); + const xla::Shape& input_shape = *input_shape_or; + int last_dim = input_shape.dimensions_size() - 1; + int last_dim_size = input_shape.dimensions(last_dim); int64_t k; bool k_bound_inferrable = @@ -49,7 +55,7 @@ class TopKOp : public XlaOpKernel { OP_REQUIRES(context, k >= 0, errors::InvalidArgument("Need k >= 0, got ", k)); - OP_REQUIRES(context, input_shape.dims() >= 1, + OP_REQUIRES(context, input_shape.dimensions_size() >= 1, errors::InvalidArgument("input must be >= 1-D, got shape ", input_shape.DebugString())); @@ -64,7 +70,7 @@ class TopKOp : public XlaOpKernel { bool k_is_dynamic; OP_REQUIRES_OK(context, context->ResolveInputDynamismIntoPred(1, &k_is_dynamic)); - xla::XlaOp output_tuple = TopK(context->Input(0), k); + xla::XlaOp output_tuple = TopK(context->Input(0), k, index_type_); auto values = xla::GetTupleElement(output_tuple, 0); auto indices = xla::GetTupleElement(output_tuple, 1); if (k_is_dynamic) { @@ -78,11 +84,18 @@ class TopKOp : public XlaOpKernel { private: bool sorted_; + xla::PrimitiveType index_type_; }; -REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint( - "T", {DT_UINT32, DT_INT32, DT_UINT64, DT_INT64, DT_FLOAT, - DT_HALF, DT_DOUBLE, DT_BFLOAT16, DT_UINT8, DT_INT8}), +REGISTER_XLA_OP(Name("TopKV2") + .CompileTimeConstantInput("k") + .TypeConstraint("T", + {DT_UINT32, DT_INT32, DT_UINT64, DT_INT64, + DT_FLOAT, DT_HALF, DT_DOUBLE, DT_BFLOAT16, + DT_UINT8, DT_INT8, DT_INT16}) + .TypeConstraint("Tk", {DT_INT16, DT_INT32, DT_INT64}) + .TypeConstraint("index_type", + {DT_INT16, DT_INT32, DT_INT64}), TopKOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc new file mode 100644 index 00000000000..c8a82fbfa28 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -0,0 +1,473 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" + +#include +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/Serialization.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/regexp.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { + +namespace { + +// When adding a new version, write when it was added. Also change the default +// version in the constructor in xla.py. +// Version 1 used MHLO & CHLO, not supported anymore. +// Version 2 supports StableHLO & CHLO. From 10/2022. +const int VERSION_START_STABLE_HLO = 2; +// Version 3 supports platform checking and multiple platforms. From 02/2023. +const int VERSION_START_PLATFORMS = 3; +// Version 4 supports StableHLO with compatibility guarantees. +// Used from 03/2023. +const int VERSION_START_STABLE_HLO_COMPATIBILITY = 4; +// Version 5 add support to stablehlo.custom_call for host call tf graph. +// Used from 04/2023. +const int VERSION_SUPPORT_CUSTOM_CALL = 5; +const int VERSION_MINIMUM_SUPPORTED = VERSION_START_STABLE_HLO; +const int VERSION_MAXIMUM_SUPPORTED = VERSION_SUPPORT_CUSTOM_CALL; + +// Computes a dimension value from the dim_arg specification. +// The specification is of the form ".". +tsl::StatusOr ComputeDimensionValue( + int version, std::string dim_arg_spec, std::vector arguments, + mlir::OpBuilder op_builder, mlir::Type dim_arg_type) { + static const LazyRE2 dim_arg_spec_re = {R"((\d+).(\d+))"}; + int arg_idx, arg_axis_idx; + if (!RE2::FullMatch(dim_arg_spec, *dim_arg_spec_re, &arg_idx, + &arg_axis_idx)) { + return tsl::errors::InvalidArgument("Syntax error in dim_args_spec '", + dim_arg_spec, "'"); + } + if (arg_idx < 0 || arg_idx >= arguments.size()) { + return tsl::errors::InvalidArgument( + "Invalid argument index ", arg_idx, + " when the number of non-dimension arguments is ", arguments.size(), + " in dim_arg_spec '", dim_arg_spec, "'"); + } + mlir::RankedTensorType arg_type = + arguments[arg_idx].getType().dyn_cast(); + if (!arg_type) { + return tsl::errors::InvalidArgument( + "Argument ", arg_idx, " referenced in dim_arg_spec '", dim_arg_spec, + "' does not have a RankedTensorType"); + } + if (arg_axis_idx < 0 || arg_axis_idx >= arg_type.getShape().size()) { + return tsl::errors::InvalidArgument( + "Invalid axis index ", arg_axis_idx, + " when the rank of non-dimension argument ", arg_idx, " is ", + arg_type.getShape().size(), " in dim_arg_spec '", dim_arg_spec, "'"); + } + mlir::Value val; + mlir::Type get_dim_type = + mlir::RankedTensorType::get({}, op_builder.getI32Type()); + val = op_builder.create( + arguments[arg_idx].getLoc(), get_dim_type, arguments[arg_idx], + op_builder.getI64IntegerAttr(arg_axis_idx)); + if (dim_arg_type != get_dim_type) { + val = op_builder.create( + arguments[arg_idx].getLoc(), dim_arg_type, val); + } + return val; +} + +} // namespace + +tsl::StatusOr> XlaCallModuleLoader::Create( + mlir::MLIRContext *context, int version, std::string module_str, + std::vector dim_args_spec, int platform_index) { + if (version < VERSION_MINIMUM_SUPPORTED) { + return tsl::errors::InvalidArgument( + "XlaCallModuleOp with version ", version, + " is not supported anymore. Must be >= ", VERSION_MINIMUM_SUPPORTED); + } + if (version > VERSION_MAXIMUM_SUPPORTED) { + return tsl::errors::InvalidArgument( + "XlaCallModuleOp with version ", version, + " is not supported by this build. Must be <= ", + VERSION_MAXIMUM_SUPPORTED); + } + + if (version < VERSION_START_PLATFORMS) { + platform_index = -1; + } + + std::unique_ptr loader(new XlaCallModuleLoader); + TF_RETURN_IF_ERROR(loader->LoadAndPreprocessModule( + context, version, std::move(module_str), std::move(dim_args_spec), + platform_index)); + return loader; +} + +// Adds a wrapper for the "main" function to compute the platform index and the +// dimension arguments. +// +// The input module has the following structure: +// +// func public main(%arg_platform_index: i32, %arg_dim0: i32, %arg_dim1: i32, +// %arg0: f32[?, ?, 8]) { ... } +// +// where %arg_platform_index is the index of the current compilation platform +// among the declared `platforms` (missing if version < 3 or if platforms has +// fewer than 2 elements), %arg_dim0 and %arg_dim1 are dimension arguments +// (missing if dim_args_spec is empty). The value of the dimension arguments +// are computed based on the static shapes of the actual arguments +// (%arg0 and following). +// In the above example, the dim_args_spec array would have two elements, one +// for %arg_dim0 and one for %arg_dim1. E.g., ['0.0', '0.1'] specifies that +// %arg_dim0 should be set to the size of axis 0 or array argument 0 (%arg0), +// while %arg_dim1 should be set to the size of axis 1. +// The platform index argument must be a 0-dimensional 32-bit integer, and the +// dimension arguments must be 0-dimensional tensors of integer type. +// +// We create a new "main" function as follows: +// func public main(%arg0: f32[?, ?, 8]) { +// %arg_platform_index = stablehlo.constant +// %arg_dim0 = stablehlo.get_dimension_size(%arg0) dimension=0 +// %arg_dim1 = stablehlo.get_dimension_size(%arg0) dimension=1 +// %res = func.call _wrapped_main(%arg_platform_index, +// %arg_dim0, %arg_dim1, %arg0) +// return %res +// } +// func private _wrapped_main(%arg_platform_index: i32, +// %arg_dim0: i32, %arg_dim1: i32, +// %arg0: f32[?, ?, 8]) { +// ... the original main function ... +// } +// +// and then we run the inliner. This is important because in the +// RefineDynamicShapes method called in Compile we refine the shape of the +// array arguments. This would create a type error at the call to _wrapped_main +// with the expected type of %arg0. +tsl::Status XlaCallModuleLoader::AddMainWrapper() { + int nr_dim_args = dim_args_spec_.size(); + // Locate the 'main' function. + // This is the convention used by MlirToXlaComputation. + mlir::func::FuncOp orig_main = + module_->lookupSymbol("main"); + if (!orig_main) { + return tsl::errors::InvalidArgument("Cannot find 'main' in module"); + } + int nr_platform_args = 0; + if (platform_index_ >= 0) { + nr_platform_args = 1; + } + if (orig_main.getNumArguments() <= nr_platform_args + nr_dim_args) { + return tsl::errors::InvalidArgument( + "The module should have ", nr_platform_args, + " platform index arguments and ", nr_dim_args, + " dimension arguments, but it ", "has only ", + orig_main.getNumArguments(), " total arguments"); + } + mlir::Block &orig_main_body = orig_main.front(); + + mlir::SymbolTable::setSymbolVisibility( + orig_main, mlir::SymbolTable::Visibility::Private); + mlir::OpBuilder op_builder(module_->getBodyRegion()); + orig_main.setName(op_builder.getStringAttr("_wrapped_main")); + mlir::Location loc = module_->getLoc(); + std::vector new_main_arg_types( + orig_main.getArgumentTypes().begin() + nr_platform_args + nr_dim_args, + orig_main.getArgumentTypes().end()); + mlir::func::FuncOp new_main = op_builder.create( + loc, "main", + mlir::FunctionType::get(module_->getContext(), + /*inputs=*/new_main_arg_types, + /*results=*/orig_main.getResultTypes())); + mlir::SymbolTable::setSymbolVisibility(new_main, + mlir::SymbolTable::Visibility::Public); + mlir::Block *new_main_block = new_main.addEntryBlock(); + std::vector block_args(new_main_block->getArguments().begin(), + new_main_block->getArguments().end()); + op_builder.setInsertionPointToStart(new_main_block); + + std::vector call_args(orig_main_body.getNumArguments()); + for (int i = 0; i < orig_main_body.getNumArguments(); ++i) { + if (i < nr_platform_args + nr_dim_args) { + mlir::Type arg_type = orig_main.getArgument(i).getType(); + mlir::RankedTensorType arg_ranked_type = + arg_type.dyn_cast(); + if (!arg_ranked_type || + !arg_ranked_type.getElementType().dyn_cast() || + !arg_ranked_type.getShape().empty()) { + std::string argument_type = + (i < nr_platform_args) ? "platform index" : "dimension"; + return tsl::errors::InvalidArgument( + "Module argument at index ", i, + " should be a 0-dimensional integer-tensor ", argument_type, + " argument but has type ", mlir::debugString(arg_type)); + } + if (i < nr_platform_args) { + if (arg_ranked_type.getElementTypeBitWidth() != 32) { + return tsl::errors::InvalidArgument( + "Module argument at index ", i, + " should be a 0-dimensional 32-bit integer-tensor" + " platform index argument but has type ", + mlir::debugString(arg_type)); + } + call_args[i] = op_builder.create( + block_args[0].getLoc(), + op_builder.getI32IntegerAttr(platform_index_)); + } else { + TF_ASSIGN_OR_RETURN( + call_args[i], + ComputeDimensionValue( + version_, dim_args_spec_[i - nr_platform_args], block_args, + op_builder, orig_main.getArgument(i).getType())); + } + } else { + call_args[i] = + new_main_block->getArgument(i - nr_platform_args - nr_dim_args); + } + } + mlir::func::CallOp call_op = op_builder.create( + loc, orig_main.getResultTypes(), orig_main.getSymName(), call_args); + op_builder.create(loc, call_op.getResults()); + VLOG(3) << "XlaCallModule module with wrapper: " + << mlir::debugString(*module_); + + return tsl::OkStatus(); +} + +tsl::Status XlaCallModuleLoader::RefineDynamicShapes( + llvm::ArrayRef input_shapes) { + // Locate the (wrapped) 'main' function. + // This is the convention used by MlirToXlaComputation. + mlir::Block &main_body = main_.front(); + int nr_platform_args = (platform_index_ >= 0 ? 1 : 0); + int nr_dim_args = dim_args_spec_.size(); + int non_dimension_arguments = input_shapes.size(); + if (non_dimension_arguments != main_body.getNumArguments()) { + return tsl::errors::InvalidArgument( + "Incorrect number of arguments passed to XlaCallModule: ", + non_dimension_arguments, ". The module takes ", + main_body.getNumArguments() + nr_platform_args + nr_dim_args, + " arguments of which ", nr_platform_args, + " platform index arguments and ", nr_dim_args, + " dimension arguments. It must be called with ", + main_body.getNumArguments(), " arguments."); + } + + mlir::Builder builder(module_->getContext()); + std::vector static_array_input_types(non_dimension_arguments); + for (int i = 0, end = non_dimension_arguments; i < end; ++i) { + const xla::Shape &xla_shape = input_shapes[i]; + std::vector xla_dimensions(xla_shape.dimensions().begin(), + xla_shape.dimensions().end()); + TF_ASSIGN_OR_RETURN( + mlir::Type element_type, + ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder)); + mlir::Type type = mlir::RankedTensorType::get(xla_dimensions, element_type); + // TODO(burmako): This fails with an obscure compilation error. + // TF_ASSIGN_OR_RETURN( + // mlir::Type type, + // ConvertShapeToType(xla_shape, builder)); + VLOG(3) << "XlaCallModule static array input type #" << i << ": " + << mlir::debugString(type); + // TODO(b/278273480): Determine whether it's safe to override the element + // type using that from the input shape. + static_array_input_types[i] = type; + } + + // Refine 'main' argument types to use static input types instead. + // This will only change the argument types and will not propagate the + // additional type information further. For that, we'll need to run + // shape refinement as explained below. + // Before refining the argument types it is useful to run the inliner to + // remove calls that may be called with the input arguments. + mlir::PassManager pm_inline(module_->getContext()); + pm_inline.addPass(mlir::createInlinerPass()); + if (!mlir::succeeded(pm_inline.run(*module_))) { + return tsl::errors::InvalidArgument("Module inlining failed"); + } + VLOG(3) << "XlaCallModule module after inlining: " + << mlir::debugString(*module_); + + auto static_array_output_types = llvm::to_vector(main_.getResultTypes()); + for (auto i = 0; i < main_body.getNumArguments(); ++i) { + auto arg = main_body.getArgument(i); + arg.setType(static_array_input_types[i]); + // If the argument is used by `func.return`, then we also need to + // update function result types. It's not great that we need this hack, + // but in the future when we have stablehlo.func, stablehlo.return, etc, + // this will not be needed. + // TODO(burmako): Once https://github.com/openxla/stablehlo/issues/425 is + // fixed, clean this up. + for (mlir::OpOperand &use : arg.getUses()) { + if (auto ret = llvm::dyn_cast(use.getOwner())) { + static_array_output_types[use.getOperandNumber()] = arg.getType(); + } + } + } + main_.setType(builder.getFunctionType(static_array_input_types, + static_array_output_types)); + + // Verify the module before running passes on it. + // If the module doesn't pass verification, all sorts of weirdness might + // happen if we run the pass manager. + if (failed(verify(*module_))) { + VLOG(3) << "XlaCallModule module with verification failed: " + << mlir::debugString(*module_); + return tsl::errors::InvalidArgument("Module verification failed"); + } + mlir::PassManager pm(module_->getContext()); + if (VLOG_IS_ON(3)) { + auto print_before = [](mlir::Pass *, mlir::Operation *) { return true; }; + auto print_after = [](mlir::Pass *, mlir::Operation *) { return true; }; + pm.enableIRPrinting(print_before, print_after, /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false); + } + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); + pm.addNestedPass( + mlir::stablehlo::createStablehloCanonicalizeDynamismPass()); + if (!mlir::succeeded(pm.run(*module_))) { + return tsl::errors::InvalidArgument("Module shape refinement failed"); + } + + VLOG(3) << "XlaCallModule module with refined shapes: " + << mlir::debugString(*module_); + return tsl::OkStatus(); +} + +tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule( + mlir::MLIRContext *context, int version, std::string module_str, + std::vector dim_args_spec, int platform_index) { + context_ = context; + version_ = version; + dim_args_spec_ = std::move(dim_args_spec); + platform_index_ = platform_index; + + // Load a superset of dialects; we should check at serialization time that + // we only include allowable dialects. + context_->loadDialect(); + context_->loadDialect(); + context_->loadDialect(); + context_->loadDialect(); + context_->loadDialect(); + // Parses both IR text and bytecode. + if (version >= VERSION_START_STABLE_HLO_COMPATIBILITY) { + module_ = + mlir::stablehlo::deserializePortableArtifact(module_str, context_); + } else { + module_ = mlir::parseSourceString(module_str, context_); + } + + if (!module_) { + return tsl::errors::InvalidArgument("Cannot deserialize computation"); + } + VLOG(3) << "Parsed serialized module (version " << version + << ", platform_index = " << platform_index_ << ", dim_args_spec = [" + << absl::StrJoin(dim_args_spec_, ", ") << "])\n" + << mlir::debugString(*module_); + + if (failed(module_->verifyInvariants())) { + VLOG(1) << "MLIR verification failed."; + module_->dump(); + return tsl::errors::InvalidArgument("Error verifying module"); + } + main_ = module_->lookupSymbol("main"); + if (!main_) { + return tsl::errors::InvalidArgument("Cannot find 'main' in module"); + } + + if (!dim_args_spec_.empty() || platform_index_ >= 0) { + TF_RETURN_IF_ERROR(AddMainWrapper()); + main_ = module_->lookupSymbol("main"); + } + return tsl::OkStatus(); +} + +tsl::Status XlaCallModuleLoader::ValidateModule() { + bool moduleHasUnsupportedDialects = false; + bool moduleHasDynamicShapes = false; + + module_->walk([&](mlir::Operation *op) { + // StableHLO programs created by jax2tf only contain operations + // from Builtin, Func and StableHLO dialects. + if (!llvm::isa( + op->getDialect())) { + moduleHasUnsupportedDialects = true; + VLOG(3) << "Operation has unsupported dialects: " + << mlir::debugString(*op); + } + + // It's sufficient to only check results because operands either come from + // results or from block arguments which are checked below. + auto hasDynamicShape = [](mlir::Value value) { + auto shaped_type = value.getType().dyn_cast(); + return shaped_type ? !shaped_type.hasStaticShape() : false; + }; + bool opHasDynamicShapes = false; + opHasDynamicShapes |= llvm::any_of(op->getResults(), hasDynamicShape); + for (mlir::Region ®ion : op->getRegions()) { + opHasDynamicShapes |= + llvm::any_of(region.getArguments(), hasDynamicShape); + } + if (opHasDynamicShapes) { + moduleHasDynamicShapes = true; + VLOG(3) << "Operation has dynamic shapes: " << mlir::debugString(*op); + } + }); + + if (moduleHasUnsupportedDialects) + return tsl::errors::InvalidArgument("Module has unsupported dialects"); + if (moduleHasDynamicShapes) + return tsl::errors::InvalidArgument("Module has dynamic shapes"); + return tsl::OkStatus(); +} + +tsl::StatusOr XlaCallModuleLoader::ToXlaComputation() { + xla::XlaComputation xla_computation; + TF_RETURN_IF_ERROR( + MlirToXlaComputation(*module_, xla_computation, false, false)); + return xla_computation; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h new file mode 100644 index 00000000000..6196cfe1f20 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -0,0 +1,85 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_XLA_CALL_MODULE_LOADER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_XLA_CALL_MODULE_LOADER_H_ + +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { + +class XlaCallModuleLoader { + public: + static tsl::StatusOr> Create( + mlir::MLIRContext* context, int version, std::string module_str, + std::vector dim_args_spec, int platform_index); + + int nr_outputs() { return main_.getNumResults(); } + mlir::TypeRange output_types() { return main_.getResultTypes(); } + + // Refines the dynamic module arguments based on the static argument shapes. + // This assumes that the module has a "main" function without dimension args, + // but possibly with dynamic shapes. We read the static shapes of the inputs, + // then set them as the types of the function parameters, and run StableHLO + // shape refinement to specialize all dynamic shapes in the StableHLO program + // to static shapes. + // + // This method accepts a list of `llvm::ArrayRef` instead of `mlir::Type`. + // This is to prevent callers from accidentally passing `mlir::Type` owned by + // a context that's different from the one passed to `Create`, which could + // cause lifetime issues. + tsl::Status RefineDynamicShapes(llvm::ArrayRef input_shapes); + + // Validate that the module represents a statically-shaped StableHLO program, + // otherwise all sorts of weirdness might happen in the HLO exporter which is + // much easier to detect here. + tsl::Status ValidateModule(); + + tsl::StatusOr ToXlaComputation(); + + private: + XlaCallModuleLoader() = default; + + // Initializes the loader with the given serialized module string. + tsl::Status LoadAndPreprocessModule(mlir::MLIRContext* context, int version, + std::string module_str, + std::vector dim_args_spec, + int platform_index); + + // Adds a wrapper for the "main" function to compute the platform index and + // the dimension arguments. + tsl::Status AddMainWrapper(); + + mlir::MLIRContext* context_; + int version_; + mlir::OwningOpRef module_; + int platform_index_; + std::vector dim_args_spec_; + mlir::func::FuncOp main_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_XLA_CALL_MODULE_LOADER_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 6b30c17a9b0..fbb853528fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -14,441 +14,40 @@ limitations under the License. ==============================================================================*/ #include +#include #include +#include #include #include "absl/strings/str_join.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinDialect.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Verifier.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/DebugStringHelper.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "stablehlo/dialect/Serialization.h" // from @stablehlo -#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "stablehlo/dialect/VhloOps.h" // from @stablehlo -#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "llvm/ADT/ArrayRef.h" +#include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/tpu_defs.h" -#include "tensorflow/tsl/platform/regexp.h" namespace tensorflow { namespace { -// Version 1 used MHLO & CHLO, not supported anymore. -// Version 2 supports StableHLO & CHLO. From 10/2022. Minimum from 03/2023. -const int VERSION_START_STABLE_HLO = 2; -// Version 3 supports platform checking and multiple platforms. From 02/2023. -const int VERSION_START_PLATFORMS = 3; -// Version 4 supports StableHLO with compatibility guarantees. From 03/2023 -const int VERSION_START_STABLE_HLO_COMPATIBILITY = 4; -const int VERSION_MINIMUM_SUPPORTED = VERSION_START_STABLE_HLO; - -// Computes a dimension value from the dim_arg specification. -// The specification is of the form ".". -StatusOr ComputeDimensionValue(int version, string dim_arg_spec, - std::vector arguments, - mlir::OpBuilder op_builder, - mlir::Type dim_arg_type) { - static const LazyRE2 dim_arg_spec_re = {R"((\d+).(\d+))"}; - int arg_idx, arg_axis_idx; - if (!RE2::FullMatch(dim_arg_spec, *dim_arg_spec_re, &arg_idx, - &arg_axis_idx)) { - return errors::InvalidArgument("Syntax error in dim_args_spec '", - dim_arg_spec, "'"); - } - if (arg_idx < 0 || arg_idx >= arguments.size()) { - return errors::InvalidArgument( - "Invalid argument index ", arg_idx, - " when the number of non-dimension arguments is ", arguments.size(), - " in dim_arg_spec '", dim_arg_spec, "'"); - } - mlir::RankedTensorType arg_type = - arguments[arg_idx].getType().dyn_cast(); - if (!arg_type) { - return errors::InvalidArgument( - "Argument ", arg_idx, " referenced in dim_arg_spec '", dim_arg_spec, - "' does not have a RankedTensorType"); - } - if (arg_axis_idx < 0 || arg_axis_idx >= arg_type.getShape().size()) { - return errors::InvalidArgument("Invalid axis index ", arg_axis_idx, - " when the rank of non-dimension argument ", - arg_idx, " is ", arg_type.getShape().size(), - " in dim_arg_spec '", dim_arg_spec, "'"); - } - mlir::Value val; - mlir::Type get_dim_type = - mlir::RankedTensorType::get({}, op_builder.getI32Type()); - val = op_builder.create( - arguments[arg_idx].getLoc(), get_dim_type, arguments[arg_idx], - op_builder.getI64IntegerAttr(arg_axis_idx)); - if (dim_arg_type != get_dim_type) { - val = op_builder.create( - arguments[arg_idx].getLoc(), dim_arg_type, val); - } - return val; -} - -// Adds a wrapper for the "main" function to compute the platform index and the -// dimension arguments. -// -// The input module has the following structure: -// -// func public main(%arg_platform_index: i32, %arg_dim0: i32, %arg_dim1: i32, -// %arg0: f32[?, ?, 8]) { ... } -// -// where %arg_platform_index is the index of the current compilation platform -// among the declared `platforms` (missing if version < 3 or if platforms has -// fewer than 2 elements), %arg_dim0 and %arg_dim1 are dimension arguments -// (missing if dim_args_spec is empty). The value of the dimension arguments -// are computed based on the static shapes of the actual arguments -// (%arg0 and following). -// In the above example, the dim_args_spec array would have two elements, one -// for %arg_dim0 and one for %arg_dim1. E.g., ['0.0', '0.1'] specifies that -// %arg_dim0 should be set to the size of axis 0 or array argument 0 (%arg0), -// while %arg_dim1 should be set to the size of axis 1. -// The platform index argument must be a 0-dimensional 32-bit integer, and the -// dimension arguments must be 0-dimensional tensors of integer type. -// -// We create a new "main" function as follows: -// func public main(%arg0: f32[?, ?, 8]) { -// %arg_platform_index = stablehlo.constant -// %arg_dim0 = stablehlo.get_dimension_size(%arg0) dimension=0 -// %arg_dim1 = stablehlo.get_dimension_size(%arg0) dimension=1 -// %res = func.call _wrapped_main(%arg_platform_index, -// %arg_dim0, %arg_dim1, %arg0) -// return %res -// } -// func private _wrapped_main(%arg_platform_index: i32, -// %arg_dim0: i32, %arg_dim1: i32, -// %arg0: f32[?, ?, 8]) { -// ... the original main function ... -// } -// -// and then we run the inliner. This is important because in the -// RefineDynamicShapes method called in Compile we refine the shape of the -// array arguments. This would create a type error at the call to _wrapped_main -// with the expected type of %arg0. -Status AddMainWrapper(int version, mlir::ModuleOp module, int platform_index, - std::vector dim_args_spec) { - int nr_dim_args = dim_args_spec.size(); - // Locate the 'main' function. - // This is the convention used by MlirToXlaComputation. - mlir::func::FuncOp orig_main = - module.lookupSymbol("main"); - if (!orig_main) { - return errors::InvalidArgument("Cannot find 'main' in module"); - } - int nr_platform_args = 0; - if (platform_index >= 0) { - nr_platform_args = 1; - } - if (orig_main.getNumArguments() <= nr_platform_args + nr_dim_args) { - return errors::InvalidArgument("The module should have ", nr_platform_args, - " platform index arguments and ", - nr_dim_args, " dimension arguments, but it ", - "has only ", orig_main.getNumArguments(), - " total arguments"); - } - mlir::Block &orig_main_body = orig_main.front(); - - mlir::SymbolTable::setSymbolVisibility( - orig_main, mlir::SymbolTable::Visibility::Private); - mlir::OpBuilder op_builder(module.getBodyRegion()); - orig_main.setName(op_builder.getStringAttr("_wrapped_main")); - mlir::Location loc = module.getLoc(); - std::vector new_main_arg_types( - orig_main.getArgumentTypes().begin() + nr_platform_args + nr_dim_args, - orig_main.getArgumentTypes().end()); - mlir::func::FuncOp new_main = op_builder.create( - loc, "main", - mlir::FunctionType::get(module.getContext(), - /*inputs=*/new_main_arg_types, - /*results=*/orig_main.getResultTypes())); - mlir::SymbolTable::setSymbolVisibility(new_main, - mlir::SymbolTable::Visibility::Public); - mlir::Block *new_main_block = new_main.addEntryBlock(); - std::vector block_args(new_main_block->getArguments().begin(), - new_main_block->getArguments().end()); - op_builder.setInsertionPointToStart(new_main_block); - - std::vector call_args(orig_main_body.getNumArguments()); - for (int i = 0; i < orig_main_body.getNumArguments(); ++i) { - if (i < nr_platform_args + nr_dim_args) { - mlir::Type arg_type = orig_main.getArgument(i).getType(); - mlir::RankedTensorType arg_ranked_type = - arg_type.dyn_cast(); - if (!arg_ranked_type || - !arg_ranked_type.getElementType().dyn_cast() || - !arg_ranked_type.getShape().empty()) { - string argument_type = - (i < nr_platform_args) ? "platform index" : "dimension"; - return errors::InvalidArgument( - "Module argument at index ", i, - " should be a 0-dimensional integer-tensor ", argument_type, - " argument but has type ", debugString(arg_type)); - } - if (i < nr_platform_args) { - if (arg_ranked_type.getElementTypeBitWidth() != 32) { - return errors::InvalidArgument( - "Module argument at index ", i, - " should be a 0-dimensional 32-bit integer-tensor" - " platform index argument but has type ", - debugString(arg_type)); - } - call_args[i] = op_builder.create( - block_args[0].getLoc(), - op_builder.getI32IntegerAttr(platform_index)); - } else { - TF_ASSIGN_OR_RETURN( - call_args[i], - ComputeDimensionValue(version, dim_args_spec[i - nr_platform_args], - block_args, op_builder, - orig_main.getArgument(i).getType())); - } - } else { - call_args[i] = - new_main_block->getArgument(i - nr_platform_args - nr_dim_args); - } - } - mlir::func::CallOp call_op = op_builder.create( - loc, orig_main.getResultTypes(), orig_main.getSymName(), call_args); - op_builder.create(loc, call_op.getResults()); - VLOG(3) << "XlaCallModule module with wrapper: " << debugString(module); - - return OkStatus(); -} - -// Refines the dynamic module arguments based on the static argument shapes. -// This assumes that the module has a "main" function without dimension args, -// but possibly with dynamic shapes. We read the static shapes of the inputs, -// then set them as the types of the function parameters, and run StableHLO -// shape refinement to specialize all dynamic shapes in the StableHLO program -// to static shapes. -Status RefineDynamicShapes(XlaOpKernelContext *ctx, - mlir::OwningOpRef *module, - int nr_platform_args, int nr_dim_args) { - // Locate the (wrapped) 'main' function. - // This is the convention used by MlirToXlaComputation. - mlir::func::FuncOp main = (*module)->lookupSymbol("main"); - if (!main) { - return errors::InvalidArgument("Cannot find 'main' in module"); - } - mlir::Block &main_body = main.front(); - int non_dimension_arguments = ctx->num_inputs(); - if (non_dimension_arguments != main_body.getNumArguments()) { - return errors::InvalidArgument( - "Incorrect number of arguments passed to XlaCallModule: ", - non_dimension_arguments, ". The module takes ", - main_body.getNumArguments() + nr_platform_args + nr_dim_args, - " arguments of which ", nr_platform_args, - " platform index arguments and ", nr_dim_args, - " dimension arguments. It must be called with ", - main_body.getNumArguments(), " arguments."); - } - - mlir::Builder builder((*module)->getContext()); - std::vector static_array_input_types(non_dimension_arguments); - for (int i = 0, end = non_dimension_arguments; i < end; ++i) { - TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, ctx->InputXlaShape(i)); - std::vector xla_dimensions(xla_shape.dimensions().begin(), - xla_shape.dimensions().end()); - TF_ASSIGN_OR_RETURN( - mlir::Type element_type, - ConvertPrimitiveTypeToMLIRType(xla_shape.element_type(), builder)); - mlir::Type type = mlir::RankedTensorType::get(xla_dimensions, element_type); - // TODO(burmako): This fails with an obscure compilation error. - // OP_REQUIRES_VALUE( - // mlir::Type type, ctx, - // ConvertShapeToType(xla_shape, builder)); - VLOG(3) << "XlaCallModule static array input type #" << i << ": " - << debugString(type); - static_array_input_types[i] = type; - } - - // Refine 'main' argument types to use static input types instead. - // This will only change the argument types and will not propagate the - // additional type information further. For that, we'll need to run - // shape refinement as explained below. - // Before refining the argument types it is useful to run the inliner to - // remove calls that may be called with the input arguments. - mlir::PassManager pm_inline((*module)->getContext()); - pm_inline.addPass(mlir::createInlinerPass()); - if (!mlir::succeeded(pm_inline.run(**module))) { - return errors::InvalidArgument("Module inlining failed"); - } - VLOG(3) << "XlaCallModule module after inlining: " << debugString(module); - - auto static_array_output_types = llvm::to_vector(main.getResultTypes()); - for (auto i = 0; i < main_body.getNumArguments(); ++i) { - auto arg = main_body.getArgument(i); - arg.setType(static_array_input_types[i]); - // If the argument is used by `func.return`, then we also need to - // update function result types. It's not great that we need this hack, - // but in the future when we have stablehlo.func, stablehlo.return, etc, - // this will not be needed. - // TODO(burmako): Once https://github.com/openxla/stablehlo/issues/425 is - // fixed, clean this up. - for (mlir::OpOperand &use : arg.getUses()) { - if (auto ret = llvm::dyn_cast(use.getOwner())) { - static_array_output_types[use.getOperandNumber()] = arg.getType(); - } - } - } - main.setType(builder.getFunctionType(static_array_input_types, - static_array_output_types)); - - // Verify the module before running passes on it. - // If the module doesn't pass verification, all sorts of weirdness might - // happen if we run the pass manager. - if (failed(verify(**module))) { - VLOG(3) << "XlaCallModule module with verification failed: " - << debugString(**module); - return errors::InvalidArgument("Module verification failed"); - } - mlir::PassManager pm((*module)->getContext()); - if (VLOG_IS_ON(3)) { - auto print_before = [](mlir::Pass *, mlir::Operation *) { return true; }; - auto print_after = [](mlir::Pass *, mlir::Operation *) { return true; }; - pm.enableIRPrinting(print_before, print_after, /*printModuleScope=*/true, - /*printAfterOnlyOnChange=*/false); - } - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::stablehlo::createStablehloRefineShapesPass()); - if (!mlir::succeeded(pm.run(**module))) { - return errors::InvalidArgument("Module shape refinement failed"); - } - - VLOG(3) << "XlaCallModule module with refined shapes: " - << debugString(**module); - return OkStatus(); -} - -Status LoadAndPreprocessModule(int version, - mlir::OwningOpRef *module, - mlir::MLIRContext *context, string module_str, - std::vector dim_args_spec, - std::vector platforms, - int platform_index, int *nr_outputs) { - // Load a superset of dialects; we should check at serialization time that - // we only include allowable dialects. - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - // Parses both IR text and bytecode. - if (version >= VERSION_START_STABLE_HLO_COMPATIBILITY) { - *module = mlir::stablehlo::deserializePortableArtifact(module_str, context); - } else { - *module = mlir::parseSourceString(module_str, context); - } - - if (!*module) { - return errors::InvalidArgument("Cannot deserialize computation"); - } - VLOG(3) << "Parsed serialized module (version " << version - << ", platforms = [" << absl::StrJoin(platforms, ", ") << "]" - << ", platform_index = " << platform_index << ", dim_args_spec = [" - << absl::StrJoin(dim_args_spec, ", ") << "])\n" - << debugString(**module); - - if (failed((*module)->verifyInvariants())) { - VLOG(1) << "MLIR verification failed."; - (*module)->dump(); - return errors::InvalidArgument("Error verifying module"); - } - mlir::func::FuncOp main = (*module)->lookupSymbol("main"); - if (!main) { - return errors::InvalidArgument("Cannot find 'main' in module"); - } - - if (!dim_args_spec.empty() || platform_index >= 0) { - TF_RETURN_IF_ERROR( - AddMainWrapper(version, **module, platform_index, dim_args_spec)); - main = (*module)->lookupSymbol("main"); - } - *nr_outputs = main.getNumResults(); - return OkStatus(); -} - -// Validate that the module represents a statically-shaped StableHLO program, -// otherwise all sorts of weirdness might happen in the HLO exporter which -// is much easier to detect here. -Status ValidateModule(mlir::ModuleOp module) { - bool moduleHasUnsupportedDialects = false; - bool moduleHasDynamicShapes = false; - - module.walk([&](mlir::Operation *op) { - // StableHLO programs created by jax2tf only contain operations - // from Builtin, Func and StableHLO dialects. - if (!llvm::isa( - op->getDialect())) { - moduleHasUnsupportedDialects = true; - VLOG(3) << "Operation has unsupported dialects: " << debugString(op); - } - - // It's sufficient to only check results because operands either come from - // results or from block arguments which are checked below. - auto hasDynamicShape = [](mlir::Value value) { - auto shaped_type = value.getType().dyn_cast(); - return shaped_type ? !shaped_type.hasStaticShape() : false; - }; - bool opHasDynamicShapes = false; - opHasDynamicShapes |= llvm::any_of(op->getResults(), hasDynamicShape); - for (mlir::Region ®ion : op->getRegions()) { - opHasDynamicShapes |= - llvm::any_of(region.getArguments(), hasDynamicShape); - } - if (opHasDynamicShapes) { - moduleHasDynamicShapes = true; - VLOG(3) << "Operation has dynamic shapes: " << debugString(op); - } - }); - - if (moduleHasUnsupportedDialects) - return errors::InvalidArgument("Module has unsupported dialects"); - if (moduleHasDynamicShapes) - return errors::InvalidArgument("Module has dynamic shapes"); - return OkStatus(); -} - class XlaCallModuleOp : public XlaOpKernel { public: explicit XlaCallModuleOp(OpKernelConstruction *ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("version", &version_)); - OP_REQUIRES( - ctx, version_ >= VERSION_MINIMUM_SUPPORTED, - errors::InvalidArgument("XlaCallModuleOp with version ", version_, - " is not supported anymore. Must be >= ", - VERSION_MINIMUM_SUPPORTED)); + int version; + OP_REQUIRES_OK(ctx, ctx->GetAttr("version", &version)); string module_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("module", &module_str)); std::vector expected_output_shapes; OP_REQUIRES_OK(ctx, ctx->GetAttr("Sout", &expected_output_shapes)); std::vector expected_output_dtypes; OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &expected_output_dtypes)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec_)); + std::vector dim_args_spec; + OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec)); OP_REQUIRES(ctx, expected_output_shapes.size() == expected_output_dtypes.size(), errors::InvalidArgument("The size of Sout (", @@ -456,12 +55,14 @@ class XlaCallModuleOp : public XlaOpKernel { ") must match the size of Tout (", expected_output_dtypes.size(), ")")); std::vector platforms; - platform_index_ = -1; - if (version_ >= VERSION_START_PLATFORMS) { + // Index in platforms of the current platform, or -1 if module does not take + // a platform index arg. + int platform_index = -1; + if (ctx->HasAttr("platforms")) { OP_REQUIRES_OK(ctx, ctx->GetAttr("platforms", &platforms)); if (!platforms.empty()) { - std::string current_device_type = ctx->device_type().type_string(); - std::string current_platform = ""; + string current_device_type = ctx->device_type().type_string(); + string current_platform = ""; if (current_device_type == DEVICE_CPU_XLA_JIT) { current_platform = "CPU"; } else if (current_device_type == DEVICE_GPU_XLA_JIT) { @@ -491,46 +92,51 @@ class XlaCallModuleOp : public XlaOpKernel { // We only use a platform index arguments if we support at least 2 // platforms. if (platforms.size() > 1) { - platform_index_ = found_platform - platforms.begin(); + platform_index = found_platform - platforms.begin(); } } } - OP_REQUIRES_OK( - ctx, LoadAndPreprocessModule(version_, &module_, &context_, module_str, - dim_args_spec_, platforms, platform_index_, - &nr_outputs_)); + + auto loader = + XlaCallModuleLoader::Create(&context_, version, std::move(module_str), + std::move(dim_args_spec), platform_index); + OP_REQUIRES_OK(ctx, loader.status()); + loader_ = *std::move(loader); } void Compile(XlaOpKernelContext *ctx) override { - OP_REQUIRES_OK( - ctx, RefineDynamicShapes(ctx, &module_, (platform_index_ >= 0 ? 1 : 0), - dim_args_spec_.size())); - OP_REQUIRES_OK(ctx, ValidateModule(*module_)); + std::vector input_shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + auto shape = ctx->InputXlaShape(i); + OP_REQUIRES_OK(ctx, shape.status()); + input_shapes.push_back(*std::move(shape)); + } + OP_REQUIRES_OK(ctx, loader_->RefineDynamicShapes(input_shapes)); + OP_REQUIRES_OK(ctx, loader_->ValidateModule()); std::vector inputs(ctx->num_inputs()); for (int i = 0, end = ctx->num_inputs(); i < end; ++i) { inputs[i] = ctx->Input(i); } - xla::XlaComputation xla_computation; - OP_REQUIRES_OK( - ctx, MlirToXlaComputation(*module_, xla_computation, false, false)); + auto xla_computation = loader_->ToXlaComputation(); + OP_REQUIRES_OK(ctx, xla_computation.status()); if (VLOG_IS_ON(3)) { OP_REQUIRES_VALUE( const xla::HloModuleConfig module_config, ctx, xla::HloModule::CreateModuleConfigFromProto( - xla_computation.proto(), xla::GetDebugOptionsFromFlags())); + xla_computation->proto(), xla::GetDebugOptionsFromFlags())); OP_REQUIRES_VALUE(std::unique_ptr hlo_module, ctx, - xla::HloModule::CreateFromProto(xla_computation.proto(), - module_config)); + xla::HloModule::CreateFromProto( + xla_computation->proto(), module_config)); xla::HloPrintOptions options; options = xla::HloPrintOptions::ShortParsable(); VLOG(3) << "XlaCallModule converted to HLO module " << hlo_module->ToString(options); } - xla::XlaOp output = xla::Call(ctx->builder(), xla_computation, inputs); + xla::XlaOp output = xla::Call(ctx->builder(), *xla_computation, inputs); // Check that the resulting computation returns the expected shape OP_REQUIRES_VALUE(xla::Shape found_output_shape, ctx, @@ -538,25 +144,21 @@ class XlaCallModuleOp : public XlaOpKernel { VLOG(3) << "XlaCallModule compiled output shape : " << xla::ShapeUtil::HumanString(found_output_shape); - if (nr_outputs_ == 1) { + if (loader_->nr_outputs() == 1) { ctx->SetOutput(0, output); } else { - for (int i = 0; i < nr_outputs_; ++i) { + for (int i = 0; i < loader_->nr_outputs(); ++i) { ctx->SetOutput(i, xla::GetTupleElement(output, i)); } } } private: - int version_; - int nr_outputs_; - std::vector dim_args_spec_; - int platform_index_; // Index in platforms of the current platform, or -1 - // if module does not take a platform index arg. mlir::MLIRContext context_{mlir::MLIRContext::Threading::DISABLED}; - mlir::OwningOpRef module_; + std::unique_ptr loader_; }; REGISTER_XLA_OP(Name("XlaCallModule"), XlaCallModuleOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 05af57e551a..e1fa0821054 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -32,10 +32,10 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) - .error_message()); + .message()); EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); + .message()); EXPECT_TRUE( LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); test::ExpectTensorEqual(host_tensor, diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 5ffe2a06f34..3959ebb5771 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -163,42 +163,75 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( const DeviceSet* device_set, const ConfigProto& config_proto, const Graph& graph, const FunctionLibraryDefinition& function_library) const { - // Skip MLIR TF XLA Bridge if no TPU devices found and the non TPU graph is - // not qualified. - if (device_set && !HasTPUDevice(*device_set) && !EnableNonTpuBridge(graph)) { + // Skip MLIR TF/XLA Bridge if no TPU devices and no qualified CPU/GPU + // graphs are found. + bool has_tpu_device = device_set ? HasTPUDevice(*device_set) : false; + // GetPassState is called once before MlirBridgePass starts, and the pass + // gets skipped if it is disabled. Log such cases in this function. The cases + // where the pass is enabled will only be logged during their execution to + // prevent them from being counted twice. + if (device_set && !has_tpu_device && !EnableNonTpuBridge(graph)) { + // Only record CPU/GPU graphs that are qualified but filtered out + if (HasQualifiedNonTPUOp(graph)) { + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + /*device type*/ "cpu/gpu", + /*bridge version*/ "tfxla", + /*fallback_enabled*/ false, + /*result*/ "invalid_graph"); + } return MlirOptimizationPassState::Disabled; } // We set `uses_uninitialized_resource_args` to false here because the first // phase of the bridge is not affected by uninitialized resource args. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( - graph, &function_library, config_proto, + graph, &function_library, config_proto, /*is_tpu_graph*/ has_tpu_device, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/false, /*record_stats=*/false); + if (has_tpu_device) { + switch (policy) { + case MlirBridgeRolloutPolicy::kEnabledByUser: + return MlirOptimizationPassState::Enabled; + case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: + return MlirOptimizationPassState::FallbackEnabled; + case MlirBridgeRolloutPolicy::kDisabledByUser: + VLOG(1) << "Skipping MLIR TPU Bridge, disabled by user. " + "Old bridge will evaluate."; + metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, + "disabled_by_user"); + return MlirOptimizationPassState::Disabled; + case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: + VLOG(1) << "Skipping MLIR TPU Bridge, disabled because " + "graph has unsupported features. Old bridge will evaluate."; + metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, + "invalid_graph"); + // We set `uses_uninitialized_resource_args` to false here because the + // first phase of the bridge is not affected by uninitialized resource + // args. + // For Invalid Graph Analysis we need to log here because Run will not + // be called. + LogGraphFeatures(graph, &function_library, config_proto, + /*uses_uninitialized_resource_args=*/false, + /*is_v1_compat=*/false); + return MlirOptimizationPassState::Disabled; + } + } + // TODO(b/277112519): Have uniform behavior for GPU/CPU and TPU switch (policy) { case MlirBridgeRolloutPolicy::kEnabledByUser: return MlirOptimizationPassState::Enabled; case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: return MlirOptimizationPassState::FallbackEnabled; case MlirBridgeRolloutPolicy::kDisabledByUser: - VLOG(1) << "Skipping MLIR TPU Bridge, MLIR TPU bridge disabled by user. " - "Old bridge will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, + VLOG(1) << "Skipping MLIR CPU/GPU Bridge, disabled by user."; + metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "tfxla", false, "disabled_by_user"); return MlirOptimizationPassState::Disabled; - case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: - VLOG(1) << "Skipping MLIR TPU Bridge, MLIR TPU bridge disabled because " - "graph has unsupported features. Old bridge will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, + default: + // This case should never be hit. Added here to be consistent with OSS + // implementation. + metrics::UpdateTfMlirBridgeFirstPhaseCounter("cpu/gpu", "ftxla", false, "invalid_graph"); - // We set `uses_uninitialized_resource_args` to false here because the - // first phase of the bridge is not affected by uninitialized resource - // args. - // For Invalid Graph Analysis we need to log here because Run will not be - // called. - LogGraphFeatures(graph, &function_library, config_proto, - /*uses_uninitialized_resource_args=*/false, - /*is_v1_compat=*/false); return MlirOptimizationPassState::Disabled; } } @@ -209,14 +242,15 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( // and attached to a "compile" operation, whose result is fed to an "execute" // operation. The kernel for these operations is responsible to lower the // encapsulated graph to a particular device. -Status MlirBridgePass::Run(const ConfigProto& config_proto, +Status MlirBridgePass::Run(const std::string& function_name, + const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph, const FunctionLibraryDefinition& function_library) { static absl::once_flag flag; absl::call_once(flag, UpdateLogVerbosityIfDefined, "TF_DEBUG_LOG_VERBOSITY"); // Check if there are TPU devices or TPU ops. If not, then check if the - // non TPU graph is qualified to run TF XLA Bridge. + // non TPU graph is qualified to run TF2XLA Bridge. // This check needs to precede GetPassState for instrumentation purposes. bool is_qualified_for_tpu_bridge = HasTPUDevicesAndOps(module), is_qualified_for_non_tpu_bridge = false; @@ -224,7 +258,7 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto, is_qualified_for_non_tpu_bridge = EnableNonTpuBridge(graph); if (!is_qualified_for_tpu_bridge && !is_qualified_for_non_tpu_bridge) { VLOG(1) - << "Skipping MLIR TF XLA Bridge, no qualified devices or ops found."; + << "Skipping MLIR TF2XLA Bridge, no qualified devices or ops found."; return OkStatus(); } @@ -259,11 +293,10 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto, } VLOG(1) << "Running MLIR TPU Bridge"; mlir_bridge_gauge_v2->GetCell()->Set(true); - return mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1), - fallback_enabled); + return mlir::TFTPU::TPUBridge(module, fallback_enabled, function_name); } - VLOG(1) << "Running MLIR non-TPU Bridge"; - return mlir::TF::RunTFXLABridge(module, VLOG_IS_ON(1)); + VLOG(1) << "Running MLIR CPU/GPU Bridge"; + return mlir::TF::RunTFXLABridge(module, function_name); } MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( @@ -277,6 +310,7 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( // phase of the bridge is not affected by uninitialized resource args. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( graph, /*function_library=*/&function_library, config_proto, + /*is_tpu_graph*/ true, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/true, /*record_stats=*/false); switch (policy) { @@ -356,8 +390,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, mlir_bridge_gauge_v1->GetCell()->Set(true); - return mlir::TFTPU::TPUBridgeV1Compat( - module, /*enable_logging=*/VLOG_IS_ON(1), fallback_enabled); + return mlir::TFTPU::TPUBridgeV1Compat(module, fallback_enabled); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index f0f8424cab5..ff32bc5f9ad 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ #define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ +#include + #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" #include "llvm/ADT/StringRef.h" #include "tensorflow/compiler/jit/flags.h" @@ -37,8 +39,8 @@ class MlirBridgePass : public MlirOptimizationPass { // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime. - Status Run(const ConfigProto& config_proto, mlir::ModuleOp module, - const Graph& graph, + Status Run(const std::string& function_name, const ConfigProto& config_proto, + mlir::ModuleOp module, const Graph& graph, const FunctionLibraryDefinition& function_library) override; }; diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc index 0e61e144f93..694a1a15910 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc @@ -15,12 +15,43 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/resource_op_kernel.h" namespace tensorflow { +namespace { + +class MLIRContextResource : public ResourceBase { + public: + static constexpr const char* kDefaultResourceName = + "mlir-xla-op-cached-context"; + + static Status Create(MLIRContextResource** resource) { + *resource = new MLIRContextResource(); + return OkStatus(); + } + mlir::MLIRContext* GetContext() { return &mlir_ctx_; } + std::string DebugString() const override { + return "MlirXlaOpKernel MLIRContext resource"; + } + + private: + // Since this kernel implements lowering for a single TF operation, we + // disable MLIR threading for efficiency purpose (avoid starting a large + // number of threads eagerly). + MLIRContextResource() : mlir_ctx_(mlir::MLIRContext::Threading::DISABLED) {} + mlir::MLIRContext mlir_ctx_; +}; + +} // namespace + Status MlirXlaOpKernel::ContextToXlaArgs( XlaOpKernelContext* ctx, std::vector& xla_args) { // Collect arguments that are registered as CompileTimeConstantInput. @@ -57,11 +88,7 @@ Status MlirXlaOpKernel::ContextToXlaArgs( } MlirXlaOpKernel::MlirXlaOpKernel(OpKernelConstruction* ctx) - : XlaOpKernel(ctx), - // Since this kernel implements lowering for a single TF operation, we - // disable MLIR threading for efficiency purpose (avoid starting a large - // number of threads eagerly). - mlir_ctx_(mlir::MLIRContext::Threading::DISABLED) {} + : XlaOpKernel(ctx) {} Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { // Create input XlaArguments. @@ -99,11 +126,19 @@ Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { TF_ASSIGN_OR_RETURN(auto graph, CreateSingleOpGraph(def(), xla_args, result_dtypes)); + ResourceMgr* res_manager = ctx->op_kernel_context()->resource_manager(); + MLIRContextResource* ctx_res; + TF_RETURN_IF_ERROR(res_manager->LookupOrCreate( + res_manager->default_container(), + MLIRContextResource::kDefaultResourceName, &ctx_res, + MLIRContextResource::Create)); + core::ScopedUnref unref_ctx(ctx_res); + // Compile the graph to HLO. GraphDebugInfo debug_info; std::vector returns(1); TF_RETURN_IF_ERROR(BuildHloFromGraph( - *graph, *ctx->builder(), mlir_ctx_, xla_params, returns, + *graph, *ctx->builder(), *ctx_res->GetContext(), xla_params, returns, mlir::SpanToArrayRef(xla_args), control_rets, device->device_type(), *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info, diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h index e4ece6e692a..ec62bd98a21 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_ #define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_ -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" namespace tensorflow { @@ -32,7 +31,6 @@ class MlirXlaOpKernel : public XlaOpKernel { std::vector& xla_args); void Compile(XlaOpKernelContext* ctx) override; Status ConstructXlaOp(XlaOpKernelContext* ctx); - mlir::MLIRContext mlir_ctx_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index a66c953efeb..e536ffa3746 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -1325,6 +1325,7 @@ REGISTER_OP("XlaCallModule") .Attr("Tin: list(type) >= 0") .Attr("dim_args_spec: list(string) = []") .Attr("platforms: list(string) = []") + .Attr("function_list: list(func) = []") .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector args_shapes; TF_RETURN_IF_ERROR(c->input("args", &args_shapes)); @@ -1347,8 +1348,8 @@ REGISTER_OP("XlaCallModule") .Doc(R"doc( Invokes a StableHLO module. -This op is experimental and is intended for use with JAX native serialization -in a TensorFlow context. +This op is used with JAX native serialization in a TensorFlow context with +stability guarantees. args: A list of `Tensor` with possibly different types to be passed as arguments to the `module`. These are the actual arguments and do not include the @@ -1357,7 +1358,10 @@ args: A list of `Tensor` with possibly different types to be passed as arguments version: Tracks changes the semantics of the op, to support backwards compatibility. Minimum supported version is 2. From version 2, the op carries a StableHLO text or bytecode `module`. From - version 3, the op also supports the `platforms` attribute. + version 3, the op also supports the `platforms` attribute. From version 4, + the op carries a StableHLO module with compatibility guarantees. From version + 5, XLACallModule can include `stablehlo.custom_call` op to execute tf + functions. module: A serialized computation, a text or bytecode representation of an mlir.Module. The return type must be a tuple if and only if the `Sout` is a list with 0 or more than 1 elements. The length of `Tout` and @@ -1382,6 +1386,11 @@ dim_args_spec: in presence of dynamic shapes, this is the specification for the string of the form "." that specifies that the value of the corresponding dimension argument must be "args[arg_idx].shape[axis_idx]", where "args" are the actual array arguments. +function_list: This list contains the TensorFlow FunctionDefs that are used by + the XLACallModule. If the XLACallModule contains `stablehlo.custom_call` + operations, they can call TensorFlow graph functions outside of the + XLACallModule. This `function_list` attribute registers the dependency of the + XLACallModule on those functions. This attribute was added in version 5. )doc"); } // namespace diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index d1273cd403e..61d2be76ac1 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -604,12 +604,12 @@ def custom_call_v2( ) -def call_module(args, *, version=2, module, Tout, Sout, - dim_args_spec=(), platforms=()): +def call_module(args, *, version=4, module, Tout, Sout, + dim_args_spec=(), platforms=(), function_list=()): # See documentation for the XlaCallModule op. return gen_xla_ops.xla_call_module( args, version=version, module=module, dim_args_spec=dim_args_spec, - Tout=Tout, Sout=Sout, platforms=platforms) + Tout=Tout, Sout=Sout, platforms=platforms, function_list=function_list) def gather(operand, diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index 24cd31b3ba9..942b3ef0bdc 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -366,8 +366,13 @@ Status MaybeRewriteWhileNode( string new_name = fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_")); TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); - TF_RETURN_IF_ERROR( - fld->AddFunctionDef(new_fdef, fld->GetStackTraces(attr_value.name()))); + + const StackTracesMap* stack_traces = fld->GetStackTraces(attr_value.name()); + if (stack_traces != nullptr) { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, *stack_traces)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, {})); + } // Change node to use rewritten function. attr_value.set_name(new_name); @@ -457,11 +462,18 @@ Status MaybeRewriteIfNode( string new_name = fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_")); TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); - const StackTracesMap& stack_traces = - fld->GetStackTraces(f.name()).empty() && global_fld - ? global_fld->GetStackTraces(f.name()) - : fld->GetStackTraces(f.name()); - TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, stack_traces)); + + const StackTracesMap* global_stack_traces = + global_fld ? global_fld->GetStackTraces(f.name()) : nullptr; + const StackTracesMap* local_stack_traces = fld->GetStackTraces(f.name()); + + if (global_stack_traces != nullptr) { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, *global_stack_traces)); + } else if (local_stack_traces != nullptr) { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, *local_stack_traces)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef, {})); + } // Change node to use rewritten function. f.set_name(new_name); diff --git a/tensorflow/compiler/tf2xla/tf2xla_opset.cc b/tensorflow/compiler/tf2xla/tf2xla_opset.cc new file mode 100644 index 00000000000..a2a9ddde35b --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_opset.cc @@ -0,0 +1,96 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_opset.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def.pb.h" + +namespace tensorflow { + +const int SUPPORTED_DEVICES_NUM = 2; +static const char* const SUPPORTED_DEVICES[SUPPORTED_DEVICES_NUM] = { + DEVICE_GPU_XLA_JIT, DEVICE_CPU_XLA_JIT}; + +bool IsSupportedBackend(absl::string_view device_name) { + for (int i = 0; i < SUPPORTED_DEVICES_NUM; i++) { + if (SUPPORTED_DEVICES[i] == device_name) return true; + } + return false; +} + +absl::Status RegisterBackends(absl::string_view device_name) { + if (!IsSupportedBackend(device_name)) { + return absl::InvalidArgumentError( + absl::StrCat(device_name, " is not supported. Supported devices are ", + absl::StrJoin(SUPPORTED_DEVICES, ", "))); + } + // All backends need to be registered before DeviceKernels is called + // because it calls RegisterCompilationKernels which will only run 1x, + // meaning if a device is registered afterwards the ops for that device + // will not be included. + auto op_filter = [](KernelDef* kdef) { + if (kdef->op() == "Const") { + AddDtypeToKernelDefConstraint("dtype", DT_STRING, kdef); + } + if (kdef->op() == "Assert") { + AddDtypeToKernelDefConstraint("T", DT_STRING, kdef); + } + return true; + }; + + // Backends might already be registered due to preprocesser macros defined + // in xla_op_registery.h so this first checks to see if they are registered + // already because re-registering the same device will cause a failure. + if (!XlaOpRegistry::IsBackendRegistered(DEVICE_GPU_XLA_JIT)) { + static auto gpu_backend = + XlaBackendRegistrar(DEVICE_GPU_XLA_JIT, kGpuAllTypes, op_filter); + } + if (!XlaOpRegistry::IsBackendRegistered(DEVICE_CPU_XLA_JIT)) { + static auto cpu_backend = + XlaBackendRegistrar(DEVICE_CPU_XLA_JIT, kCpuAllTypes, op_filter); + } + if (!XlaOpRegistry::IsBackendRegistered(std::string(device_name))) { + return absl::InternalError( + absl::StrCat(device_name, " is not registered.")); + } + return absl::OkStatus(); +} + +absl::StatusOr> GetRegisteredXlaOpsForDevice( + absl::string_view device_name) { + auto status = RegisterBackends(device_name); + if (!status.ok()) return status; + + std::vector kernel_defs = + XlaOpRegistry::DeviceKernels(std::string(device_name), true); + std::vector op_names; + op_names.reserve(kernel_defs.size()); + for (const auto& kernel_def : kernel_defs) { + op_names.push_back(kernel_def->op()); + } + std::sort(op_names.begin(), op_names.end()); + return op_names; +} +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_opset.h b/tensorflow/compiler/tf2xla/tf2xla_opset.h new file mode 100644 index 00000000000..37fa8f3940f --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_opset.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_OPSET_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_OPSET_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace tensorflow { + +absl::StatusOr> GetRegisteredXlaOpsForDevice( + absl::string_view device_name); + +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_OPSET_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_opset_test.cc b/tensorflow/compiler/tf2xla/tf2xla_opset_test.cc new file mode 100644 index 00000000000..f7031e06a4f --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_opset_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_opset.h" + +#include +#include +#include + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(GeXlaOpsForDeviceTest, InvalidDeviceToRegister) { + absl::StatusOr> result = + GetRegisteredXlaOpsForDevice("Invalid_Device"); + EXPECT_FALSE(result.ok()); +} +TEST(GeXlaOpsForDeviceTest, GetGpuNames) { + absl::StatusOr> result = + GetRegisteredXlaOpsForDevice("XLA_GPU_JIT"); + EXPECT_GT(result.value().size(), 0); + auto matmul = + std::find(result.value().begin(), result.value().end(), "MatMul"); + auto max = std::find(result.value().begin(), result.value().end(), "Max"); + auto min = std::find(result.value().begin(), result.value().end(), "Min"); + EXPECT_TRUE((matmul != result.value().end())); + EXPECT_TRUE((max != result.value().end())); + EXPECT_TRUE((min != result.value().end())); + EXPECT_LT(matmul, max); + EXPECT_LT(max, min); +} +TEST(GeXlaOpsForDeviceTest, GetCpuNames) { + absl::StatusOr> result = + GetRegisteredXlaOpsForDevice("XLA_CPU_JIT"); + EXPECT_GT(result.value().size(), 0); + auto matmul = + std::find(result.value().begin(), result.value().end(), "MatMul"); + auto max = std::find(result.value().begin(), result.value().end(), "Max"); + auto min = std::find(result.value().begin(), result.value().end(), "Min"); + EXPECT_TRUE((matmul != result.value().end())); + EXPECT_TRUE((max != result.value().end())); + EXPECT_TRUE((min != result.value().end())); + EXPECT_LT(matmul, max); + EXPECT_LT(max, min); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 2840076a3c3..f896f97a462 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -256,10 +257,15 @@ Status PropagateConstIntoFuncAttr( FunctionDef replace_fdef; string new_func_name = fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_")); + const StackTracesMap* stack_traces = + lookup_fld->GetStackTraces(func_attr.name()); TF_RETURN_IF_ERROR( GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef)); - TF_RETURN_IF_ERROR(fld->AddFunctionDef( - replace_fdef, lookup_fld->GetStackTraces(func_attr.name()))); + if (stack_traces != nullptr) { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef, *stack_traces)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef, {})); + } VLOG(1) << "replace func " << func_attr.name() << " with " << new_func_name; // Change the node to use rewritten function. @@ -267,9 +273,6 @@ Status PropagateConstIntoFuncAttr( n->ClearAttr(attr_name); n->AddAttr(attr_name, func_attr); - TF_RETURN_IF_ERROR(fld->AddFunctionDef( - replace_fdef, lookup_fld->GetStackTraces(func_attr.name()))); - // Copy associated functions. TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld)); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 70ca4576be0..d3ba6133243 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -42,8 +42,8 @@ namespace { void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(OkStatus(), status); - EXPECT_TRUE(absl::StrContains(status.error_message(), str)) - << "expected error: " << status.error_message() << " to contain: " << str; + EXPECT_TRUE(absl::StrContains(status.message(), str)) + << "expected error: " << status.message() << " to contain: " << str; } TEST(ValidateConfig, Good) { diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index bd57112ccdc..1c24cffa93d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include +#include #include "tensorflow/compiler/xla/cpu_function_runtime.h" @@ -24,9 +25,12 @@ namespace tensorflow { XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, AllocMode alloc_mode) : raw_function_(static_data.raw_function_), + run_function_(static_data.run_function_), + cpu_executable_(static_data.cpu_executable_), result_index_(static_data.result_index_), buffer_table_(new void*[static_data.num_buffers_]), buffer_infos_(static_data.buffer_infos_), + num_buffers_(static_data.num_buffers_), arg_index_table_(static_data.arg_index_table_), num_args_(static_data.num_args_), num_variables_(static_data.num_variables_), @@ -53,12 +57,29 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, } bool XlaCompiledCpuFunction::Run() { + if (run_function_) { + std::vector descriptor_table = + MakeXlaRuntimeDescriptorTable(); + return run_function_(cpu_executable_, descriptor_table, &run_options_); + } XlaCustomCallStatus status; raw_function_(buffer_table_[result_index_], &run_options_, nullptr, buffer_table_, &status, profile_counters_); return !xla::CustomCallStatusGetMessage(&status).has_value(); } +std::vector +XlaCompiledCpuFunction::MakeXlaRuntimeDescriptorTable() { + std::vector descriptor_table; + descriptor_table.reserve(num_buffers_); + for (int32_t i = 0; i < num_buffers_; ++i) { + void* data = buffer_table_[i]; + uint64_t size = buffer_infos_[i].size(); + descriptor_table.emplace_back(data, size); + } + return descriptor_table; +} + XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { xla::cpu_function_runtime::FreeContiguous(alloc_buffer_table_); delete[] buffer_table_; diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 8e707278ed8..176f203e924 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -18,9 +18,11 @@ limitations under the License. #include #include +#include #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_desc.h" #include "tensorflow/compiler/xla/service/custom_call_status_internal.h" #include "tensorflow/core/platform/types.h" @@ -29,6 +31,10 @@ limitations under the License. namespace xla { class ProgramShapeProto; class HloProfilePrinterData; + +namespace cpu { +class CpuExecutable; +} // namespace cpu } // namespace xla namespace tensorflow { @@ -54,6 +60,10 @@ class XlaCompiledCpuFunction { const xla::ExecutableRunOptions* run_options, const void** args, void** temps, XlaCustomCallStatus*, int64_t* profile_counters); + using RunFunction = + bool (*)(const xla::cpu::CpuExecutable* cpu_executable, + const std::vector& descriptor_table, + const xla::ExecutableRunOptions* run_options); // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for @@ -66,9 +76,12 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function_; + RunFunction run_function_ = nullptr; + const xla::cpu::CpuExecutable* cpu_executable_ = nullptr; + // Contains information about the buffers used by the XLA computation. const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; - size_t num_buffers_ = 0; + int32_t num_buffers_ = 0; // Entry parameter i is described by // buffer_infos[arg_index_table[i]]. @@ -278,6 +291,16 @@ class XlaCompiledCpuFunction { static_data->raw_function_ = raw_function; } + static void set_static_data_run_function(StaticData* static_data, + RunFunction run_function) { + static_data->run_function_ = run_function; + } + + static void set_static_data_cpu_executable( + StaticData* static_data, const xla::cpu::CpuExecutable* cpu_executable) { + static_data->cpu_executable_ = cpu_executable; + } + static void set_static_data_buffer_infos( StaticData* static_data, const xla::cpu_function_runtime::BufferInfo* buffer_infos) { @@ -347,6 +370,12 @@ class XlaCompiledCpuFunction { private: const RawFunction raw_function_; + // TODO(ecg): RunFunction and CpuExecutable should go away. Instead, we should + // have a pointer or reference to a minimal wrapper around CpuExecutable's + // Execute(), without CpuExecutable's dependences. We could call this wrapper + // "XlaRuntimeRunner". + const RunFunction run_function_; + const xla::cpu::CpuExecutable* cpu_executable_; const size_t result_index_; // Array containing pointers to argument and temp buffers (slots corresponding @@ -355,6 +384,7 @@ class XlaCompiledCpuFunction { // Describes the buffers used by the XLA computation. const xla::cpu_function_runtime::BufferInfo* const buffer_infos_; + const int32 num_buffers_; // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. @@ -383,6 +413,9 @@ class XlaCompiledCpuFunction { const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + // Creates a descriptor table for XLA Runtime. + std::vector MakeXlaRuntimeDescriptorTable(); + // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the // `set_static_data_*` static methods above. friend class XlaJitCompiledCpuFunction; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d8a227c6de3..ef7c45f0a4b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include #include #include +#include #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" @@ -53,6 +55,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/node_builder.h" @@ -61,7 +64,6 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/dump_graph.h" @@ -124,6 +126,10 @@ ComputeArgAndRetvalShardings(const Graph& graph) { return std::make_pair(std::move(arg_shardings), std::move(retval_shardings)); } +// Due to the wonkiness with Resource Cleanup, changing how resources are +// cleaned up here need to change how resources are cleaned up in +// graph_compiler_test. +// LINT.IfChange(ExecuteGraph) Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, XlaCompilationDevice* device, FunctionLibraryRuntime* flib, int64_t step_id) { @@ -150,6 +156,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_container.reset(); return status; } +// LINT.ThenChange(//tensorflow/compiler/tf2xla/graph_compiler_test.cc) // Builds the XLA computation. // - `args` is the list of input arguments @@ -570,7 +577,7 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, } TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), - "Local lookup failed with: ", status.error_message()); + "Local lookup failed with: ", status.message()); if (config_proto) { *config_proto = flib_runtime_->config_proto(); } @@ -896,8 +903,16 @@ Status XlaCompiler::CompileFunction( } } else { VLOG(1) << "MLIR bridge off. Using the old bridge to compile the function"; - TF_RETURN_IF_ERROR( - CompileGraph(options, function_id, std::move(graph), args, result)); + auto status = + CompileGraph(options, function_id, std::move(graph), args, result); + if (!status.ok()) { + ::tsl::errors::AppendToMessage( + &status, "tf2xla conversion failed while converting ", function_id, + ". Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and " + "--vmodule=xla_compiler=2 to obtain a dump of the compiled " + "functions."); + return status; + } } VLOG(1) << "===================================================="; @@ -1325,7 +1340,7 @@ Status ValidateGraph(const Graph* graph, std::string errmsg = absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", - s.error_message(), ")", FormatNodeForError(*node)); + s.message(), ")", FormatNodeForError(*node)); if (absl::StrContains(device_type.type_string(), "TPU")) { absl::StrAppend(&errmsg, "\nOne approach is to outside compile the unsupported " @@ -1382,12 +1397,43 @@ void ConvertConstantsToExpressions(xla::XlaBuilder* builder, } // namespace +// A temporary dummy stack trace, used to identify locations where stack trace +// info is being lost, and to clarify how stack trace info is otherwise being +// handled in individual passes. This class and its usage below will be removed +// once we have robust end-to-end metadata handling. +// TODO(b/265059672): Remove when end-to-end stack trace handling is in place +class DummyStackTrace : public AbstractStackTrace { + absl::Span ToFrames() const override { return frames_; } + + StackFrame LastUserFrame() const override { return frames_.back(); } + + std::vector GetUserFrames(int /*limit*/) const override { + return frames_; + } + + std::string ToString(const TracePrintingOptions& opts) const override { + auto frame = LastUserFrame(); + return absl::StrCat(frame.file_name, ":", frame.line_number, ":", + frame.function_name); + } + + std::vector frames_{ + StackFrame({"dummy_file_name", 10, "dummy_function_name"})}; +}; + Status XlaCompiler::CompileGraph( const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; + DummyStackTrace stack_trace; + for (auto node : graph->nodes()) { + if (node->GetStackTrace() == nullptr) { + node->SetStackTrace(std::make_shared(stack_trace)); + } + } + TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( graph.get(), options_.flib_def, local_flib_def_.get())); TF_RETURN_IF_ERROR(RearrangeFunctionArguments( @@ -1635,7 +1681,7 @@ Status XlaCompiler::GetHostComputeControlDependency( } Status XlaCompiler::SetHostComputeControlDependency( - const string& host_compute_name, const xla::XlaOp& handle) { + const string& host_compute_name, const xla::XlaOp handle) { if (host_compute_control_output_.find(host_compute_name) != host_compute_control_output_.end()) { return errors::InvalidArgument( @@ -1660,8 +1706,7 @@ Status XlaCompiler::PopNodeTokenMapping() { return OkStatus(); } -Status XlaCompiler::SetNodeToken(const string& node_name, - const xla::XlaOp& op) { +Status XlaCompiler::SetNodeToken(const string& node_name, const xla::XlaOp op) { if (node_token_mapping_stack_.empty()) { return errors::FailedPrecondition( "Calling SetNodeToken() when node_token_mapping_stack_ is " diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index d027326239e..a90d705b2b1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -301,7 +301,7 @@ class XlaCompiler { Status GetHostComputeControlDependency(const string& host_compute_name, xla::XlaOp* handle); Status SetHostComputeControlDependency(const string& host_compute_name, - const xla::XlaOp& handle); + xla::XlaOp handle); const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } @@ -309,7 +309,7 @@ class XlaCompiler { void PushNodeTokenMapping(); Status PopNodeTokenMapping(); - Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + Status SetNodeToken(const string& node_name, xla::XlaOp op); StatusOr GetNodeToken(const string& node_name); // Sets the function body `fbody` to the one registered as `function`. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 5231d9e4246..ea34895b76b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -572,14 +572,13 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape", std::move(graph), args, &result); EXPECT_FALSE(status.ok()); + EXPECT_TRUE(absl::StrContains(status.message(), "depends on a parameter")) + << status.message(); + EXPECT_TRUE(absl::StrContains(status.message(), "{{node C}}")) + << status.message(); EXPECT_TRUE( - absl::StrContains(status.error_message(), "depends on a parameter")) - << status.error_message(); - EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node C}}")) - << status.error_message(); - EXPECT_TRUE(absl::StrContains(status.error_message(), - "must be a compile-time constant")) - << status.error_message(); + absl::StrContains(status.message(), "must be a compile-time constant")) + << status.message(); } // Tests handling of compile-time constant outputs. @@ -943,8 +942,8 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) - << status.error_message(); + EXPECT_TRUE(absl::StrContains(status.message(), "is not defined.")) + << status.message(); } FunctionDef FillFn() { @@ -1022,11 +1021,11 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) - << status.error_message(); + EXPECT_TRUE(absl::StrContains(status.message(), "is not defined.")) + << status.message(); // Local flib lookup failure. - EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found")) - << status.error_message(); + EXPECT_TRUE(absl::StrContains(status.message(), "Attr T is not found")) + << status.message(); } FunctionDef SliceFn() { @@ -1521,10 +1520,10 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) - << status.error_message(); - EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}")) - << status.error_message(); + EXPECT_TRUE(absl::StrContains(status.message(), "InvalidOp")) + << status.message(); + EXPECT_TRUE(absl::StrContains(status.message(), "{{node fill_fn}}")) + << status.message(); } // Tests a graph which has a node with invalid data type. @@ -1546,11 +1545,11 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(absl::StrContains(status.error_message(), + EXPECT_TRUE(absl::StrContains(status.message(), "is not in the list of allowed values")) - << status.error_message(); - EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}")) - << status.error_message(); + << status.message(); + EXPECT_TRUE(absl::StrContains(status.message(), "{{node Shape}}")) + << status.message(); } TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f52e83c8c63..c936d6d7962 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -90,8 +90,8 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, DataType index_type, const TensorShape& indices_shape, - const xla::XlaOp& indices, const xla::XlaOp& on_value, - const xla::XlaOp& off_value, xla::XlaOp* one_hot) { + const xla::XlaOp indices, const xla::XlaOp on_value, + const xla::XlaOp off_value, xla::XlaOp* one_hot) { // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); @@ -128,7 +128,7 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { return dtype; } -xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand, +xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp operand, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 0e621995cbc..f2551774f1c 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -68,8 +68,8 @@ class XlaHelpers { // respectively. static Status OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, DataType index_type, const TensorShape& indices_shape, - const xla::XlaOp& indices, const xla::XlaOp& on_value, - const xla::XlaOp& off_value, xla::XlaOp* one_hot); + xla::XlaOp indices, xla::XlaOp on_value, + xla::XlaOp off_value, xla::XlaOp* one_hot); // Certain DataTypes should use increased precision DataTypes when performing // reductions. This function remaps a given DataType to a higher precision @@ -78,7 +78,7 @@ class XlaHelpers { // A helper for creating a ConvertElementType xla op given a DataType rather // than the xla::PrimitiveType. - static xla::XlaOp ConvertElementType(const xla::XlaOp& operand, + static xla::XlaOp ConvertElementType(xla::XlaOp operand, const DataType new_element_type); typedef std::function(const TensorShape&, DataType, bool, diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index c9d17abe2a7..5c54551707b 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -79,6 +79,15 @@ void CollectNames(const T& entries, std::vector* nonempty_names, name_ptrs->push_back(nullptr); // array terminator } +bool RunXlaRuntime(const xla::cpu::CpuExecutable* cpu_executable, + const std::vector& descriptor_table, + const xla::ExecutableRunOptions* run_options) { + assert(cpu_executable->IsXlaRuntime()); + Status status = + cpu_executable->ExecuteXlaRuntime(descriptor_table, run_options); + return status.ok(); +} + } // namespace /*static*/ StatusOr> @@ -147,6 +156,12 @@ XlaJitCompiledCpuFunction::Compile( std::make_unique(program_shape->ToProto()); XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_, raw_function); + if (cpu_executable->IsXlaRuntime()) { + XlaCompiledCpuFunction::set_static_data_run_function(&jit->static_data_, + RunXlaRuntime); + XlaCompiledCpuFunction::set_static_data_cpu_executable(&jit->static_data_, + cpu_executable); + } XlaCompiledCpuFunction::set_static_data_buffer_infos( &jit->static_data_, jit->buffer_infos_.data()); XlaCompiledCpuFunction::set_static_data_num_buffers( diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 6f45dcf1726..f1b838ab882 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -331,7 +331,7 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { return std::unique_ptr(nullptr); }); - EXPECT_THAT(xla::PlatformUtil::GetDefaultPlatform().status().error_message(), + EXPECT_THAT(xla::PlatformUtil::GetDefaultPlatform().status().message(), HasSubstr("FakePlatform")); GraphDef graph_def = SumGraph(); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 0f7373659bd..54996de9c24 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -199,18 +199,20 @@ Status XlaOpKernelContext::ConstantInputReshaped( return OkStatus(); } -// Converts an int32 or int64 scalar literal to an int64. +// Converts an int16, int32 or int64 scalar literal to an int64. static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, int64_t* out) { if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } - if (literal.shape().element_type() == xla::S32) { + if (literal.shape().element_type() == xla::S16) { + *out = literal.Get({}); + } else if (literal.shape().element_type() == xla::S32) { *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S64) { *out = literal.Get({}); } else { - return errors::InvalidArgument("value must be either int32 or int64"); + return errors::InvalidArgument("value must be int16, int32, or int64"); } return OkStatus(); } @@ -754,8 +756,7 @@ Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, static Status GetStatusWithStackTrace(const Status& s, const XlaOpKernelContext* ctx) { if (s.code() == error::INVALID_ARGUMENT) { - return Status{s.code(), - absl::StrCat(s.error_message(), "\n", ctx->StackTrace())}; + return Status{s.code(), absl::StrCat(s.message(), "\n", ctx->StackTrace())}; } return s; } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 7f1b5dbd1b9..0bdb03cd76b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -222,7 +222,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { const OpDef* op_def; Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); if (!lookup_status.ok()) { - LOG(ERROR) << lookup_status.error_message(); + LOG(ERROR) << lookup_status.message(); XLA_LOG_LINES( ERROR, "Ops registered: \n" + diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 347c81c10d2..f29391811ed 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -29,6 +29,7 @@ package_group( "//third_party/mlir_edge/model_curriculum/...", "//third_party/py/jax/...", "//third_party/py/t5x/...", + "//third_party/py/tpu_graphs/...", "//tensorflow/compiler/...", "//tensorflow/python/tpu/...", ], @@ -206,6 +207,7 @@ cc_library( visibility = [":friends"], deps = [ "//third_party/eigen3", + "@com_google_absl//absl/strings:str_format", ], ) @@ -770,6 +772,10 @@ cc_library( hdrs = ["executable_run_options.h"], compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + ], ) cc_library( @@ -1177,7 +1183,9 @@ filegroup( "runlit.cfg.py", "runlit.site.cfg.py", ], - visibility = ["//tensorflow/compiler/xla:__subpackages__"], + visibility = [ + "//tensorflow/compiler/xla:__subpackages__", # Scheuklappen: keep + ], ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index bdfc8d687e6..3238ffdf53d 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include +#include +#include #include #include #include @@ -26,15 +27,12 @@ limitations under the License. #include #include #include -#include #include "absl/functional/function_ref.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/tsl/platform/logging.h" namespace xla { @@ -69,29 +67,37 @@ class Array { // nests, especially if one or more dimensions is one as the compiler just // sees a single-element integer initializer. These typedefs allow casting // explicitly with less typing. - using InitializerList1D = std::initializer_list; - using InitializerList2D = std::initializer_list; - using InitializerList3D = std::initializer_list; - using InitializerList4D = std::initializer_list; + template + using InitializerList1D = std::initializer_list; + template + using InitializerList2D = std::initializer_list>; + template + using InitializerList3D = std::initializer_list>; + template + using InitializerList4D = std::initializer_list>; using value_type = T; // Creates a new array with the specified dimensions and initialized elements. explicit Array(absl::Span sizes) - : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]()) {} + : sizes_(sizes.size()), + values_(calculate_elements(sizes), default_init_t{}) { + std::memcpy(sizes_.data.get(), sizes.data(), + sizeof(int64_t) * sizes.size()); + } // Creates a new array with the specified dimensions and specified value for // every cell. Array(absl::Span sizes, T value) - : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { + : Array(sizes, no_default_init_t{}) { Fill(value); } // Creates a 2D array from the given nested initializer list. The outer // initializer list is the first dimension, the inner is the second dimension. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. - Array(InitializerList2D values) - : Array(ToInt64Vector({values.size(), values.begin()->size()})) { + Array(InitializerList2D values) + : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; for (const auto& it1 : values) { for (const auto& it2 : it1) { @@ -111,7 +117,7 @@ class Array { std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list values) - : Array(ToInt64Vector({values.size()})) { + : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; for (const auto& it1 : values) { values_[idx] = static_cast(it1); @@ -131,7 +137,7 @@ class Array { std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list> values) - : Array(ToInt64Vector({values.size(), values.begin()->size()})) { + : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; for (const auto& it1 : values) { for (const auto& it2 : it1) { @@ -144,9 +150,8 @@ class Array { // Creates a 3D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(InitializerList3D values) - : Array(ToInt64Vector({values.size(), values.begin()->size(), - values.begin()->begin()->size()})) { + Array(InitializerList3D values) + : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; for (const auto& it1 : values) { for (const auto& it2 : it1) { @@ -169,8 +174,7 @@ class Array { std::is_same::value>::type> Array(std::initializer_list>> values) - : Array(ToInt64Vector({values.size(), values.begin()->size(), - values.begin()->begin()->size()})) { + : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; for (const auto& it1 : values) { for (const auto& it2 : it1) { @@ -185,10 +189,8 @@ class Array { // Creates a 4D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(InitializerList4D values) - : Array(ToInt64Vector({values.size(), values.begin()->size(), - values.begin()->begin()->size(), - values.begin()->begin()->begin()->size()})) { + Array(InitializerList4D values) + : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; for (const auto& it1 : values) { for (const auto& it2 : it1) { @@ -214,9 +216,7 @@ class Array { Array(std::initializer_list< std::initializer_list>>> values) - : Array(ToInt64Vector({values.size(), values.begin()->size(), - values.begin()->begin()->size(), - values.begin()->begin()->begin()->size()})) { + : Array(ToInt64Array(values), no_default_init_t{}) { int64_t idx = 0; for (const auto& it1 : values) { for (const auto& it2 : it1) { @@ -232,43 +232,29 @@ class Array { } Array(const Array& other) - : sizes_(other.sizes_), values_(new T[num_elements()]) { - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - } + : sizes_(other.sizes_.Clone()), values_(other.values_.Clone()) {} - Array(Array&& other) - : sizes_(std::move(other.sizes_)), values_(std::move(other.values_)) {} + Array(Array&& other) = default; Array& operator=(const Array& other) { - sizes_ = other.sizes_; - values_.reset(new T[num_elements()]); - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); + sizes_ = other.sizes_.Clone(); + values_ = other.values_.Clone(); return *this; } - Array& operator=(Array&& other) { - sizes_ = std::move(other.sizes_); - values_ = std::move(other.values_); - return *this; - } + Array& operator=(Array&& other) = default; // Fills the array with the specified value. - void Fill(const T& value) { - std::fill(&values_[0], &values_[0] + num_elements(), value); - } + void Fill(const T& value) { std::fill(begin(), end(), value); } // Fills the array with sequentially increasing values. - void FillIota(const T& value) { - std::iota(&values_[0], &values_[0] + num_elements(), value); - } + void FillIota(const T& value) { std::iota(begin(), end(), value); } // Fills the array with a repeating sequence: // [value, value + 1, ..., value + length - 1, value, ... ] void FillRepeatedIota(const T& value, int64_t length) { for (int64_t i = 0; i < num_elements(); i += length) { - std::iota(&values_[i], &values_[std::min(i + length, num_elements())], + std::iota(begin() + i, begin() + std::min(i + length, num_elements()), value); } } @@ -324,23 +310,23 @@ class Array { void SetValues(const Container& container) { CHECK_EQ(std::distance(std::begin(container), std::end(container)), num_elements()); - std::copy(std::begin(container), std::end(container), &values_[0]); + std::copy(std::begin(container), std::end(container), begin()); } // Invokes a callback with the (indices, value_ptr) for each cell in the // array. void Each(absl::FunctionRef, T*)> f) { - std::vector index(sizes_.size()); + OwnedBuffer index(sizes_.size, default_init_t{}); for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { - f(index, &values_[i]); + f(index.span(), &values_[i]); } } // Invokes a callback with the (indices, value) for each cell in the array. void Each(absl::FunctionRef, T)> f) const { - std::vector index(sizes_.size()); + OwnedBuffer index(sizes_.size, default_init_t{}); for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { - f(index, values_[i]); + f(index.span(), values_[i]); } } @@ -349,9 +335,9 @@ class Array { // OkStatus(). Status EachStatus( absl::FunctionRef, T*)> f) { - std::vector index(sizes_.size()); + OwnedBuffer index(sizes_.size, default_init_t{}); for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { - Status s = f(index, &values_[i]); + Status s = f(index.span(), &values_[i]); if (!s.ok()) { return s; } @@ -364,9 +350,9 @@ class Array { // OkStatus(). Status EachStatus( absl::FunctionRef, T)> f) const { - std::vector index(sizes_.size()); + OwnedBuffer index(sizes_.size, default_init_t{}); for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { - Status s = f(index, values_[i]); + Status s = f(index.span(), values_[i]); if (!s.ok()) { return s; } @@ -384,6 +370,7 @@ class Array { typename std::enable_if::value, const T&>::type operator()(Dims... dims) const { + CHECK_EQ(sizeof...(dims), num_dimensions()); // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{ @@ -397,23 +384,21 @@ class Array { typename std::enable_if::value, T&>::type operator()(Dims... dims) { - // We are using a std::array to avoid having to allocate memory in this - // function for performance reasons. - std::array indexes{ - {static_cast(dims)...}}; - return values_[calculate_index(indexes)]; + return const_cast(const_cast(this)->operator()( + std::forward(dims)...)); } // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. const T& operator()(absl::Span indexes) const { + CHECK_EQ(indexes.size(), num_dimensions()); return values_[calculate_index(indexes)]; } // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. T& operator()(absl::Span indexes) { - return values_[calculate_index(indexes)]; + return const_cast(const_cast(this)->operator()(indexes)); } // Low-level accessor for stuff like memcmp, handle with care. Returns pointer @@ -422,37 +407,33 @@ class Array { // TODO(tberghammer): Get rid of the const_cast. Currently it is needed // because the Eigen backend needs a non-const pointers even for reading // from the array. - return const_cast(this)->values_.get(); + return const_cast(this)->values_.data.get(); } // Returns the size of the dimension at the given index. int64_t dim(int64_t n) const { - const int64_t sizes_size = sizes_.size(); - CHECK(n < sizes_size); + DCHECK_LT(n, sizes_.size); return sizes_[n]; } // Returns a vector containing the dimensions of the array. - const std::vector& dimensions() const { return sizes_; } + absl::Span dimensions() const { return sizes_.span(); } - int64_t num_dimensions() const { return sizes_.size(); } + int64_t num_dimensions() const { return sizes_.size; } // Returns the total number of elements in the array. - int64_t num_elements() const { - return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, - std::multiplies()); - } + int64_t num_elements() const { return values_.size; } - const T* begin() const { return &values_[0]; } - T* begin() { return &values_[0]; } - const T* end() const { return &values_[num_elements()]; } - T* end() { return &values_[num_elements()]; } + const T* begin() const { return values_.data.get(); } + T* begin() { return values_.data.get(); } + const T* end() const { return values_.data.get() + num_elements(); } + T* end() { return values_.data.get() + num_elements(); } bool operator==(const Array& other) const { - if (sizes_.size() != other.sizes_.size()) { + if (sizes_.size != other.sizes_.size) { return false; } - for (int64_t i = 0, end = sizes_.size(); i < end; ++i) { + for (int64_t i = 0, end = sizes_.size; i < end; ++i) { if (sizes_[i] != other.sizes_[i]) { return false; } @@ -473,16 +454,16 @@ class Array { CHECK_EQ(starts.size(), num_dimensions()); CHECK_EQ(limits.size(), num_dimensions()); - std::vector sizes; - std::transform(starts.begin(), starts.end(), limits.begin(), - std::back_inserter(sizes), - [](int64_t start, int64_t limit) { return limit - start; }); - Array result(sizes); + OwnedBuffer sizes(starts.size()); + for (int64_t i = 0; i < starts.size(); ++i) { + sizes[i] = limits[i] - starts[i]; + } + Array result(sizes.span()); - std::vector index(sizes_.size()); + OwnedBuffer index(sizes_.size, default_init_t{}); int64_t slice_i = 0; for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { - if (array_impl::all_inside_range(index, starts, limits)) { + if (array_impl::all_inside_range(index.span(), starts, limits)) { // Even though the bounds of result are different to our bounds, we're // iterating in the same order. So we can simply write successive linear // indices instead of recalculating a multi-dimensional index. @@ -496,14 +477,15 @@ class Array { void UpdateSlice(const Array& from, absl::Span start_indices) { CHECK_EQ(from.num_dimensions(), num_dimensions()); - std::vector limit_indices; - std::transform(start_indices.begin(), start_indices.end(), - from.dimensions().begin(), std::back_inserter(limit_indices), - std::plus{}); - std::vector index(sizes_.size()); + OwnedBuffer limit_indices(start_indices.size()); + for (int64_t i = 0; i < start_indices.size(); ++i) { + limit_indices[i] = from.sizes_[i] + start_indices[i]; + } + OwnedBuffer index(sizes_.size, default_init_t{}); int64_t from_i = 0; for (int64_t i = 0; i < num_elements(); ++i, next_index(&index)) { - if (array_impl::all_inside_range(index, start_indices, limit_indices)) { + if (array_impl::all_inside_range(index.span(), start_indices, + limit_indices)) { // Even though the bounds of from are different to our bounds, we're // iterating in the same order. So we can simply write successive linear // indices instead of recalculating a multi-dimensional index. @@ -515,86 +497,167 @@ class Array { // Performs an in-place reshape, modifying the dimensions but not the // underlying data. void Reshape(absl::Span new_dimensions) { - int64_t old_num_elements = num_elements(); - sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); - CHECK_EQ(num_elements(), old_num_elements); + const int64_t new_num_elements = + std::accumulate(new_dimensions.begin(), new_dimensions.end(), 1LL, + std::multiplies()); + CHECK_EQ(new_num_elements, num_elements()); + if (sizes_.size != new_dimensions.size()) { + sizes_ = OwnedBuffer(new_dimensions.size()); + } + std::memcpy(sizes_.data.get(), new_dimensions.data(), + new_dimensions.size() * sizeof(int64_t)); } // Performs a permutation of dimensions. void TransposeDimensions(absl::Span permutation) { - std::vector permuted_dims(permutation.size()); + CHECK_EQ(sizes_.size, permutation.size()); + OwnedBuffer permuted_dims(permutation.size()); for (int64_t i = 0; i < permutation.size(); ++i) { permuted_dims[i] = this->dim(permutation[i]); } - Array permuted(permuted_dims); - std::vector src_indices(sizes_.size(), -1); + Array permuted(permuted_dims.span()); + OwnedBuffer src_indices(sizes_.size, -1); permuted.Each([&](absl::Span indices, T* value) { - CHECK_EQ(sizes_.size(), indices.size()); - for (int64_t i = 0; i < sizes_.size(); ++i) { + for (int64_t i = 0; i < sizes_.size; ++i) { src_indices[permutation[i]] = indices[i]; } - *value = (*this)(src_indices); + *value = (*this)(src_indices.span()); }); *this = std::move(permuted); } template friend H AbslHashValue(H h, const Array& array) { - return H::combine(std::move(h), absl::MakeSpan(array.begin(), array.end()), - array.dimensions()); + return H::combine(std::move(h), array.values_.span(), array.dimensions()); } // Returns a string representation of the array suitable for debugging. std::string ToString() const { - if (sizes_.empty()) { + if (sizes_.size == 0) { return ""; } - std::vector pieces; - std::vector index(sizes_.size()); + std::string result; + OwnedBuffer index(sizes_.size, default_init_t{}); do { // Emit leading spaces and opening square brackets - if (index.back() == 0) { - for (int64_t i = sizes_.size() - 1; i >= 0; --i) { + if (index[index.size - 1] == 0) { + for (int64_t i = sizes_.size - 1; i >= 0; --i) { if (i == 0 || index[i - 1] != 0) { - for (int64_t j = 0; j < sizes_.size(); ++j) { - pieces.push_back(j < i ? " " : "["); + for (int64_t j = 0; j < sizes_.size; ++j) { + absl::StrAppend(&result, j < i ? " " : "["); } break; } } } - int value_index = calculate_index(index); + int value_index = calculate_index(index.span()); if (value_index < num_elements()) { - pieces.push_back(absl::StrCat(values_[value_index])); + absl::StrAppend(&result, values_[value_index]); } // Emit comma if it isn't the last element - if (index.back() < sizes_.back() - 1) { - pieces.push_back(", "); + if (index[index.size - 1] < sizes_[sizes_.size - 1] - 1) { + absl::StrAppend(&result, ", "); } // Emit closing square brackets - for (int64_t i = sizes_.size() - 1; i >= 0; --i) { + for (int64_t i = sizes_.size - 1; i >= 0; --i) { if (index[i] < sizes_[i] - 1) { break; } - pieces.push_back("]"); + absl::StrAppend(&result, "]"); if (i != 0 && index[i - 1] < sizes_[i - 1] - 1) { - pieces.push_back(",\n"); + absl::StrAppend(&result, ",\n"); } } } while (next_index(&index)); - return absl::StrJoin(pieces, ""); + return result; } private: - // Converts an initializer_list of type U to a vector of type int64_t. Used by - // the initializer list based constructors to convert the size type into - // int64_t to be passed to the size based constructor. - template - static std::vector ToInt64Vector( - const std::initializer_list& data) { - return std::vector(data.begin(), data.end()); + struct default_init_t {}; + struct no_default_init_t {}; + // A fixed sized dynamically allocated buffer to replace std::vector usage. It + // saves one word for storing capacity which is always the same as size and it + // provides the ability to leave its elements uninitialized if the element + // type is trivially destructible. + template + struct OwnedBuffer { + explicit OwnedBuffer(size_t size) + : data(std::is_trivially_destructible_v ? new D[size] + : new D[size]()), + size(size) {} + explicit OwnedBuffer(size_t size, default_init_t) + : data(new D[size]()), size(size) {} + + explicit OwnedBuffer(size_t size, D init) : OwnedBuffer(size) { + std::fill(data.get(), data.get() + size, init); + } + + OwnedBuffer(OwnedBuffer&& other) + : data(std::move(other.data)), size(other.size) { + other.size = 0; + } + + OwnedBuffer& operator=(OwnedBuffer&& other) { + data = std::move(other.data); + size = other.size; + other.size = 0; + return *this; + } + + OwnedBuffer Clone() const { + OwnedBuffer clone(size); + std::memcpy(clone.data.get(), data.get(), size * sizeof(D)); + return clone; + } + + D& operator[](int64_t index) { return data[index]; } + const D& operator[](int64_t index) const { return data[index]; } + + absl::Span span() const { + return absl::MakeConstSpan(data.get(), size); + } + + std::unique_ptr data; + size_t size; + }; + + explicit Array(absl::Span sizes, no_default_init_t) + : sizes_(sizes.size()), values_(calculate_elements(sizes)) { + std::memcpy(sizes_.data.get(), sizes.data(), + sizeof(int64_t) * sizes.size()); + } + + // Extracts the dimensions of an initializer_list to an array type int64_t. + // Used by the initializer list based constructors to convert the size type + // into int64_t to be passed to the size based constructor. + template + static std::array ToInt64Array(const InitializerList1D& data) { + return std::array{static_cast(data.size())}; + } + + template + static std::array ToInt64Array(const InitializerList2D& data) { + return std::array{static_cast(data.size()), + static_cast(data.begin()->size())}; + } + + template + static std::array ToInt64Array(const InitializerList3D& data) { + return std::array{ + static_cast(data.size()), + static_cast(data.begin()->size()), + static_cast(data.begin()->begin()->size())}; + } + + template + static std::array ToInt64Array(const InitializerList4D& data) { + return std::array{ + static_cast(data.size()), + static_cast(data.begin()->size()), + static_cast(data.begin()->begin()->size()), + static_cast(data.begin()->begin()->begin()->size())}; } // Returns the linear index from the list of per-dimension indexes. Function @@ -602,11 +665,10 @@ class Array { // memory allocation. // The returned value may be larger than or equal to the number of elements if // the indexes exceed the array's corresponding dimension size. - template - int64_t calculate_index(const U& indexes) const { - CHECK_EQ(sizes_.size(), indexes.size()); + int64_t calculate_index(absl::Span indexes) const { + DCHECK_EQ(sizes_.size, indexes.size()); int64_t index = 0; - for (int64_t i = 0; i < sizes_.size(); ++i) { + for (int64_t i = 0; i < sizes_.size; ++i) { index *= sizes_[i]; index += indexes[i]; } @@ -615,9 +677,9 @@ class Array { // Advances the specified set of indexes and returns true if we haven't // wrapped around (i.e. result isn't {0, 0, ...}). - bool next_index(std::vector* index) const { - CHECK_EQ(index->size(), sizes_.size()); - for (int64_t i = sizes_.size() - 1; i >= 0; --i) { + bool next_index(OwnedBuffer* index) const { + DCHECK_EQ(index->size, sizes_.size); + for (int64_t i = sizes_.size - 1; i >= 0; --i) { (*index)[i]++; if ((*index)[i] < sizes_[i]) { return true; @@ -627,15 +689,20 @@ class Array { return false; } - std::vector sizes_; - std::unique_ptr values_; + static size_t calculate_elements(absl::Span sizes) { + return std::accumulate(sizes.begin(), sizes.end(), 1LL, + std::multiplies()); + } + + OwnedBuffer sizes_; + OwnedBuffer values_; }; // Specialization of FillRandom() method for complex64 type. Uses real part of // the stddev parameter as the standard deviation value. template <> -void Array::FillRandom(const complex64& stddev, const double mean, - const int seed); +void Array::FillRandom(const complex64& stddev, double mean, + int seed); } // namespace xla diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 77e3c9c94e8..2409fe6268b 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -54,6 +54,7 @@ class Array2D : public Array { // or double) from the given nested initializer list of float values. template ::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || diff --git a/tensorflow/compiler/xla/autotune_serialize.cc b/tensorflow/compiler/xla/autotune_serialize.cc index 149fdf9f24f..71e0562b6e8 100644 --- a/tensorflow/compiler/xla/autotune_serialize.cc +++ b/tensorflow/compiler/xla/autotune_serialize.cc @@ -40,7 +40,7 @@ Status LoadAutotuneResults(absl::string_view data) { } if (results.version() != kVersion) { return tsl::errors::InvalidArgument(absl::StrFormat( - "Version mismatch in autotune results. Expected %d but was %d", + "Version mismatch in autotune results. Expected %d but was %d", kVersion, results.version())); } diff --git a/tensorflow/compiler/xla/backends/interpreter/platform.cc b/tensorflow/compiler/xla/backends/interpreter/platform.cc index 0259baf8221..9e7bfe804f8 100644 --- a/tensorflow/compiler/xla/backends/interpreter/platform.cc +++ b/tensorflow/compiler/xla/backends/interpreter/platform.cc @@ -80,7 +80,7 @@ XlaInterpreterPlatform::GetUncachedExecutor( auto init_status = executor->Init(config.device_options); if (!init_status.ok()) { return tsl::Status{ - tsl::error::INTERNAL, + absl::StatusCode::kInternal, absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", config.ordinal, init_status.ToString())}; diff --git a/tensorflow/compiler/xla/backends/profiler/cpu/BUILD b/tensorflow/compiler/xla/backends/profiler/cpu/BUILD index 3543ed7a558..c1f28b7bc4d 100644 --- a/tensorflow/compiler/xla/backends/profiler/cpu/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/cpu/BUILD @@ -9,7 +9,6 @@ cc_library( visibility = [ "//tensorflow/compiler/xla/backends/profiler:__pkg__", "//tensorflow/core/profiler:internal", - "//third_party/car/onboard/gpu:__subpackages__", ], deps = [ ":host_tracer_impl", diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD index 12a4f1a9926..5ae87af9e0a 100644 --- a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD @@ -112,8 +112,8 @@ tsl_gpu_cc_test( size = "small", srcs = ["cupti_error_manager_test.cc"], tags = tf_cuda_tests_tags() + [ - "nomac", "gpu_cupti", + "nomac", ], deps = [ "//tensorflow/tsl/platform:test_main", @@ -240,23 +240,23 @@ tsl_gpu_library( copts = tf_profiler_copts() + tsl_copts(), visibility = ["//visibility:public"], deps = [ + "//tensorflow/tsl/platform:abi", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:mutex", + "//tensorflow/tsl/platform:platform_port", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:types", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + "//tensorflow/tsl/profiler/utils:parse_annotation", + "//tensorflow/tsl/profiler/utils:trace_utils", + "//tensorflow/tsl/profiler/utils:xplane_builder", + "//tensorflow/tsl/profiler/utils:xplane_schema", + "//tensorflow/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", - "//tensorflow/tsl/platform:abi", - "//tensorflow/tsl/platform:platform_port", - "//tensorflow/tsl/platform:mutex", - "//tensorflow/tsl/platform:macros", - "//tensorflow/tsl/platform:status", - "//tensorflow/tsl/platform:types", - "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", - "//tensorflow/tsl/profiler/utils:parse_annotation", - "//tensorflow/tsl/profiler/utils:xplane_builder", - "//tensorflow/tsl/profiler/utils:xplane_schema", - "//tensorflow/tsl/profiler/utils:xplane_utils", - "//tensorflow/tsl/profiler/utils:trace_utils", ] + tf_additional_cupti_deps(), ) diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.cc index 11a13892df9..f3a17b64db2 100644 --- a/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.cc @@ -90,7 +90,7 @@ Status ToStatus(CUresult result) { inline void LogIfError(const Status &status) { if (status.ok()) return; - LOG(ERROR) << status.error_message(); + LOG(ERROR) << status.message(); } // Maps an OverheadKind enum to a const string. diff --git a/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc b/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc index 48f9ec1bead..de05a7e1545 100644 --- a/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc +++ b/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc @@ -65,7 +65,7 @@ TpuTracer::TpuTracer() { stream_executor::tpu::OpsApiFn()->TpuProfiler_CreateFn(&tpu_profiler_, status.c_status); if (!status.ok()) { - LOG(ERROR) << status.status().error_message(); + LOG(ERROR) << status.status().message(); } } diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index e1a9edfb119..9d46f68b6fb 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -212,6 +212,7 @@ xla_test( "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/tsl/lib/core:status_test_util", ], ) @@ -293,6 +294,7 @@ cc_library( hdrs = ["prng.h"], deps = [ ":constants", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc index 4bc4c494d87..5b034dde320 100644 --- a/tensorflow/compiler/xla/client/lib/constants_test.cc +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -40,7 +40,7 @@ XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32DoesNotAcceptFloats) { ConstantR0WithType(&builder, xla::S32, 4.5); auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid cast")); + EXPECT_THAT(statusor.status().message(), HasSubstr("Invalid cast")); } XLA_TEST_F(ConstantsTest, ConstantR0WithTypeF32) { diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index fd6a02223ea..25179617548 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -33,8 +33,12 @@ template XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { static_assert(std::is_floating_point::value, "Template-argument 'FP' must be a floating-point type"); - XlaOp poly = ScalarLike(x, 0.0); - for (FP c : coefficients) { + if (coefficients.empty()) { + return ScalarLike(x, FP(0.0)); + } + XlaOp poly = ScalarLike(x, coefficients[0]); + for (int i = 1; i < coefficients.size(); ++i) { + FP c = coefficients[i]; poly = poly * x + ScalarLike(x, c); } return poly; @@ -296,23 +300,27 @@ XlaOp Erfc(XlaOp x) { }); } -// Compute a polynomial approximation of the error function. -// This is the same approximation used by Eigen. +// Compute a rational approximation of the error function. static XlaOp ErfImpl32(XlaOp x) { - static const std::array kAlpha{ - -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, - -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, - -1.60960333262415e-02f, - }; + static const std::array kAlpha{ + 0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f, + 0.18520832239976145f, 1.128379143519084f}; - static const std::array kBeta{ - -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, - -7.37332916720468e-03f, -1.42647390514189e-02f, - }; + static const std::array kBeta{-1.1791602954361697e-7, + 0.000023547966471313185f, + 0.0010179625278914885f, + 0.014070470171167667f, + 0.11098505178285362f, + 0.49746925110067538f, + 1.0f}; - x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f)); + // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of + // which x should be +/-1. + constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f; + x = Clamp(ScalarLike(x, -kErfInvOneMinusHalfULP), x, + ScalarLike(x, kErfInvOneMinusHalfULP)); auto x2 = x * x; - return x * EvaluatePolynomial(x2, kAlpha) / + return (x * EvaluatePolynomial(x2, kAlpha)) / EvaluatePolynomial(x2, kBeta); } @@ -330,10 +338,8 @@ XlaOp Erf(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { - return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x), - ScalarLike(x, 1) - ErfcImpl32(x)); - }); + return DoWithUpcastToF32(x, {BF16, F16}, + [](XlaOp x) { return ErfImpl32(x); }); }); } diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 92a0eaf5b73..ccd4ee2b1cc 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" namespace xla { namespace { @@ -143,7 +144,7 @@ class MathTypedTest : public MathTest { ComputeAndCompareR1(&b, expected, {}, error_spec_); } - void TestErfEdgeCases() { + void TestErfInvEdgeCases() { SetFastMathDisabled(true); XlaBuilder b(TestName()); @@ -155,6 +156,23 @@ class MathTypedTest : public MathTest { ComputeAndCompareR1(&b, expected, {}, error_spec_); } + + void TestErfEdgeCases() { + SetFastMathDisabled(true); + const T kErfInvOneMinusHalfULP = T(3.832506856900711); + const T inf(std::numeric_limits::infinity()); + + XlaBuilder b(TestName()); + auto x = AddParam(LiteralUtil::CreateR1({T{-inf}, T{inf}, T{-0}, T{0}, + T{-kErfInvOneMinusHalfULP}, + T{kErfInvOneMinusHalfULP}}), + &b); + Erf(x); + + std::vector expected = {T(-1), T(1), T(-0), T(0), T(-1), T(1)}; + + ComputeAndCompareR1(&b, expected, {}, error_spec_); + } }; // TODO(b/123355973): Add bfloat16 to TestTypes once it's working. @@ -178,7 +196,8 @@ XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); } XLA_TYPED_TEST(MathTypedTest, SqrtPowInequivalence) { this->TestSqrtPowInequivalence(); } -XLA_TYPED_TEST(MathTypedTest, ErfInvEdgeCases) { this->TestErfEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, ErfInvEdgeCases) { this->TestErfInvEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, ErfEdgeCases) { this->TestErfEdgeCases(); } // Check that certain ops only support real, floating-point inputs. // @@ -203,7 +222,7 @@ XLA_TEST_F(MathTest, RealFpOnlyOps) { } else { continue; } - if (ty == F8E5M2 || ty == F8E4M3FN) { + if (ty == F8E5M2 || ty == F8E4M3FN || ty == F8E4M3B11FNUZ) { // TODO(b/259609697): Add FP8 support to math ops continue; } @@ -226,7 +245,11 @@ XLA_TEST_F(MathTest, RealFpOnlyOps) { XlaOp p = Parameter(&b, 0, shape, "p0"); test.first(p); - EXPECT_EQ(b.first_error().ok(), primitive_util::IsFloatingPointType(ty)); + if (primitive_util::IsFloatingPointType(ty)) { + TF_EXPECT_OK(b.first_error()); + } else { + EXPECT_FALSE(b.first_error().ok()); + } } } } diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 8466b3d51a8..b0b66dd1b0a 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -255,13 +256,15 @@ RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) { // Generates random 16bits with the given shape using the Three Fry // implementation. Returns the random bits and the new state. -RngOutput ThreeFryRngBit16(XlaOp op_key, XlaOp initial_state, - const Shape& shape) { +RngOutput ThreeFryRngBitNarrow(XlaOp op_key, XlaOp initial_state, + const Shape& shape) { // TODO(b/256713018): Use a better approach to not waste the upper 16 bits. auto new_shape = shape; new_shape.set_element_type(U32); auto output = ThreeFryRngBit32(op_key, initial_state, new_shape); - output.value = ConvertElementType(output.value, U16); + output.value = ConvertElementType( + output.value, primitive_util::UnsignedIntegralTypeForBitWidth( + primitive_util::BitWidth(shape.element_type()))); return output; } @@ -446,15 +449,17 @@ RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, // Generates an array of primitive type U16 with the given shape containing // random bits generated by the Philox algorithm. Returns the array and the new // state of the random number generator. -RngOutput PhiloxRngBit16(XlaOp op_key, XlaOp initial_state, - const Shape& shape) { +RngOutput PhiloxRngBitNarrow(XlaOp op_key, XlaOp initial_state, + const Shape& shape) { // We use PhiloxRngBit32 and throw away the upper 16 bits here, to align with // the non-XLA kernels. // TODO(b/256713018): Use a better approach to not waste the upper 16 bits. auto new_shape = shape; new_shape.set_element_type(U32); auto output = PhiloxRngBit32(op_key, initial_state, new_shape); - output.value = ConvertElementType(output.value, U16); + output.value = ConvertElementType( + output.value, primitive_util::UnsignedIntegralTypeForBitWidth( + primitive_util::BitWidth(shape.element_type()))); return output; } @@ -593,10 +598,12 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); switch (type) { + case S8: + case U8: case F16: case U16: case S16: - return ThreeFryRngBit16(key, initial_state, shape); + return ThreeFryRngBitNarrow(key, initial_state, shape); case F32: case U32: case S32: @@ -619,10 +626,12 @@ RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); switch (type) { + case S8: + case U8: case F16: case U16: case S16: - return PhiloxRngBit16(key, initial_state, shape); + return PhiloxRngBitNarrow(key, initial_state, shape); case F32: case U32: case S32: diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index af4883b1f65..cdd1f4a542a 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -22,10 +22,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -XlaOp TopK(XlaOp input, int64_t k) { +XlaOp TopK(XlaOp input, int64_t k, PrimitiveType index_type) { XlaBuilder* const builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); @@ -41,16 +42,17 @@ XlaOp TopK(XlaOp input, int64_t k) { int64_t num_partitions = CeilOfRatio(last_dim_size - k, kPerPartitionSize - k); if (num_partitions >= kMinNumPartitions) { - return TopKWithPartitions(input, k, num_partitions); + return TopKWithPartitions(input, k, num_partitions, index_type); } } - Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions()); - XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); + Shape iota_shape = + ShapeUtil::MakeShape(index_type, input_shape.dimensions()); + XlaOp iota = Iota(builder, iota_shape, last_dim); for (int64_t i = 0; i < input_shape.rank(); ++i) { if (input_shape.is_dynamic_dimension(i)) { // Propagate dynamic dimension from inputs to iota. - iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i); + iota = SetDimensionSize(iota, GetDimensionSize(input, i), i); } } auto input_dims = input_shape.dimensions(); @@ -101,13 +103,14 @@ XlaOp TopK(XlaOp input, int64_t k) { Or(sign_magnitude_to_from_ones_complement( BitcastConvertType(ConvertElementType(input, F32), S32)), ConstantR0(builder, kLow16BitsMask)); - XlaOp input_and_iota = Xor(input_f32_trimmed, iota_s32); + XlaOp input_and_iota = Xor(input_f32_trimmed, iota); // Sort in reverse order so the largest elements are at the beginning. // Breaking ties here is why the index bits need to be inverted. - XlaOp sort_result_raw = Sort( - {input_and_iota}, CreateScalarGtComputation({S32}, builder), last_dim, - /*is_stable=*/false); + XlaOp sort_result_raw = + Sort({input_and_iota}, + CreateScalarGtComputation({index_type}, builder), last_dim, + /*is_stable=*/false); // Slice off the first k values. sort_result_raw = @@ -132,9 +135,9 @@ XlaOp TopK(XlaOp input, int64_t k) { ConstantR0(builder, kLow16BitsMask)); } else { XlaOp sort_result = - Sort({input, iota_s32}, - CreateScalarGtComputation({input_shape.element_type(), S32}, - iota_s32.builder()), + Sort({input, iota}, + CreateScalarGtComputation( + {input_shape.element_type(), index_type}, iota.builder()), last_dim, /*is_stable=*/true); values = Slice(GetTupleElement(sort_result, 0), start_indices, limit_indices, strides); @@ -150,7 +153,8 @@ XlaOp TopK(XlaOp input, int64_t k) { }); } -XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) { +XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions, + PrimitiveType index_type) { XlaBuilder* const builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); @@ -162,15 +166,16 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) { CeilOfRatio(last_dim_size, num_partitions); // Do normal TopK when per partition size is smaller than or equal to k. if (k >= per_partition_size) { - return TopK(input, k); + return TopK(input, k, index_type); } - Shape iota_shape = ShapeUtil::MakeShape(S32, input_shape.dimensions()); - XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); + Shape iota_shape = + ShapeUtil::MakeShape(index_type, input_shape.dimensions()); + XlaOp iota = Iota(builder, iota_shape, last_dim); for (int64_t i = 0; i < input_shape.rank(); ++i) { if (input_shape.is_dynamic_dimension(i)) { // Propagate dynamic dimension from inputs to iota. - iota_s32 = SetDimensionSize(iota_s32, GetDimensionSize(input, i), i); + iota = SetDimensionSize(iota, GetDimensionSize(input, i), i); } } @@ -180,25 +185,41 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) { auto values = values_and_indices[0]; auto indices = values_and_indices[1]; auto input = values_and_indices[2]; - auto iota_s32 = values_and_indices[3]; + auto iota = values_and_indices[3]; // Slice value and indices for this partition. - XlaOp start = Mul(Add(partition, ConstantR0(builder, 1)), - ConstantR0(builder, per_partition_size)); + XlaOp start; + switch (index_type) { + case PrimitiveType::S16: + start = Mul(Add(partition, ConstantR0(builder, 1)), + ConstantR0(builder, per_partition_size)); + break; + case PrimitiveType::S32: + start = Mul(Add(partition, ConstantR0(builder, 1)), + ConstantR0(builder, per_partition_size)); + break; + case PrimitiveType::S64: + start = Mul(Add(partition, ConstantR0(builder, 1)), + ConstantR0(builder, per_partition_size)); + break; + default: + LOG(FATAL) << "Unsupported index type " + << PrimitiveType_Name(index_type); + } XlaOp sliced_input = DynamicSliceInMinorDims(input, {start}, {per_partition_size}); XlaOp sliced_indices = - DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size}); + DynamicSliceInMinorDims(iota, {start}, {per_partition_size}); // Concat with previous results. sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim); sliced_indices = ConcatInDim(builder, {indices, sliced_indices}, last_dim); // Sort this slice - XlaOp sort_result = - Sort({sliced_input, sliced_indices}, - CreateScalarGtComputation({input_shape.element_type(), S32}, - sliced_indices.builder()), - last_dim, true); + XlaOp sort_result = Sort( + {sliced_input, sliced_indices}, + CreateScalarGtComputation({input_shape.element_type(), index_type}, + sliced_indices.builder()), + last_dim, true); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); @@ -210,7 +231,7 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) { limit_indices, strides); indices = Slice(GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); - return std::vector{values, indices, input, iota_s32}; + return std::vector{values, indices, input, iota}; }; // Get the values and indices for the first topk so that they can @@ -222,12 +243,11 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) { limit_indices[last_dim] = per_partition_size; // Slice value and indices for the first partition. XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides); - XlaOp sliced_indices = - Slice(iota_s32, start_indices, limit_indices, strides); + XlaOp sliced_indices = Slice(iota, start_indices, limit_indices, strides); // Sort this slice XlaOp sort_result = Sort({sliced_input, sliced_indices}, - CreateScalarGtComputation({input_shape.element_type(), S32}, + CreateScalarGtComputation({input_shape.element_type(), index_type}, sliced_indices.builder()), last_dim, /*is_stable=*/true); @@ -241,10 +261,11 @@ XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions) { // Pass the result of the first TopK to the while loop and do // num_partition - 1 iterations. - TF_ASSIGN_OR_RETURN(auto values_and_indices, - ForEachIndex(num_partitions - 1, S32, topk_body_fn, - {values, indices, input, iota_s32}, - "topk_with_partition", builder)); + TF_ASSIGN_OR_RETURN( + auto values_and_indices, + ForEachIndex(num_partitions - 1, index_type, topk_body_fn, + {values, indices, input, iota}, "topk_with_partition", + builder)); return Tuple(builder, {values_and_indices[0], values_and_indices[1]}); }); } diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h index 0f810ccb365..9fbdf1b9945 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.h +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -24,11 +24,14 @@ namespace xla { // Returns a tuple composed of the top `k` values and corresponding indices in // `input`. Output values are in descending order, from largest to smallest. -XlaOp TopK(XlaOp input, int64_t k); +XlaOp TopK(XlaOp input, int64_t k, + PrimitiveType index_type = PrimitiveType::S32); + // Split sort in TopK into smaller sorts. // Returns a tuple composed of the top `k` values and corresponding indices in // `input`. Output values are in descending order, from largest to smallest. -XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions = 1); +XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions = 1, + PrimitiveType index_type = PrimitiveType::S32); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index e820d5bfe6f..7d5de392067 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -44,6 +44,14 @@ XLA_TEST_F(SortingTest, TopK3From8Indices) { ComputeAndCompareR1(&builder, {0, 1, 2}, {}); } +XLA_TEST_F(SortingTest, TopK3From8Int16Indices) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopK(x, 3, PrimitiveType::S16), 1); + ComputeAndCompareR1(&builder, {7, 6, 5}, {}); +} + XLA_TEST_F(SortingTest, TopKFullSortMinInt) { XlaBuilder builder(TestName()); auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), @@ -140,6 +148,16 @@ XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) { ComputeAndCompareR1(&builder, {0, 1, 2}, {}); } +XLA_TEST_F(SortingTest, TopK3From8Int16Indices5Partitions) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopKWithPartitions(x_rev, 3, /*num_partitions=*/5, + PrimitiveType::S16), + 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates2Partitions) { XlaBuilder builder(TestName()); XlaOp a; diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 2b3d972d5a5..339ce5b2ad8 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1090,7 +1090,7 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { if (!status_or_shape.status().ok()) { return InvalidArgument( "%s Input scalar shapes may have been changed to non-scalar shapes.", - status_or_shape.status().error_message()); + status_or_shape.status().message()); } return AddOpWithShape(triop, status_or_shape.value(), diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 513670738f1..1b0eb3bc073 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -181,7 +181,7 @@ TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { auto statusor = b.Build(); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( - statusor.status().error_message(), + statusor.status().message(), HasSubstr("Argument to >> operator does not have an integral type")); } @@ -226,7 +226,7 @@ TEST_F(XlaBuilderTest, ShapeInferenceError) { Add(x, y); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), + EXPECT_THAT(statusor.status().message(), HasSubstr("Shapes must be equal rank")); } @@ -250,7 +250,7 @@ TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { Add(x, y); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), + EXPECT_THAT(statusor.status().message(), HasSubstr("parameter 0 already registered")); } @@ -345,7 +345,7 @@ TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) { /*broadcast_dimensions=*/{0, 1, 2}); auto statusor = BuildHloModule(&b); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape")); + EXPECT_THAT(statusor.status().message(), HasSubstr("invalid shape")); } TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { @@ -357,7 +357,7 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( - statusor.status().error_message(), + statusor.status().message(), HasSubstr( "built by builder 'b1', but is trying to use it in builder 'main'")); } @@ -527,7 +527,7 @@ TEST_F(XlaBuilderTest, ReportError) { Add(b.ReportError(InvalidArgument("a test error")), x); auto statusor = b.Build(); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); + EXPECT_THAT(statusor.status().message(), HasSubstr("a test error")); } TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { @@ -545,7 +545,7 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); auto statusor = b.Build(); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); + EXPECT_THAT(statusor.status().message(), HasSubstr("a test error")); } TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { @@ -584,7 +584,7 @@ TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { Status status = b.Build(other_param).status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( - status.error_message(), + status.message(), ::testing::HasSubstr("root operation is not in this computation")); } @@ -1238,7 +1238,7 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); Status status = b.Build().status(); ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), + EXPECT_THAT(status.message(), ::testing::HasSubstr("All operands to AfterAll must be tokens")); } @@ -1471,9 +1471,7 @@ TEST_F(XlaBuilderTest, OutfeedTokenSharding) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto it = std::find_if(module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end(), - [](const HloInstruction* i) { - return i->opcode() == HloOpcode::kOutfeed; - }); + HloPredicateIsOp); EXPECT_NE(it, module->entry_computation()->instructions().end()); auto* outfeed = *it; EXPECT_TRUE(outfeed->has_sharding()); @@ -1507,7 +1505,7 @@ TEST_F(XlaBuilderTest, InvalidSharding) { Parameter(&b, 0, shape2d, "p0"); auto statusor = b.Build(); EXPECT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), + EXPECT_THAT(statusor.status().message(), HasSubstr("Number of tile assignment dimensions (excluding " "subgroups) is different than the input rank")); } diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc index e79d6b75606..69f4c0b2100 100644 --- a/tensorflow/compiler/xla/comparison_util.cc +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -38,6 +38,7 @@ bool IsValidComparison(xla::PrimitiveType type, Comparison::Order order) { case F64: case F8E5M2: case F8E4M3FN: + case F8E4M3B11FNUZ: case C64: case C128: return true; @@ -105,6 +106,7 @@ Comparison::Order DefaultOrdering(PrimitiveType type) { return Comparison::Order::kTotal; case F8E5M2: case F8E4M3FN: + case F8E4M3B11FNUZ: case BF16: case F16: case F32: @@ -187,11 +189,11 @@ std::string ComparisonTypeToString(Comparison::Type type) { } } -std::string ComparisonPrimitiveTypeToString(PrimitiveType type) { +absl::string_view ComparisonPrimitiveTypeToString(PrimitiveType type) { return PrimitiveType_Name(type); } -std::string ComparisonOrderToString(Comparison::Order order) { +absl::string_view ComparisonOrderToString(Comparison::Order order) { switch (order) { case Comparison::Order::kPartial: return "PARTIALORDER"; @@ -262,6 +264,7 @@ Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) { return Type::kUnsigned; case F8E5M2: case F8E4M3FN: + case F8E4M3B11FNUZ: case F16: case F32: case BF16: @@ -316,6 +319,7 @@ std::optional Comparison::Inverse() const { case F64: case F8E5M2: case F8E4M3FN: + case F8E4M3B11FNUZ: case C64: case C128: case S4: diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h index ec97bdc8c2c..1b6f349e8b9 100644 --- a/tensorflow/compiler/xla/comparison_util.h +++ b/tensorflow/compiler/xla/comparison_util.h @@ -234,8 +234,8 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) { std::string ComparisonDirectionToString(Comparison::Direction direction); std::string ComparisonTypeToString(Comparison::Type type); -std::string ComparisonPrimitiveTypeToString(PrimitiveType type); -std::string ComparisonOrderToString(Comparison::Order order); +absl::string_view ComparisonPrimitiveTypeToString(PrimitiveType type); +absl::string_view ComparisonOrderToString(Comparison::Order order); StatusOr StringToComparisonDirection( absl::string_view direction); diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 0c904033bdc..dabed025c92 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -58,6 +58,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_use_acl(true); #endif opts.set_xla_cpu_use_xla_runtime(false); + opts.set_xla_cpu_sparse_cuda_threads(0); opts.set_xla_cpu_enable_fast_math(false); // Disable forms of fast math that have caused users problems in the past. @@ -75,8 +76,11 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // flag. opts.set_xla_gpu_enable_cublaslt(false); - // TODO(b/258036887): Remove this flag once CUDA Graphs are fully supported. - opts.set_xla_gpu_enable_cuda_graphs(false); + // TODO(b/258036887): Enable once CUDA Graphs are fully supported. + opts.set_xla_gpu_cuda_graph_level(0); + opts.set_xla_gpu_cuda_graph_instantiation_threshold(2); + opts.set_xla_gpu_enable_persistent_temp_buffers(false); + opts.set_xla_gpu_cuda_graph_capture_threshold(2); // Despite the name, fast min/max on GPUs does not seem to be any faster, and // adds very counter-intuitive "NaN-swallowing" behavior. @@ -87,6 +91,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_force_host_platform_device_count(1); opts.set_xla_gpu_all_reduce_combine_threshold_bytes(30 * 1024 * 1024); opts.set_xla_gpu_enable_async_all_reduce(true); + opts.set_xla_gpu_enable_reassociation_for_converted_ar(true); + opts.set_xla_cpu_enable_xprof_traceme(false); opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(false); opts.set_xla_multiheap_size_constraint_per_heap(-1); @@ -104,8 +110,15 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_simplify_all_fp_conversions(true); opts.set_xla_dump_latency_hiding_schedule(false); opts.set_xla_gpu_enable_latency_hiding_scheduler(false); + opts.set_xla_gpu_lhs_enable_gpu_async_tracker(false); + opts.set_xla_gpu_pgle_profile_directory(""); opts.set_xla_cpu_enable_mlir_tiling_and_fusion(true); + opts.set_xla_cpu_enable_custom_matmul_tiling(false); + opts.set_xla_cpu_matmul_tiling_m_dim(8); + opts.set_xla_cpu_matmul_tiling_n_dim(8); + opts.set_xla_cpu_matmul_tiling_k_dim(8); + opts.set_xla_cpu_enable_mlir_fusion_outlining(true); opts.set_xla_cpu_enable_experimental_deallocation(true); opts.set_xla_partitioning_algorithm( @@ -114,6 +127,12 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_triton_gemm(true); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); + + // Moving reduce-scatter out of while loops can incrase memory footprint, so + // turning it off by default. + opts.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(false); + + opts.set_xla_gpu_collective_inflation_factor(1); return opts; } @@ -326,14 +345,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), debug_options->xla_cpu_fast_math_honor_nans(), "When xla_cpu_enable_fast_math is true then this controls whether we " - "allow operations to produce NaNs. Ignored when " + "allow operations to produce NaNs. Ignored when " "xla_cpu_enable_fast_math is false.")); flag_list->push_back(tsl::Flag( "xla_cpu_fast_math_honor_infs", bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), debug_options->xla_cpu_fast_math_honor_infs(), "When xla_cpu_enable_fast_math is true then this controls whether we " - "allow operations to produce infinites. Ignored when " + "allow operations to produce infinites. Ignored when " "xla_cpu_enable_fast_math is false.")); flag_list->push_back(tsl::Flag( "xla_cpu_fast_math_honor_division", @@ -403,10 +422,10 @@ void MakeDebugOptionsFlags(std::vector* flag_list, flag_list->push_back(tsl::Flag( "xla_disable_all_hlo_passes", bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, - "Disables all HLO passes. Notes that some passes are necessary for " + "Disables all HLO passes. Notes that some passes are necessary for " "correctness and the invariants that must be satisfied by 'fully " "optimized' HLO are different for different devices and may change " - "over time. The only 'guarantee', such as it is, is that if you compile " + "over time. The only 'guarantee', such as it is, is that if you compile " "XLA and dump the optimized HLO for some graph, you should be able to " "run it again on the same device with the same build of XLA.")); flag_list->push_back( @@ -486,6 +505,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_cpu_use_xla_runtime), debug_options->xla_cpu_use_xla_runtime(), "Enable XLA Runtime in the CPU backend.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_sparse_cuda_threads", + int32_setter_for(&DebugOptions::set_xla_cpu_sparse_cuda_threads), + debug_options->xla_cpu_sparse_cuda_threads(), + "Sets number fo CUDA threads for sparse GPU acceleration in the CPU " + "backend (0 = off).")); flag_list->push_back(tsl::Flag( "xla_gpu_crash_on_verification_failures", bool_setter_for( @@ -527,7 +552,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "If multiple parameters, separate them by comma.")); flag_list->push_back(tsl::Flag( "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", - "Sets compiler fuel, useful for bisecting bugs in passes. Format " + "Sets compiler fuel, useful for bisecting bugs in passes. Format " "--xla_fuel=PASS1=NUM1,PASS2=NUM2,...")); flag_list->push_back(tsl::Flag( "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), @@ -727,6 +752,21 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_enable_async_collective_permute), debug_options->xla_gpu_enable_async_collective_permute(), "Converts synchronous collective-permute ops into asynchronous.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_async_all_gather", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_async_all_gather), + debug_options->xla_gpu_enable_async_all_gather(), + "Converts synchronous all-gather ops into asynchronous.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_async_reduce_scatter", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_async_reduce_scatter), + debug_options->xla_gpu_enable_async_reduce_scatter(), + "Converts synchronous reduce-scatter ops into asynchronous.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_async_all_to_all", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_async_all_to_all), + debug_options->xla_gpu_enable_async_all_to_all(), + "Converts synchronous all-to-all ops into asynchronous.")); flag_list->push_back(tsl::Flag( "xla_gpu_all_reduce_combine_threshold_bytes", int64_setter_for( @@ -749,6 +789,28 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "ReduceScatter-AllReduce-AllGather sequence, with the initial " "ReduceScatter being performed over all of the devices in the same host. " "Set to < 1 to disable all-reduce decomposition.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_while_loop_reduce_scatter_code_motion", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_enable_while_loop_reduce_scatter_code_motion), + debug_options->xla_gpu_enable_while_loop_reduce_scatter_code_motion(), + "Enable hoisting of reduce-scatter outside while loops.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_collective_inflation_factor", + int32_setter_for(&DebugOptions::set_xla_gpu_collective_inflation_factor), + debug_options->xla_gpu_collective_inflation_factor(), + "Inflation factor for collectives. If set to > 1, each XLA/GPU " + "collective will execute multiple times (will yield incorrect results)")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_reassociation_for_converted_ar", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_reassociation_for_converted_ar), + debug_options->xla_gpu_enable_reassociation_for_converted_ar(), + "Enable allreduce reassociation on allreduces that are converted to a " + "wider type. " + "The reassociated allreduce will be promoted to a wider-typed " + "allreduce.")); flag_list->push_back( tsl::Flag("xla_gpu_dump_llvmir", bool_setter_for(&DebugOptions::set_xla_gpu_dump_llvmir), @@ -764,10 +826,33 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_cublaslt(), "Use cuBLASLt for GEMMs when possible.")); flag_list->push_back(tsl::Flag( - "xla_gpu_enable_cuda_graphs", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_cuda_graphs), - debug_options->xla_gpu_enable_cuda_graphs(), - "Use CUDA graphs to execute XLA GPU executables when possible.")); + "xla_gpu_cuda_graph_level", + int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_level), + debug_options->xla_gpu_cuda_graph_level(), + "Set CUDA graph level. 0 = off; 1 = capture fusions and memcpys; 2 = " + "capture convolutions and gemms; 3 = capture collectives.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_cuda_graph_instantiation_threshold", + int32_setter_for( + &DebugOptions::set_xla_gpu_cuda_graph_instantiation_threshold), + debug_options->xla_gpu_cuda_graph_instantiation_threshold(), + "Instantiate a cuda graph after the time a captured function is executed " + "reaches the threshold.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_cuda_graph_capture_threshold", + int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_capture_threshold), + debug_options->xla_gpu_cuda_graph_capture_threshold(), + "Capture a region as a function to be launched as cuda graph if the " + "number of moved instructions reaches this threshold.")); + + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_persistent_temp_buffers", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_persistent_temp_buffers), + debug_options->xla_gpu_enable_persistent_temp_buffers(), + "Allocate temp buffers once during the first execution of an executable. " + "Reuse the allocated buffers in subsequent executions. Executables cannot" + " run concurrently if this is enabled.")); flag_list->push_back( tsl::Flag("xla_dump_disable_metadata", bool_setter_for(&DebugOptions::set_xla_dump_disable_metadata), @@ -834,7 +919,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_cpu_strict_dot_conv_math), debug_options->xla_cpu_strict_dot_conv_math(), "By default, XLA:CPU will run fp16 dot/conv as fp32, as this is " - "generally (much) faster on our hardware. Set this flag to true to " + "generally (much) faster on our hardware. Set this flag to true to " "disable this behavior.")); flag_list->push_back(tsl::Flag( "xla_dump_latency_hiding_schedule", @@ -846,6 +931,31 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_cpu_enable_mlir_tiling_and_fusion), debug_options->xla_cpu_enable_mlir_tiling_and_fusion(), "Enable MLIR tiling and fusion.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_enable_mlir_fusion_outlining", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_mlir_fusion_outlining), + debug_options->xla_cpu_enable_mlir_fusion_outlining(), + "Enable MLIR fusion outlining (to improve compile time).")); + flag_list->push_back(tsl::Flag( + "xla_cpu_enable_custom_matmul_tiling", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_custom_matmul_tiling), + debug_options->xla_cpu_enable_custom_matmul_tiling(), + "Enable custom tiling given by M, K, N parameters.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_matmul_tiling_m_dim", + int64_setter_for(&DebugOptions::set_xla_cpu_matmul_tiling_m_dim), + debug_options->xla_cpu_matmul_tiling_m_dim(), + "Custom tile size for matmul's M dimension.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_matmul_tiling_n_dim", + int64_setter_for(&DebugOptions::set_xla_cpu_matmul_tiling_n_dim), + debug_options->xla_cpu_matmul_tiling_n_dim(), + "Custom tile size for matmul's N dimension.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_matmul_tiling_k_dim", + int64_setter_for(&DebugOptions::set_xla_cpu_matmul_tiling_k_dim), + debug_options->xla_cpu_matmul_tiling_k_dim(), + "Custom tile size for matmul's K dimension.")); flag_list->push_back(tsl::Flag( "xla_cpu_enable_experimental_deallocation", bool_setter_for( @@ -858,6 +968,16 @@ void MakeDebugOptionsFlags(std::vector* flag_list, &DebugOptions::set_xla_gpu_enable_latency_hiding_scheduler), debug_options->xla_gpu_enable_latency_hiding_scheduler(), "Enable latency-hiding scheduler for XLA:GPU")); + flag_list->push_back(tsl::Flag( + "xla_gpu_pgle_profile_directory", + string_setter_for(&DebugOptions::set_xla_gpu_pgle_profile_directory), + debug_options->xla_gpu_pgle_profile_directory(), + "Directory for PGLE profiles in XLA:GPU")); + flag_list->push_back(tsl::Flag( + "xla_gpu_lhs_enable_gpu_async_tracker", + bool_setter_for(&DebugOptions::set_xla_gpu_lhs_enable_gpu_async_tracker), + debug_options->xla_gpu_lhs_enable_gpu_async_tracker(), + "Enable GPU async tracker for latency-hiding scheduler in XLA:GPU")); flag_list->push_back(tsl::Flag( "xla_partitioning_algorithm", setter_for_xla_partitioning_algorithm, DebugOptions::PartitioningAlgorithm_Name( diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 8a0aa19dc06..58be6edef78 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" + // These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't @@ -38,18 +41,12 @@ struct ThreadPoolDevice; } // namespace Eigen namespace tsl { -class Status; -template -class StatusOr; template class AsyncValueRef; } // namespace tsl namespace xla { -using ::tsl::Status; // TENSORFLOW_STATUS_OK -using ::tsl::StatusOr; // TENSORFLOW_STATUS_OK - class DeviceAssignment; class ExecutionProfile; class Shape; @@ -96,19 +93,23 @@ using ThenExecuteFunction = // Callback for sending device buffer to a channel. Returned event will be // recorded on a `stream` once the send operation is completed and data was -// copied from the `src` memory. +// copied from the `src` memory. `frontend_attrs` contains frontend specific +// attributes for the send. using SendDeviceMemoryFunction = - std::function>( + std::function>( int64_t channel_id, stream_executor::Stream* stream, const Shape& shape, - const stream_executor::DeviceMemoryBase& src)>; + const stream_executor::DeviceMemoryBase& src, + const absl::flat_hash_map& frontend_attrs)>; // Callback for receiving device buffer from a channel. Returned event will be // recorded on a `stream` once the recv operation is completed and data was -// copied into the `dst` memory. +// copied into the `dst` memory. `frontend_attrs` contains frontend specific +// attributes for the receive. using RecvDeviceMemoryFunction = - std::function>( + std::function>( int64_t channel_id, stream_executor::Stream* stream, const Shape& shape, - stream_executor::DeviceMemoryBase* dst)>; + stream_executor::DeviceMemoryBase* dst, + const absl::flat_hash_map& frontend_attrs)>; // Class containing options for running a LocalExecutable. class ExecutableRunOptions { diff --git a/tensorflow/compiler/xla/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/experimental/conv_emitter/BUILD deleted file mode 100644 index 35b2e106800..00000000000 --- a/tensorflow/compiler/xla/experimental/conv_emitter/BUILD +++ /dev/null @@ -1,92 +0,0 @@ -# Description: -# MLIR-GPU-specific convolution in XLA service implementation. - -load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [":friends"], - licenses = ["notice"], -) - -package_group( - name = "friends", - includes = ["//tensorflow/compiler/xla:friends"], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -cc_library( - name = "conv_emitter", - srcs = ["conv_emitter.cc"], - hdrs = ["conv_emitter.h"], - deps = [ - ":conv_emitter_transforms", - "//tensorflow/compiler/xla:permutation_util", - "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:TransformUtils", - ], -) - -cc_library( - name = "conv_emitter_transforms", - srcs = ["conv_emitter_transforms.cc"], - hdrs = ["conv_emitter_transforms.h"], - deps = [ - "//tensorflow/tsl/platform:logging", - "//tensorflow/tsl/platform:types", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TransformUtils", - ], -) - -xla_cc_test( - name = "conv_emitter_test", - srcs = ["conv_emitter_test.cc"], - deps = [ - ":conv_emitter", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/tests:filecheck", - "//tensorflow/compiler/xla/tests:verified_hlo_module", - "//tensorflow/tsl/platform:test", - "//tensorflow/tsl/platform:test_main", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineToStandard", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MemRefToLLVM", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.cc deleted file mode 100644 index c5af2884e0a..00000000000 --- a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.cc +++ /dev/null @@ -1,608 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is an explorative prototype emitter for convolution using MLIR. -// This prototype is still under construction. -// TODO(timshen): Fix the documentation once it's implemented. -// -// Goals: -// * Autotune-able tiling. -// * Autotune-able memory accesses. -// * Autotune-able lowering logic (from a portable program to thread-oriented -// CUDA program). -// * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically -// find memory access strategies for given input layouts and tiling configs. - -#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h" - -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/AffineExpr.h" // from @llvm-project -#include "mlir/IR/AffineMap.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h" -#include "tensorflow/compiler/xla/permutation_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h" -#include "tensorflow/compiler/xla/window_util.h" - -namespace xla { -namespace experimental { -namespace { - -using mlir::OpBuilder; - -// Various extracted information for input shapes. -struct ShapeInfo { - // Buffer dimensions in the order of NCHW. - std::vector nchw_dimensions; - - // Buffer dimensions in the order of major to minor; - std::vector physical_dimensions; - - // The affine map that takes NCHW indices, and maps to the physical order. - mlir::AffineMap affine_map; - - mlir::Type element_type; -}; - -ShapeInfo GetShapeInfo(const Shape& shape, int64_t n_dim, int64_t c_dim, - absl::Span spatial_dims, - mlir::Builder builder) { - ShapeInfo shape_info; - - std::vector physical_to_logical( - shape.layout().minor_to_major().rbegin(), - shape.layout().minor_to_major().rend()); - - std::vector nchw_to_logical; - - nchw_to_logical.push_back(n_dim); - nchw_to_logical.push_back(c_dim); - for (int64_t dim : spatial_dims) { - nchw_to_logical.push_back(dim); - } - - for (int64_t dim : nchw_to_logical) { - shape_info.nchw_dimensions.push_back(shape.dimensions(dim)); - } - - for (int64_t dim : physical_to_logical) { - shape_info.physical_dimensions.push_back(shape.dimensions(dim)); - } - - std::vector affine_exprs; - // We want physical to nchw order. - for (int64_t dim : ComposePermutations(InversePermutation(nchw_to_logical), - physical_to_logical)) { - affine_exprs.push_back(builder.getAffineDimExpr(dim)); - } - - shape_info.affine_map = mlir::AffineMap::get( - /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs, - builder.getContext()); - - shape_info.element_type = [&] { - switch (shape.element_type()) { - case xla::F16: - return builder.getF16Type(); - case xla::F32: - return builder.getF32Type(); - default: - break; - } - CHECK(false); - }(); - - return shape_info; -} - -void SetMemRef(mlir::Operation* op, mlir::Value memref) { - if (auto load = mlir::dyn_cast(op)) { - load.setMemRef(memref); - } else if (auto store = mlir::dyn_cast(op)) { - store.setMemRef(memref); - } else { - CHECK(false); - } -} - -// Hoist operations out of `where`. [begin_op, end_op) must be the first -// operations of their parent loop, and `where` must be an ancestor of that -// parent loop. -// -// It always preserves the semantics of the program, therefore it may modify the -// hoisted operations or add extra loops at the hoisted place. -mlir::Operation* HoistAndFix(llvm::iplist::iterator begin_op, - llvm::iplist::iterator end_op, - mlir::AffineForOp where) { - // All loops to hoist through. - llvm::SmallVector ancestors; - getPerfectlyNestedLoops(ancestors, where); - { - int i; - for (i = 0; i < ancestors.size(); i++) { - if (&ancestors[i].getBody()->front() == &*begin_op) { - break; - } - } - CHECK(i < ancestors.size()); - ancestors.resize(i + 1); - } - - std::vector ancestor_dimensions; - for (auto ancestor : ancestors) { - CHECK(IsSimpleLoop(ancestor)); - ancestor_dimensions.push_back( - ancestor.getUpperBoundMap().getSingleConstantResult()); - } - - if (auto alloc = mlir::dyn_cast(begin_op)) { - CHECK(std::next(begin_op) == end_op) - << "alloc() needs to be hoisted by its own"; - - OpBuilder builder(where); - mlir::MemRefType type = alloc.getType(); - CHECK(type.getLayout().isIdentity()); - ancestor_dimensions.insert(ancestor_dimensions.end(), - type.getShape().begin(), type.getShape().end()); - mlir::MemRefType new_type = - mlir::MemRefType::get(ancestor_dimensions, type.getElementType()); - auto new_alloc = builder.create( - builder.getUnknownLoc(), new_type); - - std::vector indvars; - for (auto ancestor : ancestors) { - indvars.push_back(ancestor.getInductionVar()); - } - for (auto& use : llvm::make_early_inc_range(alloc.getResult().getUses())) { - mlir::Operation* owner = use.getOwner(); - BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); - affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(), - indvars.end()); - CHECK(affine_map.affine_map.isIdentity()); - affine_map.affine_map = mlir::AffineMap::getMultiDimIdentityMap( - affine_map.operands.size(), builder.getContext()); - - mlir::Operation* new_op = - CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner)); - SetMemRef(new_op, new_alloc); - owner->replaceAllUsesWith(new_op); - owner->erase(); - } - alloc.erase(); - return new_alloc; - } - - const bool any_op_is_loop_variant = [&] { - for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) { - if (mlir::isa(op)) { - return true; - } - } - return false; - }(); - - if (any_op_is_loop_variant) { - auto builder = OpBuilder(where); - std::vector new_loops; - for (auto dim : ancestor_dimensions) { - auto where = - builder.create(builder.getUnknownLoc(), 0, dim); - new_loops.push_back(where); - builder = OpBuilder::atBlockTerminator(where.getBody()); - } - for (mlir::Operation& op : - llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) { - op.moveBefore(&new_loops.back().getBody()->back()); - } - CHECK_EQ(ancestors.size(), new_loops.size()); - for (int i = 0; i < ancestors.size(); i++) { - replaceAllUsesInRegionWith(ancestors[i].getInductionVar(), - new_loops[i].getInductionVar(), - new_loops.back().getRegion()); - } - return new_loops.front(); - } - CHECK(false); -} - -mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) { - return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where); -} - -struct InitialMlirConvAnchors { - std::vector cartesian_product_loops; - std::vector reduction_loops; - mlir::memref::AllocOp output_acc; -}; - -// Return the following IR with the anchors set to corresponding operations. -// for (cartesian loops...) { -// %output_acc = alloc() : memref(f32) -// output_acc[] = 0 -// for (reduction loops...) { -// output_acc[] += input[...] * filter[...] -// } -// output[...] = output_acc[] -// } -StatusOr CreateNaiveMlirConv( - mlir::Value input, mlir::Value filter, mlir::Value output, - const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info, - const ShapeInfo& output_shape_info, const Window& window, - OpBuilder builder) { - CHECK(input_shape_info.element_type == builder.getF16Type()); - CHECK(filter_shape_info.element_type == builder.getF16Type()); - CHECK(output_shape_info.element_type == builder.getF16Type()); - - auto location = mlir::UnknownLoc::get(builder.getContext()); - - std::vector cartesian_product_loops = - CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder); - - builder = - OpBuilder::atBlockTerminator(cartesian_product_loops.back().getBody()); - - auto output_acc = builder.create( - location, mlir::MemRefType::get({}, builder.getF32Type())); - - builder.create( - location, - builder.create( - location, mlir::FloatAttr::get(builder.getF32Type(), 0)), - output_acc, llvm::ArrayRef()); - - std::vector reduction_loops; - reduction_loops = CreateNestedSimpleLoops( - absl::MakeSpan(filter_shape_info.nchw_dimensions).subspan(1), builder); - - mlir::AffineForOp loop_n = cartesian_product_loops[0]; - mlir::AffineForOp loop_o = cartesian_product_loops[1]; - mlir::AffineForOp loop_c = reduction_loops[0]; - - std::vector output_spatial_indvars; - for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) { - output_spatial_indvars.push_back(loop.getInductionVar()); - } - std::vector filter_spatial_indvars; - for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) { - filter_spatial_indvars.push_back(loop.getInductionVar()); - } - int num_spatial_dims = output_spatial_indvars.size(); - CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size()); - - builder = OpBuilder::atBlockTerminator(reduction_loops.back().getBody()); - - mlir::Value loaded_input = [&] { - std::vector input_indices; - input_indices.push_back(builder.getAffineDimExpr(0)); - input_indices.push_back(builder.getAffineDimExpr(1)); - - // For spatial dimensions, generate input_index * stride + filter_index - - // left_pad - // - // TODO(timshen): guard out-of-bound loads and stores brought by padding. - for (int i = 0; i < num_spatial_dims; i++) { - const WindowDimension& window_dim = window.dimensions(i); - input_indices.push_back( - builder.getAffineDimExpr(i + 2) * window_dim.stride() + - builder.getAffineDimExpr(2 + num_spatial_dims + i) - - window_dim.padding_low()); - } - std::vector input_vars; - input_vars.push_back(loop_n.getInductionVar()); - input_vars.push_back(loop_c.getInductionVar()); - input_vars.insert(input_vars.end(), output_spatial_indvars.begin(), - output_spatial_indvars.end()); - input_vars.insert(input_vars.end(), filter_spatial_indvars.begin(), - filter_spatial_indvars.end()); - - return builder.create( - location, builder.getF32Type(), - builder.createOrFold( - location, input, - mlir::AffineMap(input_shape_info.affine_map) - .compose(mlir::AffineMap::get( - /*dimCount=*/2 + num_spatial_dims * 2, - /*symbolCount=*/0, input_indices, builder.getContext())), - input_vars)); - }(); - - mlir::Value loaded_filter = [&] { - std::vector filter_vars; - filter_vars.push_back(loop_o.getInductionVar()); - filter_vars.push_back(loop_c.getInductionVar()); - filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(), - filter_spatial_indvars.end()); - - return builder.create( - location, builder.getF32Type(), - builder.createOrFold( - location, filter, filter_shape_info.affine_map, filter_vars)); - }(); - - auto accum_load_op = - builder.createOrFold(location, output_acc); - builder.createOrFold( - location, - builder.create( - location, accum_load_op, - builder.create(location, loaded_input, - loaded_filter)), - output_acc, llvm::ArrayRef()); - - builder.setInsertionPointAfter(reduction_loops[0]); - { - std::vector output_vars; - output_vars.push_back(loop_n.getInductionVar()); - output_vars.push_back(loop_o.getInductionVar()); - output_vars.insert(output_vars.end(), output_spatial_indvars.begin(), - output_spatial_indvars.end()); - builder.createOrFold( - location, - builder.create( - location, builder.getF16Type(), - builder.createOrFold(location, output_acc)), - output, output_shape_info.affine_map, output_vars); - } - - return InitialMlirConvAnchors{cartesian_product_loops, reduction_loops, - output_acc}; -} - -// Contains the following pattern with anchors: -// for (cartesian loops...) { -// %output_acc = alloc() : memref(..., f32) -// for (reduction loops...) { -// for (tiled cartesian loops...) { -// output_acc[...] = 0 -// } -// for (tiled cartesian loops...) { -// for (reduction loops...) { -// output_acc[] += input[...] * filter[...] -// } -// } -// for (tiled cartesian loops...) { -// output[...] = output_acc[...] -// } -// } -// } -struct TransformedMlirConvAnchors { - std::vector cartesian_product_loops; - std::vector reduction_loops; -}; - -StatusOr TransformMlirConv( - InitialMlirConvAnchors anchors) { - std::vector cartesian_product_loops = - anchors.cartesian_product_loops; - std::vector reduction_loops = anchors.reduction_loops; - mlir::memref::AllocOp output_acc = anchors.output_acc; - - // TODO(timshen): consider using pattern matchers for transformations - // - // Initial form: - // for (cartesian loops...) { - // %output_acc = alloc() : memref(f32) - // output_acc[] = 0 - // for (reduction loops...) { - // output_acc[] += input[...] * filter[...] - // } - // output[...] = output_acc[] - // } - - // Tile cartesian loops to: - // for (cartesian loops...) { - // for (tiled cartesian loops...) { - // %output_acc = alloc() : memref(f32) - // output_acc[] = 0 - // for (reduction loops...) { - // output_acc[] += input[...] * filter[...] - // } - // output[...] = output_acc[] - // } - // } - TileLoop(reduction_loops[0], 4, reduction_loops.back()); - - std::vector tiled_cartesian_loops; - tiled_cartesian_loops.push_back( - TileLoop(cartesian_product_loops[1], 32, cartesian_product_loops.back())); - - tiled_cartesian_loops.push_back(TileLoop(cartesian_product_loops.back(), 16, - tiled_cartesian_loops.back())); - - // Two hoist operations to interleave the allocation, computation, and - // writebacks to output_acc: - // After first hoist: - // for (cartesian loops...) { - // %output_acc = alloc() : memref(..., f32) - // for (tiled cartesian loops...) { - // output_acc[...] = 0 - // for (reduction loops...) { - // output_acc[...] += input[...] * filter[...] - // } - // output[...] = output_acc[...] - // } - // } - output_acc = llvm::cast( - HoistAndFix(output_acc, tiled_cartesian_loops.front())); - - // Hoist everything before reduction loops (aka zero initializations of - // output_acc): - // for (cartesian loops...) { - // %output_acc = alloc() : memref(..., f32) - // for (tiled cartesian loops...) { - // output_acc[...] = 0 - // } - // for (tiled cartesian loops...) { - // for (reduction loops...) { - // output_acc[...] += input[...] * filter[...] - // } - // output[...] = output_acc[...] - // } - // } - HoistAndFix(tiled_cartesian_loops.back().getBody()->begin(), - reduction_loops.front().getOperation()->getIterator(), - tiled_cartesian_loops.front()); - - // Now hoist all reduction loops outside of tiled cartesian loops. - // Notice that HoistAndFix automatically add a new set of tiled cartesian - // loops for hoisted reduction loops to keep the semantics correct. - // - // After second hoist: - // for (cartesian loops...) { - // %output_acc = alloc() : memref(..., f32) - // for (tiled cartesian loops...) { - // output_acc[...] = 0 - // } - // for (tiled cartesian loops...) { - // for (reduction loops...) { - // output_acc[] += input[...] * filter[...] - // } - // } // compute loop - // for (tiled cartesian loops...) { - // output[...] = output_acc[...] - // } - // } - { - auto compute_loop = llvm::cast( - HoistAndFix(reduction_loops.front(), tiled_cartesian_loops[0])); - - // Fix tiled_cartesian_loops to make them point to the tiled compute loops, - // not the writeback loops to output buffer. - llvm::SmallVector all_loops; - getPerfectlyNestedLoops(all_loops, compute_loop); - absl::c_copy_n(all_loops, tiled_cartesian_loops.size(), - tiled_cartesian_loops.data()); - } - - // After exchanging tiled cartesian compute loops with reduction loops: - // for (cartesian loops...) { - // %output_acc = alloc() : memref(..., f32) - // for (tiled cartesian loops...) { - // output_acc[...] = 0 - // } - // for (reduction loops...) { - // for (tiled cartesian loops...) { - // output_acc[] += input[...] * filter[...] - // } - // } - // for (tiled cartesian loops...) { - // output[...] = output_acc[...] - // } - // } - // - // ...so that later tiled cartesian loops (with computations in it) can be - // replaced by CUDA MMA instructions. - { - std::vector loops; - loops.insert(loops.end(), tiled_cartesian_loops.begin(), - tiled_cartesian_loops.end()); - loops.insert(loops.end(), reduction_loops.begin(), reduction_loops.end()); - SinkPerfectlyNestedLoops(loops, tiled_cartesian_loops.size()); - } - return TransformedMlirConvAnchors{cartesian_product_loops, reduction_loops}; -} - -} // namespace - -StatusOr EmitConvolutionForwardAsMlir( - HloInstruction* conv, absl::string_view function_name, - mlir::MLIRContext* context) { - OpBuilder builder(context); - - const auto& dim_nums = conv->convolution_dimension_numbers(); - ShapeInfo input_shape_info = - GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(), - dim_nums.input_feature_dimension(), - dim_nums.input_spatial_dimensions(), builder); - - ShapeInfo filter_shape_info = GetShapeInfo( - conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(), - dim_nums.kernel_input_feature_dimension(), - dim_nums.kernel_spatial_dimensions(), builder); - - ShapeInfo output_shape_info = GetShapeInfo( - conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(), - dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(), - builder); - - auto function = mlir::func::FuncOp::create( - mlir::UnknownLoc::get(builder.getContext()), - llvm_ir::AsStringRef(function_name), - builder.getFunctionType( - {mlir::MemRefType::get(output_shape_info.physical_dimensions, - output_shape_info.element_type, - mlir::AffineMap()), - mlir::MemRefType::get(input_shape_info.physical_dimensions, - input_shape_info.element_type, - mlir::AffineMap()), - mlir::MemRefType::get(filter_shape_info.physical_dimensions, - filter_shape_info.element_type, - mlir::AffineMap())}, - {})); - - auto* entry_block = function.addEntryBlock(); - builder.setInsertionPointToStart(entry_block); - builder.create(builder.getUnknownLoc()); - builder.setInsertionPointToStart(entry_block); - - mlir::Value input = entry_block->getArgument(1); - mlir::Value filter = entry_block->getArgument(2); - mlir::Value output = entry_block->getArgument(0); - - TF_RETURN_IF_ERROR(ConvIsImplemented(conv)); - - TF_ASSIGN_OR_RETURN( - InitialMlirConvAnchors initial_anchors, - CreateNaiveMlirConv(input, filter, output, input_shape_info, - filter_shape_info, output_shape_info, conv->window(), - builder)); - - TF_ASSIGN_OR_RETURN(TransformedMlirConvAnchors transformed_anchors, - TransformMlirConv(initial_anchors)); - - // TODO(timshen): Implement a transformation that collects loads to a given - // buffer, create a local alloc() for the accessed part, redirects all loads - // and stores to that local alloc(), and create code to initialize / - // writeback the local alloc() if needed. - - // TODO(timshen): Implement CUDA-specific lowering. - - return function; -} - -Status ConvIsImplemented(const HloInstruction* conv) { - if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) { - return Unimplemented("group count is not implemented."); - } - if (window_util::HasWindowReversal(conv->window())) { - return Unimplemented("Window reversal is not implemented."); - } - if (window_util::HasDilation(conv->window())) { - return Unimplemented("Dilation is not implemented."); - } - return ::tsl::OkStatus(); -} - -} // namespace experimental -} // namespace xla diff --git a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h deleted file mode 100644 index a380800b2f7..00000000000 --- a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ -#define TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" - -namespace xla { -namespace experimental { - -// Builds MLIR using custom_call that represents a foward convolution. -// -// The generated function has the following signature: -// func @(%output: memref, -// %input: memref, -// %filter: memref) { ... } -// -// Note that the custom_call is XLA/GPU-specific, as it calls into cuDNN's -// forward convolution. However, here we are building a MLIR custom emitter, and -// we are not calling into cuDNN. We just want to borrow the HLO representation -// that already exists in XLA/GPU backend. -// -// `input`, `filter`, `output` are convolution inputs. -StatusOr EmitConvolutionForwardAsMlir( - HloInstruction* conv, absl::string_view function_name, - mlir::MLIRContext* context); - -// Returns OkStatus() if convolution can be implemented by this emitter. -Status ConvIsImplemented(const HloInstruction* conv); - -} // namespace experimental -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ diff --git a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_test.cc deleted file mode 100644 index c0ab3e283a1..00000000000 --- a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_test.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h" - -#include - -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/tests/filecheck.h" -#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" -#include "tensorflow/tsl/platform/test.h" - -namespace xla { -namespace experimental { -namespace { - -std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) { - xla::HloModuleConfig hlo_config; - VerifiedHloModule hlo_module( - "Conv", hlo_config, /*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/true, - /*shape_size_function=*/ShapeUtil::ByteSizeOfElements); - TF_CHECK_OK(hlo_module.ParseHloStringAndVerifyModule(hlo_text)); - xla::HloInstruction* conv = - hlo_module.entry_computation()->root_instruction(); - - mlir::MLIRContext context; - context.loadDialect(); - mlir::OwningOpRef mlir_module( - mlir::ModuleOp::create(mlir::UnknownLoc::get(&context))); - - mlir::func::FuncOp function = - EmitConvolutionForwardAsMlir(conv, "Conv", &context).value(); - - mlir_module->push_back(function); - (void)mlir_module->verifyInvariants(); - - std::string mlir_text = llvm_ir::DumpToString(function); - VLOG(1) << mlir_text; - - { - mlir::PassManager pm(mlir_module->getContext()); - pm.addPass(mlir::createLowerAffinePass()); - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); - pm.addPass(mlir::createConvertFuncToLLVMPass()); - CHECK(mlir::succeeded(pm.run(*mlir_module))); - } - - return mlir_text; -} - -// TODO(timshen): integrate this with mlir's testing infrastructure. -TEST(ConvEmitterTest, TestDefault) { - std::string hlo_text = R"(HloModule TestModule -ENTRY %TestComputation { - %param_0 = f16[128,4,224,224]{1,3,2,0} parameter(0) - %param_1 = f16[7,7,64,4]{3,1,0,2} parameter(1) - ROOT %custom-call.1 = (f16[128,64,112,112]{1,3,2,0}, u8[0]{0}) custom-call(%param_0, %param_1), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=bf01_01oi->bf01, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}" -})"; - - std::string expected_mlir_pattern = - R"( -CHECK: func @Conv(%arg0: memref<128x112x112x64xf16>, %arg1: memref<128x224x224x4xf16>, %arg2: memref<64x7x7x4xf16>) { -CHECK-NEXT: affine.for %arg3 = 0 to 128 { -CHECK-NEXT: affine.for %arg4 = 0 to 2 { -CHECK-NEXT: affine.for %arg5 = 0 to 112 { -CHECK-NEXT: affine.for %arg6 = 0 to 7 { -CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32x16xf32> -CHECK-NEXT: affine.for %arg7 = 0 to 32 { -CHECK-NEXT: affine.for %arg8 = 0 to 16 { -CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -CHECK-NEXT: affine.store %cst, %[[ALLOC]][%arg7, %arg8] : memref<32x16xf32> -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: affine.for %arg7 = 0 to 1 { -CHECK-NEXT: affine.for %arg8 = 0 to 7 { -CHECK-NEXT: affine.for %arg9 = 0 to 7 { -CHECK-NEXT: affine.for %arg10 = 0 to 32 { -CHECK-NEXT: affine.for %arg11 = 0 to 16 { -CHECK-NEXT: affine.for %arg12 = 0 to 4 { -CHECK-NEXT: %[[LOAD0:.*]] = affine.load %arg1[%arg3, %arg5 * 2 + %arg8 - 3, (%arg6 * 16 + %arg11) * 2 + %arg9 - 3, %arg7 * 4 + %arg12] : memref<128x224x224x4xf16> -CHECK-NEXT: %[[EXT0:.*]] = arith.extf %[[LOAD0]] : f16 to f32 -CHECK-NEXT: %[[LOAD1:.*]] = affine.load %arg2[%arg4 * 32 + %arg10, %arg8, %arg9, %arg7 * 4 + %arg12] : memref<64x7x7x4xf16> -CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[LOAD1]] : f16 to f32 -CHECK-NEXT: %[[LOAD2:.*]] = affine.load %[[ALLOC]][%arg10, %arg11] : memref<32x16xf32> -CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32 -CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[LOAD2]], %[[MUL]] : f32 -CHECK-NEXT: affine.store %[[ADD]], %[[ALLOC]][%arg10, %arg11] : memref<32x16xf32> -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: affine.for %arg7 = 0 to 32 { -CHECK-NEXT: affine.for %arg8 = 0 to 16 { -CHECK-NEXT: %[[LOAD:.*]] = affine.load %[[ALLOC]][%arg7, %arg8] : memref<32x16xf32> -CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[LOAD]] : f32 to f16 -CHECK-NEXT: affine.store %[[TRUNC]], %arg0[%arg3, %arg5, %arg6 * 16 + %arg8, %arg4 * 32 + %arg7] : memref<128x112x112x64xf16> -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: } -CHECK-NEXT: return -CHECK-NEXT: } -)"; - - EXPECT_TRUE( - RunFileCheck(CompileHloConvAndGetMlir(hlo_text), expected_mlir_pattern) - .value()); -} - -} // namespace -} // namespace experimental -} // namespace xla diff --git a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.cc b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.cc deleted file mode 100644 index 91268062959..00000000000 --- a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h" - -#include - -#include "absl/algorithm/container.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project -#include "tensorflow/tsl/platform/logging.h" - -namespace xla { -namespace experimental { - -using mlir::OpBuilder; - -BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) { - if (auto load = mlir::dyn_cast(op)) { - return {load.getAffineMap(), - std::vector(load.getMapOperands().begin(), - load.getMapOperands().end())}; - } else if (auto store = mlir::dyn_cast(op)) { - return {store.getAffineMap(), - std::vector(store.getMapOperands().begin(), - store.getMapOperands().end())}; - } else { - CHECK(false); - } -} - -mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op, - BoundAffineMap new_affine, - OpBuilder builder) { - if (auto load = mlir::dyn_cast(op)) { - return builder.create( - builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map, - new_affine.operands); - } else if (auto store = mlir::dyn_cast(op)) { - return builder.create( - builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(), - new_affine.affine_map, new_affine.operands); - } else { - CHECK(false); - } -} - -bool IsSimpleLoop(mlir::AffineForOp loop) { - return loop.getLowerBoundMap().isSingleConstant() && - loop.getLowerBoundMap().getSingleConstantResult() == 0 && - loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 && - std::next(loop.getRegion().begin()) == loop.getRegion().end(); -} - -std::vector CreateNestedSimpleLoops( - absl::Span upper_bounds, OpBuilder builder) { - std::vector loops; - loops.reserve(upper_bounds.size()); - for (int64_t dim : upper_bounds) { - auto loop = - builder.create(builder.getUnknownLoc(), 0, dim); - loops.push_back(loop); - builder = OpBuilder::atBlockTerminator(loop.getBody()); - } - return loops; -} - -void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound, - OpBuilder builder) { - CHECK(IsSimpleLoop(loop)); - - loop.setUpperBoundMap(mlir::AffineMap::get( - loop.getUpperBoundMap().getNumDims(), - loop.getUpperBoundMap().getNumSymbols(), {new_bound})); -} - -mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, - mlir::AffineForOp target) { - CHECK(IsSimpleLoop(loop)); - CHECK(IsSimpleLoop(target)); - { - llvm::SmallVector all_loops; - getPerfectlyNestedLoops(all_loops, loop); - CHECK(absl::c_linear_search(all_loops, target)); - } - - auto builder = OpBuilder::atBlockTerminator(target.getBody()); - - auto inner_loop = - builder.create(builder.getUnknownLoc(), 0, size); - { - auto& inner_operations = inner_loop.getBody()->getOperations(); - auto& target_operations = target.getBody()->getOperations(); - - inner_operations.splice(inner_operations.begin(), target_operations, - target_operations.begin(), - std::prev(target_operations.end(), 2)); - - mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0); - CHECK_EQ(0, length.cast().getValue() % size); - SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder); - } - - for (auto& use : - llvm::make_early_inc_range(loop.getInductionVar().getUses())) { - mlir::Operation* owner = use.getOwner(); - BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); - unsigned new_dim = affine_map.operands.size(); - affine_map.operands.push_back(inner_loop.getInductionVar()); - std::vector replacements; - for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) { - if (affine_map.operands[i] == loop.getInductionVar()) { - replacements.push_back(builder.getAffineDimExpr(i) * size + - builder.getAffineDimExpr(new_dim)); - } else { - replacements.push_back(builder.getAffineDimExpr(i)); - } - } - affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols( - replacements, {}, affine_map.operands.size(), 0); - auto new_op = CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner)); - owner->replaceAllUsesWith(new_op); - owner->erase(); - } - return inner_loop; -} - -void SinkPerfectlyNestedLoops(llvm::MutableArrayRef loops, - int rotate_amount) { - CHECK_GE(rotate_amount, 0); - std::vector permutation(loops.size()); - std::iota(permutation.begin(), permutation.end(), unsigned(0)); - std::rotate(permutation.begin(), - permutation.begin() + loops.size() - rotate_amount, - permutation.end()); - mlir::permuteLoops(loops, permutation); -} - -} // namespace experimental -} // namespace xla diff --git a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h deleted file mode 100644 index 97c44daa52f..00000000000 --- a/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ -#define TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ - -#include "absl/types/span.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "tensorflow/tsl/platform/types.h" - -namespace xla { -namespace experimental { - -struct BoundAffineMap { - mlir::AffineMap affine_map; - std::vector operands; -}; - -BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op); -mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op, - BoundAffineMap new_affine, - mlir::OpBuilder builder); - -bool IsSimpleLoop(mlir::AffineForOp loop); -std::vector CreateNestedSimpleLoops( - absl::Span upper_bounds, mlir::OpBuilder builder); -void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound, - mlir::OpBuilder builder); - -// Tile a loop with trip count N by `size`. For now, N has to be a multiple of -// size, but later this constraint will be removed. -// -// The major loop (with trip count N / size) stays as-is, while the minor loop -// (with trip count `size`) will take over the body of `target`, and be placed -// as the new body of `target`. -// -// `target` has to be within the same "perfectly nested loop group" as `loop`. -// See the documentation for mlir::getPerfectlyNestedLoops. -// -// Example: -// Before tiling `loop` with tile size X: -// for (loop in N) -// for (unrelated_loop in ...) -// for (target in ...) -// // pass loop into affine maps -// After: -// for (loop in N / X) -// for (unrelated_loop in ...) -// for (target in ...) -// for (tiled_loop in X) -// // rewrite all affine exprs from loop to `loop * X + tiled_loop`. -// -// Design note: -// TileLoop is different from mlir::tile. At the moment, mlir::tile is not well -// documented about the exact tiling semantics, but the observed behavior is: -// for (i from 0 to N) -// for (unrelated_loop in ...) -// for (target in ...) -// // pass i into affine maps -// => -// for (i from 0 to N, step = X) -// for (unrelated_loop in ...) -// for (target in ...) -// for (j from i to min(i + X, N), step = 1) -// // pass j into affine maps -// -// There are two differences between mlir::tile and TileLoop: -// * TileLoop always puts the tiling logic "stepping" logic into AffineExprs. -// With that all index calculation is done in AffineExprs and easier to -// analyze in a single place. -// * TileLoop doesn't plan to use max() and min() to resolve the issue when -// N % X != 0. max() and min() are not representable in AffineExprs. -// TODO(timshen): support the case where N % X != 0. -// -// TODO(timshen): consider the possibility to reuse mlir::tile's logic to -// achieve the same goal. -mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, - mlir::AffineForOp target); - -// Sinks a segment of perfectly nested loops to the bottom. It implements this -// by rotating the loop nest by rotate_amount. -void SinkPerfectlyNestedLoops(llvm::MutableArrayRef loops, - int rotate_amount); - -} // namespace experimental -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ diff --git a/tensorflow/compiler/xla/experimental/conv_emitter/g3doc/conv_emitter.md b/tensorflow/compiler/xla/experimental/conv_emitter/g3doc/conv_emitter.md deleted file mode 100644 index 6151357372d..00000000000 --- a/tensorflow/compiler/xla/experimental/conv_emitter/g3doc/conv_emitter.md +++ /dev/null @@ -1,324 +0,0 @@ -# Convolution Emitter - -## Context - -This is a doc that describes a set of patches that are still under review. -TODO(timshen): Change once all patches are checked in. - -The convolution emitter is a prototype with the following goals: - -* The top priority is performance. -* It supports arbitrarily sophisticated layouts. -* It supports platform-specific high-performance instructions. -* It is as portable as possible. -* It enables fusion support in the future. - -## Current Design - -### Overview - -The prototype consists of the following components: - -* The emitter currently focuses on NVIDIA Volta architecture and N(C/4)HW4 - layout. -* An MLIR-based emitter. It takes a set of tuning parameters and a convolution - configuration, then produces a NVVM device function. -* An autotuner, which generates tuning parameters given a convolution - configuration. -* A test framework, which executes the generated device function with random - inputs, and compares the result against cuDNN. - -### The Emitter - Naive Implementation - -The emitter starts with a hand-built, naive implementation that looks like -following Resnet first layer convolution (pseudo code): - -```mlir -func @Conv(%input : memref<128x1x224x224xvector<4xf16>>, - %filter : memref<64x1x7x7xvector<4xf16>>, - %output : memref<128x64x224x224xf16>) { - affine.parallel (%n, %o, %oh, %ow) = 0 to 128, 0 to 64, 0 to 112, 0 to 112 { - %acc = alloc() : memref - affine.store 0, %acc[] - affine.for (%c, %fh, %fw) = 0 to 1, 0 to 7, 0 to 7 { - %a = affine.padded.load %input[%n, %c, %oh * 2 + %fh - 3, %ow * 2 + %fw - 3] - %b = affine.load %filter[%o, %c, %fh, %fw] - %c = affine.load %acc[] - %d = std.fpext %a to vector<4xf32> - %e = std.fpext %b to vector<4xf32> - %f = std.multiply %d, %e - %g = "reduce" %f - %v = %g + %c - affine.store %v, %acc[] - } - %c = affine.load %acc[] - affine.store %acc, %output[%n, %o, %oh, %ow] - } -} -``` - -A few extensions are used in the example above: - -* affine.padded.load allows out-of-bounds access, in which case the result is - always 0. -* The "reduce" operation produces the sum of elements in a vector. - -Also notice that the input element type is vector<4xf16> only because the -current implementation does so. A MemRef with <...x4xf16> should work as well, -given the alignment properly aligned to at least 8 (usually 16). - -Then the emitter does a few semantic preserving transformations to work the code -towards PTX's structure. - -### The Emitter - Tiling - -The following is the naive code after loop tiling: - -```mlir -func @Conv(%input : memref<128x1x224x224xvector<4xf16>>, - %filter : memref<64x1x7x7xvector<4xf16>>, - %output : memref<128x64x224x224xf16>) { - affine.parallel (%n0, %o0, %oh0, %ow0) = 0 to 128, 0 to 1, 0 to 7, 0 to 7 { - affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - %acc = alloc() : memref - affine.store 0, %acc[] - affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 { - affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 { - %a = affine.padded.load %input[ - %n0 * 1 + %n1, - %c0 * 1 + %c1, - (%oh0 * 16 + %oh1) * 2 + %fh0 * 7 + %fh1 - 3, - (%ow0 * 16 + %ow1) * 2 + %fw0 * 7 + %fw1 - 3] - %b = affine.load %filter[ - %o0 * 64 + %o1, - %c0 * 1 + %c1, - %fh0 * 7 + %fh1, - %fw0 * 7 + %fw1] - %old = affine.load %acc[] - %d = std.fpext %a to vector<4xf32> - %e = std.fpext %b to vector<4xf32> - %f = std.multiply %d, %e - %g = "reduce" %f - %new = %g + %old - affine.store %new, %acc[] - } - } - %v = affine.load %acc[] - affine.store %v, %output[ - %n0 * 1 + %n1, - %o0 * 64 + %o1, - %oh0 * 16 + %oh1, - %ow0 * 16 + %ow1] - } { ptx_block } - } { ptx_grid } -} -``` - -The motivation is obvious - we need to decide which loops are parallelized on -the compute units in the PTX architecture. The `ptx_grid` and `ptx_block` -directs that the loop should be parallelized on a grid / a block, respectively. - -Also notice that to keep the code pattern clean and neat, tiling is implemented -in the following way. Defining "simple loop" as a loop with lower bound 0, and -step 1, the tiling: - -* only takes simple loops. -* only produces simple loops. -* no extra operation is generated. All altered index calculations are done in - each user AffineMaps. - -The contracting dimensions (%c, %fh, %fw) are also tiled for once. The -significance will be seen later in shared memory promotion. - -### The Emitter - Splitting - -This step splits the body of the (%n1, %o1, %oh1, %ow1) loop into several parts: - -* The code that sets the accumulators to 0. -* The actual convolution computation code. -* The code that writes back accumulators to the %output buffer. - -This transformation "vectorizes" the accumulator accordingly as the `alloc()` -gets hoisted out of the `affine.parallel` op. - -After splitting: - -```mlir -func @Conv(%input : memref<128x1x224x224xvector<4xf16>>, - %filter : memref<64x1x7x7xvector<4xf16>>, - %output : memref<128x64x224x224xf16>) { - affine.parallel (%n0, %o0, %oh0, %ow0) = 0 to 128, 0 to 1, 0 to 7, 0 to 7 { - %acc = alloc() : memref<1x64x16x16xf32> - affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - affine.store 0, %acc[%n1, %o1, %oh1, %ow1] - } { ptx_block } - affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 { - affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 { - %a = affine.padded.load %input[ - %n0 * 1 + %n1, - %c0 * 1 + %c1, - (%oh0 * 16 + %oh1) * 2 + %fh0 * 7 + %fh1 - 3, - (%ow0 * 16 + %ow1) * 2 + %fw0 * 7 + %fw1 - 3] - %b = affine.load %filter[ - %o0 * 64 + %o1, - %c0 * 1 + %c1, - %fh0 * 7 + %fh1, - %fw0 * 7 + %fw1] - %old = affine.load %acc[%n1, %o1, %oh1, %ow1] - %d = std.fpext %a to vector<4xf32> - %e = std.fpext %b to vector<4xf32> - %f = std.multiply %d, %e - %g = "reduce" %f - %new = %g + %old - affine.store %new, %acc[%n1, %o1, %oh1, %ow1] - } - } - } { ptx_block } - affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - %v = affine.load %acc[%n1, %o1, %oh1, %ow1] - affine.store %v, %output[ - %n0 * 1 + %n1, - %o0 * 64 + %o1, - %oh0 * 16 + %oh1, - %ow0 * 16 + %ow1] - } { ptx_block } - } { ptx_grid } -} -``` - -To prepare for the next transformations, we'd also like to sink the (%n1, %o1, -%oh1, %ow1), as (%c0, %fh0, %fw0) is not interesting. - -``` -affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 { - affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 { - ... - } - } -} { ptx_block } - -=> - -affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 { - affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 { - affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - ... - } { ptx_block } - } -} -``` - -### The Emitter - Shared Memory Promotion - -This transformation is done by `affineDataCopyGenerate`, which does precise -calculation on how much memory is transferred for a load operation. - -After calculating the sizes of the shared memory buffer (`%promoted_input` and -`%promoted_filter`), the transformation also creates loads and stores to -pre-fetch data from global memory (`%input`, `%filter`) to the promoted, shared -memory. - -```mlir -// Before -affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 { - affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - %a = affine.padded.load %input[ - %n0 * 1 + %n1, - %c0 * 1 + %c1, - (%oh0 * 16 + %oh1) * 2 + %fh0 * 7 + %fh1 - 3, - (%ow0 * 16 + %ow1) * 2 + %fw0 * 7 + %fw1 - 3] - %b = affine.load %filter[ - %o0 * 64 + %o1, - %c0 * 1 + %c1, - %fh0 * 7 + %fh1, - %fw0 * 7 + %fw1] - %old = affine.load %acc[%n1, %o1, %oh1, %ow1] - %d = std.fpext %a to vector<4xf32> - %e = std.fpext %b to vector<4xf32> - %f = std.multiply %d, %e - %g = "reduce" %f - %new = %g + %old - affine.store %new, %acc[%n1, %o1, %oh1, %ow1] - } { ptx_block } -} -``` - -```mlir -// After - -%promoted_input = alloc() : memref<1x1x37x37, memory_space = 3> -%promoted_filter = alloc() : memref<64x1x7x7, memory_space = 3> -affine.parallel (%i0, %i1, %i2, %i3) = 0 to 1, 0 to 1, 0 to 37, 0 to 37 { - %v = affine.padded.load %input[ - %n0 * 1 + %i0, - %c0 * 1 + %i1, - (%oh0 * 16) * 2 + %fh0 * 7 + %i2 - 3, - (%ow0 * 16) * 2 + %fw0 * 7 + %i3 - 3] - affine.store %v, %promoted_input[%i0, %i1, %i2, %i3] -} { ptx_block } -affine.parallel (%i0, %i1, %i2, %i3) = 0 to 64, 0 to 1, 0 to 7, 0 to 7 { - %v = affine.load %filter[ - %o0 * 64 + %i0, - %c0 * 1 + %i1, - %fh0 * 7 + %i2, - %fw0 * 7 + %i3] - affine.store %v, %promoted_filter[%i0, %i1, %i2, %i3] -} { ptx_block } -affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 { - affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - %a = affine.load %promoted_input[%n1, %c1, %oh1 * 2 + %fh1, %ow1 * 2 + %fw1] - %b = affine.load %promoted_filter[%o1, %c1, %fh1, %fw1] - %old = affine.load %acc[%n1, %o1, %oh1, %ow1] - %d = std.fpext %a to vector<4xf32> - %e = std.fpext %b to vector<4xf32> - %f = std.multiply %d, %e - %g = "reduce" %f - %new = %g + %old - affine.store %new, %acc[%n1, %o1, %oh1, %ow1] - } { ptx_block } -} -``` - -### The Emitter - Volta MMA Instruction - -This transformation turns the inner loop: - -```mlir -affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 { - %a = affine.load %promoted_input[%n1, %c1, %oh1 * 2 + %fh1, %ow1 * 2 + %fw1] - %b = affine.load %promoted_filter[%o1, %c1, %fh1, %fw1] - %old = affine.load %acc[%n1, %o1, %oh1, %ow1] - %d = std.fpext %a to vector<4xf32> - %e = std.fpext %b to vector<4xf32> - %f = std.multiply %d, %e - %g = "reduce" %f - %new = %g + %old - affine.store %new, %acc[%n1, %o1, %oh1, %ow1] -} { ptx_block } -``` - -to multiple Volta mma.sync instructions. The result is not shown here, because -the prototype currently only hacks it up to achieve benchmark goals. - -### The Autotuner - -As shown above, many parameters dictate how a naive implementation is -transformed. For now, the parameters are all tile sizes. On the top of the -emitter, the prototype includes a simple autotuner that enumerates all good -combinations of tile sizes and invoke the emitter with each of the combinations. -With the assistance of in-process benchmarking, the autotuner is able to pick -the best set of parameters. - -## Future Improvements - -* Explore Linalg/Vector for a higher-level naive implementation. MMA - instruction handling would be much easier with high-level functional - constructs. -* Explore other layouts. The current layout corresponds to NVIDIA - `CUDNN_TENSOR_NCHW_VECT_C` but for fp16s. -* Iron out GPU dialect related lowering. Annotations like `ptx_grid` and - `ptx_block` should be generalized to more architectures. -* Speed up autotuning through more pruning. -* Support dynamic shapes. diff --git a/tensorflow/compiler/xla/experiments/BUILD b/tensorflow/compiler/xla/experiments/BUILD new file mode 100644 index 00000000000..d298feaf3f0 --- /dev/null +++ b/tensorflow/compiler/xla/experiments/BUILD @@ -0,0 +1,8 @@ +# Various experiments related to the compiler that are not a part of the final XLA binary. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # keep visibility private, if you need to depend on this, move it out of experiments + default_visibility = ["//visibility:private"], + licenses = ["notice"], +) diff --git a/tensorflow/compiler/xla/experiments/README.md b/tensorflow/compiler/xla/experiments/README.md new file mode 100644 index 00000000000..502dcb82913 --- /dev/null +++ b/tensorflow/compiler/xla/experiments/README.md @@ -0,0 +1,24 @@ +# XLA Experiments + +This folder is intended to serve as a place to collaborate on code related to +the XLA compiler, but will not end up being a part of the compiler itself. + +As such, the code here is not necessarily production quality, and should not be +depended on from other parts of the compiler. + +Some examples of code appropriate for this folder are: + +* microbenchmarks that allow us to better understand various architectures +* scripts that help with developing specific features of the compiler, which + might remain useful after the feature is complete (general tools should + instead go into the xla/tools directory) +* experimental code transformations that are not yet integrated into the + compiler + +## Visibility + +As a result of the nature of the content in this folder, its build visibility +is intentionally kept private. + +If you need something from here elsewhere, the recommended approach is to move +it to a more suitable and production-supported location. \ No newline at end of file diff --git a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/BUILD b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/BUILD new file mode 100644 index 00000000000..769e8a76bb0 --- /dev/null +++ b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/BUILD @@ -0,0 +1,32 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda") + +cc_library( + name = "sm_bw_utils", + hdrs = ["sm_bw_utils.h"], + defines = if_cuda(["GOOGLE_CUDA=1"]), + deps = [ + "//tensorflow/tsl/platform:logging", + ] + if_cuda([ + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +cuda_library( + name = "sm_bw_kernels", + srcs = ["sm_bw_kernels.cu.cc"], + hdrs = ["sm_bw_kernels.h"], + deps = [ + ":sm_bw_utils", + ], +) + +cc_test( + name = "sm_bw_test", + srcs = ["sm_bw_test.cc"], + tags = ["requires-gpu-sm80-only"], + deps = [ + ":sm_bw_kernels", + ":sm_bw_utils", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc new file mode 100644 index 00000000000..d0bc62cd0de --- /dev/null +++ b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.cu.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#include "tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h" + +#include "tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_utils.h" + +namespace experiments { +namespace benchmark { +#define DFUNC __forceinline__ __device__ +#define HDFUNC DFUNC __host__ + +constexpr int kMaxBlockSize = 1024; + +template +class Vec { + public: + using ElementType = ET; + constexpr static size_t Size = S; + + template + HDFUNC Vec(Ts... elements) : data_() { + InsertElements(0, elements...); + } + + HDFUNC ElementType& operator[](size_t idx) { return data_[idx]; } + HDFUNC const ElementType& operator[](size_t idx) const { return data_[idx]; } + + private: + template + HDFUNC void InsertElements(size_t idx, T element, Ts... rest) { + data_[idx] = element; + InsertElements(idx + 1, rest...); + } + HDFUNC void InsertElements(size_t idx) {} + + ElementType data_[Size]; +}; + +template +DFUNC void Store(VectorType vx, T* __restrict__ x, size_t id) { + reinterpret_cast(x)[id] = vx; +} +template <> +DFUNC void Store(Vec vx, float* __restrict__ x, size_t id) { + asm("st.global.v4.f32 [%0], {%1, %2, %3, %4};" + : + : "l"(x + 4 * id), "f"(vx[0]), "f"(vx[1]), "f"(vx[2]), "f"(vx[3])); +} + +template +DFUNC void LoadNc(VectorType& vx, const T* __restrict__ x, size_t id) { + vx = reinterpret_cast(x)[id]; +} + +template <> +DFUNC void LoadNc(Vec& vx, const float* __restrict__ x, size_t id) { + asm("ld.global.nc.v4.f32 {%0, %1, %2, %3}, [%4];" + : "=f"(vx[0]), "=f"(vx[1]), "=f"(vx[2]), "=f"(vx[3]) + : "l"(x + 4 * id)); +} + +template +__launch_bounds__(kMaxBlockSize) __global__ + void BenchmarkDeviceCopyKernel(const float* __restrict__ in, + float* __restrict__ out, int64_t size) { + const int64_t lines = size / (blockDim.x * chunks); + const int64_t start_line = lines * blockIdx.x / gridDim.x; + const int64_t end_line = lines * (blockIdx.x + 1) / gridDim.x; + const int64_t start_offset = + start_line * blockDim.x * chunks + 4 * threadIdx.x; + const int64_t end_offset = end_line * blockDim.x * chunks; + Vec buffer[chunks / 4]; + for (int64_t i = start_offset; i < end_offset; i += blockDim.x * chunks) { +#pragma unroll + for (int j = 0; j < chunks; j += 4) { + LoadNc(buffer[j / 4], in + i + blockDim.x * j, 0); + } +#pragma unroll + for (int j = 0; j < chunks; j += 4) { + Store(buffer[j / 4], out + i + blockDim.x * j, 0); + } + } +} + +template +void BenchmarkDeviceCopy(float* in, float* out, int64_t size, int num_blocks, + int num_threads) { + BenchmarkDeviceCopyKernel<<>>(in, out, size); + CHECK_CUDA(cudaGetLastError()); +} + +template void BenchmarkDeviceCopy<1>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 1>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 2>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 3>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 4>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 5>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 6>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 7>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 8>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 9>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +template void BenchmarkDeviceCopy<1 << 10>(float* in, float* out, int64_t size, + int num_blocks, int num_threads); +} // namespace benchmark +} // namespace experiments + +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h new file mode 100644 index 00000000000..ea398f04a06 --- /dev/null +++ b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_EXPERIMENTS_SM_BANDWIDTH_BENCHMARK_SM_BW_KERNELS_H_ +#define TENSORFLOW_COMPILER_XLA_EXPERIMENTS_SM_BANDWIDTH_BENCHMARK_SM_BW_KERNELS_H_ + +namespace experiments { +namespace benchmark { + +template +void BenchmarkDeviceCopy(float* in, float* out, int64_t size, int num_blocks, + int num_threads); + +} // namespace benchmark +} // namespace experiments + +#endif // TENSORFLOW_COMPILER_XLA_EXPERIMENTS_SM_BANDWIDTH_BENCHMARK_SM_BW_KERNELS_H_ diff --git a/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_test.cc b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_test.cc new file mode 100644 index 00000000000..e4bacd79c3f --- /dev/null +++ b/tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_test.cc @@ -0,0 +1,315 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#include + +#include +#include "tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_kernels.h" +#include "tensorflow/compiler/xla/experiments/sm_bandwidth_benchmark/sm_bw_utils.h" + +namespace experiments { +namespace benchmark { +namespace { + +constexpr int kNumSM = 108; +constexpr int kNum32BitRegisters = 64 * 1024; +constexpr int kMaxBlockSize = 1024; + +template +struct DeviceMemoryDeleter { + void operator()(T* ptr) { cudaFree(ptr); } +}; +template +using DeviceMemory = std::unique_ptr>; + +template +DeviceMemory MakeDeviceMemory(size_t size) { + T* gpu_ptr = nullptr; + CHECK_CUDA(cudaMalloc(reinterpret_cast(&gpu_ptr), size * sizeof(T))); + return DeviceMemory(gpu_ptr); +} + +template +struct HostMemoryDeleter { + void operator()(T* ptr) { free(ptr); } +}; +template +using HostMemory = std::unique_ptr>; + +template +HostMemory MakeHostMemory(size_t size) { + T* h_in = (T*)malloc(size * sizeof(T)); + return HostMemory(h_in); +} + +struct EventDeleter { + using pointer = cudaEvent_t; + void operator()(pointer event) { cudaEventDestroy(event); } +}; +using Event = std::unique_ptr; +Event MakeEvent() { + cudaEvent_t event = nullptr; + CHECK_CUDA(cudaEventCreate(&event)); + return Event(event); +} + +bool CheckOutputAndClean(float* h_in, float* h_out, float* d_out, size_t size) { + cudaMemcpy(h_out, d_out, size * sizeof(float), cudaMemcpyDeviceToHost); + + for (size_t i = 0; i < size; i++) { + if ((h_in[i] - h_out[i]) > 1e-6) { + LOG(ERROR) << "mismatch :(, i = " << i << " , values are " << h_in[i] + << ", " << h_out[i]; + return false; + } + h_out[i] = 0; + } + return true; +} + +template +float BenchmarkCustomDeviceCopy(int kReps, float* d_in, float* d_out, + size_t size, int num_blocks = kNumSM, + int num_threads = 64) { + Event start = MakeEvent(); + Event stop = MakeEvent(); + CHECK_CUDA(cudaEventRecord(start.get())); + for (int i = 0; i < kReps; i++) { + BenchmarkDeviceCopy(d_in, d_out, size, num_blocks, num_threads); + } + CHECK_CUDA(cudaEventRecord(stop.get())); + CHECK_CUDA(cudaEventSynchronize(stop.get())); + float time_diff = 0.0f; + CHECK_CUDA(cudaEventElapsedTime(&time_diff, start.get(), stop.get())); + return time_diff / kReps; +} + +float BenchmarkDev2DevCopy(int kReps, float* d_in, float* d_out, size_t size) { + Event start = MakeEvent(); + Event stop = MakeEvent(); + CHECK_CUDA(cudaEventRecord(start.get())); + for (int i = 0; i < kReps; i++) { + CHECK_CUDA(cudaMemcpy(d_out, d_in, size * sizeof(float), + cudaMemcpyDeviceToDevice)); + } + CHECK_CUDA(cudaEventRecord(stop.get())); + CHECK_CUDA(cudaEventSynchronize(stop.get())); + float time_diff = 0.0f; + CHECK_CUDA(cudaEventElapsedTime(&time_diff, start.get(), stop.get())); + return time_diff / kReps; +} + +// B/ms -> TB/s +float TbPerSec(size_t size, float time_diff) { + return 2 * sizeof(float) * size / (1e9 * time_diff); +} + +TEST(SMBandwidthTest, IncreasingMemorySize) { + constexpr int64_t kOneM = 1024 * 1024; + constexpr int64_t kOneG = 1024 * 1024 * 1024; + constexpr int64_t kMaxSize = kOneG; + + DeviceMemory d_in = MakeDeviceMemory(kMaxSize); + DeviceMemory d_out = MakeDeviceMemory(kMaxSize); + + HostMemory h_in = MakeHostMemory(kMaxSize); + HostMemory h_out = MakeHostMemory(kMaxSize); + + for (size_t i = 0; i < kMaxSize; i++) { + h_in.get()[i] = i; + } + CHECK_CUDA(cudaMemcpy(d_in.get(), h_in.get(), kMaxSize * sizeof(float), + cudaMemcpyHostToDevice)); + + constexpr int kReps = 10; + LOG(ERROR) << "size,custom TB/s,devTodev TB/s"; + for (size_t size = kOneM; size <= kMaxSize; size *= 2) { + float time_diff_c = + BenchmarkCustomDeviceCopy<1>(kReps, d_in.get(), d_out.get(), size); + EXPECT_TRUE( + CheckOutputAndClean(h_in.get(), h_out.get(), d_out.get(), size)); + + float time_diff_d2d = + BenchmarkDev2DevCopy(kReps, d_in.get(), d_out.get(), size); + EXPECT_TRUE( + CheckOutputAndClean(h_in.get(), h_out.get(), d_out.get(), size)); + + LOG(ERROR) << size << "," << TbPerSec(size, time_diff_c) << "," + << TbPerSec(size, time_diff_d2d); + } +} + +TEST(SMBandwidthTest, IncreasingNumBlocks) { + constexpr size_t kSize = 1 << 28; + constexpr int kReps = 10; + constexpr int kNumThreads = 64; + + DeviceMemory d_in = MakeDeviceMemory(kSize); + DeviceMemory d_out = MakeDeviceMemory(kSize); + + HostMemory h_in = MakeHostMemory(kSize); + HostMemory h_out = MakeHostMemory(kSize); + + for (size_t i = 0; i < kSize; i++) { + h_in.get()[i] = i; + } + CHECK_CUDA(cudaMemcpy(d_in.get(), h_in.get(), kSize * sizeof(float), + cudaMemcpyHostToDevice)); + + LOG(ERROR) << "num_blocks,TB/s"; + for (int64_t num_blocks = kNumSM; num_blocks <= kNumSM * 32; + num_blocks += kNumSM) { + Event start = MakeEvent(); + Event stop = MakeEvent(); + CHECK_CUDA(cudaEventRecord(start.get())); + for (int i = 0; i < kReps; i++) { + BenchmarkDeviceCopy<1>(d_in.get(), d_out.get(), kSize, num_blocks, + kNumThreads); + } + CHECK_CUDA(cudaEventRecord(stop.get())); + CHECK_CUDA(cudaEventSynchronize(stop.get())); + float time_diff = 0.0f; + CHECK_CUDA(cudaEventElapsedTime(&time_diff, start.get(), stop.get())); + time_diff /= kReps; + LOG(ERROR) << num_blocks << "," << TbPerSec(kSize, time_diff); + + CHECK_CUDA(cudaMemcpy(h_out.get(), d_out.get(), kSize * sizeof(float), + cudaMemcpyDeviceToHost)); + EXPECT_TRUE( + CheckOutputAndClean(h_in.get(), h_out.get(), d_out.get(), kSize)); + } +} + +template +struct ForLoop { + template