mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Improve simplification of nested Min and Max patterns.
Specifically, handles the following pattern simplications:
* `Max(A, Max(A, Const)) => Max(A, Const)`
* `Max(Min(A, B), Min(A, C)) => Min(A, Max(B, C))`
* `Max(Const, Max(A, OtherConst) => Max(A, Max(Const, OtherConst))`
- This case can have an arbitrarily long chain of Max ops. For example: `Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z)`
Similarly, for the case of Min as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44142
Reviewed By: albanD
Differential Revision: D23644486
Pulled By: navahgar
fbshipit-source-id: 42bd241e6c2af820566744c8494e5dee172107f4
829 lines
30 KiB
C++
829 lines
30 KiB
C++
#include <test/cpp/tensorexpr/test_base.h>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <unordered_map>
|
|
|
|
#include <test/cpp/tensorexpr/padded_buffer.h>
|
|
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
|
#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
|
|
#include <torch/csrc/jit/tensorexpr/buffer.h>
|
|
#include <torch/csrc/jit/tensorexpr/eval.h>
|
|
#include <torch/csrc/jit/tensorexpr/function.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
static void verifyConstBounds(
|
|
const TensorAccessBoundsInfo& access_info,
|
|
const std::vector<std::pair<int, int>>& ref) {
|
|
size_t ndim = ref.size();
|
|
ASSERT_EQ(access_info.start.size(), ndim);
|
|
ASSERT_EQ(access_info.stop.size(), ndim);
|
|
for (size_t i = 0; i < ndim; i++) {
|
|
if (ref[i].first >= 0) { // Negative values are used to skip the check
|
|
ASSERT_TRUE(access_info.start[i]->isConstant());
|
|
int start_i = immediateAs<int>(access_info.start[i]);
|
|
ASSERT_EQ(start_i, ref[i].first);
|
|
}
|
|
if (ref[i].second >= 0) {
|
|
ASSERT_TRUE(access_info.stop[i]->isConstant());
|
|
int stop_i = immediateAs<int>(access_info.stop[i]);
|
|
ASSERT_EQ(stop_i, ref[i].second);
|
|
}
|
|
}
|
|
}
|
|
|
|
void testBoundsInference_1() {
|
|
// Verify that bounds inference works for the following example:
|
|
// for i in 0..100:
|
|
// b[i] = a[i]
|
|
// For this loop bounds inference should yield the following:
|
|
// {{b, kStore, 0, 99}, {a, kLoad, 0, 99}}
|
|
KernelScope kernel_scope;
|
|
ExprHandle n(100);
|
|
Buffer a(BufHandle("a", {n}, kFloat));
|
|
Tensor* b =
|
|
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); });
|
|
LoopNest l({b});
|
|
auto bounds_info = inferBounds(l.root_stmt());
|
|
|
|
// We should have two entries: one for 'b' and one for 'a'.
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 99}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}});
|
|
}
|
|
|
|
void testBoundsInference_2() {
|
|
// Verify that bounds inference works for the following example:
|
|
// for i in 0..n:
|
|
// b[i] = a[i]
|
|
// For this loop bounds inference should yield the following:
|
|
// {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}}
|
|
KernelScope kernel_scope;
|
|
VarHandle n("n", kInt);
|
|
Buffer a(BufHandle("a", {n}, kFloat));
|
|
Tensor* b =
|
|
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); });
|
|
LoopNest l({b});
|
|
auto bounds_info = inferBounds(l.root_stmt());
|
|
|
|
// We should have two entries: one for 'b' and one for 'a'.
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, -1}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, -1}});
|
|
}
|
|
|
|
void testBoundsInference_3() {
|
|
// Verify that bounds inference works for the following example:
|
|
// for i in 0..100:
|
|
// b[i] = a[i] * a[i+10]
|
|
// For this loop bounds inference should yield the following:
|
|
// {{b, kStore, 0, 99}, {a, kLoad, 0, 109}}
|
|
KernelScope kernel_scope;
|
|
ExprHandle n(100);
|
|
Buffer a(BufHandle("a", {n + 10}, kFloat));
|
|
Tensor* b = Compute(
|
|
"b", {{n, "i"}}, [&](const VarHandle& i) { return a(i) * a(i + 10); });
|
|
LoopNest l({b});
|
|
auto bounds_info = inferBounds(l.root_stmt());
|
|
|
|
// We should have two entries: one for 'b' and one for 'a'.
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 109}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}});
|
|
}
|
|
|
|
void testBoundsInference_4() {
|
|
// Verify that bounds inference works for the following example:
|
|
//
|
|
// for y in 0..200:
|
|
// for x in 0..320:
|
|
// b[y,x] = x*y
|
|
// for y in 0..200:
|
|
// for x in 0..320:
|
|
// c[y,x] = a[y,x] * b[y,x]
|
|
KernelScope kernel_scope;
|
|
ExprHandle W(320);
|
|
ExprHandle H(200);
|
|
Buffer a(BufHandle("a", {H, W}, kFloat));
|
|
Tensor* b = Compute(
|
|
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
|
return x * y;
|
|
});
|
|
Tensor* c = Compute(
|
|
"c", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
|
return a(y, x) * b->call(y, x);
|
|
});
|
|
LoopNest l({c});
|
|
std::vector<For*> loops = l.getLoopStmtsFor(c);
|
|
Stmt* body = l.getLoopBodyFor(c);
|
|
{
|
|
// Infer bounds on the top-level loop scope
|
|
auto bounds_info = inferBounds(loops[0]);
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 199}, {0, 319}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 199}, {0, 319}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 199}, {0, 319}});
|
|
}
|
|
{
|
|
// Infer bounds on the inner loop scope
|
|
auto bounds_info = inferBounds(loops[1]);
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {0, 319}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {0, 319}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {0, 319}});
|
|
}
|
|
{
|
|
// Infer bounds on the inner loop body's scope
|
|
auto bounds_info = inferBounds(body);
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {-1, -1}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {-1, -1}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {-1, -1}});
|
|
}
|
|
}
|
|
|
|
void testBoundsInference_5() {
|
|
// Verify that bounds inference works for the following example:
|
|
// for i in 0..100:
|
|
// b[i] = a[i]
|
|
//
|
|
// ==> split ==>
|
|
//
|
|
// for i_outer in 0..100/16:
|
|
// for i_inner in 0..16:
|
|
// b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner]
|
|
// for i_tail in 0..100%16:
|
|
// b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16];
|
|
KernelScope kernel_scope;
|
|
ExprHandle n(100);
|
|
Buffer a(BufHandle("a", {n}, kFloat));
|
|
Tensor* b =
|
|
Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a(i); });
|
|
LoopNest l({b});
|
|
|
|
For* outer;
|
|
For* inner;
|
|
For* tail;
|
|
std::vector<For*> loops = l.getLoopStmtsFor(b);
|
|
l.splitWithTail(loops[0], 16, &outer, &inner, &tail);
|
|
|
|
{
|
|
// Verify inferred bounds for the outer loop
|
|
auto bounds_info = inferBounds(outer);
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 95}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 95}});
|
|
}
|
|
{
|
|
// Verify inferred bounds for the tail loop
|
|
auto bounds_info = inferBounds(tail);
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{96, 99}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{96, 99}});
|
|
}
|
|
}
|
|
|
|
void testBoundsInference_6() {
|
|
// Verify that bounds inference works for the following example:
|
|
//
|
|
// for y in 0..200:
|
|
// for x in 0..320:
|
|
// b[y,x] = x*y
|
|
// for y in 0..20:
|
|
// for x in 0..32:
|
|
// c[y,x] = a[y+100,x+100] * b[y*2,x*5]
|
|
KernelScope kernel_scope;
|
|
ExprHandle W(320);
|
|
ExprHandle H(200);
|
|
ExprHandle CW(32);
|
|
ExprHandle CH(20);
|
|
Buffer a(BufHandle("a", {H, W}, kFloat));
|
|
Tensor* b = Compute(
|
|
"b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
|
return x * y;
|
|
});
|
|
Tensor* c = Compute(
|
|
"c", {{CH, "y"}, {CW, "x"}}, [&](const VarHandle& y, const VarHandle& x) {
|
|
return a(y + 100, x + 100) * b->call(y * 2, x * 5);
|
|
});
|
|
LoopNest l({c});
|
|
std::vector<For*> loops = l.getLoopStmtsFor(c);
|
|
Stmt* body = l.getLoopBodyFor(c);
|
|
{
|
|
// Infer bounds on the top-level loop scope
|
|
auto bounds_info = inferBounds(loops[0]);
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{100, 119}, {100, 131}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 38}, {0, 155}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 19}, {0, 31}});
|
|
}
|
|
{
|
|
// Infer bounds on the inner loop scope
|
|
auto bounds_info = inferBounds(loops[1]);
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {100, 131}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {0, 155}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {0, 31}});
|
|
}
|
|
{
|
|
// Infer bounds on the inner loop body's scope
|
|
auto bounds_info = inferBounds(body);
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {-1, -1}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {-1, -1}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {-1, -1}});
|
|
}
|
|
}
|
|
|
|
void testBoundsInferenceNonOverlapping() {
|
|
KernelScope kernel_scope;
|
|
ExprHandle H(3);
|
|
Buffer a(BufHandle("a", {10}, kFloat));
|
|
Tensor* b =
|
|
Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a(x); });
|
|
Tensor* c = Compute(
|
|
"c", {{H, "x"}}, [&](const VarHandle& x) { return a(x + H + 1); });
|
|
LoopNest l({b, c});
|
|
std::vector<For*> loops = NodeFinder<For>::find(l.root_stmt());
|
|
|
|
{
|
|
// Infer bounds on the top-level loop scope
|
|
auto bounds_info = inferBounds(loops[0]);
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
|
|
// reads from a[0:2], writes to b[0:2]
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 2}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 2}});
|
|
}
|
|
{
|
|
// Infer bounds on the inner loop scope
|
|
auto bounds_info = inferBounds(loops[1]);
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
|
|
// reads from a[0+4:2+4], writes to c[0:2]
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{4, 6}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 2}});
|
|
}
|
|
{
|
|
// Infer bounds on the high level program.
|
|
auto bounds_info = inferBounds(l.root_stmt());
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
// Should be union of above 2 bounds.
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 2);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 2}});
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[1], {{4, 6}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 2}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 2}});
|
|
}
|
|
}
|
|
|
|
void testBoundsInferenceAdjacent() {
|
|
KernelScope kernel_scope;
|
|
ExprHandle H(6);
|
|
Buffer a(BufHandle("a", {20}, kFloat));
|
|
Tensor* b =
|
|
Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a(x); });
|
|
Tensor* c =
|
|
Compute("c", {{H, "x"}}, [&](const VarHandle& x) { return a(x + H); });
|
|
LoopNest l({b, c});
|
|
std::vector<For*> loops = NodeFinder<For>::find(l.root_stmt());
|
|
|
|
{
|
|
// Infer bounds on the top-level loop scope
|
|
auto bounds_info = inferBounds(loops[0]);
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
|
|
// reads from a[0:5], writes to b[0:5]
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 5}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}});
|
|
}
|
|
{
|
|
// Infer bounds on the inner loop scope
|
|
auto bounds_info = inferBounds(loops[1]);
|
|
ASSERT_EQ(bounds_info.size(), 2);
|
|
|
|
// reads from a[0+6:5+6], writes to c[0:5]
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{6, 11}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}});
|
|
}
|
|
{
|
|
// Infer bounds on the high level program.
|
|
auto bounds_info = inferBounds(l.root_stmt());
|
|
ASSERT_EQ(bounds_info.size(), 3);
|
|
|
|
// Should be union of above 2 bounds, but this time the bounds of A can be
|
|
// merged.
|
|
ASSERT_EQ(bounds_info.at(a.data()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad);
|
|
verifyConstBounds(bounds_info.at(a.data())[0], {{0, 11}});
|
|
|
|
ASSERT_EQ(bounds_info.at(b->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}});
|
|
|
|
ASSERT_EQ(bounds_info.at(c->buf()).size(), 1);
|
|
ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore);
|
|
verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}});
|
|
}
|
|
}
|
|
|
|
void testMergeInferredBounds() {
|
|
KernelScope kernel_scope;
|
|
Buffer a(BufHandle("a", {10}, kFloat));
|
|
|
|
// There are seven cases to consider in mergeTensorAccesses(A, B)
|
|
// * A is lower than B and does not overlap.
|
|
// * A is higher than B and does not overlap.
|
|
// * A overlaps B on both ends.
|
|
// * B overlaps A on both ends.
|
|
// * A overlaps B on the lower end. (equiv to B overlaps A on upper end).
|
|
// * A overlaps B on the upper end. (likewise covers reverse)
|
|
// * A and B are the same range.
|
|
|
|
BoundsInfo info;
|
|
// Test no overlap, both ways.
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(3)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(9)}, {new IntImm(9)}});
|
|
BoundsInfo res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res.size(), 1);
|
|
ASSERT_EQ(res[a.data()].size(), 3);
|
|
|
|
ASSERT_EQ(res.at(a.data())[0].kind, kLoad);
|
|
ASSERT_EQ(res.at(a.data())[1].kind, kLoad);
|
|
ASSERT_EQ(res.at(a.data())[2].kind, kLoad);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 3}});
|
|
verifyConstBounds(res.at(a.data())[1], {{5, 7}});
|
|
verifyConstBounds(res.at(a.data())[2], {{9, 9}});
|
|
|
|
// Test full overlap, A over B.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
|
|
|
|
// B over A.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
|
|
|
|
// Test partial overlap on the low end, A over B.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(6)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{3, 7}});
|
|
|
|
// Test partial overlap on the high end.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(2)}, {new IntImm(5)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{2, 6}});
|
|
|
|
// Test equality is deduped.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(4)}, {new IntImm(6)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{4, 6}});
|
|
}
|
|
|
|
void testMergeInferredLoadStoreDiff() {
|
|
KernelScope kernel_scope;
|
|
Buffer a(BufHandle("a", {10}, kFloat));
|
|
|
|
// Loads and Stores do not merge:
|
|
BoundsInfo info;
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(7)}});
|
|
info[a.data()].push_back({kStore, {new IntImm(3)}, {new IntImm(9)}});
|
|
|
|
BoundsInfo res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res.size(), 1);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
ASSERT_EQ(res.at(a.data())[0].kind, kLoad);
|
|
ASSERT_EQ(res.at(a.data())[1].kind, kStore);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
|
|
verifyConstBounds(res.at(a.data())[1], {{3, 9}});
|
|
|
|
// Do merge around the other kind of access:
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(3)}});
|
|
info[a.data()].push_back({kStore, {new IntImm(3)}, {new IntImm(4)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(3)}, {new IntImm(5)}});
|
|
info[a.data()].push_back({kStore, {new IntImm(4)}, {new IntImm(8)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(7)}});
|
|
res = mergeTensorAccesses(info);
|
|
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 7}});
|
|
verifyConstBounds(res.at(a.data())[1], {{3, 8}});
|
|
}
|
|
|
|
void testMergeInferred2DBounds() {
|
|
KernelScope kernel_scope;
|
|
Buffer a(BufHandle("a", {10, 10}, kFloat));
|
|
|
|
// Non overlapping in both dimensions:
|
|
BoundsInfo info;
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}});
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(5), new IntImm(5)}, {new IntImm(9), new IntImm(9)}});
|
|
|
|
BoundsInfo res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res.size(), 1);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
ASSERT_EQ(res.at(a.data())[0].kind, kLoad);
|
|
ASSERT_EQ(res.at(a.data())[1].kind, kLoad);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}});
|
|
verifyConstBounds(res.at(a.data())[1], {{5, 9}, {5, 9}});
|
|
|
|
// Overlapping in a single dimension should mean we cannot merge.
|
|
// First dimension:
|
|
info.clear();
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}});
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(2), new IntImm(5)}, {new IntImm(9), new IntImm(9)}});
|
|
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}});
|
|
verifyConstBounds(res.at(a.data())[1], {{2, 9}, {5, 9}});
|
|
|
|
// Second dimension:
|
|
info.clear();
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(3), new IntImm(3)}});
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(5), new IntImm(2)}, {new IntImm(9), new IntImm(9)}});
|
|
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 3}, {1, 3}});
|
|
verifyConstBounds(res.at(a.data())[1], {{5, 9}, {2, 9}});
|
|
|
|
// Overlapping in both dimensions:
|
|
// {1-6, 1-3) | {4-9, 2,7} => {1,9, 1,7}
|
|
// TODO: this will overestimate and we should fix it.
|
|
info.clear();
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(1), new IntImm(1)}, {new IntImm(6), new IntImm(3)}});
|
|
info[a.data()].push_back(
|
|
{kLoad, {new IntImm(4), new IntImm(2)}, {new IntImm(9), new IntImm(7)}});
|
|
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 9}, {1, 7}});
|
|
}
|
|
|
|
void testMergeAdjacentBounds() {
|
|
KernelScope kernel_scope;
|
|
Buffer a(BufHandle("a", {10}, kFloat));
|
|
|
|
// Adjacent but not overlapping bounds can be merged.
|
|
// e.g. {1-4} | {5-9} => {1-9}
|
|
BoundsInfo info;
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(9)}});
|
|
BoundsInfo res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 9}});
|
|
|
|
// And on the other side:
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {new IntImm(9)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 9}});
|
|
|
|
// One space gap is enough to prevent merging:
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(1)}, {new IntImm(4)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(6)}, {new IntImm(9)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
verifyConstBounds(res.at(a.data())[0], {{1, 4}});
|
|
verifyConstBounds(res.at(a.data())[1], {{6, 9}});
|
|
}
|
|
|
|
std::pair<std::string, std::string> boundAsStringPair(
|
|
TensorAccessBoundsInfo& info,
|
|
size_t idx = 0) {
|
|
std::ostringstream start, stop;
|
|
start << *info.start[idx];
|
|
stop << *info.stop[idx];
|
|
return {start.str(), stop.str()};
|
|
}
|
|
|
|
void testMergeSymbolicBounds() {
|
|
KernelScope kernel_scope;
|
|
Buffer a(BufHandle("a", {10}, kFloat));
|
|
VarHandle W("W", kInt);
|
|
VarHandle X("X", kInt);
|
|
VarHandle Y("Y", kInt);
|
|
VarHandle Z("Z", kInt);
|
|
|
|
// Can do nothing with fully symbolic bounds:
|
|
BoundsInfo info;
|
|
info[a.data()].push_back({kLoad, {W.node()}, {Z.node()}});
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
BoundsInfo res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
|
|
// Can merge if the difference between bounds is constant and enclosing.
|
|
// {X-Y} | {X-5 - Y+10} => {X-5 - Y+10}
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
info[a.data()].push_back({kLoad,
|
|
{new Sub(X.node(), new IntImm(5))},
|
|
{new Add(Y.node(), new IntImm(10))}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
|
|
// Cannot merge otherwise.
|
|
// {X-Y} | {X+5 - Y+10} => could be 2 groups if Y < X+5.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
info[a.data()].push_back({kLoad,
|
|
{new Add(X.node(), new IntImm(5))},
|
|
{new Add(Y.node(), new IntImm(10))}});
|
|
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
|
|
// Can't merge if there's a gap of at least one element:
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(4)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
|
|
// Can't even though the high of the first bound is above the low of the
|
|
// second, X can == 6 and Y can == 4 so this can't merge in all cases.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(4)}, {Y.node()}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
|
|
// If either side is equal, they must be overlapping.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Z.node()}});
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
auto pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "X");
|
|
ASSERT_EQ(pair.second, "Max(Y, Z, 1)");
|
|
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
info[a.data()].push_back({kLoad, {Z.node()}, {Y.node()}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "Min(X, Z, 1)");
|
|
ASSERT_EQ(pair.second, "Y");
|
|
|
|
// If either side is only one apart, they must be adjacent.
|
|
info.clear();
|
|
info[a.data()].push_back(
|
|
{kLoad, {new Add(X.node(), new IntImm(1))}, {Z.node()}});
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "X");
|
|
ASSERT_EQ(pair.second, "Max(Y, Z, 1)");
|
|
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
info[a.data()].push_back(
|
|
{kLoad, {Z.node()}, {new Sub(Y.node(), new IntImm(1))}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "Min(X, Z, 1)");
|
|
ASSERT_EQ(pair.second, "Y");
|
|
|
|
// If either side is 2 apart, they may not be overlapping.
|
|
// in this case if Y == X+1 they don't overlap.
|
|
info.clear();
|
|
info[a.data()].push_back(
|
|
{kLoad, {new Add(X.node(), new IntImm(2))}, {Z.node()}});
|
|
info[a.data()].push_back(
|
|
{kLoad, {X.node()}, {new Sub(Y.node(), new IntImm(1))}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
|
|
// In this case they may not overlap if X == Y.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {Y.node()}});
|
|
info[a.data()].push_back(
|
|
{kLoad, {Z.node()}, {new Sub(Y.node(), new IntImm(2))}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 2);
|
|
}
|
|
|
|
void testMergeSymbolicAdjacent() {
|
|
KernelScope kernel_scope;
|
|
Buffer a(BufHandle("a", {10}, kFloat));
|
|
VarHandle X("X", kInt);
|
|
VarHandle Y("Y", kInt);
|
|
|
|
BoundsInfo info;
|
|
// Can merge if a range is adjacent:
|
|
// {X-5} | {6-Y} => {X-Y}
|
|
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(5)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}});
|
|
BoundsInfo res = mergeTensorAccesses(info);
|
|
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
auto pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "X");
|
|
ASSERT_EQ(pair.second, "Y");
|
|
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(6)}, {Y.node()}});
|
|
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(5)}});
|
|
res = mergeTensorAccesses(info);
|
|
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "X");
|
|
ASSERT_EQ(pair.second, "Y");
|
|
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}});
|
|
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
|
|
res = mergeTensorAccesses(info);
|
|
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "X");
|
|
ASSERT_EQ(pair.second, "Y");
|
|
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}});
|
|
res = mergeTensorAccesses(info);
|
|
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "X");
|
|
ASSERT_EQ(pair.second, "Y");
|
|
|
|
// If either the lower or upper bound is adjacent the range then they must
|
|
// overlap, even if we don't know the extent.
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {new IntImm(6)}, {X.node()}});
|
|
info[a.data()].push_back({kLoad, {new IntImm(5)}, {Y.node()}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "5");
|
|
ASSERT_EQ(pair.second, "Max(X, Y, 1)");
|
|
|
|
info.clear();
|
|
info[a.data()].push_back({kLoad, {X.node()}, {new IntImm(6)}});
|
|
info[a.data()].push_back({kLoad, {Y.node()}, {new IntImm(5)}});
|
|
res = mergeTensorAccesses(info);
|
|
ASSERT_EQ(res[a.data()].size(), 1);
|
|
pair = boundAsStringPair(res[a.data()][0]);
|
|
ASSERT_EQ(pair.first, "Min(X, Y, 1)");
|
|
ASSERT_EQ(pair.second, "6");
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|