pytorch/test/cpp/tensorexpr/test_ir_verifier.cpp
Raghavan Raman 59dd12042e [nnc] Removed const from all fields in IR. (#62336)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62336

This PR was generated by removing `const` for all types of nodes in NNC IR, and fixing compilation errors that were the result of this change.

This is the first step in making all NNC mutations in-place.

Test Plan: Imported from OSS

Reviewed By: iramazanli

Differential Revision: D30049829

Pulled By: navahgar

fbshipit-source-id: ed14e2d2ca0559ffc0b92ac371f405579c85dd63
2021-08-03 11:44:36 -07:00

194 lines
6.2 KiB
C++

#include <gtest/gtest.h>
#include <stdexcept>
#include "test/cpp/tensorexpr/test_base.h"
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <sstream>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
TEST(IRVerifier, BitwiseOps) {
KernelScope kernel_scope;
Var* X = new Var("x", kInt);
Var* Y = new Var("y", kFloat);
{
auto a = new And(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
auto a = new Or(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
auto a = new Xor(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
auto a = new Lshift(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
auto a = new Rshift(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
}
TEST(IRVerifier, CompareSelect) {
KernelScope kernel_scope;
Expr* X = new IntImm(1);
Expr* Y = new FloatImm(3.14f);
{
auto a = new CompareSelect(X, X, X, Y, kEQ);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
auto a = new CompareSelect(X, Y, X, X, kEQ);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
}
TEST(IRVerifier, Ramp) {
KernelScope kernel_scope;
Var* I = new Var("i", kInt);
Var* J = new Var("j", kFloat);
{
auto a = new Ramp(I, J, 4);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
}
TEST(IRVerifier, Load) {
KernelScope kernel_scope;
Var* I = new Var("i", kInt);
Var* J = new Var("j", kLong);
Var* K = new Var("k", kFloat);
Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
{
// Indices with different int dtypes (kInt, kLong) are ok
auto a = new Load(B, {I, J});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_NO_THROW(verify(a));
}
{
// Float index
auto a = new Load(B, {K, K});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Multilanes are only allowed in flattened indices
auto multilane_index = new Ramp(I, new IntImm(1), 4);
auto a = new Load(B, {I, multilane_index});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
}
TEST(IRVerifier, IfThenElse) {
KernelScope kernel_scope;
Var* I = new Var("i", kInt);
Var* J = new Var("j", kLong);
Var* K = new Var("k", kFloat);
{
// Condition must be integral
auto a = new IfThenElse(K, I, I);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Dtypes of true and false exprs must match
auto a = new IfThenElse(I, I, J);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Can't have multiple lanes in condition expr
auto a = new IfThenElse(new Broadcast(I, 4), I, I);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
}
TEST(IRVerifier, For) {
KernelScope kernel_scope;
Var* I = new Var("i", kInt);
Var* J = new Var("j", kInt);
Stmt* body = new Block({});
{
// Can't have nullptr as a Var
auto a = new For(nullptr, I, J, body);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_ANY_THROW(verify(a));
}
}
TEST(IRVerifier, Block) {
KernelScope kernel_scope;
Var* I = new Var("i", kInt);
Buf* B = new Buf("B", {new IntImm(10)}, kInt);
{
Stmt* store = new Store(B, {I}, I);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
Stmt* block1 = new Block({store});
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
Stmt* block2 = new Block({store});
// Stmt can't have multiple parrents, thus inserting it into several blocks
// is illegal
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(block2));
}
}
TEST(IRVerifier, Store) {
KernelScope kernel_scope;
Var* I = new Var("i", kInt);
Var* J = new Var("j", kLong);
Var* K = new Var("k", kFloat);
Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
{
// Indices with different int dtypes (kInt, kLong) are ok
auto a = new Store(B, {I, J}, K);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_NO_THROW(verify(a));
}
{
// Float index
auto a = new Store(B, {K, K}, K);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Multilanes are only allowed in flattened indices
auto multilane_index = new Ramp(I, new IntImm(1), 4);
auto a = new Store(B, {I, multilane_index}, K);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Value and buf dtypes mismatch
auto a = new Store(B, {I}, I);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
}
} // namespace jit
} // namespace torch