Merge branch 'master' into aarch64_build_patch

This commit is contained in:
Aleksandr Nikolaev 2021-03-19 19:45:25 +00:00 committed by GitHub
commit 1e15aa8c83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1030 changed files with 23590 additions and 14911 deletions

185
.bazelrc
View File

@ -84,16 +84,13 @@
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
#
# Release build options (for all operating systems)
# release_common: Common options for all builds on all operating systems.
# release_windows_common: Common options for all builds on Windows.
# release_gpu_common: Common options for GPU builds on Linux and Windows.
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
# release_gpu_linux_cuda_11_2: Toolchain and CUDA options for CUDA 11.2 Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
# release_base: Common options for all builds on all operating systems.
# release_gpu_base: Common options for GPU builds on Linux and Windows.
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
# Allow builds using libc++ as a linker library
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
@ -123,7 +120,13 @@ build:android_x86_64 --cpu=x86_64
build:android_x86_64 --fat_apk_cpu=x86_64
# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
build:macos --apple_platform_type=macos
# gRPC on MacOS requires this #define
build:macos --copt=-DGRPC_BAZEL_BUILD
# Settings for MacOS on ARM CPUs.
build:macos_arm64 --cpu=darwin_arm64
# iOS configs for each architecture and the fat binary builds.
build:ios --apple_platform_type=ios
@ -140,10 +143,10 @@ build:ios_x86_64 --cpu=ios_x86_64
build:ios_fat --config=ios
build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
# Enables all the macos config options for macos_arm64
build:macos_arm64 --config=macos
build:macos_arm64 --apple_platform_type=macos
build:macos_arm64 --cpu=darwin_arm64
# For projects which use TensorFlow as part of a Bazel build process, putting
# nothing in a bazelrc will default to a monolithic build. The following line
# opts in to modular op registration support by default.
build --define framework_shared_object=true
# Config to use a mostly-static build and disable modular op registration
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
@ -151,11 +154,6 @@ build:macos_arm64 --cpu=darwin_arm64
# //tensorflow:libtensorflow_framework.so.
build:monolithic --define framework_shared_object=false
# For projects which use TensorFlow as part of a Bazel build process, putting
# nothing in a bazelrc will default to a monolithic build. The following line
# opts in to modular op registration support by default.
build --define framework_shared_object=true
# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
build --java_toolchain=@org_tensorflow//third_party/toolchains/java:tf_java_toolchain
build --host_java_toolchain=@org_tensorflow//third_party/toolchains/java:tf_java_toolchain
@ -197,8 +195,8 @@ build:cuda_clang --config=cuda
build:cuda_clang --repo_env TF_CUDA_CLANG=1
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
# dbg config, as a shorthand for '--config=opt -c dbg'
build:dbg --config=opt -c dbg
# Debug config
build:dbg -c dbg
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
@ -266,12 +264,10 @@ build:c++1z --config=c++17
build:c++17_gcc --cxxopt=-std=c++1z
build:c++1z_gcc --config=c++17_gcc
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
# Trigger --config=<host platform>, except when cross-compiling.
build --enable_platform_specific_config
build:android --noenable_platform_specific_config
build:ios --noenable_platform_specific_config
build:macos_arm64 --noenable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:android --copt=-w
@ -357,15 +353,15 @@ build:avx_win --copt=/arch=AVX
build:avx2_win --copt=/arch=AVX2
# Options to build TensorFlow 1.x or 2.x.
build:v1 --define=tf_api_version=1
build:v2 --define=tf_api_version=2
build:v1 --action_env=TF2_BEHAVIOR=0
build:v2 --action_env=TF2_BEHAVIOR=1
build:v1 --define=tf_api_version=1 --action_env=TF2_BEHAVIOR=0
build:v2 --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
build --config=v2
test --config=v2
# Enable XLA
build:xla --define=with_xla_support=true
# Enable XLA except on mobile.
build --config=xla
build:xla --define=with_xla_support=true
build:android --define=with_xla_support=false
build:ios --define=with_xla_support=false
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
# Options when using remote execution
@ -399,7 +395,6 @@ build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
# Non-rbe settings we should include because we do not run configure
build:rbe_linux --config=xla
build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
@ -423,45 +418,9 @@ build:rbe_linux_cuda_base --config=tensorrt
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --action_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda10.1_nvcc_base --action_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --action_env=TF_CUDA_VERSION=11
build:rbe_linux_cuda11.0_nvcc_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
build:rbe_linux_cuda11.2_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDA_VERSION=11.2
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDNN_VERSION=8.1
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDA_VERSION=11
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda11.2_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.2_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.2_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
@ -471,25 +430,24 @@ build:rbe_linux_cuda11.2_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda"
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_tensorrt"
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_nccl"
build:rbe_linux_cuda11.2_nvcc_py3.5 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.5"
build:rbe_linux_cuda11.2_nvcc_py3.6 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.6"
build:rbe_linux_cuda11.2_nvcc_py3.7 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7"
build:rbe_linux_cuda11.2_nvcc_py3.8 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8"
build:rbe_linux_cuda11.2_nvcc_py3.9 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9"
# Map default to CUDA 11.2 for PY35 and greater.
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.2_nvcc_py3.5
# Map default to CUDA 11.2.
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.2_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.2_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.2_nvcc_py3.8
build:rbe_linux_cuda_nvcc_py39 --config=rbe_linux_cuda11.2_nvcc_py3.9
# Deprecated configs that people might still use.
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --action_env=TF_CUDA_VERSION=11.2
build:rbe_linux_cuda_clang_base --action_env=TF_CUDNN_VERSION=8.1
build:rbe_linux_cuda_clang_base --action_env=TF_CUDA_VERSION=11
build:rbe_linux_cuda_clang_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
@ -583,65 +541,42 @@ try-import %workspace%/.tf_configure.bazelrc
try-import %workspace%/.bazelrc.user
# Here are bazelrc configs for release builds
build:release_common --config=opt
build:release_common --config=v2
build:release_common --distinct_host_configuration=false
build:release_common --action_env TF_CONFIGURE_IOS="0"
build:release_base --config=v2
build:release_base --distinct_host_configuration=false
test:release_base --flaky_test_attempts=3
test:release_base --test_size_filters=small,medium
build:release_cpu_linux --config=release_common
build:release_cpu_linux --config=release_base
build:release_cpu_linux --config=avx_linux
# We use the same toolchain for CPU/GPU packages.
# Did not add this to the defaults in case this changes.
build:release_cpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_cpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
test:release_cpu_linux --test_env=LD_LIBRARY_PATH
build:release_cpu_macos --config=release_common
build:release_cpu_macos --config=release_base
build:release_cpu_macos --config=avx_linux
build:release_gpu_common --config=release_common
build:release_gpu_common --config=cuda
build:release_gpu_common --config=tensorrt
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
build:release_gpu_common --action_env=TF_CUDA_VERSION="11.2"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8.1"
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_base --config=release_base
build:release_gpu_base --config=cuda
build:release_gpu_base --action_env=TF_CUDA_VERSION="11"
build:release_gpu_base --action_env=TF_CUDNN_VERSION="8"
build:release_gpu_base --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
build:release_gpu_linux --config=release_gpu_common
build:release_gpu_linux --config=avx_linux
build:release_gpu_linux --config=release_cpu_linux
build:release_gpu_linux --config=release_gpu_base
build:release_gpu_linux --config=tensorrt
build:release_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
build:release_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:release_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
build:release_windows_common --config=release_common
build:release_windows_common --define=no_tensorflow_py_deps=true
build:release_windows_common --announce_rc
build:release_cpu_windows --config=release_base
build:release_cpu_windows --config=avx_win
build:release_cpu_windows --define=no_tensorflow_py_deps=true
# First available in VS 16.4. Speeds Windows compile times by a lot. See
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
build:release_windows_common --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
build:release_cpu_windows --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
build:release_cpu_windows --config=release_windows_common
build:release_gpu_windows --config=release_windows_common
build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux
build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_linux_cuda_10_1 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"
build:release_gpu_linux_cuda_11 --config=release_gpu_linux
build:release_gpu_linux_cuda_11 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
build:release_gpu_linux_cuda_11 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain
build:release_gpu_linux_cuda_11 --action_env=TF_CUDA_VERSION="11"
build:release_gpu_linux_cuda_11 --action_env=TF_CUDNN_VERSION="8"
build:release_gpu_linux_cuda_11_2 --config=release_gpu_linux
build:release_gpu_linux_cuda_11_2 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
build:release_gpu_linux_cuda_11_2 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDA_VERSION="11.2"
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDNN_VERSION="8.1"
build:release_gpu_windows --config=release_cpu_windows
build:release_gpu_windows --config=release_gpu_base
# Address sanitizer
# CC=clang bazel build --config asan

View File

@ -39,11 +39,23 @@
`num_parallel_calls` set, `deterministic` is used to indicate that
outputs can be obtained in the non-deterministic order.
* Options returned by `tf.data.Dataset.options()` are no longer mutable.
* tf.data input pipelines can now be executed in debug mode, which
disables any asynchrony, parallelism, or non-determinism and forces
Python execution (as opposed to trace-compiled graph execution) of
user-defined functions passed into transformations such as `map`. The
debug mode can be enabled through `tf.data.experimental.enable_debug_mode()`.
* `tf.lite`
* Enabled the new MLIR-based quantization backend by default
* The new backend is used for 8 bits full integer post-training quantization
* The new backend removes the redundant rescales and fixes some bugs (shared weight/bias, extremely small scales, etc)
* Set `experimental_new_quantizer` in tf.lite.TFLiteConverter to False to disable this change
* `tf.keras`
* Enabled a new supported input type in `Model.fit`,
`tf.keras.utils.experimental.DatasetCreator`, which takes a
callable, `dataset_fn`.
`DatasetCreator` is intended to work across all `tf.distribute`
strategies, and is the only input type supported for Parameter Server
strategy.
## Bug Fixes and Other Changes
@ -74,6 +86,8 @@
* Add `tf.data.experimental.AutoShardingPolicy.HINT` which can be used
to provide hints to tf.distribute-based auto-sharding as to where in
the input pipeline to insert sharding transformations.
* Make tf.data.Options persistent across `tf.function` and `GraphDef`
boundaries.
* XLA compilation:
* `tf.function(experimental_compile=True)` has become a stable API,
renamed `tf.function(jit_compile=True)`.

View File

@ -542,7 +542,6 @@ def set_cc_opt_flags(environ_cp):
for opt in cc_opt_flags.split():
write_to_bazelrc('build:opt --copt=%s' % opt)
write_to_bazelrc('build:opt --host_copt=%s' % opt)
write_to_bazelrc('build:opt --define with_default_optimizations=true')
def set_tf_cuda_clang(environ_cp):
@ -1202,13 +1201,12 @@ def config_info_line(name, help_text):
print('\t--config=%-12s\t# %s' % (name, help_text))
def configure_ios():
"""Configures TensorFlow for iOS builds.
This function will only be executed if `is_macos()` is true.
"""
def configure_ios(environ_cp):
"""Configures TensorFlow for iOS builds."""
if not is_macos():
return
if not get_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False):
return
for filepath in APPLE_BAZEL_FILES:
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
@ -1327,11 +1325,11 @@ def main():
if is_macos():
environ_cp['TF_NEED_TENSORRT'] = '0'
else:
environ_cp['TF_CONFIGURE_IOS'] = '0'
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla')
with_xla_support = environ_cp.get('TF_ENABLE_XLA', None)
if with_xla_support is not None:
write_to_bazelrc('build --define=with_xla_support=%s' % (
'true' if int(with_xla_support) else 'false'))
set_action_env_var(
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
@ -1450,9 +1448,7 @@ def main():
system_specific_test_config(environ_cp)
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
configure_ios()
configure_ios(environ_cp)
print('Preconfigured Bazel build configs. You can use any of the below by '
'adding "--config=<>" to your build command. See .bazelrc for more '

View File

@ -936,6 +936,7 @@ tf_cc_shared_object(
"//tensorflow/core/common_runtime:core_cpu_impl",
"//tensorflow/core:framework_internal_impl",
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
"//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler:profiler_impl",

View File

@ -85,13 +85,12 @@ if _module_dir:
setattr(_current_module, "estimator", estimator)
if _os.environ.get("_PREFER_OSS_KERAS", False):
try:
from keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
_keras_module = "keras.api._v2.keras"
keras = _LazyLoader("keras", globals(), _keras_module)
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "keras", keras)
else:
try:
from .python.keras.api._v2 import keras
@ -160,14 +159,33 @@ if _running_from_pip_package():
# Add module aliases
if hasattr(_current_module, 'keras'):
losses = keras.losses
metrics = keras.metrics
optimizers = keras.optimizers
initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
# It is possible that keras is a lazily loaded module, which might break when
# actually trying to import it. Have a Try-Catch to make sure it doesn't break
# when it doing some very initial loading, like tf.compat.v2, etc.
if _os.environ.get("_PREFER_OSS_KERAS", False):
try:
_keras_package = "keras.api._v2.keras."
losses = _LazyLoader("losses", globals(), _keras_package + "losses")
metrics = _LazyLoader("metrics", globals(), _keras_package + "metrics")
optimizers = _LazyLoader(
"optimizers", globals(), _keras_package + "optimizers")
initializers = _LazyLoader(
"initializers", globals(), _keras_package + "initializers")
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
except ImportError:
pass
else:
losses = keras.losses
metrics = keras.metrics
optimizers = keras.optimizers
initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
# pylint: enable=undefined-variable
# Delete modules that should be hidden from dir().

View File

@ -76,13 +76,12 @@ if _module_dir:
setattr(_current_module, "estimator", estimator)
if _os.environ.get("_PREFER_OSS_KERAS", False):
try:
from keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
_keras_module = "keras.api._v1.keras"
keras = _LazyLoader("keras", globals(), _keras_module)
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "keras", keras)
else:
try:
from .python.keras.api._v1 import keras

View File

@ -429,6 +429,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/pluggable_device:pluggable_device_plugin_init",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"//tensorflow/core/platform:blocking_counter",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -735,7 +736,10 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
if (status->status.ok()) {
TF_CHECK_OK(
tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle));
} else {
delete lib_handle;
return nullptr;
}

View File

@ -236,6 +236,8 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
}
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
// TODO(penpornk): Enable this test on Windows.
#if !defined(PLATFORM_WINDOWS)
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
@ -250,6 +252,7 @@ TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
#endif // !defined(PLATFORM_WINDOWS)
}
} // namespace

View File

@ -32,11 +32,7 @@ cc_library(
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/core/lib/llvm_rtti",
],
)
@ -93,7 +89,6 @@ cc_library(
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:gradients_internal",
"@com_google_absl//absl/types:span",
],
)
@ -113,6 +108,7 @@ cc_library(
":math_grad",
":nn_grad",
":not_differentiable",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:gradients_internal",
],
)
@ -188,6 +184,8 @@ tf_cuda_cc_test(
":nn_grad",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:unified_api_testutil",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/platform:tensor_float_32_utils",
@ -213,6 +211,8 @@ tf_cuda_cc_test(
":math_grad",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:unified_api_testutil",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/core/platform:tensor_float_32_utils",
@ -243,6 +243,8 @@ tf_cuda_cc_test(
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:unified_api_testutil",
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],

View File

@ -17,8 +17,6 @@ cc_library(
deps = [
":tape_operation",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/c/eager:abstract_operation",
],
)
@ -33,7 +31,6 @@ cc_library(
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:gradients_internal",
],
@ -51,6 +48,9 @@ cc_library(
deps = [
":tape_context",
":tape_operation",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:gradients_internal",
],
)

View File

@ -60,7 +60,10 @@ cc_library(
"stream_executor.h",
"stream_executor_internal.h",
],
visibility = ["//tensorflow/c:__subpackages__"],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/core/common_runtime/pluggable_device:__subpackages__",
],
deps = [
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
@ -76,6 +79,7 @@ tf_cc_test(
deps = [
":stream_executor",
":stream_executor_internal",
":stream_executor_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
@ -84,3 +88,14 @@ tf_cc_test(
"//tensorflow/stream_executor:stream_executor_pimpl",
],
)
cc_library(
name = "stream_executor_test_util",
srcs = ["stream_executor_test_util.cc"],
hdrs = ["stream_executor_test_util.h"],
visibility = ["//tensorflow:internal"],
deps = [
":stream_executor_hdrs",
"//tensorflow/c:tf_status",
],
)

View File

@ -749,7 +749,9 @@ port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
return result;
}
port::Status InitStreamExecutorPlugin(void* dso_handle) {
port::Status InitStreamExecutorPlugin(void* dso_handle,
std::string* device_type,
std::string* platform_name) {
tensorflow::Env* env = tensorflow::Env::Default();
// Step 1: Load symbol for `TF_InitPlugin`
@ -759,10 +761,12 @@ port::Status InitStreamExecutorPlugin(void* dso_handle) {
// Step 2: Call `TF_InitPlugin`
auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
return InitStreamExecutorPlugin(init_fn);
return InitStreamExecutorPlugin(init_fn, device_type, platform_name);
}
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
std::string* device_type,
std::string* platform_name) {
SE_PlatformRegistrationParams params{
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
@ -808,7 +812,8 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
// Register new platform
std::string platform_name = std::string(platform.name);
*device_type = std::string(platform.type);
*platform_name = std::string(platform.name);
std::unique_ptr<stream_executor::CPlatform> cplatform(
new stream_executor::CPlatform(
std::move(platform), params.destroy_platform, std::move(platform_fns),
@ -816,8 +821,8 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
std::move(timer_fns)));
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(cplatform)));
// TODO(annarev): Add pluggable device registration here.
// TODO(annarev): Return `use_bfc_allocator` value in some way so that it is
// available in `PluggableDeviceProcessState` once the latter is checked in.
return port::Status::OK();
}
} // namespace stream_executor

View File

@ -431,10 +431,13 @@ typedef struct SP_Platform {
// Whether this platform supports unified memory.
// Unified memory is a single memory address space accessible from any device.
TF_Bool supports_unified_memory;
// Whether to wrap allocator for this device with an allocator that uses BFC
// (best-fit with coalescing) strategy.
TF_Bool use_bfc_allocator;
} SP_Platform;
#define SP_PLATFORM_STRUCT_SIZE \
TF_OFFSET_OF_END(SP_Platform, supports_unified_memory)
#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, use_bfc_allocator)
typedef struct SP_PlatformFns {
size_t struct_size;

View File

@ -31,12 +31,17 @@ namespace stream_executor {
typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
TF_Status* const);
// Registers StreamExecutor platform.
port::Status InitStreamExecutorPlugin(void* dso_handle);
// Registers StreamExecutor platform. `device_type` and `platform_name` are
// output parameters.
port::Status InitStreamExecutorPlugin(void* dso_handle,
std::string* device_type,
std::string* platform_name);
// Allow registering a StreamExecutor plugin using a function (used for
// testing).
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
std::string* device_type,
std::string* platform_name);
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
@ -24,209 +25,26 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/timer.h"
struct SP_Stream_st {
explicit SP_Stream_st(int id) : stream_id(id) {}
int stream_id;
};
struct SP_Event_st {
explicit SP_Event_st(int id) : event_id(id) {}
int event_id;
};
struct SP_Timer_st {
explicit SP_Timer_st(int id) : timer_id(id) {}
int timer_id;
};
namespace stream_executor {
namespace {
constexpr int kDeviceCount = 2;
constexpr char kDeviceName[] = "MY_DEVICE";
constexpr char kDeviceType[] = "GPU";
/*** Create SP_StreamExecutor (with empty functions) ***/
void allocate(const SP_Device* const device, uint64_t size,
int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
}
void* host_memory_allocate(const SP_Device* const device, uint64_t size) {
return nullptr;
}
void host_memory_deallocate(const SP_Device* const device, void* mem) {}
TF_Bool get_allocator_stats(const SP_Device* const device,
SP_AllocatorStats* const stats) {
return true;
}
TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free,
int64_t* const total) {
return true;
}
void create_stream(const SP_Device* const device, SP_Stream* stream,
TF_Status* const status) {
stream = nullptr;
}
void destroy_stream(const SP_Device* const device, SP_Stream stream) {}
void create_stream_dependency(const SP_Device* const device,
SP_Stream dependent, SP_Stream other,
TF_Status* const status) {}
void get_stream_status(const SP_Device* const device, SP_Stream stream,
TF_Status* const status) {}
void create_event(const SP_Device* const device, SP_Event* event,
TF_Status* const status) {
event = nullptr;
}
void destroy_event(const SP_Device* const device, SP_Event event) {}
SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) {
return SE_EVENT_UNKNOWN;
}
void record_event(const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {}
void wait_for_event(const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {}
void create_timer(const SP_Device* const device, SP_Timer* timer,
TF_Status* const status) {}
void destroy_timer(const SP_Device* const device, SP_Timer timer) {}
void start_timer(const SP_Device* const device, SP_Stream stream,
SP_Timer timer, TF_Status* const status) {}
void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
TF_Status* const status) {}
void memcpy_dtoh(const SP_Device* const device, SP_Stream stream,
void* host_dst, const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {}
void memcpy_htod(const SP_Device* const device, SP_Stream stream,
SP_DeviceMemoryBase* const device_dst, const void* host_src,
uint64_t size, TF_Status* const status) {}
void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst,
const SP_DeviceMemoryBase* const device_src,
uint64_t size, TF_Status* const status) {}
void sync_memcpy_htod(const SP_Device* const device,
SP_DeviceMemoryBase* const device_dst,
const void* host_src, uint64_t size,
TF_Status* const status) {}
void block_host_for_event(const SP_Device* const device, SP_Event event,
TF_Status* const status) {}
void synchronize_all_activity(const SP_Device* const device,
TF_Status* const status) {}
TF_Bool host_callback(const SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) {
return true;
}
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
*se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
se->allocate = allocate;
se->deallocate = deallocate;
se->host_memory_allocate = host_memory_allocate;
se->host_memory_deallocate = host_memory_deallocate;
se->get_allocator_stats = get_allocator_stats;
se->device_memory_usage = device_memory_usage;
se->create_stream = create_stream;
se->destroy_stream = destroy_stream;
se->create_stream_dependency = create_stream_dependency;
se->get_stream_status = get_stream_status;
se->create_event = create_event;
se->destroy_event = destroy_event;
se->get_event_status = get_event_status;
se->record_event = record_event;
se->wait_for_event = wait_for_event;
se->create_timer = create_timer;
se->destroy_timer = destroy_timer;
se->start_timer = start_timer;
se->stop_timer = stop_timer;
se->memcpy_dtoh = memcpy_dtoh;
se->memcpy_htod = memcpy_htod;
se->sync_memcpy_dtoh = sync_memcpy_dtoh;
se->sync_memcpy_htod = sync_memcpy_htod;
se->block_host_for_event = block_host_for_event;
se->synchronize_all_activity = synchronize_all_activity;
se->host_callback = host_callback;
}
void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
*device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
}
/*** Create SP_TimerFns ***/
uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; }
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
timer_fns->nanoseconds = nanoseconds;
}
/*** Create SP_Platform ***/
void create_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultTimerFns(timer_fns);
}
void destroy_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
void create_stream_executor(const SP_Platform* platform,
SE_CreateStreamExecutorParams* params,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultStreamExecutor(params->stream_executor);
}
void destroy_stream_executor(const SP_Platform* platform,
SP_StreamExecutor* se) {}
void get_device_count(const SP_Platform* platform, int* device_count,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
*device_count = kDeviceCount;
}
void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
}
void destroy_device(const SP_Platform* platform, SP_Device* device) {}
void create_device_fns(const SP_Platform* platform,
SE_CreateDeviceFnsParams* params, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
}
void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
}
void PopulateDefaultPlatform(SP_Platform* platform,
SP_PlatformFns* platform_fns) {
*platform = {SP_PLATFORM_STRUCT_SIZE};
platform->name = kDeviceName;
platform->type = kDeviceType;
platform_fns->get_device_count = get_device_count;
platform_fns->create_device = create_device;
platform_fns->destroy_device = destroy_device;
platform_fns->create_device_fns = create_device_fns;
platform_fns->destroy_device_fns = destroy_device_fns;
platform_fns->create_stream_executor = create_stream_executor;
platform_fns->destroy_stream_executor = destroy_stream_executor;
platform_fns->create_timer_fns = create_timer_fns;
platform_fns->destroy_timer_fns = destroy_timer_fns;
}
void destroy_platform(SP_Platform* const platform) {}
void destroy_platform_fns(SP_PlatformFns* const platform_fns) {}
/*** Registration tests ***/
TEST(StreamExecutor, SuccessfulRegistration) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
test_util::PopulateDefaultPlatformRegistrationParams(params);
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
TF_ASSERT_OK(status);
port::StatusOr<Platform*> maybe_platform =
MultiPlatformManager::PlatformWithName("MY_DEVICE");
TF_ASSERT_OK(maybe_platform.status());
Platform* platform = maybe_platform.ConsumeValueOrDie();
ASSERT_EQ(platform->Name(), kDeviceName);
ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount);
ASSERT_EQ(platform->Name(), test_util::kDeviceName);
ASSERT_EQ(platform->VisibleDeviceCount(), test_util::kDeviceCount);
port::StatusOr<StreamExecutor*> maybe_executor =
platform->ExecutorForDevice(0);
@ -237,13 +55,13 @@ TEST(StreamExecutor, NameNotSet) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
test_util::PopulateDefaultPlatformRegistrationParams(params);
params->platform->name = nullptr;
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
}
@ -252,13 +70,13 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
test_util::PopulateDefaultPlatformRegistrationParams(params);
params->platform->name = "INVALID:NAME";
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(
status.error_message(),
@ -269,13 +87,13 @@ TEST(StreamExecutor, InvalidNameWithSlash) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
test_util::PopulateDefaultPlatformRegistrationParams(params);
params->platform->name = "INVALID/";
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(status.error_message(),
testing::ContainsRegex("Device name/type 'INVALID/' must match"));
@ -285,13 +103,13 @@ TEST(StreamExecutor, CreateDeviceNotSet) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
test_util::PopulateDefaultPlatformRegistrationParams(params);
params->platform_fns->create_device = nullptr;
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(),
"'create_device' field in SP_PlatformFns must be set.");
@ -301,13 +119,13 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
test_util::PopulateDefaultPlatformRegistrationParams(params);
params->platform->supports_unified_memory = true;
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(
status.error_message(),
@ -319,18 +137,18 @@ class StreamExecutorTest : public ::testing::Test {
protected:
StreamExecutorTest() {}
void SetUp() override {
PopulateDefaultPlatform(&platform_, &platform_fns_);
PopulateDefaultDeviceFns(&device_fns_);
PopulateDefaultStreamExecutor(&se_);
PopulateDefaultTimerFns(&timer_fns_);
test_util::PopulateDefaultPlatform(&platform_, &platform_fns_);
test_util::PopulateDefaultDeviceFns(&device_fns_);
test_util::PopulateDefaultStreamExecutor(&se_);
test_util::PopulateDefaultTimerFns(&timer_fns_);
}
void TearDown() override {}
StreamExecutor* GetExecutor(int ordinal) {
if (!cplatform_) {
cplatform_ = absl::make_unique<CPlatform>(
platform_, destroy_platform, platform_fns_, destroy_platform_fns,
device_fns_, se_, timer_fns_);
platform_, test_util::DestroyPlatform, platform_fns_,
test_util::DestroyPlatformFns, device_fns_, se_, timer_fns_);
}
port::StatusOr<StreamExecutor*> maybe_executor =
cplatform_->ExecutorForDevice(ordinal);

View File

@ -0,0 +1,193 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
namespace stream_executor {
namespace test_util {
/*** Functions for creating SP_StreamExecutor ***/
void Allocate(const SP_Device* const device, uint64_t size,
int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
}
void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) {
return nullptr;
}
void HostMemoryDeallocate(const SP_Device* const device, void* mem) {}
TF_Bool GetAllocatorStats(const SP_Device* const device,
SP_AllocatorStats* const stats) {
return true;
}
TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free,
int64_t* const total) {
return true;
}
void CreateStream(const SP_Device* const device, SP_Stream* stream,
TF_Status* const status) {
stream = nullptr;
}
void DestroyStream(const SP_Device* const device, SP_Stream stream) {}
void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent,
SP_Stream other, TF_Status* const status) {}
void GetStreamStatus(const SP_Device* const device, SP_Stream stream,
TF_Status* const status) {}
void CreateEvent(const SP_Device* const device, SP_Event* event,
TF_Status* const status) {
event = nullptr;
}
void DestroyEvent(const SP_Device* const device, SP_Event event) {}
SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) {
return SE_EVENT_UNKNOWN;
}
void RecordEvent(const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {}
void WaitForEvent(const SP_Device* const device, SP_Stream stream,
SP_Event event, TF_Status* const status) {}
void CreateTimer(const SP_Device* const device, SP_Timer* timer,
TF_Status* const status) {}
void DestroyTimer(const SP_Device* const device, SP_Timer timer) {}
void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
TF_Status* const status) {}
void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
TF_Status* const status) {}
void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst,
const SP_DeviceMemoryBase* const device_src, uint64_t size,
TF_Status* const status) {}
void MemcpyHToD(const SP_Device* const device, SP_Stream stream,
SP_DeviceMemoryBase* const device_dst, const void* host_src,
uint64_t size, TF_Status* const status) {}
void SyncMemcpyDToH(const SP_Device* const device, void* host_dst,
const SP_DeviceMemoryBase* const device_src, uint64_t size,
TF_Status* const status) {}
void SyncMemcpyHToD(const SP_Device* const device,
SP_DeviceMemoryBase* const device_dst, const void* host_src,
uint64_t size, TF_Status* const status) {}
void BlockHostForEvent(const SP_Device* const device, SP_Event event,
TF_Status* const status) {}
void SynchronizeAllActivity(const SP_Device* const device,
TF_Status* const status) {}
TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream,
SE_StatusCallbackFn const callback_fn,
void* const callback_arg) {
return true;
}
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
*se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
se->allocate = Allocate;
se->deallocate = Deallocate;
se->host_memory_allocate = HostMemoryAllocate;
se->host_memory_deallocate = HostMemoryDeallocate;
se->get_allocator_stats = GetAllocatorStats;
se->device_memory_usage = DeviceMemoryUsage;
se->create_stream = CreateStream;
se->destroy_stream = DestroyStream;
se->create_stream_dependency = CreateStreamDependency;
se->get_stream_status = GetStreamStatus;
se->create_event = CreateEvent;
se->destroy_event = DestroyEvent;
se->get_event_status = GetEventStatus;
se->record_event = RecordEvent;
se->wait_for_event = WaitForEvent;
se->create_timer = CreateTimer;
se->destroy_timer = DestroyTimer;
se->start_timer = StartTimer;
se->stop_timer = StopTimer;
se->memcpy_dtoh = MemcpyDToH;
se->memcpy_htod = MemcpyHToD;
se->sync_memcpy_dtoh = SyncMemcpyDToH;
se->sync_memcpy_htod = SyncMemcpyHToD;
se->block_host_for_event = BlockHostForEvent;
se->synchronize_all_activity = SynchronizeAllActivity;
se->host_callback = HostCallback;
}
void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
*device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
}
/*** Functions for creating SP_TimerFns ***/
uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; }
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
timer_fns->nanoseconds = Nanoseconds;
}
/*** Functions for creating SP_Platform ***/
void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultTimerFns(timer_fns);
}
void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
void CreateStreamExecutor(const SP_Platform* platform,
SE_CreateStreamExecutorParams* params,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultStreamExecutor(params->stream_executor);
}
void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) {
}
void GetDeviceCount(const SP_Platform* platform, int* device_count,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
*device_count = kDeviceCount;
}
void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
}
void DestroyDevice(const SP_Platform* platform, SP_Device* device) {}
void CreateDeviceFns(const SP_Platform* platform,
SE_CreateDeviceFnsParams* params, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
}
void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {}
void PopulateDefaultPlatform(SP_Platform* platform,
SP_PlatformFns* platform_fns) {
*platform = {SP_PLATFORM_STRUCT_SIZE};
platform->name = kDeviceName;
platform->type = kDeviceType;
platform_fns->get_device_count = GetDeviceCount;
platform_fns->create_device = CreateDevice;
platform_fns->destroy_device = DestroyDevice;
platform_fns->create_device_fns = CreateDeviceFns;
platform_fns->destroy_device_fns = DestroyDeviceFns;
platform_fns->create_stream_executor = CreateStreamExecutor;
platform_fns->destroy_stream_executor = DestroyStreamExecutor;
platform_fns->create_timer_fns = CreateTimerFns;
platform_fns->destroy_timer_fns = DestroyTimerFns;
}
/*** Functions for creating SE_PlatformRegistrationParams ***/
void DestroyPlatform(SP_Platform* platform) {}
void DestroyPlatformFns(SP_PlatformFns* platform_fns) {}
void PopulateDefaultPlatformRegistrationParams(
SE_PlatformRegistrationParams* const params) {
PopulateDefaultPlatform(params->platform, params->platform_fns);
params->destroy_platform = DestroyPlatform;
params->destroy_platform_fns = DestroyPlatformFns;
}
} // namespace test_util
} // namespace stream_executor

View File

@ -0,0 +1,56 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
struct SP_Stream_st {
explicit SP_Stream_st(int id) : stream_id(id) {}
int stream_id;
};
struct SP_Event_st {
explicit SP_Event_st(int id) : event_id(id) {}
int event_id;
};
struct SP_Timer_st {
explicit SP_Timer_st(int id) : timer_id(id) {}
int timer_id;
};
namespace stream_executor {
namespace test_util {
constexpr int kDeviceCount = 2;
constexpr char kDeviceName[] = "MY_DEVICE";
constexpr char kDeviceType[] = "GPU";
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se);
void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns);
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns);
void PopulateDefaultPlatform(SP_Platform* platform,
SP_PlatformFns* platform_fns);
void PopulateDefaultPlatformRegistrationParams(
SE_PlatformRegistrationParams* const params);
void DestroyPlatform(SP_Platform* platform);
void DestroyPlatformFns(SP_PlatformFns* platform_fns);
} // namespace test_util
} // namespace stream_executor
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_

View File

@ -13,5 +13,8 @@ tf_cc_shared_object(
name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"],
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
deps = [
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/c/experimental/stream_executor:stream_executor_test_util",
],
)

View File

@ -14,10 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
extern "C" {
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
stream_executor::test_util::PopulateDefaultPlatformRegistrationParams(params);
}
void TF_InitKernel() {}
}

View File

@ -53,13 +53,12 @@ if _module_dir:
setattr(_current_module, "estimator", estimator)
if _os.environ.get("_PREFER_OSS_KERAS", False):
try:
from keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
_keras_module = "keras.api._v2.keras"
keras = _LazyLoader("keras", globals(), _keras_module)
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "keras", keras)
else:
try:
from tensorflow.python.keras.api._v2 import keras
@ -88,11 +87,30 @@ setattr(_current_module, "enable_v2_behavior", enable_v2_behavior)
# Add module aliases
if hasattr(_current_module, 'keras'):
losses = keras.losses
metrics = keras.metrics
optimizers = keras.optimizers
initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
# It is possible that keras is a lazily loaded module, which might break when
# actually trying to import it. Have a Try-Catch to make sure it doesn't break
# when it doing some very initial loading, like tf.compat.v2, etc.
if _os.environ.get("_PREFER_OSS_KERAS", False):
try:
_keras_package = "keras.api._v2.keras."
losses = _LazyLoader("losses", globals(), _keras_package + "losses")
metrics = _LazyLoader("metrics", globals(), _keras_package + "metrics")
optimizers = _LazyLoader(
"optimizers", globals(), _keras_package + "optimizers")
initializers = _LazyLoader(
"initializers", globals(), _keras_package + "initializers")
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
except ImportError:
pass
else:
losses = keras.losses
metrics = keras.metrics
optimizers = keras.optimizers
initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)

View File

@ -43,13 +43,12 @@ if _module_dir:
setattr(_current_module, "estimator", estimator)
if _os.environ.get("_PREFER_OSS_KERAS", False):
try:
from keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
_keras_module = "keras.api._v1.keras"
keras = _LazyLoader("keras", globals(), _keras_module)
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "keras", keras)
else:
try:
from tensorflow.python.keras.api._v1 import keras

View File

@ -163,6 +163,7 @@ void AllocateAndParseFlags() {
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
ops_flags->tf_xla_async_compilation = false;
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
@ -216,6 +217,10 @@ void AllocateAndParseFlags() {
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation,
"When lazy compilation is enabled, asynchronous compilation starts "
"the cluster compilation in the background, and the fallback path "
"is executed until the compilation has finished."),
Flag("tf_introduce_floating_point_jitter_to_tensors",
setter_for_jitter_tensor_names, "",

View File

@ -99,6 +99,9 @@ struct XlaOpsCommonFlags {
// If true, _XlaCompile always refuses to compile the cluster, which means the
// XLA clusters always run in the TF executor. Defaults to false.
bool tf_xla_always_defer_compilation;
// If true, _XlaCompile compiles the cluster asynchronously with respect to
// the main execution. The fallback path is taken while compilation happens.
bool tf_xla_async_compilation;
};
// Flags for the build_xla_ops pass.

View File

@ -37,7 +37,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
const XlaCompiler::Options& options,
const XlaCompiler::CompileOptions& compile_options,
const NameAttrList& function, XlaCompilationCache* cache,
absl::Span<XlaCompiler::Argument const> args, const XlaCompiler& compiler) {
const std::vector<XlaCompiler::Argument>& args,
const XlaCompiler& compiler) {
const XlaCompiler::CompilationResult* compilation_result = nullptr;
xla::LocalExecutable* executable = nullptr;
TF_RETURN_IF_ERROR(cache->Compile(options, function, args, compile_options,
@ -100,12 +101,10 @@ xla::StatusOr<std::string> GetCompilerIr(
}));
core::ScopedUnref cache_ref(cache);
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(*cache, *flr, dev,
/*stream=*/nullptr, platform_info,
/*has_ref_vars=*/false, &tf_allocator_adapter);
/*has_ref_vars=*/false);
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;

View File

@ -166,8 +166,9 @@ static Status CompileToLocalExecutable(
const XlaPlatformInfo& platform_info,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
xla::LocalClient** client,
absl::Span<const int> constants,
XlaCompilationCache::CompileMode compile_mode,
bool may_alias_resource_update, xla::LocalClient** client,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation
@ -190,11 +191,10 @@ static Status CompileToLocalExecutable(
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options = GenerateCompilerOptions(
*cache, *ctx->function_library(), ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info, has_ref_vars, &tf_allocator_adapter);
platform_info, has_ref_vars);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
@ -202,7 +202,6 @@ static Status CompileToLocalExecutable(
// rather than a one-element tuple.
compile_options.always_return_tuple = false;
compile_options.alias_resource_update = !has_ref_vars &&
!platform_info.is_on_xla_device() &&
may_alias_resource_update;
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
@ -210,9 +209,7 @@ static Status CompileToLocalExecutable(
constants, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
return cache->Compile(options, function, *args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
return cache->Compile(options, function, *args, compile_options, compile_mode,
compilation_result, executable);
}
@ -233,7 +230,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
variable_infos, constants_, /*lazy=*/false,
variable_infos, constants_, XlaCompilationCache::CompileMode::kStrict,
/*may_alias_resource_update=*/true, &client, &compilation_result,
&executable);
OP_REQUIRES_OK(ctx, s);
@ -246,12 +243,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
GetAllocator(ctx->device(), stream, platform_info_);
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
int device_ordinal = stream ? stream->parent()->device_ordinal()
: client->default_device_ordinal();
XlaComputationLaunchContext launch_context(
@ -381,6 +375,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
mutex_lock guard(cannot_compile_cluster_mu_);
cannot_compile_cluster = cannot_compile_cluster_;
}
XlaCompilationCache::CompileMode compile_mode = [&] {
if (must_compile_) {
return XlaCompilationCache::CompileMode::kStrict;
}
return GetXlaOpsCommonFlags().tf_xla_async_compilation
? XlaCompilationCache::CompileMode::kAsync
: XlaCompilationCache::CompileMode::kLazy;
}();
if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
cannot_compile_cluster) {
@ -396,12 +398,12 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
// unlocking them in XlaRun may lead to deadlocks.
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
constants_,
/*lazy=*/!must_compile_,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
constants_, compile_mode, /*may_alias_resource_update=*/false, &client,
&kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
if (compile_mode != XlaCompilationCache::CompileMode::kLazy ||
status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}
@ -423,6 +425,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
host_alloc_attrs.set_on_host(true);
Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
// Async compilation returns nullptr executable without an error.
if (!executable) {
DCHECK(!must_compile_);
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
@ -463,13 +466,11 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
XlaExecutableClosure closure =
XlaExecutableClosureStore::Global()->Consume(key);
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
GetAllocator(ctx->device(), stream, platform_info_);
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
int device_ordinal = stream ? stream->parent()->device_ordinal()
: closure.client()->default_device_ordinal();
XlaComputationLaunchContext launch_context(

View File

@ -53,6 +53,9 @@ limitations under the License.
namespace tensorflow {
constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
constexpr int64 XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads;
constexpr int64
XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations;
XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
DeviceType device_type)
@ -68,6 +71,12 @@ XlaCompilationCache::~XlaCompilationCache() {
"programs to complete";
}
}
// Wait for all outstanding compilations to finish.
// Resetting the pointer explicitly in the top level destructor.
// Without this, the pointer would be reset when the AsyncCompilationState
// is destructed, which is dependent on the order of the members in the
// XlaCompilationCache class, which is error prone if the order changes.
async_compilation_state_.compiler_threads.reset();
// TODO(b/110813685): Think about the program ownership model. Programs are
// currently owned by the compilation cache which means we must wait for
// program completion in the destructor. There are multiple compilation caches
@ -170,7 +179,7 @@ Status XlaCompilationCache::BuildExecutable(
? options.device_ordinal
: client_->default_device_ordinal());
build_options.set_result_layout(result.xla_output_shape);
build_options.set_device_allocator(options.device_allocator);
build_options.set_device_allocator(options.device_allocator.get());
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
build_options.mutable_debug_options()->set_xla_detailed_logging(
options.detailed_logging);
@ -184,21 +193,22 @@ Status XlaCompilationCache::BuildExecutable(
Status XlaCompilationCache::Compile(
const XlaCompiler::Options& options, const NameAttrList& function,
absl::Span<const XlaCompiler::Argument> args,
const std::vector<XlaCompiler::Argument>& args,
const XlaCompiler::CompileOptions& compile_options,
CompileMode compile_mode,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) {
absl::optional<int64> compile_threshold;
if (compile_mode == CompileMode::kLazy) {
compile_threshold = kDefaultCompilationThreshold;
}
auto compile_fn = [&](XlaCompiler* compiler,
// !!Pay attention when additional variables must be captured by this
// lambda!! compile_fn can run asynchronously after this funcion has
// exited. Make sure that any variable needed inside compile_fn is
// either passed as an argument, or captured by value right here.
auto compile_fn = [compile_options, function](
XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult* result) {
return compiler->CompileFunction(compile_options, function, args, result);
};
return CompileImpl(options, function, args, compile_fn,
/*compile_threshold=*/compile_threshold,
return CompileImpl(options, function, args, compile_fn, compile_mode,
out_compilation_result, out_executable);
}
@ -261,7 +271,7 @@ static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
Status XlaCompilationCache::CompileSingleOp(
const XlaCompiler::Options& options,
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
const XlaCompiler::CompileOptions& compile_options,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) {
@ -274,6 +284,7 @@ Status XlaCompilationCache::CompileSingleOp(
// and causes false uniqueness between nodes.
name.mutable_attr()->erase("_class");
auto compile_op = [&](XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult* result) {
std::vector<DataType> result_dtypes(ctx->num_outputs());
for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
@ -308,8 +319,7 @@ Status XlaCompilationCache::CompileSingleOp(
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,
return CompileImpl(options, name, args, compile_op, CompileMode::kStrict,
out_compilation_result, out_executable);
}
@ -327,12 +337,113 @@ void LogOnceXlaCompiledFirstCluster() {
}
} // namespace
Status XlaCompilationCache::CompileStrict(
Entry* entry, const XlaCompiler::Options& options,
const std::vector<XlaCompiler::Argument>& args, const string& function_name,
const std::function<Status(XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult*)>& compile_fn) {
tensorflow::Env* env = tensorflow::Env::Default();
const uint64 compile_start_us = env->NowMicros();
XlaCompiler compiler(options);
entry->compile_state = CompileState::kCompiled;
entry->compilation_status =
compile_fn(&compiler, args, &entry->compilation_result);
TF_RETURN_IF_ERROR(entry->compilation_status);
TF_RET_CHECK(entry->executable.get() == nullptr);
entry->compilation_status =
BuildExecutable(options, entry->compilation_result, &entry->executable);
const uint64 compile_end_us = env->NowMicros();
const uint64 compile_time_us = compile_end_us - compile_start_us;
metrics::UpdateXlaCompilationTime(compile_time_us);
{
mutex_lock lock(cluster_compile_stats_mu_);
auto it = cluster_compile_stats_.find(function_name);
const uint64 compile_time_s = compile_time_us / 1.0e6;
it->second.compile_count++;
it->second.cumulative_compile_time_us += compile_time_us;
LogOnceXlaCompiledFirstCluster();
VLOG(1) << "compiled " << function_name << " " << it->second.compile_count
<< " times, compile time: " << compile_time_us
<< " us, cumulative: " << it->second.cumulative_compile_time_us
<< " us ("
<< tensorflow::strings::HumanReadableElapsedTime(compile_time_s)
<< " / "
<< tensorflow::strings::HumanReadableElapsedTime(
it->second.cumulative_compile_time_us / 1.0e6)
<< ")";
XlaJitCompilationActivity jit_compilation_activity;
jit_compilation_activity.set_cluster_name(function_name);
jit_compilation_activity.set_compile_count(it->second.compile_count);
jit_compilation_activity.set_compile_time_us(compile_time_us);
jit_compilation_activity.set_cumulative_compile_time_us(
it->second.cumulative_compile_time_us);
TF_RETURN_IF_ERROR(
BroadcastXlaActivity(std::move(jit_compilation_activity)));
}
return Status::OK();
}
Status XlaCompilationCache::CompileAsynchronous(
Entry* entry, const XlaCompiler::Options& options,
const std::vector<XlaCompiler::Argument>& args, const string& function_name,
const std::function<Status(XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult*)>& compile_fn) {
// Explicitly capture all required data by value for async compilation.
entry->compile_state = CompileState::kCompiling;
{
mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
async_compilation_state_.num_ongoing_compilations++;
}
// Don't move the above code into the thread function as it synchronously
// updates the async compilation state!
// When the ThreadPool for the compilation cache is destroyed, it waits for
// compilations to have finished. This means that both 'entry' and 'this' will
// be alive for the duration of the compilation.
// !!Pay attention when additional variables must be captured by this lambda!!
// All values are captured by value. Make sure that all pointer values (like
// entry) do not get freed until the lambda has finished,\.
async_compilation_state_.compiler_threads->Schedule([=] {
Entry local_entry;
VLOG(2) << "Starting asynchronous compilation of cluster " << function_name
<< '.';
// We don't need to lock local_entry.mu, but do it anyway to satisfy
// thread safety analysis.
mutex_lock entry_lock(local_entry.mu);
(void)CompileStrict(&local_entry, options, args, function_name, compile_fn);
VLOG(2) << "Finished asynchronous compililation of cluster "
<< function_name << '.';
{
mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
async_compilation_state_.num_ongoing_compilations--;
}
{ // Populate original entry with compilation result.
mutex_lock entry_lock(entry->mu);
entry->compilation_result = local_entry.compilation_result;
entry->compile_state = local_entry.compile_state;
entry->compilation_status = local_entry.compilation_status;
entry->executable = std::move(local_entry.executable);
}
});
return Status::OK();
}
Status XlaCompilationCache::CompileImpl(
const XlaCompiler::Options& options, const NameAttrList& function,
absl::Span<const XlaCompiler::Argument> args,
const std::vector<XlaCompiler::Argument>& args,
const std::function<Status(XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult*)>& compile_fn,
absl::optional<int64> compile_threshold,
CompileMode compile_mode,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) {
if (FailOnXlaCompilation()) {
@ -348,9 +459,20 @@ Status XlaCompilationCache::CompileImpl(
VLOG(3) << i << ": " << args[i].HumanString();
}
}
absl::optional<int64> compile_threshold;
if (compile_mode == CompileMode::kLazy) {
compile_threshold = kDefaultCompilationThreshold;
} else if (compile_mode == CompileMode::kAsync) {
compile_threshold = 0; // for now, always compile right away.
}
TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
VLOG(2) << "Signature: " << signature.HumanString();
string human_signature;
if (VLOG_IS_ON(2)) {
human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name();
VLOG(2) << "Signature: " << human_signature;
}
// The outer lock protects the existence of the cache entry. It does not
// protect the contents of the cache entry.
@ -402,14 +524,18 @@ Status XlaCompilationCache::CompileImpl(
// cache eviction.
mutex_lock entry_lock(entry->mu);
int64 current_request_count = ++entry->request_count;
VLOG(2) << "Compilation cache entry hit: " << entry->compiled
<< " signature: " << signature.HumanString() << " with request count "
VLOG(2) << "Compilation cache entry hit: "
<< static_cast<int>(entry->compile_state)
<< " signature: " << human_signature << " with request count "
<< current_request_count << " and compile threshold "
<< compile_threshold.value_or(0);
if (!entry->compiled) {
// TODO(sanjoy): Refactor this code into helper functions.
bool return_null = false;
CompileState state = entry->compile_state;
if (state == CompileState::kUncompiled) {
XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
const bool should_compile = [&] {
if (!compile_threshold.has_value()) {
if (compile_mode == CompileMode::kStrict) {
// Lazy compilation is disabled.
return true;
}
@ -418,7 +544,7 @@ Status XlaCompilationCache::CompileImpl(
BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
function.name())
.IgnoreError();
VLOG(3) << "Not compiling cluster " << function.name()
VLOG(2) << "Not compiling cluster " << function.name()
<< " because it is megamorphic.";
return false;
}
@ -427,10 +553,21 @@ Status XlaCompilationCache::CompileImpl(
return true;
}
if (compile_mode == CompileMode::kAsync) {
// Asynchronous compilation is enabled.
mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
if (async_compilation_state_.num_ongoing_compilations >=
async_compilation_state_.kMaxNumOngoingCompilations) {
VLOG(2) << "Not asynchronously compiling cluster " << function.name()
<< " because of too many ongoing compilations.";
return false;
}
}
bool reached_compile_threshold =
current_request_count >= *compile_threshold;
if (!reached_compile_threshold) {
VLOG(3)
VLOG(2)
<< "Not compiling cluster " << function.name()
<< " because it has not reached compile threshold; threshold is "
<< *compile_threshold << " execution count "
@ -440,62 +577,34 @@ Status XlaCompilationCache::CompileImpl(
}();
if (!should_compile) {
VLOG(2) << "Not compiling for signature: " << signature.HumanString();
*out_compilation_result = nullptr;
*out_executable = nullptr;
return Status::OK();
}
tensorflow::Env* env = tensorflow::Env::Default();
const uint64 compile_start_us = env->NowMicros();
// Do the actual JIT compilation without holding the lock (it can take
// a long time.)
XlaCompiler compiler(options);
entry->compiled = true;
entry->compilation_status =
compile_fn(&compiler, &entry->compilation_result);
TF_RETURN_IF_ERROR(entry->compilation_status);
CHECK_EQ(entry->executable.get(), nullptr);
entry->compilation_status =
BuildExecutable(options, entry->compilation_result, &entry->executable);
const uint64 compile_end_us = env->NowMicros();
const uint64 compile_time_us = compile_end_us - compile_start_us;
metrics::UpdateXlaCompilationTime(compile_time_us);
{
mutex_lock lock(cluster_compile_stats_mu_);
auto it = cluster_compile_stats_.find(function.name());
it->second.compile_count++;
it->second.cumulative_compile_time_us += compile_time_us;
LogOnceXlaCompiledFirstCluster();
VLOG(1) << "compiled " << function.name() << " "
<< it->second.compile_count
<< " times, compile time: " << compile_time_us
<< " us, cumulative: " << it->second.cumulative_compile_time_us
<< " us ("
<< tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
1.0e6)
<< " / "
<< tensorflow::strings::HumanReadableElapsedTime(
it->second.cumulative_compile_time_us / 1.0e6)
<< ")";
XlaJitCompilationActivity jit_compilation_activity;
jit_compilation_activity.set_cluster_name(function.name());
jit_compilation_activity.set_compile_count(it->second.compile_count);
jit_compilation_activity.set_compile_time_us(compile_time_us);
jit_compilation_activity.set_cumulative_compile_time_us(
it->second.cumulative_compile_time_us);
VLOG(2) << "Not compiling for signature: " << human_signature;
return_null = true;
} else if (compile_mode == CompileMode::kAsync) {
VLOG(2) << "Queueing asynchronous compilation for signature: "
<< human_signature;
TF_RETURN_IF_ERROR(CompileAsynchronous(entry, options, args,
function.name(), compile_fn));
return_null = true;
} else {
VLOG(2) << "Instantly compiling for signature: " << human_signature;
TF_RETURN_IF_ERROR(
BroadcastXlaActivity(std::move(jit_compilation_activity)));
CompileStrict(entry, options, args, function.name(), compile_fn));
}
} else if (state == CompileState::kCompiling) {
VLOG(2) << "Ongoing asynchronous compilation for signature: "
<< human_signature;
return_null = true;
} else if (state == CompileState::kCompiled) {
VLOG(2) << "Already Compiled for signature: " << human_signature;
}
if (return_null) {
*out_compilation_result = nullptr;
*out_executable = nullptr;
} else {
TF_RETURN_IF_ERROR(entry->compilation_status);
*out_compilation_result = &entry->compilation_result;
*out_executable = entry->executable.get();
}
TF_RETURN_IF_ERROR(entry->compilation_status);
*out_compilation_result = &entry->compilation_result;
*out_executable = entry->executable.get();
return Status::OK();
}

View File

@ -50,6 +50,13 @@ class XlaCompilationCache : public ResourceBase {
enum class CompileMode {
kLazy,
kStrict,
kAsync,
};
enum class CompileState {
kUncompiled,
kCompiling,
kCompiled,
};
// Compiles a function into a XlaCompiler::CompilationResult that can be used
@ -62,7 +69,9 @@ class XlaCompilationCache : public ResourceBase {
// heuristics, the compilation cache may decide not to compile the cluster at
// this time. In this case it returns null into both `out_compilation_result`
// and `out_executable`. If `compile_mode` is `kStrict` then the compilation
// cache always attempts the compilation on a cache miss.
// cache always attempts the compilation on a cache miss. If compilation mode
// is 'kAsync' compilation of the cluster happens in the background while the
// fallback path executes.
//
// The result of compilation is written to `*out_compilation_result`, which
// must be non-null. If `out_executable` is non-null, also builds an
@ -71,7 +80,7 @@ class XlaCompilationCache : public ResourceBase {
// non-constant outputs.
Status Compile(const XlaCompiler::Options& options,
const NameAttrList& function,
absl::Span<const XlaCompiler::Argument> args,
const std::vector<XlaCompiler::Argument>& args,
const XlaCompiler::CompileOptions& compile_options,
CompileMode compile_mode,
const XlaCompiler::CompilationResult** out_compilation_result,
@ -83,7 +92,7 @@ class XlaCompilationCache : public ResourceBase {
// XlaCompiler, if possible.
Status CompileSingleOp(
const XlaCompiler::Options& options,
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
const XlaCompiler::CompileOptions& compile_options,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable);
@ -126,10 +135,11 @@ class XlaCompilationCache : public ResourceBase {
// Common implementation of Compile and CompileSingleOp.
Status CompileImpl(
const XlaCompiler::Options& options, const NameAttrList& function,
absl::Span<const XlaCompiler::Argument> args,
const std::vector<XlaCompiler::Argument>& args,
const std::function<Status(XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult*)>& compile_fn,
absl::optional<int64> compile_threshold,
CompileMode compile_mode,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable);
@ -146,8 +156,8 @@ class XlaCompilationCache : public ResourceBase {
struct Entry {
mutex mu;
// Have we tried compiling this entry?
bool compiled = false;
// The current compilation state for this entry.
CompileState compile_state = CompileState::kUncompiled;
// The number of times a compilation with this signature has been requested.
int64 request_count = 0;
@ -163,6 +173,22 @@ class XlaCompilationCache : public ResourceBase {
std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
};
Status CompileStrict(
Entry* entry, const XlaCompiler::Options& options,
const std::vector<XlaCompiler::Argument>& args,
const string& function_name,
const std::function<Status(XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult*)>& compile_fn)
TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu);
Status CompileAsynchronous(
Entry* entry, const XlaCompiler::Options& options,
const std::vector<XlaCompiler::Argument>& args,
const string& function_name,
const std::function<Status(XlaCompiler* compiler,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult*)>& compile_fn);
mutex compile_cache_mu_;
absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
TF_GUARDED_BY(compile_cache_mu_);
@ -189,6 +215,30 @@ class XlaCompilationCache : public ResourceBase {
absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
TF_GUARDED_BY(cluster_compile_stats_mu_);
struct AsyncCompilationState {
mutex async_compilation_state_mu;
// Number of threads for asynchronous compilations.
static constexpr int64 kNumCompilerThreads = 10;
// Maximum number of ongoing compilations.
static constexpr int64 kMaxNumOngoingCompilations = kNumCompilerThreads;
// Number of ongoing compilations.
int64 num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) =
0;
// Pool of threads for asynchronous compilations.
std::unique_ptr<thread::ThreadPool> compiler_threads;
AsyncCompilationState() {
compiler_threads = absl::make_unique<tensorflow::thread::ThreadPool>(
tensorflow::Env::Default(), "async_compiler_threads",
kNumCompilerThreads);
}
} async_compilation_state_;
// The number of times a lazy compilation must be requested for a specific
// signature before we attempt to compile it.
static constexpr int64 kDefaultCompilationThreshold = 2;

View File

@ -48,11 +48,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
GetAllocator(ctx->device(), stream, platform_info_);
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
XlaComputationLaunchContext launch_context(
client, allocator, client->default_device_ordinal(),
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
@ -74,9 +74,6 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
input_output_alias);
TF_RETURN_IF_ERROR(execution_inputs.status());
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
VLOG(2) << "Executing computation: " << name();
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
@ -126,12 +123,10 @@ Status XlaCompileOnDemandOp::Compile(
write_into_cache);
}));
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options = GenerateCompilerOptions(
**cache, *ctx->function_library(), ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
platform_info_, /*has_ref_vars=*/true);
// No detailed logging from on demand op.
options.detailed_logging = false;
XlaCompiler::CompileOptions compile_options;

View File

@ -214,14 +214,18 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
// Fills in `execution_input` with `buffer` for `index`.
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
xla::ShapeIndex index,
se::DeviceMemoryBase& buffer,
se::DeviceMemoryBase buffer,
bool donate_buffer, int device_ordinal,
se::DeviceMemoryAllocator* allocator) {
xla::MaybeOwningDeviceMemory* in_buffer =
execution_input.MutableBuffer(index);
if (donate_buffer) {
// Here we pass ownership of the buffer to execution_input without releasing
// ownership from the caller of PopulateExecutionInputBuffer. If execution
// succeeds, we'll take back that duplicate ownership in
// GetOrCreateTensorForOutput. If execution fails, the ExecutionInput will
// release that duplicate ownership automatically.
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
buffer = se::DeviceMemoryBase();
} else {
*in_buffer = buffer;
}
@ -308,18 +312,21 @@ static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
return t;
}
// Get aliased tensor, or make a new one for the corresponding output operation.
static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
// Get aliased tensor from output, or make a new one for the corresponding
// output operation. Transfers ownership of the buffer from output to the
// returned tensor.
static xla::StatusOr<Tensor> GetOrCreateTensorForOutput(
xla::ScopedShapedBuffer& output, int output_num, OpKernelContext* ctx,
int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, const Tensor*>& resource_vars_snapshots,
DataType output_dtype, const TensorShape& output_shape,
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream,
bool use_multiple_streams, std::shared_ptr<se::Event> definition_event) {
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
? xla::ShapeIndex({output_num})
: xla::ShapeIndex({});
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
input_output_alias.GetAliasedParameter(output_index)) {
@ -330,24 +337,39 @@ static Tensor GetOrCreateTensorForOutput(
ctx->input(tf_param).dtype() != DT_RESOURCE
? ctx->input(tf_param)
: *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
if (output_buffer.opaque() == input_tensor.data()) {
se::DeviceMemoryBase input_buffer =
XlaTensor::DeviceMemoryFromTensor(input_tensor);
se::DeviceMemoryBase output_buffer = output.buffer({output_num});
if (input_buffer.opaque() == output_buffer.opaque()) {
// In the case of a donated buffer, both input_tensor and output think
// they have ownership of the buffer (see comment in
// PopulateExecutionInputBuffer). Release ownership from output to avoid
// double free.
output.set_buffer(se::OwningDeviceMemory(), {output_num});
return input_tensor;
}
}
return MakeTensor(output_dtype, output_shape, output_buffer,
output_allocator);
}
static void PopulateXlaTensor(Tensor* output_tensor,
xla::ScopedShapedBuffer* output, int output_num,
se::Stream* stream, bool use_multiple_streams,
std::shared_ptr<se::Event> definition_event) {
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num}));
if (use_multiple_streams) {
xla_tensor->ResetDefinitionEvent(definition_event, stream);
if (allocate_xla_tensors) {
Tensor output_tensor;
TF_RETURN_IF_ERROR(
ctx->allocate_temp(output_dtype, output_shape, &output_tensor));
if (output_tensor.TotalBytes() > 0) {
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
TF_RET_CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
if (use_multiple_streams) {
xla_tensor->ResetDefinitionEvent(definition_event, stream);
}
}
return output_tensor;
}
se::DeviceMemoryBase output_buffer = output.buffer({output_num});
Tensor output_tensor =
MakeTensor(output_dtype, output_shape, output_buffer, output_allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
return output_tensor;
}
// Sets output `output_num` for `ctx` provided it is known at a compile time.
@ -525,22 +547,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (allocate_xla_tensors_) {
Tensor* output_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
if (output_tensor->TotalBytes() > 0) {
PopulateXlaTensor(output_tensor, &output, output_num, stream,
use_multiple_streams_, definition_event);
}
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_vars,
ctx->expected_output_dtype(i), shape, buffer, allocator);
ctx->set_output(i, output_tensor);
}
output.set_buffer(se::OwningDeviceMemory(), {output_num});
TF_ASSIGN_OR_RETURN(
Tensor output_tensor,
GetOrCreateTensorForOutput(
output, output_num, ctx, missing_ctx_input_prefix,
input_output_alias, compilation_result->input_mapping,
resource_vars, ctx->expected_output_dtype(i), shape, allocator,
allocate_xla_tensors_, stream, use_multiple_streams_,
definition_event));
ctx->set_output(i, output_tensor);
++output_num;
}
}
@ -571,22 +586,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
return errors::Internal("Mismatched type in variable write");
}
Tensor output_tensor;
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(
ctx->allocate_temp(write.type, write.shape, &output_tensor));
if (write.shape.num_elements() > 0) {
PopulateXlaTensor(&output_tensor, &output, output_num, stream,
use_multiple_streams_, definition_event);
}
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_vars, write.type,
write.shape, buffer, allocator);
}
output.set_buffer(se::OwningDeviceMemory(), {output_num});
TF_ASSIGN_OR_RETURN(
Tensor output_tensor,
GetOrCreateTensorForOutput(output, output_num, ctx,
missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping,
resource_vars, write.type, write.shape,
allocator, allocate_xla_tensors_, stream,
use_multiple_streams_, definition_event));
var->is_initialized |= write.modified;
*var->tensor() = output_tensor;
++output_num;

View File

@ -79,7 +79,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
auto device = static_cast<Device*>(device_base);
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
if (device->device_type() == DEVICE_CPU) {
platform_id = se::host::kHostPlatformId;
@ -101,37 +101,35 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
xla_device_metadata->client()->backend().shared_memory_allocator();
}
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
xla_device_metadata, custom_allocator);
}
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
DeviceBase* device, se::Stream* stream,
const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
auto* alloc = device->GetAllocator({});
if (!stream) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
return std::make_shared<se::TfAllocatorAdapter>(alloc, platform);
}
tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
return &tf_allocator_adapter->value();
return std::make_shared<se::TfAllocatorAdapter>(alloc, stream);
}
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache,
const FunctionLibraryRuntime& function_library, DeviceBase* device,
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
se::Stream* stream, const XlaPlatformInfo& platform_info,
bool has_ref_vars) {
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache.client());
if (stream != nullptr) {
@ -142,8 +140,7 @@ XlaCompiler::Options GenerateCompilerOptions(
options.graph_def_version = function_library.graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(tf_allocator_adapter, device, stream, platform_info);
options.device_allocator = GetAllocator(device, stream, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();

View File

@ -29,10 +29,10 @@ class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
explicit XlaPlatformInfo(
const DeviceType device_type, se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
std::shared_ptr<se::DeviceMemoryAllocator> device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
@ -45,7 +45,7 @@ class XlaPlatformInfo {
}
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator() const {
return device_allocator_;
}
@ -74,7 +74,9 @@ class XlaPlatformInfo {
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
// The allocator is of unknown provenance; keep it in a shared pointer to
// set an artificial refcount of one.
std::shared_ptr<se::DeviceMemoryAllocator> device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
@ -94,8 +96,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
// dummy tensors.
//
// `stream` parameter is nullable when running on host.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
DeviceBase* device, se::Stream* stream,
const XlaPlatformInfo& platform_info);
@ -104,8 +105,8 @@ se::DeviceMemoryAllocator* GetAllocator(
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache,
const FunctionLibraryRuntime& function_library, DeviceBase* device,
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
se::Stream* stream, const XlaPlatformInfo& platform_info,
bool has_ref_vars);
} // namespace tensorflow

View File

@ -51,6 +51,7 @@ td_library(
compatible_with = get_compatible_with_cloud(),
includes = ["include"],
deps = [
"@llvm-project//mlir:MemRefOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
],
@ -445,6 +446,7 @@ cc_library(
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CopyOpInterface",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
@ -598,6 +600,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
@ -707,6 +710,7 @@ cc_library(
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
@ -939,6 +943,7 @@ cc_library(
":hlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",

View File

@ -119,31 +119,30 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
Type TensorType>: HLO_Op<mnemonic,
!listconcat(traits,
[InferShapedTypeOpInterface, InferFusibilityOpInterface,
SameOperandsAndResultShape])> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType);
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure();
}
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
bool inferInputOutputShapeEquality(int input, int output) {
return true;
}
llvm::Optional<Value> inferEffectiveWorkloadShape() {
return getOperation()->getResult(0);
}
}];
Type TensorType> : HLO_Op<mnemonic, traits # [Elementwise,
InferShapedTypeOpInterface, InferFusibilityOpInterface,
SameOperandsAndResultShape]> {
let arguments = (ins TensorType:$operand);
let results = (outs TensorType);
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure();
}
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes);
}
bool inferInputOutputShapeEquality(int input, int output) {
return true;
}
llvm::Optional<Value> inferEffectiveWorkloadShape() {
return getOperation()->getResult(0);
}
}];
}
// Abs supports complex to real, so element type is not guaranteed to match.
@ -826,8 +825,9 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
let hasCustomHLOConverter = 1;
}
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
[NoSideEffect]> {
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [
NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
string description = [{
This is a generalization of the BroadcastInDimOp which accepts its output

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"

View File

@ -33,6 +33,7 @@ limitations under the License.
#ifndef LHLO_OPS
#define LHLO_OPS
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -685,7 +686,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let extraClassDeclaration = [{
SmallVector<Value, 4> getInputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
this->region().walk([&](memref::TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load.memref());
});
@ -694,7 +695,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
SmallVector<Value, 4> getOutputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
this->region().walk([&](memref::TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.memref());
});
@ -703,7 +704,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
SmallVector<Value, 4> getFusionParameters() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
this->region().walk([&](memref::TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load);
});
@ -712,7 +713,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
SmallVector<Value, 4> getFusionResults() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
this->region().walk([&](memref::TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.tensor());
});

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef LHLO_OPS_BASE
#define LHLO_OPS_BASE
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"

View File

@ -597,12 +597,18 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TupleOp op) {
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
op.operand_type_end()};
auto expectedType = TupleType::get(op.getContext(), operandTypes);
if (op.getType() != expectedType) {
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
op.getType(), expectedType));
auto opType = op.getType().dyn_cast<TupleType>();
if (!opType) return op.emitOpError("tuple op with non-tuple result");
if (op.getNumOperands() != opType.size())
return op.emitOpError(
"number of operands to tuple expected to match number of types in "
"resultant tuple type");
for (auto it : llvm::enumerate(
llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
if (std::get<0>(it.value()) != std::get<1>(it.value()))
return op.emitOpError("has return type mismatch at ")
<< it.index() << "th value (" << std::get<0>(it.value())
<< " != " << std::get<1>(it.value()) << ")";
}
return success();
}
@ -902,6 +908,18 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
context);
}
LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
MLIRContext*, llvm::Optional<mlir::Location>, ValueRange, DictionaryAttr,
RegionRange, llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
return failure();
}
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
OpBuilder&, SmallVectorImpl<Value>& reifiedReturnShapes) {
reifiedReturnShapes.push_back(output_dimensions());
return success();
}
//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//

View File

@ -36,12 +36,13 @@ def DynamicBroadcastToOwnShape_4 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr),
(Tensor_CastOp $x)>;
def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape)>;
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp:$op (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape),
[(HasSameType $shape, $op)]>;
def IdentityBroadcastReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
(replaceWithValue $input),

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@ -156,13 +157,13 @@ struct EraseConstOp : public OpRewritePattern<ConstOp> {
LogicalResult matchAndRewrite(ConstOp op,
PatternRewriter& rewriter) const override {
Value memref = op.output();
if (!memref.getDefiningOp<AllocOp>()) {
if (!memref.getDefiningOp<memref::AllocOp>()) {
return failure();
}
// Check that all uses of the memref are either DeallocOps or this op.
for (Operation* user : memref.getUsers())
if (user != op && !isa<DeallocOp>(user)) return failure();
if (user != op && !isa<memref::DeallocOp>(user)) return failure();
rewriter.eraseOp(op);
return success();

View File

@ -71,7 +71,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
dynamic_operands.push_back(alloc_operand);
}
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
return rewriter->create<memref::AllocOp>(loc, memref_type, dynamic_operands);
}
Value InsertAlloc(Location loc, OpResult result,
@ -85,7 +85,7 @@ Value InsertAlloc(Location loc, OpResult result,
MemRefType::get(result_type.getShape(), result_type.getElementType());
OpBuilder::InsertionGuard guard(*rewriter);
rewriter->setInsertionPoint(result.getDefiningOp());
auto alloc = rewriter->create<AllocOp>(loc, memref_type);
auto alloc = rewriter->create<memref::AllocOp>(loc, memref_type);
return alloc;
}
@ -207,7 +207,7 @@ class HloToLhloReshapeUnrankedConverter
if (unranked_operand_type == nullptr) return failure();
auto result_type = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<MemRefCastOp>(
rewriter.replaceOpWithNewOp<memref::CastOp>(
op, adaptor.operand(),
MemRefType::get(result_type.getShape(), result_type.getElementType()));
return success();
@ -235,7 +235,7 @@ class HloToLhloDynamicReshapeConverter
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
rewriter.replaceOpWithNewOp<memref::ReshapeOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
@ -273,7 +273,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
@ -295,7 +295,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
for (int i = operand_rank - 1; i >= 0; --i) {
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[i])
? b->create<DimOp>(loc, operand, i).getResult()
? b->create<memref::DimOp>(loc, operand, i).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
operand_sizes[i] = operand_dim_size;
@ -355,7 +355,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
auto transformed_operand = b->create<memref::ReinterpretCastOp>(
loc, type_erased_memref_type, operand,
/*offset=*/b->getI64IntegerAttr(0), sizes, strides);
return transformed_operand;
@ -484,12 +484,12 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
// TODO(b/175789537) Remove this pattern.
class HloToLhloTensorStoreOpLegacyConverter
: public BaseOpConversion<mlir::TensorStoreOp> {
: public BaseOpConversion<mlir::memref::TensorStoreOp> {
public:
using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
using BaseOpConversion<mlir::memref::TensorStoreOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mlir::TensorStoreOp op, ArrayRef<Value> operands,
mlir::memref::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
operands.back());
@ -577,14 +577,16 @@ struct HloLegalizeToLhlo
ConversionTarget target(context);
target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<shape::ShapeDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
// Declare tensor_load and tensor_store illegal.
target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>();
// tensor_to_memref is illegal if it has uses.
// TODO(b/175670649) Make tensor_to_memref illegal.
target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>(
target.addIllegalOp<mlir::memref::TensorLoadOp,
mlir::memref::TensorStoreOp>();
// buffer_cast is illegal if it has uses.
// TODO(b/175670649) Make buffer_cast illegal.
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
[](auto op) { return op->use_empty(); });
BufferizeTypeConverter converter;

View File

@ -108,7 +108,7 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
dyn_sizes.push_back(
b.create<IndexCastOp>(loc, b.getIndexType(), extract));
} else {
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
dyn_sizes.push_back(b.create<memref::DimOp>(loc, tensor, en.index()));
}
}
return dyn_sizes;
@ -324,13 +324,13 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
}
// Create two loads from the input.
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
auto lhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
rewriter.create<memref::StoreOp>(loc, op_result, lhlo_op.out());
rewriter.eraseOp(lhlo_op);
return success();
}
@ -518,15 +518,20 @@ class HloDynamicBroadcastInDimConverter
LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
// Convert only if the producer is an HLO constant. Ideally the pattern
// (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted
// to an Tensor-dialect op similar to TF ConstantLikeOp.
if (!op.operand().getDefiningOp<mhlo::ConstOp>()) return failure();
// If the input has a static shape we know exactly when the broadcast must
// expand (the dimension is 1, which also trivially expands to 1) or will
// never expand (the dimension is not 1). This means we can lower the
// broadcast just as we would lower a fully static broadcast and go directly
// to linalg.generic. This also covers the important case of broadcasting a
// scalar.
// Ideally the pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`)
// should be converted to an Tensor-dialect op similar to TF ConstantLikeOp.
mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
Value operand = adaptor.operand();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
if (!operand_type || operand_type.getRank() != 0) return failure();
if (!operand_type || !operand_type.hasStaticShape()) return failure();
Value shape = adaptor.output_dimensions();
auto shape_type = shape.getType().cast<RankedTensorType>();
@ -544,13 +549,27 @@ class HloDynamicBroadcastInDimConverter
}
int64_t nloops = result_type.getRank();
auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dim_exprs;
dim_exprs.reserve(nloops);
if (op.broadcast_dimensions()) {
for (const auto& broadcast_dim :
enumerate(op.broadcast_dimensions().getIntValues())) {
int64_t size = broadcast_dim.value().getSExtValue();
bool expansion_needed = operand_shape[broadcast_dim.index()] == 1;
dim_exprs.push_back(expansion_needed ? rewriter.getAffineConstantExpr(0)
: rewriter.getAffineDimExpr(size));
}
}
Value init = rewriter.create<linalg::InitTensorOp>(
loc, dyn_dims, result_type.getShape(), result_type.getElementType());
Operation* generic = rewriter.create<linalg::GenericOp>(
loc, TypeRange{init.getType()}, ValueRange{operand},
/*outputBuffers=*/ValueRange{init},
llvm::makeArrayRef(
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {},
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dim_exprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(nloops)}),
GetNParallelLoopsAttrs(nloops),
@ -590,8 +609,8 @@ class LhloBroadcastInDimConverter
operand_type.getDimSize(0) <
result_type.getDimSize(broadcast_dims.front())) {
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value val =
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
Value val = rewriter.create<memref::LoadOp>(loc, operand,
llvm::makeArrayRef({zero}));
rewriter.create<linalg::GenericOp>(
loc, /*inputs=*/ValueRange{},
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
@ -971,7 +990,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
}
// First fill the output buffer with the init value.
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
Value init_value =
rewriter.create<memref::LoadOp>(loc, adaptor.init_values()[0]);
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
@ -1011,9 +1031,9 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
// expects scalar SSA values. Add some allocs around the original op to
// make it compatible.
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
Value alloc_a = rewriter.create<memref::AllocaOp>(loc, arg_type);
Value alloc_b = rewriter.create<memref::AllocaOp>(loc, arg_type);
Value alloc_res = rewriter.create<memref::AllocaOp>(loc, arg_type);
// Now turn the existing signature
// (memref<X>, memref<X>, memref<X>) -> ()
@ -1030,13 +1050,15 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
// Store the arguments into the newly allocated buffers.
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(0),
alloc_a);
rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(1),
alloc_b);
rewriter.replaceOp(entry_block->getTerminator(), {});
// Load & yield the result.
rewriter.setInsertionPointToEnd(entry_block);
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
auto load_res = rewriter.create<memref::LoadOp>(loc, alloc_res);
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
}
@ -1099,8 +1121,8 @@ class SliceConverter : public OpConversionPattern<OpTy> {
slice_op.strides().template getValue<int64_t>(i)));
}
if (isLHLO) {
auto linalg_op =
rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides);
auto linalg_op = rewriter.create<memref::SubViewOp>(loc, args[0], offsets,
sizes, strides);
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
rewriter.eraseOp(slice_op);
} else {
@ -1149,14 +1171,14 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
switch (type) {
case DotOperationType::kMatrixMatrix: {
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kMatrixVector: {
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
break;
}
case DotOperationType::kVectorDot:
@ -1203,11 +1225,11 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
SmallVector<Value, 8> dyn_shape;
if (result_type.isDynamicDim(0))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
if (result_type.isDynamicDim(1))
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 1));
if (result_type.isDynamicDim(2))
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 2));
return dyn_shape;
}
@ -1307,7 +1329,7 @@ SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
for (int i = 0, j = 0; i < rank; ++i) {
if (s.count(i)) continue;
if (!result_type.isDynamicDim(j++)) continue;
dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
dyn_shape.push_back(b.create<memref::DimOp>(loc, arg, i));
}
return dyn_shape;
@ -1467,7 +1489,7 @@ struct NormalConvOpOnTensorsConversion
// The output shape is N spatial_dims F.
SmallVector<Value, 8> dyn_sizes;
if (result_type.isDynamicDim(0)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
dyn_sizes.push_back(rewriter.create<memref::DimOp>(loc, input, 0));
}
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
if (result_type.isDynamicDim(i)) {
@ -1476,7 +1498,8 @@ struct NormalConvOpOnTensorsConversion
}
}
if (result_type.isDynamicDim(rank - 1)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
dyn_sizes.push_back(
rewriter.create<memref::DimOp>(loc, filter, rank - 1));
}
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
@ -1769,6 +1792,99 @@ struct ReduceWindowOpOnTensorsConversion
}
};
/// Converts xla-hlo.torch_index_select op to a linalg.indexed_generic op.
struct TorchIndexSelectOpOnTensorsConversion
: public OpConversionPattern<mhlo::TorchIndexSelectOp> {
using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::TorchIndexSelectOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
mhlo::TorchIndexSelectOp::Adaptor adaptor(args);
int axis = static_cast<int>(op.dim());
int batch = static_cast<int>(op.batch_dims());
auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
int num_indices = static_cast<int>(index_shaped_type.getRank());
auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
if (batch < 0) batch += num_indices;
Location loc = op.getLoc();
auto result_type = op.getResult().getType().cast<ShapedType>();
int rank = static_cast<int>(result_type.getRank());
SmallVector<AffineMap, 2> indexing_maps;
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < batch; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(i));
}
for (int i = 0, e = num_indices - batch; i < e; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(axis + i));
}
indexing_maps.emplace_back(
AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
// The output shape is
// `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
SmallVector<Value, 4> dyn_sizes;
for (int i = 0; i < rank; ++i) {
if (!result_type.isDynamicDim(i)) continue;
if (i < axis) {
dyn_sizes.push_back(
rewriter.create<memref::DimOp>(loc, adaptor.input(), i));
} else if (i < (axis + num_indices - batch)) {
int idx = i - axis + batch;
dyn_sizes.push_back(
rewriter.create<memref::DimOp>(loc, adaptor.index(), idx));
} else {
int idx = i - (axis + num_indices - batch) + axis + 1;
dyn_sizes.push_back(
rewriter.create<memref::DimOp>(loc, adaptor.input(), idx));
}
}
Value init_op = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
loc, /*resultTensors=*/ArrayRef<Type>{result_type},
/*inputs=*/adaptor.index(),
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
SmallVector<Type, 4> body_arg_types;
SmallVector<Value, 2> linalg_op_args = {adaptor.index()};
// Add a block to the region.
auto* region = &linalg_op.region();
auto* block = rewriter.createBlock(region, region->end());
body_arg_types.append(rank, rewriter.getIndexType());
for (auto block_args : linalg_op_args) {
body_arg_types.push_back(
block_args.getType().cast<ShapedType>().getElementType());
}
block->addArguments(body_arg_types);
block->addArguments(result_type.getElementType());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);
SmallVector<Value, 4> indices;
Value casted_value = rewriter.create<IndexCastOp>(
loc, block->getArgument(rank), rewriter.getIndexType());
for (int i = 0; i < axis; ++i) {
indices.push_back(block->getArgument(i));
}
indices.push_back(casted_value);
for (int i = axis + num_indices - batch; i < rank; ++i) {
indices.push_back(block->getArgument(i));
}
Value res =
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
rewriter.create<linalg::YieldOp>(loc, res);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
@ -1856,8 +1972,8 @@ struct LhloLegalizeToLinalgPass
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, StandardOpsDialect,
AffineDialect>();
math::MathDialect, memref::MemRefDialect,
StandardOpsDialect, AffineDialect>();
auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
@ -1881,6 +1997,9 @@ struct HloLegalizeToLinalgPass
math::MathDialect, StandardOpsDialect,
tensor::TensorDialect, scf::SCFDialect>();
// TODO: DimOp shouldn't be in MemRefDialect
target.addLegalOp<memref::DimOp>();
auto func = getFunction();
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
@ -1961,6 +2080,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
DepthwiseConvOpOnTensorsConversion,
ReduceOnTensorsConversion,
ReduceWindowOpOnTensorsConversion,
TorchIndexSelectOpOnTensorsConversion,
PadOpOnTensorsConversion>(context);
// clang-format on
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,

View File

@ -74,6 +74,9 @@ class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
}
};
// This approximation resembles Eigen and realizes a constant approximation for
// the +/-1 limits on top.
// https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/MathFunctionsImpl.h
class ApproximateTanhLowering
: public ApproximateOnExtendedF32Lowering<math::TanhOp> {
public:
@ -83,42 +86,18 @@ class ApproximateTanhLowering
// Emits the fast tanh approximation that is also used by XLA.
Value emitApproximation(ValueRange args, Location loc,
PatternRewriter &rewriter) const override {
// For small values of x, we can approximate tanh(x) = x. For extremely
// small values of x (|x| < 1e-37), the other approximation would evaluate
// tanh(x) = 0.
Value input = args.front();
assert(input.getType().isF32());
constexpr float kCanUseApprox = 0.0004;
Value abs_value = rewriter.create<AbsFOp>(loc, input);
Value can_use_approx = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(kCanUseApprox));
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
abs_value, can_use_approx);
// Clamp the input to [-c, c].
Value max_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(7.90531110763549805f));
Value smaller_than_max =
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
Value clamped_half =
rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
Value min_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
Value larger_than_min = rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE,
clamped_half, min_clamp);
Value input_clamped = rewriter.create<SelectOp>(loc, larger_than_min,
clamped_half, min_clamp);
static constexpr std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
4.89352455891786e-03f};
static constexpr std::array<float, 4> denominator_coeffs{
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
4.89352518554385e-03f};
Value input_squared =
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
// Materialize polynomial approximation.
Value input_squared = rewriter.create<MulFOp>(loc, input, input);
Value numerator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
for (int i = 1; i < numerator_coeffs.size(); i++) {
@ -127,9 +106,7 @@ class ApproximateTanhLowering
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
}
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
numerator = rewriter.create<MulFOp>(loc, input, numerator);
Value denominator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
for (int i = 1; i < denominator_coeffs.size(); i++) {
@ -138,10 +115,38 @@ class ApproximateTanhLowering
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
}
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
return rewriter.create<SelectOp>(loc, return_input, input, approx);
// For small values of |x|, we can approximate tanh(x) = x. For extremely
// small values of x (|x| < 1e-37), the other approximation would evaluate
// tanh(x) = 0.
constexpr float kUseIdentityApprox = 0.0004;
Value abs_input = rewriter.create<AbsFOp>(loc, input);
Value use_identity_approx = rewriter.create<CmpFOp>(
loc, CmpFPredicate::OLT, abs_input,
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(kUseIdentityApprox)));
approx = rewriter.create<SelectOp>(loc, use_identity_approx, input, approx);
// For very small/large values, use a constant approximation -1/1.
Value too_large_input = rewriter.create<CmpFOp>(
loc, CmpFPredicate::UGT, input,
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(7.90531110763549805f)));
Value too_small_input = rewriter.create<CmpFOp>(
loc, CmpFPredicate::ULT, input,
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
approx = rewriter.create<SelectOp>(
loc, too_large_input,
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
approx);
approx = rewriter.create<SelectOp>(
loc, too_small_input,
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
approx);
return approx;
}
};

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -95,7 +96,7 @@ class LhloFuseLinalgPass
continue;
}
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
auto alias = tensor_load.memref();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
@ -103,7 +104,7 @@ class LhloFuseLinalgPass
continue;
}
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
auto alias = tensor_to_memref.tensor();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);

View File

@ -96,9 +96,10 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
// Load the initial value and store it to the output.
for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
ArrayRef<Value>{index});
auto init_value =
rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
rewriter.create<mlir::memref::StoreOp>(
loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
}
// Insert a loop into the body to compute the reduction. The loop ranges
@ -128,8 +129,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
auto oneAttr = rewriter.getI64IntegerAttr(1);
OpFoldResult size = oneAttr;
OpFoldResult stride = oneAttr;
auto accumulator = rewriter.create<SubViewOp>(loc, resType, output,
offset, size, stride);
auto accumulator = rewriter.create<memref::SubViewOp>(
loc, resType, output, offset, size, stride);
llvm::SmallVector<Value, 4> indexings;
auto input_buffer = *reduce_op.operands().begin();
auto input_type_rank =
@ -143,8 +144,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
}));
SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
auto rhs = rewriter.create<SubViewOp>(loc, accumulator.getType(), input,
offsets, sizes, strides);
auto rhs = rewriter.create<memref::SubViewOp>(
loc, accumulator.getType(), input, offsets, sizes, strides);
// Now copy over the actual body of the reduction, leaving out the
// terminator.
@ -179,8 +180,9 @@ struct LhloLegalizeToGpuPass
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,
LmhloDialect>();
target.addIllegalOp<ReduceOp>();
auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
@ -43,10 +44,11 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
Block* lhlo_block, OpBuilder* b) {
SmallVector<Value, 2> arg_bufs;
for (auto arg_type : lhlo_block->getArgumentTypes()) {
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
arg_bufs.push_back(
b->create<memref::AllocOp>(loc, arg_type.cast<MemRefType>()));
}
for (auto operand : llvm::enumerate(operands)) {
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
b->create<memref::StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
}
// Clone the ops from `lhlo_block`.
BlockAndValueMapping mapping;
@ -55,7 +57,7 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
auto clone = b->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults());
}
return b->create<LoadOp>(loc, arg_bufs.back());
return b->create<memref::LoadOp>(loc, arg_bufs.back());
}
// Converts a block with LHLO ops and with signature:
@ -78,7 +80,8 @@ void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
size_t dim_index, int64_t dim, OpBuilder* b) {
return dim == ShapedType::kDynamicSize
? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
? b->create<memref::DimOp>(loc, shaped_value, dim_index)
.getResult()
: b->create<ConstantIndexOp>(loc, dim);
}
@ -249,8 +252,8 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
}
// Load initial value from memref<element_type>.
SmallVector<Value, 1> init_value = {
rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
SmallVector<Value, 1> init_value = {rewriter->create<memref::LoadOp>(
loc, *reduce_op.init_values().begin())};
// Outer ParallelOp is not needed if it is a reduction across all dims.
scf::ParallelOp outer;
if (!parallel_lower.empty()) {
@ -272,7 +275,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
}
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
rewriter->create<memref::StoreOp>(loc, reduction_result, out, out_indices);
// Load the element to reduce.
SmallVector<Value, 2> indices;
@ -290,7 +293,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
}
rewriter->setInsertionPointToStart(inner.getBody());
Value elem = rewriter->create<mlir::LoadOp>(
Value elem = rewriter->create<mlir::memref::LoadOp>(
loc, *reduce_op.operands().begin(), indices);
return rewriter->create<scf::ReduceOp>(loc, elem);
}
@ -385,7 +388,7 @@ class ReduceWindowOpConverter
ConversionPatternRewriter* rewriter) const {
auto loc = reduce_window_op.getLoc();
Value init_value =
rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
rewriter->create<memref::LoadOp>(loc, reduce_window_op.init_value());
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
@ -408,7 +411,8 @@ class ReduceWindowOpConverter
Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
rewriter->create<memref::StoreOp>(loc, reduction_result, output,
output_ivs);
return std::make_pair(output_loop, window_loop);
}
@ -439,7 +443,7 @@ class ReduceWindowOpConverter
OpBuilder then_builder =
elem_or_init.getThenBodyBuilder(rewriter->getListener());
Value elem = then_builder.create<mlir::LoadOp>(
Value elem = then_builder.create<mlir::memref::LoadOp>(
loc, reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem);
@ -497,8 +501,8 @@ class SelectAndScatterOpConverter
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
// Load `source[selected_ivs]`.
auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
loop_over_src.getInductionVars());
auto src_elem = rewriter.create<memref::LoadOp>(
loc, s_and_s_op.source(), loop_over_src.getInductionVars());
// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
@ -517,14 +521,14 @@ class SelectAndScatterOpConverter
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
Value init_value = b->create<memref::LoadOp>(loc, s_and_s_op.init_value());
scf::ParallelOp loop_over_output =
MakeLoopOverShape(loc, s_and_s_op.out(), b);
OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(loop_over_output.getBody());
b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
loop_over_output.getInductionVars());
b->create<memref::StoreOp>(loc, init_value, s_and_s_op.out(),
loop_over_output.getInductionVars());
}
struct WindowLoops {
@ -647,7 +651,7 @@ class SelectAndScatterOpConverter
TypeRange iter_arg_types{ivs_val_flag->to_vector()};
Value operand_elem =
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
b->create<memref::LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
auto if_init =
b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
/*withElseRegion=*/true);
@ -712,8 +716,8 @@ struct LhloLegalizeToParallelLoopsPass
// clang-format on
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
scf::SCFDialect, LmhloDialect>();
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect, scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
lmhlo::SelectAndScatterOp>();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
@ -58,7 +59,8 @@ Value CalculateShapeValue(Location loc, Value operand,
int64_t rank = result_type.getRank();
shape_values.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
shape_values.push_back(
rewriter.create<mlir::memref::DimOp>(loc, operand, i));
}
return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
}

View File

@ -967,10 +967,10 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
// CHECK-LABEL: func @erase_dead_lhlo_constant
func @erase_dead_lhlo_constant() {
%M = alloc() : memref<256x1024xf32>
%M = memref.alloc() : memref<256x1024xf32>
// CHECK-NEXT: return
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
dealloc %M : memref<256x1024xf32>
memref.dealloc %M : memref<256x1024xf32>
return
}
@ -979,9 +979,9 @@ func @erase_dead_lhlo_constant() {
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
// CHECK-NEXT: lmhlo.constant
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
// CHECK-NEXT: alloc
// CHECK-NEXT: memref.alloc
// CHECK-NEXT: lmhlo.constant
%N = alloc() : memref<256x1024xf32>
%N = memref.alloc() : memref<256x1024xf32>
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
return %N : memref<256x1024xf32>
}

View File

@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
return %reshaped : tensor<?xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// -----
@ -30,7 +30,7 @@ func @dynamic_reshape_to_unranked(
return %reshaped : tensor<*xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
// -----
@ -41,4 +41,4 @@ func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
return %reshaped : tensor<f32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32>
// CHECK-NEXT: memref.cast [[ARG]] : memref<*xf32> to memref<f32>

View File

@ -31,20 +31,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
return %5 : tensor<4xf32>
}
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: memref.dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: %[[SUB_RESULT:.*]] = memref.alloc() : memref<4xf32>
//  CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: memref.dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<4xf32>
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: memref.dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// -----
@ -53,15 +53,15 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
%summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
%sum = "mhlo.add"(%summand_1, %summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
// CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
%result = "mhlo.multiply"(%sum, %multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32>
return %result : tensor<2x2xf32>
}
@ -154,9 +154,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
@ -172,9 +172,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
@ -469,7 +469,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return %result : tensor<?x?xf32>
// CHECK: return %[[RESULT]]
@ -485,7 +485,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return %result : tensor<?x?xf32>
// CHECK: return %[[RESULT]]
@ -496,7 +496,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-LABEL: func @dot
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// dot_dimension_numbers = {
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
@ -517,7 +517,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
-> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
@ -548,11 +548,11 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<1xf32>
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
// CHECK: %[[TMP:.*]] = alloc() : memref<f32>
// CHECK: %[[TMP:.*]] = memref.alloc() : memref<f32>
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// CHECK: "lmhlo.terminator"() : () -> ()

View File

@ -404,7 +404,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
return %0: tensor<4x2x1x4x?x16xf32>
}
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
// CHECK: %[[D1:.*]] = memref.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
@ -997,19 +997,41 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
// CHECK-SAME: [[SCALAR:%.*]]: tensor<f32>
// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex>
func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor<?x32xf32> {
%cst = mhlo.constant dense<0x7F800000> : tensor<f32>
%result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) {
func @dynamic_broadcast_in_dim(%scalar: tensor<f32>, %shape: tensor<2xindex>)
-> tensor<?x32xf32> {
%result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) {
broadcast_dimensions = dense<> : tensor<0xi64>
} : (tensor<f32>, tensor<2xindex>) -> tensor<?x32xf32>
return %result : tensor<?x32xf32>
}
// CHECK: [[CST:%.*]] = constant
// CHECK: [[INIT:%.*]] = linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
// CHECK-SAME: ins([[SCALAR]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
// CHECK-SAME: [[VECTOR:%.*]]: tensor<42xf32>
// CHECK-SAME: [[SHAPE:%.*]]: tensor<3xindex>
func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xindex>)
-> tensor<?x?x?xf32> {
%result = "mhlo.dynamic_broadcast_in_dim"(%vector, %shape) {
broadcast_dimensions = dense<1> : tensor<1xi64>
} : (tensor<42xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}
// CHECK: [[INIT:%.*]] = linalg.init_tensor
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-SAME: ins([[VECTOR]] : tensor<42xf32>) outs([[INIT]] : tensor<?x?x?xf32>)
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -1024,7 +1046,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
// CHECK-LABEL: func @dot_matmul(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul
@ -1040,7 +1062,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul
@ -1058,7 +1080,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul
@ -1076,7 +1098,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matmul
@ -1094,7 +1116,7 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
// CHECK-LABEL: func @dot_matvec(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.matvec
@ -1134,11 +1156,11 @@ func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
// CHECK-LABEL: func @dot_general_batch_matmul(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul
@ -1163,11 +1185,11 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul
@ -1192,11 +1214,11 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
// CHECK: linalg.batch_matmul
@ -1420,7 +1442,7 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
// CHECK-DAG: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
// CHECK: linalg.generic
@ -1531,9 +1553,9 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tenso
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
// CHECK: %[[DIM2:.+]] = memref.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
@ -1571,9 +1593,9 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
// CHECK: %[[DIM3:.+]] = memref.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
@ -1611,9 +1633,9 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tens
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
// CHECK: %[[C4:.+]] = constant 4 : index
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
@ -1826,7 +1848,6 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1
return %1 : tensor<1x8x8x64xf32>
}
// -----
// CHECK-LABEL: func @reduce_window_max_nhwc
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-DAG: %[[CST:.+]] = constant dense<0xFF800000> : tensor<f32>
@ -1839,3 +1860,125 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
// -----
func @torch_select_index(%arg0: tensor<5x1x5xi32>,
%arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
dim = 0 : i64,
batch_dims = 0 : i64
} : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
return %0 : tensor<2x1x5xi32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @torch_select_index
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
// CHECK: linalg.yield %[[VAL2]] : i32
// -----
func @torch_select_index_scalar(%arg0: tensor<4x8xf32>,
%arg1: tensor<i32>) -> tensor<8xf32> {
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
batch_dims = 0 : i64,
dim = 0 : i64
} : (tensor<4x8xf32>, tensor<i32>) -> tensor<8xf32>
return %0 : tensor<8xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
// CHECK: func @torch_select_index_scalar
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<i32>) outs(%[[T0]] : tensor<8xf32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32):
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32>
// CHECK: linalg.yield %[[VAL2]] : f32
// -----
func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>,
%arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> {
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
dim = 2 : i64,
batch_dims = 1 : i64
} : (tensor<4x7x8x2xf32>, tensor<4x1xi32>) -> tensor<4x7x1x2xf32>
return %0 : tensor<4x7x1x2xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @torch_select_index_batch
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>)
// CHECK-NEXT: ^{{.+}}(
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[J:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]: index, %[[L:.+]]: index,
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: f32):
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32>
// CHECK: linalg.yield %[[VAL2]] : f32
// -----
func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
%index: tensor<?x?xi32>) -> tensor<?x?x?x?xf32>{
%0 = "mhlo.torch_index_select"(%input, %index) {
batch_dims = 1 : i64,
dim = 2 : i64
} : (tensor<?x?x?x?xf32>, tensor<?x?xi32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @torch_index_select_dynamic
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[D0:.+]] = memref.dim %[[INPUT]], %[[C0]]
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[D1:.+]] = memref.dim %[[INPUT]], %[[C1]]
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[D2:.+]] = memref.dim %[[INDEX]], %[[C1]]
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
// CHECK: %[[RESULT:.+]] = linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32)
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
// CHECK: linalg.yield %[[YIELD]]

View File

@ -1,125 +1,78 @@
// RUN: mlir-hlo-opt --mhlo-legalize-trigonometric-to-approximation --split-input-file %s | FileCheck %s
// CHECK-LABEL: @tanh_f64
func @tanh_f64(%arg0 : f64) -> f64 {
// CHECK: tanh
%res = math.tanh %arg0 : f64
return %res : f64
}
// CHECK-LABEL: @tanh_f64
// CHECK: tanh
// -----
// CHECK-LABEL: @tanh_f32
// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32
func @tanh_f32(%arg0 : f32) -> f32 {
// CHECK-DAG: %[[C:.*]] = constant -2.76076837E-16 : f32
// CHECK-DAG: %[[C0:.*]] = constant 2.00018794E-13 : f32
// CHECK-DAG: %[[C1:.*]] = constant -8.60467184E-11 : f32
// CHECK-DAG: %[[C2:.*]] = constant 5.12229725E-8 : f32
// CHECK-DAG: %[[C3:.*]] = constant 1.48572235E-5 : f32
// CHECK-DAG: %[[C4:.*]] = constant 6.37261954E-4 : f32
// CHECK-DAG: %[[C5:.*]] = constant 0.00489352457 : f32
// CHECK-DAG: %[[C6:.*]] = constant 1.19825836E-6 : f32
// CHECK-DAG: %[[C7:.*]] = constant 1.18534706E-4 : f32
// CHECK-DAG: %[[C8:.*]] = constant 0.00226843474 : f32
// CHECK-DAG: %[[C9:.*]] = constant 0.00489352504 : f32
// CHECK-DAG: %[[C10:.*]] = constant 4.000000e-04 : f32
// CHECK-DAG: %[[C11:.*]] = constant 7.90531111 : f32
// CHECK-DAG: %[[C12:.*]] = constant -7.90531111 : f32
// CHECK-DAG: %[[C13:.*]] = constant 1.000000e+00 : f32
// CHECK-DAG: %[[C14:.*]] = constant -1.000000e+00 : f32
// CHECK-DAG: %[[TMP0:.*]] = mulf %[[ARG]], %[[ARG]] : f32
// CHECK-DAG: %[[TMP1:.*]] = mulf %[[TMP0]], %[[C]] : f32
// CHECK-DAG: %[[TMP2:.*]] = addf %[[TMP1]], %[[C0]] : f32
// CHECK-DAG: %[[TMP3:.*]] = mulf %[[TMP0]], %[[TMP2]] : f32
// CHECK-DAG: %[[TMP4:.*]] = addf %[[TMP3]], %[[C1]] : f32
// CHECK-DAG: %[[TMP5:.*]] = mulf %[[TMP0]], %[[TMP4]] : f32
// CHECK-DAG: %[[TMP6:.*]] = addf %[[TMP5]], %[[C2]] : f32
// CHECK-DAG: %[[TMP7:.*]] = mulf %[[TMP0]], %[[TMP6]] : f32
// CHECK-DAG: %[[TMP8:.*]] = addf %[[TMP7]], %[[C3]] : f32
// CHECK-DAG: %[[TMP9:.*]] = mulf %[[TMP0]], %[[TMP8]] : f32
// CHECK-DAG: %[[TMP10:.*]] = addf %[[TMP9]], %[[C4]] : f32
// CHECK-DAG: %[[TMP11:.*]] = mulf %[[TMP0]], %[[TMP10]] : f32
// CHECK-DAG: %[[TMP12:.*]] = addf %[[TMP11]], %[[C5]] : f32
// CHECK-DAG: %[[TMP13:.*]] = mulf %[[ARG]], %[[TMP12]] : f32
// CHECK-DAG: %[[TMP14:.*]] = mulf %[[TMP0]], %[[C6]] : f32
// CHECK-DAG: %[[TMP15:.*]] = addf %[[TMP14]], %[[C7]] : f32
// CHECK-DAG: %[[TMP16:.*]] = mulf %[[TMP0]], %[[TMP15]] : f32
// CHECK-DAG: %[[TMP17:.*]] = addf %[[TMP16]], %[[C8]] : f32
// CHECK-DAG: %[[TMP18:.*]] = mulf %[[TMP0]], %[[TMP17]] : f32
// CHECK-DAG: %[[TMP19:.*]] = addf %[[TMP18]], %[[C9]] : f32
// CHECK-DAG: %[[TMP20:.*]] = divf %[[TMP13]], %[[TMP19]] : f32
// CHECK-DAG: %[[TMP21:.*]] = absf %[[ARG]] : f32
// CHECK-DAG: %[[TMP22:.*]] = cmpf olt, %[[TMP21]], %[[C10]] : f32
// CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32
// CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32
// CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32
// CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32
// CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32
// CHECK: return %[[TMP27]] : f32
%res = math.tanh %arg0 : f32
return %res : f32
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @tanh_f32
// CHECK-SAME: (%[[VAL_0:.*]]: f32) -> f32
// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32
// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32
// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32
// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32
// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32
// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32
// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32
// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32
// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32
// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32
// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32
// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32
// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32
// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32
// CHECK: %[[VAL_15:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_16:.*]] = cmpf olt, %[[VAL_15]], %[[VAL_1]] : f32
// CHECK: %[[VAL_17:.*]] = cmpf ule, %[[VAL_0]], %[[VAL_2]] : f32
// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_0]], %[[VAL_2]] : f32
// CHECK: %[[VAL_19:.*]] = cmpf uge, %[[VAL_18]], %[[VAL_3]] : f32
// CHECK: %[[VAL_20:.*]] = select %[[VAL_19]], %[[VAL_18]], %[[VAL_3]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_20]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_4]] : f32
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_21]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32
// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_21]], %[[VAL_25]] : f32
// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32
// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_21]], %[[VAL_27]] : f32
// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32
// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_21]], %[[VAL_29]] : f32
// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32
// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_21]], %[[VAL_31]] : f32
// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32
// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_20]], %[[VAL_33]] : f32
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_11]] : f32
// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_12]] : f32
// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_21]], %[[VAL_36]] : f32
// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32
// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_21]], %[[VAL_38]] : f32
// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32
// CHECK: %[[VAL_41:.*]] = divf %[[VAL_34]], %[[VAL_40]] : f32
// CHECK: %[[VAL_42:.*]] = select %[[VAL_16]], %[[VAL_0]], %[[VAL_41]] : f32
// CHECK: return %[[VAL_42]] : f32
// -----
func @tanh_f16(%arg0 : f16) -> f16 {
// CHECK-LABEL: func @tanh_f16
// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16
// CHECK: %{{.*}} = fpext %[[ARG]] : f16 to f32
// CHECK: %[[RES:.*]] = fptrunc %{{.*}} : f32 to f16
// CHECK: return %[[RES]] : f16
%res = math.tanh %arg0 : f16
return %res : f16
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @tanh_f16
// CHECK-SAME: (%[[VAL_0:.*]]: f16) -> f16
// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32
// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32
// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32
// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32
// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32
// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32
// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32
// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32
// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32
// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32
// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32
// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32
// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32
// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32
// CHECK: %[[VAL_15:.*]] = fpext %[[VAL_0]] : f16 to f32
// CHECK: %[[VAL_16:.*]] = absf %[[VAL_15]] : f32
// CHECK: %[[VAL_17:.*]] = cmpf olt, %[[VAL_16]], %[[VAL_1]] : f32
// CHECK: %[[VAL_18:.*]] = cmpf ule, %[[VAL_15]], %[[VAL_2]] : f32
// CHECK: %[[VAL_19:.*]] = select %[[VAL_18]], %[[VAL_15]], %[[VAL_2]] : f32
// CHECK: %[[VAL_20:.*]] = cmpf uge, %[[VAL_19]], %[[VAL_3]] : f32
// CHECK: %[[VAL_21:.*]] = select %[[VAL_20]], %[[VAL_19]], %[[VAL_3]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_21]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_4]] : f32
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_5]] : f32
// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_22]], %[[VAL_24]] : f32
// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32
// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_22]], %[[VAL_26]] : f32
// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32
// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_22]], %[[VAL_28]] : f32
// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_22]], %[[VAL_30]] : f32
// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_22]], %[[VAL_32]] : f32
// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_34]] : f32
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_11]] : f32
// CHECK: %[[VAL_37:.*]] = addf %[[VAL_36]], %[[VAL_12]] : f32
// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_22]], %[[VAL_37]] : f32
// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32
// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_22]], %[[VAL_39]] : f32
// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32
// CHECK: %[[VAL_42:.*]] = divf %[[VAL_35]], %[[VAL_41]] : f32
// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32
// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16
// CHECK: return %[[VAL_44]] : f16
// -----
// CHECK-LABEL: @atan2_f64

View File

@ -7,7 +7,7 @@
iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() : memref<6x6xf32>
%temp_result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
@ -22,7 +22,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
}
dealloc %temp_result : memref<6x6xf32>
memref.dealloc %temp_result : memref<6x6xf32>
return
}
// CHECK-LABEL: func @fusion
@ -62,7 +62,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>,
%arg2: memref<100x10xf32>) {
%0 = alloc() : memref<100x10xf32>
%0 = memref.alloc() : memref<100x10xf32>
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
@ -72,7 +72,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
}
%1 = alloc() : memref<100x10xf32>
%1 = memref.alloc() : memref<100x10xf32>
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
@ -84,7 +84,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%2 = subf %arg3, %arg4 : f32
linalg.yield %2 : f32
}
dealloc %0 : memref<100x10xf32>
memref.dealloc %0 : memref<100x10xf32>
linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
@ -95,7 +95,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%2 = math.exp %arg3 : f32
linalg.yield %2 : f32
}
dealloc %1 : memref<100x10xf32>
memref.dealloc %1 : memref<100x10xf32>
return
}
// CHECK-LABEL: func @fusion
@ -141,7 +141,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
"parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() : memref<6x6x6x6xf32>
%temp_result = memref.alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
outs(%temp_result : memref<6x6x6x6xf32>) {
@ -156,7 +156,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
}
dealloc %temp_result : memref<6x6x6x6xf32>
memref.dealloc %temp_result : memref<6x6x6x6xf32>
return
}
// CHECK-LABEL: func @fusion_4d
@ -200,7 +200,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32>
%temp_result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
outs(%temp_result : memref<6x6xf32>) {
@ -208,7 +208,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
}
%result = alloc() : memref<6x6xf32>
%result = memref.alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
outs(%result : memref<6x6xf32>) {
@ -216,7 +216,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
}
dealloc %temp_result : memref<6x6xf32>
memref.dealloc %temp_result : memref<6x6xf32>
return %result : memref<6x6xf32>
}
@ -258,7 +258,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32>
%1 = memref.alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = memref_reshape %1(%arg1)
%2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %2 : memref<*xf32>
}
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: memref_reshape
// CHECK: memref.reshape
// TILED-LABEL: func @view_result
// TILED-DAG: %[[C2:.*]] = constant 2
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: memref_reshape
// TILED: memref.reshape
// PLOOP-LABEL: func @view_result
@ -297,20 +297,20 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: memref_reshape
// PLOOP: memref.reshape
// -----
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
// This test also uses memref_reshape, just to have a value to return through
// This test also uses memref.reshape, just to have a value to return through
// the if statement.
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32>
%1 = memref.alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
@ -321,11 +321,11 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
}
%true = constant 1 : i1
%3 = scf.if %true -> memref<*xf32> {
%2 = memref_reshape %1(%arg1)
%2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
} else {
%2 = memref_reshape %1(%arg1)
%2 = memref.reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
}
@ -340,10 +340,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// CHECK: linalg.generic
// CHECK: absf
// CHECK: scf.if
// CHECK: memref_reshape
// CHECK: memref.reshape
// CHECK: scf.yield
// CHECK: else
// CHECK: memref_reshape
// CHECK: memref.reshape
// CHECK: scf.yield
// TILED-LABEL: func @branching_result
@ -354,10 +354,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// TILED: linalg.generic
// TILED: absf
// TILED: scf.if
// TILED: memref_reshape
// TILED: memref.reshape
// TILED: scf.yield
// TILED: else
// TILED: memref_reshape
// TILED: memref.reshape
// TILED: scf.yield
// PLOOP-LABEL: func @branching_result
@ -367,10 +367,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: scf.if
// PLOOP: memref_reshape
// PLOOP: memref.reshape
// PLOOP: scf.yield
// PLOOP: else
// PLOOP: memref_reshape
// PLOOP: memref.reshape
// PLOOP: scf.yield
// -----
@ -380,7 +380,7 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
-> memref<?xf32> {
%c1 = constant 1 : index
%1 = alloc() : memref<32xf32>
%1 = memref.alloc() : memref<32xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
@ -389,9 +389,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = tensor_load %1 : memref<32xf32>
%2 = memref.tensor_load %1 : memref<32xf32>
%3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
%4 = tensor_to_memref %3 : memref<?xf32>
%4 = memref.buffer_cast %3 : memref<?xf32>
return %4 : memref<?xf32>
}
@ -402,9 +402,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: tensor_load
// CHECK: memref.tensor_load
// CHECK: tensor.cast
// CHECK: tensor_to_memref
// CHECK: memref.buffer_cast
// TILED-LABEL: func @tensor_ops
// TILED-DAG: %[[C2:.*]] = constant 2
@ -413,9 +413,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: tensor_load
// TILED: memref.tensor_load
// TILED: tensor.cast
// TILED: tensor_to_memref
// TILED: memref.buffer_cast
// PLOOP-LABEL: func @tensor_ops
@ -424,6 +424,6 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: tensor_load
// PLOOP: memref.tensor_load
// PLOOP: tensor.cast
// PLOOP: tensor_to_memref
// PLOOP: memref.buffer_cast

View File

@ -49,10 +49,10 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK-DAG: [[CTRUE:%.*]] = constant true
// Parallel loop to initialize the output buffer.
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) {
// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: memref.store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield
// CHECK: }
@ -101,7 +101,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true
// CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
// CHECK: [[ARG_ELEM:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
// CHECK: [[IF_INIT_RES:%.*]]:4
// CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) {
@ -114,16 +114,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// Allocate buffers for ARG element, current selected value to adapt LHLO
// code.
// CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[PRED_BUF:%.*]] = alloc() : memref<i1>
// CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
// CHECK: [[ARG_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[SEL_VAL_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[PRED_BUF:%.*]] = memref.alloc() : memref<i1>
// CHECK: memref.store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
// CHECK: memref.store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
// Compute PRED.
// CHECK: "lmhlo.compare"(
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
// CHECK: [[PRED:%.*]] = memref.load [[PRED_BUF]][] : memref<i1>
// Depending on PRED, return ARG ivs & elem or current select ivs and value.
@ -165,7 +165,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: }
// Use selected ivs to load element from the SRC buffer.
// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
// CHECK: [[SRC_ELEM:%.*]] = memref.load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
// it may happen that several other threads select the same IVs if the windows
@ -175,16 +175,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: ^bb0([[CUR_RES:%.*]]: f32):
// Allocate buffers for ARG element, current selected value to adapt LHLO code.
// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[RES_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
// CHECK: [[SRC_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[CUR_RES_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[RES_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: memref.store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
// CHECK: memref.store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
// Compute scatter value.
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>
// CHECK: [[RES:%.*]] = memref.load [[RES_BUF]][] : memref<f32>
// Atomic RMW terminator that returns updated value.
// CHECK: atomic_yield [[RES]] : f32

View File

@ -19,14 +19,14 @@ func @reduce(%arg: memref<100x10xf32>,
// CHECK-DAG: %[[C100:.*]] = constant 100 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) {
// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref<f32>
// CHECK: %[[ACC:.*]] = memref.load %[[ARG1]][] : memref<f32>
// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32>
// CHECK-DAG: %[[LB:.*]] = constant 0 : index
// CHECK-DAG: %[[UB:.*]] = constant 10 : index
// CHECK-DAG: %[[STEP:.*]] = constant 1 : index
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[LHS:.*]] = subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
// CHECK: %[[RHS:.*]] = subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
// CHECK: %[[LHS:.*]] = memref.subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
// CHECK: %[[RHS:.*]] = memref.subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: }
// CHECK: gpu.terminator

View File

@ -52,10 +52,10 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
: (memref<f32>, memref<f32>, memref<f32>) -> ()
return
}
// CHECK: %[[LHS:.*]] = load
// CHECK: %[[RHS:.*]] = load
// CHECK: %[[LHS:.*]] = memref.load
// CHECK: %[[RHS:.*]] = memref.load
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK: store %[[RES]]
// CHECK: memref.store %[[RES]]
// CHECK-NEXT: return
// -----
@ -347,7 +347,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
}
// CHECK-NOT: linalg.reshape
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
// CHECK: %[[VALUE:.*]] = memref.load %{{.*}}[[C0]]
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
@ -785,7 +785,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
} : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// CHECK: %[[RESULT:.*]] = subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
// CHECK: %[[RESULT:.*]] = memref.subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
// CHECK: linalg.copy(%[[RESULT]], %[[OUT]])
// -----
@ -899,7 +899,7 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
%c0 = constant 0 : index
%0 = alloc() : memref<3x5x5x4xf32>
%0 = memref.alloc() : memref<3x5x5x4xf32>
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 2]
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
@ -948,22 +948,22 @@ func @reduce_add(%arg: memref<100x10xf32>,
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return
}
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
// CHECK: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: store
// CHECK-NEXT: store
// CHECK-NEXT: load
// CHECK-NEXT: load
// CHECK: memref.alloca
// CHECK-NEXT: memref.alloca
// CHECK-NEXT: memref.alloca
// CHECK-NEXT: memref.store
// CHECK-NEXT: memref.store
// CHECK-NEXT: memref.load
// CHECK-NEXT: memref.load
// CHECK-NEXT: addf
// CHECK-NEXT: store
// CHECK-NEXT: load
// CHECK-NEXT: memref.store
// CHECK-NEXT: memref.load
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
@ -984,22 +984,22 @@ func @reduce_maximum(%arg: memref<100x10xf32>,
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return
}
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
// CHECK: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: alloca
// CHECK-NEXT: store
// CHECK-NEXT: store
// CHECK-NEXT: load
// CHECK-NEXT: load
// CHECK: memref.alloca
// CHECK-NEXT: memref.alloca
// CHECK-NEXT: memref.alloca
// CHECK-NEXT: memref.store
// CHECK-NEXT: memref.store
// CHECK-NEXT: memref.load
// CHECK-NEXT: memref.load
// CHECK: cmpf
// CHECK: select
// CHECK: store
// CHECK-NEXT: load
// CHECK: memref.store
// CHECK-NEXT: memref.load
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }

View File

@ -21,27 +21,27 @@ func @reduce(%arg: memref<100x10x5xf32>,
// CHECK-DAG: [[C5:%.*]] = constant 5 : index
// CHECK-DAG: [[C10:%.*]] = constant 10 : index
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32>
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: scf.yield
// -----
@ -65,23 +65,23 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]])
// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]]
// CHECK: }
// CHECK: scf.yield
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
// -----
@ -104,30 +104,30 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: [[DIM0:%.*]] = memref.dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
// CHECK: [[DIM1:%.*]] = memref.dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = memref.dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32>
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: scf.yield
// -----
@ -157,7 +157,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel
@ -176,7 +176,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
// CHECK: [[OPERAND_ELEM:%.*]] =
// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
// CHECK-SAME: memref.load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
// CHECK: scf.yield [[OPERAND_ELEM]] : f32
// CHECK: } else {
// CHECK: scf.yield [[INIT]] : f32
@ -184,18 +184,18 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield
// CHECK: }
// CHECK: return

View File

@ -30,7 +30,7 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
// CHECK-LABEL: func @conv_forward
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
%scratch = alloc() : memref<32xi8>
%scratch = memref.alloc() : memref<32xi8>
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
@ -61,7 +61,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
// CHECK-LABEL: func @conv_backfilter
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
%scratch = alloc() : memref<23328xui8>
%scratch = memref.alloc() : memref<23328xui8>
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
{ backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0],
@ -91,7 +91,7 @@ func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64x
// CHECK-LABEL: func @conv_backinput
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
%scratch = alloc() : memref<32xui8>
%scratch = memref.alloc() : memref<32xui8>
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
{ backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0],
@ -122,7 +122,7 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6
// CHECK-LABEL: func @conv_fused
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = alloc() : memref<32xui8>
%scratch = memref.alloc() : memref<32xui8>
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
{activation_mode = "Relu",
backend_config = {algorithm = 1 : i64,
@ -153,7 +153,7 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>,
// CHECK-LABEL: func @conv_fused_side_input
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = alloc() : memref<0xui8>
%scratch = memref.alloc() : memref<0xui8>
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
{activation_mode = "Relu",
backend_config = {algorithm = 1 : i64,
@ -218,8 +218,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
// CHECK-LABEL: func @cholesky
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
%scratch = alloc() : memref<32xi8>
%info = alloc() : memref<32xi32>
%scratch = memref.alloc() : memref<32xi8>
%info = memref.alloc() : memref<32xi32>
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
return

View File

@ -457,12 +457,12 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
// CHECK-LABEL: func @fusion_memref
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
"lmhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32>
%0 = memref.tensor_load %input1 : memref<10xf32>
%1 = memref.tensor_load %input2 : memref<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32>
%3 = memref.tensor_load %input3 : memref<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32>
memref.tensor_store %4, %out : memref<10xf32>
"lmhlo.terminator"() : () -> ()
} ) : () -> ()
return

View File

@ -953,7 +953,7 @@ func @tuple_token(%arg0: tensor<f32>, %arg1: !mhlo.token) -> tuple<tensor<f32>,
// -----
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
// expected-error@+1 {{number of operands to tuple expected to match number of types}}
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
return %0 : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
}
@ -961,7 +961,7 @@ func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<t
// -----
func @tuple_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<i32>> {
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<i32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
// expected-error@+1 {{op has return type mismatch at 1th value}}
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<i32>>
return %0 : tuple<tensor<f32>, tensor<i32>>
}

View File

@ -108,15 +108,15 @@ func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[DIM:.+]] = memref.dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = memref.dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = memref.dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = memref.dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = memref.dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>

View File

@ -1452,22 +1452,6 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
results.push_back(tensor_index_map.lookup(result));
}
Operation* real_inst = &inst;
// CustomTfOp is just a wrapper around a TF op, we export the custom Op
// not the wrapper, so we fetch the op from the region.
if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
// If we have custom op with a region, then use the first op in the
// region, if it exists, otherwise just use params for custom op.
if (!custom_op.body().empty()) {
real_inst = &custom_op.body().front().front();
// Use the inputs of the wrapper to reset the inputs.
for (auto idx_op : llvm::enumerate(custom_op->getOperands())) {
real_inst->setOperand(idx_op.index(), idx_op.value());
}
} else {
module_.emitError(
"Invalid CustomTfOp: Custom TF Op have empty region.");
}
}
std::vector<int32_t> operands;
operands.reserve(real_inst->getNumOperands());
for (auto operand : real_inst->getOperands()) {
@ -1481,6 +1465,19 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
operands.push_back(tensor_index_map.lookup(operand));
}
// CustomTfOp is just a wrapper around a TF op, we export the custom Op
// not the wrapper, so we fetch the op from the region.
if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
// If we have custom op with a region, then use the first op in the
// region, if it exists, otherwise just use params for custom op.
if (!custom_op.body().empty()) {
real_inst = &custom_op.body().front().front();
} else {
module_.emitError(
"Invalid CustomTfOp: Custom TF Op have empty region.");
}
}
if (auto tfl_operator =
BuildOperator(real_inst, operands, results, intermediates))
operators.push_back(*tfl_operator);

View File

@ -1067,7 +1067,8 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
static void BuildGatherOp(OpBuilder *builder, OperationState &result,
Value params, Value indices, IntegerAttr axis) {
Value params, Value indices, IntegerAttr axis,
IntegerAttr batch_dims) {
auto params_type = params.getType().cast<TensorType>();
auto indices_type = indices.getType().cast<TensorType>();
@ -1075,7 +1076,7 @@ static void BuildGatherOp(OpBuilder *builder, OperationState &result,
if (!params_type.hasRank() || !indices_type.hasRank())
return TFL::GatherOp::build(
*builder, result, UnrankedTensorType::get(params_type.getElementType()),
params, indices, axis);
params, indices, axis, batch_dims);
int64_t params_rank = params_type.getRank();
int64_t indices_rank = indices_type.getRank();
@ -1096,7 +1097,29 @@ static void BuildGatherOp(OpBuilder *builder, OperationState &result,
emitError(result.location, "params must be at least rank axis + 1");
}
if (indices_rank == 0) {
int64_t batch_dims_i = batch_dims.getInt();
if (batch_dims_i < 0) {
batch_dims_i += indices_rank;
}
if (batch_dims_i > axis_i) {
emitError(result.location,
"axis should be bigger than or equal to batch_dims");
}
if (batch_dims_i >= params_rank || batch_dims_i > indices_rank) {
emitError(result.location,
"batch_dims must be smaller than params' rank and smaller than "
"or equal to indices'rank");
}
for (int i = 0; i < batch_dims_i; ++i) {
if (indices_type.getShape()[i] != params_type.getShape()[i]) {
emitError(result.location,
"batch dimensions of params must be equal to batch dimensions "
"of indices");
}
}
if ((indices_rank == 0) || (indices_rank == batch_dims_i)) {
// Scalar indices (output is rank(params) - 1).
// Erase shape[axis]
shape.erase(shape.begin() + axis_i);
@ -1107,21 +1130,21 @@ static void BuildGatherOp(OpBuilder *builder, OperationState &result,
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
} else {
// Higher rank indices (output is rank(params) + rank(indices) - 1).
shape.resize(params_rank + indices_rank - 1);
shape.resize(params_rank + indices_rank - 1 - batch_dims_i);
// Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
std::copy(std::begin(params_type.getShape()) + axis_i + 1,
std::end(params_type.getShape()),
std::begin(shape) + axis_i + indices_rank);
std::begin(shape) + axis_i + indices_rank - batch_dims_i);
// Copy indices.shape into params.shape[axis]
std::copy(std::begin(indices_type.getShape()),
std::copy(std::begin(indices_type.getShape()) + batch_dims_i,
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
}
TFL::GatherOp::build(
*builder, result,
RankedTensorType::get(shape, params_type.getElementType()), params,
indices, axis);
indices, axis, batch_dims);
}
//===----------------------------------------------------------------------===//

View File

@ -1057,13 +1057,14 @@ def TFL_GatherOp : TFL_Op<"gather", [
let arguments = (ins
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
TFL_TensorOf<[I32, I64]>:$indices,
I32Attr:$axis
I32Attr:$axis,
DefaultValuedAttr<I32Attr, "0">:$batch_dims
);
let builders =
[
OpBuilder<(ins "Value":$params, "Value":$indices, "IntegerAttr":$axis),
[{ BuildGatherOp(&$_builder, $_state, params, indices, axis); }]>
OpBuilder<(ins "Value":$params, "Value":$indices, "IntegerAttr":$axis, "IntegerAttr":$batch_dims),
[{ BuildGatherOp(&$_builder, $_state, params, indices, axis, batch_dims); }]>
];
let results = (outs

View File

@ -115,6 +115,7 @@ cc_library(
"//tensorflow/lite/kernels/internal:tensor_utils",
"//tensorflow/lite/tools/optimize:quantization_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",

View File

@ -55,6 +55,7 @@ cc_library(
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",

View File

@ -512,7 +512,7 @@ bool QuantizationDriver::SetOperandParams(Operation *op, int index,
void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
QuantParams params) {
builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
builder_.setInsertionPointAfter(op);
Value original_result = op->getResult(index);
QuantizeValue(original_result, params, op->getLoc());
}
@ -741,10 +741,9 @@ void QuantizationDriver::SetupAllStates() {
}
fn_.walk([&](Operation *op) {
if (op->hasTrait<OpTrait::IsTerminator>() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
if (IsOpNotQuantizable(op)) {
return;
}
work_list_.push_back(op);
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {

View File

@ -51,6 +51,25 @@ namespace quant {
constexpr double kNearZeroTolerance = 1.0e-6;
constexpr double kSmallestHalfRange = kNearZeroTolerance / 2;
const char kQuantTraitAttr[] = "_tfl_quant_trait";
const absl::string_view QuantTraitValues[] = {"fully_quantizable",
"not_quantizable"};
bool IsOpNotQuantizable(Operation* op) {
// If it is terminator or not quantizable or any ops form the mlir quant
// ops dialect, we shouldn't rewrite.
bool attr_enforced_quantizable =
op->hasAttrOfType<StringAttr>(kQuantTraitAttr) &&
op->getAttrOfType<StringAttr>(kQuantTraitAttr).getValue().str() ==
QuantTraitValues[QuantizationTrait::FullyQuantizable];
bool prop_enforced_no_quantizable =
op->hasTrait<OpTrait::quant::NoQuantizableResult>();
return op->hasTrait<OpTrait::IsTerminator>() ||
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(op) ||
(!attr_enforced_quantizable && prop_enforced_no_quantizable);
}
// This method expands the range to be larger than or equal to 1.0e-6, if it is
// very small (< 1.0e-6). This is to prevent very large quantized value by this
// range.

View File

@ -22,9 +22,11 @@ limitations under the License.
#include <string>
#include <unordered_map>
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -32,6 +34,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
@ -49,6 +52,10 @@ namespace quant {
// losing accuracy.
constexpr char kVolatileOpAttrName[] = "volatile";
enum QuantizationTrait { FullyQuantizable, NotQuantizable };
extern const char kQuantTraitAttr[];
extern const absl::string_view QuantTraitValues[];
using QuantParams = quant::QuantizedType;
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
@ -91,6 +98,8 @@ QuantizedType DownCastScale(QuantizedType type,
QuantizedType DownCastScale(QuantizedType type, double min, double max,
Location loc);
bool IsOpNotQuantizable(Operation* op);
template <typename Q, typename DQ>
struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed,
@ -227,10 +236,7 @@ struct QuantizationPattern : public RewritePattern {
// If it is terminator or not quantizable or any ops form the mlir quant
// ops dialect, we shouldn't rewrite.
if (quantized_op->hasTrait<OpTrait::IsTerminator>() ||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(
quantized_op)) {
if (IsOpNotQuantizable(quantized_op)) {
return failure();
}
@ -306,8 +312,9 @@ struct QuantizationPattern : public RewritePattern {
if (quantized_op->getNumRegions() != 0) {
for (auto indexed_regions :
llvm::enumerate(quantized_op->getRegions())) {
new_op->getRegion(indexed_regions.index())
.takeBody(indexed_regions.value());
Region& target_region = new_op->getRegion(indexed_regions.index());
BlockAndValueMapping mapping;
indexed_regions.value().cloneInto(&target_region, mapping);
}
}
for (auto output : outputs_replaced) {

View File

@ -55,3 +55,40 @@ module attributes {tf_saved_model.semantics} {
// CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
}
}
// -----
// Test for func with no bound_input.
module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "Variable", type = tensor<1x10xf32>, value = dense<0.000000e+00> : tensor<1x10xf32>} : () -> ()
func @serving_default(%arg0: tensor<1x10xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>{tf_saved_model.bound_input = @Variable}) ->
(tensor<1x10xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
%1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : tensor<1x10xf32>
"tf.AssignVariableOp"(%arg1, %1) : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
%2 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
return %2 : tensor<1x10xf32>
}
func private @"FuncWithNoBoundInput"(%arg0: tensor<1x10xf32>, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32> {
"tf.AssignVariableOp"(%arg1, %arg0) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
return %0 : tensor<1x10xf32>
}
// CHECK: func @SessionInitializerFunction() attributes {tf_saved_model.exported_names = ["SessionInitializerFunction"]} {
// CHECK: %[[RESOURCE:.*]] = "tfl.pseudo_const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[VAL:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
// CHECK: "tfl.assign_variable"(%[[RESOURCE]], %[[VAL]]) : (tensor<1xi32>, tensor<1x10xf32>) -> ()
// CHECK: return
// CHECK: }
// CHECK: "tf_saved_model.session_initializer"() {initializers = [@SessionInitializerFunction]} : () -> ()
// CHECK: func @serving_default(%arg0: tensor<1x10xf32> {tf_saved_model.index_path = ["x"]}
// CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
//
// CHECK: func private @FuncWithNoBoundInput(%arg0: tensor<1x10xf32>, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32> {
// CHECK: "tf.AssignVariableOp"(%arg1, %arg0) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
// CHECK: %0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
// CHECK: return %0 : tensor<1x10xf32>
// CHECK: }
}

View File

@ -389,7 +389,7 @@ func @gatherScalarIndices(%arg0 : tensor<3x2xf32>, %arg1 : tensor<i32>) -> tenso
return %0 : tensor<2xf32>
// CHECK-LABEL:gatherScalarIndices
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
}
func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> tensor<3xf32> {
@ -397,7 +397,7 @@ func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> tenso
return %0 : tensor<3xf32>
// CHECK-LABEL:gatherVectorIndices
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
}
func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5xi32>) -> tensor<4x5x3x6xf32> {
@ -405,7 +405,7 @@ func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5xi32>
return %0 : tensor<4x5x3x6xf32>
// CHECK-LABEL:gatherHigherRankIndices
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
}
func @gatherNdVectorIndices(%arg0 : tensor<3x2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> {
@ -460,7 +460,7 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
return %1 : tensor<1x3x5x20xf32>
// CHECK-LABEL:gatherV2VectorIndices
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
}
func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
@ -469,7 +469,7 @@ func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3
return %1 : tensor<1x3x5x20xf32>
// CHECK-LABEL:gatherV2VectorIndices_I64Axis
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
}
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
@ -478,18 +478,20 @@ func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x
return %1 : tensor<1x2x3x5xf32>
// CHECK-LABEL:gatherV2VectorIndices
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = -1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32>
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = -1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32>
}
func @gatherV2NonZeroBatchDims(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
func @gatherWithBatchDims(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<2x5xi32>) -> tensor<2x5x3x6xf32> {
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
return %1 : tensor<1x2x3x5xf32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<2x3x6xf32>, tensor<2x5xi32>, tensor<1xi32>) -> tensor<2x5x3x6xf32>
return %1 : tensor<2x5x3x6xf32>
// CHECK-LABEL:gatherV2NonZeroBatchDims
// CHECK: tf.GatherV2
// CHECK-LABEL:gatherWithBatchDims
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 1 : i32} : (tensor<2x3x6xf32>, tensor<2x5xi32>) -> tensor<2x5x3x6xf32>
}
func @greater(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
return %0 : tensor<8x16xi1>

View File

@ -9,11 +9,11 @@
// Verify tensors in interpreter state:
// ------------------------------------
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4B ( 0.0 MB) (null)
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4B ( 0.0 MB) (null)
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4B ( 0.0 MB) [1]
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4B ( 0.0 MB) (null)
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4B ( 0.0 MB) [1]
// Verify while was not folded away:
// ------------------------------------

View File

@ -82,6 +82,14 @@ func @testGatherUnsupportedRank(%arg0 : tensor<f32>, %arg1 : tensor<1xi32>) -> t
// -----
// CHECK-LABEL: testGatherWithBatchDims
func @testGatherWithBatchDims(%arg0 : tensor<2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> {
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 2 : i32}: (tensor<2xf32>,tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// CHECK-LABEL: testAbs
func @testAbs(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):

View File

@ -1948,6 +1948,32 @@ func @DontRemoveReshapeBeforeFullyConnectedChangeLastDim(%arg0: tensor<128x64xf3
// CHECK: return %[[FULLY_CONNECTED]] : tensor<256x32xf32>
}
// CHECK-LABEL: RemoveReshapeAfterFullyConnected
func @RemoveReshapeAfterFullyConnected(%arg0: tensor<4x1024x1024xbf16>) -> tensor<4x1024x4096xbf16> {
%cst_0 = constant dense<1.0> : tensor<4096x1024xbf16>
%cst_1 = constant unit
%cst_2 = constant dense<[4, 1024, 4096]> : tensor<3xi32>
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x1024x1024xbf16>, tensor<4096x1024xbf16>, none) -> tensor<4096x4096xbf16>
%1 = "tfl.reshape"(%0, %cst_2) : (tensor<4096x4096xbf16>, tensor<3xi32>) -> tensor<4x1024x4096xbf16>
return %1 : tensor<4x1024x4096xbf16>
// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{.*}}keep_num_dims = true{{.*}} -> tensor<4x1024x4096xbf16>
// CHECK: return %[[V0]]
}
// CHECK-LABEL: RemoveReshapeAfterFullyConnectedAdd
func @RemoveReshapeAfterFullyConnectedAdd(%arg0: tensor<4x1024x1024xbf16>) -> tensor<4x1024x4096xbf16> {
%cst_0 = constant dense<1.0> : tensor<4096x1024xbf16>
%cst_1 = constant unit
%cst_2 = constant dense<[4, 1024, 4096]> : tensor<3xi32>
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x1024x1024xbf16>, tensor<4096x1024xbf16>, none) -> tensor<4096x4096xbf16>
%1 = "tfl.reshape"(%0, %cst_2) : (tensor<4096x4096xbf16>, tensor<3xi32>) -> tensor<4x1024x4096xbf16>
%2 = "tfl.mul"(%1, %1) {fused_activation_function = "NONE"} : (tensor<4x1024x4096xbf16>, tensor<4x1024x4096xbf16>) -> tensor<4x1024x4096xbf16>
return %2 : tensor<4x1024x4096xbf16>
// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{.*}}keep_num_dims = true{{.*}} -> tensor<4x1024x4096xbf16>
// CHECK: %[[V1:.*]] = tfl.mul %[[V0]], %[[V0]] {{.*}} : tensor<4x1024x4096xbf16
// CHECK: return %[[V1]]
}
// CHECK-LABEL: DontFuseAddWithConvActivationFunc
func @DontFuseAddWithConvActivationFunc(%arg0: tensor<1x3x1x1xf32>) -> tensor<1x2x1x3xf32> {
%cst = constant dense<1.5> : tensor<1xf32>

View File

@ -305,3 +305,52 @@ func @NotQuantizePow(%arg0: tensor<4x!quant.uniform<u8:f32, 1.0>>,
// DEBUG-NOT: "tfl.NumericVerify"
}
// CHECK-LABEL: QuantizeCustomTfOp
// DEBUG-LABEL: QuantizeCustomTfOp
func @QuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>,
%arg1: tensor<1x!quant.uniform<u8:f32, 0.2:127>>, %arg2: tensor<1x!quant.uniform<u8:f32, 0.4:127>>,
%arg3: tensor<1xi32>) -> (tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>) {
%0 = "tfl.dequantize"(%arg0) : (tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>) -> tensor<128x128xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<1x!quant.uniform<u8:f32, 0.2:127>>) -> tensor<1xf32>
%2 = "tfl.dequantize"(%arg2) : (tensor<1x!quant.uniform<u8:f32, 0.4:127>>) -> tensor<1xf32>
%3 = "tfl.custom_tf"(%0, %1, %2, %arg3) ( {
^bb0(%a1: tensor<128x128xf32>, %a2: tensor<1xf32>, %a3: tensor<1xf32>, %a4: tensor<1xi32>): // no predecessors
%4 = "tf.LayerNorm"(%a1, %a2, %a3, %a4) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
"tfl.yield"(%4) : (tensor<128x128xf32>) -> ()
}) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
%4 = "tfl.quantize"(%3) {qtype = tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>} : (tensor<128x128xf32>) -> tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
return %4 : tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
// CHECK: %4 = "tfl.custom_tf"(%arg0, %arg1, %arg2, %arg3) ( {
// CHECK-NEXT: ^bb0(%arg4: tensor<128x128xf32>, %arg5: tensor<1xf32>, %arg6: tensor<1xf32>, %arg7: tensor<1xi32>): // no predecessors
// CHECK-NEXT: "tf.LayerNorm"(%arg4, %arg5, %arg6, %arg7) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
// CHECK-NEXT: "tfl.yield"
// CHECK-NEXT: }) {_tfl_quant_trait = "fully_quantizable", device = ""} :
// CHECK-SAME: (tensor<128x128x!quant.uniform<u8:f32, 1.000000e-01:127>>, tensor<1x!quant.uniform<u8:f32, 2.000000e-01:127>>, tensor<1x!quant.uniform<u8:f32, 4.000000e-01:127>>, tensor<1xi32>)
// CHECK-SAME: -> tensor<128x128x!quant.uniform<u8:f32, 2.000000e-01:125>>
}
// CHECK-LABEL: NotQuantizeCustomTfOp
// DEBUG-LABEL: NotQuantizeCustomTfOp
func @NotQuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>,
%arg1: tensor<1x!quant.uniform<u8:f32, 0.2:127>>, %arg2: tensor<1x!quant.uniform<u8:f32, 0.4:127>>,
%arg3: tensor<1xi32>) -> (tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>) {
%0 = "tfl.dequantize"(%arg0) : (tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>) -> tensor<128x128xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<1x!quant.uniform<u8:f32, 0.2:127>>) -> tensor<1xf32>
%2 = "tfl.dequantize"(%arg2) : (tensor<1x!quant.uniform<u8:f32, 0.4:127>>) -> tensor<1xf32>
%3 = "tfl.custom_tf"(%0, %1, %2, %arg3) ( {
^bb0(%a1: tensor<128x128xf32>, %a2: tensor<1xf32>, %a3: tensor<1xf32>, %a4: tensor<1xi32>): // no predecessors
%4 = "tf.LayerNorm"(%a1, %a2, %a3, %a4) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
"tfl.yield"(%4) : (tensor<128x128xf32>) -> ()
}) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
%4 = "tfl.quantize"(%3) {qtype = tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>} : (tensor<128x128xf32>) -> tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
return %4 : tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
// CHECK: "tfl.custom_tf"
// CHECK-NEXT: ^bb0(%arg4: tensor<128x128xf32>, %arg5: tensor<1xf32>, %arg6: tensor<1xf32>, %arg7: tensor<1xi32>): // no predecessors
// CHECK-NEXT: "tf.LayerNorm"(%arg4, %arg5, %arg6, %arg7) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
// CHECK-NEXT: "tfl.yield"
// CHECK-NEXT: }) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
@ -219,6 +220,9 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
}
if (mlir::failed(module.verify())) {
return tensorflow::errors::Unknown("Final module is invalid");
}
return Status::OK();
}

View File

@ -123,7 +123,11 @@ class InitializeVariablesPass
// with ops that accepts resource as input.
if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(op))
return WalkResult::advance();
tensors_to_initialize.insert(GetGlobalTensorOp(op, symbol_table, func));
auto global_tensor = GetGlobalTensorOp(op, symbol_table, func);
// In case the function doesn't have bound_input to a resource
// then we return nullptr.
// We need only to initialize the variables that are bounded.
if (global_tensor) tensors_to_initialize.insert(global_tensor);
return WalkResult::advance();
});
}
@ -154,9 +158,9 @@ class InitializeVariablesPass
void runOnOperation() override {
auto module = getOperation();
// Use ordered container to make sure ids are deterministic if we got tensor
// ids from different part, since we have different passes that touches
// variables.
// Use ordered container to make sure ids are deterministic if we got
// tensor ids from different part, since we have different passes that
// touches variables.
// TODO(b/149099381): Remove integer IDs after adding the new variable
// handle type.
std::map<std::string, int> global_tensor_id;

View File

@ -240,15 +240,16 @@ def LegalizeGreaterEqual : Pat<(TF_GreaterEqualOp $l, $r),
// The 'validate_indices' attribute is deprecated.
def LegalizeGather: Pat<
(TF_GatherOp $params, $indices, $ignored_validate_indices),
(TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">)>;
(TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">,
ConstantAttr<I32Attr, "0">)>;
def LegalizeGatherNd : Pat<(TF_GatherNdOp $params, $indices),
(TFL_GatherNdOp $params, $indices)>;
def LegalizeGatherV2 : Pat<
(TF_GatherV2Op $params, $indices, (ConstantOp ElementsAttr:$axis),
ConstantAttr<I64Attr, "0">:$batch_dims),
(TFL_GatherOp $params, $indices, ExtractSingleElementAsInt32:$axis)>;
(TF_GatherV2Op $params, $indices, (ConstantOp ElementsAttr:$axis), $batch_dims),
(TFL_GatherOp $params, $indices, ExtractSingleElementAsInt32:$axis,
(convertIntAttrTo32Bit $batch_dims))>;
def LegalizeFloorDiv : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;

View File

@ -949,7 +949,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
operands, op->getAttrs());
converted.removeAttr("T");
converted->removeAttr("T");
(void)UpdateFunctionTypes(rewriter, converted, tensor_list_args);
rewriter.replaceOp(op, converted.getResults());

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
@ -1117,6 +1118,59 @@ struct RemoveReshapeBeforeFullyConnected
}
};
// Remove Reshape after FullyConnected when `keep_num_dims=false`, the Reshaoe
// does not alter the last dimension and it restores the batch dimensions
// collapsed by the FullyConnected op due to `keep_num_dims=false`. For example,
//
// // %input: tensor<4x16x32xf32>
// %fc = tfl.fully_connected(%input, %filter, %bias)
// {keep_num_dims = false, weights_format = "DEFAULT"}
// %shape = constant dense<[4, 16, 32]> : tensor<3xi32>
// %rs = tfl.reshape(%fc, %shape)
//
// can be canonicalized to
//
// %fc = tfl.fully_connected(%input, %filter, %bias)
// {keep_num_dims = true, weights_format = "DEFAULT"}
struct RemoveReshapeAfterFullyConnected
: public OpRewritePattern<TFL::ReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::ReshapeOp reshape_op,
PatternRewriter &rewriter) const override {
auto fully_connected_op = llvm::dyn_cast_or_null<TFL::FullyConnectedOp>(
reshape_op.input().getDefiningOp());
if (!fully_connected_op || fully_connected_op.getNumResults() != 1 ||
fully_connected_op.weights_format() != "DEFAULT" ||
fully_connected_op.keep_num_dims())
return failure();
if (!reshape_op.input().getUseList()->hasOneUse()) return failure();
auto input_shape = fully_connected_op.input().getType().cast<ShapedType>();
auto output_shape = fully_connected_op.getType(0).cast<ShapedType>();
auto reshape_shape = reshape_op.getType().cast<ShapedType>();
if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() ||
!reshape_shape.hasStaticShape())
return failure();
// Check that the reshape doesn't modify the last dimension and it restores
// the input (batch) dimension with the exception of the feature (last)
// dimension.
if (output_shape.getShape().back() != reshape_shape.getShape().back() ||
input_shape.getShape().drop_back() !=
reshape_shape.getShape().drop_back())
return failure();
llvm::SmallVector<Type, 1> output_type{reshape_op.getType()};
rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
reshape_op, output_type, fully_connected_op.input(),
fully_connected_op.filter(), fully_connected_op.bias(),
fully_connected_op.fused_activation_function(),
fully_connected_op.weights_format(), /*keep_num_dims=*/true);
return success();
}
};
using FuseBinaryOpToFollowingFullyConnected =
FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
using FuseBinaryOpToFollowingDepthwiseConv2D =
@ -1135,6 +1189,13 @@ void OptimizePass::runOnFunction() {
auto *ctx = &getContext();
auto func = getFunction();
// Merge reshapes into fully connected ops before we start moving them past
// binary ops.
OwningRewritePatternList phase_0_patterns;
phase_0_patterns.insert<RemoveReshapeAfterFullyConnected,
RemoveReshapeBeforeFullyConnected>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns));
// Potentially the binary ops might be fused together, like hard_swish, thus
// we explore these potentially first and then fuse the binary ops with the
// following ops in a second pattern match.
@ -1161,7 +1222,7 @@ void OptimizePass::runOnFunction() {
FuseBinaryOpToFollowingDepthwiseConv2D,
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp,
RemoveReshapeBeforeFullyConnected>(ctx);
RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected>(ctx);
if (enable_canonicalization_)
AddCanonicalizationPatterns(ctx, &phase_2_patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));

View File

@ -363,6 +363,7 @@ cc_library(
],
hdrs = [
"ir/tf_attributes.h",
"ir/tf_dialect.h",
],
deps = [
"@llvm-project//llvm:Support",
@ -389,6 +390,7 @@ cc_library(
cc_library(
name = "tensorflow_" + target["name"],
srcs = [
"ir/tf_dialect.h",
"ir/tf_ops.h",
"ir/tfrt_ops.h",
"ir/tf_remaining_ops.h",
@ -433,6 +435,7 @@ cc_library(
cc_library(
name = "tensorflow_remaining_ops",
srcs = [
"ir/tf_dialect.h",
"ir/tf_ops.h",
"ir/tf_remaining_ops.h",
"ir/tf_remaining_ops.cc",
@ -476,6 +479,7 @@ cc_library(
cc_library(
name = "tensorflow_tfrt_ops",
srcs = [
"ir/tf_dialect.h",
"ir/tf_ops.h",
"ir/tfrt_ops.h",
"ir/tfrt_ops.cc",
@ -519,6 +523,7 @@ cc_library(
cc_library(
name = "tensorflow_ops",
srcs = [
"ir/tf_dialect.h",
"ir/tf_ops.cc",
"ir/tf_ops.h",
],
@ -592,6 +597,7 @@ cc_library(
"ir/tf_types.cc",
],
hdrs = [
"ir/tf_dialect.h",
"ir/tf_types.h",
],
textual_hdrs = [
@ -616,6 +622,7 @@ cc_library(
hdrs = [
"dialect_registration.h",
"ir/tf_device.h",
"ir/tf_dialect.h",
"ir/tf_executor.h",
"ir/tf_ops.h",
"ir/tf_saved_model.h",

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
namespace mlir {
namespace TF {
@ -131,5 +133,9 @@ DictionaryAttr FuncAttr::GetAttrs() const {
return getImpl()->attrs.cast<DictionaryAttr>();
}
void TensorFlowDialect::registerAttributes() {
addAttributes<ShapeAttr, FuncAttr>();
}
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,138 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the standard MLIR TensorFlow dialect after control
// dependences are raise to the standard form.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
namespace mlir {
namespace TF {
class ResourceType;
class VariantType;
class TensorFlowDialect : public Dialect {
public:
TensorFlowDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "tf"; }
// Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
// function references to its gradient function. This attribute in TensorFlow
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
// string description of gradient attribute.
static StringRef GetGradientAttrName() { return "tf.gradient"; }
// This attribute marks if a function is stateful.
// Returns the string description of stateful attribute.
static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; }
// Returns true if the op can be duplicated during transformations.
static bool CanDuplicate(Operation *op);
// Returns true if the op can have side effects.
static bool CanHaveSideEffects(Operation *op);
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
// Prints a type registered to this dialect.
void printType(Type ty, DialectAsmPrinter &os) const override;
// Parses resource type with potential subtypes.
Type ParseResourceType(DialectAsmParser &parser) const;
// Prints resource type with potential subtypes.
void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) const;
// Parse and print variant type. It may have subtypes inferred using shape
// inference.
Type ParseVariantType(DialectAsmParser &parser) const;
void PrintVariantType(VariantType ty, DialectAsmPrinter &os) const;
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
// Register an op registration hook which is invoked during construction.
//
// A hook may use the public addOperations() method to add additional
// operations to the dialect. Hooks will only apply to subsequent
// instantations of the Dialect/MLIRContext.
static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
GetAdditionalOperationHooks()->push_back(std::move(fn));
}
// Re-define publicly the protected addOperations() method from the Dialect
// class, usually used in a Dialect constructor. This allows hook
// functions to register operations on the TensorFlow dialect using the
// same interface.
template <typename... Args>
void addOperations() {
Dialect::addOperations<Args...>();
}
using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &);
static void RegisterConstantFoldHook(ConstantFoldHook fn) {
constant_fold_hook_ = std::move(fn);
}
static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
return failure();
}
using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
ElementsAttr &output);
static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
decode_constant_hook_ = std::move(fn);
}
static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
if (decode_constant_hook_) return decode_constant_hook_(input, output);
return failure();
}
private:
/// Register the attributes of this dialect.
void registerAttributes();
/// Register the types of this dialect.
void registerTypes();
// Hook functions which may add additional operations to the dialect.
// These are invoked at construction time.
static std::vector<AdditionalOpFunction> *GetAdditionalOperationHooks();
static ConstantFoldHook constant_fold_hook_;
static DecodeConstantHook decode_constant_hook_;
};
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_

View File

@ -65,6 +65,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
@ -175,9 +177,15 @@ bool TensorFlowDialect::CanDuplicate(Operation *op) {
if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
return is_stateless.getValue();
// Otherwise, assume ops can be duplicated by default if its registered, else
// it cannot be for unknown ops.
return op->isRegistered();
// Assume ops can be duplicated when the given op is not a stateful op.
const tensorflow::OpRegistrationData *op_reg_data = nullptr;
tensorflow::Status s = tensorflow::OpRegistry::Global()->LookUp(
op->getName().stripDialect().str(), &op_reg_data);
if (!s.ok()) {
// Assume unknown ops can not be duplicated.
return false;
}
return !op_reg_data->op_def.is_stateful();
}
// Returns true if the op can have side effects.
@ -217,14 +225,10 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
>();
addTypes<
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
registerTypes();
addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
TFConstantFoldInterface>();
addAttributes<ShapeAttr, FuncAttr>();
registerAttributes();
// Support unknown operations because not all TensorFlow operations are
// registered.

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
@ -45,109 +46,4 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h"
namespace mlir {
namespace TF {
class TensorFlowDialect : public Dialect {
public:
TensorFlowDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "tf"; }
// Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
// function references to its gradient function. This attribute in TensorFlow
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
// string description of gradient attribute.
static StringRef GetGradientAttrName() { return "tf.gradient"; }
// This attribute marks if a function is stateful.
// Returns the string description of stateful attribute.
static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; }
// Returns true if the op can be duplicated during transformations.
static bool CanDuplicate(Operation *op);
// Returns true if the op can have side effects.
static bool CanHaveSideEffects(Operation *op);
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
// Prints a type registered to this dialect.
void printType(Type ty, DialectAsmPrinter &os) const override;
// Parses resource type with potential subtypes.
Type ParseResourceType(DialectAsmParser &parser) const;
// Prints resource type with potential subtypes.
void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) const;
// Parse and print variant type. It may have subtypes inferred using shape
// inference.
Type ParseVariantType(DialectAsmParser &parser) const;
void PrintVariantType(VariantType ty, DialectAsmPrinter &os) const;
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
// Register an op registration hook which is invoked during construction.
//
// A hook may use the public addOperations() method to add additional
// operations to the dialect. Hooks will only apply to subsequent
// instantations of the Dialect/MLIRContext.
static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
GetAdditionalOperationHooks()->push_back(std::move(fn));
}
// Re-define publicly the protected addOperations() method from the Dialect
// class, usually used in a Dialect constructor. This allows hook
// functions to register operations on the TensorFlow dialect using the
// same interface.
template <typename... Args>
void addOperations() {
Dialect::addOperations<Args...>();
}
using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &);
static void RegisterConstantFoldHook(ConstantFoldHook fn) {
constant_fold_hook_ = std::move(fn);
}
static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
return failure();
}
using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
ElementsAttr &output);
static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
decode_constant_hook_ = std::move(fn);
}
static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
if (decode_constant_hook_) return decode_constant_hook_(input, output);
return failure();
}
private:
// Hook functions which may add additional operations to the dialect.
// These are invoked at construction time.
static std::vector<AdditionalOpFunction> *GetAdditionalOperationHooks();
static ConstantFoldHook constant_fold_hook_;
static DecodeConstantHook decode_constant_hook_;
};
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_

View File

@ -1037,6 +1037,25 @@ def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>;
}
// TODO(b/177675373): Make dtypes and shapes derived attributes,
// use more general solution.
def TF_InfeedEnqueueTupleOp : TF_Op<"InfeedEnqueueTuple", []> {
let summary = [{
Feeds multiple Tensor values into the computation as an XLA tuple.
}];
let arguments = (ins
Arg<Variadic<TF_Tensor>, [{A list of tensors that will be provided using the infeed mechanism.}]>:$inputs,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$dtypes,
TF_ShapeAttrArray:$shapes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$layouts,
DefaultValuedAttr<I64Attr, "-1">:$device_ordinal
);
let results = (outs);
}
def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> {
let summary = "Formats a string template using a list of tensors.";

View File

@ -739,11 +739,9 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
auto func = op.branches()[index].cast<SymbolRefAttr>();
auto empty = rewriter.getStringAttr("");
auto call_op = rewriter.create<PartitionedCallOp>(
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
ReplaceTfOpWithNewOp<PartitionedCallOp>(
rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func,
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
rewriter.replaceOp(op, call_op.getResults());
return success();
}
@ -2510,11 +2508,9 @@ LogicalResult FoldConstantIfOp::matchAndRewrite(
// Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
auto rewrite = [&](auto op_type) {
auto empty = rewriter.getStringAttr("");
auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
op.getLoc(), op.getResultTypes(), op.input(), func,
ReplaceTfOpWithNewOp<typename decltype(op_type)::CallOp>(
rewriter, op, op.getResultTypes(), op.input(), func,
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
rewriter.replaceOp(op, call_op.getResults());
};
if (op.is_stateless())

View File

@ -548,7 +548,7 @@ struct DropAttributes : public OpRewritePattern<Op> {
// Drop the "output_shapes" attribute.
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
bool found = !!op.removeAttr("output_shapes");
bool found = !!op->removeAttr("output_shapes");
return success(found);
}
};
@ -581,3 +581,23 @@ ResourceHandleValueAndId GetResourceHandleValueAndIdBase(
if (emplace_res.second) ++next_id;
return {resource, emplace_res.first->second};
}
// Helper function to create TF op while copying all underscore attributes from
// another TF op.
// TODO(jpienaar): This is a workaround until behavior is established.
template <typename OpTy, typename... Args>
OpTy CreateTfOp(RewriterBase& b, Operation *op, Args &&... args) {
auto ret = b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
CopyDeviceAndUnderscoredAttributes(op, ret.getOperation());
return ret;
}
// Helper function to replace TF op with another op while copying all underscore
// attributes from the TF op.
// TODO(jpienaar): This is a workaround until behavior is established.
template <typename OpTy, typename... Args>
OpTy ReplaceTfOpWithNewOp(RewriterBase& b, Operation *op, Args &&... args) {
auto ret = CreateTfOp<OpTy>(b, op, std::forward<Args>(args)...);
b.replaceOp(op, ret.getOperation()->getResults());
return ret;
}

View File

@ -335,16 +335,9 @@ struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
auto reshape_op = rewriter.create<ReshapeOp>(pack_op.getLoc(), output_ty,
pack_op.getOperand(0), shape);
// Preserve unregistered attributes. Outside compilation relies on
// unregistered attribute `_xla_outside_compilation` to form clusters, so
// they must be preserved during canonicalization.
// TODO(b/173622615): Remove after fixed.
CopyUnderscoredAttributes(pack_op.getOperation(),
reshape_op.getOperation());
rewriter.replaceOp(pack_op, reshape_op.getResult());
ReplaceTfOpWithNewOp<ReshapeOp>(rewriter, pack_op, output_ty,
pack_op.getOperand(0), shape);
return success();
}
};

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
namespace {
// Returns the shape of the given value if it's ranked; returns llvm::None
@ -409,5 +410,13 @@ Type DropRefType(Type ty) { return DropTypeHelper<TF::TensorFlowRefType>(ty); }
Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
void TensorFlowDialect::registerTypes() {
addTypes<
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
}
} // namespace TF
} // namespace mlir

View File

@ -9,4 +9,4 @@ module attributes {tf.versions = {producer = 179 : i32}} {
}
// CHECK-LABEL: HloModule main
// CHECK: -> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0])
// CHECK: -> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], /*index=5*/f32[0])

View File

@ -35,6 +35,24 @@ func private @inline_shape_cast_callee(%arg : tensor<*xi32>) -> tensor<*xi32> {
return %arg : tensor<*xi32>
}
func private @custom_callee() -> tensor<2xi32> {
%0 = "tf.CustomTFOp"() : () -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// Test that unregistered user-defined custom TF operations can not be inlined
// when there are duplicated cases.
// CHECK-LABEL: func @dont_inline_custom_on_duplicated_cases(
func @dont_inline_custom_on_duplicated_cases() -> tensor<2xi32> {
// CHECK-NEXT: "tf.PartitionedCall"
// CHECK-NEXT: "tf.PartitionedCall"
// CHECK-NEXT: return
%0 = "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @custom_callee} : () -> tensor<2xi32>
%1 = "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @custom_callee} : () -> tensor<2xi32>
return %1: tensor<2xi32>
}
// CHECK-LABEL: func @inline_shape_cast(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi32>
func @inline_shape_cast(%arg: tensor<2xi32>) -> tensor<2xi32> {

View File

@ -1026,3 +1026,23 @@ func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i64> {
// CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
// CHECK: return %[[PROD]]
}
// CHECK-LABEL: @is_finite
func @is_finite(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> {
%0 = "tf.IsFinite"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1>
return %0 : tensor<3x4xi1>
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %arg0) : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: @is_finite_dynamic
func @is_finite_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xi1> {
%0 = "tf.IsFinite"(%arg0) : (tensor<?x4xf32>) -> tensor<?x4xi1>
return %0 : tensor<?x4xi1>
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %arg0) : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
// CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<?x4xf32>, tensor<f32>) -> tensor<?x4xi1>
// CHECK: return %[[RESULT]]
}

View File

@ -359,6 +359,31 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
// -----
// Tests that the pass reports error on non-aliasing WhileRegion input/output
// resources. It cannot lift resource ops from such WhileRegion ops and should
// fail with a helpful error message.
func @fail_non_aliasing_resource_input_output() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.cluster"() ( {
// expected-error@+1 {{Result #0 is not tied to arg #0 of the body}}
%1 = "tf.WhileRegion"(%0) ({
^bb0(%carg0:tensor<*x!tf.resource<tensor<f32>>>):
%cond = "tf.SomeOp"() : () -> tensor<i1>
"tf.Yield"(%cond) : (tensor<i1>) -> ()
}, {
^bb0(%carg0:tensor<*x!tf.resource<tensor<f32>>>):
%body = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
"tf.Yield"(%body) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
}) { is_stateless = false }
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// -----
// Tests that pass reports error on unsupported ops in loop cond.
func @cluster_with_loop() -> () {
@ -833,6 +858,9 @@ func @cluster_with_caseregion(%arg0: tensor<i32>) -> tensor<4xf32> {
// -----
// Test that the pass can lift resources out of WhileRegion
!tf_ref = type tensor<*x!tf.resource<tensor<f32>>>
// CHECK-LABEL: func @cluster_with_whileregion
func @cluster_with_whileregion() -> () {
// CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>}
@ -841,16 +869,17 @@ func @cluster_with_whileregion() -> () {
// CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"()
// CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[COUNT]], %[[READ]])
%0 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_ref
%pass_through = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> !tf_ref
%unused = "tf.VarHandleOp"() {container = "c", shared_name = "v3"} : () -> !tf_ref
"tf_device.cluster"() ( {
%2:3 = "tf.WhileRegion"(%0, %1, %unused) ({
%2:4 = "tf.WhileRegion"(%0, %1, %pass_through, %unused) ({
// CHECK: (%[[CARG0:.+]]: tensor<i32>, %[[CARG1:.+]]: tensor<f32>):
// CHECK: %[[CAST:.+]] = "tf.Cast"(%[[CARG1]])
// CHECK: "tf.Less"(%[[CARG0]], %[[CAST]])
// CHECK: "tf.Yield"
^bb0(%carg0: tensor<i32>, %carg1:tensor<*x!tf.resource<tensor<f32>>>, %carg2: tensor<*x!tf.resource<tensor<f32>>>):
%read0 = "tf.ReadVariableOp"(%carg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
^bb0(%carg0: tensor<i32>, %carg1: !tf_ref, %carg2: !tf_ref, %carg3: !tf_ref):
%read0 = "tf.ReadVariableOp"(%carg1) : (!tf_ref) -> tensor<f32>
%cast = "tf.Cast"(%read0) : (tensor<f32>) -> tensor<i32>
%cond = "tf.Less"(%carg0, %cast) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%cond) : (tensor<i1>) -> ()
@ -861,20 +890,20 @@ func @cluster_with_whileregion() -> () {
// CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>}
// CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]])
// CHECK-NEXT: "tf.Yield"(%[[ADD2]], %[[ADD1]])
^bb1(%barg0: tensor<i32>, %barg1:tensor<*x!tf.resource<tensor<f32>>>, %barg2: tensor<*x!tf.resource<tensor<f32>>>):
%read0 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
^bb1(%barg0: tensor<i32>, %barg1: !tf_ref, %barg2: !tf_ref, %barg3: !tf_ref):
%read0 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
%add0 = "tf.AddV2"(%read0, %read0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.AssignVariableOp"(%barg1, %add0) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
%read1 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
"tf.AssignVariableOp"(%barg1, %add0) : (!tf_ref, tensor<f32>) -> ()
%read1 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
%add1 = "tf.AddV2"(%read1, %read1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.AssignVariableOp"(%barg1, %add1) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%barg1, %add1) : (!tf_ref, tensor<f32>) -> ()
%constant = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%add2 = "tf.AddV2"(%barg0, %constant) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%id = "tf.Identity"(%barg2) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
"tf.Yield"(%add2, %barg1, %id) : (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) -> ()
%id = "tf.Identity"(%barg3) : (!tf_ref) -> !tf_ref
"tf.Yield"(%add2, %barg1, %pass_through, %id) : (tensor<i32>, !tf_ref, !tf_ref, !tf_ref) -> ()
}) {device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
: (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
-> (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
// CHECK: tf_device.return %[[WHILE]]#1 : tensor<f32>

View File

@ -40,6 +40,25 @@ func @read_only_resource(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor
return %2 : tensor<i32>
}
func private @computation_two_args(%arg0: tensor<i32>, %arg1: tensor<i32>)
// CHECK-LABEL: func @partitioned_variable_multiple_users
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf.resource<tensor<i32>>>, [[ARG1:%.+]]: tensor<!tf.resource<tensor<i32>>>)
func @partitioned_variable_multiple_users(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) {
// CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]])
// CHECK-DAG: [[READ1:%.+]] = "tf.ReadVariableOp"([[ARG1]])
// CHECK: [[INPUT0:%.+]] = "tf.TPUPartitionedInput"([[READ0]], [[READ1]])
// CHECK-DAG: [[READ2:%.+]] = "tf.ReadVariableOp"([[ARG0]])
// CHECK-DAG: [[READ3:%.+]] = "tf.ReadVariableOp"([[ARG1]])
// CHECK: [[INPUT1:%.+]] = "tf.TPUPartitionedInput"([[READ2]], [[READ3]])
%0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<i32>>>, tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
%2 = "tf.ReadVariableOp"(%0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
// CHECK: "tf_device.cluster_func"([[INPUT0]], [[INPUT1]])
"tf_device.cluster_func"(%1, %2) {func = @computation_two_args, use_spmd_for_xla_partitioning = true} : (tensor<i32>, tensor<i32>) -> ()
return
}
// Tests unsupported cases and IR are not modified.
// CHECK-LABEL: func @no_spmd
@ -86,16 +105,6 @@ func @resource_read_multiple_users(%arg0: tensor<!tf.resource<tensor<i32>>>, %ar
return %1 : tensor<i32>
}
// CHECK-LABEL: func @partitioned_variable_multiple_users
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf.resource<tensor<i32>>>, [[ARG1:%.+]]: tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
func @partitioned_variable_multiple_users(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>> {
// CHECK: "tf.TPUPartitionedInput"([[ARG0]], [[ARG1]])
%0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<i32>>>, tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
%2 = "tf_device.cluster_func"(%1) {func = @computation} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<!tf.resource<tensor<i32>>>
}
// CHECK-LABEL: func @non_resource_read_input_write_output
func @non_resource_read_input_write_output(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-NOT: tf.TPUPartitionedInput

View File

@ -409,3 +409,13 @@ def LowerXlog1pyOp : BinaryXopyPat<
def LowerXlogyOp : BinaryXopyPat<
(TF_XlogyOp $x, $y),
(TF_MulOp $x, (TF_LogOp $y))>;
//===----------------------------------------------------------------------===//
// IsFinite op patterns.
//===----------------------------------------------------------------------===//
def LowerIsFiniteOp : Pat<(TF_IsFiniteOp $x),
(TF_EqualOp
(TF_SubOp $x, $x),
(TF_ConstOp (GetScalarOfType<0> $x)),
/*incompatible_shape_error*/ConstBoolAttrTrue)>;

View File

@ -377,20 +377,21 @@ LogicalResult CanonicalizeWhileRegion(TF::WhileRegionOp op) {
for (OpResult result : llvm::reverse(op.getResults())) {
if (!IsResource(result)) continue;
int result_idx = result.getResultNumber();
auto body_arg = body.front()
.getTerminator()
->getOperand(result_idx)
.dyn_cast<BlockArgument>();
if (!body_arg || body_arg.getArgNumber() != result_idx) {
Operation *yield_op = body.front().getTerminator();
Value yield_operand = yield_op->getOperand(result_idx);
Value while_operand = op.getOperand(result_idx);
Value body_arg = body.getArgument(result_idx);
Value cond_arg = cond.getArgument(result_idx);
if (yield_operand != body_arg && yield_operand != while_operand) {
return op.emitOpError("Result #") << result_idx << " is not tied to arg #"
<< result_idx << " of the body";
}
body.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
cond.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
body_arg.replaceAllUsesWith(while_operand);
cond_arg.replaceAllUsesWith(while_operand);
result.replaceAllUsesWith(while_operand);
body.front().getTerminator()->eraseOperand(result_idx);
body.eraseArgument(result_idx);
cond.eraseArgument(result_idx);
result.replaceAllUsesWith(op.getOperand(result_idx));
op.getOperation()->eraseOperand(result_idx);
can_eliminate.set(result_idx);
}
@ -434,7 +435,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
if (while_region.cond().walk(check_while_cond).wasInterrupted())
return WalkResult::interrupt();
// For while region, the body input and output arg should match.
(void)CanonicalizeWhileRegion(while_region);
result = CanonicalizeWhileRegion(while_region);
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
if (!func) return WalkResult::interrupt();

View File

@ -118,7 +118,7 @@ void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
if (!read_var || !read_var.value().hasOneUse()) continue;
auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
read_var.resource().getDefiningOp());
if (!partitioned_input || !partitioned_input.output().hasOneUse() ||
if (!partitioned_input ||
!AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
continue;
@ -135,7 +135,7 @@ void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
partitioned_input._XlaShardingAttr());
operand.set(partitioned_read);
read_var->erase();
partitioned_input->erase();
if (partitioned_input->use_empty()) partitioned_input->erase();
}
}

View File

@ -112,7 +112,7 @@ void TPUResourceReadForWritePass::runOnOperation() {
auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
cluster_func.getLoc(), cluster_func.getResultTypes(), operands,
cluster_func.getAttrs());
cluster_func->getAttrs());
cluster_func.replaceAllUsesWith(new_cluster_func);
FuncOp func = cluster_func.getFunc();
Block& block = func.front();

View File

@ -557,7 +557,7 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) {
graph_ = absl::make_unique<Graph>(graph.flib_def());
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.add_default_attributes = false;
opts.add_default_attributes = true;
TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
opts, std::move(graph_def), graph_.get()));

Some files were not shown because too many files have changed in this diff Show More