pytorch/test/cpp/tensorexpr/test_registerizer.cpp
Nick Gibson 1390cad2d8 [NNC] Hook up registerizer to Cuda codegen [2/x] (#42878)
Summary:
Insert the registerizer into the Cuda Codegen pass list, to enable scalar replacement and close the gap in simple reduction performance.

First up the good stuff, benchmark before:
```
          Column sum          Caffe2             NNC          Simple          Better
           (10, 100)          5.7917          9.7037          6.9386          6.0448
          (100, 100)          5.9338          14.972          7.1139          6.3254
        (100, 10000)          21.453          741.54          145.74          12.555
        (1000, 1000)          8.0678          122.75          22.833          9.0778

             Row sum          Caffe2             NNC          Simple          Better
           (10, 100)          5.4502          7.9661          6.1469          5.5587
          (100, 100)          5.7613          13.897           21.49          5.5808
        (100, 10000)          21.702          82.398          75.462          22.793
        (1000, 1000)          22.527             129          176.51          22.517

```

After:
```
          Column sum          Caffe2             NNC          Simple          Better
           (10, 100)          6.0458          9.4966          7.1094           6.056
          (100, 100)          5.9299          9.1482          7.1693           6.593
        (100, 10000)          21.739          121.97          162.63          14.376
        (1000, 1000)          9.2374           29.01          26.883          10.127

             Row sum          Caffe2             NNC          Simple          Better
           (10, 100)          5.9773          8.1792          7.2307          5.8941
          (100, 100)          6.1456          9.3155          24.563          5.8163
        (100, 10000)          25.384          30.212          88.531          27.185
        (1000, 1000)          26.517          32.702          209.31          26.537
```

Speedup about 3-8x depending on the size of the data (increasing with bigger inputs).

The gap between NNC and simple is closed or eliminated - remaining issue appears to be kernel launch overhead. Next up is getting us closer to the _Better_ kernel.

It required a lot of refactoring and bug fixes on the way:
* Refactored flattening of parallelized loops out of the CudaPrinter and into its own stage, so we can transform the graph in the stage between flattening and printing (where registerization occurs).
* Made AtomicAddFuser less pessimistic, it will now recognize that if an Add to a buffer is dependent on all used Block and Thread vars then it has no overlap and does not need to be atomic. This allows registerization to apply to these stores.
* Fixed PrioritizeLoad mutator so that it does not attempt to separate the Store and Load to the same buffer (i.e. reduction case).
* Moved CudaAnalysis earlier in the process, allowing later stages to use the analyzed bufs.
* Fixed a bug in the Registerizer where when adding a default initializer statement it would use the dtype of the underlying var (which is always kHandle) instead of the dtype of the Buf.
* Fixed a bug in the IRMutator where Allocate statements logic was inverted to be replaced only if they did not change.
* Added simplification of simple Division patterns to the IRSimplifier.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42878

Reviewed By: glaringlee

Differential Revision: D23382499

Pulled By: nickgg

fbshipit-source-id: 3640a98fd843723abad9f54e67070d48c96fe949
2020-08-31 10:39:46 -07:00

768 lines
17 KiB
C++

#include "test/cpp/tensorexpr/test_base.h"
#include "test/cpp/tensorexpr/test_utils.h"
#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
#include "torch/csrc/jit/tensorexpr/registerizer.h"
#include <iostream>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
// Can replace a simple scalar access with a local variable.
void testRegisterizerSimple() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
VarHandle x("x", kInt);
Stmt* stmt =
Block::make({Store::make(a, {0}, 0, 1),
For::make(
x,
0,
10,
Block::make({Store::make(
a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* A[0] = (A[0]) + x;
* }
*/
registerize(stmt);
/*
* int A_ = 0;
* for (int x = 0; x < 10; x++) {
* A_ = x + A_;
* }
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Won't do replacement of a loop access.
void testRegisterizerLoop() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {10}, kInt));
VarHandle x("x", kInt);
Stmt* stmt =
Block::make({Store::make(a, {0}, 0, 1),
For::make(
x,
0,
10,
Block::make({Store::make(
a, {x}, Add::make(Load::make(a, {x}, 1), x), 1)}))});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* A[x] = (A[x]) + x;
* }
*/
// No change.
registerize(stmt);
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* A[x] = (A[x]) + x;
* }
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK-NOT: int
# CHECK: A[0] = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A_
# CHECK: A[x] =
# CHECK-NOT: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Won't replace even if the load is a fixed scalar, since the store could
// invalidate it.
void testRegisterizerLoopFixedLoad() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
VarHandle x("x", kInt);
Stmt* stmt =
Block::make({Store::make(a, {0}, 0, 1),
For::make(
x,
0,
10,
Block::make({Store::make(
a, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* A[0] = (A[0]) + x;
* }
*/
// No change.
registerize(stmt);
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* A[0] = (A[0]) + x;
* }
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK-NOT: int
# CHECK: A[0] = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A_
# CHECK: A[x] =
# CHECK-NOT: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Will registerize multiple accesses of different items of the same buffer.
void testRegisterizerMultiVar() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {2}, kInt));
VarHandle x("x", kInt);
Stmt* stmt = Block::make({
Store::make(a, {0}, 0, 1),
Store::make(a, {1}, 0, 1),
For::make(
x,
0,
10,
Block::make(
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1),
Store::make(a, {1}, Sub::make(Load::make(a, {1}, 1), x), 1)})),
});
/*
* A[0] = 0;
* A[1] = 0;
* for (int x = 0; x < 10; x++) {
* A[0] = (A[0]) + x;
* A[1] = (A[1]) - x;
* }
*/
registerize(stmt);
/*
* int A_ = 0;
* int A__1 = 0;
* for (int x = 0; x < 10; x++) {
* A__1 = x + A__1;
* A_ = A_ - x;
* }
* A[1] = A__1;
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: int A__1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A__1 =
# CHECK: A[1] = A__1
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Will registerize the valid accesses while skipping invalid replacements.
void testRegisterizerVariableLoad() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
Buffer b(BufHandle("B", {10}, kInt));
VarHandle x("x", kInt);
VarHandle x2("x", kInt);
Stmt* stmt = Block::make(
{Store::make(a, {0}, 0, 1),
For::make(x, 0, 10, Store::make(b, {x}, x, 1)),
For::make(
x2,
0,
10,
Block::make({Store::make(
a,
{0},
Add::make(Load::make(a, {0}, 1), Load::make(b, {x2}, 1)),
1)}))});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* B[x] = x;
* }
* for (int x_1 = 0; x_1 < 10; x_1++) {
* A[0] = (A[0]) + (B[x_1]);
* }
*/
registerize(stmt);
/*
* int A_ = 0;
* for (int x = 0; x < 10; x++) {
* B[x] = x;
* }
* for (int x_1 = 0; x_1 < 10; x_1++) {
* A_ = A_ + (B[x_1]);
* }
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK: B[x] = x
# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Can registerize variable accesses so long as the variable does not change.
void testRegisterizerSymbolicIndices() {
KernelScope kernel_scope;
VarHandle i("i", kInt);
VarHandle N("N", kInt);
Buffer a(BufHandle("A", {N}, kInt));
VarHandle x("x", kInt);
Stmt* stmt =
Block::make({Store::make(a, {i}, 0, 1),
For::make(
x,
0,
10,
Block::make({Store::make(
a, {i}, Add::make(Load::make(a, {i}, 1), x), 1)}))});
/*
* A[i] = 0;
* for (int x = 0; x < 10; x++) {
* A[i] = (A[i]) + x;
* }
*/
registerize(stmt);
/*
* int A_ = 0;
* for (int x = 0; x < 10; x++) {
* A_ = x + A_;
* }
* A[i] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A[i] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Will not registerize if a variable usage of the sclar may overlap the target
// scalar.
// TODO: we can support this by writing back to the buffer before the variable
// access, but we'd need temporal analysis of dependencies which we don't have
// yet. Will have to fix soon though.
void testRegisterizerEarlyStop() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
VarHandle x("x", kInt);
Stmt* stmt = Block::make(
{Store::make(a, {0}, 0, 1),
For::make(
x,
0,
10,
Block::make(
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)})),
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}, 1), 1))});
std::ostringstream before;
before << *stmt;
// No change.
registerize(stmt);
std::ostringstream after;
after << *stmt;
ASSERT_EQ(before.str(), after.str());
}
// Can registerize accesses dependent on multiple loop vars.
void testRegisterizerMultiLoop() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
VarHandle x("x", kInt);
VarHandle y("y", kInt);
Stmt* stmt = Block::make(
{Store::make(a, {0}, 0, 1),
For::make(
x,
0,
10,
For::make(
y,
0,
10,
Block::make({Store::make(
a,
{0},
Mul::make(Add::make(Load::make(a, {0}, 1), x), y),
1)})))});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* for (int y = 0; y < 10; y++) {
* A[0] = x * y + (A[0]) * y;
* }
* }
*/
registerize(stmt);
/*
* int A_ = 0;
* for (int x = 0; x < 10; x++) {
* for (int y = 0; y < 10; y++) {
* A_ = x * y + y * A_l
* }
* }
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK: for (int y = 0; y < 10; y++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Can registerize correctly if scalars already exist in the program.
void testRegisterizerRepeated() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {2}, kInt));
VarHandle x("x", kInt);
Stmt* stmt = Block::make({
Store::make(a, {0}, 0, 1),
Store::make(a, {1}, 0, 1),
For::make(
x,
0,
10,
Block::make(
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1),
Store::make(a, {1}, Sub::make(Load::make(a, {1}, 1), x), 1)})),
});
// Registerize manually to make sure we only replace a single target.
{
RegisterizerAnalysis analysis;
stmt->accept(&analysis);
auto candidates = analysis.getCandidates();
ASSERT_EQ(candidates.size(), 2);
RegisterizerReplacer replacer(candidates.front());
stmt = stmt->accept_mutator(&replacer);
}
// Re-analyze and replace the second target.
{
RegisterizerAnalysis analysis;
stmt->accept(&analysis);
auto candidates = analysis.getCandidates();
ASSERT_EQ(candidates.size(), 1);
RegisterizerReplacer replacer(candidates.front());
stmt = stmt->accept_mutator(&replacer);
}
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: int A__1 = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A__1 =
# CHECK: A[1] = A__1
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Can registerize rthe load of A.
void testRegisterizerNoLoads() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
VarHandle x("x", kInt);
Stmt* stmt = Block::make(
{Store::make(a, {0}, 0, 1),
For::make(
x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1), 1)}))});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* A[0] = x + 1;
* }
*/
registerize(stmt);
/*
* int A_ = 0;
* for (int x = 0; x < 10; x++) {
* A_ = x + 1;
* }
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Can registerize the load of A but not the store of B.
void testRegisterizerNoRepeatedStores() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
Buffer b(BufHandle("B", {10}, kInt));
VarHandle x("x", kInt);
Stmt* stmt =
Block::make({Store::make(a, {0}, 0, 1),
For::make(
x,
0,
10,
Block::make({Store::make(
b, {x}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* B[x] = (A[0]) + x;
* }
*/
registerize(stmt);
// TODO: its unnecessary to reorder the initializer of A[0], but it's not
// actually worse so lets not worry for now.
/*
* int A_ = 0;
* for (int x = 0; x < 10; x++) {
* B[x] = x + A_;
* }
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A_
# CHECK: B[x] =
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
// Won't registerize if there are multiple accesses which may overlap.
void testRegisterizerMultiVarOverlap() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {2}, kInt));
VarHandle x("x", kInt);
Stmt* stmt = Block::make({
Store::make(a, {0}, 0, 1),
Store::make(a, {1}, 0, 1),
For::make(
x,
0,
10,
Block::make(
{Store::make(a, {x}, Add::make(Load::make(a, {0}, 1), x), 1),
Store::make(
a, {x + 1}, Sub::make(Load::make(a, {1}, 1), x), 1)})),
});
std::ostringstream before;
before << *stmt;
// No change.
registerize(stmt);
std::ostringstream after;
after << *stmt;
ASSERT_EQ(before.str(), after.str());
}
void testRegisterizerAllocs() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {2}, kInt));
Buffer b(BufHandle("B", {1}, kInt));
Buffer c(BufHandle("C", {1}, kInt));
VarHandle x("x", kInt);
VarHandle b_(b.data()->base_handle());
Stmt* stmt = Block::make(
{Allocate::make(b_, kInt, {Load::make(c, {0}, 1)}),
Store::make(a, {0}, Load::make(c, {0}, 1), 1),
Store::make(b, {0}, 0, 1),
For::make(
x,
0,
10,
Block::make(
{Store::make(b, {0}, Add::make(Load::make(b, {0}, 1), x), 1),
Store::make(a, {0}, Load::make(c, {0}, 1), 1)})),
Free::make(b_)});
/*
* Allocate(B, int, {C[0]});
* A[0] = C[0];
* B[0] = 0;
* for (int x = 0; x < 10; x++) {
* B[0] = (B[0]) + x;
* A[0] = C[0];
* }
* Free(B);
*/
registerize(stmt);
/*
* int C_ = C[0];
* Allocate(B, int, {C_});
* int A_ = C_;
* int B_ = 0;
* for (int x = 0; x < 10; x++) {
* B_ = B_ + x;
* A_ = C_;
* }
* B[0] = B_;
* A[0] = A_;
* Free(B);
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int C_ = C[0];
# CHECK: Allocate(B
# CHECK: int A_ = C_;
# CHECK: int B_ = 0;
# CHECK: for (int x = 0; x < 10; x++)
# CHECK: B_ =
# CHECK: A_ = C_
# CHECK: B[0] = B_;
# CHECK: A[0] = A_;
# CHECK: Free(B)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
void testRegisterizerNoInitializer() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
VarHandle x("x", kInt);
Stmt* stmt = Block::make({For::make(
x,
0,
10,
Block::make(
{Store::make(a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}))});
/*
* for (int x = 0; x < 10; x++) {
* A[0] = (A[0]) + x;
* }
*/
registerize(stmt);
/*
* int A_ = A[0];
* for (int x = 0; x < 10; x++) {
* A_ = x + A_;
* }
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = A[0];
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: A[
# CHECK: A_ =
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
void testRegisterizerLoadThenStore() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
Buffer b(BufHandle("B", {1}, kInt));
VarHandle x("x", kInt);
Stmt* stmt = Block::make({For::make(
x,
0,
10,
Block::make({Store::make(b, {0}, Add::make(Load::make(a, {0}, 1), x), 1),
Store::make(a, {0}, Load::make(b, {0}, 1), 1)}))});
/*
* for (int x = 0; x < 10; x++) {
* B[0] = (A[0]) + x;
* A[0] = B[0];
* }
*/
registerize(stmt);
/*
* int A_ = A[0];
* int B_ = B[0];
* for (int x = 0; x < 10; x++) {
* B_ = x + A_;
* A_ = B_;
* }
* B[0] = B_;
* A[0] = A_;
*/
std::ostringstream oss;
oss << *stmt;
const std::string& verification_pattern =
R"IR(
# CHECK: int A_ = A[0];
# CHECK: int B_ = B[0];
# CHECK: for (int x = 0; x < 10; x++)
# CHECK-NOT: B[
# CHECK: B_ =
# CHECK-NOT: A[
# CHECK: A_ = B_
# CHECK: B[0] = B_
# CHECK: A[0] = A_;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
void testRegisterizerParallelized() {
KernelScope kernel_scope;
Buffer a(BufHandle("A", {1}, kInt));
VarHandle x("x", kInt);
LoopOptions loopOpts;
loopOpts.set_gpu_block_index(0);
Stmt* stmt =
Block::make({Store::make(a, {0}, 0, 1),
For::make(
x,
0,
10,
Block::make({Store::make(
a, {0}, Add::make(Load::make(a, {0}, 1), x), 1)}),
loopOpts)});
/*
* A[0] = 0;
* for (int x = 0; x < 10; x++) {
* A[0] = (A[0]) + x;
* }
*/
ASSERT_THROWS_WITH(
registerize(stmt),
"Registerization must occur after parallelism flattening");
}
} // namespace jit
} // namespace torch