#include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { using namespace torch::jit::tensorexpr; extern void checkIR(StmtPtr s, const std::string& pattern); TEST(BufLiveRange, SingleRangeLine) { VarHandle i("i", kInt), j("j", kInt); BufHandle a("a", {32}, kFloat); BufHandle b("b", {32, 32}, kFloat); // Construct Stmt: // { // for (int i = 0; i < 32; i++) { // a[i] = 0; // for (int j = 0; j < 32; j++) { // a[i] = (a[i]) + (b[i, j]); // } // } // } StorePtr aInit = Store::make(a, {i}, 0); ExprHandle reduce = a.load({i}) + b.load({i, j}); StorePtr aReduce = Store::make(a, {i}, reduce); StmtPtr loop = For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)})); StmtPtr stmt = Block::make({loop}); auto range = BufLiveRange::liveRange(stmt, a.node()); ASSERT_TRUE(std::get<0>(range) == 0); ASSERT_TRUE(std::get<1>(range) == 0); } TEST(BufLiveRange, MulRangeLine) { VarHandle i("i", kInt); BufHandle a("a", {32}, kFloat); BufHandle b("b", {32}, kFloat); // Construct Stmt: // { // for (int i = 0; i < 32; i++) { // if (i<10 ? 1 : 0) { // a[i] = i + i; // b[i] = i * i; // } // } // for (int i = 0; i < 32; i++) { // if (i>10 ? 1 : 0) { // a[i] = i * i; // b[i] = i + i; // } // } // } StorePtr aStore_1 = Store::make(a, {i}, i + i); StorePtr bStore_1 = Store::make(b, {i}, i * i); StmtPtr loop_1 = For::make( i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL)); StorePtr aStore_2 = Store::make(a, {i}, i * i); StorePtr bStore_2 = Store::make(b, {i}, i + i); StmtPtr loop_2 = For::make( i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL)); StmtPtr stmt = Block::make({loop_1, loop_2}); auto range_a = BufLiveRange::liveRange(stmt, a.node()); ASSERT_TRUE(std::get<0>(range_a) == 0); ASSERT_TRUE(std::get<1>(range_a) == 1); auto range_b = BufLiveRange::liveRange(stmt, b.node()); ASSERT_TRUE(std::get<0>(range_b) == 0); ASSERT_TRUE(std::get<1>(range_b) == 1); } TEST(MemPlanning, MemReuseWithTypeCast) { int M = 4; int N = 4; int K = 4; BufHandle AP("A", {M, K}, kFloat); BufHandle BP("B", {K, N}, kFloat); Tensor CT = Reduce( "gemm", {M, N}, Sum(), [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, {K}); Tensor DT = Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return CompareSelect::make( CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); }); Tensor ET = Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); }); Tensor FT = Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return ET.load(m, n); }); StmtPtr stmt = tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); // Constructed stmt: // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' // with typecasting. //{ // for (int i = 0; i < 4; i++) { // for (int i_1 = 0; i_1 < 4; i_1++) { // gemm[i, i_1] = float(0); // for (int i_2 = 0; i_2 < 4; i_2++) { // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, // i_1]), reduce_args={i_2}); // } // } // } // for (int i_3 = 0; i_3 < 4; i_3++) { // for (int i_4 = 0; i_4 < 4; i_4++) { // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); // } // } // for (int i_5 = 0; i_5 < 4; i_5++) { // for (int i_6 = 0; i_6 < 4; i_6++) { // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); // } // } // for (int i_7 = 0; i_7 < 4; i_7++) { // for (int i_8 = 0; i_8 < 4; i_8++) { // F[i_7, i_8] = E[i_7, i_8]; // } // } //} LoopNest l(stmt, {FT.buf()}); l.prepareForCodegen(); SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); checkIR(cg.stmt(), R"IR( # CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] # CHECK: Allocate(relu); // dtype=float, dims=[4, 4] # CHECK: Alias(E,gemm); # CHECK: Free(relu); # CHECK: Free(gemm))IR"); PaddedBuffer a_v(M, K, "a"); PaddedBuffer b_v(K, N, "b"); PaddedBuffer o1(M, N, "e_before"); PaddedBuffer o2(M, N, "e_after"); for (const auto m : c10::irange(M)) { for (const auto k : c10::irange(K)) { a_v(m, k) = at::randn({1}).item().to(); } } for (const auto k : c10::irange(K)) { for (const auto n : c10::irange(N)) { b_v(k, n) = at::randn({1}).item().to(); } } cg.call({a_v, b_v, o1}); #ifdef TORCH_ENABLE_LLVM LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); checkIR(cg_llvm.stmt(), R"IR( # CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] # CHECK: Allocate(relu); // dtype=float, dims=[4, 4] # CHECK: Alias(E,gemm); # CHECK: Free(relu); # CHECK: Free(gemm))IR"); cg_llvm.call({a_v, b_v, o2}); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) ExpectAllNear(o1, o2, 1e-5); #endif } TEST(MemPlanning, NoMemReuseForLargerType) { int M = 4; int N = 4; int K = 4; BufHandle AP("A", {M, K}, kShort); BufHandle BP("B", {K, N}, kShort); Tensor CT = Reduce( "gemm", {M, N}, Sum(), [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, {K}); auto zero = Cast::make(CT.buf()->dtype(), 0); Tensor DT = Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return CompareSelect::make( CT.load(m, n), zero, zero, CT.load(m, n), kLT); }); Tensor ET = Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); }); Tensor FT = Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return ET.load(m, n); }); StmtPtr stmt = tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); // Constructed stmt: // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for // 'E'. //{ // for (int i = 0; i < 4; i++) { // for (int i_1 = 0; i_1 < 4; i_1++) { // gemm[i, i_1] = int16_t(0); // for (int i_2 = 0; i_2 < 4; i_2++) { // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, // i_1]), reduce_args={i_2}); // } // } // } // for (int i_3 = 0; i_3 < 4; i_3++) { // for (int i_4 = 0; i_4 < 4; i_4++) { // relu[i_3, i_4] = (gemm[i_3, i_4]) a_v(M, K, "a"); PaddedBuffer b_v(K, N, "b"); PaddedBuffer o1(M, N, "e_before"); PaddedBuffer o2(M, N, "e_after"); for (const auto m : c10::irange(M)) { for (const auto k : c10::irange(K)) { a_v(m, k) = at::randn({1}).item().to(); } } for (const auto k : c10::irange(K)) { for (const auto n : c10::irange(N)) { b_v(k, n) = at::randn({1}).item().to(); } } cg.call({a_v, b_v, o1}); #ifdef TORCH_ENABLE_LLVM LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); checkIR(cg_llvm.stmt(), R"IR( # CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] # CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] # CHECK: Allocate(E); // dtype=float, dims=[4, 4] # CHECK: Free(E); # CHECK: Free(relu); # CHECK: Free(gemm))IR"); cg_llvm.call({a_v, b_v, o2}); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) ExpectAllNear(o1, o2, 1e-5); #endif } TEST(MemPlanning, SameBufSizeMemReuse) { int M = 1024; int N = 1024; int K = 2048; BufHandle AP("A", {M, K}, kFloat); BufHandle BP("B", {K, N}, kFloat); Tensor CT = Reduce( "gemm", {M, N}, Sum(), [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); }, {K}); Tensor DT = Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { auto zero = Cast::make(CT.buf()->dtype(), 0); return CompareSelect::make( CT.load(m, n), zero, zero, CT.load(m, n), kLT); }); Tensor ET = Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return DT.load(m, n) + DT.load(m, n); }); Tensor FT = Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return ET.load(m, n) * ET.load(m, n); }); auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); // Constructed stmt: // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm' // for 'add'. //{ // for (int M = 0; M < 1024; M++) { // for (int N = 0; N < 1024; N++) { // gemm[M, N] = float(0); // for (int K = 0; K < 2048; K++) { // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), // reduce_args={K}); // } // } // } // for (int M_1 = 0; M_1 < 1024; M_1++) { // for (int N_1 = 0; N_1 < 1024; N_1++) { // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); return CompareSelect::make( CT.load(m, n), zero, zero, CT.load(m, n), kLT); }); Tensor ET = Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return DT.load(m, n) + DT.load(m, n); }); Tensor FT = Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return ET.load(m, n) * ET.load(m, n); }); Tensor GT = Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return FT.load(m, n) - ET.load(m, n); }); auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()}); // Constructed stmt: // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul' //{ // for (int M = 0; M < 1024; M++) { // for (int N = 0; N < 1024; N++) { // gemm[M, N] = float(0); // for (int K = 0; K < 2048; K++) { // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), // reduce_args={K}); // } // } // } // for (int M_1 = 0; M_1 < 1024; M_1++) { // for (int N_1 = 0; N_1 < 1024; N_1++) { // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); return CompareSelect::make( CT.load(m, n), zero, zero, CT.load(m, n), kLT); }); Tensor ET = Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return DT.load(m, n) + DT.load(m, n); }); Tensor FT = Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return ET.load(m, n) * ET.load(m, n); }); Tensor GT = Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return FT.load(m, n) - 1; }); Tensor HT = Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { return GT.load(m, n) / 2; }); auto stmt = Block::make( {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()}); // Constructed stmt: // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for // 'mul', and reuse 'gemm' for 'sub'. //{ // for (int M = 0; M < 1024; M++) { // for (int N = 0; N < 1024; N++) { // gemm[M, N] = float(0); // for (int K = 0; K < 2048; K++) { // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), // reduce_args={K}); // } // } // } // for (int M_1 = 0; M_1 < 1024; M_1++) { // for (int N_1 = 0; N_1 < 1024; N_1++) { // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); return CompareSelect::make( CT.load(m, n), zero, zero, CT.load(m, n), kLT); }); Tensor ET = Compute( "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) { return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2); }); Tensor FT = Compute( "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) { return ET.load(fm, fn) * ET.load(fm, fn); }); auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); // Constructed stmt: // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of // buffer 'gemm' is smaller. //{ // for (int M = 0; M < 1024; M++) { // for (int N = 0; N < 1024; N++) { // gemm[M, N] = float(0); // for (int K = 0; K < 2048; K++) { // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), // reduce_args={K}); // } // } // } // for (int M_1 = 0; M_1 < 1024; M_1++) { // for (int N_1 = 0; N_1 < 1024; N_1++) { // relu[M_1, N_1] = (gemm[M_1, N_1])