mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Remove tensorexpr tests (#158928)"
This reverts commit517eebc1dd. Reverted https://github.com/pytorch/pytorch/pull/158928 on behalf of https://github.com/ZainRizvi due to Sorry but this breaks trunk test_jit_fuser_te.py::TestNNCOpInfoCPU::test_nnc_correctness_frac_cpu_bfloat16 [GH job link](https://github.com/pytorch/pytorch/actions/runs/16534544469/job/46768022799) [HUD commit link](517eebc1dd) ([comment](https://github.com/pytorch/pytorch/pull/158928#issuecomment-3122158944))
This commit is contained in:
parent
e2b2685f84
commit
f62772f365
|
|
@ -50,6 +50,9 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then
|
|||
export ATEN_THREADING=NATIVE
|
||||
fi
|
||||
|
||||
# Enable LLVM dependency for TensorExpr testing
|
||||
export USE_LLVM=/opt/llvm
|
||||
export LLVM_DIR=/opt/llvm/lib/cmake/llvm
|
||||
|
||||
if ! which conda; then
|
||||
# In ROCm CIs, we are doing cross compilation on build machines with
|
||||
|
|
@ -189,6 +192,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then
|
|||
export USE_ASAN=1
|
||||
export REL_WITH_DEB_INFO=1
|
||||
export UBSAN_FLAGS="-fno-sanitize-recover=all"
|
||||
unset USE_LLVM
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then
|
||||
|
|
|
|||
|
|
@ -1037,10 +1037,20 @@ test_libtorch_api() {
|
|||
mkdir -p $TEST_REPORTS_DIR
|
||||
|
||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml
|
||||
"$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml
|
||||
else
|
||||
# Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy
|
||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest"
|
||||
|
||||
# On s390x, pytorch is built without llvm.
|
||||
# Even if it would be built with llvm, llvm currently doesn't support used features on s390x and
|
||||
# test fails with errors like:
|
||||
# JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer
|
||||
# unknown file: Failure
|
||||
# C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) }
|
||||
if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
|
||||
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
|
||||
fi
|
||||
fi
|
||||
|
||||
# quantization is not fully supported on s390x yet
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <thread>
|
||||
|
||||
|
||||
|
|
@ -10,7 +9,7 @@
|
|||
// numbers of threads set and also whether the scheduler
|
||||
// will throw an exception when multiple threads call
|
||||
// their first parallel construct.
|
||||
static void test(int given_num_threads) {
|
||||
void test(int given_num_threads) {
|
||||
auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
|
||||
ASSERT_TRUE(given_num_threads >= 0);
|
||||
ASSERT_EQ(at::get_num_threads(), given_num_threads);
|
||||
|
|
@ -20,7 +19,7 @@ static void test(int given_num_threads) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(ThreadInitTest, ThreadInit) {
|
||||
int main() {
|
||||
at::init_num_threads();
|
||||
|
||||
at::set_num_threads(4);
|
||||
|
|
@ -33,11 +32,13 @@ TEST(ThreadInitTest, ThreadInit) {
|
|||
|
||||
#if !AT_PARALLEL_NATIVE
|
||||
at::set_num_threads(5);
|
||||
ASSERT_EQ(at::get_num_threads(), 5);
|
||||
ASSERT_TRUE(at::get_num_threads() == 5);
|
||||
#endif
|
||||
|
||||
// test inter-op settings
|
||||
at::set_num_interop_threads(5);
|
||||
ASSERT_EQ(at::get_num_interop_threads(), 5);
|
||||
ASSERT_ANY_THROW(at::set_num_interop_threads(6));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1346,6 +1346,10 @@ if(BUILD_TEST)
|
|||
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
|
||||
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/tensorexpr
|
||||
${CMAKE_BINARY_DIR}/test_tensorexpr
|
||||
)
|
||||
if(USE_DISTRIBUTED)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
|
||||
if(NOT WIN32)
|
||||
|
|
|
|||
83
test/cpp/tensorexpr/CMakeLists.txt
Normal file
83
test/cpp/tensorexpr/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr)
|
||||
|
||||
set(TENSOREXPR_TEST_SRCS
|
||||
${TENSOREXPR_TEST_ROOT}/test_approx.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_aten.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_boundsinference.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_conv.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_cpp_codegen.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_dynamic_shapes.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_expr.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_external_calls.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_graph_opt.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_ir_printer.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_ir_verifier.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_kernel.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_ops.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_quantization.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_memplanning.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_reductions.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_simplify.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_te_fuser_pass.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_type.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_type_specializations.cpp
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(USE_LLVM AND LLVM_FOUND)
|
||||
list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp)
|
||||
endif()
|
||||
|
||||
add_executable(test_tensorexpr
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp
|
||||
${TENSOREXPR_TEST_SRCS})
|
||||
|
||||
target_link_libraries(test_tensorexpr PRIVATE torch gtest_main)
|
||||
target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE})
|
||||
target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST)
|
||||
|
||||
add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp)
|
||||
target_link_libraries(tutorial_tensorexpr PRIVATE torch)
|
||||
target_include_directories(tutorial_tensorexpr PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
# The test case depends on the xnnpack header which in turn depends on the
|
||||
# pthreadpool header. For some build environment we need add the dependency
|
||||
# explicitly.
|
||||
if(USE_PTHREADPOOL)
|
||||
target_link_libraries(test_tensorexpr PRIVATE pthreadpool_interface)
|
||||
endif()
|
||||
if(USE_CUDA)
|
||||
target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA)
|
||||
target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA)
|
||||
elseif(USE_ROCM)
|
||||
target_link_libraries(test_tensorexpr PRIVATE
|
||||
hiprtc::hiprtc
|
||||
hip::amdhip64
|
||||
${TORCH_CUDA_LIBRARIES})
|
||||
target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM)
|
||||
|
||||
target_link_libraries(tutorial_tensorexpr PRIVATE
|
||||
hiprtc::hiprtc
|
||||
hip::amdhip64
|
||||
${TORCH_CUDA_LIBRARIES})
|
||||
target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM)
|
||||
endif()
|
||||
|
||||
if(INSTALL_TEST)
|
||||
set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib")
|
||||
install(TARGETS test_tensorexpr DESTINATION bin)
|
||||
set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib")
|
||||
install(TARGETS tutorial_tensorexpr DESTINATION bin)
|
||||
# Install PDB files for MSVC builds
|
||||
if(MSVC AND BUILD_SHARED_LIBS)
|
||||
install(FILES $<TARGET_PDB_FILE:test_tensorexpr> DESTINATION bin OPTIONAL)
|
||||
install(FILES $<TARGET_PDB_FILE:tutorial_tensorexpr> DESTINATION bin OPTIONAL)
|
||||
endif()
|
||||
endif()
|
||||
55
test/cpp/tensorexpr/README.md
Normal file
55
test/cpp/tensorexpr/README.md
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# TensorExpr C++ Tests
|
||||
|
||||
## How to add a new test
|
||||
First, create a new test file. Test files should have be placed in this
|
||||
directory, with a name that starts with `test_`, like `test_foo.cpp`.
|
||||
|
||||
Here is an example test file you can copy-paste.
|
||||
```cpp
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
|
||||
// Tests go in torch::jit
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// 1. Test cases are void() functions.
|
||||
// 2. They start with the prefix `test`
|
||||
void testCaseOne() {
|
||||
// ...
|
||||
}
|
||||
|
||||
void testCaseTwo() {
|
||||
// ...
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Then, register your test in `tests.h`:
|
||||
```cpp
|
||||
// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests
|
||||
#define TH_FORALL_TESTS(_) \
|
||||
_(ADFormulas) \
|
||||
_(Attributes) \
|
||||
...
|
||||
_(CaseOne) // note that the `test` prefix is omitted.
|
||||
_(CaseTwo)
|
||||
```
|
||||
|
||||
We glob all the test files together in `CMakeLists.txt` so that you don't
|
||||
have to edit it every time you add a test. Unfortunately, this means that in
|
||||
order to get the build to pick up your new test file, you need to re-run
|
||||
cmake:
|
||||
```bash
|
||||
CMAKE_FRESH=1 python setup.py build
|
||||
```
|
||||
|
||||
## How do I run the tests?
|
||||
The following commands assume you are in PyTorch root.
|
||||
|
||||
```bash
|
||||
# (re)build the test binary
|
||||
ninja build/bin/test_tensorexpr
|
||||
# run
|
||||
build/bin/test_tensorexpr --gtest_filter='glob_style_filter*'
|
||||
```
|
||||
119
test/cpp/tensorexpr/gtest_assert_float_eq.h
Normal file
119
test/cpp/tensorexpr/gtest_assert_float_eq.h
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
// Copyright 2005, Google Inc.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
// The Google C++ Testing and Mocking Framework (Google Test)
|
||||
//
|
||||
// This header file declares functions and macros used internally by
|
||||
// Google Test. They are subject to change without notice.
|
||||
|
||||
using Bits = uint32_t;
|
||||
|
||||
// this avoids the "dereferencing type-punned pointer
|
||||
// will break strict-aliasing rules" error
|
||||
union Float {
|
||||
float float_;
|
||||
Bits bits_;
|
||||
};
|
||||
|
||||
// # of bits in a number.
|
||||
static const size_t kBitCount = 8 * sizeof(Bits);
|
||||
// The mask for the sign bit.
|
||||
static const Bits kSignBitMask = static_cast<Bits>(1) << (kBitCount - 1);
|
||||
|
||||
// GOOGLETEST_CM0001 DO NOT DELETE
|
||||
|
||||
// Converts an integer from the sign-and-magnitude representation to
|
||||
// the biased representation. More precisely, let N be 2 to the
|
||||
// power of (kBitCount - 1), an integer x is represented by the
|
||||
// unsigned number x + N.
|
||||
//
|
||||
// For instance,
|
||||
//
|
||||
// -N + 1 (the most negative number representable using
|
||||
// sign-and-magnitude) is represented by 1;
|
||||
// 0 is represented by N; and
|
||||
// N - 1 (the biggest number representable using
|
||||
// sign-and-magnitude) is represented by 2N - 1.
|
||||
//
|
||||
// Read http://en.wikipedia.org/wiki/Signed_number_representations
|
||||
// for more details on signed number representations.
|
||||
static Bits SignAndMagnitudeToBiased(const Bits& sam) {
|
||||
if (kSignBitMask & sam) {
|
||||
// sam represents a negative number.
|
||||
return ~sam + 1;
|
||||
} else {
|
||||
// sam represents a positive number.
|
||||
return kSignBitMask | sam;
|
||||
}
|
||||
}
|
||||
|
||||
// Given two numbers in the sign-and-magnitude representation,
|
||||
// returns the distance between them as an unsigned number.
|
||||
static Bits DistanceBetweenSignAndMagnitudeNumbers(
|
||||
const Bits& sam1,
|
||||
const Bits& sam2) {
|
||||
const Bits biased1 = SignAndMagnitudeToBiased(sam1);
|
||||
const Bits biased2 = SignAndMagnitudeToBiased(sam2);
|
||||
return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1);
|
||||
}
|
||||
|
||||
// How many ULP's (Units in the Last Place) we want to tolerate when
|
||||
// comparing two numbers. The larger the value, the more error we
|
||||
// allow. A 0 value means that two numbers must be exactly the same
|
||||
// to be considered equal.
|
||||
//
|
||||
// The maximum error of a single floating-point operation is 0.5
|
||||
// units in the last place. On Intel CPU's, all floating-point
|
||||
// calculations are done with 80-bit precision, while double has 64
|
||||
// bits. Therefore, 4 should be enough for ordinary use.
|
||||
//
|
||||
// See the following article for more details on ULP:
|
||||
// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
|
||||
static const size_t kMaxUlps = 4;
|
||||
|
||||
// Returns true if and only if this number is at most kMaxUlps ULP's away
|
||||
// from rhs. In particular, this function:
|
||||
//
|
||||
// - returns false if either number is (or both are) NAN.
|
||||
// - treats really large numbers as almost equal to infinity.
|
||||
// - thinks +0.0 and -0.0 are 0 DLP's apart.
|
||||
inline bool AlmostEquals(float lhs, float rhs) {
|
||||
// The IEEE standard says that any comparison operation involving
|
||||
// a NAN must return false.
|
||||
if (std::isnan(lhs) || std::isnan(rhs))
|
||||
return false;
|
||||
|
||||
Float l = {lhs};
|
||||
Float r = {rhs};
|
||||
|
||||
return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps;
|
||||
}
|
||||
37
test/cpp/tensorexpr/padded_buffer.cpp
Normal file
37
test/cpp/tensorexpr/padded_buffer.cpp
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
#include "test/cpp/tensorexpr/padded_buffer.h"
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
int PaddedBufferBase::Index(const std::vector<int>& indices) const {
|
||||
TORCH_DCHECK_EQ(dims_.size(), indices.size());
|
||||
int total_index = 0;
|
||||
for (const auto i : c10::irange(dims_.size())) {
|
||||
total_index += indices[i] * strides_[i];
|
||||
}
|
||||
return total_index;
|
||||
}
|
||||
|
||||
PaddedBufferBase::PaddedBufferBase(
|
||||
const std::vector<int>& dims,
|
||||
// NOLINTNEXTLINE(modernize-pass-by-value)
|
||||
const std::string& name)
|
||||
: dims_(dims), name_(name), strides_(dims.size()) {
|
||||
for (int i = (int)dims.size() - 1; i >= 0; --i) {
|
||||
if (i == (int)dims.size() - 1) {
|
||||
strides_[i] = 1;
|
||||
} else {
|
||||
strides_[i] = strides_[i + 1] * dims[i + 1];
|
||||
}
|
||||
}
|
||||
total_size_ = strides_[0] * dims[0];
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
242
test/cpp/tensorexpr/padded_buffer.h
Normal file
242
test/cpp/tensorexpr/padded_buffer.h
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
template <typename T>
|
||||
struct DefaultPaddedValue;
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<int> {
|
||||
static const int kValue = static_cast<int>(0xDEADBEEF);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<int8_t> {
|
||||
static const int8_t kValue = static_cast<int8_t>(0xBE);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<uint8_t> {
|
||||
static const uint8_t kValue = static_cast<uint8_t>(0xBE);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<int16_t> {
|
||||
static const int16_t kValue = static_cast<int16_t>(0xBEEF);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<int64_t> {
|
||||
static const int64_t kValue = static_cast<int64_t>(0xDEADBEEF);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<float> {
|
||||
static constexpr float kValue = 0.1357;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<at::Half> {
|
||||
// at::Half ctor isn't constexpr, so just fill it with bits.
|
||||
static constexpr uint16_t kValue = 1357;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DefaultPaddedValue<double> {
|
||||
static constexpr double kValue = 0.1357;
|
||||
};
|
||||
|
||||
// A concrete base to be used in PaddedBase.
|
||||
class PaddedBufferBase {
|
||||
public:
|
||||
const std::string& name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
int size() const {
|
||||
return total_size_;
|
||||
}
|
||||
|
||||
int raw_size() const {
|
||||
return total_size_ + 2 * kPaddingSize;
|
||||
}
|
||||
|
||||
virtual ~PaddedBufferBase() {}
|
||||
|
||||
protected:
|
||||
explicit PaddedBufferBase(
|
||||
const std::vector<int>& dims,
|
||||
const std::string& name);
|
||||
int Index(const std::vector<int>& indices) const;
|
||||
|
||||
std::vector<int> dims_;
|
||||
std::string name_;
|
||||
std::vector<int> strides_;
|
||||
int total_size_; // total number of useful element, does not include the
|
||||
// paddings
|
||||
static constexpr int kPaddingSize = 64;
|
||||
};
|
||||
|
||||
// A padded buffer with wartermarks for testing.
|
||||
// The buffer carries padded watermarks on both sides to catch potential
|
||||
// out-of-bounds writes. For read-only data that are not supposed to change, it
|
||||
// can also make a backup and be compared later.
|
||||
template <typename T>
|
||||
class PaddedBuffer : public PaddedBufferBase {
|
||||
public:
|
||||
PaddedBuffer(int d0, const std::string& name = "")
|
||||
: PaddedBuffer(std::vector<int>({d0}), name) {}
|
||||
PaddedBuffer(int d0, int d1, const std::string& name = "")
|
||||
: PaddedBuffer(std::vector<int>({d0, d1}), name) {}
|
||||
PaddedBuffer(int d0, int d1, int d2, const std::string& name = "")
|
||||
: PaddedBuffer(std::vector<int>({d0, d1, d2}), name) {}
|
||||
PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "")
|
||||
: PaddedBuffer(std::vector<int>({d0, d1, d2, d3}), name) {}
|
||||
PaddedBuffer(const std::vector<int>& dims, const std::string& name = "")
|
||||
: PaddedBufferBase(dims, name) {
|
||||
data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue);
|
||||
}
|
||||
PaddedBuffer(const PaddedBuffer& other, const std::string& name)
|
||||
: PaddedBuffer(other) {
|
||||
this->name_ = name;
|
||||
}
|
||||
|
||||
T* data() {
|
||||
return data_.data() + kPaddingSize;
|
||||
}
|
||||
const T* data() const {
|
||||
return const_cast<PaddedBuffer*>(this)->data();
|
||||
}
|
||||
T* raw_data() {
|
||||
return data_.data();
|
||||
}
|
||||
const T* raw_data() const {
|
||||
return const_cast<PaddedBuffer*>(this)->raw_data();
|
||||
}
|
||||
T& operator()(int i0) {
|
||||
// There is a bit performance impact with forming a vector here. But this
|
||||
// data structure is for testing only, and not performance critical.
|
||||
return this->operator()(std::vector<int>({i0}));
|
||||
}
|
||||
const T& operator()(int i0) const {
|
||||
return const_cast<PaddedBuffer*>(this)->operator()(i0);
|
||||
}
|
||||
T& operator()(int i0, int i1) {
|
||||
return this->operator()(std::vector<int>({i0, i1}));
|
||||
}
|
||||
const T& operator()(int i0, int i1) const {
|
||||
return const_cast<PaddedBuffer*>(this)->operator()(i0, i1);
|
||||
}
|
||||
T& operator()(int i0, int i1, int i2) {
|
||||
return this->operator()(std::vector<int>({i0, i1, i2}));
|
||||
}
|
||||
const T& operator()(int i0, int i1, int i2) const {
|
||||
return const_cast<PaddedBuffer*>(this)->operator()(i0, i1, i2);
|
||||
}
|
||||
T& operator()(int i0, int i1, int i2, int i3) {
|
||||
return this->operator()(std::vector<int>({i0, i1, i2, i3}));
|
||||
}
|
||||
const T& operator()(int i0, int i1, int i2, int i3) const {
|
||||
return const_cast<PaddedBuffer*>(this)->operator()(i0, i1, i2, i3);
|
||||
}
|
||||
T& operator()(const std::vector<int>& indices) {
|
||||
return data_[kPaddingSize + Index(indices)];
|
||||
}
|
||||
const T& operator()(const std::vector<int>& indices) const {
|
||||
return const_cast<PaddedBuffer*>(this)->operator()(indices);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
friend void ExpectAllNear(
|
||||
const PaddedBuffer<U>& v1,
|
||||
const PaddedBuffer<U>& v2,
|
||||
float abs_error);
|
||||
template <typename U>
|
||||
friend void ExpectAllEqual(
|
||||
const PaddedBuffer<U>& v1,
|
||||
const PaddedBuffer<U>& v2);
|
||||
void Backup() {
|
||||
backup_data_ = data_;
|
||||
}
|
||||
|
||||
// Verify the watermarks in the paddings are intact.
|
||||
void ValidateWatermark() const {
|
||||
for (const auto i : c10::irange(kPaddingSize)) {
|
||||
ASSERT_EQ(data_[i], kPaddingValue);
|
||||
ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckBackup() const {
|
||||
ValidateWatermark();
|
||||
DCHECK(backup_data_.size() == data_.size())
|
||||
<< "Please make sure you have call Backup() before calling CheckBackup()";
|
||||
for (const auto i : c10::irange(total_size_)) {
|
||||
ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<T> data_;
|
||||
std::vector<T> backup_data_;
|
||||
T kPaddingValue = DefaultPaddedValue<T>::kValue;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline CodeGen::CallArg::CallArg(const PaddedBuffer<T>& buffer)
|
||||
: data_(const_cast<T*>(buffer.data())) {}
|
||||
|
||||
template <typename T>
|
||||
std::string CompareErrorMsg(
|
||||
const PaddedBuffer<T>& v1,
|
||||
const PaddedBuffer<T>& v2,
|
||||
int index) {
|
||||
std::ostringstream oss;
|
||||
oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index)
|
||||
<< ")"
|
||||
<< ", v2: (" << v2.name() << ", " << v2(index) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExpectAllEqual(const PaddedBuffer<T>& f1, const PaddedBuffer<T>& f2) {
|
||||
const std::vector<T>& v1 = f1.data_;
|
||||
const std::vector<T>& v2 = f2.data_;
|
||||
const int kPaddingSize = f1.kPaddingSize;
|
||||
const int total_size = f1.total_size_;
|
||||
ASSERT_EQ(v1.size(), v2.size());
|
||||
f1.ValidateWatermark();
|
||||
f2.ValidateWatermark();
|
||||
for (const auto i : c10::irange(total_size)) {
|
||||
ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExpectAllNear(
|
||||
const PaddedBuffer<T>& f1,
|
||||
const PaddedBuffer<T>& f2,
|
||||
float abs_error) {
|
||||
const std::vector<T>& v1 = f1.data_;
|
||||
const std::vector<T>& v2 = f2.data_;
|
||||
const int kPaddingSize = f1.kPaddingSize;
|
||||
const int total_size = f1.total_size_;
|
||||
ASSERT_EQ(v1.size(), v2.size());
|
||||
f1.ValidateWatermark();
|
||||
f2.ValidateWatermark();
|
||||
for (const auto i : c10::irange(total_size)) {
|
||||
ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
96
test/cpp/tensorexpr/test_approx.cpp
Normal file
96
test/cpp/tensorexpr/test_approx.cpp
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
#ifdef TORCH_ENABLE_LLVM
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/torch.h>
|
||||
#include <cstring>
|
||||
|
||||
using namespace torch::indexing;
|
||||
namespace te = torch::jit::tensorexpr;
|
||||
|
||||
static void vectorize(te::LoopNest* ln, te::Tensor target, int width) {
|
||||
auto loops = ln->getLoopStmtsFor(target);
|
||||
te::ForPtr inner, tail;
|
||||
ln->splitWithTail(loops[0], width, &inner, &tail);
|
||||
ASSERT_TRUE(te::LoopNest::vectorize(inner));
|
||||
}
|
||||
|
||||
std::string diffs(const at::Tensor& a, const at::Tensor& b) {
|
||||
auto diff = torch::abs(a.flatten() - b.flatten());
|
||||
auto count_diffs = torch::sum(diff > 0.f);
|
||||
auto greatest_diff_index = torch::argmax(diff);
|
||||
std::stringstream ss;
|
||||
ss << "Found " << count_diffs << " unequal element(s). "
|
||||
<< "The greatest difference was " << diff.index({greatest_diff_index})
|
||||
<< " at index " << greatest_diff_index;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
TEST(Approx, log_vml) {
|
||||
te::VarHandle N("N", te::kInt);
|
||||
te::BufHandle A("A", {N}, te::kFloat);
|
||||
te::Tensor B = te::Compute(
|
||||
"B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); });
|
||||
|
||||
te::LoopNest ln({B});
|
||||
ln.prepareForCodegen();
|
||||
vectorize(&ln, B, 8);
|
||||
te::StmtPtr s = ln.root_stmt();
|
||||
s = te::IRSimplifier::simplify(s);
|
||||
te::LLVMCodeGen cg(s, {A, B, N});
|
||||
|
||||
auto eps = std::numeric_limits<float>::epsilon();
|
||||
auto test = [&](const at::Tensor& A_t) {
|
||||
at::Tensor B_ref = at::log(A_t);
|
||||
at::Tensor B_t = at::empty_like(A_t);
|
||||
auto ap = A_t.data_ptr<float>();
|
||||
auto bp = B_t.data_ptr<float>();
|
||||
cg.call({ap, bp, A_t.numel()});
|
||||
// Results should be bit-identical.
|
||||
ASSERT_TRUE(torch::allclose(
|
||||
B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true))
|
||||
<< "Input[:8]\n"
|
||||
<< A_t.index({Slice(0, 8)}) << "\n"
|
||||
<< "Test[:8]\n"
|
||||
<< B_t.index({Slice(0, 8)}) << "\n"
|
||||
<< "Ref[:8]\n"
|
||||
<< B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref);
|
||||
};
|
||||
|
||||
// Generate every single-precision FP value in [1.0, 2.0).
|
||||
at::Tensor A_t = torch::arange(1.0f, 2.0f, eps);
|
||||
ASSERT_EQ(A_t.numel(), 1 << 23);
|
||||
|
||||
test(A_t);
|
||||
|
||||
test(A_t * 2.0f);
|
||||
test(A_t * 0.5f);
|
||||
|
||||
test(A_t * 4.0f);
|
||||
test(A_t * 0.25f);
|
||||
|
||||
test(A_t * powf(2.0f, 16));
|
||||
test(A_t * powf(2.0f, -16));
|
||||
|
||||
test(A_t * powf(2.0f, 126));
|
||||
test(A_t * powf(2.0f, -126));
|
||||
|
||||
test(torch::full({32}, INFINITY));
|
||||
test(torch::full({32}, NAN));
|
||||
|
||||
auto min = std::numeric_limits<float>::min();
|
||||
auto denorm_min = std::numeric_limits<float>::denorm_min();
|
||||
|
||||
// Denormals aren't bit precise, because sleef isn't bit-precise either.
|
||||
A_t = torch::arange(0.0f, min, denorm_min);
|
||||
ASSERT_EQ(A_t.numel(), 1 << 23);
|
||||
auto B_ref = at::log(A_t);
|
||||
auto B_t = at::empty_like(B_ref);
|
||||
cg.call({A_t.data_ptr<float>(), B_t.data_ptr<float>(), A_t.numel()});
|
||||
ASSERT_TRUE(torch::allclose(B_t, B_ref));
|
||||
}
|
||||
|
||||
#endif // TORCH_ENABLE_LLVM
|
||||
1068
test/cpp/tensorexpr/test_aten.cpp
Normal file
1068
test/cpp/tensorexpr/test_aten.cpp
Normal file
File diff suppressed because it is too large
Load Diff
89
test/cpp/tensorexpr/test_base.h
Normal file
89
test/cpp/tensorexpr/test_base.h
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
#pragma once
|
||||
|
||||
#if defined(USE_GTEST)
|
||||
#include <gtest/gtest.h>
|
||||
#include <test/cpp/common/support.h>
|
||||
#else
|
||||
#include <cmath>
|
||||
#include "c10/util/Exception.h"
|
||||
#include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
|
||||
#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
|
||||
#define ASSERT_FLOAT_EQ(x, y, ...) \
|
||||
TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
|
||||
#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
|
||||
#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
|
||||
#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
|
||||
#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
|
||||
#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)
|
||||
|
||||
#define ASSERT_NEAR(x, y, a, ...) \
|
||||
TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)
|
||||
|
||||
#define ASSERT_TRUE TORCH_INTERNAL_ASSERT
|
||||
#define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
|
||||
#define ASSERT_THROWS_WITH(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
ASSERT_TRUE(false); \
|
||||
} catch (const std::exception& e) { \
|
||||
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
|
||||
}
|
||||
#define ASSERT_ANY_THROW(statement) \
|
||||
{ \
|
||||
bool threw = false; \
|
||||
try { \
|
||||
(void)statement; \
|
||||
} catch (const std::exception& e) { \
|
||||
threw = true; \
|
||||
} \
|
||||
ASSERT_TRUE(threw); \
|
||||
}
|
||||
|
||||
#endif // defined(USE_GTEST)
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
template <typename U, typename V>
|
||||
void ExpectAllNear(
|
||||
const std::vector<U>& v1,
|
||||
const std::vector<U>& v2,
|
||||
V threshold,
|
||||
const std::string& name = "") {
|
||||
ASSERT_EQ(v1.size(), v2.size());
|
||||
for (size_t i = 0; i < v1.size(); i++) {
|
||||
ASSERT_NEAR(v1[i], v2[i], threshold);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
void ExpectAllNear(
|
||||
const std::vector<U>& vec,
|
||||
const U& val,
|
||||
V threshold,
|
||||
const std::string& name = "") {
|
||||
for (size_t i = 0; i < vec.size(); i++) {
|
||||
ASSERT_NEAR(vec[i], val, threshold);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void assertAllEqual(const std::vector<T>& vec, const T& val) {
|
||||
for (auto const& elt : vec) {
|
||||
ASSERT_EQ(elt, val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void assertAllEqual(const std::vector<T>& v1, const std::vector<T>& v2) {
|
||||
ASSERT_EQ(v1.size(), v2.size());
|
||||
for (size_t i = 0; i < v1.size(); ++i) {
|
||||
ASSERT_EQ(v1[i], v2[i]);
|
||||
}
|
||||
}
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
1019
test/cpp/tensorexpr/test_boundsinference.cpp
Normal file
1019
test/cpp/tensorexpr/test_boundsinference.cpp
Normal file
File diff suppressed because it is too large
Load Diff
234
test/cpp/tensorexpr/test_conv.cpp
Normal file
234
test/cpp/tensorexpr/test_conv.cpp
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/conv2d.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace te = torch::jit::tensorexpr;
|
||||
namespace F = torch::nn::functional;
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
|
||||
// Generate test data with few bits of precision, to minimize error
|
||||
// accumulation from floating-point reordering.
|
||||
static at::Tensor genTestData(c10::IntArrayRef args) {
|
||||
return at::trunc(at::randn(args) * 256.0f) / 256.0f;
|
||||
}
|
||||
|
||||
TEST(Conv, DepthwiseConv2D) {
|
||||
constexpr int N = 1, C = 72, H = 56, W = 56;
|
||||
constexpr int K = 72, R = 3, S = 3;
|
||||
constexpr int kPad = 1, kStride = 2, kGroups = C;
|
||||
constexpr int CperG = C / kGroups;
|
||||
|
||||
te::BufHandle input("input", {N, C, H, W}, te::kFloat);
|
||||
te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);
|
||||
te::BufHandle bias("bias", {K}, te::kFloat);
|
||||
te::Tensor output =
|
||||
te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups);
|
||||
|
||||
te::LoopNest loop({output});
|
||||
loop.simplify();
|
||||
loop.prepareForCodegen();
|
||||
te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output});
|
||||
|
||||
auto it = genTestData({N, C, H, W});
|
||||
auto wt = genTestData({K, CperG, R, S});
|
||||
auto bt = genTestData({K});
|
||||
auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups);
|
||||
auto ot = at::zeros_like(ref);
|
||||
cg.call(
|
||||
{it.data_ptr<float>(),
|
||||
wt.data_ptr<float>(),
|
||||
bt.data_ptr<float>(),
|
||||
ot.data_ptr<float>()});
|
||||
|
||||
ASSERT_TRUE(at::allclose(ref, ot));
|
||||
}
|
||||
|
||||
TEST(Conv, DepthwiseConv2DNoBias) {
|
||||
constexpr int N = 1, C = 72, H = 56, W = 56;
|
||||
constexpr int K = 72, R = 3, S = 3;
|
||||
constexpr int kPad = 1, kStride = 2, kGroups = C;
|
||||
constexpr int CperG = C / kGroups;
|
||||
|
||||
te::BufHandle input("input", {N, C, H, W}, te::kFloat);
|
||||
te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);
|
||||
te::Tensor output =
|
||||
te::conv2d_depthwise(input, weight, kStride, kPad, kGroups);
|
||||
|
||||
te::LoopNest loop({output});
|
||||
loop.simplify();
|
||||
loop.prepareForCodegen();
|
||||
te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output});
|
||||
|
||||
auto it = genTestData({N, C, H, W});
|
||||
auto wt = genTestData({K, CperG, R, S});
|
||||
auto ref =
|
||||
at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);
|
||||
auto ot = at::zeros_like(ref);
|
||||
cg.call({it.data_ptr<float>(), wt.data_ptr<float>(), ot.data_ptr<float>()});
|
||||
|
||||
ASSERT_TRUE(at::allclose(ref, ot));
|
||||
}
|
||||
|
||||
TEST(Conv, DepthwiseConv2DDynamicShapes) {
|
||||
te::VarHandle N_var("N", te::kInt);
|
||||
te::VarHandle C_var("C", te::kInt);
|
||||
te::VarHandle H_var("H", te::kInt);
|
||||
te::VarHandle W_var("W", te::kInt);
|
||||
te::VarHandle K_var("K", te::kInt);
|
||||
te::VarHandle CperG_var("CperG", te::kInt);
|
||||
te::VarHandle R_var("R", te::kInt);
|
||||
te::VarHandle S_var("S", te::kInt);
|
||||
te::VarHandle kPad_var("kPad", te::kInt);
|
||||
te::VarHandle kStride_var("kStride", te::kInt);
|
||||
te::VarHandle kGroups_var("kGroups", te::kInt);
|
||||
|
||||
te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat);
|
||||
te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat);
|
||||
te::Tensor output = te::conv2d_depthwise(
|
||||
input,
|
||||
weight,
|
||||
N_var,
|
||||
C_var,
|
||||
H_var,
|
||||
W_var,
|
||||
K_var,
|
||||
CperG_var,
|
||||
R_var,
|
||||
S_var,
|
||||
kStride_var,
|
||||
kPad_var,
|
||||
kGroups_var);
|
||||
|
||||
te::LoopNest loop({output});
|
||||
loop.simplify();
|
||||
loop.prepareForCodegen();
|
||||
std::vector<te::CodeGen::BufferArg> buffer_args = {
|
||||
input,
|
||||
weight,
|
||||
N_var,
|
||||
C_var,
|
||||
H_var,
|
||||
W_var,
|
||||
K_var,
|
||||
CperG_var,
|
||||
R_var,
|
||||
S_var,
|
||||
kPad_var,
|
||||
kStride_var,
|
||||
kGroups_var,
|
||||
output};
|
||||
te::LLVMCodeGen cg(loop.root_stmt(), buffer_args);
|
||||
|
||||
constexpr int N = 1, C = 72, H = 56, W = 56;
|
||||
constexpr int K = 72, R = 3, S = 3;
|
||||
constexpr int kPad = 1, kStride = 2, kGroups = C;
|
||||
constexpr int CperG = C / kGroups;
|
||||
|
||||
auto it = genTestData({N, C, H, W});
|
||||
auto wt = genTestData({K, CperG, R, S});
|
||||
auto ref =
|
||||
at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);
|
||||
auto ot = at::zeros_like(ref);
|
||||
std::vector<te::CodeGen::CallArg> call_args = {
|
||||
it.data_ptr<float>(),
|
||||
wt.data_ptr<float>(),
|
||||
N,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
K,
|
||||
CperG,
|
||||
R,
|
||||
S,
|
||||
kPad,
|
||||
kStride,
|
||||
kGroups,
|
||||
ot.data_ptr<float>()};
|
||||
cg.call(call_args);
|
||||
|
||||
ASSERT_TRUE(at::allclose(ref, ot));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TEST(Conv, Conv2D) {
|
||||
// Input dimensions.
|
||||
constexpr int N = 1;
|
||||
constexpr int C = 3;
|
||||
constexpr int H = 11;
|
||||
constexpr int W = 11;
|
||||
|
||||
// Filter dimensions.
|
||||
constexpr int K = 8;
|
||||
constexpr int R = 3;
|
||||
constexpr int S = 3;
|
||||
|
||||
// Output dims.
|
||||
constexpr int OH = H - R + 1;
|
||||
constexpr int OW = W - S + 1;
|
||||
|
||||
// Compute reference result.
|
||||
at::Tensor input = torch::randn({N, C, H, W});
|
||||
at::Tensor filter = torch::randn({K, C, R, S});
|
||||
at::Tensor ref = F::conv2d(input, filter);
|
||||
|
||||
// Double check the output size is as expected.
|
||||
ASSERT_EQ(ref.size(0), N);
|
||||
ASSERT_EQ(ref.size(1), K);
|
||||
ASSERT_EQ(ref.size(2), OH);
|
||||
ASSERT_EQ(ref.size(3), OW);
|
||||
|
||||
te::BufHandle inputB("input", {N, C, H, W}, te::kFloat);
|
||||
te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat);
|
||||
|
||||
te::Tensor conv = te::Reduce(
|
||||
"conv",
|
||||
{N, K, OH, OW},
|
||||
te::Sum(),
|
||||
// FIXME: We have to use a `std::vector` parameter here and then unpack
|
||||
// it, because we don't have an overload allowing for an arbitrary number
|
||||
// of ExprHandle/VarHandle parameters.
|
||||
[&](const std::vector<te::VarHandle>& v) {
|
||||
auto const& n = v[0];
|
||||
auto const& k = v[1];
|
||||
auto const& oh = v[2];
|
||||
auto const& ow = v[3];
|
||||
auto const& c = v[4];
|
||||
auto const& r = v[5];
|
||||
auto const& s = v[6];
|
||||
// FIXME: We have to use `call` and construct a `std::vector` here
|
||||
// because the `operator()` overload is only specialized for a small
|
||||
// number of arguments.
|
||||
return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s);
|
||||
},
|
||||
// FIXME: If you forget one of the reduction dims, you get a segfault.
|
||||
// Could that be caught by a verifier?
|
||||
{C, R, S});
|
||||
|
||||
// FIXME: It'd be nice to have a single header that pulls in things like
|
||||
// LoopNest, IRSimplifier, etc.
|
||||
te::LoopNest loop({conv});
|
||||
loop.prepareForCodegen();
|
||||
te::StmtPtr s = loop.root_stmt();
|
||||
s = te::IRSimplifier::simplify(s);
|
||||
|
||||
at::Tensor result = at::empty_like(ref);
|
||||
te::SimpleIREvaluator cg(s, {inputB, filterB, conv});
|
||||
cg.call(
|
||||
{input.data_ptr<float>(),
|
||||
filter.data_ptr<float>(),
|
||||
result.data_ptr<float>()});
|
||||
|
||||
ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3));
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
259
test/cpp/tensorexpr/test_cpp_codegen.cpp
Normal file
259
test/cpp/tensorexpr/test_cpp_codegen.cpp
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include "test/cpp/tensorexpr/test_base.h"
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
|
||||
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
|
||||
#include <torch/csrc/jit/tensorexpr/stmt.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
#define STR_CHECK(node, expected) \
|
||||
std::stringstream ss; \
|
||||
CppPrinter printer(&ss); \
|
||||
printer.visit(node); \
|
||||
ASSERT_EQ(ss.str(), expected)
|
||||
|
||||
#define FILE_CHECK(node, pattern) \
|
||||
std::stringstream ss; \
|
||||
CppPrinter printer(&ss); \
|
||||
printer.visit(node); \
|
||||
torch::jit::testing::FileCheck().run(pattern, ss.str())
|
||||
|
||||
TEST(CppPrinter, IntImm) {
|
||||
auto i = alloc<IntImm>(10);
|
||||
STR_CHECK(i, "10");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, FloatImm) {
|
||||
auto f = alloc<FloatImm>(10);
|
||||
STR_CHECK(f, "10.f");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, FloatImm1) {
|
||||
auto f = alloc<FloatImm>(10);
|
||||
STR_CHECK(f, "10.f");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, DoubleImm) {
|
||||
auto d = alloc<DoubleImm>(10);
|
||||
STR_CHECK(d, "10.0");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, DoubleImm1) {
|
||||
auto d = alloc<DoubleImm>(10.1);
|
||||
STR_CHECK(d, "10.1");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, HalfImm) {
|
||||
auto h = alloc<HalfImm>(10);
|
||||
STR_CHECK(h, "10");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Add) {
|
||||
auto add = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
|
||||
STR_CHECK(add, "1 + 2");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, AddExpr1) {
|
||||
auto add = alloc<Add>(
|
||||
alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
|
||||
alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
|
||||
STR_CHECK(add, "(0 + 1) + (2 - 3)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, AddExpr2) {
|
||||
auto add = alloc<Add>(
|
||||
alloc<Mul>(alloc<IntImm>(0), alloc<IntImm>(1)),
|
||||
alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
|
||||
STR_CHECK(add, "0 * 1 + (2 - 3)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, AddExpr3) {
|
||||
auto add = alloc<Add>(
|
||||
alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
|
||||
alloc<Div>(alloc<IntImm>(2), alloc<IntImm>(3)));
|
||||
STR_CHECK(add, "(0 + 1) + 2 / 3");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Mod) {
|
||||
auto mod = alloc<Mod>(alloc<IntImm>(1), alloc<IntImm>(2));
|
||||
STR_CHECK(mod, "1 % 2");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, ModFloat) {
|
||||
auto mod = alloc<Mod>(alloc<FloatImm>(1), alloc<FloatImm>(2));
|
||||
STR_CHECK(mod, "std::fmod(1.f, 2.f)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Max) {
|
||||
auto max = alloc<Max>(alloc<IntImm>(1), alloc<IntImm>(2), false);
|
||||
STR_CHECK(max, "std::max(1, 2)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, MaxFloat) {
|
||||
auto max = alloc<Max>(alloc<FloatImm>(1), alloc<FloatImm>(2), false);
|
||||
STR_CHECK(max, "std::max(1.f, 2.f)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, MaxHalf) {
|
||||
auto max = alloc<Max>(alloc<HalfImm>(1), alloc<HalfImm>(2), false);
|
||||
STR_CHECK(max, "(1 < 2) ? 2 : 1");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, And) {
|
||||
auto v = alloc<And>(alloc<IntImm>(1), alloc<IntImm>(2));
|
||||
STR_CHECK(v, "1 & 2");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, CompareSelect) {
|
||||
auto cs = alloc<CompareSelect>(
|
||||
alloc<IntImm>(1),
|
||||
alloc<IntImm>(2),
|
||||
alloc<FloatImm>(1),
|
||||
alloc<FloatImm>(2),
|
||||
CompareSelectOperation::kLE);
|
||||
STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, IfThenElse) {
|
||||
auto cond = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
|
||||
auto true_value = alloc<Sub>(alloc<IntImm>(0), alloc<IntImm>(1));
|
||||
auto false_value = alloc<Mul>(alloc<IntImm>(2), alloc<IntImm>(3));
|
||||
auto v = alloc<IfThenElse>(cond, true_value, false_value);
|
||||
STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, AllocateFree) {
|
||||
BufHandle buf("x", {2, 3}, kInt);
|
||||
AllocatePtr alloc = Allocate::make(buf);
|
||||
FreePtr free = Free::make(buf);
|
||||
BlockPtr block = Block::make({alloc, free});
|
||||
|
||||
const std::string pattern = R"(
|
||||
# CHECK: {
|
||||
# CHECK: int* x = static_cast<int*>(malloc(24));
|
||||
# CHECK: free(x);
|
||||
# CHECK: }
|
||||
)";
|
||||
FILE_CHECK(block, pattern);
|
||||
}
|
||||
|
||||
TEST(CppPrinter, LoadStore) {
|
||||
BufHandle a("A", {2, 3}, kInt);
|
||||
BufHandle b("B", {3, 4}, kInt);
|
||||
auto store = b.store({2, 2}, a.load(1, 1));
|
||||
STR_CHECK(
|
||||
store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Var) {
|
||||
auto var = alloc<Var>("x", kInt);
|
||||
STR_CHECK(var, "x");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Cast) {
|
||||
auto cast = alloc<Cast>(kFloat, alloc<IntImm>(1));
|
||||
STR_CHECK(cast, "static_cast<float>(1)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, BitCast) {
|
||||
auto cast = alloc<BitCast>(kInt, alloc<FloatImm>(20));
|
||||
STR_CHECK(cast, "std::bitcast<float, int>(20.f)");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Let) {
|
||||
auto var = alloc<Var>("x", kFloat);
|
||||
auto val = alloc<FloatImm>(2);
|
||||
auto let = alloc<Let>(var, val);
|
||||
STR_CHECK(let, "float x = 2.f;\n");
|
||||
}
|
||||
|
||||
TEST(CppPrinter, For) {
|
||||
constexpr int N = 1024;
|
||||
BufHandle a("A", {N}, kInt);
|
||||
BufHandle b("B", {N}, kInt);
|
||||
BufHandle c("C", {N}, kInt);
|
||||
VarHandle i("i", kInt);
|
||||
auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));
|
||||
const std::string pattern = R"(
|
||||
# CHECK: for (int i = 0; i < 1024; i++) {
|
||||
# CHECK: C[i] = (A[i]) + (B[i]);
|
||||
# CHECK: }
|
||||
)";
|
||||
FILE_CHECK(f, pattern);
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Cond) {
|
||||
BufHandle x("X", {1}, kInt);
|
||||
auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
|
||||
auto cond =
|
||||
Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
|
||||
const std::string pattern = R"(
|
||||
# CHECK: if (((X[0] < 10) ? 1 : 0)) {
|
||||
# CHECK: X[0] = (X[0]) + 1;
|
||||
# CHECK: } else {
|
||||
# CHECK: X[0] = (X[0]) - 1;
|
||||
# CHECK: }
|
||||
)";
|
||||
FILE_CHECK(cond, pattern);
|
||||
}
|
||||
|
||||
TEST(CppPrinter, Intrinsics) {
|
||||
const std::unordered_set<IntrinsicsOp, std::hash<int>> unsupported_ops{
|
||||
kRand, kSigmoid};
|
||||
for (const auto i : c10::irange(static_cast<uint32_t>(kMaxIntrinsicsOp))) {
|
||||
IntrinsicsOp op = static_cast<IntrinsicsOp>(i);
|
||||
if (unsupported_ops.count(op)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (Intrinsics::OpArgCount(op) == 1) {
|
||||
auto v = alloc<Intrinsics>(op, alloc<FloatImm>(2.0f));
|
||||
STR_CHECK(v, "std::" + v->func_name() + "(2.f)");
|
||||
} else {
|
||||
auto v =
|
||||
alloc<Intrinsics>(op, alloc<FloatImm>(1.0f), alloc<FloatImm>(2.0f));
|
||||
STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CppPrinter, ExternalCall) {
|
||||
std::vector<ExprPtr> dims{alloc<IntImm>(2), alloc<IntImm>(2)};
|
||||
auto output = alloc<Buf>("out", dims, kFloat);
|
||||
auto buf_arg1 = alloc<Buf>("a", dims, kFloat);
|
||||
auto buf_arg2 = alloc<Buf>("b", dims, kFloat);
|
||||
auto scalar_arg = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
|
||||
std::vector<BufPtr> buf_args{buf_arg1, buf_arg2};
|
||||
std::vector<ExprPtr> scalar_args{scalar_arg};
|
||||
auto call =
|
||||
alloc<ExternalCall>(output, "nnc_aten_matmul", buf_args, scalar_args);
|
||||
const std::string pattern = R"(
|
||||
# CHECK: {
|
||||
# CHECK: void* buf_ptrs[]{out, a, b};
|
||||
# CHECK: int64_t buf_ranks[]{2, 2, 2};
|
||||
# CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2};
|
||||
# CHECK: int8_t buf_dtypes[]{6, 6, 6};
|
||||
# CHECK: int64_t extra_args[]{1 + 2};
|
||||
# CHECK: nnc_aten_matmul(
|
||||
# CHECK: 3,
|
||||
# CHECK: buf_ptrs,
|
||||
# CHECK: buf_ranks,
|
||||
# CHECK: buf_dims,
|
||||
# CHECK: buf_dtypes,
|
||||
# CHECK: 1,
|
||||
# CHECK: extra_args);
|
||||
# CHECK: }
|
||||
)";
|
||||
FILE_CHECK(call, pattern);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
2344
test/cpp/tensorexpr/test_cuda.cpp
Normal file
2344
test/cpp/tensorexpr/test_cuda.cpp
Normal file
File diff suppressed because it is too large
Load Diff
701
test/cpp/tensorexpr/test_dynamic_shapes.cpp
Normal file
701
test/cpp/tensorexpr/test_dynamic_shapes.cpp
Normal file
|
|
@ -0,0 +1,701 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/code_template.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <torch/torch.h>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::indexing;
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
TEST(DynamicShapes, SimpleGraph) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Tensor,
|
||||
%SS_2 : int,
|
||||
%SS_3 : int):
|
||||
%3 : Tensor = aten::tanh(%x)
|
||||
%4 : Tensor = aten::erf(%3)
|
||||
return (%4))IR";
|
||||
torch::jit::parseIR(graph_string, graph.get());
|
||||
|
||||
auto x_inp = graph->inputs()[0];
|
||||
auto x_type = TensorType::create(at::rand({10, 5}));
|
||||
std::vector<ShapeSymbol> x_sym_dims(
|
||||
{c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()});
|
||||
auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims);
|
||||
graph->inputs().at(0)->setType(x_sym_type);
|
||||
for (const auto n : graph->nodes()) {
|
||||
n->output()->setType(x_sym_type);
|
||||
}
|
||||
|
||||
// Graph with symbolic shapes:
|
||||
//
|
||||
// graph(%x : Float(SS(-2), SS(-3)),
|
||||
// %SS_2 : int,
|
||||
// %SS_3 : int):
|
||||
// %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x)
|
||||
// %4 : Float(SS(-2), SS(-3)) = aten::erf(%3)
|
||||
// return (%4)
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
|
||||
x_sym_dims,
|
||||
[](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
// Run with the same static dims as the one we initialized the graph with.
|
||||
{
|
||||
auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::erf(at::tanh(a));
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a}));
|
||||
stack.push_back(10);
|
||||
stack.push_back(5);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
|
||||
// Run with inputs having different dims.
|
||||
{
|
||||
auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::erf(at::tanh(a));
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a}));
|
||||
stack.push_back(50);
|
||||
stack.push_back(100);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicShapes, GraphWith2InputsSameDims) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
// The two inputs in this graph must have the same dims.
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Tensor,
|
||||
%y : Tensor,
|
||||
%SS_2 : int,
|
||||
%SS_3 : int):
|
||||
%3 : Tensor = aten::tanh(%x)
|
||||
%4 : Tensor = aten::erf(%3)
|
||||
%5 : Tensor = aten::mul(%4, %y)
|
||||
return (%5))IR";
|
||||
torch::jit::parseIR(graph_string, graph.get());
|
||||
|
||||
auto x_inp = graph->inputs()[0];
|
||||
auto y_inp = graph->inputs()[1];
|
||||
auto x_type = TensorType::create(at::rand({10, 5}));
|
||||
std::vector<ShapeSymbol> x_sym_dims(
|
||||
{c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()});
|
||||
auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims);
|
||||
graph->inputs().at(0)->setType(x_sym_type);
|
||||
graph->inputs().at(1)->setType(x_sym_type);
|
||||
for (const auto n : graph->nodes()) {
|
||||
n->output()->setType(x_sym_type);
|
||||
}
|
||||
|
||||
// Graph with symbolic shapes:
|
||||
//
|
||||
// graph(%x : Float(SS(-4), SS(-5)),
|
||||
// %y : Float(SS(-4), SS(-5)),
|
||||
// %SS_2 : int,
|
||||
// %SS_3 : int):
|
||||
// %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x)
|
||||
// %5 : Float(SS(-4), SS(-5)) = aten::erf(%4)
|
||||
// %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y)
|
||||
// return (%6)
|
||||
|
||||
std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
|
||||
x_sym_dims,
|
||||
[](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
||||
// Run with the same static dims as the one we initialized the graph with.
|
||||
{
|
||||
auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::erf(at::tanh(a)), b);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
|
||||
stack.push_back(10);
|
||||
stack.push_back(5);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
|
||||
// Run with inputs having different dims.
|
||||
{
|
||||
auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::erf(at::tanh(a)), b);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
|
||||
stack.push_back(50);
|
||||
stack.push_back(100);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicShapes, GraphWith2InputsAndBroadcast) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
// The second input to the graph has a dim of size 1 which should be
|
||||
// broadcasted in the at::mul op.
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, 5, requires_grad=0, device=cpu),
|
||||
%y : Float(1, 5, requires_grad=0, device=cpu),
|
||||
%SS_2 : int,
|
||||
%SS_3 : int):
|
||||
%3 : Tensor = aten::tanh(%x)
|
||||
%4 : Tensor = aten::erf(%3)
|
||||
%5 : Tensor = aten::mul(%4, %y)
|
||||
return (%5))IR";
|
||||
torch::jit::parseIR(graph_string, graph.get());
|
||||
|
||||
auto x_inp = graph->inputs()[0];
|
||||
auto y_inp = graph->inputs()[1];
|
||||
auto x_type = TensorType::create(at::rand({10, 5}));
|
||||
auto y_type = TensorType::create(at::rand({1, 5}));
|
||||
auto x_dim0_sym = c10::ShapeSymbol::newSymbol();
|
||||
auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
|
||||
auto x_sym_type = x_type->withSymbolicShapes(
|
||||
std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym}));
|
||||
auto y_sym_type = y_type->withSymbolicShapes(std::vector<ShapeSymbol>(
|
||||
{c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym}));
|
||||
graph->inputs().at(0)->setType(x_sym_type);
|
||||
graph->inputs().at(1)->setType(y_sym_type);
|
||||
for (const auto n : graph->nodes()) {
|
||||
n->output()->setType(x_sym_type);
|
||||
}
|
||||
|
||||
// Graph with symbolic shapes:
|
||||
//
|
||||
// graph(%x : Float(SS(-6), SS(-7)),
|
||||
// %y : Float(1, SS(-7)),
|
||||
// %SS_2 : int,
|
||||
// %SS_3 : int):
|
||||
// %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x)
|
||||
// %5 : Float(SS(-6), SS(-7)) = aten::erf(%4)
|
||||
// %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y)
|
||||
// return (%6)
|
||||
|
||||
std::vector<int64_t> symbolic_shape_inputs(
|
||||
{x_dim0_sym.value(), x_dim1_sym.value()});
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
||||
// Run with the same static dims as the one we initialized the graph with.
|
||||
{
|
||||
auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::erf(at::tanh(a)), b);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
|
||||
stack.push_back(10);
|
||||
stack.push_back(5);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
|
||||
// Run with inputs having different dims.
|
||||
{
|
||||
auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::erf(at::tanh(a)), b);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
|
||||
stack.push_back(50);
|
||||
stack.push_back(100);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
// The second input to the graph has a dim of size 1 which should be
|
||||
// broadcasted in the at::mul op.
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(1, 5, requires_grad=0, device=cpu),
|
||||
%y : Float(1, 5, requires_grad=0, device=cpu),
|
||||
%SS_2 : int):
|
||||
%4 : Tensor = aten::tanh(%x)
|
||||
%5 : Tensor = aten::mul(%4, %y)
|
||||
return (%5))IR";
|
||||
torch::jit::parseIR(graph_string, graph.get());
|
||||
|
||||
auto x_inp = graph->inputs()[0];
|
||||
auto y_inp = graph->inputs()[1];
|
||||
auto x_type = TensorType::create(at::rand({1, 5}));
|
||||
auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
|
||||
auto x_sym_type = x_type->withSymbolicShapes(std::vector<ShapeSymbol>(
|
||||
{c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym}));
|
||||
graph->inputs().at(0)->setType(x_sym_type);
|
||||
graph->inputs().at(1)->setType(x_sym_type);
|
||||
for (const auto n : graph->nodes()) {
|
||||
n->output()->setType(x_sym_type);
|
||||
}
|
||||
|
||||
// Graph with symbolic shapes:
|
||||
//
|
||||
// graph(%x : Float(1, SS(-2)),
|
||||
// %y : Float(1, SS(-2)),
|
||||
// %SS_2 : int):
|
||||
// %3 : Float(1, SS(-2)) = aten::tanh(%x)
|
||||
// %4 : Float(1, SS(-2)) = aten::mul(%3, %y)
|
||||
// return (%4)
|
||||
|
||||
std::vector<int64_t> symbolic_shape_inputs({x_dim1_sym.value()});
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
||||
// Run with the same static dims as the one we initialized the graph with.
|
||||
{
|
||||
auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::tanh(a), b);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
|
||||
stack.push_back(5);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
|
||||
// Run with inputs having different dims.
|
||||
{
|
||||
auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::tanh(a), b);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
|
||||
stack.push_back(100);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicShapes, GraphWithSymbolicStrides) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
|
||||
%1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
|
||||
%SS_3 : int,
|
||||
%SS_2 : int):
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15)
|
||||
%22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0)
|
||||
return (%22))IR";
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE};
|
||||
std::vector<torch::jit::StrideInput> output_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[graph->inputs().at(0)] = input_desc;
|
||||
symbolic_strides[graph->inputs().at(1)] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = output_desc;
|
||||
std::vector<int64_t> symbolic_shape_inputs = {-3, -2};
|
||||
TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
||||
{
|
||||
auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::add(x0, x1, 1), x0);
|
||||
|
||||
std::vector<at::Tensor> inputs = {x0, x1};
|
||||
std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
|
||||
stack.push_back(32);
|
||||
stack.push_back(10);
|
||||
k.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
|
||||
{
|
||||
auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto out =
|
||||
at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::add(x0, x1, 1), x0);
|
||||
|
||||
std::vector<at::Tensor> inputs = {out, x0, x1};
|
||||
std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
|
||||
stack.push_back(32);
|
||||
stack.push_back(10);
|
||||
k.runWithAllocatedOutputs(stack);
|
||||
|
||||
ASSERT_TRUE(at::allclose(out, ref));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicShapes, GraphWithCatAndBroadcast) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, 5, requires_grad=0, device=cpu),
|
||||
%y : Float(4, 5, requires_grad=0, device=cpu),
|
||||
%z : Float(1, 1, requires_grad=0, device=cpu),
|
||||
%SS_2 : int,
|
||||
%SS_3 : int,
|
||||
%SS_4 : int,
|
||||
%SS_5 : int):
|
||||
%11 : int = prim::Constant[value=0]()
|
||||
%3 : Tensor = aten::tanh(%x)
|
||||
%out1 : Tensor = aten::erf(%3)
|
||||
%out2 : Tensor = aten::relu(%y)
|
||||
%10 : Tensor[] = prim::ListConstruct(%out1, %out2)
|
||||
%25 : Tensor = aten::cat(%10, %11)
|
||||
%28 : Tensor = aten::hardswish(%25)
|
||||
%29 : Tensor = aten::mul(%28, %z)
|
||||
return (%29))IR";
|
||||
torch::jit::parseIR(graph_string, graph.get());
|
||||
|
||||
auto x_inp = graph->inputs()[0];
|
||||
auto y_inp = graph->inputs()[1];
|
||||
auto z_inp = graph->inputs()[2];
|
||||
auto x_type = TensorType::create(at::rand({10, 5}));
|
||||
auto y_type = TensorType::create(at::rand({4, 5}));
|
||||
auto z_type = TensorType::create(at::rand({1, 1}));
|
||||
auto x_dim0_sym = c10::ShapeSymbol::newSymbol();
|
||||
auto x_dim1_sym = c10::ShapeSymbol::newSymbol();
|
||||
auto x_sym_type = x_type->withSymbolicShapes(
|
||||
std::vector<ShapeSymbol>({x_dim0_sym, x_dim1_sym}));
|
||||
auto y_dim0_sym = c10::ShapeSymbol::newSymbol();
|
||||
auto y_sym_type = y_type->withSymbolicShapes(
|
||||
std::vector<ShapeSymbol>({y_dim0_sym, x_dim1_sym}));
|
||||
graph->inputs().at(0)->setType(x_sym_type);
|
||||
graph->inputs().at(1)->setType(y_sym_type);
|
||||
auto cat_dim0_sym = c10::ShapeSymbol::newSymbol();
|
||||
auto cat_out_type = x_type->withSymbolicShapes(
|
||||
std::vector<ShapeSymbol>({cat_dim0_sym, x_dim1_sym}));
|
||||
auto nodeIt = graph->nodes().begin();
|
||||
++nodeIt;
|
||||
nodeIt->output()->setType(x_sym_type); // aten::tanh
|
||||
++nodeIt;
|
||||
nodeIt->output()->setType(x_sym_type); // aten::erf
|
||||
++nodeIt;
|
||||
nodeIt->output()->setType(y_sym_type); // aten::relu
|
||||
++nodeIt;
|
||||
++nodeIt;
|
||||
nodeIt->output()->setType(cat_out_type); // aten::cat
|
||||
++nodeIt;
|
||||
nodeIt->output()->setType(cat_out_type); // aten::hardswish
|
||||
++nodeIt;
|
||||
nodeIt->output()->setType(cat_out_type); // aten::mul
|
||||
|
||||
// Graph with symbolic shapes:
|
||||
//
|
||||
// graph(%x : Float(SS(-2), SS(-3)),
|
||||
// %y : Float(SS(-4), SS(-3)),
|
||||
// %z : Float(1, 1),
|
||||
// %SS_2 : int,
|
||||
// %SS_3 : int,
|
||||
// %SS_4 : int,
|
||||
// %SS_5 : int):
|
||||
// %7 : int = prim::Constant[value=0]()
|
||||
// %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x)
|
||||
// %9 : Float(SS(-2), SS(-3)) = aten::erf(%8)
|
||||
// %10 : Float(SS(-4), SS(-3)) = aten::relu(%y)
|
||||
// %11 : Tensor[] = prim::ListConstruct(%9, %10)
|
||||
// %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7)
|
||||
// %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12)
|
||||
// %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z)
|
||||
// return (%14)
|
||||
|
||||
std::vector<int64_t> symbolic_shape_inputs(
|
||||
{x_dim0_sym.value(),
|
||||
x_dim1_sym.value(),
|
||||
y_dim0_sym.value(),
|
||||
cat_dim0_sym.value()});
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[x_inp] = input_desc;
|
||||
symbolic_strides[y_inp] = input_desc;
|
||||
symbolic_strides[z_inp] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
||||
auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(
|
||||
at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
|
||||
stack.push_back(10);
|
||||
stack.push_back(5);
|
||||
stack.push_back(4);
|
||||
stack.push_back(14);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicShapes, GraphFromModel) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu),
|
||||
%1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu),
|
||||
%2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu),
|
||||
%input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu),
|
||||
%4 : Float(SS(-7), requires_grad=0, device=cpu),
|
||||
%5 : Float(SS(-7), requires_grad=0, device=cpu),
|
||||
%SS_10 : int,
|
||||
%SS_9 : int,
|
||||
%SS_8 : int,
|
||||
%SS_7 : int,
|
||||
%SS_6 : int,
|
||||
%SS_5 : int,
|
||||
%SS_4 : int,
|
||||
%SS_3 : int,
|
||||
%SS_2 : int):
|
||||
%15 : int = prim::Constant[value=1]()
|
||||
%16 : bool = prim::Constant[value=0]()
|
||||
%17 : int = prim::Constant[value=6]()
|
||||
%18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16)
|
||||
%19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2)
|
||||
%20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15)
|
||||
%21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15)
|
||||
%22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4)
|
||||
return (%22))IR";
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[graph->inputs().at(0)] = input_desc;
|
||||
symbolic_strides[graph->inputs().at(1)] = input_desc;
|
||||
symbolic_strides[graph->inputs().at(2)] = input_desc;
|
||||
symbolic_strides[graph->inputs().at(3)] = input_desc;
|
||||
symbolic_strides[graph->inputs().at(4)] = input_desc;
|
||||
symbolic_strides[graph->inputs().at(5)] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
std::vector<int64_t> symbolic_shape_inputs = {
|
||||
-10, -9, -8, -7, -6, -5, -4, -3, -2};
|
||||
TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
||||
int64_t i2 = 10;
|
||||
int64_t i3 = 32;
|
||||
int64_t i4 = 19;
|
||||
int64_t i5 = 71;
|
||||
int64_t i6 = 139;
|
||||
int64_t i7 = 261;
|
||||
int64_t i8 = 261;
|
||||
int64_t i9 = 261;
|
||||
int64_t i10 = 261;
|
||||
auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong));
|
||||
auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4);
|
||||
|
||||
{
|
||||
std::vector<at::Tensor> inputs = {x0, x1, x2, x3, x4, x5};
|
||||
std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
|
||||
stack.emplace_back(i10);
|
||||
stack.emplace_back(i9);
|
||||
stack.emplace_back(i8);
|
||||
stack.emplace_back(i7);
|
||||
stack.emplace_back(i6);
|
||||
stack.emplace_back(i5);
|
||||
stack.emplace_back(i4);
|
||||
stack.emplace_back(i3);
|
||||
stack.emplace_back(i2);
|
||||
k.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
}
|
||||
|
||||
{
|
||||
auto out =
|
||||
at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
std::vector<at::Tensor> inputs = {out, x0, x1, x2, x3, x4, x5};
|
||||
std::vector<IValue> stack = at::fmap<at::IValue>(inputs);
|
||||
stack.emplace_back(i10);
|
||||
stack.emplace_back(i9);
|
||||
stack.emplace_back(i8);
|
||||
stack.emplace_back(i7);
|
||||
stack.emplace_back(i6);
|
||||
stack.emplace_back(i5);
|
||||
stack.emplace_back(i4);
|
||||
stack.emplace_back(i3);
|
||||
stack.emplace_back(i2);
|
||||
k.runWithAllocatedOutputs(stack);
|
||||
|
||||
ASSERT_TRUE(at::allclose(out, ref));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(DynamicShapes, MultiThreadedExecution) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_template = R"IR(
|
||||
graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
|
||||
%y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}),
|
||||
%SS_2 : int,
|
||||
%SS_3 : int):
|
||||
%3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x)
|
||||
%4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3)
|
||||
%5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y)
|
||||
return (%5))IR";
|
||||
for (bool use_cuda : {false, true}) {
|
||||
if (!torch::cuda::is_available() && use_cuda) {
|
||||
continue;
|
||||
}
|
||||
auto device = use_cuda ? at::kCUDA : at::kCPU;
|
||||
at::jit::TemplateEnv env;
|
||||
env.s("device", use_cuda ? "cuda:0" : "cpu");
|
||||
const auto graph_string = format(graph_template, env);
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, graph.get());
|
||||
|
||||
std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
|
||||
|
||||
std::vector<torch::jit::StrideInput> input_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
symbolic_strides[graph->inputs().at(0)] = input_desc;
|
||||
symbolic_strides[graph->inputs().at(1)] = input_desc;
|
||||
symbolic_strides[graph->outputs().at(0)] = input_desc;
|
||||
|
||||
TensorExprKernel kernel(
|
||||
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
|
||||
|
||||
auto run_kernel = [&](int dim1, int dim2) {
|
||||
auto a =
|
||||
at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
|
||||
auto b =
|
||||
at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat));
|
||||
|
||||
auto ref = at::mul(at::erf(at::tanh(a)), b);
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
|
||||
stack.emplace_back(dim1);
|
||||
stack.emplace_back(dim2);
|
||||
kernel.run(stack);
|
||||
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
};
|
||||
|
||||
// Run the kernel in parallel to ensure that the run() method calls in
|
||||
// TensorExprKernel are not changing any state.
|
||||
constexpr size_t kNumThreads = 4;
|
||||
std::vector<std::thread> threads;
|
||||
for (size_t id = 0; id < kNumThreads; ++id) {
|
||||
threads.emplace_back(run_kernel, id + 5, id + 20);
|
||||
}
|
||||
for (auto& t : threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
836
test/cpp/tensorexpr/test_expr.cpp
Normal file
836
test/cpp/tensorexpr/test_expr.cpp
Normal file
|
|
@ -0,0 +1,836 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <test/cpp/tensorexpr/padded_buffer.h>
|
||||
#include <test/cpp/tensorexpr/test_utils.h>
|
||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
||||
|
||||
TEST(Expr, BasicValueTest) {
|
||||
ExprHandle a = IntImm::make(2), b = IntImm::make(3);
|
||||
ExprHandle c = Add::make(a, b);
|
||||
SimpleIRExprEval eval(c);
|
||||
ASSERT_EQ(eval.value<int>(), 5);
|
||||
}
|
||||
|
||||
TEST(Expr, BasicValueTest02) {
|
||||
ExprHandle a(2.0f);
|
||||
ExprHandle b(3.0f);
|
||||
ExprHandle c(4.0f);
|
||||
ExprHandle d(5.0f);
|
||||
ExprHandle f = (a + b) - (c + d);
|
||||
SimpleIRExprEval eval(f);
|
||||
ASSERT_EQ(eval.value<float>(), -4.0f);
|
||||
}
|
||||
|
||||
TEST(Expr, IsChannelsLastContiguous) {
|
||||
std::vector<VarHandle> vars = {
|
||||
VarHandle("var1", kLong),
|
||||
VarHandle("var2", kLong),
|
||||
VarHandle("var3", kLong),
|
||||
VarHandle("var4", kLong),
|
||||
VarHandle("var5", kLong)};
|
||||
|
||||
// {
|
||||
// key: ndims,
|
||||
// value: [
|
||||
// ...
|
||||
// [dim_2, dim_1, ..., dim_n]
|
||||
// ]
|
||||
// }
|
||||
using shapGenInfo = std::unordered_map<int, std::vector<std::vector<int>>>;
|
||||
|
||||
// {
|
||||
// size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n],
|
||||
// strides: [
|
||||
// ...
|
||||
// [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z]
|
||||
// ]
|
||||
// }
|
||||
using shapeInfo =
|
||||
std::pair<std::vector<ExprHandle>, std::vector<std::vector<ExprHandle>>>;
|
||||
|
||||
std::vector<int> dims = {3, 4, 5};
|
||||
|
||||
std::unordered_map<int, std::vector<ExprHandle>> dims_expr_vec_conf = {
|
||||
{3, std::vector<ExprHandle>(vars.begin(), vars.begin() + 2)},
|
||||
{4, std::vector<ExprHandle>(vars.begin(), vars.begin() + 3)},
|
||||
{5, std::vector<ExprHandle>(vars.begin(), vars.begin() + 4)},
|
||||
};
|
||||
|
||||
shapGenInfo channels_last_cont_shape_conf = {
|
||||
{3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}};
|
||||
shapGenInfo channels_last_non_cont_shape_conf = {
|
||||
{3, {{2, 1, 0}, {1, 0, 2}}},
|
||||
{4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}},
|
||||
{5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}};
|
||||
|
||||
shapGenInfo cont_shape_conf = {
|
||||
{3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}};
|
||||
|
||||
auto shape_gen_fn = [dims_expr_vec_conf](
|
||||
int ndims, shapGenInfo shape_gen_info) -> shapeInfo {
|
||||
auto dims_expr_vec = dims_expr_vec_conf.at(ndims);
|
||||
std::vector<std::vector<ExprHandle>> strides_expr_vec;
|
||||
for (size_t i = 0; i < strides_expr_vec.size(); i++) {
|
||||
strides_expr_vec[i].resize(ndims);
|
||||
}
|
||||
|
||||
auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) {
|
||||
if (indicator % 2 == 0) {
|
||||
return a * b;
|
||||
} else {
|
||||
return b * a;
|
||||
}
|
||||
};
|
||||
|
||||
auto stride_order_vec = shape_gen_info.at(ndims);
|
||||
for (size_t i = 0; i < strides_expr_vec.size(); i++) {
|
||||
auto stride_order = stride_order_vec[i];
|
||||
|
||||
strides_expr_vec[i][stride_order[0]] = 1;
|
||||
for (size_t j = 1; j < stride_order.size(); j++) {
|
||||
auto cur_dim_idx = stride_order[j];
|
||||
auto adjacent_dim_idx = stride_order[j - 1];
|
||||
|
||||
strides_expr_vec[i][cur_dim_idx] = stride_gen_fn(
|
||||
i,
|
||||
dims_expr_vec[adjacent_dim_idx],
|
||||
strides_expr_vec[i][adjacent_dim_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
return {dims_expr_vec, strides_expr_vec};
|
||||
};
|
||||
|
||||
auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool {
|
||||
if (ndims == 3) {
|
||||
return buf_handle.is_channels_last_1d_contiguous();
|
||||
} else if (ndims == 4) {
|
||||
return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast);
|
||||
} else {
|
||||
return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d);
|
||||
}
|
||||
};
|
||||
|
||||
// channels-last contiguous
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
|
||||
for (size_t j = 0; j < shape_info.second.size(); j++) {
|
||||
BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
|
||||
ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true);
|
||||
}
|
||||
}
|
||||
|
||||
// channels-last non-contiguous
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf);
|
||||
for (size_t j = 0; j < shape_info.second.size(); j++) {
|
||||
BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
|
||||
ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false);
|
||||
}
|
||||
}
|
||||
|
||||
// contiguous
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
auto shape_info = shape_gen_fn(dims[i], cont_shape_conf);
|
||||
for (size_t j = 0; j < shape_info.second.size(); j++) {
|
||||
BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
|
||||
ASSERT_EQ(buf_handle.is_contiguous(), true);
|
||||
}
|
||||
}
|
||||
|
||||
// non-contiguous
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
|
||||
for (size_t j = 0; j < shape_info.second.size(); j++) {
|
||||
BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
|
||||
ASSERT_EQ(buf_handle.is_contiguous(), false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Expr, LetTest01) {
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle(3.f));
|
||||
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, LetTest02) {
|
||||
VarHandle x("x", kFloat);
|
||||
VarHandle y("y", kFloat);
|
||||
ExprHandle body =
|
||||
ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle(3.f));
|
||||
eval.bindVar(y, ExprHandle(6.f));
|
||||
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
|
||||
}
|
||||
|
||||
TEST(Expr, LetStmtTest01) {
|
||||
BufHandle a_buf("a", {1}, kFloat);
|
||||
BufHandle b_buf("b", {1}, kFloat);
|
||||
|
||||
ExprHandle load_a = a_buf.load(0);
|
||||
VarHandle var = VarHandle("v", kFloat);
|
||||
StmtPtr let_store = Let::make(var, load_a);
|
||||
StmtPtr store_b = b_buf.store({0}, var);
|
||||
BlockPtr block = Block::make({let_store, store_b});
|
||||
|
||||
SimpleIREvaluator eval(block, {a_buf, b_buf});
|
||||
|
||||
PaddedBuffer<float> a_v(1);
|
||||
PaddedBuffer<float> b_v(1);
|
||||
PaddedBuffer<float> b_ref(1);
|
||||
|
||||
a_v(0) = 23;
|
||||
b_ref(0) = a_v(0);
|
||||
eval(a_v, b_v);
|
||||
|
||||
ExpectAllNear(b_v, b_ref, 1e-5);
|
||||
}
|
||||
|
||||
TEST(Expr, IntTest) {
|
||||
VarHandle x("x", kInt);
|
||||
ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle(3));
|
||||
ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, FloatTest) {
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle(3.f));
|
||||
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, ByteTest) {
|
||||
VarHandle x("x", kByte);
|
||||
ExprHandle body = ExprHandle((uint8_t)2) +
|
||||
(x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle((uint8_t)3));
|
||||
ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, CharTest) {
|
||||
VarHandle x("x", kChar);
|
||||
ExprHandle body = ExprHandle((int8_t)2) +
|
||||
(x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle((int8_t)3));
|
||||
ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, ShortTest) {
|
||||
VarHandle x("x", kShort);
|
||||
ExprHandle body = ExprHandle((int16_t)2) +
|
||||
(x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle((int16_t)3));
|
||||
ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, LongTest) {
|
||||
VarHandle x("x", kLong);
|
||||
ExprHandle body = ExprHandle((int64_t)2) +
|
||||
(x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle((int64_t)3));
|
||||
ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, HalfTest) {
|
||||
VarHandle x("x", kHalf);
|
||||
ExprHandle body = ExprHandle((at::Half)2) +
|
||||
(x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle((at::Half)3));
|
||||
ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, DoubleTest) {
|
||||
VarHandle x("x", kDouble);
|
||||
ExprHandle body = ExprHandle((double)2) +
|
||||
(x * ExprHandle((double)3) + ExprHandle((double)4));
|
||||
SimpleIRExprEval eval(body);
|
||||
eval.bindVar(x, ExprHandle((double)3));
|
||||
ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
TEST(Expr, VectorAdd01) {
|
||||
const int kVectorSize = 8;
|
||||
const int kVectorCount = 128;
|
||||
const int kTotalSize = kVectorSize * kVectorCount;
|
||||
|
||||
BufHandle a_buf("A", {kTotalSize}, kFloat);
|
||||
BufHandle b_buf("B", {kTotalSize}, kFloat);
|
||||
BufHandle c_buf("C", {kTotalSize}, kFloat);
|
||||
|
||||
/*
|
||||
Build the following:
|
||||
for (const auto index : c10::irange(kVectorCount)) {
|
||||
store(c_buf, ramp(index * 8, 1, 8),
|
||||
load(a_buf, ramp(index * 8, 1, 8) +
|
||||
load(b_buf, ramp(index * 8, 1, 8))))
|
||||
}
|
||||
*/
|
||||
VarHandle index = VarHandle("index", kInt);
|
||||
ExprHandle load_a =
|
||||
a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
|
||||
ExprHandle load_b =
|
||||
b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
|
||||
ExprHandle value = load_a + load_b;
|
||||
StmtPtr store_c =
|
||||
c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value);
|
||||
StmtPtr stmt = For::make(index, 0, kVectorCount, store_c);
|
||||
|
||||
ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
|
||||
ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
|
||||
ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
|
||||
|
||||
PaddedBuffer<float> a_v(kTotalSize);
|
||||
PaddedBuffer<float> b_v(kTotalSize);
|
||||
PaddedBuffer<float> c_v(kTotalSize);
|
||||
PaddedBuffer<float> c_ref(kTotalSize);
|
||||
for (const auto i : c10::irange(kTotalSize)) {
|
||||
a_v(i) = i * i;
|
||||
b_v(i) = i * i * 4;
|
||||
c_ref(i) = a_v(i) + b_v(i);
|
||||
}
|
||||
SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
ExpectAllNear(c_v, c_ref, 1e-5);
|
||||
}
|
||||
|
||||
TEST(Expr, CompareSelectEQ) {
|
||||
constexpr int N = 1024;
|
||||
BufHandle a("A", {N}, kInt);
|
||||
BufHandle b("B", {N}, kInt);
|
||||
BufHandle c("C", {N}, kInt);
|
||||
std::vector<int> a_buffer(N, 1);
|
||||
std::vector<int> b_buffer(N, 1);
|
||||
std::vector<int> c_buffer(N, 0);
|
||||
std::vector<int> c_ref(N, 0);
|
||||
|
||||
VarHandle i("i", kInt);
|
||||
auto memcpy_expr = For::make(
|
||||
i,
|
||||
0,
|
||||
N,
|
||||
c.store(
|
||||
{i},
|
||||
CompareSelect::make(
|
||||
a.load(i), b.load(i), CompareSelectOperation::kEQ)));
|
||||
|
||||
SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
|
||||
ir_eval(a_buffer, b_buffer, c_buffer);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
|
||||
assertAllEqual(a_buffer, 1);
|
||||
assertAllEqual(b_buffer, 1);
|
||||
assertAllEqual(c_buffer, 1);
|
||||
}
|
||||
|
||||
TEST(Expr, CompareSelectDtypes) {
|
||||
// LHS and RHS expressions should have the same dtype, but this dtype could
|
||||
// differ from the dtype of the return values (but dtypes of true and false
|
||||
// return values should be the same).
|
||||
// This test constructs a CompareSelect expression where the input dtype is
|
||||
// different from the output dtype and verifies that it works correctly:
|
||||
// result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2
|
||||
constexpr int N = 1024;
|
||||
BufHandle a("A", {N}, kInt);
|
||||
BufHandle b("B", {N}, kInt);
|
||||
BufHandle c("C", {N}, kFloat);
|
||||
std::vector<int> a_buffer(N, 1);
|
||||
std::vector<int> b_buffer(N, 1);
|
||||
std::vector<float> c_buffer(N, 0.0f);
|
||||
std::vector<float> c_ref(N, 3.14f);
|
||||
|
||||
VarHandle i("i", kInt);
|
||||
// C[i] = (A[i] == B[i]) ? 3.14f : 2.78f
|
||||
// A and B are int, C is float.
|
||||
auto select_expr = For::make(
|
||||
i,
|
||||
0,
|
||||
N,
|
||||
c.store(
|
||||
{i},
|
||||
CompareSelect::make(
|
||||
a.load(i),
|
||||
b.load(i),
|
||||
FloatImm::make(3.14f),
|
||||
FloatImm::make(2.78f),
|
||||
CompareSelectOperation::kEQ)));
|
||||
|
||||
SimpleIREvaluator ir_eval(select_expr, {a, b, c});
|
||||
ir_eval(a_buffer, b_buffer, c_buffer);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
|
||||
assertAllEqual(a_buffer, 1);
|
||||
assertAllEqual(b_buffer, 1);
|
||||
ExpectAllNear(c_buffer, c_ref, 1e-7);
|
||||
}
|
||||
|
||||
TEST(Expr, IntrinsicsDtypes) {
|
||||
constexpr int N = 256;
|
||||
BufHandle a("A", {N}, kDouble);
|
||||
BufHandle b("B", {N}, kDouble);
|
||||
std::vector<double> a_buffer(N, -10.0);
|
||||
std::vector<double> b_buffer(N, 0.0);
|
||||
std::vector<double> b_ref(N, 10.0);
|
||||
|
||||
VarHandle i("i", kInt);
|
||||
auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i))));
|
||||
|
||||
SimpleIREvaluator ir_eval(abs_expr, {a, b});
|
||||
ir_eval(a_buffer, b_buffer);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
|
||||
assertAllEqual(a_buffer, -10.0);
|
||||
ExpectAllNear(b_buffer, b_ref, 1e-7);
|
||||
}
|
||||
|
||||
TEST(Expr, Substitute01) {
|
||||
VarPtr x = alloc<Var>("x", kFloat);
|
||||
VarPtr y = alloc<Var>("y", kFloat);
|
||||
ExprPtr e =
|
||||
alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y));
|
||||
|
||||
VarPtr z = alloc<Var>("z", kFloat);
|
||||
ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}});
|
||||
ExprPtr e2_ref = alloc<Mul>(
|
||||
alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)),
|
||||
alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y));
|
||||
std::ostringstream oss;
|
||||
oss << *e2;
|
||||
std::string e2_str = oss.str();
|
||||
|
||||
oss.str("");
|
||||
oss << *e2_ref;
|
||||
std::string e2_ref_str = oss.str();
|
||||
ASSERT_EQ(e2_str, e2_ref_str);
|
||||
}
|
||||
|
||||
TEST(Expr, Math01) {
|
||||
ExprHandle v = sin(ExprHandle(1.0f));
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << v;
|
||||
ASSERT_EQ(oss.str(), "sin(1.f)");
|
||||
|
||||
SimpleIRExprEval eval(v);
|
||||
float v_ref = std::sin(1.0f);
|
||||
float res = eval.value<float>();
|
||||
ASSERT_NEAR(res, v_ref, 1e-6);
|
||||
}
|
||||
|
||||
TEST(Expr, UnaryMath01) {
|
||||
struct TestConfig {
|
||||
std::function<ExprHandle(const ExprHandle&)> func;
|
||||
std::function<float(float)> ref_func;
|
||||
};
|
||||
|
||||
std::vector<TestConfig> test_configs = {
|
||||
{[](const ExprHandle& v) { return sin(v); },
|
||||
[](float v) { return std::sin(v); }},
|
||||
{[](const ExprHandle& v) { return sin(v); },
|
||||
[](float v) { return std::sin(v); }},
|
||||
{[](const ExprHandle& v) { return tan(v); },
|
||||
[](float v) { return std::tan(v); }},
|
||||
{[](const ExprHandle& v) { return asin(v); },
|
||||
[](float v) { return std::asin(v); }},
|
||||
{[](const ExprHandle& v) { return acos(v); },
|
||||
[](float v) { return std::acos(v); }},
|
||||
{[](const ExprHandle& v) { return atan(v); },
|
||||
[](float v) { return std::atan(v); }},
|
||||
{[](const ExprHandle& v) { return sinh(v); },
|
||||
[](float v) { return std::sinh(v); }},
|
||||
{[](const ExprHandle& v) { return cosh(v); },
|
||||
[](float v) { return std::cosh(v); }},
|
||||
{[](const ExprHandle& v) { return tanh(v); },
|
||||
[](float v) { return std::tanh(v); }},
|
||||
{[](const ExprHandle& v) { return exp(v); },
|
||||
[](float v) { return std::exp(v); }},
|
||||
{[](const ExprHandle& v) { return tensorexpr::abs(v); },
|
||||
[](float v) { return std::fabs(v); }},
|
||||
{[](const ExprHandle& v) { return log(v); },
|
||||
[](float v) { return std::log(v); }},
|
||||
{[](const ExprHandle& v) { return log2(v); },
|
||||
[](float v) { return std::log2(v); }},
|
||||
{[](const ExprHandle& v) { return log10(v); },
|
||||
[](float v) { return std::log10(v); }},
|
||||
{[](const ExprHandle& v) { return erf(v); },
|
||||
[](float v) { return std::erf(v); }},
|
||||
{[](const ExprHandle& v) { return sqrt(v); },
|
||||
[](float v) { return std::sqrt(v); }},
|
||||
{[](const ExprHandle& v) { return rsqrt(v); },
|
||||
[](float v) { return 1.0f / std::sqrt(v); }},
|
||||
{[](const ExprHandle& v) { return ceil(v); },
|
||||
[](float v) { return std::ceil(v); }},
|
||||
{[](const ExprHandle& v) { return floor(v); },
|
||||
[](float v) { return std::floor(v); }},
|
||||
{[](const ExprHandle& v) { return round(v); },
|
||||
[](float v) { return std::round(v); }},
|
||||
{[](const ExprHandle& v) { return trunc(v); },
|
||||
[](float v) { return std::trunc(v); }},
|
||||
};
|
||||
|
||||
for (const TestConfig& test_config : test_configs) {
|
||||
const float input_v = 0.8765f;
|
||||
ExprHandle v = test_config.func(ExprHandle(input_v));
|
||||
float v_ref = test_config.ref_func(input_v);
|
||||
SimpleIRExprEval eval(v);
|
||||
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
for (float input_v : {std::nan("1"), 0., .5}) {
|
||||
ExprHandle v = FloatImm::make(input_v);
|
||||
SimpleIRExprEval eval(Intrinsics::make(kIsNan, v));
|
||||
ASSERT_NEAR(eval.value<int>(), std::isnan(input_v), 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Expr, BinaryMath01) {
|
||||
struct TestConfig {
|
||||
std::function<ExprHandle(const ExprHandle&, const ExprHandle&)> func;
|
||||
std::function<float(float, float)> ref_func;
|
||||
};
|
||||
|
||||
std::vector<TestConfig> test_configs = {
|
||||
{[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); },
|
||||
[](float v1, float v2) { return std::pow(v1, v2); }},
|
||||
{[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); },
|
||||
[](float v1, float v2) { return std::fmod(v1, v2); }},
|
||||
};
|
||||
|
||||
for (const TestConfig& test_config : test_configs) {
|
||||
const float v1 = 0.8765f;
|
||||
float v2 = 1.2345f;
|
||||
ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
|
||||
float v_ref = test_config.ref_func(v1, v2);
|
||||
SimpleIRExprEval eval(v_expr);
|
||||
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Expr, LogicalOps01) {
|
||||
ExprHandle a(23);
|
||||
ExprHandle b(11);
|
||||
ExprHandle c(0.72f);
|
||||
ExprHandle d(0.69f);
|
||||
ExprHandle f1 = (a > b) && (c > d);
|
||||
ExprHandle f2 = (a > b) && (c < d);
|
||||
ExprHandle f3 = (a < b) && (c > d);
|
||||
ExprHandle f4 = (a < b) && (c < d);
|
||||
ExprHandle f5 = (a < b) || (c > d);
|
||||
ExprHandle f6 = (a < b) || (c < d);
|
||||
ExprHandle f7 = (a > b) || (c < d);
|
||||
ExprHandle f8 = (a > b) || (c > d);
|
||||
|
||||
SimpleIRExprEval eval1(f1);
|
||||
SimpleIRExprEval eval2(f2);
|
||||
SimpleIRExprEval eval3(f3);
|
||||
SimpleIRExprEval eval4(f4);
|
||||
SimpleIRExprEval eval5(f5);
|
||||
SimpleIRExprEval eval6(f6);
|
||||
SimpleIRExprEval eval7(f7);
|
||||
SimpleIRExprEval eval8(f8);
|
||||
ASSERT_EQ(eval1.value<int>(), 1);
|
||||
ASSERT_EQ(eval2.value<int>(), 0);
|
||||
ASSERT_EQ(eval3.value<int>(), 0);
|
||||
ASSERT_EQ(eval4.value<int>(), 0);
|
||||
ASSERT_EQ(eval5.value<int>(), 1);
|
||||
ASSERT_EQ(eval6.value<int>(), 0);
|
||||
ASSERT_EQ(eval7.value<int>(), 1);
|
||||
ASSERT_EQ(eval8.value<int>(), 1);
|
||||
}
|
||||
|
||||
TEST(Expr, LogicalOps02) {
|
||||
ExprHandle a(23);
|
||||
ExprHandle b(11);
|
||||
ExprHandle c(0.72f);
|
||||
ExprHandle d(0.72f);
|
||||
|
||||
ExprHandle f1 = (a > b) || (c > d);
|
||||
ExprHandle f2 = (a > b) && (c <= d);
|
||||
ExprHandle f3 = (a > b) && (c > d);
|
||||
ExprHandle ff1 = f1 && f2;
|
||||
ExprHandle ff2 = f2 || f3;
|
||||
|
||||
SimpleIRExprEval eval1(ff1);
|
||||
SimpleIRExprEval eval2(ff2);
|
||||
ASSERT_EQ(eval1.value<int>(), 1);
|
||||
ASSERT_EQ(eval2.value<int>(), 1);
|
||||
}
|
||||
|
||||
TEST(Expr, LogicalOps03) {
|
||||
ExprHandle a(23);
|
||||
ExprHandle b(11);
|
||||
ExprHandle c(0.72f);
|
||||
ExprHandle d(0.69f);
|
||||
|
||||
// Bool types
|
||||
ExprHandle bool_f1 = (a > b) && BoolImm::make(true);
|
||||
ExprHandle bool_f2 = (c <= d) || BoolImm::make(true);
|
||||
|
||||
// Int types
|
||||
ExprHandle int_f1 = (a > b) && IntImm::make(1);
|
||||
ExprHandle int_f2 = (c <= d) || IntImm::make(1);
|
||||
|
||||
// Short types
|
||||
ExprHandle short_f1 = (a > b) && ShortImm::make(1);
|
||||
ExprHandle short_f2 = (c <= d) || ShortImm::make(1);
|
||||
|
||||
// Long types
|
||||
ExprHandle long_f1 = (a > b) && LongImm::make(1);
|
||||
ExprHandle long_f2 = (c <= d) || LongImm::make(1);
|
||||
|
||||
// Char types
|
||||
ExprHandle char_f1 = (a > b) && CharImm::make(1);
|
||||
ExprHandle char_f2 = (c <= d) || CharImm::make(1);
|
||||
|
||||
// Byte types
|
||||
ExprHandle byte_f1 = (a > b) && ByteImm::make(1);
|
||||
ExprHandle byte_f2 = (c <= d) || ByteImm::make(1);
|
||||
|
||||
SimpleIRExprEval eval1(bool_f1);
|
||||
SimpleIRExprEval eval2(bool_f2);
|
||||
SimpleIRExprEval eval3(int_f1);
|
||||
SimpleIRExprEval eval4(int_f2);
|
||||
SimpleIRExprEval eval5(short_f1);
|
||||
SimpleIRExprEval eval6(short_f2);
|
||||
SimpleIRExprEval eval7(long_f1);
|
||||
SimpleIRExprEval eval8(long_f2);
|
||||
SimpleIRExprEval eval9(char_f1);
|
||||
SimpleIRExprEval eval10(char_f2);
|
||||
SimpleIRExprEval eval11(byte_f1);
|
||||
SimpleIRExprEval eval12(byte_f2);
|
||||
|
||||
ASSERT_EQ(eval1.value<bool>(), true);
|
||||
ASSERT_EQ(eval2.value<bool>(), true);
|
||||
ASSERT_EQ(eval3.value<int>(), 1);
|
||||
ASSERT_EQ(eval4.value<int>(), 1);
|
||||
ASSERT_EQ(eval5.value<int16_t>(), 1);
|
||||
ASSERT_EQ(eval6.value<int16_t>(), 1);
|
||||
ASSERT_EQ(eval7.value<int64_t>(), 1);
|
||||
ASSERT_EQ(eval8.value<int64_t>(), 1);
|
||||
ASSERT_EQ(eval9.value<int8_t>(), 1);
|
||||
ASSERT_EQ(eval10.value<int8_t>(), 1);
|
||||
ASSERT_EQ(eval11.value<uint8_t>(), 1);
|
||||
ASSERT_EQ(eval12.value<uint8_t>(), 1);
|
||||
}
|
||||
|
||||
TEST(Expr, BitwiseOps) {
|
||||
ExprHandle a(59);
|
||||
ExprHandle b(11);
|
||||
ExprHandle c(101);
|
||||
ExprHandle d(2);
|
||||
ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
|
||||
|
||||
SimpleIRExprEval eval(f);
|
||||
ASSERT_EQ(eval.value<int>(), 11);
|
||||
}
|
||||
|
||||
TEST(Expr, DynamicShapeAdd) {
|
||||
auto testWithSize = [](int32_t size) {
|
||||
VarHandle n("n", kInt);
|
||||
BufHandle a("a", {n}, kFloat);
|
||||
BufHandle b("b", {n}, kFloat);
|
||||
BufHandle c("c", {n}, kFloat);
|
||||
VarHandle i("i", kInt);
|
||||
StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
|
||||
std::vector<float> aData(size, 1.0f);
|
||||
std::vector<float> bData(size, 2.0f);
|
||||
std::vector<float> cData(size, 0.0f);
|
||||
SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size);
|
||||
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
|
||||
};
|
||||
testWithSize(1);
|
||||
testWithSize(16);
|
||||
testWithSize(37);
|
||||
}
|
||||
|
||||
TEST(Expr, OutOfBounds) {
|
||||
ExprHandle N(10);
|
||||
ExprHandle start(0);
|
||||
ExprHandle stop(15);
|
||||
VarHandle i("i", kInt);
|
||||
|
||||
BufHandle X("X", {N}, kInt);
|
||||
|
||||
auto body = Store::make(X, {i}, i);
|
||||
auto stmt = For::make(i, start, stop, body);
|
||||
|
||||
PaddedBuffer<int> data(20);
|
||||
|
||||
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
||||
}
|
||||
|
||||
TEST(Expr, OutOfBounds2d) {
|
||||
std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
|
||||
for (auto sizes : size_options) {
|
||||
ExprHandle N(sizes.first);
|
||||
ExprHandle M(sizes.second);
|
||||
ExprHandle start(0);
|
||||
ExprHandle stopInner(15);
|
||||
ExprHandle stopOuter(15);
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
|
||||
BufHandle X("X", {N, M}, kInt);
|
||||
|
||||
auto body = Store::make(X, {i, j}, i);
|
||||
auto inner = For::make(j, start, stopInner, body);
|
||||
auto stmt = For::make(i, start, stopOuter, inner);
|
||||
|
||||
PaddedBuffer<int> data(400);
|
||||
|
||||
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Expr, OutOfBounds2dFlattenedIndex) {
|
||||
ExprHandle buf_size(149);
|
||||
ExprHandle start(0);
|
||||
ExprHandle stopInner(15);
|
||||
ExprHandle stopOuter(10);
|
||||
VarHandle i("i", kInt);
|
||||
VarHandle j("j", kInt);
|
||||
|
||||
BufHandle X("X", {buf_size}, kInt);
|
||||
|
||||
auto idx = Add::make(Mul::make(i, stopInner), j);
|
||||
auto body = Store::make(X, {idx}, i);
|
||||
auto inner = For::make(j, start, stopInner, body);
|
||||
auto stmt = For::make(i, start, stopOuter, inner);
|
||||
|
||||
PaddedBuffer<int> data(400);
|
||||
|
||||
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
|
||||
}
|
||||
|
||||
void testCond01() {
|
||||
const int N = 16;
|
||||
PaddedBuffer<float> a_v(N);
|
||||
BufHandle a_buf("a", {N}, kFloat);
|
||||
VarHandle index = VarHandle("index", kInt);
|
||||
StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
|
||||
StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
|
||||
ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
|
||||
StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3);
|
||||
StmtPtr for_stmt = For::make(index, 0, N, assign);
|
||||
SimpleIREvaluator(for_stmt, {a_buf})(a_v);
|
||||
|
||||
PaddedBuffer<float> a_ref(N);
|
||||
for (const auto i : c10::irange(N)) {
|
||||
if (i % 2 == 0) {
|
||||
a_ref(i) = i * 2;
|
||||
} else {
|
||||
a_ref(i) = i * 3;
|
||||
}
|
||||
}
|
||||
ExpectAllNear(a_v, a_ref, 1e-5);
|
||||
}
|
||||
|
||||
void testIfThenElse01() {
|
||||
ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f));
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << v;
|
||||
ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)");
|
||||
|
||||
SimpleIRExprEval eval(v);
|
||||
ASSERT_EQ(eval.value<float>(), 1.0f);
|
||||
}
|
||||
|
||||
void testIfThenElse02() {
|
||||
ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f));
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << v;
|
||||
ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
|
||||
|
||||
SimpleIRExprEval eval(v);
|
||||
ASSERT_EQ(eval.value<float>(), 2.0f);
|
||||
}
|
||||
|
||||
void testIfThenElse03() {
|
||||
ExprHandle v =
|
||||
ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f));
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << v;
|
||||
ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
|
||||
|
||||
SimpleIRExprEval eval(v);
|
||||
ASSERT_EQ(eval.value<float>(), 2.0f);
|
||||
}
|
||||
|
||||
void testStmtClone() {
|
||||
const int N = 16;
|
||||
|
||||
BufHandle a_buf("a", {N}, kInt);
|
||||
VarHandle index = VarHandle("index", kInt);
|
||||
StmtPtr body = a_buf.store({index}, 5);
|
||||
StmtPtr loop = For::make(index, 0, N, body);
|
||||
|
||||
StmtPtr cloned_loop = Stmt::clone(loop);
|
||||
std::vector<int> orig_loop_results(N);
|
||||
std::vector<int> cloned_loop_results(N);
|
||||
SimpleIREvaluator(loop, {a_buf})(orig_loop_results);
|
||||
SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results);
|
||||
|
||||
assertAllEqual(orig_loop_results, 5);
|
||||
assertAllEqual(cloned_loop_results, 5);
|
||||
|
||||
// Let's add another assign to the body in the cloned loop and verify that the
|
||||
// original statement hasn't changed while the cloned one has.
|
||||
StmtPtr body_addition = a_buf.store({index}, 33);
|
||||
BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body());
|
||||
cloned_body->append_stmt(body_addition);
|
||||
|
||||
std::vector<int> orig_loop_results_after_mutation(N);
|
||||
std::vector<int> cloned_loop_results_after_mutation(N);
|
||||
SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation);
|
||||
SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation);
|
||||
|
||||
assertAllEqual(orig_loop_results_after_mutation, 5);
|
||||
assertAllEqual(cloned_loop_results_after_mutation, 33);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
1061
test/cpp/tensorexpr/test_external_calls.cpp
Normal file
1061
test/cpp/tensorexpr/test_external_calls.cpp
Normal file
File diff suppressed because it is too large
Load Diff
319
test/cpp/tensorexpr/test_graph_opt.cpp
Normal file
319
test/cpp/tensorexpr/test_graph_opt.cpp
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
class GraphOpt : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
old_cat_wo_conditionals_ = getCatWoConditionals();
|
||||
getCatWoConditionals() = true;
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
getCatWoConditionals() = old_cat_wo_conditionals_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool old_cat_wo_conditionals_;
|
||||
};
|
||||
|
||||
TEST_F(GraphOpt, OptimizeCat) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
|
||||
return (%5))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
TensorExprKernel kernel(g);
|
||||
|
||||
// The `aten::log` op must be moved to the inputs of `aten::cat`.
|
||||
testing::FileCheck()
|
||||
.check("aten::log")
|
||||
->check("aten::log")
|
||||
->check("aten::log")
|
||||
->check("aten::cat")
|
||||
->check_not("aten::log")
|
||||
->run(*kernel.graph());
|
||||
|
||||
auto x = at::rand({10}, at::kFloat);
|
||||
auto y = at::rand({20}, at::kFloat);
|
||||
auto z = at::rand({30}, at::kFloat);
|
||||
auto ref = at::log(at::cat({x, y, z}, 0));
|
||||
|
||||
std::vector<at::Tensor> inputs = {x, y, z};
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
kernel.run(stack);
|
||||
auto out = stack[0].toTensor();
|
||||
ASSERT_EQ(out.sizes(), ref.sizes());
|
||||
ASSERT_EQ(out.dtype(), ref.dtype());
|
||||
ASSERT_TRUE(at::allclose(out, ref));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(GraphOpt, OptimizeCat2) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
|
||||
%6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
|
||||
return (%6))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
TensorExprKernel kernel(g);
|
||||
|
||||
// The `aten::log` and `aten::tanh` ops must be moved to the inputs of
|
||||
// `aten::cat`.
|
||||
testing::FileCheck()
|
||||
.check("aten::log")
|
||||
->check("aten::log")
|
||||
->check("aten::log")
|
||||
->check("aten::tanh")
|
||||
->check("aten::tanh")
|
||||
->check("aten::tanh")
|
||||
->check("aten::cat")
|
||||
->check_not("aten::log")
|
||||
->check_not("aten::tanh")
|
||||
->run(*kernel.graph());
|
||||
|
||||
auto x = at::rand({10}, at::kFloat);
|
||||
auto y = at::rand({20}, at::kFloat);
|
||||
auto z = at::rand({30}, at::kFloat);
|
||||
auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
|
||||
|
||||
std::vector<at::Tensor> inputs = {x, y, z};
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
kernel.run(stack);
|
||||
auto out = stack[0].toTensor();
|
||||
ASSERT_EQ(out.sizes(), ref.sizes());
|
||||
ASSERT_EQ(out.dtype(), ref.dtype());
|
||||
ASSERT_TRUE(at::allclose(out, ref));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(GraphOpt, OptimizeCat3) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%a : Float(60, strides=[1], device=cpu),
|
||||
%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
|
||||
%6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
|
||||
return (%6))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
TensorExprKernel kernel(g);
|
||||
|
||||
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
|
||||
// But the `aten::mul` op must not be moved since it is not a single-tensor
|
||||
// op (it has 2 tensor inputs).
|
||||
testing::FileCheck()
|
||||
.check("aten::tanh")
|
||||
->check("aten::tanh")
|
||||
->check("aten::tanh")
|
||||
->check("aten::cat")
|
||||
->check("aten::mul")
|
||||
->check_not("aten::tanh")
|
||||
->run(*kernel.graph());
|
||||
|
||||
auto a = at::rand({60}, at::kFloat);
|
||||
auto x = at::rand({10}, at::kFloat);
|
||||
auto y = at::rand({20}, at::kFloat);
|
||||
auto z = at::rand({30}, at::kFloat);
|
||||
auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
|
||||
|
||||
std::vector<at::Tensor> inputs = {a, x, y, z};
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
kernel.run(stack);
|
||||
auto out = stack[0].toTensor();
|
||||
ASSERT_EQ(out.sizes(), ref.sizes());
|
||||
ASSERT_EQ(out.dtype(), ref.dtype());
|
||||
ASSERT_TRUE(at::allclose(out, ref));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Int(10, strides=[1], device=cpu),
|
||||
%y : Int(20, strides=[1], device=cpu),
|
||||
%z : Int(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
|
||||
return (%5))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
TensorExprKernel kernel(g);
|
||||
|
||||
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
|
||||
// The scalar type of the inputs to `cat` should now be `Float` since they
|
||||
// are the result of `tanh` which does the type promotion.
|
||||
testing::FileCheck()
|
||||
.check("aten::tanh")
|
||||
->check("aten::tanh")
|
||||
->check("aten::tanh")
|
||||
->check("aten::cat")
|
||||
->check_not("aten::tanh")
|
||||
->run(*kernel.graph());
|
||||
|
||||
auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
|
||||
auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
|
||||
auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
|
||||
auto ref = at::tanh(at::cat({x, y, z}, 0));
|
||||
|
||||
std::vector<at::Tensor> inputs = {x, y, z};
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
kernel.run(stack);
|
||||
auto out = stack[0].toTensor();
|
||||
ASSERT_EQ(out.sizes(), ref.sizes());
|
||||
ASSERT_EQ(out.dtype(), ref.dtype());
|
||||
ASSERT_TRUE(at::allclose(out, ref));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Double(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
%5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
|
||||
return (%5))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
TensorExprKernel kernel(g);
|
||||
|
||||
// No transformation should have happened because the `aten::cat` op performs
|
||||
// type promotion. This case is currently not handled.
|
||||
testing::FileCheck()
|
||||
.check("aten::cat")
|
||||
->check("aten::log")
|
||||
->check_not("aten::cat")
|
||||
->check_not("aten::log")
|
||||
->run(*kernel.graph());
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(60, strides=[1], device=cpu),
|
||||
%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
|
||||
return (%5))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
TensorExprKernel kernel(g);
|
||||
|
||||
// No transformation is expected since the consumers of cat are not
|
||||
// single-tensor element-wise ops.
|
||||
testing::FileCheck()
|
||||
.check("aten::cat")
|
||||
->check("aten::mul")
|
||||
->check_not("aten::cat")
|
||||
->check_not("aten::mul")
|
||||
->run(*kernel.graph());
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(60, strides=[1], device=cpu),
|
||||
%1 : Float(60, strides=[1], device=cpu),
|
||||
%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(30, strides=[1], device=cpu)):
|
||||
%one : int = prim::Constant[value=1]()
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
|
||||
%6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
|
||||
return (%6))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
|
||||
TensorExprKernel kernel(g);
|
||||
|
||||
// No transformation is expected since the consumers of cat are not
|
||||
// single-tensor element-wise ops.
|
||||
testing::FileCheck()
|
||||
.check("aten::cat")
|
||||
->check("aten::mul")
|
||||
->check("aten::add")
|
||||
->check_not("aten::cat")
|
||||
->check_not("aten::mul")
|
||||
->check_not("aten::add")
|
||||
->run(*kernel.graph());
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(GraphOpt, AOTGraphPrepPasses) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x, %y, %z, %i : int):
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
return (%xyz_list, %i))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
removeGraphOutput(g, 1);
|
||||
replaceListOutputWithTuple(g);
|
||||
LowerAllTuples(g);
|
||||
|
||||
testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
98
test/cpp/tensorexpr/test_ir_printer.cpp
Normal file
98
test/cpp/tensorexpr/test_ir_printer.cpp
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <stdexcept>
|
||||
#include "test/cpp/tensorexpr/test_base.h"
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
#include <sstream>
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
TEST(IRPrinter, BasicValueTest) {
|
||||
ExprHandle a = IntImm::make(2), b = IntImm::make(3);
|
||||
ExprHandle c = Add::make(a, b);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << c;
|
||||
ASSERT_EQ(ss.str(), "2 + 3");
|
||||
}
|
||||
|
||||
TEST(IRPrinter, BasicValueTest02) {
|
||||
ExprHandle a(2.0f);
|
||||
ExprHandle b(3.0f);
|
||||
ExprHandle c(4.0f);
|
||||
ExprHandle d(5.0f);
|
||||
ExprHandle f = (a + b) - (c + d);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << f;
|
||||
ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
|
||||
}
|
||||
|
||||
TEST(IRPrinter, BasicValueTest03) {
|
||||
ExprHandle a(3.402823466385289e+38f);
|
||||
ExprHandle b(-3.402823466385289e+38f);
|
||||
std::stringstream ss;
|
||||
ss << a << ", " << b;
|
||||
ASSERT_EQ(ss.str(), "3.402823466385289e+38f, -3.402823466385289e+38f");
|
||||
}
|
||||
|
||||
TEST(IRPrinter, CastTest) {
|
||||
VarHandle x("x", kHalf);
|
||||
VarHandle y("y", kFloat);
|
||||
ExprHandle body = ExprHandle(2.f) +
|
||||
(Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << body;
|
||||
ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)");
|
||||
}
|
||||
|
||||
TEST(IRPrinter, FunctionName) {
|
||||
int M = 4;
|
||||
int N = 20;
|
||||
|
||||
Tensor producer = Compute(
|
||||
"producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return m * n;
|
||||
});
|
||||
|
||||
Tensor chunk_0 = Compute(
|
||||
"chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return producer.load(m, n);
|
||||
});
|
||||
|
||||
Tensor chunk_1 = Compute(
|
||||
"chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return producer.load(m, n + ExprHandle(N / 2));
|
||||
});
|
||||
|
||||
Tensor consumer = Compute(
|
||||
"consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) {
|
||||
return i * chunk_1.load(i, j);
|
||||
});
|
||||
|
||||
LoopNest l({chunk_0, chunk_1, consumer});
|
||||
auto body = LoopNest::sanitizeNames(l.root_stmt());
|
||||
|
||||
std::stringstream ss;
|
||||
ss << *body;
|
||||
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: for (int i_2
|
||||
# CHECK: for (int j_2
|
||||
# CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR";
|
||||
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, ss.str());
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
191
test/cpp/tensorexpr/test_ir_verifier.cpp
Normal file
191
test/cpp/tensorexpr/test_ir_verifier.cpp
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <stdexcept>
|
||||
#include "test/cpp/tensorexpr/test_base.h"
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
#include <sstream>
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
TEST(IRVerifier, BitwiseOps) {
|
||||
VarPtr X = alloc<Var>("x", kInt);
|
||||
VarPtr Y = alloc<Var>("y", kFloat);
|
||||
{
|
||||
auto a = alloc<And>(X, Y);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
auto a = alloc<Or>(X, Y);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
auto a = alloc<Xor>(X, Y);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
auto a = alloc<Lshift>(X, Y);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
auto a = alloc<Rshift>(X, Y);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRVerifier, CompareSelect) {
|
||||
ExprPtr X = alloc<IntImm>(1);
|
||||
ExprPtr Y = alloc<FloatImm>(3.14f);
|
||||
{
|
||||
auto a = alloc<CompareSelect>(X, X, X, Y, kEQ);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
auto a = alloc<CompareSelect>(X, Y, X, X, kEQ);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRVerifier, Ramp) {
|
||||
VarPtr I = alloc<Var>("i", kInt);
|
||||
VarPtr J = alloc<Var>("j", kFloat);
|
||||
{
|
||||
auto a = alloc<Ramp>(I, J, 4);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRVerifier, Load) {
|
||||
VarPtr I = alloc<Var>("i", kInt);
|
||||
VarPtr J = alloc<Var>("j", kLong);
|
||||
VarPtr K = alloc<Var>("k", kFloat);
|
||||
BufPtr B = alloc<Buf>(
|
||||
"b",
|
||||
std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}),
|
||||
kFloat);
|
||||
{
|
||||
// Indices with different int dtypes (kInt, kLong) are ok
|
||||
auto a = alloc<Load>(B, std::vector<ExprPtr>({I, J}));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_NO_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
// Float index
|
||||
auto a = alloc<Load>(B, std::vector<ExprPtr>({K, K}));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
// Multilanes are only allowed in flattened indices
|
||||
auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4);
|
||||
auto a = alloc<Load>(B, std::vector<ExprPtr>({I, multilane_index}));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRVerifier, IfThenElse) {
|
||||
VarPtr I = alloc<Var>("i", kInt);
|
||||
VarPtr J = alloc<Var>("j", kLong);
|
||||
VarPtr K = alloc<Var>("k", kFloat);
|
||||
{
|
||||
// Condition must be integral
|
||||
auto a = alloc<IfThenElse>(K, I, I);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
// Dtypes of true and false exprs must match
|
||||
auto a = alloc<IfThenElse>(I, I, J);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
// Can't have multiple lanes in condition expr
|
||||
auto a = alloc<IfThenElse>(alloc<Broadcast>(I, 4), I, I);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRVerifier, For) {
|
||||
VarPtr I = alloc<Var>("i", kInt);
|
||||
VarPtr J = alloc<Var>("j", kInt);
|
||||
StmtPtr body = alloc<Block>(std::vector<StmtPtr>({}));
|
||||
{
|
||||
// Can't have nullptr as a Var
|
||||
auto a = alloc<For>(nullptr, I, J, body);
|
||||
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRVerifier, Block) {
|
||||
VarPtr I = alloc<Var>("i", kInt);
|
||||
BufPtr B = alloc<Buf>("B", std::vector<ExprPtr>({alloc<IntImm>(10)}), kInt);
|
||||
{
|
||||
StmtPtr store = alloc<Store>(B, std::vector<ExprPtr>({I}), I);
|
||||
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
||||
StmtPtr block1 = alloc<Block>(std::vector<StmtPtr>({store}));
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
StmtPtr block2 = alloc<Block>(std::vector<StmtPtr>({store}));
|
||||
// Stmt can't have multiple parents, thus inserting it into several blocks
|
||||
// is illegal
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(block2));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(IRVerifier, Store) {
|
||||
VarPtr I = alloc<Var>("i", kInt);
|
||||
VarPtr J = alloc<Var>("j", kLong);
|
||||
VarPtr K = alloc<Var>("k", kFloat);
|
||||
BufPtr B = alloc<Buf>(
|
||||
"b",
|
||||
std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}),
|
||||
kFloat);
|
||||
{
|
||||
// Indices with different int dtypes (kInt, kLong) are ok
|
||||
auto a = alloc<Store>(B, std::vector<ExprPtr>({I, J}), K);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_NO_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
// Float index
|
||||
auto a = alloc<Store>(B, std::vector<ExprPtr>({K, K}), K);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
// Multilanes are only allowed in flattened indices
|
||||
auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4);
|
||||
auto a = alloc<Store>(B, std::vector<ExprPtr>({I, multilane_index}), K);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
{
|
||||
// Value and buf dtypes mismatch
|
||||
auto a = alloc<Store>(B, std::vector<ExprPtr>({I}), I);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
EXPECT_ANY_THROW(verify(a));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
2133
test/cpp/tensorexpr/test_kernel.cpp
Normal file
2133
test/cpp/tensorexpr/test_kernel.cpp
Normal file
File diff suppressed because it is too large
Load Diff
1799
test/cpp/tensorexpr/test_llvm.cpp
Normal file
1799
test/cpp/tensorexpr/test_llvm.cpp
Normal file
File diff suppressed because it is too large
Load Diff
6894
test/cpp/tensorexpr/test_loopnest.cpp
Normal file
6894
test/cpp/tensorexpr/test_loopnest.cpp
Normal file
File diff suppressed because it is too large
Load Diff
3252
test/cpp/tensorexpr/test_memdependency.cpp
Normal file
3252
test/cpp/tensorexpr/test_memdependency.cpp
Normal file
File diff suppressed because it is too large
Load Diff
708
test/cpp/tensorexpr/test_memplanning.cpp
Normal file
708
test/cpp/tensorexpr/test_memplanning.cpp
Normal file
|
|
@ -0,0 +1,708 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <test/cpp/tensorexpr/padded_buffer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
extern void checkIR(StmtPtr s, const std::string& pattern);
|
||||
|
||||
TEST(BufLiveRange, SingleRangeLine) {
|
||||
VarHandle i("i", kInt), j("j", kInt);
|
||||
BufHandle a("a", {32}, kFloat);
|
||||
BufHandle b("b", {32, 32}, kFloat);
|
||||
|
||||
// Construct Stmt:
|
||||
// {
|
||||
// for (int i = 0; i < 32; i++) {
|
||||
// a[i] = 0;
|
||||
// for (int j = 0; j < 32; j++) {
|
||||
// a[i] = (a[i]) + (b[i, j]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
StorePtr aInit = Store::make(a, {i}, 0);
|
||||
ExprHandle reduce = a.load({i}) + b.load({i, j});
|
||||
StorePtr aReduce = Store::make(a, {i}, reduce);
|
||||
StmtPtr loop =
|
||||
For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)}));
|
||||
|
||||
StmtPtr stmt = Block::make({loop});
|
||||
|
||||
auto range = BufLiveRange::liveRange(stmt, a.node());
|
||||
ASSERT_TRUE(std::get<0>(range) == 0);
|
||||
ASSERT_TRUE(std::get<1>(range) == 0);
|
||||
}
|
||||
|
||||
TEST(BufLiveRange, MulRangeLine) {
|
||||
VarHandle i("i", kInt);
|
||||
BufHandle a("a", {32}, kFloat);
|
||||
BufHandle b("b", {32}, kFloat);
|
||||
|
||||
// Construct Stmt:
|
||||
// {
|
||||
// for (int i = 0; i < 32; i++) {
|
||||
// if (i<10 ? 1 : 0) {
|
||||
// a[i] = i + i;
|
||||
// b[i] = i * i;
|
||||
// }
|
||||
// }
|
||||
// for (int i = 0; i < 32; i++) {
|
||||
// if (i>10 ? 1 : 0) {
|
||||
// a[i] = i * i;
|
||||
// b[i] = i + i;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
StorePtr aStore_1 = Store::make(a, {i}, i + i);
|
||||
StorePtr bStore_1 = Store::make(b, {i}, i * i);
|
||||
StmtPtr loop_1 = For::make(
|
||||
i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL));
|
||||
|
||||
StorePtr aStore_2 = Store::make(a, {i}, i * i);
|
||||
StorePtr bStore_2 = Store::make(b, {i}, i + i);
|
||||
StmtPtr loop_2 = For::make(
|
||||
i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL));
|
||||
|
||||
StmtPtr stmt = Block::make({loop_1, loop_2});
|
||||
|
||||
auto range_a = BufLiveRange::liveRange(stmt, a.node());
|
||||
ASSERT_TRUE(std::get<0>(range_a) == 0);
|
||||
ASSERT_TRUE(std::get<1>(range_a) == 1);
|
||||
|
||||
auto range_b = BufLiveRange::liveRange(stmt, b.node());
|
||||
ASSERT_TRUE(std::get<0>(range_b) == 0);
|
||||
ASSERT_TRUE(std::get<1>(range_b) == 1);
|
||||
}
|
||||
|
||||
TEST(MemPlanning, MemReuseWithTypeCast) {
|
||||
int M = 4;
|
||||
int N = 4;
|
||||
int K = 4;
|
||||
|
||||
BufHandle AP("A", {M, K}, kFloat);
|
||||
BufHandle BP("B", {K, N}, kFloat);
|
||||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET =
|
||||
Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n));
|
||||
});
|
||||
Tensor FT =
|
||||
Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n);
|
||||
});
|
||||
StmtPtr stmt =
|
||||
tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
|
||||
|
||||
// Constructed stmt:
|
||||
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
|
||||
// E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
|
||||
// different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E'
|
||||
// with typecasting.
|
||||
//{
|
||||
// for (int i = 0; i < 4; i++) {
|
||||
// for (int i_1 = 0; i_1 < 4; i_1++) {
|
||||
// gemm[i, i_1] = float(0);
|
||||
// for (int i_2 = 0; i_2 < 4; i_2++) {
|
||||
// gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
|
||||
// i_1]), reduce_args={i_2});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int i_3 = 0; i_3 < 4; i_3++) {
|
||||
// for (int i_4 = 0; i_4 < 4; i_4++) {
|
||||
// relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]);
|
||||
// }
|
||||
// }
|
||||
// for (int i_5 = 0; i_5 < 4; i_5++) {
|
||||
// for (int i_6 = 0; i_6 < 4; i_6++) {
|
||||
// E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6]));
|
||||
// }
|
||||
// }
|
||||
// for (int i_7 = 0; i_7 < 4; i_7++) {
|
||||
// for (int i_8 = 0; i_8 < 4; i_8++) {
|
||||
// F[i_7, i_8] = E[i_7, i_8];
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
LoopNest l(stmt, {FT.buf()});
|
||||
l.prepareForCodegen();
|
||||
SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT});
|
||||
|
||||
checkIR(cg.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
|
||||
# CHECK: Alias(E,gemm);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
PaddedBuffer<float> a_v(M, K, "a");
|
||||
PaddedBuffer<float> b_v(K, N, "b");
|
||||
PaddedBuffer<uint8_t> o1(M, N, "e_before");
|
||||
PaddedBuffer<uint8_t> o2(M, N, "e_after");
|
||||
|
||||
for (const auto m : c10::irange(M)) {
|
||||
for (const auto k : c10::irange(K)) {
|
||||
a_v(m, k) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto k : c10::irange(K)) {
|
||||
for (const auto n : c10::irange(N)) {
|
||||
b_v(k, n) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
}
|
||||
|
||||
cg.call({a_v, b_v, o1});
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
|
||||
|
||||
checkIR(cg_llvm.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
|
||||
# CHECK: Alias(E,gemm);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
cg_llvm.call({a_v, b_v, o2});
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ExpectAllNear(o1, o2, 1e-5);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(MemPlanning, NoMemReuseForLargerType) {
|
||||
int M = 4;
|
||||
int N = 4;
|
||||
int K = 4;
|
||||
|
||||
BufHandle AP("A", {M, K}, kShort);
|
||||
BufHandle BP("B", {K, N}, kShort);
|
||||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{K});
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET =
|
||||
Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n));
|
||||
});
|
||||
Tensor FT =
|
||||
Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n);
|
||||
});
|
||||
StmtPtr stmt =
|
||||
tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
|
||||
|
||||
// Constructed stmt:
|
||||
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
|
||||
// E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
|
||||
// different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for
|
||||
// 'E'.
|
||||
//{
|
||||
// for (int i = 0; i < 4; i++) {
|
||||
// for (int i_1 = 0; i_1 < 4; i_1++) {
|
||||
// gemm[i, i_1] = int16_t(0);
|
||||
// for (int i_2 = 0; i_2 < 4; i_2++) {
|
||||
// gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
|
||||
// i_1]), reduce_args={i_2});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int i_3 = 0; i_3 < 4; i_3++) {
|
||||
// for (int i_4 = 0; i_4 < 4; i_4++) {
|
||||
// relu[i_3, i_4] = (gemm[i_3, i_4])<int16_t(0) ? int16_t(0) : (gemm[i_3,
|
||||
// i_4]);
|
||||
// }
|
||||
// }
|
||||
// for (int i_5 = 0; i_5 < 4; i_5++) {
|
||||
// for (int i_6 = 0; i_6 < 4; i_6++) {
|
||||
// E[i_5, i_6] = float((relu[i_5, i_6]) + (relu[i_5, i_6]));
|
||||
// }
|
||||
// }
|
||||
// for (int i_7 = 0; i_7 < 4; i_7++) {
|
||||
// for (int i_8 = 0; i_8 < 4; i_8++) {
|
||||
// F[i_7, i_8] = E[i_7, i_8];
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
LoopNest l(stmt, {FT.buf()});
|
||||
l.prepareForCodegen();
|
||||
SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT.buf()});
|
||||
|
||||
checkIR(cg.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
|
||||
# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
|
||||
# CHECK: Allocate(E); // dtype=float, dims=[4, 4]
|
||||
# CHECK: Free(E);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
PaddedBuffer<short> a_v(M, K, "a");
|
||||
PaddedBuffer<short> b_v(K, N, "b");
|
||||
PaddedBuffer<float> o1(M, N, "e_before");
|
||||
PaddedBuffer<float> o2(M, N, "e_after");
|
||||
|
||||
for (const auto m : c10::irange(M)) {
|
||||
for (const auto k : c10::irange(K)) {
|
||||
a_v(m, k) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto k : c10::irange(K)) {
|
||||
for (const auto n : c10::irange(N)) {
|
||||
b_v(k, n) = at::randn({1}).item().to<float>();
|
||||
}
|
||||
}
|
||||
|
||||
cg.call({a_v, b_v, o1});
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
|
||||
|
||||
checkIR(cg_llvm.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
|
||||
# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
|
||||
# CHECK: Allocate(E); // dtype=float, dims=[4, 4]
|
||||
# CHECK: Free(E);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
cg_llvm.call({a_v, b_v, o2});
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
ExpectAllNear(o1, o2, 1e-5);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(MemPlanning, SameBufSizeMemReuse) {
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 2048;
|
||||
|
||||
BufHandle AP("A", {M, K}, kFloat);
|
||||
BufHandle BP("B", {K, N}, kFloat);
|
||||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET =
|
||||
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return DT.load(m, n) + DT.load(m, n);
|
||||
});
|
||||
Tensor FT =
|
||||
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n) * ET.load(m, n);
|
||||
});
|
||||
auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
|
||||
|
||||
// Constructed stmt:
|
||||
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
|
||||
// add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm'
|
||||
// for 'add'.
|
||||
//{
|
||||
// for (int M = 0; M < 1024; M++) {
|
||||
// for (int N = 0; N < 1024; N++) {
|
||||
// gemm[M, N] = float(0);
|
||||
// for (int K = 0; K < 2048; K++) {
|
||||
// gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
|
||||
// reduce_args={K});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int M_1 = 0; M_1 < 1024; M_1++) {
|
||||
// for (int N_1 = 0; N_1 < 1024; N_1++) {
|
||||
// relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
|
||||
// N_1]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_2 = 0; M_2 < 1024; M_2++) {
|
||||
// for (int N_2 = 0; N_2 < 1024; N_2++) {
|
||||
// add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_3 = 0; M_3 < 1024; M_3++) {
|
||||
// for (int N_3 = 0; N_3 < 1024; N_3++) {
|
||||
// mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
SimpleIREvaluator cg(stmt, {AP, BP, FT});
|
||||
|
||||
checkIR(cg.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Alias(add,gemm);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
LoopNest loop(Stmt::clone(stmt), {FT.buf()});
|
||||
loop.prepareForCodegen();
|
||||
LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
|
||||
|
||||
checkIR(cg_llvm.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Alias(add,gemm);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(MemPlanning, SameBufSizeMultiMemReuses) {
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 2048;
|
||||
|
||||
BufHandle AP("A", {M, K}, kFloat);
|
||||
BufHandle BP("B", {K, N}, kFloat);
|
||||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET =
|
||||
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return DT.load(m, n) + DT.load(m, n);
|
||||
});
|
||||
Tensor FT =
|
||||
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n) * ET.load(m, n);
|
||||
});
|
||||
Tensor GT =
|
||||
Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return FT.load(m, n) - ET.load(m, n);
|
||||
});
|
||||
|
||||
auto stmt =
|
||||
Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()});
|
||||
|
||||
// Constructed stmt:
|
||||
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
|
||||
// add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same
|
||||
// size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul'
|
||||
//{
|
||||
// for (int M = 0; M < 1024; M++) {
|
||||
// for (int N = 0; N < 1024; N++) {
|
||||
// gemm[M, N] = float(0);
|
||||
// for (int K = 0; K < 2048; K++) {
|
||||
// gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
|
||||
// reduce_args={K});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int M_1 = 0; M_1 < 1024; M_1++) {
|
||||
// for (int N_1 = 0; N_1 < 1024; N_1++) {
|
||||
// relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
|
||||
// N_1]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_2 = 0; M_2 < 1024; M_2++) {
|
||||
// for (int N_2 = 0; N_2 < 1024; N_2++) {
|
||||
// add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_3 = 0; M_3 < 1024; M_3++) {
|
||||
// for (int N_3 = 0; N_3 < 1024; N_3++) {
|
||||
// mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_4 = 0; M_4 < 1024; M_4++) {
|
||||
// for (int N_4 = 0; N_4 < 1024; N_4++) {
|
||||
// sub[M_4, N_4] = (mul[M_4, N_4]) - (add[M_4, N_4]);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
SimpleIREvaluator cg(stmt, {AP, BP, GT});
|
||||
|
||||
checkIR(cg.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Alias(add,gemm);
|
||||
# CHECK: Alias(mul,relu);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
LoopNest loop(Stmt::clone(stmt), {FT.buf()});
|
||||
loop.prepareForCodegen();
|
||||
LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
|
||||
|
||||
checkIR(cg_llvm.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Alias(add,gemm);
|
||||
# CHECK: Alias(mul,relu);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(MemPlanning, SameBufSizeMultiMemReusesOfOneBuf) {
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 2048;
|
||||
|
||||
BufHandle AP("A", {M, K}, kFloat);
|
||||
BufHandle BP("B", {K, N}, kFloat);
|
||||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET =
|
||||
Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return DT.load(m, n) + DT.load(m, n);
|
||||
});
|
||||
Tensor FT =
|
||||
Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return ET.load(m, n) * ET.load(m, n);
|
||||
});
|
||||
Tensor GT =
|
||||
Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return FT.load(m, n) - 1;
|
||||
});
|
||||
Tensor HT =
|
||||
Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
return GT.load(m, n) / 2;
|
||||
});
|
||||
|
||||
auto stmt = Block::make(
|
||||
{CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()});
|
||||
|
||||
// Constructed stmt:
|
||||
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
|
||||
// add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and
|
||||
// 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for
|
||||
// 'mul', and reuse 'gemm' for 'sub'.
|
||||
//{
|
||||
// for (int M = 0; M < 1024; M++) {
|
||||
// for (int N = 0; N < 1024; N++) {
|
||||
// gemm[M, N] = float(0);
|
||||
// for (int K = 0; K < 2048; K++) {
|
||||
// gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
|
||||
// reduce_args={K});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int M_1 = 0; M_1 < 1024; M_1++) {
|
||||
// for (int N_1 = 0; N_1 < 1024; N_1++) {
|
||||
// relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
|
||||
// N_1]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_2 = 0; M_2 < 1024; M_2++) {
|
||||
// for (int N_2 = 0; N_2 < 1024; N_2++) {
|
||||
// add[M_2, N_2] = (relu[M_2, N_2]) + (relu[M_2, N_2]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_3 = 0; M_3 < 1024; M_3++) {
|
||||
// for (int N_3 = 0; N_3 < 1024; N_3++) {
|
||||
// mul[M_3, N_3] = (add[M_3, N_3]) * (add[M_3, N_3]);
|
||||
// }
|
||||
// }
|
||||
// for (int M_4 = 0; M_4 < 1024; M_4++) {
|
||||
// for (int N_4 = 0; N_4 < 1024; N_4++) {
|
||||
// sub[M_4, N_4] = (mul[M_4, N_4]) - float(1);
|
||||
// }
|
||||
// }
|
||||
// for (int M_5 = 0; M_5 < 1024; M_5++) {
|
||||
// for (int N_5 = 0; N_5 < 1024; N_5++) {
|
||||
// div[M_5, N_5] = (sub[M_5, N_5]) / float(2);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
SimpleIREvaluator cg(stmt, {AP, BP, HT});
|
||||
|
||||
checkIR(cg.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Alias(add,gemm);
|
||||
# CHECK: Alias(mul,relu);
|
||||
# CHECK: Alias(sub,gemm);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
LoopNest loop(Stmt::clone(stmt), {FT.buf()});
|
||||
loop.prepareForCodegen();
|
||||
LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
|
||||
|
||||
checkIR(cg_llvm.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Alias(add,gemm);
|
||||
# CHECK: Alias(mul,relu);
|
||||
# CHECK: Alias(sub,gemm);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(MemPlanning, SmallerBufSizeNonMemReuse) {
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 2048;
|
||||
|
||||
BufHandle AP("A", {M, K}, kFloat);
|
||||
BufHandle BP("B", {K, N}, kFloat);
|
||||
|
||||
Tensor CT = Reduce(
|
||||
"gemm",
|
||||
{M, N},
|
||||
Sum(),
|
||||
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
|
||||
return AP.load(m, k) * BP.load(k, n);
|
||||
},
|
||||
{K});
|
||||
Tensor DT =
|
||||
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
|
||||
auto zero = Cast::make(CT.buf()->dtype(), 0);
|
||||
return CompareSelect::make(
|
||||
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
|
||||
});
|
||||
Tensor ET = Compute(
|
||||
"add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) {
|
||||
return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2);
|
||||
});
|
||||
Tensor FT = Compute(
|
||||
"mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) {
|
||||
return ET.load(fm, fn) * ET.load(fm, fn);
|
||||
});
|
||||
auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
|
||||
|
||||
// Constructed stmt:
|
||||
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
|
||||
// add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of
|
||||
// buffer 'gemm' is smaller.
|
||||
//{
|
||||
// for (int M = 0; M < 1024; M++) {
|
||||
// for (int N = 0; N < 1024; N++) {
|
||||
// gemm[M, N] = float(0);
|
||||
// for (int K = 0; K < 2048; K++) {
|
||||
// gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]),
|
||||
// reduce_args={K});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int M_1 = 0; M_1 < 1024; M_1++) {
|
||||
// for (int N_1 = 0; N_1 < 1024; N_1++) {
|
||||
// relu[M_1, N_1] = (gemm[M_1, N_1])<float(0) ? float(0) : (gemm[M_1,
|
||||
// N_1]);
|
||||
// }
|
||||
// }
|
||||
// for (int EM = 0; EM < 2048; EM++) {
|
||||
// for (int EN = 0; EN < 2048; EN++) {
|
||||
// add[EM, EN] = (relu[EM / 2, EN / 2]) + (relu[EM / 2, EN / 2]);
|
||||
// }
|
||||
// }
|
||||
// for (int FM = 0; FM < 2048; FM++) {
|
||||
// for (int FN = 0; FN < 2048; FN++) {
|
||||
// mul[FM, FN] = (add[FM, FN]) * (add[FM, FN]);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//
|
||||
|
||||
SimpleIREvaluator cg(stmt, {AP, BP, FT});
|
||||
|
||||
checkIR(cg.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK-NOT: Alias(add,gemm);
|
||||
# CHECK: Allocate(add); // dtype=float, dims=[2048, 2048]
|
||||
# CHECK: Free(add);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
LoopNest loop(Stmt::clone(stmt), {FT.buf()});
|
||||
loop.prepareForCodegen();
|
||||
LLVMCodeGen cg_llvm(loop.root_stmt(), {AP, BP, FT});
|
||||
|
||||
checkIR(cg_llvm.stmt(), R"IR(
|
||||
# CHECK: Allocate(gemm); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK: Allocate(relu); // dtype=float, dims=[1024, 1024]
|
||||
# CHECK-NOT: Alias(add,gemm);
|
||||
# CHECK: Allocate(add); // dtype=float, dims=[2048, 2048]
|
||||
# CHECK: Free(add);
|
||||
# CHECK: Free(relu);
|
||||
# CHECK: Free(gemm))IR");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
78
test/cpp/tensorexpr/test_ops.cpp
Normal file
78
test/cpp/tensorexpr/test_ops.cpp
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
using Tensors = std::vector<Tensor>;
|
||||
using Args = std::vector<CodeGen::BufferArg>;
|
||||
std::unique_ptr<SimpleIREvaluator> compile(
|
||||
const Args& inputs,
|
||||
const Tensors& outputs) {
|
||||
LoopNest nest({outputs});
|
||||
nest.prepareForCodegen();
|
||||
nest.simplify();
|
||||
auto join = inputs;
|
||||
join.insert(join.end(), outputs.begin(), outputs.end());
|
||||
return std::make_unique<SimpleIREvaluator>(nest.root_stmt(), join);
|
||||
}
|
||||
|
||||
TEST(Ops, Sum) {
|
||||
constexpr int M = 8;
|
||||
constexpr int N = 16;
|
||||
std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
|
||||
std::vector<std::vector<ExprHandle>> outputShapes = {{N}, {M}, {}};
|
||||
for (unsigned idx = 0; idx < testDims.size(); idx++) {
|
||||
const auto& dims = testDims[idx];
|
||||
const auto& outShape = outputShapes[idx];
|
||||
|
||||
BufHandle a("a", {M, N}, kFloat);
|
||||
std::vector<ExprHandle> outStrides =
|
||||
c10::fmap<ExprHandle>(make_contiguous_strides(outShape));
|
||||
Tensor b = computeSum(
|
||||
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
|
||||
auto cg = compile({a}, {b});
|
||||
|
||||
auto at = at::arange(M * N, at::kFloat).view({M, N});
|
||||
auto ref = at::sum(at, dims);
|
||||
auto bt = at::empty_like(ref);
|
||||
|
||||
cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
|
||||
|
||||
ASSERT_TRUE(at::allclose(bt, ref));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Ops, ChannelsLastSum) {
|
||||
constexpr int A = 2;
|
||||
constexpr int B = 3;
|
||||
constexpr int C = 4;
|
||||
constexpr int D = 5;
|
||||
constexpr int E = 6;
|
||||
std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
|
||||
|
||||
std::vector<std::vector<ExprHandle>> outputShapes = {
|
||||
{B, C, D, E}, {A, C, D, E}, {C, D, E}};
|
||||
for (unsigned idx = 0; idx < testDims.size(); idx++) {
|
||||
const auto& dims = testDims[idx];
|
||||
const auto& outShape = outputShapes[idx];
|
||||
|
||||
BufHandle a("a", {A, B, C, D, E}, kFloat);
|
||||
std::vector<ExprHandle> outStrides =
|
||||
c10::fmap<ExprHandle>(make_channels_last_strides(outShape));
|
||||
Tensor b = computeSum(
|
||||
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
|
||||
auto cg = compile({a}, {b});
|
||||
|
||||
auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E});
|
||||
auto ref = at::sum(at, dims);
|
||||
auto bt = at::empty_like(ref);
|
||||
|
||||
cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
|
||||
|
||||
ASSERT_TRUE(at::allclose(bt, ref));
|
||||
}
|
||||
}
|
||||
452
test/cpp/tensorexpr/test_quantization.cpp
Normal file
452
test/cpp/tensorexpr/test_quantization.cpp
Normal file
|
|
@ -0,0 +1,452 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/native/quantized/PackedParams.h>
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <torch/torch.h>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
||||
using namespace torch::indexing;
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
class Quantization : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
getTEMustUseLLVMOnCPU() = false;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(Quantization, QuantDequantInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=12]()
|
||||
%3 : int = prim::Constant[value=13]()
|
||||
%4 : float = prim::Constant[value=0.1]()
|
||||
%q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
||||
%6 : Float(2, 2) = aten::dequantize(%q.1)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8);
|
||||
auto y_expected = at::dequantize(q);
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%3 : int = prim::Constant[value=122]()
|
||||
%4 : float = prim::Constant[value=0.1]()
|
||||
%q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
||||
%6 : Float(2, 2) = aten::dequantize(%q.1)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
|
||||
auto y_expected = at::dequantize(q);
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantDequantUInt8_NLC) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%3 : int = prim::Constant[value=122]()
|
||||
%4 : float = prim::Constant[value=0.1]()
|
||||
%q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
||||
%6 : Float(1, 2, 2) = aten::dequantize(%q.1)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
x.unsafeGetTensorImpl()->set_sizes_and_strides(
|
||||
std::initializer_list<int64_t>{1, 2, 2}, {4, 1, 2});
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
|
||||
auto y_expected = at::dequantize(q);
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x:\n" << x << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
at::Tensor quantized_add(
|
||||
at::Tensor x1,
|
||||
at::Tensor x2,
|
||||
double scale,
|
||||
int64_t zero) {
|
||||
const auto qadd_op =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::add", "")
|
||||
.typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
|
||||
return qadd_op.call(x1, x2, scale, zero);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantAddDequantInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=12]()
|
||||
%qz1 : int = prim::Constant[value=13]()
|
||||
%qs1 : float = prim::Constant[value=0.1]()
|
||||
%qz2 : int = prim::Constant[value=13]()
|
||||
%qs2 : float = prim::Constant[value=0.1]()
|
||||
%qza : int = prim::Constant[value=13]()
|
||||
%qsa : float = prim::Constant[value=0.1]()
|
||||
%q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
||||
%q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
||||
%qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
|
||||
%6 : Float(2, 2) = aten::dequantize(%qa)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8);
|
||||
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8);
|
||||
auto qa = quantized_add(q1, q2, 0.1f, 13);
|
||||
auto y_expected = at::dequantize(qa);
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x1, x2};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x1:\n" << x1 << std::endl;
|
||||
std::cout << "q1:\n" << q1 << std::endl;
|
||||
std::cout << "x2:\n" << x2 << std::endl;
|
||||
std::cout << "q2:\n" << q2 << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantAddDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%qz1 : int = prim::Constant[value=13]()
|
||||
%qs1 : float = prim::Constant[value=0.1]()
|
||||
%qz2 : int = prim::Constant[value=13]()
|
||||
%qs2 : float = prim::Constant[value=0.1]()
|
||||
%qza : int = prim::Constant[value=13]()
|
||||
%qsa : float = prim::Constant[value=0.1]()
|
||||
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
||||
%q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
||||
%qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
|
||||
%6 : Float(2, 2) = aten::dequantize(%qa)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
|
||||
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
|
||||
auto qa = quantized_add(q1, q2, 0.1f, 13);
|
||||
auto y_expected = at::dequantize(qa);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x1, x2};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x1:\n" << x1 << std::endl;
|
||||
std::cout << "q1:\n" << q1 << std::endl;
|
||||
std::cout << "x2:\n" << x2 << std::endl;
|
||||
std::cout << "q2:\n" << q2 << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantSigmoidDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%qz1 : int = prim::Constant[value=13]()
|
||||
%qs1 : float = prim::Constant[value=0.1]()
|
||||
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
||||
%qa : QUInt8(2, 2) = aten::sigmoid(%q1)
|
||||
%6 : Float(2, 2) = aten::dequantize(%qa)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
|
||||
auto qs = at::sigmoid(q1);
|
||||
auto y_expected = at::dequantize(qs);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x1};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x1:\n" << x1 << std::endl;
|
||||
std::cout << "q1:\n" << q1 << std::endl;
|
||||
std::cout << "qs:\n" << qs << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
at::Tensor quantized_mul(
|
||||
at::Tensor x1,
|
||||
at::Tensor x2,
|
||||
double scale,
|
||||
int64_t zero) {
|
||||
const auto op =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::mul", "")
|
||||
.typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
|
||||
return op.call(x1, x2, scale, zero);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantMulDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%qz1 : int = prim::Constant[value=13]()
|
||||
%qs1 : float = prim::Constant[value=0.1]()
|
||||
%qz2 : int = prim::Constant[value=13]()
|
||||
%qs2 : float = prim::Constant[value=0.1]()
|
||||
%qza : int = prim::Constant[value=13]()
|
||||
%qsa : float = prim::Constant[value=0.1]()
|
||||
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
||||
%q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
||||
%qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza)
|
||||
%6 : Float(2, 2) = aten::dequantize(%qa)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
|
||||
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
|
||||
auto qa = quantized_mul(q1, q2, 0.1f, 13);
|
||||
auto y_expected = at::dequantize(qa);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x1, x2};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x1:\n" << x1 << std::endl;
|
||||
std::cout << "q1:\n" << q1 << std::endl;
|
||||
std::cout << "x2:\n" << x2 << std::endl;
|
||||
std::cout << "q2:\n" << q2 << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%4 : NoneType = prim::Constant()
|
||||
%3 : int[] = prim::Constant[value=[6, 6]]()
|
||||
%qz : int = prim::Constant[value=13]()
|
||||
%qs : float = prim::Constant[value=0.1]()
|
||||
%q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2)
|
||||
%qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4)
|
||||
%6 : Float(1, 1, 6, 6) = aten::dequantize(%qu)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
|
||||
auto qu = at::upsample_nearest2d(q, {6, 6});
|
||||
auto y_expected = at::dequantize(qu);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x:\n" << x << std::endl;
|
||||
std::cout << "q:\n" << q << std::endl;
|
||||
std::cout << "qu:\n" << qu << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, UpsampleNearst2d) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
|
||||
%4 : NoneType = prim::Constant()
|
||||
%3 : int[] = prim::Constant[value=[4, 4]]()
|
||||
%u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4)
|
||||
return (%u))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto y_expected = at::upsample_nearest2d(x, {4, 4});
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x:\n" << x << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
at::Tensor quantized_cat(
|
||||
c10::List<at::Tensor> const& xs,
|
||||
int64_t dim,
|
||||
double scale,
|
||||
int64_t zero) {
|
||||
const auto op = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::cat", "")
|
||||
.typed<at::Tensor(
|
||||
c10::List<at::Tensor> const&,
|
||||
int64_t,
|
||||
std::optional<double>,
|
||||
std::optional<int64_t>)>();
|
||||
return op.redispatch(
|
||||
DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantCatDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
|
||||
%qdt : int = prim::Constant[value=13]()
|
||||
%qxz : int = prim::Constant[value=13]()
|
||||
%qxs : float = prim::Constant[value=0.1]()
|
||||
%qyz : int = prim::Constant[value=16]()
|
||||
%qys : float = prim::Constant[value=0.15]()
|
||||
%qzz : int = prim::Constant[value=19]()
|
||||
%qzs : float = prim::Constant[value=0.2]()
|
||||
%qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt)
|
||||
%qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt)
|
||||
%qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt)
|
||||
%catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz)
|
||||
%catd : int = prim::Constant[value=0]()
|
||||
%qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz)
|
||||
%cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat)
|
||||
return (%cat))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
|
||||
auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8);
|
||||
auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8);
|
||||
auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13);
|
||||
auto expected = at::dequantize(qcat);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x, y, z};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto result = stack[0].toTensor();
|
||||
bool check = at::allclose(expected, result);
|
||||
if (!check) {
|
||||
std::cout << "x:\n" << x << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
std::cout << "z:\n" << z << std::endl;
|
||||
std::cout << "qx:\n" << qx << std::endl;
|
||||
std::cout << "qy:\n" << qy << std::endl;
|
||||
std::cout << "qz:\n" << qz << std::endl;
|
||||
std::cout << "qcat:\n" << qcat << std::endl;
|
||||
std::cout << "expected:\n" << expected << std::endl;
|
||||
std::cout << "result:\n" << result << std::endl;
|
||||
}
|
||||
TORCH_CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
1928
test/cpp/tensorexpr/test_reductions.cpp
Normal file
1928
test/cpp/tensorexpr/test_reductions.cpp
Normal file
File diff suppressed because it is too large
Load Diff
3702
test/cpp/tensorexpr/test_registerizer.cpp
Normal file
3702
test/cpp/tensorexpr/test_registerizer.cpp
Normal file
File diff suppressed because it is too large
Load Diff
5680
test/cpp/tensorexpr/test_simplify.cpp
Normal file
5680
test/cpp/tensorexpr/test_simplify.cpp
Normal file
File diff suppressed because it is too large
Load Diff
402
test/cpp/tensorexpr/test_te_fuser_pass.cpp
Normal file
402
test/cpp/tensorexpr/test_te_fuser_pass.cpp
Normal file
|
|
@ -0,0 +1,402 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <torch/csrc/jit/codegen/fuser/interface.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
struct WithCPUFuser {
|
||||
WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
|
||||
overrideCanFuseOnCPU(val);
|
||||
}
|
||||
|
||||
~WithCPUFuser() {
|
||||
overrideCanFuseOnCPU(cpuFuserEnabled);
|
||||
}
|
||||
|
||||
bool cpuFuserEnabled;
|
||||
};
|
||||
|
||||
TEST(TEFuserPass, FuserPass_1) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(128, strides=[1], device=cpu),
|
||||
%1 : Float(128, strides=[1], device=cpu)):
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
|
||||
%2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1)
|
||||
%3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12)
|
||||
%4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1)
|
||||
%5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12)
|
||||
return (%5))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g);
|
||||
|
||||
// We should not be able to fuse across the in-place operation here.
|
||||
testing::FileCheck()
|
||||
.check("prim::TensorExprGroup_")
|
||||
->check("aten::add_")
|
||||
->check("prim::TensorExprGroup_")
|
||||
->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_2) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(128, strides=[1], device=cpu),
|
||||
%1 : Float(128, strides=[1], device=cpu)):
|
||||
%12 : int = prim::Constant[value=1]()
|
||||
%a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
|
||||
%b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12)
|
||||
%c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12)
|
||||
%d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a)
|
||||
return (%d))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g);
|
||||
|
||||
// We should not be able to fuse across the in-place operation here.
|
||||
testing::FileCheck()
|
||||
.check("aten::add_")
|
||||
->check("prim::TensorExprGroup_0")
|
||||
->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_3) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(128, strides=[1], device=cpu),
|
||||
%y : Float(128, strides=[1], device=cpu)):
|
||||
%r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y)
|
||||
return (%r))IR";
|
||||
{
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
|
||||
// We should not create a fusion group since its size would be too small
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
{
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 1);
|
||||
|
||||
// We should create a fusion group since its size is above the threshold
|
||||
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_0DimInput) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(device=cpu),
|
||||
%y : Float(device=cpu)):
|
||||
%one : int = prim::Constant[value=1]()
|
||||
%a : Float(device=cpu) = aten::mul(%x, %y)
|
||||
%b : Float(device=cpu) = aten::add(%x, %a, %one)
|
||||
return (%b))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g);
|
||||
|
||||
// We should fuse 0-dim tensors too
|
||||
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_UnfusibleDevice) {
|
||||
WithCPUFuser cf(false);
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(10, strides=[1], device=cpu)):
|
||||
%a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
|
||||
return (%a))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 1);
|
||||
|
||||
// Test that we're not starting fusion groups from nodes with unfusible device
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_UnknownShapes) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Tensor,
|
||||
%y : Tensor):
|
||||
%a : Tensor = aten::mul(%x, %y)
|
||||
%b : Tensor = aten::mul(%x, %a)
|
||||
return (%b))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g);
|
||||
|
||||
// Test that we're not generating fusion groups when shapes are not known
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_Multidevice) {
|
||||
{
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
return (%cat))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 1);
|
||||
|
||||
// We should be able to fuse this
|
||||
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
{
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cuda:0),
|
||||
%z : Float(30, strides=[1], device=cpu)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
|
||||
return (%cat))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 1);
|
||||
|
||||
// We should not fuse this aten::cat since its inputs are from different
|
||||
// devices
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
{
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(10, strides=[1], device=cuda:0)):
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xy_list : Tensor[] = prim::ListConstruct(%x, %y)
|
||||
%xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
|
||||
%r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z)
|
||||
return (%r))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
|
||||
// Test that we check device before merging one node (cat) into another
|
||||
// (mul)
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
{
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cpu),
|
||||
%z : Float(10, strides=[1], device=cuda:0)):
|
||||
%z2 : Tensor = aten::mul(%z, %z)
|
||||
%dim : int = prim::Constant[value=0]()
|
||||
%xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2)
|
||||
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
|
||||
return (%cat))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
|
||||
// Test that we check device before merging one node (mul) into another
|
||||
// (cat)
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
{
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cpu),
|
||||
%y : Float(20, strides=[1], device=cuda:0)):
|
||||
%r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
|
||||
return (%r))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 1);
|
||||
|
||||
// We should not fuse this graph since its inputs are from different devices
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
{
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(10, strides=[1], device=cuda:0),
|
||||
%y : Float(20, strides=[1], device=cuda:1),
|
||||
%z : Float(20, strides=[1], device=cpu)):
|
||||
%x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x)
|
||||
%y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y)
|
||||
%z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z)
|
||||
return (%x2, %y2, %z2))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
|
||||
// We should not fuse these two computations since they use different
|
||||
// devices
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_MergeGroups) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%a : Float(128, strides=[1], device=cpu),
|
||||
%b : Float(128, strides=[1], device=cpu)):
|
||||
%x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a)
|
||||
%y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b)
|
||||
return (%x, %y))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 1);
|
||||
|
||||
// The %x and %y computations are completely independent and yet we should put
|
||||
// them into a single fusion group rather than having two separate ones.
|
||||
testing::FileCheck()
|
||||
.check("= prim::TensorExprGroup_")
|
||||
->check_not("= prim::TensorExprGroup_")
|
||||
->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Bool(8, strides=[1], device=cpu),
|
||||
%y : Bool(8, strides=[1], device=cpu)):
|
||||
%a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y)
|
||||
%b : Tensor = aten::__or__(%a, %y)
|
||||
return (%b)
|
||||
)IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_Where) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(8, strides=[1], device=cpu),
|
||||
%y : Float(8, strides=[1], device=cpu),
|
||||
%z : Float(8, strides=[1], device=cpu)):
|
||||
%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
|
||||
%b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z)
|
||||
return (%b)
|
||||
)IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, FuserPass_WhereList) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(8, strides=[1], device=cpu),
|
||||
%y : Float(8, strides=[1], device=cpu),
|
||||
%z : Float(8, strides=[1], device=cpu)):
|
||||
%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
|
||||
%b : Tensor[] = aten::where(%cond)
|
||||
return (%b)
|
||||
)IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
g->lint();
|
||||
FuseTensorExprs(g, /* min_group_size= */ 2);
|
||||
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
|
||||
}
|
||||
|
||||
TEST(TEFuserPass, DynamicShapeFusion) {
|
||||
WithCPUFuser cf;
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(10, 5, strides=[5, 1], device=cpu),
|
||||
%1 : Float(10, 5, strides=[5, 1], device=cpu)):
|
||||
%2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1)
|
||||
%3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1)
|
||||
return (%3))IR";
|
||||
auto g = std::make_shared<Graph>();
|
||||
torch::jit::parseIR(graph_string, g.get());
|
||||
|
||||
g->lint();
|
||||
FuseTensorExprs(
|
||||
g,
|
||||
/* min_group_size = */ 2,
|
||||
/* add_composed_op = */ true,
|
||||
/* fuse_to_dynamic_shapes = */ true);
|
||||
Code code(g, "");
|
||||
|
||||
testing::FileCheck()
|
||||
.check("prim::TensorExprDynamicGroup_")
|
||||
->check("prim::TensorExprDynamicGuard")
|
||||
->check("prim::TensorExprGroup_")
|
||||
->run(*g);
|
||||
|
||||
auto run_and_compare = [&](const std::vector<at::Tensor>& inputs) {
|
||||
TORCH_INTERNAL_ASSERT(inputs.size() == 2);
|
||||
|
||||
auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]);
|
||||
|
||||
InterpreterState interp(code);
|
||||
Stack stack(inputs.begin(), inputs.end());
|
||||
interp.run(stack);
|
||||
at::Tensor out = pop(stack).toTensor();
|
||||
ASSERT_TRUE(at::allclose(out, ref));
|
||||
};
|
||||
|
||||
std::vector<at::Tensor> inputs = {at::rand({10, 5}), at::rand({10, 5})};
|
||||
run_and_compare(inputs);
|
||||
|
||||
std::vector<at::Tensor> inputs2 = {at::rand({20, 5}), at::rand({20, 5})};
|
||||
run_and_compare(inputs2);
|
||||
|
||||
std::vector<at::Tensor> inputs3 = {at::rand({25, 60}), at::rand({25, 60})};
|
||||
run_and_compare(inputs3);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
202
test/cpp/tensorexpr/test_type.cpp
Normal file
202
test/cpp/tensorexpr/test_type.cpp
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
#include "torch/csrc/jit/tensorexpr/tensor.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
TEST(Type, Test01) {
|
||||
{
|
||||
Dtype dt1 = kInt;
|
||||
ASSERT_EQ(dt1, kInt);
|
||||
}
|
||||
{
|
||||
Dtype dt2_a(kInt, 8);
|
||||
Dtype dt2_b(kInt, 4);
|
||||
Dtype dt2_c(ScalarType::Int, 8);
|
||||
ASSERT_EQ(dt2_a, dt2_c);
|
||||
ASSERT_NE(dt2_a, dt2_b);
|
||||
}
|
||||
{
|
||||
ASSERT_EQ(kInt, ToDtype<int>());
|
||||
ASSERT_EQ(kFloat, ToDtype<float>());
|
||||
ASSERT_EQ(kByte, ToDtype<uint8_t>());
|
||||
ASSERT_EQ(kChar, ToDtype<int8_t>());
|
||||
ASSERT_EQ(kShort, ToDtype<int16_t>());
|
||||
ASSERT_EQ(kLong, ToDtype<int64_t>());
|
||||
ASSERT_EQ(kHalf, ToDtype<at::Half>());
|
||||
ASSERT_EQ(kDouble, ToDtype<double>());
|
||||
ASSERT_EQ(kBool, ToDtype<bool>());
|
||||
}
|
||||
{
|
||||
Dtype int32x8(kInt, 8);
|
||||
Dtype float32x8(kFloat, 8);
|
||||
ASSERT_NE(int32x8, float32x8);
|
||||
ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8));
|
||||
ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8));
|
||||
ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8));
|
||||
ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Type, BitCasting) {
|
||||
{
|
||||
VarHandle x("x", kFloat);
|
||||
ExprHandle y = bitcast<int32_t>(x);
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
ASSERT_EQ(y.dtype(), kInt);
|
||||
}
|
||||
{
|
||||
VarHandle x("x", kInt);
|
||||
ExprHandle y = bitcast<float>(x);
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
ASSERT_EQ(y.dtype(), kFloat);
|
||||
}
|
||||
{
|
||||
VarHandle x("x", kShort);
|
||||
ExprHandle y = bitcast<at::Half>(x);
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
ASSERT_EQ(y.dtype(), kHalf);
|
||||
}
|
||||
{
|
||||
VarHandle x("x", kHalf);
|
||||
ExprHandle y = bitcast<int16_t>(x);
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
ASSERT_EQ(y.dtype(), kShort);
|
||||
}
|
||||
|
||||
constexpr int32_t ref32 = 1337;
|
||||
constexpr int64_t ref64 = 1337;
|
||||
constexpr float reff32 = 1337.0f;
|
||||
constexpr double reff64 = 1337.0f;
|
||||
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
||||
// this is broken
|
||||
/*{
|
||||
constexpr int16_t ref16 = 1337;
|
||||
at::Half k_;
|
||||
at::Half* k = &k_;
|
||||
*reinterpret_cast<int16_t*>(k) = ref16;
|
||||
auto a = HalfImm::make(*k);
|
||||
auto b = BitCast::make(kShort, a);
|
||||
SimpleIRExprEval cg(b);
|
||||
ASSERT_EQ(cg.value<int16_t>(), ref16);
|
||||
}*/
|
||||
|
||||
{
|
||||
float k = raw_bitcast<float>(ref32);
|
||||
auto a = FloatImm::make(k);
|
||||
auto b = BitCast::make(kInt, a);
|
||||
SimpleIRExprEval cg(b);
|
||||
ASSERT_EQ(cg.value<int32_t>(), ref32);
|
||||
}
|
||||
|
||||
{
|
||||
double k = raw_bitcast<double>(ref64);
|
||||
auto a = DoubleImm::make(k);
|
||||
auto b = BitCast::make(kLong, a);
|
||||
SimpleIRExprEval cg(b);
|
||||
ASSERT_EQ(cg.value<int64_t>(), ref64);
|
||||
}
|
||||
|
||||
{
|
||||
int64_t k = raw_bitcast<int64_t>(reff64);
|
||||
auto a = LongImm::make(k);
|
||||
auto b = BitCast::make(kDouble, a);
|
||||
SimpleIRExprEval cg(b);
|
||||
ASSERT_EQ(cg.value<double>(), reff64);
|
||||
}
|
||||
|
||||
{
|
||||
int32_t k = raw_bitcast<int32_t>(reff32);
|
||||
auto a = IntImm::make(k);
|
||||
auto b = BitCast::make(kFloat, a);
|
||||
SimpleIRExprEval cg(b);
|
||||
ASSERT_EQ(cg.value<float>(), reff32);
|
||||
}
|
||||
|
||||
// This segfaults :(
|
||||
/*{
|
||||
VarHandle x("x", kDouble);
|
||||
ASSERT_ANY_THROW(ExprHandle y = bitcast<int32_t>(x));
|
||||
}
|
||||
{
|
||||
VarHandle x("x", kFloat);
|
||||
ASSERT_ANY_THROW(ExprHandle y = bitcast<int64_t>(x));
|
||||
}
|
||||
{
|
||||
VarHandle x("x", kLong);
|
||||
ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
|
||||
}
|
||||
{
|
||||
VarHandle x("x", kShort);
|
||||
ASSERT_ANY_THROW(ExprHandle y = bitcast<float>(x));
|
||||
}
|
||||
{
|
||||
VarHandle x("x", kInt);
|
||||
ASSERT_ANY_THROW(ExprHandle y = bitcast<at::Half>(x));
|
||||
}*/
|
||||
}
|
||||
|
||||
TEST(Type, Propagation) {
|
||||
// Same types:
|
||||
{
|
||||
VarHandle x("x", kFloat);
|
||||
VarHandle y("y", kFloat);
|
||||
ExprHandle body = FloatImm::make(2.f) +
|
||||
(x * FloatImm::make(3.f) + FloatImm::make(4.f) * y);
|
||||
ASSERT_EQ(body.dtype(), kFloat);
|
||||
}
|
||||
// Int to bigger int:
|
||||
{
|
||||
VarHandle x("x", kShort);
|
||||
VarHandle y("y", kLong);
|
||||
ExprHandle body =
|
||||
ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y);
|
||||
ASSERT_EQ(body.dtype(), kLong);
|
||||
}
|
||||
// Float to bigger float:
|
||||
{
|
||||
VarHandle x("x", kHalf);
|
||||
VarHandle y("y", kDouble);
|
||||
ExprHandle body =
|
||||
HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
|
||||
ASSERT_EQ(body.dtype(), kDouble);
|
||||
}
|
||||
// Int to Float:
|
||||
{
|
||||
VarHandle x("x", kFloat);
|
||||
VarHandle y("y", kInt);
|
||||
ExprHandle body =
|
||||
IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y);
|
||||
ASSERT_EQ(body.dtype(), kFloat);
|
||||
}
|
||||
// Smaller float, bigger Int:
|
||||
{
|
||||
VarHandle x("x", kHalf);
|
||||
VarHandle y("y", kLong);
|
||||
ExprHandle body =
|
||||
HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
|
||||
ASSERT_EQ(body.dtype(), kHalf);
|
||||
}
|
||||
// Bigger float, smaller Int:
|
||||
{
|
||||
VarHandle x("x", kChar);
|
||||
VarHandle y("y", kDouble);
|
||||
ExprHandle body =
|
||||
CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
|
||||
ASSERT_EQ(body.dtype(), kDouble);
|
||||
}
|
||||
// Sign change char/byte upgrades to short:
|
||||
{
|
||||
VarHandle x("x", kChar);
|
||||
VarHandle y("y", kByte);
|
||||
ExprHandle body =
|
||||
CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
|
||||
ASSERT_EQ(body.dtype(), kShort);
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
75
test/cpp/tensorexpr/test_type_specializations.cpp
Normal file
75
test/cpp/tensorexpr/test_type_specializations.cpp
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/pass_manager.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
|
||||
// Test that tensor type specializations are available in
|
||||
// the custom passes
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
bool hasTensorTypeSpecializations(torch::jit::Block* block) {
|
||||
for (Value* v : block->inputs()) {
|
||||
if (hasTensorTypeSpecialization(v))
|
||||
return true;
|
||||
}
|
||||
for (Node* n : block->nodes()) {
|
||||
for (torch::jit::Block* b : n->blocks()) {
|
||||
if (hasTensorTypeSpecializations(b))
|
||||
return true;
|
||||
}
|
||||
for (Value* v : n->outputs()) {
|
||||
if (hasTensorTypeSpecialization(v))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool hasSpecializations = false;
|
||||
void detectTTSpecializationPass(std::shared_ptr<Graph>& graph) {
|
||||
GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: ", graph);
|
||||
hasSpecializations = hasTensorTypeSpecializations(graph->block());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(SpecializationsInCustomPasses, Basic) {
|
||||
RegisterPass p(detectTTSpecializationPass);
|
||||
hasSpecializations = false;
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>();
|
||||
parseIR(
|
||||
R"IR(
|
||||
graph(%a.1 : Tensor,
|
||||
%b.1 : Tensor):
|
||||
%c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8
|
||||
%d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8
|
||||
return (%d.1)
|
||||
)IR",
|
||||
&*graph);
|
||||
|
||||
IValue ival = IValue(torch::randn({22}, at::kCPU));
|
||||
std::vector<IValue> stack = {ival, ival};
|
||||
auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
|
||||
GraphExecutor executor(graph, "");
|
||||
executor.run(stack);
|
||||
return stack;
|
||||
};
|
||||
run(graph, stack);
|
||||
|
||||
// Profiling mode will not be run with simple executor
|
||||
if (!getExecutorMode()) {
|
||||
EXPECT_TRUE(hasSpecializations);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
78
test/cpp/tensorexpr/test_utils.h
Normal file
78
test/cpp/tensorexpr/test_utils.h
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
#define IS_NODE(T, node) \
|
||||
{ \
|
||||
auto node_ = to<T>(node); \
|
||||
ASSERT_NE(nullptr, node_); \
|
||||
}
|
||||
|
||||
#define IS_NODE_WITH_NAME(T, node, name) \
|
||||
auto name = to<T>(node); \
|
||||
ASSERT_NE(nullptr, name);
|
||||
|
||||
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
|
||||
NodePtr<T> name = nullptr; \
|
||||
{ \
|
||||
auto node_ = to<Cast>(node); \
|
||||
ASSERT_NE(nullptr, node_); \
|
||||
ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
|
||||
name = to<T>(node_->src_value()); \
|
||||
} \
|
||||
ASSERT_NE(nullptr, name);
|
||||
|
||||
#define IS_IMM_WITH_VAL(T, node, val) \
|
||||
{ \
|
||||
auto node_ = to<T##Imm>(node); \
|
||||
ASSERT_NE(nullptr, node_); \
|
||||
ASSERT_EQ(node_->value(), val); \
|
||||
}
|
||||
|
||||
#define IS_VAR_WITH_NAME(node, name) \
|
||||
{ \
|
||||
auto node_ = to<Var>(node); \
|
||||
ASSERT_NE(nullptr, node_); \
|
||||
ASSERT_EQ(node_->name_hint(), name); \
|
||||
}
|
||||
|
||||
#define IS_BINOP_W_VARS(T, node, name, v1, v2) \
|
||||
NodePtr<T> name = nullptr; \
|
||||
{ \
|
||||
name = to<T>(node); \
|
||||
ASSERT_NE(nullptr, name); \
|
||||
IS_VAR_WITH_NAME(name->lhs(), v1); \
|
||||
IS_VAR_WITH_NAME(name->rhs(), v2); \
|
||||
}
|
||||
|
||||
#define IS_BINOP_W_CONST(T, node, name, v, c) \
|
||||
NodePtr<T> name = nullptr; \
|
||||
{ \
|
||||
name = to<T>(node); \
|
||||
ASSERT_NE(nullptr, name); \
|
||||
IS_VAR_WITH_NAME(name->lhs(), v); \
|
||||
IS_IMM_WITH_VAL(Int, name->rhs(), c); \
|
||||
}
|
||||
|
||||
#define IS_RAND(node) \
|
||||
{ \
|
||||
auto node_ = to<Intrinsics>(node); \
|
||||
ASSERT_NE(nullptr, node_); \
|
||||
ASSERT_EQ(node_->op_type(), kRand); \
|
||||
}
|
||||
|
||||
void checkIR(StmtPtr s, const std::string& pattern);
|
||||
void checkExprIR(ExprPtr e, const std::string& pattern);
|
||||
void checkExprIR(const ExprHandle& e, const std::string& pattern);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
542
test/cpp/tensorexpr/tutorial.cpp
Normal file
542
test/cpp/tensorexpr/tutorial.cpp
Normal file
|
|
@ -0,0 +1,542 @@
|
|||
// *** Tensor Expressions ***
|
||||
//
|
||||
// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to
|
||||
// work with them, and outlines how they are used in the overall TorchScript
|
||||
// compilation pipeline. This doc is permanently a "work in progress" since NNC
|
||||
// is under active development and things change fast.
|
||||
//
|
||||
// This Tutorial's code is compiled in the standard pytorch build, and the
|
||||
// executable can be found in `build/bin/tutorial_tensorexpr`.
|
||||
//
|
||||
// *** What is NNC ***
|
||||
//
|
||||
// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT
|
||||
// and it performs on-the-fly code generation for kernels, which are often a
|
||||
// combination of multiple aten (torch) operators.
|
||||
//
|
||||
// When the JIT interpreter executes a torchscript model, it automatically
|
||||
// extracts subgraphs from the torchscript IR graph for which specialized code
|
||||
// can be JIT generated. This usually improves performance as the 'combined'
|
||||
// kernel created from the subgraph could avoid unnecessary memory traffic that
|
||||
// is unavoidable when the subgraph is interpreted as-is, operator by operator.
|
||||
// This optimization is often referred to as 'fusion'. Relatedly, the process of
|
||||
// finding and extracting subgraphs suitable for NNC code generation is done by
|
||||
// a JIT pass called 'fuser'.
|
||||
//
|
||||
// *** What is TE ***
|
||||
//
|
||||
// TE stands for Tensor Expressions. TE is a commonly used approach for
|
||||
// compiling kernels performing tensor (~matrix) computation. The idea behind it
|
||||
// is that operators are represented as a mathematical formula describing what
|
||||
// computation they do (as TEs) and then the TE engine can perform mathematical
|
||||
// simplification and other optimizations using those formulas and eventually
|
||||
// generate executable code that would produce the same results as the original
|
||||
// sequence of operators, but more efficiently.
|
||||
//
|
||||
// NNC's design and implementation of TE was heavily inspired by Halide and TVM
|
||||
// projects.
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/stmt.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
|
||||
// Helper function to print a snippet from a big multi-line string
|
||||
static void printLinesToFrom(const std::string& input_str, int from, int to);
|
||||
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::cout << "*** Structure of tensor expressions and statements ***"
|
||||
<< std::endl;
|
||||
{
|
||||
// A tensor expression is a tree of expressions. Each expression has a type,
|
||||
// and that type defines what sub-expressions the current expression has.
|
||||
// For instance, an expression of type 'Mul' would have a type 'kMul' and
|
||||
// two subexpressions: LHS and RHS. Each of these two sub-expressions could
|
||||
// also be a 'Mul' or some other expression.
|
||||
//
|
||||
// Let's construct a simple TE:
|
||||
ExprPtr lhs = alloc<IntImm>(5);
|
||||
ExprPtr rhs = alloc<Var>("x", kInt);
|
||||
ExprPtr mul = alloc<Mul>(lhs, rhs);
|
||||
std::cout << "Tensor expression: " << *mul << std::endl;
|
||||
// Prints: Tensor expression: 5 * x
|
||||
|
||||
// Here we created an expression representing a 5*x computation, where x is
|
||||
// an int variable.
|
||||
|
||||
// Another, probably a more convenient, way to construct tensor expressions
|
||||
// is to use so called expression handles (as opposed to raw expressions
|
||||
// like we did in the previous example). Expression handles overload common
|
||||
// operations and allow us to express the same semantics in a more natural
|
||||
// way:
|
||||
ExprHandle l = 5;
|
||||
ExprHandle r = Var::make("x", kInt);
|
||||
ExprHandle m = l * r;
|
||||
std::cout << "Tensor expression: " << *m.node() << std::endl;
|
||||
// Prints: Tensor expression: 5 * x
|
||||
|
||||
// Converting from handles to raw expressions and back is easy:
|
||||
ExprHandle handle = Var::make("x", kInt);
|
||||
ExprPtr raw_expr_from_handle = handle.node();
|
||||
ExprPtr raw_expr = alloc<Var>("x", kInt);
|
||||
ExprHandle handle_from_raw_expr = ExprHandle(raw_expr);
|
||||
|
||||
// We could construct arbitrarily complex expressions using mathematical
|
||||
// and logical operations, casts between various data types, and a bunch of
|
||||
// intrinsics.
|
||||
ExprHandle a = Var::make("a", kInt);
|
||||
ExprHandle b = Var::make("b", kFloat);
|
||||
ExprHandle c = Var::make("c", kFloat);
|
||||
ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f);
|
||||
std::cout << "Tensor expression: " << *x.node() << std::endl;
|
||||
// Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f)
|
||||
|
||||
// An ultimate purpose of tensor expressions is to optimize tensor
|
||||
// computations, and in order to represent accesses to tensors data, there
|
||||
// is a special kind of expression - a load.
|
||||
// To construct a load we need two pieces: the base and the indices. The
|
||||
// base of a load is a Buf expression, which could be thought of as a
|
||||
// placeholder similar to Var, but with dimensions info.
|
||||
//
|
||||
// Let's construct a simple load:
|
||||
BufHandle A("A", {64, 32}, kInt);
|
||||
VarPtr i_var = alloc<Var>("i", kInt), j_var = alloc<Var>("j", kInt);
|
||||
ExprHandle i(i_var), j(j_var);
|
||||
ExprHandle load = Load::make(A.dtype(), A, {i, j});
|
||||
std::cout << "Tensor expression: " << *load.node() << std::endl;
|
||||
// Prints: Tensor expression: A[i, j]
|
||||
|
||||
// Tensor Expressions constitute Tensor Statements, which are used to
|
||||
// represent computation of a given operator or a group of operators from a
|
||||
// fusion group.
|
||||
//
|
||||
// There are three main kinds of tensor statements:
|
||||
// - block
|
||||
// - store
|
||||
// - loop
|
||||
//
|
||||
// A Store represents a store to a single element of a tensor (or to a
|
||||
// group of elements if it's a vectorized store). Store statements,
|
||||
// similarly to Load expressions, have a base and indices, but on top of
|
||||
// that they also include a value - an expression representing what needs
|
||||
// to be stored at the given memory location. Let's create a Store stmt:
|
||||
StmtPtr store_a = Store::make(A, {i, j}, i + j);
|
||||
std::cout << "Store statement: " << *store_a << std::endl;
|
||||
// Prints: Store statement: A[i, j] = i + j;
|
||||
|
||||
// An operator fills the entire tensor, not just a single element, and to
|
||||
// represent this we need to use For stmt: let's wrap our store stmt with
|
||||
// two nested loops to represent that variables i and j need to iterate
|
||||
// over some ranges.
|
||||
ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a);
|
||||
ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a);
|
||||
|
||||
std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl;
|
||||
// Prints:
|
||||
// Nested for loops:
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// A[i, j] = i + j;
|
||||
// }
|
||||
// }
|
||||
|
||||
// A Block statement is used when we need a sequence of other statements.
|
||||
// E.g. if a fusion group contains several operators, we initially define
|
||||
// separate loopnest for each of them and put them all into a common block:
|
||||
BufHandle B("B", {64, 32}, kInt);
|
||||
StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j));
|
||||
ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b);
|
||||
ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b);
|
||||
|
||||
BlockPtr block = Block::make({loop_i_a, loop_i_b});
|
||||
std::cout << "Compound Block statement: " << std::endl
|
||||
<< *block << std::endl;
|
||||
// Prints:
|
||||
// Compound Block statement:
|
||||
// {
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// A[i, j] = i + j;
|
||||
// }
|
||||
// }
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// B[i, j] = A[i, j];
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Manually constructing nested loops and blocks to represent a computation
|
||||
// might be laborious, and instead we can use a 'Compute' API. This API
|
||||
// requires us to specify dimensions and a lambda to compute a single
|
||||
// element of the resulting tensor and returns a `Tensor` structure. This
|
||||
// structure is simply a pair of a buffer that was created to represent the
|
||||
// result of the computation (BufPtr) and a statement representing the
|
||||
// computation itself (StmtPtr).
|
||||
Tensor C =
|
||||
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return i * j;
|
||||
});
|
||||
std::cout << "Stmt produced by 'Compute' API: " << std::endl
|
||||
<< *C.stmt() << std::endl;
|
||||
// Prints:
|
||||
// Stmt produced by 'Compute' API:
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// C[i, j] = i * j;
|
||||
// }
|
||||
// }
|
||||
|
||||
// To construct statements to represent computations with reductions, we
|
||||
// can use a 'Reduce' API - it is similar to 'Compute' but takes a couple
|
||||
// of extra arguments defining how to perform the reduction. Let's define a
|
||||
// simple 2D sum of C using that:
|
||||
Tensor D = Reduce(
|
||||
"D",
|
||||
{},
|
||||
Sum(),
|
||||
[&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); },
|
||||
{64, 32});
|
||||
std::cout << "Stmt produced by 'Reduce' API: " << std::endl
|
||||
<< *D.stmt() << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "*** Loopnests transformations ***" << std::endl;
|
||||
{
|
||||
// When a statement for the computation is generated, we might want to
|
||||
// apply some optimizations to it. These transformations allow us to end up
|
||||
// with a statement producing the same results, but more efficiently.
|
||||
//
|
||||
// Let's look at a couple of transformations that are used in NNC. We will
|
||||
// begin with constructing a Block statement like we did before.
|
||||
|
||||
Tensor C =
|
||||
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return i * (j + 1);
|
||||
});
|
||||
BufHandle c_buf(C.buf());
|
||||
Tensor D =
|
||||
Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return c_buf.load(i, j) - i;
|
||||
});
|
||||
StmtPtr block = Block::make({C.stmt(), D.stmt()});
|
||||
std::cout << "Stmt produced by 'Compute' API: " << std::endl
|
||||
<< *block << std::endl;
|
||||
// Prints:
|
||||
// Stmt produced by 'Compute' API:
|
||||
// {
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// C[i, j] = i * (j + 1);
|
||||
// }
|
||||
// }
|
||||
// for (const auto i_1 : c10::irange(64)) {
|
||||
// for (const auto j_1 : c10::irange(32)) {
|
||||
// D[i_1, j_1] = (C[i_1, j_1]) - i_1;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// One transformation we can apply to this computation is inlining: i.e.
|
||||
// taking the expression that defines values of C and substituting a load
|
||||
// from C with it.
|
||||
// To do that, we first need to create a special object called LoopNest -
|
||||
// all transformations are methods of this class. To create a loopnest we
|
||||
// need to provide a list of output buffers and the root statement:
|
||||
LoopNest nest(block, {D.buf()});
|
||||
|
||||
// We can always retrieve the Stmt back from LoopNest:
|
||||
std::cout << "LoopNest root stmt: " << std::endl
|
||||
<< *nest.root_stmt() << std::endl;
|
||||
// Prints:
|
||||
// LoopNest root stmt:
|
||||
// {
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// C[i, j] = i * (j + 1);
|
||||
// }
|
||||
// }
|
||||
// for (const auto i_1 : c10::irange(64)) {
|
||||
// for (const auto j_1 : c10::irange(32)) {
|
||||
// D[i_1, j_1] = (C[i_1, j_1]) - i_1;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Now we can apply the inlining transformation:
|
||||
nest.computeInline(C.buf());
|
||||
std::cout << "Stmt after inlining:" << std::endl
|
||||
<< *nest.root_stmt() << std::endl;
|
||||
// Prints:
|
||||
// Stmt after inlining:
|
||||
// {
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// D[i, j] = i * (j + 1) - i;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// We can also apply algebraic simplification to a statement:
|
||||
StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt());
|
||||
std::cout << "Stmt after simplification:" << std::endl
|
||||
<< *simplified << std::endl;
|
||||
// Prints:
|
||||
// Stmt after simplification:
|
||||
// {
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// D[i, j] = i * j;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Many loopnest transformations are stateless and can be applied without
|
||||
// creating a LoopNest object. In fact, we plan to make all transformations
|
||||
// stateless.
|
||||
// splitWithTail is one such transformation: it splits an iteration space
|
||||
// of a given loop into two with a given factor.
|
||||
ForPtr outer_loop = to<For>(to<Block>(simplified)->stmts().front());
|
||||
LoopNest::splitWithTail(outer_loop, 13);
|
||||
// Call simplifier once more to fold some arithmetic.
|
||||
simplified = IRSimplifier::simplify(simplified);
|
||||
std::cout << "Stmt after splitWithTail:" << std::endl
|
||||
<< *simplified << std::endl;
|
||||
// Prints:
|
||||
// Stmt after splitWithTail:
|
||||
// {
|
||||
// for (const auto i_outer : c10::irange(4)) {
|
||||
// for (const auto i_inner : c10::irange(13)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (const auto i_tail : c10::irange(12)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// D[i_tail + 52, j] = i_tail * j + 52 * j;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// NNC supports a wide range of loop nest transformations, which we are not
|
||||
// listing here. Please refer to documentation in
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h
|
||||
// for more details.
|
||||
}
|
||||
|
||||
std::cout << "*** Codegen ***" << std::endl;
|
||||
{
|
||||
// An ultimate goal of tensor expressions is to be provide a mechanism to
|
||||
// execute a given computation in the fastest possible way. So far we've
|
||||
// looked at how we could describe what computation we're interested in, but
|
||||
// we haven't looked at how to actually execute it.
|
||||
//
|
||||
// All we've been dealing with was just symbols with no actual data
|
||||
// associated, in this section we would look at how we can bridge that gap.
|
||||
|
||||
// Let's start by constructing a simple computation for us to work with:
|
||||
BufHandle A("A", {64, 32}, kInt);
|
||||
BufHandle B("B", {64, 32}, kInt);
|
||||
Tensor X =
|
||||
Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return A.load(i, j) + B.load(i, j);
|
||||
});
|
||||
|
||||
// And let's lower it to a loop nest, as we did in the previous section. We
|
||||
// can pass Tensor object directly:
|
||||
LoopNest loopnest({X});
|
||||
std::cout << *loopnest.root_stmt() << std::endl;
|
||||
// Prints:
|
||||
// {
|
||||
// for (const auto i : c10::irange(64)) {
|
||||
// for (const auto j : c10::irange(32)) {
|
||||
// X[i, j] = (A[i, j]) + (B[i, j]);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Now imagine that we have two actual tensors 64x32 that we want sum
|
||||
// together, how do we pass those tensors to the computation and how do we
|
||||
// carry it out?
|
||||
//
|
||||
// Codegen object is aimed at providing exactly that functionality. Codegen
|
||||
// is an abstract class and concrete codegens are derived from it.
|
||||
// Currently, we have three codegens:
|
||||
// 1) Simple Evaluator,
|
||||
// 2) LLVM Codegen for CPU,
|
||||
// 3) CUDA Codegen.
|
||||
// In this example we will be using Simple Evaluator, since it's available
|
||||
// everywhere.
|
||||
|
||||
// To create a codegen, we need to provide the statement - it specifies the
|
||||
// computation we want to perform - and a list of placeholders and tensors
|
||||
// used in the computation. The latter part is crucial since that's the only
|
||||
// way the codegen could use to correlate symbols in the statement to actual
|
||||
// data arrays that we will be passing when we will actually be performing
|
||||
// the computation.
|
||||
//
|
||||
// Let's create a Simple IR Evaluator codegen for our computation:
|
||||
SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X});
|
||||
|
||||
// We are using the simplest codegen and in it almost no work is done at the
|
||||
// construction step. Real codegens such as CUDA and LLVM perform
|
||||
// compilation during that stage so that when we're about to run the
|
||||
// computation everything is ready.
|
||||
|
||||
// Let's now create some inputs and run our computation with them:
|
||||
std::vector<int> data_A(64 * 32, 3); // This will be the input A
|
||||
std::vector<int> data_B(64 * 32, 5); // This will be the input B
|
||||
std::vector<int> data_X(64 * 32, 0); // This will be used for the result
|
||||
|
||||
// Now let's invoke our codegen to perform the computation on our data. We
|
||||
// need to provide as many arguments as how many placeholders and tensors we
|
||||
// passed at the codegen construction time. A position in these lists would
|
||||
// define how real data arrays from the latter call (these arguments are
|
||||
// referred to as 'CallArg's in our codebase) correspond to symbols
|
||||
// (placeholders and tensors) used in the tensor expressions we constructed
|
||||
// (these are referred to as 'BufferArg').
|
||||
// Thus, we will provide three arguments: data_A, data_B, and data_X. data_A
|
||||
// contains data for the placeholder A, data_B - for the placeholder B, and
|
||||
// data_X would be used for contents of tensor X.
|
||||
ir_eval(data_A, data_B, data_X);
|
||||
|
||||
// Let's print one of the elements from each array to verify that the
|
||||
// computation did happen:
|
||||
std::cout << "A[10] = " << data_A[10] << std::endl
|
||||
<< "B[10] = " << data_B[10] << std::endl
|
||||
<< "X[10] = A[10] + B[10] = " << data_X[10] << std::endl;
|
||||
// Prints:
|
||||
// A[10] = 3
|
||||
// B[10] = 5
|
||||
// X[10] = A[10] + B[10] = 8
|
||||
}
|
||||
|
||||
std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl;
|
||||
{
|
||||
// This section requires a LLVM-enabled PyTorch build, so we have to use a
|
||||
// guard:
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
|
||||
// Often we would like to convert a TorchScript IR to TE rather than
|
||||
// construct TE IR from scratch. NNC provides an API to perform such
|
||||
// lowering: it takes a TorchScript graph and returns an object that can be
|
||||
// used to invoke the generated kernel.
|
||||
// This API is currently used by the TorchScript JIT fuser and can also be
|
||||
// used ahead of time to pre-compile parts of a model.
|
||||
//
|
||||
// To get familiar with this API let's first start with defining a simple
|
||||
// TorchScript graph:
|
||||
const auto graph_string = R"IR(
|
||||
graph(%A : Float(5, 3, strides=[3, 1], device=cpu),
|
||||
%B : Float(5, 3, strides=[3, 1], device=cpu)):
|
||||
%AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B)
|
||||
%one : int = prim::Constant[value=1]()
|
||||
%AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB)
|
||||
%AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one)
|
||||
return (%AAB_plus_B))IR";
|
||||
auto graph = std::make_shared<torch::jit::Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
// This graph defines a simple computation of A*A*B + B where A and B are
|
||||
// input 5x3 tensors.
|
||||
|
||||
// To lower this TorchScript graph to TE, we just need to create a
|
||||
// TensorExprKernel object. In its constructor it constructs the
|
||||
// corresponding TE IR and compiles it for the given backend (in this
|
||||
// example for CPU using LLVM compiler).
|
||||
TensorExprKernel kernel(graph);
|
||||
|
||||
// We can retrieve the generated TE stmt from the kernel object:
|
||||
StmtPtr kernel_stmt = kernel.getCodeGenStmt();
|
||||
std::cout << "TE Stmt constructed from TorchScript: " << std::endl
|
||||
<< *kernel_stmt << std::endl;
|
||||
// Prints:
|
||||
// TE Stmt constructed from TorchScript:
|
||||
// {
|
||||
// for (const auto v : c10::irange(5)) {
|
||||
// for (const auto _tail_tail : c10::irange(3)) {
|
||||
// aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) *
|
||||
// ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) +
|
||||
// (tB[_tail_tail + 3 * v]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// We can also examine generated LLVM IR and assembly code:
|
||||
std::cout << "Generated LLVM IR: " << std::endl;
|
||||
auto ir_str = kernel.getCodeText("ir");
|
||||
printLinesToFrom(ir_str, 15, 20);
|
||||
// Prints:
|
||||
// Generated LLVM IR:
|
||||
// %9 = bitcast float* %2 to <8 x float>*
|
||||
// %10 = load <8 x float>, <8 x float>* %9 ...
|
||||
// %11 = bitcast float* %5 to <8 x float>*
|
||||
// %12 = load <8 x float>, <8 x float>* %11 ...
|
||||
// %13 = fmul <8 x float> %10, %12
|
||||
// %14 = fmul <8 x float> %10, %13
|
||||
|
||||
std::cout << "Generated assembly: " << std::endl;
|
||||
auto asm_str = kernel.getCodeText("asm");
|
||||
printLinesToFrom(asm_str, 10, 15);
|
||||
// Prints:
|
||||
// Generated assembly:
|
||||
// vmulps %ymm1, %ymm0, %ymm2
|
||||
// vfmadd213ps %ymm1, %ymm0, %ymm2
|
||||
// vmovups %ymm2, (%rax)
|
||||
// vmovss 32(%rcx), %xmm0
|
||||
// vmovss 32(%rdx), %xmm1
|
||||
// vmulss %xmm1, %xmm0, %xmm2
|
||||
|
||||
// We can also execute the generated kernel:
|
||||
auto A =
|
||||
at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
|
||||
2.0;
|
||||
auto B =
|
||||
at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
|
||||
3.0;
|
||||
std::vector<at::Tensor> inputs = {A, B};
|
||||
std::vector<torch::IValue> stack = torch::fmap<torch::IValue>(inputs);
|
||||
kernel.run(stack);
|
||||
auto R = stack[0].toTensor();
|
||||
|
||||
// Let's print one of the elements from the result tensor to verify that the
|
||||
// computation did happen and was correct:
|
||||
std::cout << "R[2][2] = " << R[2][2] << std::endl;
|
||||
// Prints:
|
||||
// R[2][2] = 15
|
||||
// [ CPUFloatType{} ]
|
||||
#endif
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void printLinesToFrom(const std::string& input_str, int from, int to) {
|
||||
std::istringstream f(input_str);
|
||||
std::string s;
|
||||
int idx = 0;
|
||||
while (getline(f, s)) {
|
||||
if (idx > from) {
|
||||
std::cout << s << "\n";
|
||||
}
|
||||
if (idx++ > to) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1910,7 +1910,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
|
|||
}
|
||||
auto& out_t = p_node->Output(0).toTensor();
|
||||
|
||||
if (te && te->checkInput<float>(in0_t) && in0_t.sizes() == in1_t.sizes() &&
|
||||
if (in0_t.sizes() == in1_t.sizes() &&
|
||||
in0_t.scalar_type() == in1_t.scalar_type() &&
|
||||
in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() &&
|
||||
in0_t.scalar_type() == at::kFloat) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user