mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge branch 'master' into aarch64_build_patch
This commit is contained in:
commit
1e15aa8c83
185
.bazelrc
185
.bazelrc
|
|
@ -84,16 +84,13 @@
|
|||
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||
#
|
||||
# Release build options (for all operating systems)
|
||||
# release_common: Common options for all builds on all operating systems.
|
||||
# release_windows_common: Common options for all builds on Windows.
|
||||
# release_gpu_common: Common options for GPU builds on Linux and Windows.
|
||||
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
|
||||
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
|
||||
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
|
||||
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
|
||||
# release_gpu_linux_cuda_11_2: Toolchain and CUDA options for CUDA 11.2 Linux GPU builds.
|
||||
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
|
||||
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
|
||||
# release_base: Common options for all builds on all operating systems.
|
||||
# release_gpu_base: Common options for GPU builds on Linux and Windows.
|
||||
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
|
||||
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
|
||||
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
|
||||
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
|
||||
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
|
||||
|
||||
# Allow builds using libc++ as a linker library
|
||||
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
|
||||
|
|
@ -123,7 +120,13 @@ build:android_x86_64 --cpu=x86_64
|
|||
build:android_x86_64 --fat_apk_cpu=x86_64
|
||||
|
||||
# Sets the default Apple platform to macOS.
|
||||
build --apple_platform_type=macos
|
||||
build:macos --apple_platform_type=macos
|
||||
|
||||
# gRPC on MacOS requires this #define
|
||||
build:macos --copt=-DGRPC_BAZEL_BUILD
|
||||
|
||||
# Settings for MacOS on ARM CPUs.
|
||||
build:macos_arm64 --cpu=darwin_arm64
|
||||
|
||||
# iOS configs for each architecture and the fat binary builds.
|
||||
build:ios --apple_platform_type=ios
|
||||
|
|
@ -140,10 +143,10 @@ build:ios_x86_64 --cpu=ios_x86_64
|
|||
build:ios_fat --config=ios
|
||||
build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
|
||||
|
||||
# Enables all the macos config options for macos_arm64
|
||||
build:macos_arm64 --config=macos
|
||||
build:macos_arm64 --apple_platform_type=macos
|
||||
build:macos_arm64 --cpu=darwin_arm64
|
||||
# For projects which use TensorFlow as part of a Bazel build process, putting
|
||||
# nothing in a bazelrc will default to a monolithic build. The following line
|
||||
# opts in to modular op registration support by default.
|
||||
build --define framework_shared_object=true
|
||||
|
||||
# Config to use a mostly-static build and disable modular op registration
|
||||
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
|
||||
|
|
@ -151,11 +154,6 @@ build:macos_arm64 --cpu=darwin_arm64
|
|||
# //tensorflow:libtensorflow_framework.so.
|
||||
build:monolithic --define framework_shared_object=false
|
||||
|
||||
# For projects which use TensorFlow as part of a Bazel build process, putting
|
||||
# nothing in a bazelrc will default to a monolithic build. The following line
|
||||
# opts in to modular op registration support by default.
|
||||
build --define framework_shared_object=true
|
||||
|
||||
# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
|
||||
build --java_toolchain=@org_tensorflow//third_party/toolchains/java:tf_java_toolchain
|
||||
build --host_java_toolchain=@org_tensorflow//third_party/toolchains/java:tf_java_toolchain
|
||||
|
|
@ -197,8 +195,8 @@ build:cuda_clang --config=cuda
|
|||
build:cuda_clang --repo_env TF_CUDA_CLANG=1
|
||||
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
|
||||
|
||||
# dbg config, as a shorthand for '--config=opt -c dbg'
|
||||
build:dbg --config=opt -c dbg
|
||||
# Debug config
|
||||
build:dbg -c dbg
|
||||
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
|
||||
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
|
|
@ -266,12 +264,10 @@ build:c++1z --config=c++17
|
|||
build:c++17_gcc --cxxopt=-std=c++1z
|
||||
build:c++1z_gcc --config=c++17_gcc
|
||||
|
||||
# Enable using platform specific build settings, except when cross-compiling for
|
||||
# mobile platforms.
|
||||
# Trigger --config=<host platform>, except when cross-compiling.
|
||||
build --enable_platform_specific_config
|
||||
build:android --noenable_platform_specific_config
|
||||
build:ios --noenable_platform_specific_config
|
||||
build:macos_arm64 --noenable_platform_specific_config
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:android --copt=-w
|
||||
|
|
@ -357,15 +353,15 @@ build:avx_win --copt=/arch=AVX
|
|||
build:avx2_win --copt=/arch=AVX2
|
||||
|
||||
# Options to build TensorFlow 1.x or 2.x.
|
||||
build:v1 --define=tf_api_version=1
|
||||
build:v2 --define=tf_api_version=2
|
||||
build:v1 --action_env=TF2_BEHAVIOR=0
|
||||
build:v2 --action_env=TF2_BEHAVIOR=1
|
||||
build:v1 --define=tf_api_version=1 --action_env=TF2_BEHAVIOR=0
|
||||
build:v2 --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
|
||||
build --config=v2
|
||||
test --config=v2
|
||||
|
||||
# Enable XLA
|
||||
build:xla --define=with_xla_support=true
|
||||
# Enable XLA except on mobile.
|
||||
build --config=xla
|
||||
build:xla --define=with_xla_support=true
|
||||
build:android --define=with_xla_support=false
|
||||
build:ios --define=with_xla_support=false
|
||||
|
||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||
# Options when using remote execution
|
||||
|
|
@ -399,7 +395,6 @@ build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
|||
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
||||
|
||||
# Non-rbe settings we should include because we do not run configure
|
||||
build:rbe_linux --config=xla
|
||||
build:rbe_linux --config=avx_linux
|
||||
build:rbe_linux --config=short_logs
|
||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||
|
|
@ -423,45 +418,9 @@ build:rbe_linux_cuda_base --config=tensorrt
|
|||
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
|
||||
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||
|
||||
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda10.1_nvcc_base --action_env=TF_CUDA_VERSION=10
|
||||
build:rbe_linux_cuda10.1_nvcc_base --action_env=TF_CUDNN_VERSION=7
|
||||
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda11.0_nvcc_base --action_env=TF_CUDA_VERSION=11
|
||||
build:rbe_linux_cuda11.0_nvcc_base --action_env=TF_CUDNN_VERSION=8
|
||||
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
|
||||
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda11.2_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDA_VERSION=11.2
|
||||
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDNN_VERSION=8.1
|
||||
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDA_VERSION=11
|
||||
build:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDNN_VERSION=8
|
||||
build:rbe_linux_cuda11.2_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.2_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.2_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
|
|
@ -471,25 +430,24 @@ build:rbe_linux_cuda11.2_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-
|
|||
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda"
|
||||
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_tensorrt"
|
||||
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_nccl"
|
||||
build:rbe_linux_cuda11.2_nvcc_py3.5 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.5"
|
||||
build:rbe_linux_cuda11.2_nvcc_py3.6 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.6"
|
||||
build:rbe_linux_cuda11.2_nvcc_py3.7 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7"
|
||||
build:rbe_linux_cuda11.2_nvcc_py3.8 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8"
|
||||
build:rbe_linux_cuda11.2_nvcc_py3.9 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9"
|
||||
|
||||
# Map default to CUDA 11.2 for PY35 and greater.
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.2_nvcc_py3.5
|
||||
# Map default to CUDA 11.2.
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.2_nvcc_py3.6
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.2_nvcc_py3.7
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.2_nvcc_py3.8
|
||||
build:rbe_linux_cuda_nvcc_py39 --config=rbe_linux_cuda11.2_nvcc_py3.9
|
||||
|
||||
# Deprecated configs that people might still use.
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
|
||||
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --action_env=TF_CUDA_VERSION=11.2
|
||||
build:rbe_linux_cuda_clang_base --action_env=TF_CUDNN_VERSION=8.1
|
||||
build:rbe_linux_cuda_clang_base --action_env=TF_CUDA_VERSION=11
|
||||
build:rbe_linux_cuda_clang_base --action_env=TF_CUDNN_VERSION=8
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
|
||||
|
|
@ -583,65 +541,42 @@ try-import %workspace%/.tf_configure.bazelrc
|
|||
try-import %workspace%/.bazelrc.user
|
||||
|
||||
# Here are bazelrc configs for release builds
|
||||
build:release_common --config=opt
|
||||
build:release_common --config=v2
|
||||
build:release_common --distinct_host_configuration=false
|
||||
build:release_common --action_env TF_CONFIGURE_IOS="0"
|
||||
build:release_base --config=v2
|
||||
build:release_base --distinct_host_configuration=false
|
||||
test:release_base --flaky_test_attempts=3
|
||||
test:release_base --test_size_filters=small,medium
|
||||
|
||||
build:release_cpu_linux --config=release_common
|
||||
build:release_cpu_linux --config=release_base
|
||||
build:release_cpu_linux --config=avx_linux
|
||||
# We use the same toolchain for CPU/GPU packages.
|
||||
# Did not add this to the defaults in case this changes.
|
||||
build:release_cpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
|
||||
build:release_cpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
|
||||
test:release_cpu_linux --test_env=LD_LIBRARY_PATH
|
||||
|
||||
build:release_cpu_macos --config=release_common
|
||||
build:release_cpu_macos --config=release_base
|
||||
build:release_cpu_macos --config=avx_linux
|
||||
|
||||
build:release_gpu_common --config=release_common
|
||||
build:release_gpu_common --config=cuda
|
||||
build:release_gpu_common --config=tensorrt
|
||||
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
|
||||
build:release_gpu_common --action_env=TF_CUDA_VERSION="11.2"
|
||||
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8.1"
|
||||
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
|
||||
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
|
||||
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
|
||||
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
|
||||
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
|
||||
build:release_gpu_base --config=release_base
|
||||
build:release_gpu_base --config=cuda
|
||||
build:release_gpu_base --action_env=TF_CUDA_VERSION="11"
|
||||
build:release_gpu_base --action_env=TF_CUDNN_VERSION="8"
|
||||
build:release_gpu_base --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
|
||||
|
||||
|
||||
build:release_gpu_linux --config=release_gpu_common
|
||||
build:release_gpu_linux --config=avx_linux
|
||||
build:release_gpu_linux --config=release_cpu_linux
|
||||
build:release_gpu_linux --config=release_gpu_base
|
||||
build:release_gpu_linux --config=tensorrt
|
||||
build:release_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
|
||||
build:release_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
|
||||
build:release_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
|
||||
build:release_gpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
|
||||
build:release_windows_common --config=release_common
|
||||
build:release_windows_common --define=no_tensorflow_py_deps=true
|
||||
build:release_windows_common --announce_rc
|
||||
|
||||
build:release_cpu_windows --config=release_base
|
||||
build:release_cpu_windows --config=avx_win
|
||||
build:release_cpu_windows --define=no_tensorflow_py_deps=true
|
||||
# First available in VS 16.4. Speeds Windows compile times by a lot. See
|
||||
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
build:release_windows_common --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
|
||||
build:release_cpu_windows --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
|
||||
|
||||
build:release_cpu_windows --config=release_windows_common
|
||||
|
||||
build:release_gpu_windows --config=release_windows_common
|
||||
|
||||
build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux
|
||||
build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
|
||||
build:release_gpu_linux_cuda_10_1 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
|
||||
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
|
||||
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"
|
||||
|
||||
build:release_gpu_linux_cuda_11 --config=release_gpu_linux
|
||||
build:release_gpu_linux_cuda_11 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
|
||||
build:release_gpu_linux_cuda_11 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain
|
||||
build:release_gpu_linux_cuda_11 --action_env=TF_CUDA_VERSION="11"
|
||||
build:release_gpu_linux_cuda_11 --action_env=TF_CUDNN_VERSION="8"
|
||||
|
||||
build:release_gpu_linux_cuda_11_2 --config=release_gpu_linux
|
||||
build:release_gpu_linux_cuda_11_2 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
|
||||
build:release_gpu_linux_cuda_11_2 --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain
|
||||
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDA_VERSION="11.2"
|
||||
build:release_gpu_linux_cuda_11_2 --action_env=TF_CUDNN_VERSION="8.1"
|
||||
build:release_gpu_windows --config=release_cpu_windows
|
||||
build:release_gpu_windows --config=release_gpu_base
|
||||
|
||||
# Address sanitizer
|
||||
# CC=clang bazel build --config asan
|
||||
|
|
|
|||
14
RELEASE.md
14
RELEASE.md
|
|
@ -39,11 +39,23 @@
|
|||
`num_parallel_calls` set, `deterministic` is used to indicate that
|
||||
outputs can be obtained in the non-deterministic order.
|
||||
* Options returned by `tf.data.Dataset.options()` are no longer mutable.
|
||||
* tf.data input pipelines can now be executed in debug mode, which
|
||||
disables any asynchrony, parallelism, or non-determinism and forces
|
||||
Python execution (as opposed to trace-compiled graph execution) of
|
||||
user-defined functions passed into transformations such as `map`. The
|
||||
debug mode can be enabled through `tf.data.experimental.enable_debug_mode()`.
|
||||
* `tf.lite`
|
||||
* Enabled the new MLIR-based quantization backend by default
|
||||
* The new backend is used for 8 bits full integer post-training quantization
|
||||
* The new backend removes the redundant rescales and fixes some bugs (shared weight/bias, extremely small scales, etc)
|
||||
* Set `experimental_new_quantizer` in tf.lite.TFLiteConverter to False to disable this change
|
||||
* `tf.keras`
|
||||
* Enabled a new supported input type in `Model.fit`,
|
||||
`tf.keras.utils.experimental.DatasetCreator`, which takes a
|
||||
callable, `dataset_fn`.
|
||||
`DatasetCreator` is intended to work across all `tf.distribute`
|
||||
strategies, and is the only input type supported for Parameter Server
|
||||
strategy.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
|
@ -74,6 +86,8 @@
|
|||
* Add `tf.data.experimental.AutoShardingPolicy.HINT` which can be used
|
||||
to provide hints to tf.distribute-based auto-sharding as to where in
|
||||
the input pipeline to insert sharding transformations.
|
||||
* Make tf.data.Options persistent across `tf.function` and `GraphDef`
|
||||
boundaries.
|
||||
* XLA compilation:
|
||||
* `tf.function(experimental_compile=True)` has become a stable API,
|
||||
renamed `tf.function(jit_compile=True)`.
|
||||
|
|
|
|||
22
configure.py
22
configure.py
|
|
@ -542,7 +542,6 @@ def set_cc_opt_flags(environ_cp):
|
|||
for opt in cc_opt_flags.split():
|
||||
write_to_bazelrc('build:opt --copt=%s' % opt)
|
||||
write_to_bazelrc('build:opt --host_copt=%s' % opt)
|
||||
write_to_bazelrc('build:opt --define with_default_optimizations=true')
|
||||
|
||||
|
||||
def set_tf_cuda_clang(environ_cp):
|
||||
|
|
@ -1202,13 +1201,12 @@ def config_info_line(name, help_text):
|
|||
print('\t--config=%-12s\t# %s' % (name, help_text))
|
||||
|
||||
|
||||
def configure_ios():
|
||||
"""Configures TensorFlow for iOS builds.
|
||||
|
||||
This function will only be executed if `is_macos()` is true.
|
||||
"""
|
||||
def configure_ios(environ_cp):
|
||||
"""Configures TensorFlow for iOS builds."""
|
||||
if not is_macos():
|
||||
return
|
||||
if not get_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False):
|
||||
return
|
||||
for filepath in APPLE_BAZEL_FILES:
|
||||
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
|
||||
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
|
||||
|
|
@ -1327,11 +1325,11 @@ def main():
|
|||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
else:
|
||||
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
||||
|
||||
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
|
||||
write_to_bazelrc('build --config=xla')
|
||||
with_xla_support = environ_cp.get('TF_ENABLE_XLA', None)
|
||||
if with_xla_support is not None:
|
||||
write_to_bazelrc('build --define=with_xla_support=%s' % (
|
||||
'true' if int(with_xla_support) else 'false'))
|
||||
|
||||
set_action_env_var(
|
||||
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
|
||||
|
|
@ -1450,9 +1448,7 @@ def main():
|
|||
|
||||
system_specific_test_config(environ_cp)
|
||||
|
||||
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
|
||||
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
|
||||
configure_ios()
|
||||
configure_ios(environ_cp)
|
||||
|
||||
print('Preconfigured Bazel build configs. You can use any of the below by '
|
||||
'adding "--config=<>" to your build command. See .bazelrc for more '
|
||||
|
|
|
|||
|
|
@ -936,6 +936,7 @@ tf_cc_shared_object(
|
|||
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
||||
"//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/core/profiler:profiler_impl",
|
||||
|
|
|
|||
|
|
@ -85,13 +85,12 @@ if _module_dir:
|
|||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
if _os.environ.get("_PREFER_OSS_KERAS", False):
|
||||
try:
|
||||
from keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
_keras_module = "keras.api._v2.keras"
|
||||
keras = _LazyLoader("keras", globals(), _keras_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "keras", keras)
|
||||
else:
|
||||
try:
|
||||
from .python.keras.api._v2 import keras
|
||||
|
|
@ -160,14 +159,33 @@ if _running_from_pip_package():
|
|||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
losses = keras.losses
|
||||
metrics = keras.metrics
|
||||
optimizers = keras.optimizers
|
||||
initializers = keras.initializers
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
# It is possible that keras is a lazily loaded module, which might break when
|
||||
# actually trying to import it. Have a Try-Catch to make sure it doesn't break
|
||||
# when it doing some very initial loading, like tf.compat.v2, etc.
|
||||
if _os.environ.get("_PREFER_OSS_KERAS", False):
|
||||
try:
|
||||
_keras_package = "keras.api._v2.keras."
|
||||
losses = _LazyLoader("losses", globals(), _keras_package + "losses")
|
||||
metrics = _LazyLoader("metrics", globals(), _keras_package + "metrics")
|
||||
optimizers = _LazyLoader(
|
||||
"optimizers", globals(), _keras_package + "optimizers")
|
||||
initializers = _LazyLoader(
|
||||
"initializers", globals(), _keras_package + "initializers")
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
losses = keras.losses
|
||||
metrics = keras.metrics
|
||||
optimizers = keras.optimizers
|
||||
initializers = keras.initializers
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
# Delete modules that should be hidden from dir().
|
||||
|
|
|
|||
|
|
@ -76,13 +76,12 @@ if _module_dir:
|
|||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
if _os.environ.get("_PREFER_OSS_KERAS", False):
|
||||
try:
|
||||
from keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
_keras_module = "keras.api._v1.keras"
|
||||
keras = _LazyLoader("keras", globals(), _keras_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "keras", keras)
|
||||
else:
|
||||
try:
|
||||
from .python.keras.api._v1 import keras
|
||||
|
|
|
|||
|
|
@ -429,6 +429,7 @@ tf_cuda_library(
|
|||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/pluggable_device:pluggable_device_plugin_init",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
|
|
@ -735,7 +736,10 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
|
|||
} else {
|
||||
status->status =
|
||||
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
|
||||
if (!status->status.ok()) {
|
||||
if (status->status.ok()) {
|
||||
TF_CHECK_OK(
|
||||
tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle));
|
||||
} else {
|
||||
delete lib_handle;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -236,6 +236,8 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
|||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
||||
// TODO(penpornk): Enable this test on Windows.
|
||||
#if !defined(PLATFORM_WINDOWS)
|
||||
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
|
@ -250,6 +252,7 @@ TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
|
|||
ASSERT_EQ(TF_OK, code) << status_msg;
|
||||
TF_DeletePluggableDeviceLibraryHandle(lib);
|
||||
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
|
||||
#endif // !defined(PLATFORM_WINDOWS)
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -32,11 +32,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -93,7 +89,6 @@ cc_library(
|
|||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -113,6 +108,7 @@ cc_library(
|
|||
":math_grad",
|
||||
":nn_grad",
|
||||
":not_differentiable",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
)
|
||||
|
|
@ -188,6 +184,8 @@ tf_cuda_cc_test(
|
|||
":nn_grad",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:unified_api_testutil",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
|
|
@ -213,6 +211,8 @@ tf_cuda_cc_test(
|
|||
":math_grad",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:unified_api_testutil",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
|
|
@ -243,6 +243,8 @@ tf_cuda_cc_test(
|
|||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:unified_api_testutil",
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
|
|
|
|||
|
|
@ -17,8 +17,6 @@ cc_library(
|
|||
deps = [
|
||||
":tape_operation",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -33,7 +31,6 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
|
|
@ -51,6 +48,9 @@ cc_library(
|
|||
deps = [
|
||||
":tape_context",
|
||||
":tape_operation",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,10 @@ cc_library(
|
|||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/core/common_runtime/pluggable_device:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
|
|
@ -76,6 +79,7 @@ tf_cc_test(
|
|||
deps = [
|
||||
":stream_executor",
|
||||
":stream_executor_internal",
|
||||
":stream_executor_test_util",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
|
||||
|
|
@ -84,3 +88,14 @@ tf_cc_test(
|
|||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor_test_util",
|
||||
srcs = ["stream_executor_test_util.cc"],
|
||||
hdrs = ["stream_executor_test_util.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":stream_executor_hdrs",
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -749,7 +749,9 @@ port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
|
|||
return result;
|
||||
}
|
||||
|
||||
port::Status InitStreamExecutorPlugin(void* dso_handle) {
|
||||
port::Status InitStreamExecutorPlugin(void* dso_handle,
|
||||
std::string* device_type,
|
||||
std::string* platform_name) {
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
|
||||
// Step 1: Load symbol for `TF_InitPlugin`
|
||||
|
|
@ -759,10 +761,12 @@ port::Status InitStreamExecutorPlugin(void* dso_handle) {
|
|||
|
||||
// Step 2: Call `TF_InitPlugin`
|
||||
auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
|
||||
return InitStreamExecutorPlugin(init_fn);
|
||||
return InitStreamExecutorPlugin(init_fn, device_type, platform_name);
|
||||
}
|
||||
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
|
||||
std::string* device_type,
|
||||
std::string* platform_name) {
|
||||
SE_PlatformRegistrationParams params{
|
||||
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
|
||||
SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
|
||||
|
|
@ -808,7 +812,8 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
|
|||
TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
|
||||
|
||||
// Register new platform
|
||||
std::string platform_name = std::string(platform.name);
|
||||
*device_type = std::string(platform.type);
|
||||
*platform_name = std::string(platform.name);
|
||||
std::unique_ptr<stream_executor::CPlatform> cplatform(
|
||||
new stream_executor::CPlatform(
|
||||
std::move(platform), params.destroy_platform, std::move(platform_fns),
|
||||
|
|
@ -816,8 +821,8 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
|
|||
std::move(timer_fns)));
|
||||
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
||||
std::move(cplatform)));
|
||||
|
||||
// TODO(annarev): Add pluggable device registration here.
|
||||
// TODO(annarev): Return `use_bfc_allocator` value in some way so that it is
|
||||
// available in `PluggableDeviceProcessState` once the latter is checked in.
|
||||
return port::Status::OK();
|
||||
}
|
||||
} // namespace stream_executor
|
||||
|
|
|
|||
|
|
@ -431,10 +431,13 @@ typedef struct SP_Platform {
|
|||
// Whether this platform supports unified memory.
|
||||
// Unified memory is a single memory address space accessible from any device.
|
||||
TF_Bool supports_unified_memory;
|
||||
|
||||
// Whether to wrap allocator for this device with an allocator that uses BFC
|
||||
// (best-fit with coalescing) strategy.
|
||||
TF_Bool use_bfc_allocator;
|
||||
} SP_Platform;
|
||||
|
||||
#define SP_PLATFORM_STRUCT_SIZE \
|
||||
TF_OFFSET_OF_END(SP_Platform, supports_unified_memory)
|
||||
#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, use_bfc_allocator)
|
||||
|
||||
typedef struct SP_PlatformFns {
|
||||
size_t struct_size;
|
||||
|
|
|
|||
|
|
@ -31,12 +31,17 @@ namespace stream_executor {
|
|||
typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
|
||||
TF_Status* const);
|
||||
|
||||
// Registers StreamExecutor platform.
|
||||
port::Status InitStreamExecutorPlugin(void* dso_handle);
|
||||
// Registers StreamExecutor platform. `device_type` and `platform_name` are
|
||||
// output parameters.
|
||||
port::Status InitStreamExecutorPlugin(void* dso_handle,
|
||||
std::string* device_type,
|
||||
std::string* platform_name);
|
||||
|
||||
// Allow registering a StreamExecutor plugin using a function (used for
|
||||
// testing).
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
|
||||
std::string* device_type,
|
||||
std::string* platform_name);
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
|
@ -24,209 +25,26 @@ limitations under the License.
|
|||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
#include "tensorflow/stream_executor/timer.h"
|
||||
|
||||
struct SP_Stream_st {
|
||||
explicit SP_Stream_st(int id) : stream_id(id) {}
|
||||
int stream_id;
|
||||
};
|
||||
|
||||
struct SP_Event_st {
|
||||
explicit SP_Event_st(int id) : event_id(id) {}
|
||||
int event_id;
|
||||
};
|
||||
|
||||
struct SP_Timer_st {
|
||||
explicit SP_Timer_st(int id) : timer_id(id) {}
|
||||
int timer_id;
|
||||
};
|
||||
|
||||
namespace stream_executor {
|
||||
namespace {
|
||||
constexpr int kDeviceCount = 2;
|
||||
constexpr char kDeviceName[] = "MY_DEVICE";
|
||||
constexpr char kDeviceType[] = "GPU";
|
||||
|
||||
/*** Create SP_StreamExecutor (with empty functions) ***/
|
||||
void allocate(const SP_Device* const device, uint64_t size,
|
||||
int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
|
||||
void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
|
||||
}
|
||||
void* host_memory_allocate(const SP_Device* const device, uint64_t size) {
|
||||
return nullptr;
|
||||
}
|
||||
void host_memory_deallocate(const SP_Device* const device, void* mem) {}
|
||||
TF_Bool get_allocator_stats(const SP_Device* const device,
|
||||
SP_AllocatorStats* const stats) {
|
||||
return true;
|
||||
}
|
||||
TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free,
|
||||
int64_t* const total) {
|
||||
return true;
|
||||
}
|
||||
void create_stream(const SP_Device* const device, SP_Stream* stream,
|
||||
TF_Status* const status) {
|
||||
stream = nullptr;
|
||||
}
|
||||
void destroy_stream(const SP_Device* const device, SP_Stream stream) {}
|
||||
void create_stream_dependency(const SP_Device* const device,
|
||||
SP_Stream dependent, SP_Stream other,
|
||||
TF_Status* const status) {}
|
||||
void get_stream_status(const SP_Device* const device, SP_Stream stream,
|
||||
TF_Status* const status) {}
|
||||
void create_event(const SP_Device* const device, SP_Event* event,
|
||||
TF_Status* const status) {
|
||||
event = nullptr;
|
||||
}
|
||||
void destroy_event(const SP_Device* const device, SP_Event event) {}
|
||||
SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) {
|
||||
return SE_EVENT_UNKNOWN;
|
||||
}
|
||||
void record_event(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {}
|
||||
void wait_for_event(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {}
|
||||
void create_timer(const SP_Device* const device, SP_Timer* timer,
|
||||
TF_Status* const status) {}
|
||||
void destroy_timer(const SP_Device* const device, SP_Timer timer) {}
|
||||
void start_timer(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Timer timer, TF_Status* const status) {}
|
||||
void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
|
||||
TF_Status* const status) {}
|
||||
void memcpy_dtoh(const SP_Device* const device, SP_Stream stream,
|
||||
void* host_dst, const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void memcpy_htod(const SP_Device* const device, SP_Stream stream,
|
||||
SP_DeviceMemoryBase* const device_dst, const void* host_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst,
|
||||
const SP_DeviceMemoryBase* const device_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void sync_memcpy_htod(const SP_Device* const device,
|
||||
SP_DeviceMemoryBase* const device_dst,
|
||||
const void* host_src, uint64_t size,
|
||||
TF_Status* const status) {}
|
||||
void block_host_for_event(const SP_Device* const device, SP_Event event,
|
||||
TF_Status* const status) {}
|
||||
void synchronize_all_activity(const SP_Device* const device,
|
||||
TF_Status* const status) {}
|
||||
TF_Bool host_callback(const SP_Device* const device, SP_Stream stream,
|
||||
SE_StatusCallbackFn const callback_fn,
|
||||
void* const callback_arg) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
|
||||
*se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
|
||||
se->allocate = allocate;
|
||||
se->deallocate = deallocate;
|
||||
se->host_memory_allocate = host_memory_allocate;
|
||||
se->host_memory_deallocate = host_memory_deallocate;
|
||||
se->get_allocator_stats = get_allocator_stats;
|
||||
se->device_memory_usage = device_memory_usage;
|
||||
se->create_stream = create_stream;
|
||||
se->destroy_stream = destroy_stream;
|
||||
se->create_stream_dependency = create_stream_dependency;
|
||||
se->get_stream_status = get_stream_status;
|
||||
se->create_event = create_event;
|
||||
se->destroy_event = destroy_event;
|
||||
se->get_event_status = get_event_status;
|
||||
se->record_event = record_event;
|
||||
se->wait_for_event = wait_for_event;
|
||||
se->create_timer = create_timer;
|
||||
se->destroy_timer = destroy_timer;
|
||||
se->start_timer = start_timer;
|
||||
se->stop_timer = stop_timer;
|
||||
se->memcpy_dtoh = memcpy_dtoh;
|
||||
se->memcpy_htod = memcpy_htod;
|
||||
se->sync_memcpy_dtoh = sync_memcpy_dtoh;
|
||||
se->sync_memcpy_htod = sync_memcpy_htod;
|
||||
se->block_host_for_event = block_host_for_event;
|
||||
se->synchronize_all_activity = synchronize_all_activity;
|
||||
se->host_callback = host_callback;
|
||||
}
|
||||
|
||||
void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
|
||||
*device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
|
||||
}
|
||||
|
||||
/*** Create SP_TimerFns ***/
|
||||
uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; }
|
||||
|
||||
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
|
||||
timer_fns->nanoseconds = nanoseconds;
|
||||
}
|
||||
|
||||
/*** Create SP_Platform ***/
|
||||
void create_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultTimerFns(timer_fns);
|
||||
}
|
||||
void destroy_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
|
||||
|
||||
void create_stream_executor(const SP_Platform* platform,
|
||||
SE_CreateStreamExecutorParams* params,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultStreamExecutor(params->stream_executor);
|
||||
}
|
||||
void destroy_stream_executor(const SP_Platform* platform,
|
||||
SP_StreamExecutor* se) {}
|
||||
void get_device_count(const SP_Platform* platform, int* device_count,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
*device_count = kDeviceCount;
|
||||
}
|
||||
void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
|
||||
}
|
||||
void destroy_device(const SP_Platform* platform, SP_Device* device) {}
|
||||
|
||||
void create_device_fns(const SP_Platform* platform,
|
||||
SE_CreateDeviceFnsParams* params, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
|
||||
}
|
||||
void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
|
||||
}
|
||||
|
||||
void PopulateDefaultPlatform(SP_Platform* platform,
|
||||
SP_PlatformFns* platform_fns) {
|
||||
*platform = {SP_PLATFORM_STRUCT_SIZE};
|
||||
platform->name = kDeviceName;
|
||||
platform->type = kDeviceType;
|
||||
platform_fns->get_device_count = get_device_count;
|
||||
platform_fns->create_device = create_device;
|
||||
platform_fns->destroy_device = destroy_device;
|
||||
platform_fns->create_device_fns = create_device_fns;
|
||||
platform_fns->destroy_device_fns = destroy_device_fns;
|
||||
platform_fns->create_stream_executor = create_stream_executor;
|
||||
platform_fns->destroy_stream_executor = destroy_stream_executor;
|
||||
platform_fns->create_timer_fns = create_timer_fns;
|
||||
platform_fns->destroy_timer_fns = destroy_timer_fns;
|
||||
}
|
||||
|
||||
void destroy_platform(SP_Platform* const platform) {}
|
||||
void destroy_platform_fns(SP_PlatformFns* const platform_fns) {}
|
||||
|
||||
/*** Registration tests ***/
|
||||
TEST(StreamExecutor, SuccessfulRegistration) {
|
||||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||
params->destroy_platform = destroy_platform;
|
||||
params->destroy_platform_fns = destroy_platform_fns;
|
||||
test_util::PopulateDefaultPlatformRegistrationParams(params);
|
||||
};
|
||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||
string device_type, platform_name;
|
||||
port::Status status =
|
||||
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
|
||||
TF_ASSERT_OK(status);
|
||||
port::StatusOr<Platform*> maybe_platform =
|
||||
MultiPlatformManager::PlatformWithName("MY_DEVICE");
|
||||
TF_ASSERT_OK(maybe_platform.status());
|
||||
Platform* platform = maybe_platform.ConsumeValueOrDie();
|
||||
ASSERT_EQ(platform->Name(), kDeviceName);
|
||||
ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount);
|
||||
ASSERT_EQ(platform->Name(), test_util::kDeviceName);
|
||||
ASSERT_EQ(platform->VisibleDeviceCount(), test_util::kDeviceCount);
|
||||
|
||||
port::StatusOr<StreamExecutor*> maybe_executor =
|
||||
platform->ExecutorForDevice(0);
|
||||
|
|
@ -237,13 +55,13 @@ TEST(StreamExecutor, NameNotSet) {
|
|||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||
test_util::PopulateDefaultPlatformRegistrationParams(params);
|
||||
params->platform->name = nullptr;
|
||||
params->destroy_platform = destroy_platform;
|
||||
params->destroy_platform_fns = destroy_platform_fns;
|
||||
};
|
||||
|
||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||
string device_type, platform_name;
|
||||
port::Status status =
|
||||
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
|
||||
}
|
||||
|
|
@ -252,13 +70,13 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) {
|
|||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||
test_util::PopulateDefaultPlatformRegistrationParams(params);
|
||||
params->platform->name = "INVALID:NAME";
|
||||
params->destroy_platform = destroy_platform;
|
||||
params->destroy_platform_fns = destroy_platform_fns;
|
||||
};
|
||||
|
||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||
string device_type, platform_name;
|
||||
port::Status status =
|
||||
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
EXPECT_THAT(
|
||||
status.error_message(),
|
||||
|
|
@ -269,13 +87,13 @@ TEST(StreamExecutor, InvalidNameWithSlash) {
|
|||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||
test_util::PopulateDefaultPlatformRegistrationParams(params);
|
||||
params->platform->name = "INVALID/";
|
||||
params->destroy_platform = destroy_platform;
|
||||
params->destroy_platform_fns = destroy_platform_fns;
|
||||
};
|
||||
|
||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||
string device_type, platform_name;
|
||||
port::Status status =
|
||||
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
testing::ContainsRegex("Device name/type 'INVALID/' must match"));
|
||||
|
|
@ -285,13 +103,13 @@ TEST(StreamExecutor, CreateDeviceNotSet) {
|
|||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||
test_util::PopulateDefaultPlatformRegistrationParams(params);
|
||||
params->platform_fns->create_device = nullptr;
|
||||
params->destroy_platform = destroy_platform;
|
||||
params->destroy_platform_fns = destroy_platform_fns;
|
||||
};
|
||||
|
||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||
string device_type, platform_name;
|
||||
port::Status status =
|
||||
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
ASSERT_EQ(status.error_message(),
|
||||
"'create_device' field in SP_PlatformFns must be set.");
|
||||
|
|
@ -301,13 +119,13 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) {
|
|||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) -> void {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||
test_util::PopulateDefaultPlatformRegistrationParams(params);
|
||||
params->platform->supports_unified_memory = true;
|
||||
params->destroy_platform = destroy_platform;
|
||||
params->destroy_platform_fns = destroy_platform_fns;
|
||||
};
|
||||
|
||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||
string device_type, platform_name;
|
||||
port::Status status =
|
||||
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
|
||||
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||
ASSERT_EQ(
|
||||
status.error_message(),
|
||||
|
|
@ -319,18 +137,18 @@ class StreamExecutorTest : public ::testing::Test {
|
|||
protected:
|
||||
StreamExecutorTest() {}
|
||||
void SetUp() override {
|
||||
PopulateDefaultPlatform(&platform_, &platform_fns_);
|
||||
PopulateDefaultDeviceFns(&device_fns_);
|
||||
PopulateDefaultStreamExecutor(&se_);
|
||||
PopulateDefaultTimerFns(&timer_fns_);
|
||||
test_util::PopulateDefaultPlatform(&platform_, &platform_fns_);
|
||||
test_util::PopulateDefaultDeviceFns(&device_fns_);
|
||||
test_util::PopulateDefaultStreamExecutor(&se_);
|
||||
test_util::PopulateDefaultTimerFns(&timer_fns_);
|
||||
}
|
||||
void TearDown() override {}
|
||||
|
||||
StreamExecutor* GetExecutor(int ordinal) {
|
||||
if (!cplatform_) {
|
||||
cplatform_ = absl::make_unique<CPlatform>(
|
||||
platform_, destroy_platform, platform_fns_, destroy_platform_fns,
|
||||
device_fns_, se_, timer_fns_);
|
||||
platform_, test_util::DestroyPlatform, platform_fns_,
|
||||
test_util::DestroyPlatformFns, device_fns_, se_, timer_fns_);
|
||||
}
|
||||
port::StatusOr<StreamExecutor*> maybe_executor =
|
||||
cplatform_->ExecutorForDevice(ordinal);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,193 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace test_util {
|
||||
|
||||
/*** Functions for creating SP_StreamExecutor ***/
|
||||
void Allocate(const SP_Device* const device, uint64_t size,
|
||||
int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
|
||||
void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
|
||||
}
|
||||
void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) {
|
||||
return nullptr;
|
||||
}
|
||||
void HostMemoryDeallocate(const SP_Device* const device, void* mem) {}
|
||||
TF_Bool GetAllocatorStats(const SP_Device* const device,
|
||||
SP_AllocatorStats* const stats) {
|
||||
return true;
|
||||
}
|
||||
TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free,
|
||||
int64_t* const total) {
|
||||
return true;
|
||||
}
|
||||
void CreateStream(const SP_Device* const device, SP_Stream* stream,
|
||||
TF_Status* const status) {
|
||||
stream = nullptr;
|
||||
}
|
||||
void DestroyStream(const SP_Device* const device, SP_Stream stream) {}
|
||||
void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent,
|
||||
SP_Stream other, TF_Status* const status) {}
|
||||
void GetStreamStatus(const SP_Device* const device, SP_Stream stream,
|
||||
TF_Status* const status) {}
|
||||
void CreateEvent(const SP_Device* const device, SP_Event* event,
|
||||
TF_Status* const status) {
|
||||
event = nullptr;
|
||||
}
|
||||
void DestroyEvent(const SP_Device* const device, SP_Event event) {}
|
||||
SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) {
|
||||
return SE_EVENT_UNKNOWN;
|
||||
}
|
||||
void RecordEvent(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {}
|
||||
void WaitForEvent(const SP_Device* const device, SP_Stream stream,
|
||||
SP_Event event, TF_Status* const status) {}
|
||||
void CreateTimer(const SP_Device* const device, SP_Timer* timer,
|
||||
TF_Status* const status) {}
|
||||
void DestroyTimer(const SP_Device* const device, SP_Timer timer) {}
|
||||
void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
|
||||
TF_Status* const status) {}
|
||||
void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
|
||||
TF_Status* const status) {}
|
||||
void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst,
|
||||
const SP_DeviceMemoryBase* const device_src, uint64_t size,
|
||||
TF_Status* const status) {}
|
||||
void MemcpyHToD(const SP_Device* const device, SP_Stream stream,
|
||||
SP_DeviceMemoryBase* const device_dst, const void* host_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void SyncMemcpyDToH(const SP_Device* const device, void* host_dst,
|
||||
const SP_DeviceMemoryBase* const device_src, uint64_t size,
|
||||
TF_Status* const status) {}
|
||||
void SyncMemcpyHToD(const SP_Device* const device,
|
||||
SP_DeviceMemoryBase* const device_dst, const void* host_src,
|
||||
uint64_t size, TF_Status* const status) {}
|
||||
void BlockHostForEvent(const SP_Device* const device, SP_Event event,
|
||||
TF_Status* const status) {}
|
||||
void SynchronizeAllActivity(const SP_Device* const device,
|
||||
TF_Status* const status) {}
|
||||
TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream,
|
||||
SE_StatusCallbackFn const callback_fn,
|
||||
void* const callback_arg) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
|
||||
*se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
|
||||
se->allocate = Allocate;
|
||||
se->deallocate = Deallocate;
|
||||
se->host_memory_allocate = HostMemoryAllocate;
|
||||
se->host_memory_deallocate = HostMemoryDeallocate;
|
||||
se->get_allocator_stats = GetAllocatorStats;
|
||||
se->device_memory_usage = DeviceMemoryUsage;
|
||||
se->create_stream = CreateStream;
|
||||
se->destroy_stream = DestroyStream;
|
||||
se->create_stream_dependency = CreateStreamDependency;
|
||||
se->get_stream_status = GetStreamStatus;
|
||||
se->create_event = CreateEvent;
|
||||
se->destroy_event = DestroyEvent;
|
||||
se->get_event_status = GetEventStatus;
|
||||
se->record_event = RecordEvent;
|
||||
se->wait_for_event = WaitForEvent;
|
||||
se->create_timer = CreateTimer;
|
||||
se->destroy_timer = DestroyTimer;
|
||||
se->start_timer = StartTimer;
|
||||
se->stop_timer = StopTimer;
|
||||
se->memcpy_dtoh = MemcpyDToH;
|
||||
se->memcpy_htod = MemcpyHToD;
|
||||
se->sync_memcpy_dtoh = SyncMemcpyDToH;
|
||||
se->sync_memcpy_htod = SyncMemcpyHToD;
|
||||
se->block_host_for_event = BlockHostForEvent;
|
||||
se->synchronize_all_activity = SynchronizeAllActivity;
|
||||
se->host_callback = HostCallback;
|
||||
}
|
||||
|
||||
void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
|
||||
*device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
|
||||
}
|
||||
|
||||
/*** Functions for creating SP_TimerFns ***/
|
||||
uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; }
|
||||
|
||||
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
|
||||
timer_fns->nanoseconds = Nanoseconds;
|
||||
}
|
||||
|
||||
/*** Functions for creating SP_Platform ***/
|
||||
void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultTimerFns(timer_fns);
|
||||
}
|
||||
void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
|
||||
|
||||
void CreateStreamExecutor(const SP_Platform* platform,
|
||||
SE_CreateStreamExecutorParams* params,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
PopulateDefaultStreamExecutor(params->stream_executor);
|
||||
}
|
||||
void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) {
|
||||
}
|
||||
void GetDeviceCount(const SP_Platform* platform, int* device_count,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
*device_count = kDeviceCount;
|
||||
}
|
||||
void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
|
||||
}
|
||||
void DestroyDevice(const SP_Platform* platform, SP_Device* device) {}
|
||||
|
||||
void CreateDeviceFns(const SP_Platform* platform,
|
||||
SE_CreateDeviceFnsParams* params, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
|
||||
}
|
||||
void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {}
|
||||
|
||||
void PopulateDefaultPlatform(SP_Platform* platform,
|
||||
SP_PlatformFns* platform_fns) {
|
||||
*platform = {SP_PLATFORM_STRUCT_SIZE};
|
||||
platform->name = kDeviceName;
|
||||
platform->type = kDeviceType;
|
||||
platform_fns->get_device_count = GetDeviceCount;
|
||||
platform_fns->create_device = CreateDevice;
|
||||
platform_fns->destroy_device = DestroyDevice;
|
||||
platform_fns->create_device_fns = CreateDeviceFns;
|
||||
platform_fns->destroy_device_fns = DestroyDeviceFns;
|
||||
platform_fns->create_stream_executor = CreateStreamExecutor;
|
||||
platform_fns->destroy_stream_executor = DestroyStreamExecutor;
|
||||
platform_fns->create_timer_fns = CreateTimerFns;
|
||||
platform_fns->destroy_timer_fns = DestroyTimerFns;
|
||||
}
|
||||
|
||||
/*** Functions for creating SE_PlatformRegistrationParams ***/
|
||||
void DestroyPlatform(SP_Platform* platform) {}
|
||||
void DestroyPlatformFns(SP_PlatformFns* platform_fns) {}
|
||||
|
||||
void PopulateDefaultPlatformRegistrationParams(
|
||||
SE_PlatformRegistrationParams* const params) {
|
||||
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||
params->destroy_platform = DestroyPlatform;
|
||||
params->destroy_platform_fns = DestroyPlatformFns;
|
||||
}
|
||||
|
||||
} // namespace test_util
|
||||
} // namespace stream_executor
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
|
||||
struct SP_Stream_st {
|
||||
explicit SP_Stream_st(int id) : stream_id(id) {}
|
||||
int stream_id;
|
||||
};
|
||||
|
||||
struct SP_Event_st {
|
||||
explicit SP_Event_st(int id) : event_id(id) {}
|
||||
int event_id;
|
||||
};
|
||||
|
||||
struct SP_Timer_st {
|
||||
explicit SP_Timer_st(int id) : timer_id(id) {}
|
||||
int timer_id;
|
||||
};
|
||||
|
||||
namespace stream_executor {
|
||||
namespace test_util {
|
||||
|
||||
constexpr int kDeviceCount = 2;
|
||||
constexpr char kDeviceName[] = "MY_DEVICE";
|
||||
constexpr char kDeviceType[] = "GPU";
|
||||
|
||||
void PopulateDefaultStreamExecutor(SP_StreamExecutor* se);
|
||||
void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns);
|
||||
void PopulateDefaultTimerFns(SP_TimerFns* timer_fns);
|
||||
void PopulateDefaultPlatform(SP_Platform* platform,
|
||||
SP_PlatformFns* platform_fns);
|
||||
void PopulateDefaultPlatformRegistrationParams(
|
||||
SE_PlatformRegistrationParams* const params);
|
||||
|
||||
void DestroyPlatform(SP_Platform* platform);
|
||||
void DestroyPlatformFns(SP_PlatformFns* platform_fns);
|
||||
|
||||
} // namespace test_util
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
|
||||
|
|
@ -13,5 +13,8 @@ tf_cc_shared_object(
|
|||
name = "test_pluggable_device.so",
|
||||
srcs = ["test_pluggable_device.cc"],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,10 +14,14 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
|
||||
TF_Status* const status) {
|
||||
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
|
||||
params->platform->name = "GPU";
|
||||
params->platform->type = "XGPU";
|
||||
stream_executor::test_util::PopulateDefaultPlatformRegistrationParams(params);
|
||||
}
|
||||
|
||||
void TF_InitKernel() {}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,13 +53,12 @@ if _module_dir:
|
|||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
if _os.environ.get("_PREFER_OSS_KERAS", False):
|
||||
try:
|
||||
from keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
_keras_module = "keras.api._v2.keras"
|
||||
keras = _LazyLoader("keras", globals(), _keras_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "keras", keras)
|
||||
else:
|
||||
try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
|
|
@ -88,11 +87,30 @@ setattr(_current_module, "enable_v2_behavior", enable_v2_behavior)
|
|||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
losses = keras.losses
|
||||
metrics = keras.metrics
|
||||
optimizers = keras.optimizers
|
||||
initializers = keras.initializers
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
# It is possible that keras is a lazily loaded module, which might break when
|
||||
# actually trying to import it. Have a Try-Catch to make sure it doesn't break
|
||||
# when it doing some very initial loading, like tf.compat.v2, etc.
|
||||
if _os.environ.get("_PREFER_OSS_KERAS", False):
|
||||
try:
|
||||
_keras_package = "keras.api._v2.keras."
|
||||
losses = _LazyLoader("losses", globals(), _keras_package + "losses")
|
||||
metrics = _LazyLoader("metrics", globals(), _keras_package + "metrics")
|
||||
optimizers = _LazyLoader(
|
||||
"optimizers", globals(), _keras_package + "optimizers")
|
||||
initializers = _LazyLoader(
|
||||
"initializers", globals(), _keras_package + "initializers")
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
losses = keras.losses
|
||||
metrics = keras.metrics
|
||||
optimizers = keras.optimizers
|
||||
initializers = keras.initializers
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
|
|
|
|||
|
|
@ -43,13 +43,12 @@ if _module_dir:
|
|||
setattr(_current_module, "estimator", estimator)
|
||||
|
||||
if _os.environ.get("_PREFER_OSS_KERAS", False):
|
||||
try:
|
||||
from keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
_keras_module = "keras.api._v1.keras"
|
||||
keras = _LazyLoader("keras", globals(), _keras_module)
|
||||
_module_dir = _module_util.get_parent_dir_for_name(_keras_module)
|
||||
if _module_dir:
|
||||
_current_module.__path__ = [_module_dir] + _current_module.__path__
|
||||
setattr(_current_module, "keras", keras)
|
||||
else:
|
||||
try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
|
|
|
|||
|
|
@ -163,6 +163,7 @@ void AllocateAndParseFlags() {
|
|||
|
||||
ops_flags = new XlaOpsCommonFlags;
|
||||
ops_flags->tf_xla_always_defer_compilation = false;
|
||||
ops_flags->tf_xla_async_compilation = false;
|
||||
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
|
@ -216,6 +217,10 @@ void AllocateAndParseFlags() {
|
|||
|
||||
Flag("tf_xla_always_defer_compilation",
|
||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||
Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation,
|
||||
"When lazy compilation is enabled, asynchronous compilation starts "
|
||||
"the cluster compilation in the background, and the fallback path "
|
||||
"is executed until the compilation has finished."),
|
||||
|
||||
Flag("tf_introduce_floating_point_jitter_to_tensors",
|
||||
setter_for_jitter_tensor_names, "",
|
||||
|
|
|
|||
|
|
@ -99,6 +99,9 @@ struct XlaOpsCommonFlags {
|
|||
// If true, _XlaCompile always refuses to compile the cluster, which means the
|
||||
// XLA clusters always run in the TF executor. Defaults to false.
|
||||
bool tf_xla_always_defer_compilation;
|
||||
// If true, _XlaCompile compiles the cluster asynchronously with respect to
|
||||
// the main execution. The fallback path is taken while compilation happens.
|
||||
bool tf_xla_async_compilation;
|
||||
};
|
||||
|
||||
// Flags for the build_xla_ops pass.
|
||||
|
|
|
|||
|
|
@ -37,7 +37,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
|
|||
const XlaCompiler::Options& options,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
const NameAttrList& function, XlaCompilationCache* cache,
|
||||
absl::Span<XlaCompiler::Argument const> args, const XlaCompiler& compiler) {
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const XlaCompiler& compiler) {
|
||||
const XlaCompiler::CompilationResult* compilation_result = nullptr;
|
||||
xla::LocalExecutable* executable = nullptr;
|
||||
TF_RETURN_IF_ERROR(cache->Compile(options, function, args, compile_options,
|
||||
|
|
@ -100,12 +101,10 @@ xla::StatusOr<std::string> GetCompilerIr(
|
|||
}));
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
|
||||
XlaCompiler::Options options =
|
||||
GenerateCompilerOptions(*cache, *flr, dev,
|
||||
/*stream=*/nullptr, platform_info,
|
||||
/*has_ref_vars=*/false, &tf_allocator_adapter);
|
||||
/*has_ref_vars=*/false);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.always_return_tuple = false;
|
||||
|
|
|
|||
|
|
@ -166,8 +166,9 @@ static Status CompileToLocalExecutable(
|
|||
const XlaPlatformInfo& platform_info,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
|
||||
xla::LocalClient** client,
|
||||
absl::Span<const int> constants,
|
||||
XlaCompilationCache::CompileMode compile_mode,
|
||||
bool may_alias_resource_update, xla::LocalClient** client,
|
||||
const XlaCompiler::CompilationResult** compilation_result,
|
||||
xla::LocalExecutable** executable) {
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
|
|
@ -190,11 +191,10 @@ static Status CompileToLocalExecutable(
|
|||
|
||||
*client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
*cache, *ctx->function_library(), ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info, has_ref_vars, &tf_allocator_adapter);
|
||||
platform_info, has_ref_vars);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
|
|
@ -202,7 +202,6 @@ static Status CompileToLocalExecutable(
|
|||
// rather than a one-element tuple.
|
||||
compile_options.always_return_tuple = false;
|
||||
compile_options.alias_resource_update = !has_ref_vars &&
|
||||
!platform_info.is_on_xla_device() &&
|
||||
may_alias_resource_update;
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
|
|
@ -210,9 +209,7 @@ static Status CompileToLocalExecutable(
|
|||
constants, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
return cache->Compile(options, function, *args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
: XlaCompilationCache::CompileMode::kStrict,
|
||||
return cache->Compile(options, function, *args, compile_options, compile_mode,
|
||||
compilation_result, executable);
|
||||
}
|
||||
|
||||
|
|
@ -233,7 +230,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
|
||||
variable_infos, constants_, /*lazy=*/false,
|
||||
variable_infos, constants_, XlaCompilationCache::CompileMode::kStrict,
|
||||
/*may_alias_resource_update=*/true, &client, &compilation_result,
|
||||
&executable);
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
|
|
@ -246,12 +243,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
|
||||
GetAllocator(ctx->device(), stream, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
: client->default_device_ordinal();
|
||||
XlaComputationLaunchContext launch_context(
|
||||
|
|
@ -381,6 +375,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
cannot_compile_cluster = cannot_compile_cluster_;
|
||||
}
|
||||
XlaCompilationCache::CompileMode compile_mode = [&] {
|
||||
if (must_compile_) {
|
||||
return XlaCompilationCache::CompileMode::kStrict;
|
||||
}
|
||||
return GetXlaOpsCommonFlags().tf_xla_async_compilation
|
||||
? XlaCompilationCache::CompileMode::kAsync
|
||||
: XlaCompilationCache::CompileMode::kLazy;
|
||||
}();
|
||||
|
||||
if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
|
||||
cannot_compile_cluster) {
|
||||
|
|
@ -396,12 +398,12 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||
// unlocking them in XlaRun may lead to deadlocks.
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
|
||||
constants_,
|
||||
/*lazy=*/!must_compile_,
|
||||
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
|
||||
constants_, compile_mode, /*may_alias_resource_update=*/false, &client,
|
||||
&kernel, &executable);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
if (compile_mode != XlaCompilationCache::CompileMode::kLazy ||
|
||||
status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
||||
|
|
@ -423,6 +425,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||
host_alloc_attrs.set_on_host(true);
|
||||
Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
|
||||
|
||||
// Async compilation returns nullptr executable without an error.
|
||||
if (!executable) {
|
||||
DCHECK(!must_compile_);
|
||||
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
|
||||
|
|
@ -463,13 +466,11 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
|||
XlaExecutableClosure closure =
|
||||
XlaExecutableClosureStore::Global()->Consume(key);
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
|
||||
GetAllocator(ctx->device(), stream, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
|
||||
int device_ordinal = stream ? stream->parent()->device_ordinal()
|
||||
: closure.client()->default_device_ordinal();
|
||||
XlaComputationLaunchContext launch_context(
|
||||
|
|
|
|||
|
|
@ -53,6 +53,9 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
|
||||
constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
|
||||
constexpr int64 XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads;
|
||||
constexpr int64
|
||||
XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations;
|
||||
|
||||
XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
|
||||
DeviceType device_type)
|
||||
|
|
@ -68,6 +71,12 @@ XlaCompilationCache::~XlaCompilationCache() {
|
|||
"programs to complete";
|
||||
}
|
||||
}
|
||||
// Wait for all outstanding compilations to finish.
|
||||
// Resetting the pointer explicitly in the top level destructor.
|
||||
// Without this, the pointer would be reset when the AsyncCompilationState
|
||||
// is destructed, which is dependent on the order of the members in the
|
||||
// XlaCompilationCache class, which is error prone if the order changes.
|
||||
async_compilation_state_.compiler_threads.reset();
|
||||
// TODO(b/110813685): Think about the program ownership model. Programs are
|
||||
// currently owned by the compilation cache which means we must wait for
|
||||
// program completion in the destructor. There are multiple compilation caches
|
||||
|
|
@ -170,7 +179,7 @@ Status XlaCompilationCache::BuildExecutable(
|
|||
? options.device_ordinal
|
||||
: client_->default_device_ordinal());
|
||||
build_options.set_result_layout(result.xla_output_shape);
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
build_options.set_device_allocator(options.device_allocator.get());
|
||||
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
|
||||
build_options.mutable_debug_options()->set_xla_detailed_logging(
|
||||
options.detailed_logging);
|
||||
|
|
@ -184,21 +193,22 @@ Status XlaCompilationCache::BuildExecutable(
|
|||
|
||||
Status XlaCompilationCache::Compile(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
CompileMode compile_mode,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
absl::optional<int64> compile_threshold;
|
||||
if (compile_mode == CompileMode::kLazy) {
|
||||
compile_threshold = kDefaultCompilationThreshold;
|
||||
}
|
||||
auto compile_fn = [&](XlaCompiler* compiler,
|
||||
// !!Pay attention when additional variables must be captured by this
|
||||
// lambda!! compile_fn can run asynchronously after this funcion has
|
||||
// exited. Make sure that any variable needed inside compile_fn is
|
||||
// either passed as an argument, or captured by value right here.
|
||||
auto compile_fn = [compile_options, function](
|
||||
XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
return compiler->CompileFunction(compile_options, function, args, result);
|
||||
};
|
||||
return CompileImpl(options, function, args, compile_fn,
|
||||
/*compile_threshold=*/compile_threshold,
|
||||
return CompileImpl(options, function, args, compile_fn, compile_mode,
|
||||
out_compilation_result, out_executable);
|
||||
}
|
||||
|
||||
|
|
@ -261,7 +271,7 @@ static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
|
|||
|
||||
Status XlaCompilationCache::CompileSingleOp(
|
||||
const XlaCompiler::Options& options,
|
||||
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
|
||||
const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
|
|
@ -274,6 +284,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||
// and causes false uniqueness between nodes.
|
||||
name.mutable_attr()->erase("_class");
|
||||
auto compile_op = [&](XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
std::vector<DataType> result_dtypes(ctx->num_outputs());
|
||||
for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
|
||||
|
|
@ -308,8 +319,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
return CompileImpl(options, name, args, compile_op, CompileMode::kStrict,
|
||||
out_compilation_result, out_executable);
|
||||
}
|
||||
|
||||
|
|
@ -327,12 +337,113 @@ void LogOnceXlaCompiledFirstCluster() {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
Status XlaCompilationCache::CompileStrict(
|
||||
Entry* entry, const XlaCompiler::Options& options,
|
||||
const std::vector<XlaCompiler::Argument>& args, const string& function_name,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn) {
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
const uint64 compile_start_us = env->NowMicros();
|
||||
|
||||
XlaCompiler compiler(options);
|
||||
entry->compile_state = CompileState::kCompiled;
|
||||
|
||||
entry->compilation_status =
|
||||
compile_fn(&compiler, args, &entry->compilation_result);
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
TF_RET_CHECK(entry->executable.get() == nullptr);
|
||||
entry->compilation_status =
|
||||
BuildExecutable(options, entry->compilation_result, &entry->executable);
|
||||
|
||||
const uint64 compile_end_us = env->NowMicros();
|
||||
const uint64 compile_time_us = compile_end_us - compile_start_us;
|
||||
metrics::UpdateXlaCompilationTime(compile_time_us);
|
||||
{
|
||||
mutex_lock lock(cluster_compile_stats_mu_);
|
||||
auto it = cluster_compile_stats_.find(function_name);
|
||||
const uint64 compile_time_s = compile_time_us / 1.0e6;
|
||||
it->second.compile_count++;
|
||||
it->second.cumulative_compile_time_us += compile_time_us;
|
||||
|
||||
LogOnceXlaCompiledFirstCluster();
|
||||
VLOG(1) << "compiled " << function_name << " " << it->second.compile_count
|
||||
<< " times, compile time: " << compile_time_us
|
||||
<< " us, cumulative: " << it->second.cumulative_compile_time_us
|
||||
<< " us ("
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(compile_time_s)
|
||||
<< " / "
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(
|
||||
it->second.cumulative_compile_time_us / 1.0e6)
|
||||
<< ")";
|
||||
|
||||
XlaJitCompilationActivity jit_compilation_activity;
|
||||
jit_compilation_activity.set_cluster_name(function_name);
|
||||
jit_compilation_activity.set_compile_count(it->second.compile_count);
|
||||
jit_compilation_activity.set_compile_time_us(compile_time_us);
|
||||
jit_compilation_activity.set_cumulative_compile_time_us(
|
||||
it->second.cumulative_compile_time_us);
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastXlaActivity(std::move(jit_compilation_activity)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompilationCache::CompileAsynchronous(
|
||||
Entry* entry, const XlaCompiler::Options& options,
|
||||
const std::vector<XlaCompiler::Argument>& args, const string& function_name,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn) {
|
||||
// Explicitly capture all required data by value for async compilation.
|
||||
entry->compile_state = CompileState::kCompiling;
|
||||
{
|
||||
mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
|
||||
async_compilation_state_.num_ongoing_compilations++;
|
||||
}
|
||||
// Don't move the above code into the thread function as it synchronously
|
||||
// updates the async compilation state!
|
||||
|
||||
// When the ThreadPool for the compilation cache is destroyed, it waits for
|
||||
// compilations to have finished. This means that both 'entry' and 'this' will
|
||||
// be alive for the duration of the compilation.
|
||||
// !!Pay attention when additional variables must be captured by this lambda!!
|
||||
// All values are captured by value. Make sure that all pointer values (like
|
||||
// entry) do not get freed until the lambda has finished,\.
|
||||
async_compilation_state_.compiler_threads->Schedule([=] {
|
||||
Entry local_entry;
|
||||
VLOG(2) << "Starting asynchronous compilation of cluster " << function_name
|
||||
<< '.';
|
||||
// We don't need to lock local_entry.mu, but do it anyway to satisfy
|
||||
// thread safety analysis.
|
||||
mutex_lock entry_lock(local_entry.mu);
|
||||
(void)CompileStrict(&local_entry, options, args, function_name, compile_fn);
|
||||
|
||||
VLOG(2) << "Finished asynchronous compililation of cluster "
|
||||
<< function_name << '.';
|
||||
{
|
||||
mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
|
||||
async_compilation_state_.num_ongoing_compilations--;
|
||||
}
|
||||
{ // Populate original entry with compilation result.
|
||||
mutex_lock entry_lock(entry->mu);
|
||||
entry->compilation_result = local_entry.compilation_result;
|
||||
entry->compile_state = local_entry.compile_state;
|
||||
entry->compilation_status = local_entry.compilation_status;
|
||||
entry->executable = std::move(local_entry.executable);
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompilationCache::CompileImpl(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn,
|
||||
absl::optional<int64> compile_threshold,
|
||||
CompileMode compile_mode,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable) {
|
||||
if (FailOnXlaCompilation()) {
|
||||
|
|
@ -348,9 +459,20 @@ Status XlaCompilationCache::CompileImpl(
|
|||
VLOG(3) << i << ": " << args[i].HumanString();
|
||||
}
|
||||
}
|
||||
absl::optional<int64> compile_threshold;
|
||||
if (compile_mode == CompileMode::kLazy) {
|
||||
compile_threshold = kDefaultCompilationThreshold;
|
||||
} else if (compile_mode == CompileMode::kAsync) {
|
||||
compile_threshold = 0; // for now, always compile right away.
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
|
||||
VLOG(2) << "Signature: " << signature.HumanString();
|
||||
|
||||
string human_signature;
|
||||
if (VLOG_IS_ON(2)) {
|
||||
human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name();
|
||||
VLOG(2) << "Signature: " << human_signature;
|
||||
}
|
||||
|
||||
// The outer lock protects the existence of the cache entry. It does not
|
||||
// protect the contents of the cache entry.
|
||||
|
|
@ -402,14 +524,18 @@ Status XlaCompilationCache::CompileImpl(
|
|||
// cache eviction.
|
||||
mutex_lock entry_lock(entry->mu);
|
||||
int64 current_request_count = ++entry->request_count;
|
||||
VLOG(2) << "Compilation cache entry hit: " << entry->compiled
|
||||
<< " signature: " << signature.HumanString() << " with request count "
|
||||
VLOG(2) << "Compilation cache entry hit: "
|
||||
<< static_cast<int>(entry->compile_state)
|
||||
<< " signature: " << human_signature << " with request count "
|
||||
<< current_request_count << " and compile threshold "
|
||||
<< compile_threshold.value_or(0);
|
||||
if (!entry->compiled) {
|
||||
// TODO(sanjoy): Refactor this code into helper functions.
|
||||
bool return_null = false;
|
||||
CompileState state = entry->compile_state;
|
||||
if (state == CompileState::kUncompiled) {
|
||||
XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
|
||||
const bool should_compile = [&] {
|
||||
if (!compile_threshold.has_value()) {
|
||||
if (compile_mode == CompileMode::kStrict) {
|
||||
// Lazy compilation is disabled.
|
||||
return true;
|
||||
}
|
||||
|
|
@ -418,7 +544,7 @@ Status XlaCompilationCache::CompileImpl(
|
|||
BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
|
||||
function.name())
|
||||
.IgnoreError();
|
||||
VLOG(3) << "Not compiling cluster " << function.name()
|
||||
VLOG(2) << "Not compiling cluster " << function.name()
|
||||
<< " because it is megamorphic.";
|
||||
return false;
|
||||
}
|
||||
|
|
@ -427,10 +553,21 @@ Status XlaCompilationCache::CompileImpl(
|
|||
return true;
|
||||
}
|
||||
|
||||
if (compile_mode == CompileMode::kAsync) {
|
||||
// Asynchronous compilation is enabled.
|
||||
mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
|
||||
if (async_compilation_state_.num_ongoing_compilations >=
|
||||
async_compilation_state_.kMaxNumOngoingCompilations) {
|
||||
VLOG(2) << "Not asynchronously compiling cluster " << function.name()
|
||||
<< " because of too many ongoing compilations.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool reached_compile_threshold =
|
||||
current_request_count >= *compile_threshold;
|
||||
if (!reached_compile_threshold) {
|
||||
VLOG(3)
|
||||
VLOG(2)
|
||||
<< "Not compiling cluster " << function.name()
|
||||
<< " because it has not reached compile threshold; threshold is "
|
||||
<< *compile_threshold << " execution count "
|
||||
|
|
@ -440,62 +577,34 @@ Status XlaCompilationCache::CompileImpl(
|
|||
}();
|
||||
|
||||
if (!should_compile) {
|
||||
VLOG(2) << "Not compiling for signature: " << signature.HumanString();
|
||||
*out_compilation_result = nullptr;
|
||||
*out_executable = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
const uint64 compile_start_us = env->NowMicros();
|
||||
// Do the actual JIT compilation without holding the lock (it can take
|
||||
// a long time.)
|
||||
|
||||
XlaCompiler compiler(options);
|
||||
entry->compiled = true;
|
||||
|
||||
entry->compilation_status =
|
||||
compile_fn(&compiler, &entry->compilation_result);
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
CHECK_EQ(entry->executable.get(), nullptr);
|
||||
entry->compilation_status =
|
||||
BuildExecutable(options, entry->compilation_result, &entry->executable);
|
||||
|
||||
const uint64 compile_end_us = env->NowMicros();
|
||||
const uint64 compile_time_us = compile_end_us - compile_start_us;
|
||||
metrics::UpdateXlaCompilationTime(compile_time_us);
|
||||
{
|
||||
mutex_lock lock(cluster_compile_stats_mu_);
|
||||
auto it = cluster_compile_stats_.find(function.name());
|
||||
it->second.compile_count++;
|
||||
it->second.cumulative_compile_time_us += compile_time_us;
|
||||
LogOnceXlaCompiledFirstCluster();
|
||||
VLOG(1) << "compiled " << function.name() << " "
|
||||
<< it->second.compile_count
|
||||
<< " times, compile time: " << compile_time_us
|
||||
<< " us, cumulative: " << it->second.cumulative_compile_time_us
|
||||
<< " us ("
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
|
||||
1.0e6)
|
||||
<< " / "
|
||||
<< tensorflow::strings::HumanReadableElapsedTime(
|
||||
it->second.cumulative_compile_time_us / 1.0e6)
|
||||
<< ")";
|
||||
|
||||
XlaJitCompilationActivity jit_compilation_activity;
|
||||
jit_compilation_activity.set_cluster_name(function.name());
|
||||
jit_compilation_activity.set_compile_count(it->second.compile_count);
|
||||
jit_compilation_activity.set_compile_time_us(compile_time_us);
|
||||
jit_compilation_activity.set_cumulative_compile_time_us(
|
||||
it->second.cumulative_compile_time_us);
|
||||
|
||||
VLOG(2) << "Not compiling for signature: " << human_signature;
|
||||
return_null = true;
|
||||
} else if (compile_mode == CompileMode::kAsync) {
|
||||
VLOG(2) << "Queueing asynchronous compilation for signature: "
|
||||
<< human_signature;
|
||||
TF_RETURN_IF_ERROR(CompileAsynchronous(entry, options, args,
|
||||
function.name(), compile_fn));
|
||||
return_null = true;
|
||||
} else {
|
||||
VLOG(2) << "Instantly compiling for signature: " << human_signature;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastXlaActivity(std::move(jit_compilation_activity)));
|
||||
CompileStrict(entry, options, args, function.name(), compile_fn));
|
||||
}
|
||||
} else if (state == CompileState::kCompiling) {
|
||||
VLOG(2) << "Ongoing asynchronous compilation for signature: "
|
||||
<< human_signature;
|
||||
return_null = true;
|
||||
} else if (state == CompileState::kCompiled) {
|
||||
VLOG(2) << "Already Compiled for signature: " << human_signature;
|
||||
}
|
||||
if (return_null) {
|
||||
*out_compilation_result = nullptr;
|
||||
*out_executable = nullptr;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
*out_compilation_result = &entry->compilation_result;
|
||||
*out_executable = entry->executable.get();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(entry->compilation_status);
|
||||
*out_compilation_result = &entry->compilation_result;
|
||||
*out_executable = entry->executable.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,13 @@ class XlaCompilationCache : public ResourceBase {
|
|||
enum class CompileMode {
|
||||
kLazy,
|
||||
kStrict,
|
||||
kAsync,
|
||||
};
|
||||
|
||||
enum class CompileState {
|
||||
kUncompiled,
|
||||
kCompiling,
|
||||
kCompiled,
|
||||
};
|
||||
|
||||
// Compiles a function into a XlaCompiler::CompilationResult that can be used
|
||||
|
|
@ -62,7 +69,9 @@ class XlaCompilationCache : public ResourceBase {
|
|||
// heuristics, the compilation cache may decide not to compile the cluster at
|
||||
// this time. In this case it returns null into both `out_compilation_result`
|
||||
// and `out_executable`. If `compile_mode` is `kStrict` then the compilation
|
||||
// cache always attempts the compilation on a cache miss.
|
||||
// cache always attempts the compilation on a cache miss. If compilation mode
|
||||
// is 'kAsync' compilation of the cluster happens in the background while the
|
||||
// fallback path executes.
|
||||
//
|
||||
// The result of compilation is written to `*out_compilation_result`, which
|
||||
// must be non-null. If `out_executable` is non-null, also builds an
|
||||
|
|
@ -71,7 +80,7 @@ class XlaCompilationCache : public ResourceBase {
|
|||
// non-constant outputs.
|
||||
Status Compile(const XlaCompiler::Options& options,
|
||||
const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
CompileMode compile_mode,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
|
|
@ -83,7 +92,7 @@ class XlaCompilationCache : public ResourceBase {
|
|||
// XlaCompiler, if possible.
|
||||
Status CompileSingleOp(
|
||||
const XlaCompiler::Options& options,
|
||||
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
|
||||
const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
|
||||
const XlaCompiler::CompileOptions& compile_options,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable);
|
||||
|
|
@ -126,10 +135,11 @@ class XlaCompilationCache : public ResourceBase {
|
|||
// Common implementation of Compile and CompileSingleOp.
|
||||
Status CompileImpl(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn,
|
||||
absl::optional<int64> compile_threshold,
|
||||
CompileMode compile_mode,
|
||||
const XlaCompiler::CompilationResult** out_compilation_result,
|
||||
xla::LocalExecutable** out_executable);
|
||||
|
||||
|
|
@ -146,8 +156,8 @@ class XlaCompilationCache : public ResourceBase {
|
|||
struct Entry {
|
||||
mutex mu;
|
||||
|
||||
// Have we tried compiling this entry?
|
||||
bool compiled = false;
|
||||
// The current compilation state for this entry.
|
||||
CompileState compile_state = CompileState::kUncompiled;
|
||||
|
||||
// The number of times a compilation with this signature has been requested.
|
||||
int64 request_count = 0;
|
||||
|
|
@ -163,6 +173,22 @@ class XlaCompilationCache : public ResourceBase {
|
|||
std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
Status CompileStrict(
|
||||
Entry* entry, const XlaCompiler::Options& options,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const string& function_name,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu);
|
||||
Status CompileAsynchronous(
|
||||
Entry* entry, const XlaCompiler::Options& options,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const string& function_name,
|
||||
const std::function<Status(XlaCompiler* compiler,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult*)>& compile_fn);
|
||||
|
||||
mutex compile_cache_mu_;
|
||||
absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
||||
TF_GUARDED_BY(compile_cache_mu_);
|
||||
|
|
@ -189,6 +215,30 @@ class XlaCompilationCache : public ResourceBase {
|
|||
absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
|
||||
TF_GUARDED_BY(cluster_compile_stats_mu_);
|
||||
|
||||
struct AsyncCompilationState {
|
||||
mutex async_compilation_state_mu;
|
||||
|
||||
// Number of threads for asynchronous compilations.
|
||||
static constexpr int64 kNumCompilerThreads = 10;
|
||||
|
||||
// Maximum number of ongoing compilations.
|
||||
static constexpr int64 kMaxNumOngoingCompilations = kNumCompilerThreads;
|
||||
|
||||
// Number of ongoing compilations.
|
||||
int64 num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) =
|
||||
0;
|
||||
|
||||
// Pool of threads for asynchronous compilations.
|
||||
std::unique_ptr<thread::ThreadPool> compiler_threads;
|
||||
|
||||
AsyncCompilationState() {
|
||||
compiler_threads = absl::make_unique<tensorflow::thread::ThreadPool>(
|
||||
tensorflow::Env::Default(), "async_compiler_threads",
|
||||
kNumCompilerThreads);
|
||||
}
|
||||
|
||||
} async_compilation_state_;
|
||||
|
||||
// The number of times a lazy compilation must be requested for a specific
|
||||
// signature before we attempt to compile it.
|
||||
static constexpr int64 kDefaultCompilationThreshold = 2;
|
||||
|
|
|
|||
|
|
@ -48,11 +48,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
|||
const ResourceVarsSnapshot& variable_args) {
|
||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_);
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
|
||||
GetAllocator(ctx->device(), stream, platform_info_);
|
||||
se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, allocator, client->default_device_ordinal(),
|
||||
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
|
||||
|
|
@ -74,9 +74,6 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
|||
input_output_alias);
|
||||
TF_RETURN_IF_ERROR(execution_inputs.status());
|
||||
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
VLOG(2) << "Executing computation: " << name();
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
|
|
@ -126,12 +123,10 @@ Status XlaCompileOnDemandOp::Compile(
|
|||
write_into_cache);
|
||||
}));
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options = GenerateCompilerOptions(
|
||||
**cache, *ctx->function_library(), ctx->device(),
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
|
||||
platform_info_,
|
||||
/*has_ref_vars=*/true, &tf_allocator_adapter);
|
||||
platform_info_, /*has_ref_vars=*/true);
|
||||
// No detailed logging from on demand op.
|
||||
options.detailed_logging = false;
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
|
|
|
|||
|
|
@ -214,14 +214,18 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
|||
// Fills in `execution_input` with `buffer` for `index`.
|
||||
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
|
||||
xla::ShapeIndex index,
|
||||
se::DeviceMemoryBase& buffer,
|
||||
se::DeviceMemoryBase buffer,
|
||||
bool donate_buffer, int device_ordinal,
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
xla::MaybeOwningDeviceMemory* in_buffer =
|
||||
execution_input.MutableBuffer(index);
|
||||
if (donate_buffer) {
|
||||
// Here we pass ownership of the buffer to execution_input without releasing
|
||||
// ownership from the caller of PopulateExecutionInputBuffer. If execution
|
||||
// succeeds, we'll take back that duplicate ownership in
|
||||
// GetOrCreateTensorForOutput. If execution fails, the ExecutionInput will
|
||||
// release that duplicate ownership automatically.
|
||||
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
|
||||
buffer = se::DeviceMemoryBase();
|
||||
} else {
|
||||
*in_buffer = buffer;
|
||||
}
|
||||
|
|
@ -308,18 +312,21 @@ static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
|||
return t;
|
||||
}
|
||||
|
||||
// Get aliased tensor, or make a new one for the corresponding output operation.
|
||||
static Tensor GetOrCreateTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
// Get aliased tensor from output, or make a new one for the corresponding
|
||||
// output operation. Transfers ownership of the buffer from output to the
|
||||
// returned tensor.
|
||||
static xla::StatusOr<Tensor> GetOrCreateTensorForOutput(
|
||||
xla::ScopedShapedBuffer& output, int output_num, OpKernelContext* ctx,
|
||||
int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, const Tensor*>& resource_vars_snapshots,
|
||||
DataType output_dtype, const TensorShape& output_shape,
|
||||
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||
Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream,
|
||||
bool use_multiple_streams, std::shared_ptr<se::Event> definition_event) {
|
||||
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
|
||||
? xla::ShapeIndex({output_num})
|
||||
: xla::ShapeIndex({});
|
||||
|
||||
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
|
||||
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
|
||||
input_output_alias.GetAliasedParameter(output_index)) {
|
||||
|
|
@ -330,24 +337,39 @@ static Tensor GetOrCreateTensorForOutput(
|
|||
ctx->input(tf_param).dtype() != DT_RESOURCE
|
||||
? ctx->input(tf_param)
|
||||
: *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
|
||||
if (output_buffer.opaque() == input_tensor.data()) {
|
||||
se::DeviceMemoryBase input_buffer =
|
||||
XlaTensor::DeviceMemoryFromTensor(input_tensor);
|
||||
se::DeviceMemoryBase output_buffer = output.buffer({output_num});
|
||||
if (input_buffer.opaque() == output_buffer.opaque()) {
|
||||
// In the case of a donated buffer, both input_tensor and output think
|
||||
// they have ownership of the buffer (see comment in
|
||||
// PopulateExecutionInputBuffer). Release ownership from output to avoid
|
||||
// double free.
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
return input_tensor;
|
||||
}
|
||||
}
|
||||
return MakeTensor(output_dtype, output_shape, output_buffer,
|
||||
output_allocator);
|
||||
}
|
||||
|
||||
static void PopulateXlaTensor(Tensor* output_tensor,
|
||||
xla::ScopedShapedBuffer* output, int output_num,
|
||||
se::Stream* stream, bool use_multiple_streams,
|
||||
std::shared_ptr<se::Event> definition_event) {
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
CHECK(xla_tensor);
|
||||
xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num}));
|
||||
if (use_multiple_streams) {
|
||||
xla_tensor->ResetDefinitionEvent(definition_event, stream);
|
||||
if (allocate_xla_tensors) {
|
||||
Tensor output_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(output_dtype, output_shape, &output_tensor));
|
||||
if (output_tensor.TotalBytes() > 0) {
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
|
||||
TF_RET_CHECK(xla_tensor);
|
||||
xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
|
||||
if (use_multiple_streams) {
|
||||
xla_tensor->ResetDefinitionEvent(definition_event, stream);
|
||||
}
|
||||
}
|
||||
return output_tensor;
|
||||
}
|
||||
|
||||
se::DeviceMemoryBase output_buffer = output.buffer({output_num});
|
||||
Tensor output_tensor =
|
||||
MakeTensor(output_dtype, output_shape, output_buffer, output_allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
return output_tensor;
|
||||
}
|
||||
|
||||
// Sets output `output_num` for `ctx` provided it is known at a compile time.
|
||||
|
|
@ -525,22 +547,15 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (allocate_xla_tensors_) {
|
||||
Tensor* output_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
||||
if (output_tensor->TotalBytes() > 0) {
|
||||
PopulateXlaTensor(output_tensor, &output, output_num, stream,
|
||||
use_multiple_streams_, definition_event);
|
||||
}
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_vars,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Tensor output_tensor,
|
||||
GetOrCreateTensorForOutput(
|
||||
output, output_num, ctx, missing_ctx_input_prefix,
|
||||
input_output_alias, compilation_result->input_mapping,
|
||||
resource_vars, ctx->expected_output_dtype(i), shape, allocator,
|
||||
allocate_xla_tensors_, stream, use_multiple_streams_,
|
||||
definition_event));
|
||||
ctx->set_output(i, output_tensor);
|
||||
++output_num;
|
||||
}
|
||||
}
|
||||
|
|
@ -571,22 +586,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||
return errors::Internal("Mismatched type in variable write");
|
||||
}
|
||||
|
||||
Tensor output_tensor;
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_temp(write.type, write.shape, &output_tensor));
|
||||
if (write.shape.num_elements() > 0) {
|
||||
PopulateXlaTensor(&output_tensor, &output, output_num, stream,
|
||||
use_multiple_streams_, definition_event);
|
||||
}
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_vars, write.type,
|
||||
write.shape, buffer, allocator);
|
||||
}
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Tensor output_tensor,
|
||||
GetOrCreateTensorForOutput(output, output_num, ctx,
|
||||
missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping,
|
||||
resource_vars, write.type, write.shape,
|
||||
allocator, allocate_xla_tensors_, stream,
|
||||
use_multiple_streams_, definition_event));
|
||||
var->is_initialized |= write.modified;
|
||||
*var->tensor() = output_tensor;
|
||||
++output_num;
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
|
|||
auto device = static_cast<Device*>(device_base);
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
|
||||
|
||||
if (device->device_type() == DEVICE_CPU) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
|
|
@ -101,37 +101,35 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
|
|||
// allocator to allocate real buffers.
|
||||
platform_id = xla_device_metadata->platform()->id();
|
||||
custom_allocator =
|
||||
xla_device_metadata->client()->backend().memory_allocator();
|
||||
xla_device_metadata->client()->backend().shared_memory_allocator();
|
||||
}
|
||||
|
||||
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
|
||||
xla_device_metadata, custom_allocator);
|
||||
}
|
||||
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
|
||||
DeviceBase* device, se::Stream* stream,
|
||||
const XlaPlatformInfo& platform_info) {
|
||||
if (platform_info.custom_allocator()) {
|
||||
return platform_info.custom_allocator();
|
||||
}
|
||||
auto* alloc = device->GetAllocator({});
|
||||
if (!stream) {
|
||||
// Stream is not set for the host platform.
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
|
||||
.ValueOrDie();
|
||||
tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
|
||||
return &tf_allocator_adapter->value();
|
||||
return std::make_shared<se::TfAllocatorAdapter>(alloc, platform);
|
||||
}
|
||||
tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
|
||||
return &tf_allocator_adapter->value();
|
||||
return std::make_shared<se::TfAllocatorAdapter>(alloc, stream);
|
||||
}
|
||||
|
||||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
const XlaCompilationCache& cache,
|
||||
const FunctionLibraryRuntime& function_library, DeviceBase* device,
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info,
|
||||
bool has_ref_vars) {
|
||||
XlaCompiler::Options options;
|
||||
options.client = static_cast<xla::LocalClient*>(cache.client());
|
||||
if (stream != nullptr) {
|
||||
|
|
@ -142,8 +140,7 @@ XlaCompiler::Options GenerateCompilerOptions(
|
|||
options.graph_def_version = function_library.graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
(platform_info.platform_id() == se::host::kHostPlatformId);
|
||||
options.device_allocator =
|
||||
GetAllocator(tf_allocator_adapter, device, stream, platform_info);
|
||||
options.device_allocator = GetAllocator(device, stream, platform_info);
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
|
|
|
|||
|
|
@ -29,10 +29,10 @@ class XlaPlatformInfo {
|
|||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
explicit XlaPlatformInfo(
|
||||
const DeviceType device_type, se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
xla_device_metadata_(xla_device_metadata),
|
||||
|
|
@ -45,7 +45,7 @@ class XlaPlatformInfo {
|
|||
}
|
||||
|
||||
// Non-null only when run on an XLA device.
|
||||
se::DeviceMemoryAllocator* custom_allocator() const {
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
|
||||
|
|
@ -74,7 +74,9 @@ class XlaPlatformInfo {
|
|||
// If the op associated with this XlaPlatformInfo is placed on an XLA device
|
||||
// then device_allocator_ is the xla::Backend's memory allocator. If the op
|
||||
// is placed on a regular CPU or GPU device then device_allocator_ is null.
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
// The allocator is of unknown provenance; keep it in a shared pointer to
|
||||
// set an artificial refcount of one.
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
};
|
||||
|
|
@ -94,8 +96,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
|
|||
// dummy tensors.
|
||||
//
|
||||
// `stream` parameter is nullable when running on host.
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator(
|
||||
DeviceBase* device, se::Stream* stream,
|
||||
const XlaPlatformInfo& platform_info);
|
||||
|
||||
|
|
@ -104,8 +105,8 @@ se::DeviceMemoryAllocator* GetAllocator(
|
|||
XlaCompiler::Options GenerateCompilerOptions(
|
||||
const XlaCompilationCache& cache,
|
||||
const FunctionLibraryRuntime& function_library, DeviceBase* device,
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
|
||||
se::Stream* stream, const XlaPlatformInfo& platform_info,
|
||||
bool has_ref_vars);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ td_library(
|
|||
compatible_with = get_compatible_with_cloud(),
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:MemRefOpsTdFiles",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
],
|
||||
|
|
@ -445,6 +446,7 @@ cc_library(
|
|||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:CopyOpInterface",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
|
@ -598,6 +600,7 @@ cc_library(
|
|||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
|
@ -707,6 +710,7 @@ cc_library(
|
|||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
|
@ -939,6 +943,7 @@ cc_library(
|
|||
":hlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TensorDialect",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
|
|
|
|||
|
|
@ -119,31 +119,30 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> {
|
|||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||
Type TensorType>: HLO_Op<mnemonic,
|
||||
!listconcat(traits,
|
||||
[InferShapedTypeOpInterface, InferFusibilityOpInterface,
|
||||
SameOperandsAndResultShape])> {
|
||||
let arguments = (ins TensorType:$operand);
|
||||
let results = (outs TensorType);
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
return failure();
|
||||
}
|
||||
LogicalResult reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
return true;
|
||||
}
|
||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||
return getOperation()->getResult(0);
|
||||
}
|
||||
}];
|
||||
Type TensorType> : HLO_Op<mnemonic, traits # [Elementwise,
|
||||
InferShapedTypeOpInterface, InferFusibilityOpInterface,
|
||||
SameOperandsAndResultShape]> {
|
||||
let arguments = (ins TensorType:$operand);
|
||||
let results = (outs TensorType);
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypeComponents(
|
||||
MLIRContext* context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||
return failure();
|
||||
}
|
||||
LogicalResult reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
return true;
|
||||
}
|
||||
llvm::Optional<Value> inferEffectiveWorkloadShape() {
|
||||
return getOperation()->getResult(0);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// Abs supports complex to real, so element type is not guaranteed to match.
|
||||
|
|
@ -826,8 +825,9 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
|
|||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
|
||||
[NoSideEffect]> {
|
||||
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [
|
||||
NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapes"]>]> {
|
||||
string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
|
||||
string description = [{
|
||||
This is a generalization of the BroadcastInDimOp which accepts its output
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#ifndef LHLO_OPS
|
||||
#define LHLO_OPS
|
||||
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/CopyOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
|
@ -685,7 +686,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||
let extraClassDeclaration = [{
|
||||
SmallVector<Value, 4> getInputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorLoadOp load) {
|
||||
this->region().walk([&](memref::TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load.memref());
|
||||
});
|
||||
|
|
@ -694,7 +695,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||
|
||||
SmallVector<Value, 4> getOutputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorStoreOp store) {
|
||||
this->region().walk([&](memref::TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.memref());
|
||||
});
|
||||
|
|
@ -703,7 +704,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||
|
||||
SmallVector<Value, 4> getFusionParameters() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorLoadOp load) {
|
||||
this->region().walk([&](memref::TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load);
|
||||
});
|
||||
|
|
@ -712,7 +713,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||
|
||||
SmallVector<Value, 4> getFusionResults() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorStoreOp store) {
|
||||
this->region().walk([&](memref::TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.tensor());
|
||||
});
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#ifndef LHLO_OPS_BASE
|
||||
#define LHLO_OPS_BASE
|
||||
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
||||
|
|
|
|||
|
|
@ -597,12 +597,18 @@ OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TupleOp op) {
|
||||
SmallVector<Type, 8> operandTypes = {op.operand_type_begin(),
|
||||
op.operand_type_end()};
|
||||
auto expectedType = TupleType::get(op.getContext(), operandTypes);
|
||||
if (op.getType() != expectedType) {
|
||||
return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}",
|
||||
op.getType(), expectedType));
|
||||
auto opType = op.getType().dyn_cast<TupleType>();
|
||||
if (!opType) return op.emitOpError("tuple op with non-tuple result");
|
||||
if (op.getNumOperands() != opType.size())
|
||||
return op.emitOpError(
|
||||
"number of operands to tuple expected to match number of types in "
|
||||
"resultant tuple type");
|
||||
for (auto it : llvm::enumerate(
|
||||
llvm::zip_first(op.getOperandTypes(), opType.getTypes()))) {
|
||||
if (std::get<0>(it.value()) != std::get<1>(it.value()))
|
||||
return op.emitOpError("has return type mismatch at ")
|
||||
<< it.index() << "th value (" << std::get<0>(it.value())
|
||||
<< " != " << std::get<1>(it.value()) << ")";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
|
@ -902,6 +908,18 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
|
|||
context);
|
||||
}
|
||||
|
||||
LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
|
||||
MLIRContext*, llvm::Optional<mlir::Location>, ValueRange, DictionaryAttr,
|
||||
RegionRange, llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
|
||||
OpBuilder&, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
reifiedReturnShapes.push_back(output_dimensions());
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ClampOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -36,12 +36,13 @@ def DynamicBroadcastToOwnShape_4 : Pat<
|
|||
(HLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr),
|
||||
(Tensor_CastOp $x)>;
|
||||
|
||||
def ShapeOfDynamicReshape : Pat<
|
||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||
(replaceWithValue $shape)>;
|
||||
|
||||
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
|
||||
def ShapeOfDynamicReshape : Pat<
|
||||
(Shape_ShapeOfOp:$op (HLO_DynamicReshapeOp $x, $shape)),
|
||||
(replaceWithValue $shape),
|
||||
[(HasSameType $shape, $op)]>;
|
||||
|
||||
def IdentityBroadcastReshape : Pat<
|
||||
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
|
||||
(replaceWithValue $input),
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
|
@ -156,13 +157,13 @@ struct EraseConstOp : public OpRewritePattern<ConstOp> {
|
|||
LogicalResult matchAndRewrite(ConstOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Value memref = op.output();
|
||||
if (!memref.getDefiningOp<AllocOp>()) {
|
||||
if (!memref.getDefiningOp<memref::AllocOp>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that all uses of the memref are either DeallocOps or this op.
|
||||
for (Operation* user : memref.getUsers())
|
||||
if (user != op && !isa<DeallocOp>(user)) return failure();
|
||||
if (user != op && !isa<memref::DeallocOp>(user)) return failure();
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
|||
dynamic_operands.push_back(alloc_operand);
|
||||
}
|
||||
|
||||
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||
return rewriter->create<memref::AllocOp>(loc, memref_type, dynamic_operands);
|
||||
}
|
||||
|
||||
Value InsertAlloc(Location loc, OpResult result,
|
||||
|
|
@ -85,7 +85,7 @@ Value InsertAlloc(Location loc, OpResult result,
|
|||
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||
OpBuilder::InsertionGuard guard(*rewriter);
|
||||
rewriter->setInsertionPoint(result.getDefiningOp());
|
||||
auto alloc = rewriter->create<AllocOp>(loc, memref_type);
|
||||
auto alloc = rewriter->create<memref::AllocOp>(loc, memref_type);
|
||||
return alloc;
|
||||
}
|
||||
|
||||
|
|
@ -207,7 +207,7 @@ class HloToLhloReshapeUnrankedConverter
|
|||
if (unranked_operand_type == nullptr) return failure();
|
||||
|
||||
auto result_type = op.getType().cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||
rewriter.replaceOpWithNewOp<memref::CastOp>(
|
||||
op, adaptor.operand(),
|
||||
MemRefType::get(result_type.getShape(), result_type.getElementType()));
|
||||
return success();
|
||||
|
|
@ -235,7 +235,7 @@ class HloToLhloDynamicReshapeConverter
|
|||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<memref::ReshapeOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
|
|
@ -273,7 +273,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
|
|||
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
|
||||
memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
|
|
@ -295,7 +295,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
|
|||
for (int i = operand_rank - 1; i >= 0; --i) {
|
||||
Value operand_dim_size =
|
||||
ShapedType::isDynamic(operand_shape[i])
|
||||
? b->create<DimOp>(loc, operand, i).getResult()
|
||||
? b->create<memref::DimOp>(loc, operand, i).getResult()
|
||||
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
|
||||
operand_sizes[i] = operand_dim_size;
|
||||
|
||||
|
|
@ -355,7 +355,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter
|
|||
makeStridedLinearLayoutMap(dynamic_layout,
|
||||
/*offset=*/0, b->getContext()));
|
||||
|
||||
auto transformed_operand = b->create<MemRefReinterpretCastOp>(
|
||||
auto transformed_operand = b->create<memref::ReinterpretCastOp>(
|
||||
loc, type_erased_memref_type, operand,
|
||||
/*offset=*/b->getI64IntegerAttr(0), sizes, strides);
|
||||
return transformed_operand;
|
||||
|
|
@ -484,12 +484,12 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
|
|||
|
||||
// TODO(b/175789537) Remove this pattern.
|
||||
class HloToLhloTensorStoreOpLegacyConverter
|
||||
: public BaseOpConversion<mlir::TensorStoreOp> {
|
||||
: public BaseOpConversion<mlir::memref::TensorStoreOp> {
|
||||
public:
|
||||
using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
|
||||
using BaseOpConversion<mlir::memref::TensorStoreOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mlir::TensorStoreOp op, ArrayRef<Value> operands,
|
||||
mlir::memref::TensorStoreOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
|
||||
operands.back());
|
||||
|
|
@ -577,14 +577,16 @@ struct HloLegalizeToLhlo
|
|||
ConversionTarget target(context);
|
||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
// Declare tensor_load and tensor_store illegal.
|
||||
target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>();
|
||||
// tensor_to_memref is illegal if it has uses.
|
||||
// TODO(b/175670649) Make tensor_to_memref illegal.
|
||||
target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>(
|
||||
target.addIllegalOp<mlir::memref::TensorLoadOp,
|
||||
mlir::memref::TensorStoreOp>();
|
||||
// buffer_cast is illegal if it has uses.
|
||||
// TODO(b/175670649) Make buffer_cast illegal.
|
||||
target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
|
||||
[](auto op) { return op->use_empty(); });
|
||||
|
||||
BufferizeTypeConverter converter;
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
|
|||
dyn_sizes.push_back(
|
||||
b.create<IndexCastOp>(loc, b.getIndexType(), extract));
|
||||
} else {
|
||||
dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
|
||||
dyn_sizes.push_back(b.create<memref::DimOp>(loc, tensor, en.index()));
|
||||
}
|
||||
}
|
||||
return dyn_sizes;
|
||||
|
|
@ -324,13 +324,13 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||
}
|
||||
|
||||
// Create two loads from the input.
|
||||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||
auto lhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.rhs());
|
||||
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
||||
Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
|
||||
lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||
&rewriter);
|
||||
rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
|
||||
rewriter.create<memref::StoreOp>(loc, op_result, lhlo_op.out());
|
||||
rewriter.eraseOp(lhlo_op);
|
||||
return success();
|
||||
}
|
||||
|
|
@ -518,15 +518,20 @@ class HloDynamicBroadcastInDimConverter
|
|||
LogicalResult matchAndRewrite(
|
||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Convert only if the producer is an HLO constant. Ideally the pattern
|
||||
// (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted
|
||||
// to an Tensor-dialect op similar to TF ConstantLikeOp.
|
||||
if (!op.operand().getDefiningOp<mhlo::ConstOp>()) return failure();
|
||||
// If the input has a static shape we know exactly when the broadcast must
|
||||
// expand (the dimension is 1, which also trivially expands to 1) or will
|
||||
// never expand (the dimension is not 1). This means we can lower the
|
||||
// broadcast just as we would lower a fully static broadcast and go directly
|
||||
// to linalg.generic. This also covers the important case of broadcasting a
|
||||
// scalar.
|
||||
|
||||
// Ideally the pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`)
|
||||
// should be converted to an Tensor-dialect op similar to TF ConstantLikeOp.
|
||||
|
||||
mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
|
||||
Value operand = adaptor.operand();
|
||||
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
if (!operand_type || operand_type.getRank() != 0) return failure();
|
||||
if (!operand_type || !operand_type.hasStaticShape()) return failure();
|
||||
|
||||
Value shape = adaptor.output_dimensions();
|
||||
auto shape_type = shape.getType().cast<RankedTensorType>();
|
||||
|
|
@ -544,13 +549,27 @@ class HloDynamicBroadcastInDimConverter
|
|||
}
|
||||
|
||||
int64_t nloops = result_type.getRank();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
SmallVector<AffineExpr, 4> dim_exprs;
|
||||
dim_exprs.reserve(nloops);
|
||||
|
||||
if (op.broadcast_dimensions()) {
|
||||
for (const auto& broadcast_dim :
|
||||
enumerate(op.broadcast_dimensions().getIntValues())) {
|
||||
int64_t size = broadcast_dim.value().getSExtValue();
|
||||
bool expansion_needed = operand_shape[broadcast_dim.index()] == 1;
|
||||
dim_exprs.push_back(expansion_needed ? rewriter.getAffineConstantExpr(0)
|
||||
: rewriter.getAffineDimExpr(size));
|
||||
}
|
||||
}
|
||||
|
||||
Value init = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, dyn_dims, result_type.getShape(), result_type.getElementType());
|
||||
Operation* generic = rewriter.create<linalg::GenericOp>(
|
||||
loc, TypeRange{init.getType()}, ValueRange{operand},
|
||||
/*outputBuffers=*/ValueRange{init},
|
||||
llvm::makeArrayRef(
|
||||
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {},
|
||||
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dim_exprs,
|
||||
rewriter.getContext()),
|
||||
rewriter.getMultiDimIdentityMap(nloops)}),
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
|
|
@ -590,8 +609,8 @@ class LhloBroadcastInDimConverter
|
|||
operand_type.getDimSize(0) <
|
||||
result_type.getDimSize(broadcast_dims.front())) {
|
||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value val =
|
||||
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
||||
Value val = rewriter.create<memref::LoadOp>(loc, operand,
|
||||
llvm::makeArrayRef({zero}));
|
||||
rewriter.create<linalg::GenericOp>(
|
||||
loc, /*inputs=*/ValueRange{},
|
||||
/*outputBuffers=*/ValueRange{operand_adaptor.output()},
|
||||
|
|
@ -971,7 +990,8 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
}
|
||||
|
||||
// First fill the output buffer with the init value.
|
||||
Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
|
||||
Value init_value =
|
||||
rewriter.create<memref::LoadOp>(loc, adaptor.init_values()[0]);
|
||||
rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
|
||||
|
||||
DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
|
||||
|
|
@ -1011,9 +1031,9 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
// expects scalar SSA values. Add some allocs around the original op to
|
||||
// make it compatible.
|
||||
auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
|
||||
Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
|
||||
Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
|
||||
Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
|
||||
Value alloc_a = rewriter.create<memref::AllocaOp>(loc, arg_type);
|
||||
Value alloc_b = rewriter.create<memref::AllocaOp>(loc, arg_type);
|
||||
Value alloc_res = rewriter.create<memref::AllocaOp>(loc, arg_type);
|
||||
|
||||
// Now turn the existing signature
|
||||
// (memref<X>, memref<X>, memref<X>) -> ()
|
||||
|
|
@ -1030,13 +1050,15 @@ class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
|
||||
// Store the arguments into the newly allocated buffers.
|
||||
rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
|
||||
rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
|
||||
rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
|
||||
rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(0),
|
||||
alloc_a);
|
||||
rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(1),
|
||||
alloc_b);
|
||||
rewriter.replaceOp(entry_block->getTerminator(), {});
|
||||
|
||||
// Load & yield the result.
|
||||
rewriter.setInsertionPointToEnd(entry_block);
|
||||
auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
|
||||
auto load_res = rewriter.create<memref::LoadOp>(loc, alloc_res);
|
||||
rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
|
||||
}
|
||||
|
||||
|
|
@ -1099,8 +1121,8 @@ class SliceConverter : public OpConversionPattern<OpTy> {
|
|||
slice_op.strides().template getValue<int64_t>(i)));
|
||||
}
|
||||
if (isLHLO) {
|
||||
auto linalg_op =
|
||||
rewriter.create<SubViewOp>(loc, args[0], offsets, sizes, strides);
|
||||
auto linalg_op = rewriter.create<memref::SubViewOp>(loc, args[0], offsets,
|
||||
sizes, strides);
|
||||
rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
|
||||
rewriter.eraseOp(slice_op);
|
||||
} else {
|
||||
|
|
@ -1149,14 +1171,14 @@ SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
|
|||
switch (type) {
|
||||
case DotOperationType::kMatrixMatrix: {
|
||||
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
|
||||
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
|
||||
dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 1));
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kMatrixVector: {
|
||||
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
|
||||
break;
|
||||
}
|
||||
case DotOperationType::kVectorDot:
|
||||
|
|
@ -1203,11 +1225,11 @@ SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
|
|||
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
|
||||
SmallVector<Value, 8> dyn_shape;
|
||||
if (result_type.isDynamicDim(0))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
|
||||
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 0));
|
||||
if (result_type.isDynamicDim(1))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
|
||||
dyn_shape.push_back(b.create<memref::DimOp>(loc, lhs, 1));
|
||||
if (result_type.isDynamicDim(2))
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
|
||||
dyn_shape.push_back(b.create<memref::DimOp>(loc, rhs, 2));
|
||||
return dyn_shape;
|
||||
}
|
||||
|
||||
|
|
@ -1307,7 +1329,7 @@ SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
|
|||
for (int i = 0, j = 0; i < rank; ++i) {
|
||||
if (s.count(i)) continue;
|
||||
if (!result_type.isDynamicDim(j++)) continue;
|
||||
dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
|
||||
dyn_shape.push_back(b.create<memref::DimOp>(loc, arg, i));
|
||||
}
|
||||
|
||||
return dyn_shape;
|
||||
|
|
@ -1467,7 +1489,7 @@ struct NormalConvOpOnTensorsConversion
|
|||
// The output shape is N spatial_dims F.
|
||||
SmallVector<Value, 8> dyn_sizes;
|
||||
if (result_type.isDynamicDim(0)) {
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, 0));
|
||||
dyn_sizes.push_back(rewriter.create<memref::DimOp>(loc, input, 0));
|
||||
}
|
||||
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
|
||||
if (result_type.isDynamicDim(i)) {
|
||||
|
|
@ -1476,7 +1498,8 @@ struct NormalConvOpOnTensorsConversion
|
|||
}
|
||||
}
|
||||
if (result_type.isDynamicDim(rank - 1)) {
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
|
||||
dyn_sizes.push_back(
|
||||
rewriter.create<memref::DimOp>(loc, filter, rank - 1));
|
||||
}
|
||||
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||
|
|
@ -1769,6 +1792,99 @@ struct ReduceWindowOpOnTensorsConversion
|
|||
}
|
||||
};
|
||||
|
||||
/// Converts xla-hlo.torch_index_select op to a linalg.indexed_generic op.
|
||||
struct TorchIndexSelectOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::TorchIndexSelectOp> {
|
||||
using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::TorchIndexSelectOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
mhlo::TorchIndexSelectOp::Adaptor adaptor(args);
|
||||
int axis = static_cast<int>(op.dim());
|
||||
int batch = static_cast<int>(op.batch_dims());
|
||||
auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
|
||||
int num_indices = static_cast<int>(index_shaped_type.getRank());
|
||||
auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
|
||||
if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
|
||||
if (batch < 0) batch += num_indices;
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto result_type = op.getResult().getType().cast<ShapedType>();
|
||||
int rank = static_cast<int>(result_type.getRank());
|
||||
|
||||
SmallVector<AffineMap, 2> indexing_maps;
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
}
|
||||
for (int i = 0, e = num_indices - batch; i < e; ++i) {
|
||||
exprs.push_back(rewriter.getAffineDimExpr(axis + i));
|
||||
}
|
||||
indexing_maps.emplace_back(
|
||||
AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
|
||||
indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
|
||||
|
||||
// The output shape is
|
||||
// `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
|
||||
SmallVector<Value, 4> dyn_sizes;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (!result_type.isDynamicDim(i)) continue;
|
||||
if (i < axis) {
|
||||
dyn_sizes.push_back(
|
||||
rewriter.create<memref::DimOp>(loc, adaptor.input(), i));
|
||||
} else if (i < (axis + num_indices - batch)) {
|
||||
int idx = i - axis + batch;
|
||||
dyn_sizes.push_back(
|
||||
rewriter.create<memref::DimOp>(loc, adaptor.index(), idx));
|
||||
} else {
|
||||
int idx = i - (axis + num_indices - batch) + axis + 1;
|
||||
dyn_sizes.push_back(
|
||||
rewriter.create<memref::DimOp>(loc, adaptor.input(), idx));
|
||||
}
|
||||
}
|
||||
Value init_op = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||
auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
|
||||
loc, /*resultTensors=*/ArrayRef<Type>{result_type},
|
||||
/*inputs=*/adaptor.index(),
|
||||
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
|
||||
|
||||
SmallVector<Type, 4> body_arg_types;
|
||||
SmallVector<Value, 2> linalg_op_args = {adaptor.index()};
|
||||
// Add a block to the region.
|
||||
auto* region = &linalg_op.region();
|
||||
auto* block = rewriter.createBlock(region, region->end());
|
||||
body_arg_types.append(rank, rewriter.getIndexType());
|
||||
for (auto block_args : linalg_op_args) {
|
||||
body_arg_types.push_back(
|
||||
block_args.getType().cast<ShapedType>().getElementType());
|
||||
}
|
||||
block->addArguments(body_arg_types);
|
||||
block->addArguments(result_type.getElementType());
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
|
||||
SmallVector<Value, 4> indices;
|
||||
Value casted_value = rewriter.create<IndexCastOp>(
|
||||
loc, block->getArgument(rank), rewriter.getIndexType());
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
indices.push_back(block->getArgument(i));
|
||||
}
|
||||
indices.push_back(casted_value);
|
||||
for (int i = axis + num_indices - batch; i < rank; ++i) {
|
||||
indices.push_back(block->getArgument(i));
|
||||
}
|
||||
|
||||
Value res =
|
||||
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
|
||||
rewriter.create<linalg::YieldOp>(loc, res);
|
||||
|
||||
rewriter.replaceOp(op, linalg_op.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
|
|
@ -1856,8 +1972,8 @@ struct LhloLegalizeToLinalgPass
|
|||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||
math::MathDialect, StandardOpsDialect,
|
||||
AffineDialect>();
|
||||
math::MathDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect, AffineDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
|
|
@ -1881,6 +1997,9 @@ struct HloLegalizeToLinalgPass
|
|||
math::MathDialect, StandardOpsDialect,
|
||||
tensor::TensorDialect, scf::SCFDialect>();
|
||||
|
||||
// TODO: DimOp shouldn't be in MemRefDialect
|
||||
target.addLegalOp<memref::DimOp>();
|
||||
|
||||
auto func = getFunction();
|
||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
|
|
@ -1961,6 +2080,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
DepthwiseConvOpOnTensorsConversion,
|
||||
ReduceOnTensorsConversion,
|
||||
ReduceWindowOpOnTensorsConversion,
|
||||
TorchIndexSelectOpOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(context);
|
||||
// clang-format on
|
||||
patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
// This approximation resembles Eigen and realizes a constant approximation for
|
||||
// the +/-1 limits on top.
|
||||
// https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/MathFunctionsImpl.h
|
||||
class ApproximateTanhLowering
|
||||
: public ApproximateOnExtendedF32Lowering<math::TanhOp> {
|
||||
public:
|
||||
|
|
@ -83,42 +86,18 @@ class ApproximateTanhLowering
|
|||
// Emits the fast tanh approximation that is also used by XLA.
|
||||
Value emitApproximation(ValueRange args, Location loc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// For small values of x, we can approximate tanh(x) = x. For extremely
|
||||
// small values of x (|x| < 1e-37), the other approximation would evaluate
|
||||
// tanh(x) = 0.
|
||||
Value input = args.front();
|
||||
assert(input.getType().isF32());
|
||||
constexpr float kCanUseApprox = 0.0004;
|
||||
Value abs_value = rewriter.create<AbsFOp>(loc, input);
|
||||
Value can_use_approx = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(kCanUseApprox));
|
||||
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
|
||||
abs_value, can_use_approx);
|
||||
// Clamp the input to [-c, c].
|
||||
Value max_clamp = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(7.90531110763549805f));
|
||||
Value smaller_than_max =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
|
||||
Value clamped_half =
|
||||
rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
|
||||
Value min_clamp = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
|
||||
Value larger_than_min = rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE,
|
||||
clamped_half, min_clamp);
|
||||
Value input_clamped = rewriter.create<SelectOp>(loc, larger_than_min,
|
||||
clamped_half, min_clamp);
|
||||
|
||||
static constexpr std::array<float, 7> numerator_coeffs{
|
||||
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
||||
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
|
||||
4.89352455891786e-03f};
|
||||
|
||||
static constexpr std::array<float, 4> denominator_coeffs{
|
||||
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
|
||||
4.89352518554385e-03f};
|
||||
|
||||
Value input_squared =
|
||||
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
|
||||
// Materialize polynomial approximation.
|
||||
Value input_squared = rewriter.create<MulFOp>(loc, input, input);
|
||||
Value numerator = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
|
||||
for (int i = 1; i < numerator_coeffs.size(); i++) {
|
||||
|
|
@ -127,9 +106,7 @@ class ApproximateTanhLowering
|
|||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
|
||||
}
|
||||
|
||||
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
|
||||
|
||||
numerator = rewriter.create<MulFOp>(loc, input, numerator);
|
||||
Value denominator = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
|
||||
for (int i = 1; i < denominator_coeffs.size(); i++) {
|
||||
|
|
@ -138,10 +115,38 @@ class ApproximateTanhLowering
|
|||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
|
||||
}
|
||||
|
||||
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
|
||||
|
||||
return rewriter.create<SelectOp>(loc, return_input, input, approx);
|
||||
// For small values of |x|, we can approximate tanh(x) = x. For extremely
|
||||
// small values of x (|x| < 1e-37), the other approximation would evaluate
|
||||
// tanh(x) = 0.
|
||||
constexpr float kUseIdentityApprox = 0.0004;
|
||||
Value abs_input = rewriter.create<AbsFOp>(loc, input);
|
||||
Value use_identity_approx = rewriter.create<CmpFOp>(
|
||||
loc, CmpFPredicate::OLT, abs_input,
|
||||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(kUseIdentityApprox)));
|
||||
approx = rewriter.create<SelectOp>(loc, use_identity_approx, input, approx);
|
||||
|
||||
// For very small/large values, use a constant approximation -1/1.
|
||||
Value too_large_input = rewriter.create<CmpFOp>(
|
||||
loc, CmpFPredicate::UGT, input,
|
||||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(7.90531110763549805f)));
|
||||
Value too_small_input = rewriter.create<CmpFOp>(
|
||||
loc, CmpFPredicate::ULT, input,
|
||||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
|
||||
approx = rewriter.create<SelectOp>(
|
||||
loc, too_large_input,
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
|
||||
approx);
|
||||
approx = rewriter.create<SelectOp>(
|
||||
loc, too_small_input,
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
|
||||
approx);
|
||||
|
||||
return approx;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
|
@ -95,7 +96,7 @@ class LhloFuseLinalgPass
|
|||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
|
||||
if (auto tensor_load = dyn_cast<memref::TensorLoadOp>(definingOp)) {
|
||||
auto alias = tensor_load.memref();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
|
|
@ -103,7 +104,7 @@ class LhloFuseLinalgPass
|
|||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
|
||||
if (auto tensor_to_memref = dyn_cast<memref::BufferCastOp>(definingOp)) {
|
||||
auto alias = tensor_to_memref.tensor();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
|
|
|
|||
|
|
@ -96,9 +96,10 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||
|
||||
// Load the initial value and store it to the output.
|
||||
for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
|
||||
auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
|
||||
rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
|
||||
ArrayRef<Value>{index});
|
||||
auto init_value =
|
||||
rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
|
||||
rewriter.create<mlir::memref::StoreOp>(
|
||||
loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
|
||||
}
|
||||
|
||||
// Insert a loop into the body to compute the reduction. The loop ranges
|
||||
|
|
@ -128,8 +129,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||
auto oneAttr = rewriter.getI64IntegerAttr(1);
|
||||
OpFoldResult size = oneAttr;
|
||||
OpFoldResult stride = oneAttr;
|
||||
auto accumulator = rewriter.create<SubViewOp>(loc, resType, output,
|
||||
offset, size, stride);
|
||||
auto accumulator = rewriter.create<memref::SubViewOp>(
|
||||
loc, resType, output, offset, size, stride);
|
||||
llvm::SmallVector<Value, 4> indexings;
|
||||
auto input_buffer = *reduce_op.operands().begin();
|
||||
auto input_type_rank =
|
||||
|
|
@ -143,8 +144,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||
}));
|
||||
SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
|
||||
SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
|
||||
auto rhs = rewriter.create<SubViewOp>(loc, accumulator.getType(), input,
|
||||
offsets, sizes, strides);
|
||||
auto rhs = rewriter.create<memref::SubViewOp>(
|
||||
loc, accumulator.getType(), input, offsets, sizes, strides);
|
||||
|
||||
// Now copy over the actual body of the reduction, leaving out the
|
||||
// terminator.
|
||||
|
|
@ -179,8 +180,9 @@ struct LhloLegalizeToGpuPass
|
|||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
|
||||
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,
|
||||
LmhloDialect>();
|
||||
target.addIllegalOp<ReduceOp>();
|
||||
auto func = getFunction();
|
||||
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
|
@ -43,10 +44,11 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
|
|||
Block* lhlo_block, OpBuilder* b) {
|
||||
SmallVector<Value, 2> arg_bufs;
|
||||
for (auto arg_type : lhlo_block->getArgumentTypes()) {
|
||||
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
|
||||
arg_bufs.push_back(
|
||||
b->create<memref::AllocOp>(loc, arg_type.cast<MemRefType>()));
|
||||
}
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
|
||||
b->create<memref::StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
|
||||
}
|
||||
// Clone the ops from `lhlo_block`.
|
||||
BlockAndValueMapping mapping;
|
||||
|
|
@ -55,7 +57,7 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
|
|||
auto clone = b->clone(nested, mapping);
|
||||
mapping.map(nested.getResults(), clone->getResults());
|
||||
}
|
||||
return b->create<LoadOp>(loc, arg_bufs.back());
|
||||
return b->create<memref::LoadOp>(loc, arg_bufs.back());
|
||||
}
|
||||
|
||||
// Converts a block with LHLO ops and with signature:
|
||||
|
|
@ -78,7 +80,8 @@ void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
|
|||
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
|
||||
size_t dim_index, int64_t dim, OpBuilder* b) {
|
||||
return dim == ShapedType::kDynamicSize
|
||||
? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
|
||||
? b->create<memref::DimOp>(loc, shaped_value, dim_index)
|
||||
.getResult()
|
||||
: b->create<ConstantIndexOp>(loc, dim);
|
||||
}
|
||||
|
||||
|
|
@ -249,8 +252,8 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
|
||||
}
|
||||
// Load initial value from memref<element_type>.
|
||||
SmallVector<Value, 1> init_value = {
|
||||
rewriter->create<LoadOp>(loc, *reduce_op.init_values().begin())};
|
||||
SmallVector<Value, 1> init_value = {rewriter->create<memref::LoadOp>(
|
||||
loc, *reduce_op.init_values().begin())};
|
||||
// Outer ParallelOp is not needed if it is a reduction across all dims.
|
||||
scf::ParallelOp outer;
|
||||
if (!parallel_lower.empty()) {
|
||||
|
|
@ -272,7 +275,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
|
||||
rewriter->create<memref::StoreOp>(loc, reduction_result, out, out_indices);
|
||||
|
||||
// Load the element to reduce.
|
||||
SmallVector<Value, 2> indices;
|
||||
|
|
@ -290,7 +293,7 @@ class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
|||
}
|
||||
|
||||
rewriter->setInsertionPointToStart(inner.getBody());
|
||||
Value elem = rewriter->create<mlir::LoadOp>(
|
||||
Value elem = rewriter->create<mlir::memref::LoadOp>(
|
||||
loc, *reduce_op.operands().begin(), indices);
|
||||
return rewriter->create<scf::ReduceOp>(loc, elem);
|
||||
}
|
||||
|
|
@ -385,7 +388,7 @@ class ReduceWindowOpConverter
|
|||
ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = reduce_window_op.getLoc();
|
||||
Value init_value =
|
||||
rewriter->create<LoadOp>(loc, reduce_window_op.init_value());
|
||||
rewriter->create<memref::LoadOp>(loc, reduce_window_op.init_value());
|
||||
|
||||
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||
|
|
@ -408,7 +411,8 @@ class ReduceWindowOpConverter
|
|||
|
||||
Value reduction_result = *window_loop.getResults().begin();
|
||||
auto output_ivs = output_loop.getInductionVars();
|
||||
rewriter->create<StoreOp>(loc, reduction_result, output, output_ivs);
|
||||
rewriter->create<memref::StoreOp>(loc, reduction_result, output,
|
||||
output_ivs);
|
||||
return std::make_pair(output_loop, window_loop);
|
||||
}
|
||||
|
||||
|
|
@ -439,7 +443,7 @@ class ReduceWindowOpConverter
|
|||
|
||||
OpBuilder then_builder =
|
||||
elem_or_init.getThenBodyBuilder(rewriter->getListener());
|
||||
Value elem = then_builder.create<mlir::LoadOp>(
|
||||
Value elem = then_builder.create<mlir::memref::LoadOp>(
|
||||
loc, reduce_window_op.operand(), mapped_ivs.ivs);
|
||||
then_builder.create<scf::YieldOp>(loc, elem);
|
||||
|
||||
|
|
@ -497,8 +501,8 @@ class SelectAndScatterOpConverter
|
|||
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
|
||||
|
||||
// Load `source[selected_ivs]`.
|
||||
auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
|
||||
loop_over_src.getInductionVars());
|
||||
auto src_elem = rewriter.create<memref::LoadOp>(
|
||||
loc, s_and_s_op.source(), loop_over_src.getInductionVars());
|
||||
|
||||
// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
|
||||
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
|
||||
|
|
@ -517,14 +521,14 @@ class SelectAndScatterOpConverter
|
|||
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
|
||||
Value init_value = b->create<memref::LoadOp>(loc, s_and_s_op.init_value());
|
||||
|
||||
scf::ParallelOp loop_over_output =
|
||||
MakeLoopOverShape(loc, s_and_s_op.out(), b);
|
||||
OpBuilder::InsertionGuard guard(*b);
|
||||
b->setInsertionPointToStart(loop_over_output.getBody());
|
||||
b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
|
||||
loop_over_output.getInductionVars());
|
||||
b->create<memref::StoreOp>(loc, init_value, s_and_s_op.out(),
|
||||
loop_over_output.getInductionVars());
|
||||
}
|
||||
|
||||
struct WindowLoops {
|
||||
|
|
@ -647,7 +651,7 @@ class SelectAndScatterOpConverter
|
|||
|
||||
TypeRange iter_arg_types{ivs_val_flag->to_vector()};
|
||||
Value operand_elem =
|
||||
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
|
||||
b->create<memref::LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
|
||||
auto if_init =
|
||||
b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
|
||||
/*withElseRegion=*/true);
|
||||
|
|
@ -712,8 +716,8 @@ struct LhloLegalizeToParallelLoopsPass
|
|||
// clang-format on
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
scf::SCFDialect, LmhloDialect>();
|
||||
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect, scf::SCFDialect, LmhloDialect>();
|
||||
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
|
||||
lmhlo::SelectAndScatterOp>();
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
|
@ -58,7 +59,8 @@ Value CalculateShapeValue(Location loc, Value operand,
|
|||
int64_t rank = result_type.getRank();
|
||||
shape_values.reserve(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
|
||||
shape_values.push_back(
|
||||
rewriter.create<mlir::memref::DimOp>(loc, operand, i));
|
||||
}
|
||||
return rewriter.create<tensor::FromElementsOp>(loc, shape_values);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -967,10 +967,10 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple<tensor<i32>>) -> tupl
|
|||
|
||||
// CHECK-LABEL: func @erase_dead_lhlo_constant
|
||||
func @erase_dead_lhlo_constant() {
|
||||
%M = alloc() : memref<256x1024xf32>
|
||||
%M = memref.alloc() : memref<256x1024xf32>
|
||||
// CHECK-NEXT: return
|
||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||
dealloc %M : memref<256x1024xf32>
|
||||
memref.dealloc %M : memref<256x1024xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -979,9 +979,9 @@ func @erase_dead_lhlo_constant() {
|
|||
func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
|
||||
// CHECK-NEXT: lmhlo.constant
|
||||
"lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
|
||||
// CHECK-NEXT: alloc
|
||||
// CHECK-NEXT: memref.alloc
|
||||
// CHECK-NEXT: lmhlo.constant
|
||||
%N = alloc() : memref<256x1024xf32>
|
||||
%N = memref.alloc() : memref<256x1024xf32>
|
||||
"lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
|
||||
return %N : memref<256x1024xf32>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked(
|
|||
return %reshaped : tensor<?xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// -----
|
||||
|
|
@ -30,7 +30,7 @@ func @dynamic_reshape_to_unranked(
|
|||
return %reshaped : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-NEXT: memref.reshape [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
|
||||
// -----
|
||||
|
|
@ -41,4 +41,4 @@ func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
|
|||
return %reshaped : tensor<f32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
||||
// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32>
|
||||
// CHECK-NEXT: memref.cast [[ARG]] : memref<*xf32> to memref<f32>
|
||||
|
|
|
|||
|
|
@ -31,20 +31,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
|||
return %5 : tensor<4xf32>
|
||||
}
|
||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: memref.dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MIN_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: %[[SUB_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: memref.dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: memref.dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
||||
|
||||
// -----
|
||||
|
|
@ -53,15 +53,15 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
|||
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
|
||||
%summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
|
||||
%sum = "mhlo.add"(%summand_1, %summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = memref.alloc() : memref<2x2xf32>
|
||||
%result = "mhlo.multiply"(%sum, %multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: memref.dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
|
|
@ -154,9 +154,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
|
|||
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||
|
||||
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
|
|
@ -172,9 +172,9 @@ func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
|
|||
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||
|
||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
||||
// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
|
||||
|
|
@ -469,7 +469,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return %result : tensor<?x?xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
|
@ -485,7 +485,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
||||
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
||||
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
||||
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]])
|
||||
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return %result : tensor<?x?xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
|
@ -496,7 +496,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
// CHECK-LABEL: func @dot
|
||||
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc
|
||||
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
||||
// dot_dimension_numbers = {
|
||||
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
|
|
@ -517,7 +517,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
|||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
||||
-> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// CHECK-SAME: padding = dense<[
|
||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
|
|
@ -548,11 +548,11 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
|||
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
|
||||
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<1xf32>
|
||||
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
|
||||
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
|
||||
// CHECK: %[[TMP:.*]] = alloc() : memref<f32>
|
||||
// CHECK: %[[TMP:.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
|
||||
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
|
||||
// CHECK: "lmhlo.terminator"() : () -> ()
|
||||
|
|
|
|||
|
|
@ -404,7 +404,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
|||
return %0: tensor<4x2x1x4x?x16xf32>
|
||||
}
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
|
||||
// CHECK: %[[D1:.*]] = memref.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32>
|
||||
// CHECK: linalg.init_tensor [4, 2, 1, 4, %[[D1]], 16] : tensor<4x2x1x4x?x16xf32>
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
|
||||
|
|
@ -997,19 +997,41 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
|
|||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||
// CHECK-SAME: [[SCALAR:%.*]]: tensor<f32>
|
||||
// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex>
|
||||
func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor<?x32xf32> {
|
||||
%cst = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||
%result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) {
|
||||
func @dynamic_broadcast_in_dim(%scalar: tensor<f32>, %shape: tensor<2xindex>)
|
||||
-> tensor<?x32xf32> {
|
||||
%result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) {
|
||||
broadcast_dimensions = dense<> : tensor<0xi64>
|
||||
} : (tensor<f32>, tensor<2xindex>) -> tensor<?x32xf32>
|
||||
return %result : tensor<?x32xf32>
|
||||
}
|
||||
// CHECK: [[CST:%.*]] = constant
|
||||
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
|
||||
// CHECK-SAME: ins([[SCALAR]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||
// CHECK-SAME: [[VECTOR:%.*]]: tensor<42xf32>
|
||||
// CHECK-SAME: [[SHAPE:%.*]]: tensor<3xindex>
|
||||
func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xindex>)
|
||||
-> tensor<?x?x?xf32> {
|
||||
%result = "mhlo.dynamic_broadcast_in_dim"(%vector, %shape) {
|
||||
broadcast_dimensions = dense<1> : tensor<1xi64>
|
||||
} : (tensor<42xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||
return %result : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-SAME: ins([[VECTOR]] : tensor<42xf32>) outs([[INIT]] : tensor<?x?x?xf32>)
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
|
|
@ -1024,7 +1046,7 @@ func @dot_matmul(%arg0: tensor<2x3xf32>,
|
|||
// CHECK-LABEL: func @dot_matmul(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul
|
||||
|
|
@ -1040,7 +1062,7 @@ func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>,
|
|||
// CHECK-LABEL: func @dot_matmul_i8_i8_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul
|
||||
|
|
@ -1058,7 +1080,7 @@ func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>,
|
|||
// CHECK-LABEL: func @dot_matmul_i16_i16_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul
|
||||
|
|
@ -1076,7 +1098,7 @@ func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>,
|
|||
// CHECK-LABEL: func @dot_matmul_i32_i32_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>)
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, %[[D1]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matmul
|
||||
|
|
@ -1094,7 +1116,7 @@ func @dot_matvec(%arg0: tensor<?x3xf32>,
|
|||
// CHECK-LABEL: func @dot_matvec(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x3xf32>, %[[ARG1:.*]]: tensor<3xf32>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.matvec
|
||||
|
|
@ -1134,11 +1156,11 @@ func @dot_general_batch_matmul(%arg0: tensor<?x?x3xf32>,
|
|||
// CHECK-LABEL: func @dot_general_batch_matmul(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xf32>, %[[ARG1:.*]]: tensor<?x3x?xf32>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul
|
||||
|
|
@ -1163,11 +1185,11 @@ func @dot_general_batch_matmul_i8_i8_i32(%arg0: tensor<?x?x3xi8>,
|
|||
// CHECK-LABEL: func @dot_general_batch_matmul_i8_i8_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi8>, %[[ARG1:.*]]: tensor<?x3x?xi8>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul
|
||||
|
|
@ -1192,11 +1214,11 @@ func @dot_general_batch_matmul_i16_i16_i32(%arg0: tensor<?x?x3xi16>,
|
|||
// CHECK-LABEL: func @dot_general_batch_matmul_i16_i16_i32(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x3xi16>, %[[ARG1:.*]]: tensor<?x3x?xi16>)
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D0:.*]] = dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[D0:.*]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[D1:.*]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[D2:.*]] = dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[D2:.*]] = memref.dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]]]
|
||||
// CHECK: %[[FILL:.*]] = linalg.fill(%[[INIT]]
|
||||
// CHECK: linalg.batch_matmul
|
||||
|
|
@ -1420,7 +1442,7 @@ func @reduce_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>) -> tensor<?xi32
|
|||
// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor<?x?xi32>
|
||||
// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor<i32>
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[DIM1:.*]] = dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
|
||||
// CHECK-DAG: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
|
||||
// CHECK-DAG: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[DIM1]]]
|
||||
// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[INIT]])
|
||||
// CHECK: linalg.generic
|
||||
|
|
@ -1531,9 +1553,9 @@ func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x8x?xf32>, %arg1: tenso
|
|||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
|
||||
// CHECK: %[[DIM2:.+]] = memref.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, %[[DIM2]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
|
|
@ -1571,9 +1593,9 @@ func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x4x5x?xf32>, %arg1: tensor<3
|
|||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x4x5x?xf32>
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
|
||||
// CHECK: %[[DIM3:.+]] = memref.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 2, 3, %[[DIM3]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
|
|
@ -1611,9 +1633,9 @@ func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x8x8x8x?xf32>, %arg1: tens
|
|||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
|
||||
// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]] : tensor<?x8x8x8x?xf32>
|
||||
// CHECK: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
|
||||
// CHECK: %[[DIM4:.+]] = memref.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], 7, 7, 7, %[[DIM4]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
|
|
@ -1826,7 +1848,6 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1
|
|||
return %1 : tensor<1x8x8x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @reduce_window_max_nhwc
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-DAG: %[[CST:.+]] = constant dense<0xFF800000> : tensor<f32>
|
||||
|
|
@ -1839,3 +1860,125 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1
|
|||
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_select_index(%arg0: tensor<5x1x5xi32>,
|
||||
%arg1: tensor<2xi32>) -> tensor<2x1x5xi32> {
|
||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
dim = 0 : i64,
|
||||
batch_dims = 0 : i64
|
||||
} : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
|
||||
return %0 : tensor<2x1x5xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: func @torch_select_index
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: linalg.indexed_generic {
|
||||
// CHECK-SAME: indexing_maps
|
||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>)
|
||||
// CHECK: ^{{.+}}(
|
||||
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index
|
||||
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32):
|
||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32>
|
||||
// CHECK: linalg.yield %[[VAL2]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_select_index_scalar(%arg0: tensor<4x8xf32>,
|
||||
%arg1: tensor<i32>) -> tensor<8xf32> {
|
||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
batch_dims = 0 : i64,
|
||||
dim = 0 : i64
|
||||
} : (tensor<4x8xf32>, tensor<i32>) -> tensor<8xf32>
|
||||
return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
|
||||
// CHECK: func @torch_select_index_scalar
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32>
|
||||
// CHECK: linalg.indexed_generic {
|
||||
// CHECK-SAME: indexing_maps
|
||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||
// CHECK-SAME: iterator_types = ["parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<i32>) outs(%[[T0]] : tensor<8xf32>)
|
||||
// CHECK: ^{{.+}}(
|
||||
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32):
|
||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32>
|
||||
// CHECK: linalg.yield %[[VAL2]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>,
|
||||
%arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> {
|
||||
%0 = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
dim = 2 : i64,
|
||||
batch_dims = 1 : i64
|
||||
} : (tensor<4x7x8x2xf32>, tensor<4x1xi32>) -> tensor<4x7x1x2xf32>
|
||||
return %0 : tensor<4x7x1x2xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: func @torch_select_index_batch
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: linalg.indexed_generic {
|
||||
// CHECK-SAME: indexing_maps
|
||||
// CHECK-SAME: #[[MAP0]], #[[MAP1]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>)
|
||||
// CHECK-NEXT: ^{{.+}}(
|
||||
// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[J:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]: index, %[[L:.+]]: index,
|
||||
// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: f32):
|
||||
// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index
|
||||
// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32>
|
||||
// CHECK: linalg.yield %[[VAL2]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @torch_index_select_dynamic(%input: tensor<?x?x?x?xf32>,
|
||||
%index: tensor<?x?xi32>) -> tensor<?x?x?x?xf32>{
|
||||
%0 = "mhlo.torch_index_select"(%input, %index) {
|
||||
batch_dims = 1 : i64,
|
||||
dim = 2 : i64
|
||||
} : (tensor<?x?x?x?xf32>, tensor<?x?xi32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: func @torch_index_select_dynamic
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[INPUT]], %[[C0]]
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[D1:.+]] = memref.dim %[[INPUT]], %[[C1]]
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[D2:.+]] = memref.dim %[[INDEX]], %[[C1]]
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]]
|
||||
// CHECK: %[[RESULT:.+]] = linalg.indexed_generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[INDEX]] : tensor<?x?xi32>)
|
||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?x?x?xf32>)
|
||||
// CHECK: ^{{.+}}(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index,
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32)
|
||||
// CHECK: %[[POS:.+]] = index_cast %[[ARG4]]
|
||||
// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]]
|
||||
// CHECK: linalg.yield %[[YIELD]]
|
||||
|
|
|
|||
|
|
@ -1,125 +1,78 @@
|
|||
// RUN: mlir-hlo-opt --mhlo-legalize-trigonometric-to-approximation --split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @tanh_f64
|
||||
func @tanh_f64(%arg0 : f64) -> f64 {
|
||||
// CHECK: tanh
|
||||
%res = math.tanh %arg0 : f64
|
||||
return %res : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tanh_f64
|
||||
// CHECK: tanh
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @tanh_f32
|
||||
// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32
|
||||
func @tanh_f32(%arg0 : f32) -> f32 {
|
||||
// CHECK-DAG: %[[C:.*]] = constant -2.76076837E-16 : f32
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 2.00018794E-13 : f32
|
||||
// CHECK-DAG: %[[C1:.*]] = constant -8.60467184E-11 : f32
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 5.12229725E-8 : f32
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 1.48572235E-5 : f32
|
||||
// CHECK-DAG: %[[C4:.*]] = constant 6.37261954E-4 : f32
|
||||
// CHECK-DAG: %[[C5:.*]] = constant 0.00489352457 : f32
|
||||
// CHECK-DAG: %[[C6:.*]] = constant 1.19825836E-6 : f32
|
||||
// CHECK-DAG: %[[C7:.*]] = constant 1.18534706E-4 : f32
|
||||
// CHECK-DAG: %[[C8:.*]] = constant 0.00226843474 : f32
|
||||
// CHECK-DAG: %[[C9:.*]] = constant 0.00489352504 : f32
|
||||
// CHECK-DAG: %[[C10:.*]] = constant 4.000000e-04 : f32
|
||||
// CHECK-DAG: %[[C11:.*]] = constant 7.90531111 : f32
|
||||
// CHECK-DAG: %[[C12:.*]] = constant -7.90531111 : f32
|
||||
// CHECK-DAG: %[[C13:.*]] = constant 1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[C14:.*]] = constant -1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[TMP0:.*]] = mulf %[[ARG]], %[[ARG]] : f32
|
||||
// CHECK-DAG: %[[TMP1:.*]] = mulf %[[TMP0]], %[[C]] : f32
|
||||
// CHECK-DAG: %[[TMP2:.*]] = addf %[[TMP1]], %[[C0]] : f32
|
||||
// CHECK-DAG: %[[TMP3:.*]] = mulf %[[TMP0]], %[[TMP2]] : f32
|
||||
// CHECK-DAG: %[[TMP4:.*]] = addf %[[TMP3]], %[[C1]] : f32
|
||||
// CHECK-DAG: %[[TMP5:.*]] = mulf %[[TMP0]], %[[TMP4]] : f32
|
||||
// CHECK-DAG: %[[TMP6:.*]] = addf %[[TMP5]], %[[C2]] : f32
|
||||
// CHECK-DAG: %[[TMP7:.*]] = mulf %[[TMP0]], %[[TMP6]] : f32
|
||||
// CHECK-DAG: %[[TMP8:.*]] = addf %[[TMP7]], %[[C3]] : f32
|
||||
// CHECK-DAG: %[[TMP9:.*]] = mulf %[[TMP0]], %[[TMP8]] : f32
|
||||
// CHECK-DAG: %[[TMP10:.*]] = addf %[[TMP9]], %[[C4]] : f32
|
||||
// CHECK-DAG: %[[TMP11:.*]] = mulf %[[TMP0]], %[[TMP10]] : f32
|
||||
// CHECK-DAG: %[[TMP12:.*]] = addf %[[TMP11]], %[[C5]] : f32
|
||||
// CHECK-DAG: %[[TMP13:.*]] = mulf %[[ARG]], %[[TMP12]] : f32
|
||||
// CHECK-DAG: %[[TMP14:.*]] = mulf %[[TMP0]], %[[C6]] : f32
|
||||
// CHECK-DAG: %[[TMP15:.*]] = addf %[[TMP14]], %[[C7]] : f32
|
||||
// CHECK-DAG: %[[TMP16:.*]] = mulf %[[TMP0]], %[[TMP15]] : f32
|
||||
// CHECK-DAG: %[[TMP17:.*]] = addf %[[TMP16]], %[[C8]] : f32
|
||||
// CHECK-DAG: %[[TMP18:.*]] = mulf %[[TMP0]], %[[TMP17]] : f32
|
||||
// CHECK-DAG: %[[TMP19:.*]] = addf %[[TMP18]], %[[C9]] : f32
|
||||
// CHECK-DAG: %[[TMP20:.*]] = divf %[[TMP13]], %[[TMP19]] : f32
|
||||
// CHECK-DAG: %[[TMP21:.*]] = absf %[[ARG]] : f32
|
||||
// CHECK-DAG: %[[TMP22:.*]] = cmpf olt, %[[TMP21]], %[[C10]] : f32
|
||||
// CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32
|
||||
// CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32
|
||||
// CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32
|
||||
// CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32
|
||||
// CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32
|
||||
// CHECK: return %[[TMP27]] : f32
|
||||
%res = math.tanh %arg0 : f32
|
||||
return %res : f32
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @tanh_f32
|
||||
// CHECK-SAME: (%[[VAL_0:.*]]: f32) -> f32
|
||||
// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32
|
||||
// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32
|
||||
// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32
|
||||
// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32
|
||||
// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32
|
||||
// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32
|
||||
// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32
|
||||
// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32
|
||||
// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32
|
||||
// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32
|
||||
// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32
|
||||
// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32
|
||||
// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32
|
||||
// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32
|
||||
// CHECK: %[[VAL_15:.*]] = absf %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_16:.*]] = cmpf olt, %[[VAL_15]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_17:.*]] = cmpf ule, %[[VAL_0]], %[[VAL_2]] : f32
|
||||
// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_0]], %[[VAL_2]] : f32
|
||||
// CHECK: %[[VAL_19:.*]] = cmpf uge, %[[VAL_18]], %[[VAL_3]] : f32
|
||||
// CHECK: %[[VAL_20:.*]] = select %[[VAL_19]], %[[VAL_18]], %[[VAL_3]] : f32
|
||||
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_20]] : f32
|
||||
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_4]] : f32
|
||||
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_21]], %[[VAL_23]] : f32
|
||||
// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_21]], %[[VAL_25]] : f32
|
||||
// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_21]], %[[VAL_27]] : f32
|
||||
// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32
|
||||
// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_21]], %[[VAL_29]] : f32
|
||||
// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32
|
||||
// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_21]], %[[VAL_31]] : f32
|
||||
// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32
|
||||
// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_20]], %[[VAL_33]] : f32
|
||||
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_11]] : f32
|
||||
// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_12]] : f32
|
||||
// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_21]], %[[VAL_36]] : f32
|
||||
// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32
|
||||
// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_21]], %[[VAL_38]] : f32
|
||||
// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32
|
||||
// CHECK: %[[VAL_41:.*]] = divf %[[VAL_34]], %[[VAL_40]] : f32
|
||||
// CHECK: %[[VAL_42:.*]] = select %[[VAL_16]], %[[VAL_0]], %[[VAL_41]] : f32
|
||||
// CHECK: return %[[VAL_42]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @tanh_f16(%arg0 : f16) -> f16 {
|
||||
// CHECK-LABEL: func @tanh_f16
|
||||
// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16
|
||||
// CHECK: %{{.*}} = fpext %[[ARG]] : f16 to f32
|
||||
// CHECK: %[[RES:.*]] = fptrunc %{{.*}} : f32 to f16
|
||||
// CHECK: return %[[RES]] : f16
|
||||
%res = math.tanh %arg0 : f16
|
||||
return %res : f16
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @tanh_f16
|
||||
// CHECK-SAME: (%[[VAL_0:.*]]: f16) -> f16
|
||||
// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32
|
||||
// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32
|
||||
// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32
|
||||
// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32
|
||||
// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32
|
||||
// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32
|
||||
// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32
|
||||
// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32
|
||||
// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32
|
||||
// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32
|
||||
// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32
|
||||
// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32
|
||||
// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32
|
||||
// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32
|
||||
// CHECK: %[[VAL_15:.*]] = fpext %[[VAL_0]] : f16 to f32
|
||||
// CHECK: %[[VAL_16:.*]] = absf %[[VAL_15]] : f32
|
||||
// CHECK: %[[VAL_17:.*]] = cmpf olt, %[[VAL_16]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_18:.*]] = cmpf ule, %[[VAL_15]], %[[VAL_2]] : f32
|
||||
// CHECK: %[[VAL_19:.*]] = select %[[VAL_18]], %[[VAL_15]], %[[VAL_2]] : f32
|
||||
// CHECK: %[[VAL_20:.*]] = cmpf uge, %[[VAL_19]], %[[VAL_3]] : f32
|
||||
// CHECK: %[[VAL_21:.*]] = select %[[VAL_20]], %[[VAL_19]], %[[VAL_3]] : f32
|
||||
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_21]] : f32
|
||||
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_4]] : f32
|
||||
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_22]], %[[VAL_24]] : f32
|
||||
// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_22]], %[[VAL_26]] : f32
|
||||
// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_22]], %[[VAL_28]] : f32
|
||||
// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32
|
||||
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_22]], %[[VAL_30]] : f32
|
||||
// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32
|
||||
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_22]], %[[VAL_32]] : f32
|
||||
// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32
|
||||
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_34]] : f32
|
||||
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_11]] : f32
|
||||
// CHECK: %[[VAL_37:.*]] = addf %[[VAL_36]], %[[VAL_12]] : f32
|
||||
// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_22]], %[[VAL_37]] : f32
|
||||
// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32
|
||||
// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_22]], %[[VAL_39]] : f32
|
||||
// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32
|
||||
// CHECK: %[[VAL_42:.*]] = divf %[[VAL_35]], %[[VAL_41]] : f32
|
||||
// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32
|
||||
// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16
|
||||
// CHECK: return %[[VAL_44]] : f16
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @atan2_f64
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
%temp_result = memref.alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||
outs(%temp_result : memref<6x6xf32>) {
|
||||
|
|
@ -22,7 +22,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
|||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
}
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
memref.dealloc %temp_result : memref<6x6xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
|
|
@ -62,7 +62,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
|||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
%arg1: memref<100xf32>,
|
||||
%arg2: memref<100x10xf32>) {
|
||||
%0 = alloc() : memref<100x10xf32>
|
||||
%0 = memref.alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
|
|
@ -72,7 +72,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
linalg.yield %arg3 : f32
|
||||
}
|
||||
%1 = alloc() : memref<100x10xf32>
|
||||
%1 = memref.alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
|
|
@ -84,7 +84,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
%2 = subf %arg3, %arg4 : f32
|
||||
linalg.yield %2 : f32
|
||||
}
|
||||
dealloc %0 : memref<100x10xf32>
|
||||
memref.dealloc %0 : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
|
|
@ -95,7 +95,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
%2 = math.exp %arg3 : f32
|
||||
linalg.yield %2 : f32
|
||||
}
|
||||
dealloc %1 : memref<100x10xf32>
|
||||
memref.dealloc %1 : memref<100x10xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
|
|
@ -141,7 +141,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
"parallel"]}
|
||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||
%temp_result = memref.alloc() : memref<6x6x6x6xf32>
|
||||
linalg.generic #pointwise_4d_trait
|
||||
ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
|
||||
outs(%temp_result : memref<6x6x6x6xf32>) {
|
||||
|
|
@ -156,7 +156,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
|||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
}
|
||||
dealloc %temp_result : memref<6x6x6x6xf32>
|
||||
memref.dealloc %temp_result : memref<6x6x6x6xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion_4d
|
||||
|
|
@ -200,7 +200,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
|||
iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
%temp_result = memref.alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
|
||||
outs(%temp_result : memref<6x6xf32>) {
|
||||
|
|
@ -208,7 +208,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
|||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
}
|
||||
%result = alloc() : memref<6x6xf32>
|
||||
%result = memref.alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
|
||||
outs(%result : memref<6x6xf32>) {
|
||||
|
|
@ -216,7 +216,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
|||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
}
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
memref.dealloc %temp_result : memref<6x6xf32>
|
||||
return %result : memref<6x6xf32>
|
||||
}
|
||||
|
||||
|
|
@ -258,7 +258,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
-> memref<*xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%1 = alloc(%arg2) : memref<?xf32>
|
||||
%1 = memref.alloc(%arg2) : memref<?xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
|
|
@ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
%2 = memref.reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
return %2 : memref<*xf32>
|
||||
}
|
||||
|
|
@ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: memref.reshape
|
||||
|
||||
// TILED-LABEL: func @view_result
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
|
|
@ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: memref_reshape
|
||||
// TILED: memref.reshape
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @view_result
|
||||
|
|
@ -297,20 +297,20 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: memref.reshape
|
||||
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
|
||||
// This test also uses memref_reshape, just to have a value to return through
|
||||
// This test also uses memref.reshape, just to have a value to return through
|
||||
// the if statement.
|
||||
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
-> memref<*xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%1 = alloc(%arg2) : memref<?xf32>
|
||||
%1 = memref.alloc(%arg2) : memref<?xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
|
|
@ -321,11 +321,11 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
|||
}
|
||||
%true = constant 1 : i1
|
||||
%3 = scf.if %true -> memref<*xf32> {
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
%2 = memref.reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
scf.yield %2 : memref<*xf32>
|
||||
} else {
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
%2 = memref.reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
scf.yield %2 : memref<*xf32>
|
||||
}
|
||||
|
|
@ -340,10 +340,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
|||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: scf.if
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: memref.reshape
|
||||
// CHECK: scf.yield
|
||||
// CHECK: else
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: memref.reshape
|
||||
// CHECK: scf.yield
|
||||
|
||||
// TILED-LABEL: func @branching_result
|
||||
|
|
@ -354,10 +354,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
|||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: scf.if
|
||||
// TILED: memref_reshape
|
||||
// TILED: memref.reshape
|
||||
// TILED: scf.yield
|
||||
// TILED: else
|
||||
// TILED: memref_reshape
|
||||
// TILED: memref.reshape
|
||||
// TILED: scf.yield
|
||||
|
||||
// PLOOP-LABEL: func @branching_result
|
||||
|
|
@ -367,10 +367,10 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
|||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: scf.if
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: memref.reshape
|
||||
// PLOOP: scf.yield
|
||||
// PLOOP: else
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: memref.reshape
|
||||
// PLOOP: scf.yield
|
||||
|
||||
// -----
|
||||
|
|
@ -380,7 +380,7 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
|||
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||
-> memref<?xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%1 = alloc() : memref<32xf32>
|
||||
%1 = memref.alloc() : memref<32xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
|
|
@ -389,9 +389,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
|||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = tensor_load %1 : memref<32xf32>
|
||||
%2 = memref.tensor_load %1 : memref<32xf32>
|
||||
%3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
|
||||
%4 = tensor_to_memref %3 : memref<?xf32>
|
||||
%4 = memref.buffer_cast %3 : memref<?xf32>
|
||||
return %4 : memref<?xf32>
|
||||
}
|
||||
|
||||
|
|
@ -402,9 +402,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
|||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: tensor_load
|
||||
// CHECK: memref.tensor_load
|
||||
// CHECK: tensor.cast
|
||||
// CHECK: tensor_to_memref
|
||||
// CHECK: memref.buffer_cast
|
||||
|
||||
// TILED-LABEL: func @tensor_ops
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
|
|
@ -413,9 +413,9 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
|||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: tensor_load
|
||||
// TILED: memref.tensor_load
|
||||
// TILED: tensor.cast
|
||||
// TILED: tensor_to_memref
|
||||
// TILED: memref.buffer_cast
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @tensor_ops
|
||||
|
|
@ -424,6 +424,6 @@ func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
|||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: tensor_load
|
||||
// PLOOP: memref.tensor_load
|
||||
// PLOOP: tensor.cast
|
||||
// PLOOP: tensor_to_memref
|
||||
// PLOOP: memref.buffer_cast
|
||||
|
|
|
|||
|
|
@ -49,10 +49,10 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
// CHECK-DAG: [[CTRUE:%.*]] = constant true
|
||||
|
||||
// Parallel loop to initialize the output buffer.
|
||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
|
||||
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) {
|
||||
// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: memref.store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
|
||||
// INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true
|
||||
|
||||
// CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
|
||||
// CHECK: [[ARG_ELEM:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
|
||||
// CHECK: [[IF_INIT_RES:%.*]]:4
|
||||
// CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) {
|
||||
|
||||
|
|
@ -114,16 +114,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
|
||||
// Allocate buffers for ARG element, current selected value to adapt LHLO
|
||||
// code.
|
||||
// CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[PRED_BUF:%.*]] = alloc() : memref<i1>
|
||||
// CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
|
||||
// CHECK: [[ARG_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[SEL_VAL_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[PRED_BUF:%.*]] = memref.alloc() : memref<i1>
|
||||
// CHECK: memref.store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: memref.store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
|
||||
|
||||
// Compute PRED.
|
||||
// CHECK: "lmhlo.compare"(
|
||||
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
|
||||
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
|
||||
// CHECK: [[PRED:%.*]] = memref.load [[PRED_BUF]][] : memref<i1>
|
||||
|
||||
|
||||
// Depending on PRED, return ARG ivs & elem or current select ivs and value.
|
||||
|
|
@ -165,7 +165,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
// CHECK: }
|
||||
|
||||
// Use selected ivs to load element from the SRC buffer.
|
||||
// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
|
||||
// CHECK: [[SRC_ELEM:%.*]] = memref.load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
|
||||
|
||||
// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
|
||||
// it may happen that several other threads select the same IVs if the windows
|
||||
|
|
@ -175,16 +175,16 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
|||
// CHECK: ^bb0([[CUR_RES:%.*]]: f32):
|
||||
|
||||
// Allocate buffers for ARG element, current selected value to adapt LHLO code.
|
||||
// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[RES_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
|
||||
// CHECK: [[SRC_ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[CUR_RES_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[RES_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: memref.store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: memref.store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
|
||||
|
||||
// Compute scatter value.
|
||||
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
|
||||
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>
|
||||
// CHECK: [[RES:%.*]] = memref.load [[RES_BUF]][] : memref<f32>
|
||||
|
||||
// Atomic RMW terminator that returns updated value.
|
||||
// CHECK: atomic_yield [[RES]] : f32
|
||||
|
|
|
|||
|
|
@ -19,14 +19,14 @@ func @reduce(%arg: memref<100x10xf32>,
|
|||
// CHECK-DAG: %[[C100:.*]] = constant 100 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) {
|
||||
// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref<f32>
|
||||
// CHECK: %[[ACC:.*]] = memref.load %[[ARG1]][] : memref<f32>
|
||||
// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32>
|
||||
// CHECK-DAG: %[[LB:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[UB:.*]] = constant 10 : index
|
||||
// CHECK-DAG: %[[STEP:.*]] = constant 1 : index
|
||||
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||
// CHECK: %[[LHS:.*]] = subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
|
||||
// CHECK: %[[RHS:.*]] = subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
|
||||
// CHECK: %[[LHS:.*]] = memref.subview %[[ARG2]][%[[IDX]]] [1] [1] : memref<100xf32> to memref<f32, #[[$MAP]]>
|
||||
// CHECK: %[[RHS:.*]] = memref.subview %[[ARG0]][%[[IDX]], %[[IDX1]]] [1, 1] [1, 1] : memref<100x10xf32> to memref<f32, #[[$MAP]]>
|
||||
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
|
||||
// CHECK: }
|
||||
// CHECK: gpu.terminator
|
||||
|
|
|
|||
|
|
@ -52,10 +52,10 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
|
|||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: %[[LHS:.*]] = load
|
||||
// CHECK: %[[RHS:.*]] = load
|
||||
// CHECK: %[[LHS:.*]] = memref.load
|
||||
// CHECK: %[[RHS:.*]] = memref.load
|
||||
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]]
|
||||
// CHECK: store %[[RES]]
|
||||
// CHECK: memref.store %[[RES]]
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// -----
|
||||
|
|
@ -347,7 +347,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
|
|||
}
|
||||
// CHECK-NOT: linalg.reshape
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
|
||||
// CHECK: %[[VALUE:.*]] = memref.load %{{.*}}[[C0]]
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
|
||||
|
|
@ -785,7 +785,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
|
|||
} : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: %[[RESULT:.*]] = subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
|
||||
// CHECK: %[[RESULT:.*]] = memref.subview %[[IN]][0, 1] [2, 2] [1, 1] : memref<?x?xf32> to memref<2x2xf32, #{{.*}}>
|
||||
// CHECK: linalg.copy(%[[RESULT]], %[[OUT]])
|
||||
|
||||
// -----
|
||||
|
|
@ -899,7 +899,7 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
|
|||
|
||||
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = alloc() : memref<3x5x5x4xf32>
|
||||
%0 = memref.alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: dilations = [1, 2]
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
|
|
@ -948,22 +948,22 @@ func @reduce_add(%arg: memref<100x10xf32>,
|
|||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
|
||||
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
|
||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
||||
// CHECK: linalg.generic {
|
||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
||||
// CHECK: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: load
|
||||
// CHECK: memref.alloca
|
||||
// CHECK-NEXT: memref.alloca
|
||||
// CHECK-NEXT: memref.alloca
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NEXT: memref.load
|
||||
// CHECK-NEXT: memref.load
|
||||
// CHECK-NEXT: addf
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NEXT: memref.load
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
|
@ -984,22 +984,22 @@ func @reduce_maximum(%arg: memref<100x10xf32>,
|
|||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: %[[INIT_VAL:.*]] = load %arg1[] : memref<f32>
|
||||
// CHECK: %[[INIT_VAL:.*]] = memref.load %arg1[] : memref<f32>
|
||||
// CHECK: linalg.fill(%arg2, %[[INIT_VAL]])
|
||||
// CHECK: linalg.generic {
|
||||
// CHECK-SAME: indexing_maps = [#[[REDUCE_INPUT_MAP]], #[[REDUCE_OUTPUT_MAP]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%arg0 : memref<100x10xf32>) outs(%arg2 : memref<100xf32>) {
|
||||
// CHECK: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: alloca
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK-NEXT: load
|
||||
// CHECK: memref.alloca
|
||||
// CHECK-NEXT: memref.alloca
|
||||
// CHECK-NEXT: memref.alloca
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NEXT: memref.store
|
||||
// CHECK-NEXT: memref.load
|
||||
// CHECK-NEXT: memref.load
|
||||
// CHECK: cmpf
|
||||
// CHECK: select
|
||||
// CHECK: store
|
||||
// CHECK-NEXT: load
|
||||
// CHECK: memref.store
|
||||
// CHECK-NEXT: memref.load
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: }
|
||||
|
|
|
|||
|
|
@ -21,27 +21,27 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
|||
// CHECK-DAG: [[C5:%.*]] = constant 5 : index
|
||||
// CHECK-DAG: [[C10:%.*]] = constant 10 : index
|
||||
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
|
||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
|
||||
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) {
|
||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
|
||||
// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
|
||||
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32>
|
||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
||||
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
||||
// CHECK: scf.yield
|
||||
|
||||
// -----
|
||||
|
|
@ -65,23 +65,23 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
|||
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
|
||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
|
||||
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
|
||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]])
|
||||
// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
|
||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]]
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield
|
||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
|
||||
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
|
||||
|
||||
// -----
|
||||
|
||||
|
|
@ -104,30 +104,30 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
|||
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
|
||||
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
|
||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
|
||||
// CHECK: [[DIM0:%.*]] = memref.dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM1:%.*]] = memref.dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM2:%.*]] = memref.dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
|
||||
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]]
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
|
||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
|
||||
// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 {
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = memref.load [[ARG_BUF]]
|
||||
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32>
|
||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
||||
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
|
||||
// CHECK: scf.yield
|
||||
|
||||
// -----
|
||||
|
|
@ -157,7 +157,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
|||
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
|
||||
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
|
||||
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
|
||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
|
||||
// CHECK: [[INIT:%.*]] = memref.load [[INIT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
|
||||
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel
|
||||
|
|
@ -176,7 +176,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
|||
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
|
||||
// CHECK: [[OPERAND_ELEM:%.*]] =
|
||||
// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
|
||||
// CHECK-SAME: memref.load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
|
||||
// CHECK: scf.yield [[OPERAND_ELEM]] : f32
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.yield [[INIT]] : f32
|
||||
|
|
@ -184,18 +184,18 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
|||
|
||||
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
|
||||
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
|
||||
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: [[ELEM_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK: memref.store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: memref.store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: [[ACC_RESULT:%.*]] = memref.load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: memref.store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
|
|||
|
||||
// CHECK-LABEL: func @conv_forward
|
||||
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
|
||||
%scratch = alloc() : memref<32xi8>
|
||||
%scratch = memref.alloc() : memref<32xi8>
|
||||
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
|
||||
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
|
||||
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
|
||||
|
|
@ -61,7 +61,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
|
|||
|
||||
// CHECK-LABEL: func @conv_backfilter
|
||||
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
|
||||
%scratch = alloc() : memref<23328xui8>
|
||||
%scratch = memref.alloc() : memref<23328xui8>
|
||||
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
|
||||
{ backend_config = {algorithm = 1 : i64,
|
||||
operand_0_layout = [3,2,1,0],
|
||||
|
|
@ -91,7 +91,7 @@ func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64x
|
|||
|
||||
// CHECK-LABEL: func @conv_backinput
|
||||
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
|
||||
%scratch = alloc() : memref<32xui8>
|
||||
%scratch = memref.alloc() : memref<32xui8>
|
||||
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
|
||||
{ backend_config = {algorithm = 1 : i64,
|
||||
operand_0_layout = [3,2,1,0],
|
||||
|
|
@ -122,7 +122,7 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6
|
|||
|
||||
// CHECK-LABEL: func @conv_fused
|
||||
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||
%scratch = alloc() : memref<32xui8>
|
||||
%scratch = memref.alloc() : memref<32xui8>
|
||||
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
|
||||
{activation_mode = "Relu",
|
||||
backend_config = {algorithm = 1 : i64,
|
||||
|
|
@ -153,7 +153,7 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>,
|
|||
|
||||
// CHECK-LABEL: func @conv_fused_side_input
|
||||
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||
%scratch = alloc() : memref<0xui8>
|
||||
%scratch = memref.alloc() : memref<0xui8>
|
||||
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
|
||||
{activation_mode = "Relu",
|
||||
backend_config = {algorithm = 1 : i64,
|
||||
|
|
@ -218,8 +218,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
|||
|
||||
// CHECK-LABEL: func @cholesky
|
||||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
||||
%scratch = alloc() : memref<32xi8>
|
||||
%info = alloc() : memref<32xi32>
|
||||
%scratch = memref.alloc() : memref<32xi8>
|
||||
%info = memref.alloc() : memref<32xi32>
|
||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
|
||||
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -457,12 +457,12 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
|
|||
// CHECK-LABEL: func @fusion_memref
|
||||
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"lmhlo.fusion"() ( {
|
||||
%0 = tensor_load %input1 : memref<10xf32>
|
||||
%1 = tensor_load %input2 : memref<10xf32>
|
||||
%0 = memref.tensor_load %input1 : memref<10xf32>
|
||||
%1 = memref.tensor_load %input2 : memref<10xf32>
|
||||
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%3 = tensor_load %input3 : memref<10xf32>
|
||||
%3 = memref.tensor_load %input3 : memref<10xf32>
|
||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
tensor_store %4, %out : memref<10xf32>
|
||||
memref.tensor_store %4, %out : memref<10xf32>
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) : () -> ()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -953,7 +953,7 @@ func @tuple_token(%arg0: tensor<f32>, %arg1: !mhlo.token) -> tuple<tensor<f32>,
|
|||
// -----
|
||||
|
||||
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
|
||||
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
||||
// expected-error@+1 {{number of operands to tuple expected to match number of types}}
|
||||
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||
return %0 : tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||
}
|
||||
|
|
@ -961,7 +961,7 @@ func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<t
|
|||
// -----
|
||||
|
||||
func @tuple_type_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<i32>> {
|
||||
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<i32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
||||
// expected-error@+1 {{op has return type mismatch at 1th value}}
|
||||
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||
return %0 : tuple<tensor<f32>, tensor<i32>>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,15 +108,15 @@ func @batchNormInference_dynamic_shape(
|
|||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = memref.dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor.from_elements %[[DIM]] : tensor<1xindex>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = memref.dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = memref.dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = memref.dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = memref.dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor.from_elements %[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]] : tensor<4xindex>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
|
|
|
|||
|
|
@ -1452,22 +1452,6 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
|||
results.push_back(tensor_index_map.lookup(result));
|
||||
}
|
||||
Operation* real_inst = &inst;
|
||||
// CustomTfOp is just a wrapper around a TF op, we export the custom Op
|
||||
// not the wrapper, so we fetch the op from the region.
|
||||
if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
|
||||
// If we have custom op with a region, then use the first op in the
|
||||
// region, if it exists, otherwise just use params for custom op.
|
||||
if (!custom_op.body().empty()) {
|
||||
real_inst = &custom_op.body().front().front();
|
||||
// Use the inputs of the wrapper to reset the inputs.
|
||||
for (auto idx_op : llvm::enumerate(custom_op->getOperands())) {
|
||||
real_inst->setOperand(idx_op.index(), idx_op.value());
|
||||
}
|
||||
} else {
|
||||
module_.emitError(
|
||||
"Invalid CustomTfOp: Custom TF Op have empty region.");
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> operands;
|
||||
operands.reserve(real_inst->getNumOperands());
|
||||
for (auto operand : real_inst->getOperands()) {
|
||||
|
|
@ -1481,6 +1465,19 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
|||
operands.push_back(tensor_index_map.lookup(operand));
|
||||
}
|
||||
|
||||
// CustomTfOp is just a wrapper around a TF op, we export the custom Op
|
||||
// not the wrapper, so we fetch the op from the region.
|
||||
if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
|
||||
// If we have custom op with a region, then use the first op in the
|
||||
// region, if it exists, otherwise just use params for custom op.
|
||||
if (!custom_op.body().empty()) {
|
||||
real_inst = &custom_op.body().front().front();
|
||||
} else {
|
||||
module_.emitError(
|
||||
"Invalid CustomTfOp: Custom TF Op have empty region.");
|
||||
}
|
||||
}
|
||||
|
||||
if (auto tfl_operator =
|
||||
BuildOperator(real_inst, operands, results, intermediates))
|
||||
operators.push_back(*tfl_operator);
|
||||
|
|
|
|||
|
|
@ -1067,7 +1067,8 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void BuildGatherOp(OpBuilder *builder, OperationState &result,
|
||||
Value params, Value indices, IntegerAttr axis) {
|
||||
Value params, Value indices, IntegerAttr axis,
|
||||
IntegerAttr batch_dims) {
|
||||
auto params_type = params.getType().cast<TensorType>();
|
||||
auto indices_type = indices.getType().cast<TensorType>();
|
||||
|
||||
|
|
@ -1075,7 +1076,7 @@ static void BuildGatherOp(OpBuilder *builder, OperationState &result,
|
|||
if (!params_type.hasRank() || !indices_type.hasRank())
|
||||
return TFL::GatherOp::build(
|
||||
*builder, result, UnrankedTensorType::get(params_type.getElementType()),
|
||||
params, indices, axis);
|
||||
params, indices, axis, batch_dims);
|
||||
|
||||
int64_t params_rank = params_type.getRank();
|
||||
int64_t indices_rank = indices_type.getRank();
|
||||
|
|
@ -1096,7 +1097,29 @@ static void BuildGatherOp(OpBuilder *builder, OperationState &result,
|
|||
emitError(result.location, "params must be at least rank axis + 1");
|
||||
}
|
||||
|
||||
if (indices_rank == 0) {
|
||||
int64_t batch_dims_i = batch_dims.getInt();
|
||||
if (batch_dims_i < 0) {
|
||||
batch_dims_i += indices_rank;
|
||||
}
|
||||
|
||||
if (batch_dims_i > axis_i) {
|
||||
emitError(result.location,
|
||||
"axis should be bigger than or equal to batch_dims");
|
||||
}
|
||||
if (batch_dims_i >= params_rank || batch_dims_i > indices_rank) {
|
||||
emitError(result.location,
|
||||
"batch_dims must be smaller than params' rank and smaller than "
|
||||
"or equal to indices'rank");
|
||||
}
|
||||
for (int i = 0; i < batch_dims_i; ++i) {
|
||||
if (indices_type.getShape()[i] != params_type.getShape()[i]) {
|
||||
emitError(result.location,
|
||||
"batch dimensions of params must be equal to batch dimensions "
|
||||
"of indices");
|
||||
}
|
||||
}
|
||||
|
||||
if ((indices_rank == 0) || (indices_rank == batch_dims_i)) {
|
||||
// Scalar indices (output is rank(params) - 1).
|
||||
// Erase shape[axis]
|
||||
shape.erase(shape.begin() + axis_i);
|
||||
|
|
@ -1107,21 +1130,21 @@ static void BuildGatherOp(OpBuilder *builder, OperationState &result,
|
|||
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
|
||||
} else {
|
||||
// Higher rank indices (output is rank(params) + rank(indices) - 1).
|
||||
shape.resize(params_rank + indices_rank - 1);
|
||||
shape.resize(params_rank + indices_rank - 1 - batch_dims_i);
|
||||
// Copy params.shape[axis + 1: ] into shape[axis + indices_rank:]
|
||||
std::copy(std::begin(params_type.getShape()) + axis_i + 1,
|
||||
std::end(params_type.getShape()),
|
||||
std::begin(shape) + axis_i + indices_rank);
|
||||
std::begin(shape) + axis_i + indices_rank - batch_dims_i);
|
||||
|
||||
// Copy indices.shape into params.shape[axis]
|
||||
std::copy(std::begin(indices_type.getShape()),
|
||||
std::copy(std::begin(indices_type.getShape()) + batch_dims_i,
|
||||
std::end(indices_type.getShape()), std::begin(shape) + axis_i);
|
||||
}
|
||||
|
||||
TFL::GatherOp::build(
|
||||
*builder, result,
|
||||
RankedTensorType::get(shape, params_type.getElementType()), params,
|
||||
indices, axis);
|
||||
indices, axis, batch_dims);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -1057,13 +1057,14 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
|||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
|
||||
TFL_TensorOf<[I32, I64]>:$indices,
|
||||
I32Attr:$axis
|
||||
I32Attr:$axis,
|
||||
DefaultValuedAttr<I32Attr, "0">:$batch_dims
|
||||
);
|
||||
|
||||
let builders =
|
||||
[
|
||||
OpBuilder<(ins "Value":$params, "Value":$indices, "IntegerAttr":$axis),
|
||||
[{ BuildGatherOp(&$_builder, $_state, params, indices, axis); }]>
|
||||
OpBuilder<(ins "Value":$params, "Value":$indices, "IntegerAttr":$axis, "IntegerAttr":$batch_dims),
|
||||
[{ BuildGatherOp(&$_builder, $_state, params, indices, axis, batch_dims); }]>
|
||||
];
|
||||
|
||||
let results = (outs
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ cc_library(
|
|||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//tensorflow/lite/tools/optimize:quantization_utils",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ cc_library(
|
|||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
|
|
|
|||
|
|
@ -512,7 +512,7 @@ bool QuantizationDriver::SetOperandParams(Operation *op, int index,
|
|||
|
||||
void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
|
||||
QuantParams params) {
|
||||
builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
|
||||
builder_.setInsertionPointAfter(op);
|
||||
Value original_result = op->getResult(index);
|
||||
QuantizeValue(original_result, params, op->getLoc());
|
||||
}
|
||||
|
|
@ -741,10 +741,9 @@ void QuantizationDriver::SetupAllStates() {
|
|||
}
|
||||
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (op->hasTrait<OpTrait::IsTerminator>() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
|
||||
if (IsOpNotQuantizable(op)) {
|
||||
return;
|
||||
}
|
||||
work_list_.push_back(op);
|
||||
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
|
|
|
|||
|
|
@ -51,6 +51,25 @@ namespace quant {
|
|||
constexpr double kNearZeroTolerance = 1.0e-6;
|
||||
constexpr double kSmallestHalfRange = kNearZeroTolerance / 2;
|
||||
|
||||
const char kQuantTraitAttr[] = "_tfl_quant_trait";
|
||||
const absl::string_view QuantTraitValues[] = {"fully_quantizable",
|
||||
"not_quantizable"};
|
||||
|
||||
bool IsOpNotQuantizable(Operation* op) {
|
||||
// If it is terminator or not quantizable or any ops form the mlir quant
|
||||
// ops dialect, we shouldn't rewrite.
|
||||
bool attr_enforced_quantizable =
|
||||
op->hasAttrOfType<StringAttr>(kQuantTraitAttr) &&
|
||||
op->getAttrOfType<StringAttr>(kQuantTraitAttr).getValue().str() ==
|
||||
QuantTraitValues[QuantizationTrait::FullyQuantizable];
|
||||
bool prop_enforced_no_quantizable =
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>();
|
||||
|
||||
return op->hasTrait<OpTrait::IsTerminator>() ||
|
||||
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(op) ||
|
||||
(!attr_enforced_quantizable && prop_enforced_no_quantizable);
|
||||
}
|
||||
|
||||
// This method expands the range to be larger than or equal to 1.0e-6, if it is
|
||||
// very small (< 1.0e-6). This is to prevent very large quantized value by this
|
||||
// range.
|
||||
|
|
|
|||
|
|
@ -22,9 +22,11 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
|
|
@ -32,6 +34,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
|
|
@ -49,6 +52,10 @@ namespace quant {
|
|||
// losing accuracy.
|
||||
constexpr char kVolatileOpAttrName[] = "volatile";
|
||||
|
||||
enum QuantizationTrait { FullyQuantizable, NotQuantizable };
|
||||
extern const char kQuantTraitAttr[];
|
||||
extern const absl::string_view QuantTraitValues[];
|
||||
|
||||
using QuantParams = quant::QuantizedType;
|
||||
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
|
||||
using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
|
||||
|
|
@ -91,6 +98,8 @@ QuantizedType DownCastScale(QuantizedType type,
|
|||
QuantizedType DownCastScale(QuantizedType type, double min, double max,
|
||||
Location loc);
|
||||
|
||||
bool IsOpNotQuantizable(Operation* op);
|
||||
|
||||
template <typename Q, typename DQ>
|
||||
struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed,
|
||||
|
|
@ -227,10 +236,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||
|
||||
// If it is terminator or not quantizable or any ops form the mlir quant
|
||||
// ops dialect, we shouldn't rewrite.
|
||||
if (quantized_op->hasTrait<OpTrait::IsTerminator>() ||
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(
|
||||
quantized_op)) {
|
||||
if (IsOpNotQuantizable(quantized_op)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
@ -306,8 +312,9 @@ struct QuantizationPattern : public RewritePattern {
|
|||
if (quantized_op->getNumRegions() != 0) {
|
||||
for (auto indexed_regions :
|
||||
llvm::enumerate(quantized_op->getRegions())) {
|
||||
new_op->getRegion(indexed_regions.index())
|
||||
.takeBody(indexed_regions.value());
|
||||
Region& target_region = new_op->getRegion(indexed_regions.index());
|
||||
BlockAndValueMapping mapping;
|
||||
indexed_regions.value().cloneInto(&target_region, mapping);
|
||||
}
|
||||
}
|
||||
for (auto output : outputs_replaced) {
|
||||
|
|
|
|||
|
|
@ -55,3 +55,40 @@ module attributes {tf_saved_model.semantics} {
|
|||
// CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test for func with no bound_input.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "Variable", type = tensor<1x10xf32>, value = dense<0.000000e+00> : tensor<1x10xf32>} : () -> ()
|
||||
func @serving_default(%arg0: tensor<1x10xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>{tf_saved_model.bound_input = @Variable}) ->
|
||||
(tensor<1x10xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
|
||||
%1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : tensor<1x10xf32>
|
||||
"tf.AssignVariableOp"(%arg1, %1) : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
|
||||
%2 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
|
||||
return %2 : tensor<1x10xf32>
|
||||
}
|
||||
|
||||
func private @"FuncWithNoBoundInput"(%arg0: tensor<1x10xf32>, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32> {
|
||||
"tf.AssignVariableOp"(%arg1, %arg0) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
|
||||
return %0 : tensor<1x10xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @SessionInitializerFunction() attributes {tf_saved_model.exported_names = ["SessionInitializerFunction"]} {
|
||||
// CHECK: %[[RESOURCE:.*]] = "tfl.pseudo_const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: %[[VAL:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
|
||||
// CHECK: "tfl.assign_variable"(%[[RESOURCE]], %[[VAL]]) : (tensor<1xi32>, tensor<1x10xf32>) -> ()
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
// CHECK: "tf_saved_model.session_initializer"() {initializers = [@SessionInitializerFunction]} : () -> ()
|
||||
// CHECK: func @serving_default(%arg0: tensor<1x10xf32> {tf_saved_model.index_path = ["x"]}
|
||||
// CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
|
||||
//
|
||||
// CHECK: func private @FuncWithNoBoundInput(%arg0: tensor<1x10xf32>, %arg1: tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32> {
|
||||
// CHECK: "tf.AssignVariableOp"(%arg1, %arg0) {device = ""} : (tensor<!tf.resource<tensor<1x10xf32>>>, tensor<1x10xf32>) -> ()
|
||||
// CHECK: %0 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<1x10xf32>>>) -> tensor<1x10xf32>
|
||||
// CHECK: return %0 : tensor<1x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -389,7 +389,7 @@ func @gatherScalarIndices(%arg0 : tensor<3x2xf32>, %arg1 : tensor<i32>) -> tenso
|
|||
return %0 : tensor<2xf32>
|
||||
|
||||
// CHECK-LABEL:gatherScalarIndices
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<3x2xf32>, tensor<i32>) -> tensor<2xf32>
|
||||
}
|
||||
|
||||
func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> tensor<3xf32> {
|
||||
|
|
@ -397,7 +397,7 @@ func @gatherVectorIndices(%arg0 : tensor<2xf32>, %arg1 : tensor<3xi32>) -> tenso
|
|||
return %0 : tensor<3xf32>
|
||||
|
||||
// CHECK-LABEL:gatherVectorIndices
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<2xf32>, tensor<3xi32>) -> tensor<3xf32>
|
||||
}
|
||||
|
||||
func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5xi32>) -> tensor<4x5x3x6xf32> {
|
||||
|
|
@ -405,7 +405,7 @@ func @gatherHigherRankIndices(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<4x5xi32>
|
|||
return %0 : tensor<4x5x3x6xf32>
|
||||
|
||||
// CHECK-LABEL:gatherHigherRankIndices
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32} : (tensor<2x3x6xf32>, tensor<4x5xi32>) -> tensor<4x5x3x6xf32>
|
||||
}
|
||||
|
||||
func @gatherNdVectorIndices(%arg0 : tensor<3x2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> {
|
||||
|
|
@ -460,7 +460,7 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
|
|||
return %1 : tensor<1x3x5x20xf32>
|
||||
|
||||
// CHECK-LABEL:gatherV2VectorIndices
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
|
||||
|
|
@ -469,7 +469,7 @@ func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3
|
|||
return %1 : tensor<1x3x5x20xf32>
|
||||
|
||||
// CHECK-LABEL:gatherV2VectorIndices_I64Axis
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
|
|
@ -478,18 +478,20 @@ func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x
|
|||
return %1 : tensor<1x2x3x5xf32>
|
||||
|
||||
// CHECK-LABEL:gatherV2VectorIndices
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = -1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32>
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = -1 : i32, batch_dims = 0 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x2x3x5xf32>
|
||||
}
|
||||
|
||||
func @gatherV2NonZeroBatchDims(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
func @gatherWithBatchDims(%arg0 : tensor<2x3x6xf32>, %arg1 : tensor<2x5xi32>) -> tensor<2x5x3x6xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
|
||||
return %1 : tensor<1x2x3x5xf32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<2x3x6xf32>, tensor<2x5xi32>, tensor<1xi32>) -> tensor<2x5x3x6xf32>
|
||||
return %1 : tensor<2x5x3x6xf32>
|
||||
|
||||
// CHECK-LABEL:gatherV2NonZeroBatchDims
|
||||
// CHECK: tf.GatherV2
|
||||
// CHECK-LABEL:gatherWithBatchDims
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 1 : i32} : (tensor<2x3x6xf32>, tensor<2x5xi32>) -> tensor<2x5x3x6xf32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
func @greater(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
return %0 : tensor<8x16xi1>
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@
|
|||
|
||||
// Verify tensors in interpreter state:
|
||||
// ------------------------------------
|
||||
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4B ( 0.0 MB) (null)
|
||||
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4B ( 0.0 MB) (null)
|
||||
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4B ( 0.0 MB) [1]
|
||||
// CHECK-NEXT: Tensor 3 tfl.while kTfLiteInt32 kTfLiteArenaRw 4B ( 0.0 MB) (null)
|
||||
// CHECK-NEXT: Tensor 4 result kTfLiteFloat32 kTfLiteArenaRw 4B ( 0.0 MB) [1]
|
||||
|
||||
// Verify while was not folded away:
|
||||
// ------------------------------------
|
||||
|
|
|
|||
|
|
@ -82,6 +82,14 @@ func @testGatherUnsupportedRank(%arg0 : tensor<f32>, %arg1 : tensor<1xi32>) -> t
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testGatherWithBatchDims
|
||||
func @testGatherWithBatchDims(%arg0 : tensor<2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 2 : i32}: (tensor<2xf32>,tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testAbs
|
||||
func @testAbs(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
|
|
|
|||
|
|
@ -1948,6 +1948,32 @@ func @DontRemoveReshapeBeforeFullyConnectedChangeLastDim(%arg0: tensor<128x64xf3
|
|||
// CHECK: return %[[FULLY_CONNECTED]] : tensor<256x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RemoveReshapeAfterFullyConnected
|
||||
func @RemoveReshapeAfterFullyConnected(%arg0: tensor<4x1024x1024xbf16>) -> tensor<4x1024x4096xbf16> {
|
||||
%cst_0 = constant dense<1.0> : tensor<4096x1024xbf16>
|
||||
%cst_1 = constant unit
|
||||
%cst_2 = constant dense<[4, 1024, 4096]> : tensor<3xi32>
|
||||
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x1024x1024xbf16>, tensor<4096x1024xbf16>, none) -> tensor<4096x4096xbf16>
|
||||
%1 = "tfl.reshape"(%0, %cst_2) : (tensor<4096x4096xbf16>, tensor<3xi32>) -> tensor<4x1024x4096xbf16>
|
||||
return %1 : tensor<4x1024x4096xbf16>
|
||||
// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{.*}}keep_num_dims = true{{.*}} -> tensor<4x1024x4096xbf16>
|
||||
// CHECK: return %[[V0]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RemoveReshapeAfterFullyConnectedAdd
|
||||
func @RemoveReshapeAfterFullyConnectedAdd(%arg0: tensor<4x1024x1024xbf16>) -> tensor<4x1024x4096xbf16> {
|
||||
%cst_0 = constant dense<1.0> : tensor<4096x1024xbf16>
|
||||
%cst_1 = constant unit
|
||||
%cst_2 = constant dense<[4, 1024, 4096]> : tensor<3xi32>
|
||||
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x1024x1024xbf16>, tensor<4096x1024xbf16>, none) -> tensor<4096x4096xbf16>
|
||||
%1 = "tfl.reshape"(%0, %cst_2) : (tensor<4096x4096xbf16>, tensor<3xi32>) -> tensor<4x1024x4096xbf16>
|
||||
%2 = "tfl.mul"(%1, %1) {fused_activation_function = "NONE"} : (tensor<4x1024x4096xbf16>, tensor<4x1024x4096xbf16>) -> tensor<4x1024x4096xbf16>
|
||||
return %2 : tensor<4x1024x4096xbf16>
|
||||
// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{.*}}keep_num_dims = true{{.*}} -> tensor<4x1024x4096xbf16>
|
||||
// CHECK: %[[V1:.*]] = tfl.mul %[[V0]], %[[V0]] {{.*}} : tensor<4x1024x4096xbf16
|
||||
// CHECK: return %[[V1]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DontFuseAddWithConvActivationFunc
|
||||
func @DontFuseAddWithConvActivationFunc(%arg0: tensor<1x3x1x1xf32>) -> tensor<1x2x1x3xf32> {
|
||||
%cst = constant dense<1.5> : tensor<1xf32>
|
||||
|
|
|
|||
|
|
@ -305,3 +305,52 @@ func @NotQuantizePow(%arg0: tensor<4x!quant.uniform<u8:f32, 1.0>>,
|
|||
|
||||
// DEBUG-NOT: "tfl.NumericVerify"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeCustomTfOp
|
||||
// DEBUG-LABEL: QuantizeCustomTfOp
|
||||
func @QuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>,
|
||||
%arg1: tensor<1x!quant.uniform<u8:f32, 0.2:127>>, %arg2: tensor<1x!quant.uniform<u8:f32, 0.4:127>>,
|
||||
%arg3: tensor<1xi32>) -> (tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>) {
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>) -> tensor<128x128xf32>
|
||||
%1 = "tfl.dequantize"(%arg1) : (tensor<1x!quant.uniform<u8:f32, 0.2:127>>) -> tensor<1xf32>
|
||||
%2 = "tfl.dequantize"(%arg2) : (tensor<1x!quant.uniform<u8:f32, 0.4:127>>) -> tensor<1xf32>
|
||||
%3 = "tfl.custom_tf"(%0, %1, %2, %arg3) ( {
|
||||
^bb0(%a1: tensor<128x128xf32>, %a2: tensor<1xf32>, %a3: tensor<1xf32>, %a4: tensor<1xi32>): // no predecessors
|
||||
%4 = "tf.LayerNorm"(%a1, %a2, %a3, %a4) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
"tfl.yield"(%4) : (tensor<128x128xf32>) -> ()
|
||||
}) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
%4 = "tfl.quantize"(%3) {qtype = tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>} : (tensor<128x128xf32>) -> tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
|
||||
return %4 : tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
|
||||
|
||||
// CHECK: %4 = "tfl.custom_tf"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||
// CHECK-NEXT: ^bb0(%arg4: tensor<128x128xf32>, %arg5: tensor<1xf32>, %arg6: tensor<1xf32>, %arg7: tensor<1xi32>): // no predecessors
|
||||
// CHECK-NEXT: "tf.LayerNorm"(%arg4, %arg5, %arg6, %arg7) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: "tfl.yield"
|
||||
// CHECK-NEXT: }) {_tfl_quant_trait = "fully_quantizable", device = ""} :
|
||||
// CHECK-SAME: (tensor<128x128x!quant.uniform<u8:f32, 1.000000e-01:127>>, tensor<1x!quant.uniform<u8:f32, 2.000000e-01:127>>, tensor<1x!quant.uniform<u8:f32, 4.000000e-01:127>>, tensor<1xi32>)
|
||||
// CHECK-SAME: -> tensor<128x128x!quant.uniform<u8:f32, 2.000000e-01:125>>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: NotQuantizeCustomTfOp
|
||||
// DEBUG-LABEL: NotQuantizeCustomTfOp
|
||||
func @NotQuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>,
|
||||
%arg1: tensor<1x!quant.uniform<u8:f32, 0.2:127>>, %arg2: tensor<1x!quant.uniform<u8:f32, 0.4:127>>,
|
||||
%arg3: tensor<1xi32>) -> (tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>) {
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>) -> tensor<128x128xf32>
|
||||
%1 = "tfl.dequantize"(%arg1) : (tensor<1x!quant.uniform<u8:f32, 0.2:127>>) -> tensor<1xf32>
|
||||
%2 = "tfl.dequantize"(%arg2) : (tensor<1x!quant.uniform<u8:f32, 0.4:127>>) -> tensor<1xf32>
|
||||
%3 = "tfl.custom_tf"(%0, %1, %2, %arg3) ( {
|
||||
^bb0(%a1: tensor<128x128xf32>, %a2: tensor<1xf32>, %a3: tensor<1xf32>, %a4: tensor<1xi32>): // no predecessors
|
||||
%4 = "tf.LayerNorm"(%a1, %a2, %a3, %a4) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
"tfl.yield"(%4) : (tensor<128x128xf32>) -> ()
|
||||
}) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
%4 = "tfl.quantize"(%3) {qtype = tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>} : (tensor<128x128xf32>) -> tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
|
||||
return %4 : tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
|
||||
|
||||
// CHECK: "tfl.custom_tf"
|
||||
// CHECK-NEXT: ^bb0(%arg4: tensor<128x128xf32>, %arg5: tensor<1xf32>, %arg6: tensor<1xf32>, %arg7: tensor<1xi32>): // no predecessors
|
||||
// CHECK-NEXT: "tf.LayerNorm"(%arg4, %arg5, %arg6, %arg7) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: "tfl.yield"
|
||||
// CHECK-NEXT: }) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
|
|
@ -219,6 +220,9 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
|||
string(reinterpret_cast<const char*>(q_buffer), q_builder.GetSize());
|
||||
}
|
||||
|
||||
if (mlir::failed(module.verify())) {
|
||||
return tensorflow::errors::Unknown("Final module is invalid");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -123,7 +123,11 @@ class InitializeVariablesPass
|
|||
// with ops that accepts resource as input.
|
||||
if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(op))
|
||||
return WalkResult::advance();
|
||||
tensors_to_initialize.insert(GetGlobalTensorOp(op, symbol_table, func));
|
||||
auto global_tensor = GetGlobalTensorOp(op, symbol_table, func);
|
||||
// In case the function doesn't have bound_input to a resource
|
||||
// then we return nullptr.
|
||||
// We need only to initialize the variables that are bounded.
|
||||
if (global_tensor) tensors_to_initialize.insert(global_tensor);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
|
@ -154,9 +158,9 @@ class InitializeVariablesPass
|
|||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
// Use ordered container to make sure ids are deterministic if we got tensor
|
||||
// ids from different part, since we have different passes that touches
|
||||
// variables.
|
||||
// Use ordered container to make sure ids are deterministic if we got
|
||||
// tensor ids from different part, since we have different passes that
|
||||
// touches variables.
|
||||
// TODO(b/149099381): Remove integer IDs after adding the new variable
|
||||
// handle type.
|
||||
std::map<std::string, int> global_tensor_id;
|
||||
|
|
|
|||
|
|
@ -240,15 +240,16 @@ def LegalizeGreaterEqual : Pat<(TF_GreaterEqualOp $l, $r),
|
|||
// The 'validate_indices' attribute is deprecated.
|
||||
def LegalizeGather: Pat<
|
||||
(TF_GatherOp $params, $indices, $ignored_validate_indices),
|
||||
(TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">)>;
|
||||
(TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">,
|
||||
ConstantAttr<I32Attr, "0">)>;
|
||||
|
||||
def LegalizeGatherNd : Pat<(TF_GatherNdOp $params, $indices),
|
||||
(TFL_GatherNdOp $params, $indices)>;
|
||||
|
||||
def LegalizeGatherV2 : Pat<
|
||||
(TF_GatherV2Op $params, $indices, (ConstantOp ElementsAttr:$axis),
|
||||
ConstantAttr<I64Attr, "0">:$batch_dims),
|
||||
(TFL_GatherOp $params, $indices, ExtractSingleElementAsInt32:$axis)>;
|
||||
(TF_GatherV2Op $params, $indices, (ConstantOp ElementsAttr:$axis), $batch_dims),
|
||||
(TFL_GatherOp $params, $indices, ExtractSingleElementAsInt32:$axis,
|
||||
(convertIntAttrTo32Bit $batch_dims))>;
|
||||
|
||||
def LegalizeFloorDiv : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
|
||||
|
||||
|
|
|
|||
|
|
@ -949,7 +949,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
|||
// Create a new while op with new operands and updated result types.
|
||||
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
|
||||
operands, op->getAttrs());
|
||||
converted.removeAttr("T");
|
||||
converted->removeAttr("T");
|
||||
(void)UpdateFunctionTypes(rewriter, converted, tensor_list_args);
|
||||
|
||||
rewriter.replaceOp(op, converted.getResults());
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
|
|
@ -1117,6 +1118,59 @@ struct RemoveReshapeBeforeFullyConnected
|
|||
}
|
||||
};
|
||||
|
||||
// Remove Reshape after FullyConnected when `keep_num_dims=false`, the Reshaoe
|
||||
// does not alter the last dimension and it restores the batch dimensions
|
||||
// collapsed by the FullyConnected op due to `keep_num_dims=false`. For example,
|
||||
//
|
||||
// // %input: tensor<4x16x32xf32>
|
||||
// %fc = tfl.fully_connected(%input, %filter, %bias)
|
||||
// {keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// %shape = constant dense<[4, 16, 32]> : tensor<3xi32>
|
||||
// %rs = tfl.reshape(%fc, %shape)
|
||||
//
|
||||
// can be canonicalized to
|
||||
//
|
||||
// %fc = tfl.fully_connected(%input, %filter, %bias)
|
||||
// {keep_num_dims = true, weights_format = "DEFAULT"}
|
||||
struct RemoveReshapeAfterFullyConnected
|
||||
: public OpRewritePattern<TFL::ReshapeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TFL::ReshapeOp reshape_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto fully_connected_op = llvm::dyn_cast_or_null<TFL::FullyConnectedOp>(
|
||||
reshape_op.input().getDefiningOp());
|
||||
if (!fully_connected_op || fully_connected_op.getNumResults() != 1 ||
|
||||
fully_connected_op.weights_format() != "DEFAULT" ||
|
||||
fully_connected_op.keep_num_dims())
|
||||
return failure();
|
||||
if (!reshape_op.input().getUseList()->hasOneUse()) return failure();
|
||||
|
||||
auto input_shape = fully_connected_op.input().getType().cast<ShapedType>();
|
||||
auto output_shape = fully_connected_op.getType(0).cast<ShapedType>();
|
||||
auto reshape_shape = reshape_op.getType().cast<ShapedType>();
|
||||
if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() ||
|
||||
!reshape_shape.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Check that the reshape doesn't modify the last dimension and it restores
|
||||
// the input (batch) dimension with the exception of the feature (last)
|
||||
// dimension.
|
||||
if (output_shape.getShape().back() != reshape_shape.getShape().back() ||
|
||||
input_shape.getShape().drop_back() !=
|
||||
reshape_shape.getShape().drop_back())
|
||||
return failure();
|
||||
|
||||
llvm::SmallVector<Type, 1> output_type{reshape_op.getType()};
|
||||
rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
|
||||
reshape_op, output_type, fully_connected_op.input(),
|
||||
fully_connected_op.filter(), fully_connected_op.bias(),
|
||||
fully_connected_op.fused_activation_function(),
|
||||
fully_connected_op.weights_format(), /*keep_num_dims=*/true);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
using FuseBinaryOpToFollowingFullyConnected =
|
||||
FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
|
||||
using FuseBinaryOpToFollowingDepthwiseConv2D =
|
||||
|
|
@ -1135,6 +1189,13 @@ void OptimizePass::runOnFunction() {
|
|||
auto *ctx = &getContext();
|
||||
auto func = getFunction();
|
||||
|
||||
// Merge reshapes into fully connected ops before we start moving them past
|
||||
// binary ops.
|
||||
OwningRewritePatternList phase_0_patterns;
|
||||
phase_0_patterns.insert<RemoveReshapeAfterFullyConnected,
|
||||
RemoveReshapeBeforeFullyConnected>(ctx);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns));
|
||||
|
||||
// Potentially the binary ops might be fused together, like hard_swish, thus
|
||||
// we explore these potentially first and then fuse the binary ops with the
|
||||
// following ops in a second pattern match.
|
||||
|
|
@ -1161,7 +1222,7 @@ void OptimizePass::runOnFunction() {
|
|||
FuseBinaryOpToFollowingDepthwiseConv2D,
|
||||
FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
|
||||
FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp,
|
||||
RemoveReshapeBeforeFullyConnected>(ctx);
|
||||
RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected>(ctx);
|
||||
if (enable_canonicalization_)
|
||||
AddCanonicalizationPatterns(ctx, &phase_2_patterns);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
|
||||
|
|
|
|||
|
|
@ -363,6 +363,7 @@ cc_library(
|
|||
],
|
||||
hdrs = [
|
||||
"ir/tf_attributes.h",
|
||||
"ir/tf_dialect.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
|
|
@ -389,6 +390,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "tensorflow_" + target["name"],
|
||||
srcs = [
|
||||
"ir/tf_dialect.h",
|
||||
"ir/tf_ops.h",
|
||||
"ir/tfrt_ops.h",
|
||||
"ir/tf_remaining_ops.h",
|
||||
|
|
@ -433,6 +435,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "tensorflow_remaining_ops",
|
||||
srcs = [
|
||||
"ir/tf_dialect.h",
|
||||
"ir/tf_ops.h",
|
||||
"ir/tf_remaining_ops.h",
|
||||
"ir/tf_remaining_ops.cc",
|
||||
|
|
@ -476,6 +479,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "tensorflow_tfrt_ops",
|
||||
srcs = [
|
||||
"ir/tf_dialect.h",
|
||||
"ir/tf_ops.h",
|
||||
"ir/tfrt_ops.h",
|
||||
"ir/tfrt_ops.cc",
|
||||
|
|
@ -519,6 +523,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "tensorflow_ops",
|
||||
srcs = [
|
||||
"ir/tf_dialect.h",
|
||||
"ir/tf_ops.cc",
|
||||
"ir/tf_ops.h",
|
||||
],
|
||||
|
|
@ -592,6 +597,7 @@ cc_library(
|
|||
"ir/tf_types.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tf_dialect.h",
|
||||
"ir/tf_types.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
|
|
@ -616,6 +622,7 @@ cc_library(
|
|||
hdrs = [
|
||||
"dialect_registration.h",
|
||||
"ir/tf_device.h",
|
||||
"ir/tf_dialect.h",
|
||||
"ir/tf_executor.h",
|
||||
"ir/tf_ops.h",
|
||||
"ir/tf_saved_model.h",
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
|
|
@ -131,5 +133,9 @@ DictionaryAttr FuncAttr::GetAttrs() const {
|
|||
return getImpl()->attrs.cast<DictionaryAttr>();
|
||||
}
|
||||
|
||||
void TensorFlowDialect::registerAttributes() {
|
||||
addAttributes<ShapeAttr, FuncAttr>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
|
|
|||
138
tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h
Normal file
138
tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the standard MLIR TensorFlow dialect after control
|
||||
// dependences are raise to the standard form.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
class ResourceType;
|
||||
class VariantType;
|
||||
|
||||
class TensorFlowDialect : public Dialect {
|
||||
public:
|
||||
TensorFlowDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "tf"; }
|
||||
|
||||
// Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
|
||||
// function references to its gradient function. This attribute in TensorFlow
|
||||
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
|
||||
// string description of gradient attribute.
|
||||
static StringRef GetGradientAttrName() { return "tf.gradient"; }
|
||||
|
||||
// This attribute marks if a function is stateful.
|
||||
// Returns the string description of stateful attribute.
|
||||
static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; }
|
||||
|
||||
// Returns true if the op can be duplicated during transformations.
|
||||
static bool CanDuplicate(Operation *op);
|
||||
|
||||
// Returns true if the op can have side effects.
|
||||
static bool CanHaveSideEffects(Operation *op);
|
||||
|
||||
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
|
||||
|
||||
void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
|
||||
|
||||
// Parse a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser &parser) const override;
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void printType(Type ty, DialectAsmPrinter &os) const override;
|
||||
|
||||
// Parses resource type with potential subtypes.
|
||||
Type ParseResourceType(DialectAsmParser &parser) const;
|
||||
|
||||
// Prints resource type with potential subtypes.
|
||||
void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) const;
|
||||
|
||||
// Parse and print variant type. It may have subtypes inferred using shape
|
||||
// inference.
|
||||
Type ParseVariantType(DialectAsmParser &parser) const;
|
||||
void PrintVariantType(VariantType ty, DialectAsmPrinter &os) const;
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
|
||||
typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
|
||||
|
||||
// Register an op registration hook which is invoked during construction.
|
||||
//
|
||||
// A hook may use the public addOperations() method to add additional
|
||||
// operations to the dialect. Hooks will only apply to subsequent
|
||||
// instantations of the Dialect/MLIRContext.
|
||||
static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
|
||||
GetAdditionalOperationHooks()->push_back(std::move(fn));
|
||||
}
|
||||
|
||||
// Re-define publicly the protected addOperations() method from the Dialect
|
||||
// class, usually used in a Dialect constructor. This allows hook
|
||||
// functions to register operations on the TensorFlow dialect using the
|
||||
// same interface.
|
||||
template <typename... Args>
|
||||
void addOperations() {
|
||||
Dialect::addOperations<Args...>();
|
||||
}
|
||||
|
||||
using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &);
|
||||
static void RegisterConstantFoldHook(ConstantFoldHook fn) {
|
||||
constant_fold_hook_ = std::move(fn);
|
||||
}
|
||||
|
||||
static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
|
||||
return failure();
|
||||
}
|
||||
|
||||
using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
|
||||
ElementsAttr &output);
|
||||
static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
|
||||
decode_constant_hook_ = std::move(fn);
|
||||
}
|
||||
static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
|
||||
if (decode_constant_hook_) return decode_constant_hook_(input, output);
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Register the attributes of this dialect.
|
||||
void registerAttributes();
|
||||
/// Register the types of this dialect.
|
||||
void registerTypes();
|
||||
|
||||
// Hook functions which may add additional operations to the dialect.
|
||||
// These are invoked at construction time.
|
||||
static std::vector<AdditionalOpFunction> *GetAdditionalOperationHooks();
|
||||
|
||||
static ConstantFoldHook constant_fold_hook_;
|
||||
static DecodeConstantHook decode_constant_hook_;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_
|
||||
|
|
@ -65,6 +65,8 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
|
|
@ -175,9 +177,15 @@ bool TensorFlowDialect::CanDuplicate(Operation *op) {
|
|||
if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless"))
|
||||
return is_stateless.getValue();
|
||||
|
||||
// Otherwise, assume ops can be duplicated by default if its registered, else
|
||||
// it cannot be for unknown ops.
|
||||
return op->isRegistered();
|
||||
// Assume ops can be duplicated when the given op is not a stateful op.
|
||||
const tensorflow::OpRegistrationData *op_reg_data = nullptr;
|
||||
tensorflow::Status s = tensorflow::OpRegistry::Global()->LookUp(
|
||||
op->getName().stripDialect().str(), &op_reg_data);
|
||||
if (!s.ok()) {
|
||||
// Assume unknown ops can not be duplicated.
|
||||
return false;
|
||||
}
|
||||
return !op_reg_data->op_def.is_stateful();
|
||||
}
|
||||
|
||||
// Returns true if the op can have side effects.
|
||||
|
|
@ -217,14 +225,10 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
|||
#define GET_OP_LIST
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc.inc"
|
||||
>();
|
||||
addTypes<
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
registerTypes();
|
||||
addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
|
||||
TFConstantFoldInterface>();
|
||||
addAttributes<ShapeAttr, FuncAttr>();
|
||||
registerAttributes();
|
||||
|
||||
// Support unknown operations because not all TensorFlow operations are
|
||||
// registered.
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
|
||||
|
|
@ -45,109 +46,4 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
class TensorFlowDialect : public Dialect {
|
||||
public:
|
||||
TensorFlowDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "tf"; }
|
||||
|
||||
// Gradient attribute ("tf.gradient") in the list of NamedAttributes in a
|
||||
// function references to its gradient function. This attribute in TensorFlow
|
||||
// Dialect is used to model TF GradientDef. GetGradientAttrName() returns the
|
||||
// string description of gradient attribute.
|
||||
static StringRef GetGradientAttrName() { return "tf.gradient"; }
|
||||
|
||||
// This attribute marks if a function is stateful.
|
||||
// Returns the string description of stateful attribute.
|
||||
static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; }
|
||||
|
||||
// Returns true if the op can be duplicated during transformations.
|
||||
static bool CanDuplicate(Operation *op);
|
||||
|
||||
// Returns true if the op can have side effects.
|
||||
static bool CanHaveSideEffects(Operation *op);
|
||||
|
||||
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
|
||||
|
||||
void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
|
||||
|
||||
// Parse a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser &parser) const override;
|
||||
|
||||
// Prints a type registered to this dialect.
|
||||
void printType(Type ty, DialectAsmPrinter &os) const override;
|
||||
|
||||
// Parses resource type with potential subtypes.
|
||||
Type ParseResourceType(DialectAsmParser &parser) const;
|
||||
|
||||
// Prints resource type with potential subtypes.
|
||||
void PrintResourceType(ResourceType ty, DialectAsmPrinter &os) const;
|
||||
|
||||
// Parse and print variant type. It may have subtypes inferred using shape
|
||||
// inference.
|
||||
Type ParseVariantType(DialectAsmParser &parser) const;
|
||||
void PrintVariantType(VariantType ty, DialectAsmPrinter &os) const;
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
|
||||
typedef std::function<void(TensorFlowDialect &dialect)> AdditionalOpFunction;
|
||||
|
||||
// Register an op registration hook which is invoked during construction.
|
||||
//
|
||||
// A hook may use the public addOperations() method to add additional
|
||||
// operations to the dialect. Hooks will only apply to subsequent
|
||||
// instantations of the Dialect/MLIRContext.
|
||||
static void RegisterAdditionalOperationHook(AdditionalOpFunction fn) {
|
||||
GetAdditionalOperationHooks()->push_back(std::move(fn));
|
||||
}
|
||||
|
||||
// Re-define publicly the protected addOperations() method from the Dialect
|
||||
// class, usually used in a Dialect constructor. This allows hook
|
||||
// functions to register operations on the TensorFlow dialect using the
|
||||
// same interface.
|
||||
template <typename... Args>
|
||||
void addOperations() {
|
||||
Dialect::addOperations<Args...>();
|
||||
}
|
||||
|
||||
using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &);
|
||||
static void RegisterConstantFoldHook(ConstantFoldHook fn) {
|
||||
constant_fold_hook_ = std::move(fn);
|
||||
}
|
||||
|
||||
static LogicalResult constantFold(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
if (constant_fold_hook_) return constant_fold_hook_(op, operands, results);
|
||||
return failure();
|
||||
}
|
||||
|
||||
using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input,
|
||||
ElementsAttr &output);
|
||||
static void RegisterDecodeConstantHook(DecodeConstantHook fn) {
|
||||
decode_constant_hook_ = std::move(fn);
|
||||
}
|
||||
static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) {
|
||||
if (decode_constant_hook_) return decode_constant_hook_(input, output);
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
// Hook functions which may add additional operations to the dialect.
|
||||
// These are invoked at construction time.
|
||||
static std::vector<AdditionalOpFunction> *GetAdditionalOperationHooks();
|
||||
|
||||
static ConstantFoldHook constant_fold_hook_;
|
||||
static DecodeConstantHook decode_constant_hook_;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
|
||||
|
|
|
|||
|
|
@ -1037,6 +1037,25 @@ def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
|
|||
TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
// TODO(b/177675373): Make dtypes and shapes derived attributes,
|
||||
// use more general solution.
|
||||
def TF_InfeedEnqueueTupleOp : TF_Op<"InfeedEnqueueTuple", []> {
|
||||
let summary = [{
|
||||
Feeds multiple Tensor values into the computation as an XLA tuple.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<Variadic<TF_Tensor>, [{A list of tensors that will be provided using the infeed mechanism.}]>:$inputs,
|
||||
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$dtypes,
|
||||
TF_ShapeAttrArray:$shapes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$layouts,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$device_ordinal
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> {
|
||||
let summary = "Formats a string template using a list of tensors.";
|
||||
|
||||
|
|
|
|||
|
|
@ -739,11 +739,9 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
|
|||
|
||||
auto func = op.branches()[index].cast<SymbolRefAttr>();
|
||||
auto empty = rewriter.getStringAttr("");
|
||||
auto call_op = rewriter.create<PartitionedCallOp>(
|
||||
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
|
||||
ReplaceTfOpWithNewOp<PartitionedCallOp>(
|
||||
rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func,
|
||||
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
||||
CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
|
||||
rewriter.replaceOp(op, call_op.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
@ -2510,11 +2508,9 @@ LogicalResult FoldConstantIfOp::matchAndRewrite(
|
|||
// Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
|
||||
auto rewrite = [&](auto op_type) {
|
||||
auto empty = rewriter.getStringAttr("");
|
||||
auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
|
||||
op.getLoc(), op.getResultTypes(), op.input(), func,
|
||||
ReplaceTfOpWithNewOp<typename decltype(op_type)::CallOp>(
|
||||
rewriter, op, op.getResultTypes(), op.input(), func,
|
||||
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
||||
CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
|
||||
rewriter.replaceOp(op, call_op.getResults());
|
||||
};
|
||||
|
||||
if (op.is_stateless())
|
||||
|
|
|
|||
|
|
@ -548,7 +548,7 @@ struct DropAttributes : public OpRewritePattern<Op> {
|
|||
// Drop the "output_shapes" attribute.
|
||||
LogicalResult matchAndRewrite(Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool found = !!op.removeAttr("output_shapes");
|
||||
bool found = !!op->removeAttr("output_shapes");
|
||||
return success(found);
|
||||
}
|
||||
};
|
||||
|
|
@ -581,3 +581,23 @@ ResourceHandleValueAndId GetResourceHandleValueAndIdBase(
|
|||
if (emplace_res.second) ++next_id;
|
||||
return {resource, emplace_res.first->second};
|
||||
}
|
||||
|
||||
// Helper function to create TF op while copying all underscore attributes from
|
||||
// another TF op.
|
||||
// TODO(jpienaar): This is a workaround until behavior is established.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy CreateTfOp(RewriterBase& b, Operation *op, Args &&... args) {
|
||||
auto ret = b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
||||
CopyDeviceAndUnderscoredAttributes(op, ret.getOperation());
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Helper function to replace TF op with another op while copying all underscore
|
||||
// attributes from the TF op.
|
||||
// TODO(jpienaar): This is a workaround until behavior is established.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy ReplaceTfOpWithNewOp(RewriterBase& b, Operation *op, Args &&... args) {
|
||||
auto ret = CreateTfOp<OpTy>(b, op, std::forward<Args>(args)...);
|
||||
b.replaceOp(op, ret.getOperation()->getResults());
|
||||
return ret;
|
||||
}
|
||||
|
|
@ -335,16 +335,9 @@ struct ConvertPackToReshape : public OpRewritePattern<PackOp> {
|
|||
auto shape_attr = DenseIntElementsAttr::get(type, output_ty.getShape());
|
||||
auto shape = rewriter.create<ConstOp>(pack_op.getLoc(), shape_attr);
|
||||
|
||||
auto reshape_op = rewriter.create<ReshapeOp>(pack_op.getLoc(), output_ty,
|
||||
pack_op.getOperand(0), shape);
|
||||
// Preserve unregistered attributes. Outside compilation relies on
|
||||
// unregistered attribute `_xla_outside_compilation` to form clusters, so
|
||||
// they must be preserved during canonicalization.
|
||||
// TODO(b/173622615): Remove after fixed.
|
||||
CopyUnderscoredAttributes(pack_op.getOperation(),
|
||||
reshape_op.getOperation());
|
||||
|
||||
rewriter.replaceOp(pack_op, reshape_op.getResult());
|
||||
ReplaceTfOpWithNewOp<ReshapeOp>(rewriter, pack_op, output_ty,
|
||||
pack_op.getOperand(0), shape);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
|
||||
|
||||
namespace {
|
||||
// Returns the shape of the given value if it's ranked; returns llvm::None
|
||||
|
|
@ -409,5 +410,13 @@ Type DropRefType(Type ty) { return DropTypeHelper<TF::TensorFlowRefType>(ty); }
|
|||
|
||||
Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
|
||||
|
||||
void TensorFlowDialect::registerTypes() {
|
||||
addTypes<
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
|
|
|||
|
|
@ -9,4 +9,4 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: HloModule main
|
||||
// CHECK: -> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0])
|
||||
// CHECK: -> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], /*index=5*/f32[0])
|
||||
|
|
|
|||
|
|
@ -35,6 +35,24 @@ func private @inline_shape_cast_callee(%arg : tensor<*xi32>) -> tensor<*xi32> {
|
|||
return %arg : tensor<*xi32>
|
||||
}
|
||||
|
||||
func private @custom_callee() -> tensor<2xi32> {
|
||||
%0 = "tf.CustomTFOp"() : () -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// Test that unregistered user-defined custom TF operations can not be inlined
|
||||
// when there are duplicated cases.
|
||||
|
||||
// CHECK-LABEL: func @dont_inline_custom_on_duplicated_cases(
|
||||
func @dont_inline_custom_on_duplicated_cases() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: "tf.PartitionedCall"
|
||||
// CHECK-NEXT: "tf.PartitionedCall"
|
||||
// CHECK-NEXT: return
|
||||
%0 = "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @custom_callee} : () -> tensor<2xi32>
|
||||
%1 = "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @custom_callee} : () -> tensor<2xi32>
|
||||
return %1: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inline_shape_cast(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi32>
|
||||
func @inline_shape_cast(%arg: tensor<2xi32>) -> tensor<2xi32> {
|
||||
|
|
|
|||
|
|
@ -1026,3 +1026,23 @@ func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i64> {
|
|||
// CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: return %[[PROD]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @is_finite
|
||||
func @is_finite(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> {
|
||||
%0 = "tf.IsFinite"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1>
|
||||
return %0 : tensor<3x4xi1>
|
||||
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %arg0) : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @is_finite_dynamic
|
||||
func @is_finite_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xi1> {
|
||||
%0 = "tf.IsFinite"(%arg0) : (tensor<?x4xf32>) -> tensor<?x4xi1>
|
||||
return %0 : tensor<?x4xi1>
|
||||
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %arg0) : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[SUB]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<?x4xf32>, tensor<f32>) -> tensor<?x4xi1>
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -359,6 +359,31 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
|||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on non-aliasing WhileRegion input/output
|
||||
// resources. It cannot lift resource ops from such WhileRegion ops and should
|
||||
// fail with a helpful error message.
|
||||
|
||||
func @fail_non_aliasing_resource_input_output() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.cluster"() ( {
|
||||
// expected-error@+1 {{Result #0 is not tied to arg #0 of the body}}
|
||||
%1 = "tf.WhileRegion"(%0) ({
|
||||
^bb0(%carg0:tensor<*x!tf.resource<tensor<f32>>>):
|
||||
%cond = "tf.SomeOp"() : () -> tensor<i1>
|
||||
"tf.Yield"(%cond) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%carg0:tensor<*x!tf.resource<tensor<f32>>>):
|
||||
%body = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf.Yield"(%body) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
|
||||
}) { is_stateless = false }
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass reports error on unsupported ops in loop cond.
|
||||
|
||||
func @cluster_with_loop() -> () {
|
||||
|
|
@ -833,6 +858,9 @@ func @cluster_with_caseregion(%arg0: tensor<i32>) -> tensor<4xf32> {
|
|||
// -----
|
||||
|
||||
// Test that the pass can lift resources out of WhileRegion
|
||||
|
||||
!tf_ref = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
|
||||
// CHECK-LABEL: func @cluster_with_whileregion
|
||||
func @cluster_with_whileregion() -> () {
|
||||
// CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor<i32>}
|
||||
|
|
@ -841,16 +869,17 @@ func @cluster_with_whileregion() -> () {
|
|||
// CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"()
|
||||
// CHECK: %[[WHILE:.*]]:2 = "tf.WhileRegion"(%[[COUNT]], %[[READ]])
|
||||
%0 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_ref
|
||||
%pass_through = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> !tf_ref
|
||||
%unused = "tf.VarHandleOp"() {container = "c", shared_name = "v3"} : () -> !tf_ref
|
||||
"tf_device.cluster"() ( {
|
||||
%2:3 = "tf.WhileRegion"(%0, %1, %unused) ({
|
||||
%2:4 = "tf.WhileRegion"(%0, %1, %pass_through, %unused) ({
|
||||
// CHECK: (%[[CARG0:.+]]: tensor<i32>, %[[CARG1:.+]]: tensor<f32>):
|
||||
// CHECK: %[[CAST:.+]] = "tf.Cast"(%[[CARG1]])
|
||||
// CHECK: "tf.Less"(%[[CARG0]], %[[CAST]])
|
||||
// CHECK: "tf.Yield"
|
||||
^bb0(%carg0: tensor<i32>, %carg1:tensor<*x!tf.resource<tensor<f32>>>, %carg2: tensor<*x!tf.resource<tensor<f32>>>):
|
||||
%read0 = "tf.ReadVariableOp"(%carg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
^bb0(%carg0: tensor<i32>, %carg1: !tf_ref, %carg2: !tf_ref, %carg3: !tf_ref):
|
||||
%read0 = "tf.ReadVariableOp"(%carg1) : (!tf_ref) -> tensor<f32>
|
||||
%cast = "tf.Cast"(%read0) : (tensor<f32>) -> tensor<i32>
|
||||
%cond = "tf.Less"(%carg0, %cast) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%cond) : (tensor<i1>) -> ()
|
||||
|
|
@ -861,20 +890,20 @@ func @cluster_with_whileregion() -> () {
|
|||
// CHECK-NEXT: %[[DELTA:.*]] = "tf.Const"() {value = dense<-1> : tensor<i32>}
|
||||
// CHECK-NEXT: %[[ADD2:.*]] = "tf.AddV2"(%[[BARG0]], %[[DELTA]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[ADD2]], %[[ADD1]])
|
||||
^bb1(%barg0: tensor<i32>, %barg1:tensor<*x!tf.resource<tensor<f32>>>, %barg2: tensor<*x!tf.resource<tensor<f32>>>):
|
||||
%read0 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
^bb1(%barg0: tensor<i32>, %barg1: !tf_ref, %barg2: !tf_ref, %barg3: !tf_ref):
|
||||
%read0 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
|
||||
%add0 = "tf.AddV2"(%read0, %read0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%barg1, %add0) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
%read1 = "tf.ReadVariableOp"(%barg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%barg1, %add0) : (!tf_ref, tensor<f32>) -> ()
|
||||
%read1 = "tf.ReadVariableOp"(%barg1) : (!tf_ref) -> tensor<f32>
|
||||
%add1 = "tf.AddV2"(%read1, %read1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%barg1, %add1) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
"tf.AssignVariableOp"(%barg1, %add1) : (!tf_ref, tensor<f32>) -> ()
|
||||
%constant = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%add2 = "tf.AddV2"(%barg0, %constant) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%id = "tf.Identity"(%barg2) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf.Yield"(%add2, %barg1, %id) : (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) -> ()
|
||||
%id = "tf.Identity"(%barg3) : (!tf_ref) -> !tf_ref
|
||||
"tf.Yield"(%add2, %barg1, %pass_through, %id) : (tensor<i32>, !tf_ref, !tf_ref, !tf_ref) -> ()
|
||||
}) {device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
: (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
|
||||
-> (tensor<i32>, !tf_ref, !tf_ref, !tf_ref)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
// CHECK: tf_device.return %[[WHILE]]#1 : tensor<f32>
|
||||
|
|
|
|||
|
|
@ -40,6 +40,25 @@ func @read_only_resource(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor
|
|||
return %2 : tensor<i32>
|
||||
}
|
||||
|
||||
func private @computation_two_args(%arg0: tensor<i32>, %arg1: tensor<i32>)
|
||||
|
||||
// CHECK-LABEL: func @partitioned_variable_multiple_users
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf.resource<tensor<i32>>>, [[ARG1:%.+]]: tensor<!tf.resource<tensor<i32>>>)
|
||||
func @partitioned_variable_multiple_users(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) {
|
||||
// CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]])
|
||||
// CHECK-DAG: [[READ1:%.+]] = "tf.ReadVariableOp"([[ARG1]])
|
||||
// CHECK: [[INPUT0:%.+]] = "tf.TPUPartitionedInput"([[READ0]], [[READ1]])
|
||||
// CHECK-DAG: [[READ2:%.+]] = "tf.ReadVariableOp"([[ARG0]])
|
||||
// CHECK-DAG: [[READ3:%.+]] = "tf.ReadVariableOp"([[ARG1]])
|
||||
// CHECK: [[INPUT1:%.+]] = "tf.TPUPartitionedInput"([[READ2]], [[READ3]])
|
||||
%0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<i32>>>, tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
|
||||
%2 = "tf.ReadVariableOp"(%0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
|
||||
// CHECK: "tf_device.cluster_func"([[INPUT0]], [[INPUT1]])
|
||||
"tf_device.cluster_func"(%1, %2) {func = @computation_two_args, use_spmd_for_xla_partitioning = true} : (tensor<i32>, tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests unsupported cases and IR are not modified.
|
||||
|
||||
// CHECK-LABEL: func @no_spmd
|
||||
|
|
@ -86,16 +105,6 @@ func @resource_read_multiple_users(%arg0: tensor<!tf.resource<tensor<i32>>>, %ar
|
|||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @partitioned_variable_multiple_users
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: tensor<!tf.resource<tensor<i32>>>, [[ARG1:%.+]]: tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
|
||||
func @partitioned_variable_multiple_users(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>> {
|
||||
// CHECK: "tf.TPUPartitionedInput"([[ARG0]], [[ARG1]])
|
||||
%0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<i32>>>, tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
|
||||
%2 = "tf_device.cluster_func"(%1) {func = @computation} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<!tf.resource<tensor<i32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @non_resource_read_input_write_output
|
||||
func @non_resource_read_input_write_output(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NOT: tf.TPUPartitionedInput
|
||||
|
|
|
|||
|
|
@ -409,3 +409,13 @@ def LowerXlog1pyOp : BinaryXopyPat<
|
|||
def LowerXlogyOp : BinaryXopyPat<
|
||||
(TF_XlogyOp $x, $y),
|
||||
(TF_MulOp $x, (TF_LogOp $y))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IsFinite op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LowerIsFiniteOp : Pat<(TF_IsFiniteOp $x),
|
||||
(TF_EqualOp
|
||||
(TF_SubOp $x, $x),
|
||||
(TF_ConstOp (GetScalarOfType<0> $x)),
|
||||
/*incompatible_shape_error*/ConstBoolAttrTrue)>;
|
||||
|
|
|
|||
|
|
@ -377,20 +377,21 @@ LogicalResult CanonicalizeWhileRegion(TF::WhileRegionOp op) {
|
|||
for (OpResult result : llvm::reverse(op.getResults())) {
|
||||
if (!IsResource(result)) continue;
|
||||
int result_idx = result.getResultNumber();
|
||||
auto body_arg = body.front()
|
||||
.getTerminator()
|
||||
->getOperand(result_idx)
|
||||
.dyn_cast<BlockArgument>();
|
||||
if (!body_arg || body_arg.getArgNumber() != result_idx) {
|
||||
Operation *yield_op = body.front().getTerminator();
|
||||
Value yield_operand = yield_op->getOperand(result_idx);
|
||||
Value while_operand = op.getOperand(result_idx);
|
||||
Value body_arg = body.getArgument(result_idx);
|
||||
Value cond_arg = cond.getArgument(result_idx);
|
||||
if (yield_operand != body_arg && yield_operand != while_operand) {
|
||||
return op.emitOpError("Result #") << result_idx << " is not tied to arg #"
|
||||
<< result_idx << " of the body";
|
||||
}
|
||||
body.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
|
||||
cond.getArgument(result_idx).replaceAllUsesWith(op.getOperand(result_idx));
|
||||
body_arg.replaceAllUsesWith(while_operand);
|
||||
cond_arg.replaceAllUsesWith(while_operand);
|
||||
result.replaceAllUsesWith(while_operand);
|
||||
body.front().getTerminator()->eraseOperand(result_idx);
|
||||
body.eraseArgument(result_idx);
|
||||
cond.eraseArgument(result_idx);
|
||||
result.replaceAllUsesWith(op.getOperand(result_idx));
|
||||
op.getOperation()->eraseOperand(result_idx);
|
||||
can_eliminate.set(result_idx);
|
||||
}
|
||||
|
|
@ -434,7 +435,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
|
|||
if (while_region.cond().walk(check_while_cond).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
// For while region, the body input and output arg should match.
|
||||
(void)CanonicalizeWhileRegion(while_region);
|
||||
result = CanonicalizeWhileRegion(while_region);
|
||||
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
|
||||
FuncOp func = dyn_cast<FuncOp>(call.resolveCallable());
|
||||
if (!func) return WalkResult::interrupt();
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
|
|||
if (!read_var || !read_var.value().hasOneUse()) continue;
|
||||
auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
|
||||
read_var.resource().getDefiningOp());
|
||||
if (!partitioned_input || !partitioned_input.output().hasOneUse() ||
|
||||
if (!partitioned_input ||
|
||||
!AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
|
||||
continue;
|
||||
|
||||
|
|
@ -135,7 +135,7 @@ void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
|
|||
partitioned_input._XlaShardingAttr());
|
||||
operand.set(partitioned_read);
|
||||
read_var->erase();
|
||||
partitioned_input->erase();
|
||||
if (partitioned_input->use_empty()) partitioned_input->erase();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ void TPUResourceReadForWritePass::runOnOperation() {
|
|||
|
||||
auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
|
||||
cluster_func.getLoc(), cluster_func.getResultTypes(), operands,
|
||||
cluster_func.getAttrs());
|
||||
cluster_func->getAttrs());
|
||||
cluster_func.replaceAllUsesWith(new_cluster_func);
|
||||
FuncOp func = cluster_func.getFunc();
|
||||
Block& block = func.front();
|
||||
|
|
|
|||
|
|
@ -557,7 +557,7 @@ Status ImporterBase::RemoveBackedges(const Graph& graph) {
|
|||
graph_ = absl::make_unique<Graph>(graph.flib_def());
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.add_default_attributes = false;
|
||||
opts.add_default_attributes = true;
|
||||
TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph(
|
||||
opts, std::move(graph_def), graph_.get()));
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user