pytorch/test/cpp/tensorexpr/test_cpp_codegen.cpp
Mikhail Zolotukhin 1dc2b52764 [TensorExpr] Add a wrapper for all expr and stmt pointers. (#63195)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63195

This helps us to later switch from using KernelArena with raw pointers
to shared pointers without having to change all our source files at
once.

The changes are mechanical and should not affect any functionality.

With this PR, we're changing the following:
 * `Add*` --> `AddPtr`
 * `new Add(...)` --> `alloc<Add>(...)`
 * `dynamic_cast<Add*>` --> `to<Add>`
 * `static_cast<Add*>` --> `static_to<Add>`

Due to some complications with args forwarding, some places became more
verbose, e.g.:
 * `new Block({})` --> `new Block(std::vector<ExprPtr>())`

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D30292779

Pulled By: ZolotukhinM

fbshipit-source-id: 150301c7d2df56b608b035827b6a9a87f5e2d9e9
2021-08-17 13:44:45 -07:00

59 lines
1.6 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
#include <torch/csrc/jit/tensorexpr/mem_arena.h>
#include <torch/csrc/jit/tensorexpr/stmt.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
TEST(CppPrinter, AllocateOnStackThenFree) {
KernelScope kernel_scope;
std::vector<ExprPtr> dims = {alloc<IntImm>(2), alloc<IntImm>(3)};
BufPtr buf = alloc<Buf>("x", dims, kInt);
AllocatePtr alloc_ = alloc<Allocate>(buf);
FreePtr free_ = alloc<Free>(buf);
BlockPtr block = Block::make({alloc_, free_});
std::stringstream ss;
CppPrinter printer(&ss);
printer.visit(block);
const std::string expected = R"(
# CHECK: {
# CHECK: int x[6];
# CHECK: }
)";
torch::jit::testing::FileCheck().run(expected, ss.str());
}
TEST(CppPrinter, AllocateOnHeapThenFree) {
KernelScope kernel_scope;
std::vector<ExprPtr> dims = {
alloc<IntImm>(20), alloc<IntImm>(50), alloc<IntImm>(3)};
BufPtr buf = alloc<Buf>("y", dims, kLong);
AllocatePtr alloc_ = alloc<Allocate>(buf);
FreePtr free_ = alloc<Free>(buf);
BlockPtr block = Block::make({alloc_, free_});
std::stringstream ss;
CppPrinter printer(&ss);
printer.visit(block);
// size(long) = 8;
// dim0 * dim1 * dim2 * size(long) = 24000.
const std::string expected = R"(
# CHECK: {
# CHECK: int64_t* y = static_cast<int64_t*>(malloc(24000));
# CHECK: free(y);
# CHECK: }
)";
torch::jit::testing::FileCheck().run(expected, ss.str());
}
} // namespace jit
} // namespace torch