pytorch/test/cpp/tensorexpr/test_graph_opt.cpp
Mikhail Zolotukhin f0d274294d [TensorExpr] Nuke KernelArena and KernelScope. (#63587)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63587

Now that there is no classes using KernelArena for memory management we
can remove it.

Differential Revision:
D30429115
D30429115

Test Plan: Imported from OSS

Reviewed By: navahgar

Pulled By: ZolotukhinM

fbshipit-source-id: 375f6f9294d27790645eeb7cb5a8e87047a57544
2021-08-24 00:32:16 -07:00

304 lines
9.9 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>
#include <limits>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
class GraphOpt : public ::testing::Test {
public:
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
void SetUp() {
old_cat_wo_conditionals_ = getCatWoConditionals();
getCatWoConditionals() = true;
}
void TearDown() {
getCatWoConditionals() = old_cat_wo_conditionals_;
}
private:
bool old_cat_wo_conditionals_;
};
TEST_F(GraphOpt, OptimizeCat) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::log` op must be moved to the inputs of `aten::cat`.
testing::FileCheck()
.check("aten::log")
->check("aten::log")
->check("aten::log")
->check("aten::cat")
->check_not("aten::log")
->run(*kernel.graph());
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::log(at::cat({x, y, z}, 0));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCat2) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
%6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::log` and `aten::tanh` ops must be moved to the inputs of
// `aten::cat`.
testing::FileCheck()
.check("aten::log")
->check("aten::log")
->check("aten::log")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check_not("aten::log")
->check_not("aten::tanh")
->run(*kernel.graph());
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::tanh(at::log(at::cat({x, y, z}, 0)));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCat3) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%a : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
%6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
// But the `aten::mul` op must not be moved since it is not a single-tensor
// op (it has 2 tensor inputs).
testing::FileCheck()
.check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check("aten::mul")
->check_not("aten::tanh")
->run(*kernel.graph());
auto a = at::rand({60}, at::kFloat);
auto x = at::rand({10}, at::kFloat);
auto y = at::rand({20}, at::kFloat);
auto z = at::rand({30}, at::kFloat);
auto ref = at::tanh(at::cat({x, y, z}, 0)) * a;
std::vector<at::Tensor> inputs = {a, x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Int(10, strides=[1], device=cpu),
%y : Int(20, strides=[1], device=cpu),
%z : Int(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// The `aten::tanh` op must be moved to the inputs of `aten::cat`.
// The scalar type of the inputs to `cat` should now be `Float` since they
// are the result of `tanh` which does the type promotion.
testing::FileCheck()
.check("aten::tanh")
->check("aten::tanh")
->check("aten::tanh")
->check("aten::cat")
->check_not("aten::tanh")
->run(*kernel.graph());
auto x = at::randint(std::numeric_limits<int>::max(), {10}, at::kInt);
auto y = at::randint(std::numeric_limits<int>::max(), {20}, at::kInt);
auto z = at::randint(std::numeric_limits<int>::max(), {30}, at::kInt);
auto ref = at::tanh(at::cat({x, y, z}, 0));
std::vector<at::Tensor> inputs = {x, y, z};
std::vector<IValue> stack = fmap<IValue>(inputs);
kernel.run(stack);
auto out = stack[0].toTensor();
ASSERT_EQ(out.sizes(), ref.sizes());
ASSERT_EQ(out.dtype(), ref.dtype());
ASSERT_TRUE(at::allclose(out, ref));
#endif
}
TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Double(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Double(60, strides=[1], device=cpu) = aten::log(%cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation should have happened because the `aten::cat` op performs
// type promotion. This case is currently not handled.
testing::FileCheck()
.check("aten::cat")
->check("aten::log")
->check_not("aten::cat")
->check_not("aten::log")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
return (%5))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation is expected since the consumers of cat are not
// single-tensor element-wise ops.
testing::FileCheck()
.check("aten::cat")
->check("aten::mul")
->check_not("aten::cat")
->check_not("aten::mul")
->run(*kernel.graph());
#endif
}
TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%0 : Float(60, strides=[1], device=cpu),
%1 : Float(60, strides=[1], device=cpu),
%x : Float(10, strides=[1], device=cpu),
%y : Float(20, strides=[1], device=cpu),
%z : Float(30, strides=[1], device=cpu)):
%one : int = prim::Constant[value=1]()
%dim : int = prim::Constant[value=0]()
%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
%5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat)
%6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one)
return (%6))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get());
g->lint();
TensorExprKernel kernel(g);
// No transformation is expected since the consumers of cat are not
// single-tensor element-wise ops.
testing::FileCheck()
.check("aten::cat")
->check("aten::mul")
->check("aten::add")
->check_not("aten::cat")
->check_not("aten::mul")
->check_not("aten::add")
->run(*kernel.graph());
#endif
}
} // namespace jit
} // namespace torch